updates
This commit is contained in:
@@ -0,0 +1,50 @@
|
||||
"""add_anonymous_gdpr_support
|
||||
|
||||
Revision ID: 6f7f8689fc98
|
||||
Revises: 7a899ef55e3b
|
||||
Create Date: 2025-12-01 04:15:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '6f7f8689fc98'
|
||||
down_revision = '7a899ef55e3b'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Update gdpr_requests table to support anonymous users
|
||||
op.alter_column('gdpr_requests', 'user_id',
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True)
|
||||
op.add_column('gdpr_requests', sa.Column('is_anonymous', sa.Boolean(), nullable=False, server_default='0'))
|
||||
op.create_index(op.f('ix_gdpr_requests_is_anonymous'), 'gdpr_requests', ['is_anonymous'], unique=False)
|
||||
|
||||
# Update consents table to support anonymous users
|
||||
op.alter_column('consents', 'user_id',
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True)
|
||||
op.add_column('consents', sa.Column('user_email', sa.String(length=255), nullable=True))
|
||||
op.add_column('consents', sa.Column('is_anonymous', sa.Boolean(), nullable=False, server_default='0'))
|
||||
op.create_index(op.f('ix_consents_user_email'), 'consents', ['user_email'], unique=False)
|
||||
op.create_index(op.f('ix_consents_is_anonymous'), 'consents', ['is_anonymous'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f('ix_consents_is_anonymous'), table_name='consents')
|
||||
op.drop_index(op.f('ix_consents_user_email'), table_name='consents')
|
||||
op.drop_column('consents', 'is_anonymous')
|
||||
op.drop_column('consents', 'user_email')
|
||||
op.alter_column('consents', 'user_id',
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False)
|
||||
|
||||
op.drop_index(op.f('ix_gdpr_requests_is_anonymous'), table_name='gdpr_requests')
|
||||
op.drop_column('gdpr_requests', 'is_anonymous')
|
||||
op.alter_column('gdpr_requests', 'user_id',
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False)
|
||||
@@ -0,0 +1,173 @@
|
||||
"""add_comprehensive_gdpr_tables
|
||||
|
||||
Revision ID: 7a899ef55e3b
|
||||
Revises: dbafe747c931
|
||||
Create Date: 2025-12-01 04:10:25.699589
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import mysql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '7a899ef55e3b'
|
||||
down_revision = 'dbafe747c931'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Consent table
|
||||
op.create_table(
|
||||
'consents',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('consent_type', sa.Enum('marketing', 'analytics', 'necessary', 'preferences', 'third_party_sharing', 'profiling', 'automated_decision_making', name='consenttype'), nullable=False),
|
||||
sa.Column('status', sa.Enum('granted', 'withdrawn', 'pending', 'expired', name='consentstatus'), nullable=False),
|
||||
sa.Column('granted_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('withdrawn_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('expires_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('legal_basis', sa.String(length=100), nullable=True),
|
||||
sa.Column('consent_method', sa.String(length=50), nullable=True),
|
||||
sa.Column('consent_version', sa.String(length=20), nullable=True),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('user_agent', sa.String(length=255), nullable=True),
|
||||
sa.Column('source', sa.String(length=100), nullable=True),
|
||||
sa.Column('extra_metadata', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_consents_id'), 'consents', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_consents_user_id'), 'consents', ['user_id'], unique=False)
|
||||
op.create_index(op.f('ix_consents_consent_type'), 'consents', ['consent_type'], unique=False)
|
||||
op.create_index(op.f('ix_consents_status'), 'consents', ['status'], unique=False)
|
||||
op.create_index(op.f('ix_consents_created_at'), 'consents', ['created_at'], unique=False)
|
||||
|
||||
# Data processing records table
|
||||
op.create_table(
|
||||
'data_processing_records',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('processing_category', sa.Enum('collection', 'storage', 'usage', 'sharing', 'deletion', 'anonymization', 'transfer', name='processingcategory'), nullable=False),
|
||||
sa.Column('legal_basis', sa.Enum('consent', 'contract', 'legal_obligation', 'vital_interests', 'public_task', 'legitimate_interests', name='legalbasis'), nullable=False),
|
||||
sa.Column('purpose', sa.Text(), nullable=False),
|
||||
sa.Column('data_categories', sa.JSON(), nullable=True),
|
||||
sa.Column('data_subjects', sa.JSON(), nullable=True),
|
||||
sa.Column('recipients', sa.JSON(), nullable=True),
|
||||
sa.Column('third_parties', sa.JSON(), nullable=True),
|
||||
sa.Column('transfers_to_third_countries', sa.Boolean(), nullable=False),
|
||||
sa.Column('transfer_countries', sa.JSON(), nullable=True),
|
||||
sa.Column('safeguards', sa.Text(), nullable=True),
|
||||
sa.Column('retention_period', sa.String(length=100), nullable=True),
|
||||
sa.Column('retention_criteria', sa.Text(), nullable=True),
|
||||
sa.Column('security_measures', sa.Text(), nullable=True),
|
||||
sa.Column('user_id', sa.Integer(), nullable=True),
|
||||
sa.Column('related_booking_id', sa.Integer(), nullable=True),
|
||||
sa.Column('related_payment_id', sa.Integer(), nullable=True),
|
||||
sa.Column('processed_by', sa.Integer(), nullable=True),
|
||||
sa.Column('processing_timestamp', sa.DateTime(), nullable=False),
|
||||
sa.Column('extra_metadata', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['processed_by'], ['users.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_data_processing_records_id'), 'data_processing_records', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_data_processing_records_processing_category'), 'data_processing_records', ['processing_category'], unique=False)
|
||||
op.create_index(op.f('ix_data_processing_records_legal_basis'), 'data_processing_records', ['legal_basis'], unique=False)
|
||||
op.create_index(op.f('ix_data_processing_records_user_id'), 'data_processing_records', ['user_id'], unique=False)
|
||||
op.create_index(op.f('ix_data_processing_records_processing_timestamp'), 'data_processing_records', ['processing_timestamp'], unique=False)
|
||||
op.create_index(op.f('ix_data_processing_records_created_at'), 'data_processing_records', ['created_at'], unique=False)
|
||||
|
||||
# Data breaches table
|
||||
op.create_table(
|
||||
'data_breaches',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('breach_type', sa.Enum('confidentiality', 'integrity', 'availability', name='breachtype'), nullable=False),
|
||||
sa.Column('status', sa.Enum('detected', 'investigating', 'contained', 'reported_to_authority', 'notified_data_subjects', 'resolved', name='breachstatus'), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=False),
|
||||
sa.Column('affected_data_categories', sa.JSON(), nullable=True),
|
||||
sa.Column('affected_data_subjects', sa.JSON(), nullable=True),
|
||||
sa.Column('detected_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('occurred_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('contained_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('reported_to_authority_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('authority_reference', sa.String(length=255), nullable=True),
|
||||
sa.Column('notified_data_subjects_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('notification_method', sa.String(length=100), nullable=True),
|
||||
sa.Column('likely_consequences', sa.Text(), nullable=True),
|
||||
sa.Column('measures_proposed', sa.Text(), nullable=True),
|
||||
sa.Column('risk_level', sa.String(length=20), nullable=True),
|
||||
sa.Column('reported_by', sa.Integer(), nullable=False),
|
||||
sa.Column('investigated_by', sa.Integer(), nullable=True),
|
||||
sa.Column('extra_metadata', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['investigated_by'], ['users.id'], ),
|
||||
sa.ForeignKeyConstraint(['reported_by'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_data_breaches_id'), 'data_breaches', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_data_breaches_breach_type'), 'data_breaches', ['breach_type'], unique=False)
|
||||
op.create_index(op.f('ix_data_breaches_status'), 'data_breaches', ['status'], unique=False)
|
||||
op.create_index(op.f('ix_data_breaches_detected_at'), 'data_breaches', ['detected_at'], unique=False)
|
||||
|
||||
# Retention rules table
|
||||
op.create_table(
|
||||
'retention_rules',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('data_category', sa.String(length=100), nullable=False),
|
||||
sa.Column('retention_period_days', sa.Integer(), nullable=False),
|
||||
sa.Column('retention_period_months', sa.Integer(), nullable=True),
|
||||
sa.Column('retention_period_years', sa.Integer(), nullable=True),
|
||||
sa.Column('legal_basis', sa.Text(), nullable=True),
|
||||
sa.Column('legal_requirement', sa.Text(), nullable=True),
|
||||
sa.Column('action_after_retention', sa.String(length=50), nullable=False),
|
||||
sa.Column('conditions', sa.JSON(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('created_by', sa.Integer(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['created_by'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('data_category')
|
||||
)
|
||||
op.create_index(op.f('ix_retention_rules_id'), 'retention_rules', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_retention_rules_data_category'), 'retention_rules', ['data_category'], unique=True)
|
||||
op.create_index(op.f('ix_retention_rules_is_active'), 'retention_rules', ['is_active'], unique=False)
|
||||
|
||||
# Data retention logs table
|
||||
op.create_table(
|
||||
'data_retention_logs',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('retention_rule_id', sa.Integer(), nullable=False),
|
||||
sa.Column('data_category', sa.String(length=100), nullable=False),
|
||||
sa.Column('action_taken', sa.String(length=50), nullable=False),
|
||||
sa.Column('records_affected', sa.Integer(), nullable=False),
|
||||
sa.Column('affected_ids', sa.JSON(), nullable=True),
|
||||
sa.Column('executed_by', sa.Integer(), nullable=True),
|
||||
sa.Column('executed_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('success', sa.Boolean(), nullable=False),
|
||||
sa.Column('error_message', sa.Text(), nullable=True),
|
||||
sa.Column('extra_metadata', sa.JSON(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['executed_by'], ['users.id'], ),
|
||||
sa.ForeignKeyConstraint(['retention_rule_id'], ['retention_rules.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_data_retention_logs_id'), 'data_retention_logs', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_data_retention_logs_retention_rule_id'), 'data_retention_logs', ['retention_rule_id'], unique=False)
|
||||
op.create_index(op.f('ix_data_retention_logs_data_category'), 'data_retention_logs', ['data_category'], unique=False)
|
||||
op.create_index(op.f('ix_data_retention_logs_executed_at'), 'data_retention_logs', ['executed_at'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop foreign keys first, then indexes, then tables
|
||||
op.drop_table('data_retention_logs')
|
||||
|
||||
op.drop_table('retention_rules')
|
||||
op.drop_table('data_breaches')
|
||||
op.drop_table('data_processing_records')
|
||||
op.drop_table('consents')
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,26 +1,28 @@
|
||||
fastapi==0.104.1
|
||||
fastapi==0.123.0
|
||||
uvicorn[standard]==0.24.0
|
||||
python-dotenv==1.0.0
|
||||
sqlalchemy==2.0.23
|
||||
pymysql==1.1.0
|
||||
pymysql==1.1.2
|
||||
cryptography>=41.0.7
|
||||
python-jose[cryptography]==3.3.0
|
||||
python-jose[cryptography]==3.5.0
|
||||
bcrypt==4.1.2
|
||||
python-multipart==0.0.6
|
||||
python-multipart==0.0.20
|
||||
aiofiles==23.2.1
|
||||
email-validator==2.1.0
|
||||
pydantic==2.5.0
|
||||
pydantic-settings==2.1.0
|
||||
slowapi==0.1.9
|
||||
pillow==10.1.0
|
||||
pillow==12.0.0
|
||||
aiosmtplib==3.0.1
|
||||
jinja2==3.1.2
|
||||
jinja2==3.1.6
|
||||
alembic==1.12.1
|
||||
stripe>=13.2.0
|
||||
paypal-checkout-serversdk>=1.0.3
|
||||
pyotp==2.9.0
|
||||
qrcode[pil]==7.4.2
|
||||
httpx==0.25.2
|
||||
httpx==0.28.1
|
||||
httpcore==1.0.9
|
||||
h11==0.16.0
|
||||
cryptography>=41.0.7
|
||||
bleach==6.1.0
|
||||
|
||||
|
||||
@@ -1,13 +1,38 @@
|
||||
import uvicorn
|
||||
import signal
|
||||
import sys
|
||||
from src.shared.config.settings import settings
|
||||
from src.shared.config.logging_config import setup_logging, get_logger
|
||||
setup_logging()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
"""Handle Ctrl+C gracefully."""
|
||||
logger.info('\nReceived interrupt signal (Ctrl+C). Shutting down gracefully...')
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Register signal handler for graceful shutdown on Ctrl+C
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
logger.info(f'Starting {settings.APP_NAME} on {settings.HOST}:{settings.PORT}')
|
||||
import os
|
||||
from pathlib import Path
|
||||
base_dir = Path(__file__).parent
|
||||
src_dir = str(base_dir / 'src')
|
||||
use_reload = False
|
||||
uvicorn.run('src.main:app', host=settings.HOST, port=settings.PORT, reload=use_reload, log_level=settings.LOG_LEVEL.lower(), reload_dirs=[src_dir] if use_reload else None, reload_excludes=['*.log', '*.pyc', '*.pyo', '*.pyd', '__pycache__', '**/__pycache__/**', '*.db', '*.sqlite', '*.sqlite3'], reload_delay=1.0)
|
||||
# Enable hot reload in development mode or if explicitly enabled via environment variable
|
||||
use_reload = settings.is_development or os.getenv('ENABLE_RELOAD', 'false').lower() == 'true'
|
||||
if use_reload:
|
||||
logger.info('Hot reload enabled - server will restart on code changes')
|
||||
logger.info('Press Ctrl+C to stop the server')
|
||||
uvicorn.run(
|
||||
'src.main:app',
|
||||
host=settings.HOST,
|
||||
port=settings.PORT,
|
||||
reload=use_reload,
|
||||
log_level=settings.LOG_LEVEL.lower(),
|
||||
reload_dirs=[src_dir] if use_reload else None,
|
||||
reload_excludes=['*.log', '*.pyc', '*.pyo', '*.pyd', '__pycache__', '**/__pycache__/**', '*.db', '*.sqlite', '*.sqlite3', 'venv/**', '.venv/**'],
|
||||
reload_delay=0.5
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -7,6 +7,7 @@ import uuid
|
||||
import os
|
||||
from ...shared.config.database import get_db
|
||||
from ..services.auth_service import auth_service
|
||||
from ..services.session_service import session_service
|
||||
from ..schemas.auth import RegisterRequest, LoginRequest, RefreshTokenRequest, ForgotPasswordRequest, ResetPasswordRequest, AuthResponse, TokenResponse, MessageResponse, MFAInitResponse, EnableMFARequest, VerifyMFARequest, MFAStatusResponse, UpdateProfileRequest
|
||||
from ...security.middleware.auth import get_current_user
|
||||
from ..models.user import User
|
||||
@@ -85,6 +86,26 @@ async def register(
|
||||
path='/'
|
||||
)
|
||||
|
||||
# Create user session for new registration
|
||||
try:
|
||||
# Extract device info from user agent
|
||||
device_info = None
|
||||
if user_agent:
|
||||
device_info = {'user_agent': user_agent}
|
||||
|
||||
session_service.create_session(
|
||||
db=db,
|
||||
user_id=result['user']['id'],
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
device_info=str(device_info) if device_info else None
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but don't fail registration if session creation fails
|
||||
from ...shared.config.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
logger.warning(f'Failed to create session during registration: {str(e)}')
|
||||
|
||||
# Log successful registration
|
||||
await audit_service.log_action(
|
||||
db=db,
|
||||
@@ -171,6 +192,26 @@ async def login(
|
||||
path='/'
|
||||
)
|
||||
|
||||
# Create user session
|
||||
try:
|
||||
# Extract device info from user agent
|
||||
device_info = None
|
||||
if user_agent:
|
||||
device_info = {'user_agent': user_agent}
|
||||
|
||||
session_service.create_session(
|
||||
db=db,
|
||||
user_id=result['user']['id'],
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
device_info=str(device_info) if device_info else None
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but don't fail login if session creation fails
|
||||
from ...shared.config.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
logger.warning(f'Failed to create session during login: {str(e)}')
|
||||
|
||||
# Log successful login
|
||||
await audit_service.log_action(
|
||||
db=db,
|
||||
@@ -394,16 +435,23 @@ async def upload_avatar(request: Request, image: UploadFile=File(...), current_u
|
||||
|
||||
# Validate file completely (MIME type, size, magic bytes, integrity)
|
||||
content = await validate_uploaded_image(image, max_avatar_size)
|
||||
upload_dir = Path(__file__).parent.parent.parent / 'uploads' / 'avatars'
|
||||
# Use same path calculation as main.py: go from Backend/src/auth/routes/auth_routes.py
|
||||
# to Backend/uploads/avatars
|
||||
upload_dir = Path(__file__).parent.parent.parent.parent / 'uploads' / 'avatars'
|
||||
upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
if current_user.avatar:
|
||||
old_avatar_path = Path(__file__).parent.parent.parent / current_user.avatar.lstrip('/')
|
||||
old_avatar_path = Path(__file__).parent.parent.parent.parent / current_user.avatar.lstrip('/')
|
||||
if old_avatar_path.exists() and old_avatar_path.is_file():
|
||||
try:
|
||||
old_avatar_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
ext = Path(image.filename).suffix or '.png'
|
||||
# Sanitize filename to prevent path traversal attacks
|
||||
from ...shared.utils.sanitization import sanitize_filename
|
||||
original_filename = image.filename or 'avatar.png'
|
||||
sanitized_filename = sanitize_filename(original_filename)
|
||||
ext = Path(sanitized_filename).suffix or '.png'
|
||||
# Generate secure filename with user ID and UUID to prevent collisions
|
||||
filename = f'avatar-{current_user.id}-{uuid.uuid4()}{ext}'
|
||||
file_path = upload_dir / filename
|
||||
async with aiofiles.open(file_path, 'wb') as f:
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
"""
|
||||
User session management routes.
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, Cookie
|
||||
from sqlalchemy.orm import Session
|
||||
from ...shared.config.database import get_db
|
||||
from ...shared.config.logging_config import get_logger
|
||||
from ...shared.config.settings import settings
|
||||
from ...security.middleware.auth import get_current_user
|
||||
from ...auth.models.user import User
|
||||
from ...auth.models.user_session import UserSession
|
||||
from ...auth.services.session_service import session_service
|
||||
from ...shared.utils.response_helpers import success_response
|
||||
from jose import jwt
|
||||
|
||||
logger = get_logger(__name__)
|
||||
router = APIRouter(prefix='/sessions', tags=['sessions'])
|
||||
@@ -44,13 +47,15 @@ async def get_my_sessions(
|
||||
@router.delete('/{session_id}')
|
||||
async def revoke_session(
|
||||
session_id: int,
|
||||
request: Request,
|
||||
response: Response,
|
||||
current_user: User = Depends(get_current_user),
|
||||
access_token: str = Cookie(None, alias='accessToken'),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Revoke a specific session."""
|
||||
try:
|
||||
# Verify session belongs to user
|
||||
from ...auth.models.user_session import UserSession
|
||||
session = db.query(UserSession).filter(
|
||||
UserSession.id == session_id,
|
||||
UserSession.user_id == current_user.id
|
||||
@@ -59,10 +64,62 @@ async def revoke_session(
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail='Session not found')
|
||||
|
||||
# Check if this is the current session being revoked
|
||||
# We detect this by checking if:
|
||||
# 1. The session IP matches the request IP (if available)
|
||||
# 2. The session is the most recent active session
|
||||
is_current_session = False
|
||||
try:
|
||||
client_ip = request.client.host if request.client else None
|
||||
user_agent = request.headers.get('User-Agent', '')
|
||||
|
||||
# Check if session matches current request characteristics
|
||||
if client_ip and session.ip_address == client_ip:
|
||||
# Also check if it's the most recent session
|
||||
recent_session = db.query(UserSession).filter(
|
||||
UserSession.user_id == current_user.id,
|
||||
UserSession.is_active == True
|
||||
).order_by(UserSession.last_activity.desc()).first()
|
||||
|
||||
if recent_session and recent_session.id == session_id:
|
||||
is_current_session = True
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not determine if session is current: {str(e)}')
|
||||
# If we can't determine, check if it's the only active session
|
||||
active_sessions_count = db.query(UserSession).filter(
|
||||
UserSession.user_id == current_user.id,
|
||||
UserSession.is_active == True
|
||||
).count()
|
||||
if active_sessions_count <= 1:
|
||||
is_current_session = True
|
||||
|
||||
success = session_service.revoke_session(db=db, session_token=session.session_token)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail='Session not found')
|
||||
|
||||
# If this was the current session, clear cookies and indicate logout needed
|
||||
if is_current_session:
|
||||
from ...shared.config.settings import settings
|
||||
samesite_value = 'strict' if settings.is_production else 'lax'
|
||||
# Clear access token cookie
|
||||
response.delete_cookie(
|
||||
key='accessToken',
|
||||
path='/',
|
||||
samesite=samesite_value,
|
||||
secure=settings.is_production
|
||||
)
|
||||
# Clear refresh token cookie
|
||||
response.delete_cookie(
|
||||
key='refreshToken',
|
||||
path='/',
|
||||
samesite=samesite_value,
|
||||
secure=settings.is_production
|
||||
)
|
||||
return success_response(
|
||||
message='Session revoked successfully. You have been logged out.',
|
||||
data={'logout_required': True}
|
||||
)
|
||||
|
||||
return success_response(message='Session revoked successfully')
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -72,19 +129,41 @@ async def revoke_session(
|
||||
|
||||
@router.post('/revoke-all')
|
||||
async def revoke_all_sessions(
|
||||
request: Request,
|
||||
response: Response,
|
||||
current_user: User = Depends(get_current_user),
|
||||
access_token: str = Cookie(None, alias='accessToken'),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Revoke all sessions for current user."""
|
||||
try:
|
||||
count = session_service.revoke_all_user_sessions(
|
||||
db=db,
|
||||
user_id=current_user.id
|
||||
user_id=current_user.id,
|
||||
exclude_token=None # Don't exclude current session, revoke all
|
||||
)
|
||||
|
||||
# Clear cookies since all sessions (including current) are revoked
|
||||
from ...shared.config.settings import settings
|
||||
samesite_value = 'strict' if settings.is_production else 'lax'
|
||||
# Clear access token cookie
|
||||
response.delete_cookie(
|
||||
key='accessToken',
|
||||
path='/',
|
||||
samesite=samesite_value,
|
||||
secure=settings.is_production
|
||||
)
|
||||
# Clear refresh token cookie
|
||||
response.delete_cookie(
|
||||
key='refreshToken',
|
||||
path='/',
|
||||
samesite=samesite_value,
|
||||
secure=settings.is_production
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data={'revoked_count': count},
|
||||
message=f'Revoked {count} session(s)'
|
||||
data={'revoked_count': count, 'logout_required': True},
|
||||
message=f'Revoked {count} session(s). You have been logged out.'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Error revoking all sessions: {str(e)}', exc_info=True)
|
||||
|
||||
Binary file not shown.
@@ -29,19 +29,13 @@ class AuthService:
|
||||
if not self.jwt_secret:
|
||||
error_msg = (
|
||||
'CRITICAL: JWT_SECRET is not configured. '
|
||||
'Please set JWT_SECRET environment variable to a secure random string (minimum 32 characters).'
|
||||
'Please set JWT_SECRET environment variable to a secure random string (minimum 64 characters). '
|
||||
'Generate one using: python -c "import secrets; print(secrets.token_urlsafe(64))"'
|
||||
)
|
||||
logger.error(error_msg)
|
||||
if settings.is_production:
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
# In development, generate a secure secret but warn
|
||||
import secrets
|
||||
self.jwt_secret = secrets.token_urlsafe(64)
|
||||
logger.warning(
|
||||
f'JWT_SECRET not configured. Auto-generated secret for development. '
|
||||
f'Set JWT_SECRET environment variable for production: {self.jwt_secret}'
|
||||
)
|
||||
# SECURITY: Always fail if JWT_SECRET is not configured, even in development
|
||||
# This prevents accidental deployment without proper secrets
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Validate JWT secret strength
|
||||
if len(self.jwt_secret) < 32:
|
||||
@@ -65,14 +59,37 @@ class AuthService:
|
||||
self.jwt_refresh_expires_in = os.getenv("JWT_REFRESH_EXPIRES_IN", "7d")
|
||||
|
||||
def generate_tokens(self, user_id: int) -> dict:
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# SECURITY: Add standard JWT claims for better security
|
||||
now = datetime.utcnow()
|
||||
access_expires = now + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
refresh_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
access_payload = {
|
||||
"userId": user_id,
|
||||
"exp": access_expires, # Expiration time
|
||||
"iat": now, # Issued at
|
||||
"iss": settings.APP_NAME, # Issuer
|
||||
"type": "access" # Token type
|
||||
}
|
||||
|
||||
refresh_payload = {
|
||||
"userId": user_id,
|
||||
"exp": refresh_expires, # Expiration time
|
||||
"iat": now, # Issued at
|
||||
"iss": settings.APP_NAME, # Issuer
|
||||
"type": "refresh" # Token type
|
||||
}
|
||||
|
||||
access_token = jwt.encode(
|
||||
{"userId": user_id},
|
||||
access_payload,
|
||||
self.jwt_secret,
|
||||
algorithm="HS256"
|
||||
)
|
||||
|
||||
refresh_token = jwt.encode(
|
||||
{"userId": user_id},
|
||||
refresh_payload,
|
||||
self.jwt_refresh_secret,
|
||||
algorithm="HS256"
|
||||
)
|
||||
@@ -316,8 +333,22 @@ class AuthService:
|
||||
db.commit()
|
||||
raise ValueError("Refresh token expired")
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# SECURITY: Add standard JWT claims when refreshing token
|
||||
now = datetime.utcnow()
|
||||
access_expires = now + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
access_payload = {
|
||||
"userId": decoded["userId"],
|
||||
"exp": access_expires, # Expiration time
|
||||
"iat": now, # Issued at
|
||||
"iss": settings.APP_NAME, # Issuer
|
||||
"type": "access" # Token type
|
||||
}
|
||||
|
||||
access_token = jwt.encode(
|
||||
{"userId": decoded["userId"]},
|
||||
access_payload,
|
||||
self.jwt_secret,
|
||||
algorithm="HS256"
|
||||
)
|
||||
|
||||
Binary file not shown.
@@ -4,7 +4,7 @@ from sqlalchemy import and_, or_, func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
import random
|
||||
import secrets
|
||||
import os
|
||||
from ...shared.config.database import get_db
|
||||
from ...shared.config.settings import settings
|
||||
@@ -37,7 +37,8 @@ def _generate_invoice_email_html(invoice: dict, is_proforma: bool=False) -> str:
|
||||
def generate_booking_number() -> str:
|
||||
prefix = 'BK'
|
||||
ts = int(datetime.utcnow().timestamp() * 1000)
|
||||
rand = random.randint(1000, 9999)
|
||||
# Use cryptographically secure random number to prevent enumeration attacks
|
||||
rand = secrets.randbelow(9000) + 1000 # Random number between 1000-9999
|
||||
return f'{prefix}-{ts}-{rand}'
|
||||
|
||||
def calculate_booking_payment_balance(booking: Booking) -> dict:
|
||||
|
||||
Binary file not shown.
@@ -1,7 +1,7 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
from decimal import Decimal
|
||||
from ..models.group_booking import (
|
||||
@@ -21,11 +21,13 @@ class GroupBookingService:
|
||||
|
||||
@staticmethod
|
||||
def generate_group_booking_number(db: Session) -> str:
|
||||
"""Generate unique group booking number"""
|
||||
"""Generate unique group booking number using cryptographically secure random"""
|
||||
max_attempts = 10
|
||||
alphabet = string.ascii_uppercase + string.digits
|
||||
for _ in range(max_attempts):
|
||||
timestamp = datetime.utcnow().strftime('%Y%m%d')
|
||||
random_suffix = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6))
|
||||
# Use secrets.choice() instead of random.choices() for security
|
||||
random_suffix = ''.join(secrets.choice(alphabet) for _ in range(6))
|
||||
booking_number = f"GRP-{timestamp}-{random_suffix}"
|
||||
|
||||
existing = db.query(GroupBooking).filter(
|
||||
@@ -35,8 +37,9 @@ class GroupBookingService:
|
||||
if not existing:
|
||||
return booking_number
|
||||
|
||||
# Fallback
|
||||
return f"GRP-{int(datetime.utcnow().timestamp())}"
|
||||
# Fallback with secure random suffix
|
||||
random_suffix = ''.join(secrets.choice(alphabet) for _ in range(4))
|
||||
return f"GRP-{int(datetime.utcnow().timestamp())}{random_suffix}"
|
||||
|
||||
@staticmethod
|
||||
def calculate_group_discount(
|
||||
@@ -405,17 +408,19 @@ class GroupBookingService:
|
||||
# Use proportional share
|
||||
booking_price = group_booking.total_price / group_booking.total_rooms
|
||||
|
||||
# Generate booking number
|
||||
import random
|
||||
# Generate booking number using cryptographically secure random
|
||||
prefix = 'BK'
|
||||
ts = int(datetime.utcnow().timestamp() * 1000)
|
||||
rand = random.randint(1000, 9999)
|
||||
# Use secrets.randbelow() instead of random.randint() for security
|
||||
rand = secrets.randbelow(9000) + 1000 # Random number between 1000-9999
|
||||
booking_number = f'{prefix}-{ts}-{rand}'
|
||||
|
||||
# Ensure uniqueness
|
||||
existing = db.query(Booking).filter(Booking.booking_number == booking_number).first()
|
||||
if existing:
|
||||
booking_number = f'{prefix}-{ts}-{rand + 1}'
|
||||
# If collision, generate new secure random number
|
||||
rand = secrets.randbelow(9000) + 1000
|
||||
booking_number = f'{prefix}-{ts}-{rand}'
|
||||
|
||||
# Create booking
|
||||
booking = Booking(
|
||||
|
||||
26
Backend/src/compliance/models/__init__.py
Normal file
26
Backend/src/compliance/models/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
GDPR Compliance Models.
|
||||
"""
|
||||
from .gdpr_request import GDPRRequest, GDPRRequestType, GDPRRequestStatus
|
||||
from .consent import Consent, ConsentType, ConsentStatus
|
||||
from .data_processing_record import DataProcessingRecord, ProcessingCategory, LegalBasis
|
||||
from .data_breach import DataBreach, BreachType, BreachStatus
|
||||
from .data_retention import RetentionRule, DataRetentionLog
|
||||
|
||||
__all__ = [
|
||||
'GDPRRequest',
|
||||
'GDPRRequestType',
|
||||
'GDPRRequestStatus',
|
||||
'Consent',
|
||||
'ConsentType',
|
||||
'ConsentStatus',
|
||||
'DataProcessingRecord',
|
||||
'ProcessingCategory',
|
||||
'LegalBasis',
|
||||
'DataBreach',
|
||||
'BreachType',
|
||||
'BreachStatus',
|
||||
'RetentionRule',
|
||||
'DataRetentionLog',
|
||||
]
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
64
Backend/src/compliance/models/consent.py
Normal file
64
Backend/src/compliance/models/consent.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
GDPR Consent Management Model.
|
||||
"""
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum, JSON, Boolean
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
import enum
|
||||
from ...shared.config.database import Base
|
||||
|
||||
class ConsentType(str, enum.Enum):
|
||||
"""Types of consent that can be given or withdrawn."""
|
||||
marketing = "marketing"
|
||||
analytics = "analytics"
|
||||
necessary = "necessary"
|
||||
preferences = "preferences"
|
||||
third_party_sharing = "third_party_sharing"
|
||||
profiling = "profiling"
|
||||
automated_decision_making = "automated_decision_making"
|
||||
|
||||
class ConsentStatus(str, enum.Enum):
|
||||
"""Status of consent."""
|
||||
granted = "granted"
|
||||
withdrawn = "withdrawn"
|
||||
pending = "pending"
|
||||
expired = "expired"
|
||||
|
||||
class Consent(Base):
|
||||
"""Model for tracking user consent for GDPR compliance."""
|
||||
__tablename__ = 'consents'
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=True, index=True) # Nullable for anonymous users
|
||||
user_email = Column(String(255), nullable=True, index=True) # Email for anonymous users
|
||||
is_anonymous = Column(Boolean, default=False, nullable=False, index=True) # Flag for anonymous consent
|
||||
consent_type = Column(Enum(ConsentType), nullable=False, index=True)
|
||||
status = Column(Enum(ConsentStatus), default=ConsentStatus.granted, nullable=False, index=True)
|
||||
|
||||
# Consent details
|
||||
granted_at = Column(DateTime, nullable=True)
|
||||
withdrawn_at = Column(DateTime, nullable=True)
|
||||
expires_at = Column(DateTime, nullable=True) # For time-limited consent
|
||||
|
||||
# Legal basis (Article 6 GDPR)
|
||||
legal_basis = Column(String(100), nullable=True) # consent, contract, legal_obligation, vital_interests, public_task, legitimate_interests
|
||||
|
||||
# Consent method
|
||||
consent_method = Column(String(50), nullable=True) # explicit, implicit, pre_checked
|
||||
consent_version = Column(String(20), nullable=True) # Version of privacy policy when consent was given
|
||||
|
||||
# Metadata
|
||||
ip_address = Column(String(45), nullable=True)
|
||||
user_agent = Column(String(255), nullable=True)
|
||||
source = Column(String(100), nullable=True) # Where consent was given (registration, cookie_banner, etc.)
|
||||
|
||||
# Additional data
|
||||
extra_metadata = Column(JSON, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
||||
|
||||
# Relationships
|
||||
user = relationship('User', foreign_keys=[user_id])
|
||||
|
||||
70
Backend/src/compliance/models/data_breach.py
Normal file
70
Backend/src/compliance/models/data_breach.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
GDPR Data Breach Notification Model (Article 33-34 GDPR).
|
||||
"""
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum, JSON, Boolean
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
import enum
|
||||
from ...shared.config.database import Base
|
||||
|
||||
class BreachType(str, enum.Enum):
|
||||
"""Types of data breaches."""
|
||||
confidentiality = "confidentiality" # Unauthorized disclosure
|
||||
integrity = "integrity" # Unauthorized alteration
|
||||
availability = "availability" # Unauthorized destruction or loss
|
||||
|
||||
class BreachStatus(str, enum.Enum):
|
||||
"""Status of breach notification."""
|
||||
detected = "detected"
|
||||
investigating = "investigating"
|
||||
contained = "contained"
|
||||
reported_to_authority = "reported_to_authority"
|
||||
notified_data_subjects = "notified_data_subjects"
|
||||
resolved = "resolved"
|
||||
|
||||
class DataBreach(Base):
|
||||
"""Data breach notification record (Articles 33-34 GDPR)."""
|
||||
__tablename__ = 'data_breaches'
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
||||
|
||||
# Breach details
|
||||
breach_type = Column(Enum(BreachType), nullable=False, index=True)
|
||||
status = Column(Enum(BreachStatus), default=BreachStatus.detected, nullable=False, index=True)
|
||||
|
||||
# Description
|
||||
description = Column(Text, nullable=False) # Nature of the breach
|
||||
affected_data_categories = Column(JSON, nullable=True) # Categories of personal data affected
|
||||
affected_data_subjects = Column(JSON, nullable=True) # Approximate number of affected individuals
|
||||
|
||||
# Timeline
|
||||
detected_at = Column(DateTime, nullable=False, index=True)
|
||||
occurred_at = Column(DateTime, nullable=True) # When breach occurred (if known)
|
||||
contained_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Notification
|
||||
reported_to_authority_at = Column(DateTime, nullable=True) # Article 33 - 72 hours
|
||||
authority_reference = Column(String(255), nullable=True) # Reference from supervisory authority
|
||||
notified_data_subjects_at = Column(DateTime, nullable=True) # Article 34 - without undue delay
|
||||
notification_method = Column(String(100), nullable=True) # email, public_notice, etc.
|
||||
|
||||
# Risk assessment
|
||||
likely_consequences = Column(Text, nullable=True)
|
||||
measures_proposed = Column(Text, nullable=True) # Measures to address the breach
|
||||
risk_level = Column(String(20), nullable=True) # low, medium, high
|
||||
|
||||
# Reporting
|
||||
reported_by = Column(Integer, ForeignKey('users.id'), nullable=False) # Who detected/reported
|
||||
investigated_by = Column(Integer, ForeignKey('users.id'), nullable=True) # DPO or responsible person
|
||||
|
||||
# Additional details
|
||||
extra_metadata = Column(JSON, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
||||
|
||||
# Relationships
|
||||
reporter = relationship('User', foreign_keys=[reported_by])
|
||||
investigator = relationship('User', foreign_keys=[investigated_by])
|
||||
|
||||
78
Backend/src/compliance/models/data_processing_record.py
Normal file
78
Backend/src/compliance/models/data_processing_record.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
GDPR Data Processing Records Model (Article 30 GDPR).
|
||||
"""
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum, JSON, Boolean
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
import enum
|
||||
from ...shared.config.database import Base
|
||||
|
||||
class ProcessingCategory(str, enum.Enum):
|
||||
"""Categories of data processing."""
|
||||
collection = "collection"
|
||||
storage = "storage"
|
||||
usage = "usage"
|
||||
sharing = "sharing"
|
||||
deletion = "deletion"
|
||||
anonymization = "anonymization"
|
||||
transfer = "transfer"
|
||||
|
||||
class LegalBasis(str, enum.Enum):
|
||||
"""Legal basis for processing (Article 6 GDPR)."""
|
||||
consent = "consent"
|
||||
contract = "contract"
|
||||
legal_obligation = "legal_obligation"
|
||||
vital_interests = "vital_interests"
|
||||
public_task = "public_task"
|
||||
legitimate_interests = "legitimate_interests"
|
||||
|
||||
class DataProcessingRecord(Base):
|
||||
"""Record of data processing activities (Article 30 GDPR requirement)."""
|
||||
__tablename__ = 'data_processing_records'
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
||||
|
||||
# Processing details
|
||||
processing_category = Column(Enum(ProcessingCategory), nullable=False, index=True)
|
||||
legal_basis = Column(Enum(LegalBasis), nullable=False, index=True)
|
||||
purpose = Column(Text, nullable=False) # Purpose of processing
|
||||
|
||||
# Data categories
|
||||
data_categories = Column(JSON, nullable=True) # List of data categories processed
|
||||
data_subjects = Column(JSON, nullable=True) # Categories of data subjects
|
||||
|
||||
# Recipients
|
||||
recipients = Column(JSON, nullable=True) # Categories of recipients (internal, third_party, etc.)
|
||||
third_parties = Column(JSON, nullable=True) # Specific third parties if any
|
||||
|
||||
# Transfers
|
||||
transfers_to_third_countries = Column(Boolean, default=False, nullable=False)
|
||||
transfer_countries = Column(JSON, nullable=True) # List of countries
|
||||
safeguards = Column(Text, nullable=True) # Safeguards for transfers
|
||||
|
||||
# Retention
|
||||
retention_period = Column(String(100), nullable=True) # How long data is retained
|
||||
retention_criteria = Column(Text, nullable=True) # Criteria for determining retention period
|
||||
|
||||
# Security measures
|
||||
security_measures = Column(Text, nullable=True)
|
||||
|
||||
# Related entities
|
||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=True, index=True) # If specific to a user
|
||||
related_booking_id = Column(Integer, nullable=True, index=True)
|
||||
related_payment_id = Column(Integer, nullable=True, index=True)
|
||||
|
||||
# Processing details
|
||||
processed_by = Column(Integer, ForeignKey('users.id'), nullable=True) # Staff/admin who processed
|
||||
processing_timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
|
||||
# Additional metadata
|
||||
extra_metadata = Column(JSON, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
|
||||
# Relationships
|
||||
user = relationship('User', foreign_keys=[user_id])
|
||||
processor = relationship('User', foreign_keys=[processed_by])
|
||||
|
||||
75
Backend/src/compliance/models/data_retention.py
Normal file
75
Backend/src/compliance/models/data_retention.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
GDPR Data Retention Policy Model.
|
||||
"""
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum, JSON, Boolean
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime, timedelta
|
||||
import enum
|
||||
from ...shared.config.database import Base
|
||||
|
||||
class RetentionRule(Base):
|
||||
"""Data retention rules for different data types."""
|
||||
__tablename__ = 'retention_rules'
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
||||
|
||||
# Rule details
|
||||
data_category = Column(String(100), nullable=False, unique=True, index=True) # user_data, booking_data, payment_data, etc.
|
||||
retention_period_days = Column(Integer, nullable=False) # Number of days to retain
|
||||
retention_period_months = Column(Integer, nullable=True) # Alternative: months
|
||||
retention_period_years = Column(Integer, nullable=True) # Alternative: years
|
||||
|
||||
# Legal basis
|
||||
legal_basis = Column(Text, nullable=True) # Why we retain for this period
|
||||
legal_requirement = Column(Text, nullable=True) # Specific legal requirement if any
|
||||
|
||||
# Action after retention
|
||||
action_after_retention = Column(String(50), nullable=False, default='anonymize') # delete, anonymize, archive
|
||||
|
||||
# Conditions
|
||||
conditions = Column(JSON, nullable=True) # Additional conditions (e.g., active bookings)
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Metadata
|
||||
description = Column(Text, nullable=True)
|
||||
created_by = Column(Integer, ForeignKey('users.id'), nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
||||
|
||||
# Relationships
|
||||
creator = relationship('User', foreign_keys=[created_by])
|
||||
|
||||
class DataRetentionLog(Base):
|
||||
"""Log of data retention actions performed."""
|
||||
__tablename__ = 'data_retention_logs'
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
||||
|
||||
# Retention action
|
||||
retention_rule_id = Column(Integer, ForeignKey('retention_rules.id'), nullable=False, index=True)
|
||||
data_category = Column(String(100), nullable=False, index=True)
|
||||
action_taken = Column(String(50), nullable=False) # deleted, anonymized, archived
|
||||
|
||||
# Affected records
|
||||
records_affected = Column(Integer, nullable=False, default=0)
|
||||
affected_ids = Column(JSON, nullable=True) # IDs of affected records (for audit)
|
||||
|
||||
# Execution
|
||||
executed_by = Column(Integer, ForeignKey('users.id'), nullable=True) # System or admin
|
||||
executed_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
|
||||
# Results
|
||||
success = Column(Boolean, default=True, nullable=False)
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
# Metadata
|
||||
extra_metadata = Column(JSON, nullable=True)
|
||||
|
||||
# Relationships
|
||||
retention_rule = relationship('RetentionRule', foreign_keys=[retention_rule_id])
|
||||
executor = relationship('User', foreign_keys=[executed_by])
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
GDPR compliance models for data export and deletion requests.
|
||||
"""
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum, JSON
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum, JSON, Boolean
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
import enum
|
||||
@@ -27,9 +27,10 @@ class GDPRRequest(Base):
|
||||
request_type = Column(Enum(GDPRRequestType), nullable=False, index=True)
|
||||
status = Column(Enum(GDPRRequestStatus), default=GDPRRequestStatus.pending, nullable=False, index=True)
|
||||
|
||||
# User making the request
|
||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True)
|
||||
user_email = Column(String(255), nullable=False) # Store email even if user is deleted
|
||||
# User making the request (nullable for anonymous users)
|
||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=True, index=True)
|
||||
user_email = Column(String(255), nullable=False) # Required: email for anonymous or registered users
|
||||
is_anonymous = Column(Boolean, default=False, nullable=False, index=True) # Flag for anonymous requests
|
||||
|
||||
# Request details
|
||||
request_data = Column(JSON, nullable=True) # Additional request parameters
|
||||
|
||||
Binary file not shown.
Binary file not shown.
340
Backend/src/compliance/routes/gdpr_admin_routes.py
Normal file
340
Backend/src/compliance/routes/gdpr_admin_routes.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""
|
||||
Admin routes for GDPR compliance management.
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Body
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional, Dict, Any, List
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
from ...shared.config.database import get_db
|
||||
from ...shared.config.logging_config import get_logger
|
||||
from ...security.middleware.auth import authorize_roles
|
||||
from ...auth.models.user import User
|
||||
from ..services.breach_service import breach_service
|
||||
from ..services.retention_service import retention_service
|
||||
from ..services.data_processing_service import data_processing_service
|
||||
from ..models.data_breach import BreachType, BreachStatus
|
||||
from ...shared.utils.response_helpers import success_response
|
||||
|
||||
logger = get_logger(__name__)
|
||||
router = APIRouter(prefix='/gdpr/admin', tags=['gdpr-admin'])
|
||||
|
||||
# Data Breach Management
|
||||
|
||||
class BreachCreateRequest(BaseModel):
|
||||
breach_type: str
|
||||
description: str
|
||||
affected_data_categories: Optional[List[str]] = None
|
||||
affected_data_subjects: Optional[int] = None
|
||||
occurred_at: Optional[str] = None
|
||||
likely_consequences: Optional[str] = None
|
||||
measures_proposed: Optional[str] = None
|
||||
risk_level: Optional[str] = None
|
||||
|
||||
@router.post('/breaches')
|
||||
async def create_breach(
|
||||
breach_data: BreachCreateRequest,
|
||||
current_user: User = Depends(authorize_roles('admin')),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Create a data breach record (admin only)."""
|
||||
try:
|
||||
try:
|
||||
breach_type_enum = BreachType(breach_data.breach_type)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f'Invalid breach type: {breach_data.breach_type}')
|
||||
|
||||
occurred_at = None
|
||||
if breach_data.occurred_at:
|
||||
occurred_at = datetime.fromisoformat(breach_data.occurred_at.replace('Z', '+00:00'))
|
||||
|
||||
breach = await breach_service.create_breach(
|
||||
db=db,
|
||||
breach_type=breach_type_enum,
|
||||
description=breach_data.description,
|
||||
reported_by=current_user.id,
|
||||
affected_data_categories=breach_data.affected_data_categories,
|
||||
affected_data_subjects=breach_data.affected_data_subjects,
|
||||
occurred_at=occurred_at,
|
||||
likely_consequences=breach_data.likely_consequences,
|
||||
measures_proposed=breach_data.measures_proposed,
|
||||
risk_level=breach_data.risk_level
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
'breach_id': breach.id,
|
||||
'status': breach.status.value,
|
||||
'detected_at': breach.detected_at.isoformat()
|
||||
},
|
||||
message='Data breach record created'
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating breach: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get('/breaches')
|
||||
async def get_breaches(
|
||||
status: Optional[str] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
current_user: User = Depends(authorize_roles('admin')),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get all data breaches (admin only)."""
|
||||
try:
|
||||
status_enum = None
|
||||
if status:
|
||||
try:
|
||||
status_enum = BreachStatus(status)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f'Invalid status: {status}')
|
||||
|
||||
offset = (page - 1) * limit
|
||||
breaches = breach_service.get_breaches(
|
||||
db=db,
|
||||
status=status_enum,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
return success_response(data={
|
||||
'breaches': [{
|
||||
'id': breach.id,
|
||||
'breach_type': breach.breach_type.value,
|
||||
'status': breach.status.value,
|
||||
'description': breach.description,
|
||||
'risk_level': breach.risk_level,
|
||||
'detected_at': breach.detected_at.isoformat(),
|
||||
'reported_to_authority_at': breach.reported_to_authority_at.isoformat() if breach.reported_to_authority_at else None,
|
||||
'notified_data_subjects_at': breach.notified_data_subjects_at.isoformat() if breach.notified_data_subjects_at else None,
|
||||
} for breach in breaches]
|
||||
})
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f'Error getting breaches: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post('/breaches/{breach_id}/report-authority')
|
||||
async def report_breach_to_authority(
|
||||
breach_id: int,
|
||||
authority_reference: str = Body(...),
|
||||
current_user: User = Depends(authorize_roles('admin')),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Report breach to supervisory authority (admin only)."""
|
||||
try:
|
||||
breach = await breach_service.report_to_authority(
|
||||
db=db,
|
||||
breach_id=breach_id,
|
||||
authority_reference=authority_reference,
|
||||
reported_by=current_user.id
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
'breach_id': breach.id,
|
||||
'authority_reference': breach.authority_reference,
|
||||
'reported_at': breach.reported_to_authority_at.isoformat()
|
||||
},
|
||||
message='Breach reported to supervisory authority'
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f'Error reporting breach: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post('/breaches/{breach_id}/notify-subjects')
|
||||
async def notify_data_subjects(
|
||||
breach_id: int,
|
||||
notification_method: str = Body(...),
|
||||
current_user: User = Depends(authorize_roles('admin')),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Notify affected data subjects (admin only)."""
|
||||
try:
|
||||
breach = await breach_service.notify_data_subjects(
|
||||
db=db,
|
||||
breach_id=breach_id,
|
||||
notification_method=notification_method,
|
||||
notified_by=current_user.id
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
'breach_id': breach.id,
|
||||
'notification_method': breach.notification_method,
|
||||
'notified_at': breach.notified_data_subjects_at.isoformat()
|
||||
},
|
||||
message='Data subjects notified'
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f'Error notifying subjects: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Data Retention Management
|
||||
|
||||
class RetentionRuleCreateRequest(BaseModel):
|
||||
data_category: str
|
||||
retention_period_days: int
|
||||
retention_period_months: Optional[int] = None
|
||||
retention_period_years: Optional[int] = None
|
||||
legal_basis: Optional[str] = None
|
||||
legal_requirement: Optional[str] = None
|
||||
action_after_retention: str = 'anonymize'
|
||||
conditions: Optional[Dict[str, Any]] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
@router.post('/retention-rules')
|
||||
async def create_retention_rule(
|
||||
rule_data: RetentionRuleCreateRequest,
|
||||
current_user: User = Depends(authorize_roles('admin')),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Create a data retention rule (admin only)."""
|
||||
try:
|
||||
rule = retention_service.create_retention_rule(
|
||||
db=db,
|
||||
data_category=rule_data.data_category,
|
||||
retention_period_days=rule_data.retention_period_days,
|
||||
retention_period_months=rule_data.retention_period_months,
|
||||
retention_period_years=rule_data.retention_period_years,
|
||||
legal_basis=rule_data.legal_basis,
|
||||
legal_requirement=rule_data.legal_requirement,
|
||||
action_after_retention=rule_data.action_after_retention,
|
||||
conditions=rule_data.conditions,
|
||||
description=rule_data.description,
|
||||
created_by=current_user.id
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
'rule_id': rule.id,
|
||||
'data_category': rule.data_category,
|
||||
'retention_period_days': rule.retention_period_days
|
||||
},
|
||||
message='Retention rule created successfully'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating retention rule: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get('/retention-rules')
|
||||
async def get_retention_rules(
|
||||
is_active: Optional[bool] = Query(None),
|
||||
current_user: User = Depends(authorize_roles('admin')),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get retention rules (admin only)."""
|
||||
try:
|
||||
rules = retention_service.get_retention_rules(db=db, is_active=is_active)
|
||||
|
||||
return success_response(data={
|
||||
'rules': [{
|
||||
'id': rule.id,
|
||||
'data_category': rule.data_category,
|
||||
'retention_period_days': rule.retention_period_days,
|
||||
'action_after_retention': rule.action_after_retention,
|
||||
'is_active': rule.is_active,
|
||||
'legal_basis': rule.legal_basis
|
||||
} for rule in rules]
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f'Error getting retention rules: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get('/retention-logs')
|
||||
async def get_retention_logs(
|
||||
data_category: Optional[str] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
current_user: User = Depends(authorize_roles('admin')),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get retention action logs (admin only)."""
|
||||
try:
|
||||
offset = (page - 1) * limit
|
||||
logs = retention_service.get_retention_logs(
|
||||
db=db,
|
||||
data_category=data_category,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
return success_response(data={
|
||||
'logs': [{
|
||||
'id': log.id,
|
||||
'data_category': log.data_category,
|
||||
'action_taken': log.action_taken,
|
||||
'records_affected': log.records_affected,
|
||||
'executed_at': log.executed_at.isoformat(),
|
||||
'success': log.success
|
||||
} for log in logs]
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f'Error getting retention logs: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Data Processing Records (Admin View)
|
||||
|
||||
@router.get('/processing-records')
|
||||
async def get_all_processing_records(
|
||||
user_id: Optional[int] = Query(None),
|
||||
processing_category: Optional[str] = Query(None),
|
||||
legal_basis: Optional[str] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
current_user: User = Depends(authorize_roles('admin')),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get all data processing records (admin only)."""
|
||||
try:
|
||||
from ..models.data_processing_record import ProcessingCategory, LegalBasis
|
||||
|
||||
category_enum = None
|
||||
if processing_category:
|
||||
try:
|
||||
category_enum = ProcessingCategory(processing_category)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f'Invalid processing category: {processing_category}')
|
||||
|
||||
basis_enum = None
|
||||
if legal_basis:
|
||||
try:
|
||||
basis_enum = LegalBasis(legal_basis)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f'Invalid legal basis: {legal_basis}')
|
||||
|
||||
offset = (page - 1) * limit
|
||||
records = data_processing_service.get_processing_records(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
processing_category=category_enum,
|
||||
legal_basis=basis_enum,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
return success_response(data={
|
||||
'records': [{
|
||||
'id': record.id,
|
||||
'processing_category': record.processing_category.value,
|
||||
'legal_basis': record.legal_basis.value,
|
||||
'purpose': record.purpose,
|
||||
'processing_timestamp': record.processing_timestamp.isoformat(),
|
||||
'user_id': record.user_id
|
||||
} for record in records]
|
||||
})
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f'Error getting processing records: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -3,46 +3,78 @@ GDPR compliance routes for data export and deletion.
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response
|
||||
from sqlalchemy.orm import Session, noload
|
||||
from sqlalchemy import or_
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from ...shared.config.database import get_db
|
||||
from ...shared.config.logging_config import get_logger
|
||||
from ...security.middleware.auth import get_current_user, authorize_roles
|
||||
from ...security.middleware.auth import get_current_user, authorize_roles, get_current_user_optional
|
||||
from ...auth.models.user import User
|
||||
from ..services.gdpr_service import gdpr_service
|
||||
from ..services.consent_service import consent_service
|
||||
from ..services.data_processing_service import data_processing_service
|
||||
from ..models.gdpr_request import GDPRRequest, GDPRRequestType, GDPRRequestStatus
|
||||
from ..models.consent import ConsentType, ConsentStatus
|
||||
from ...shared.utils.response_helpers import success_response
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
logger = get_logger(__name__)
|
||||
router = APIRouter(prefix='/gdpr', tags=['gdpr'])
|
||||
|
||||
class AnonymousExportRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
@router.post('/export')
|
||||
async def request_data_export(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
anonymous_request: Optional[AnonymousExportRequest] = None,
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Request export of user's personal data (GDPR)."""
|
||||
"""Request export of user's personal data (GDPR) - supports both authenticated and anonymous users."""
|
||||
try:
|
||||
client_ip = request.client.host if request.client else None
|
||||
user_agent = request.headers.get('User-Agent')
|
||||
|
||||
gdpr_request = await gdpr_service.create_data_export_request(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent
|
||||
)
|
||||
# Check if authenticated or anonymous
|
||||
if current_user:
|
||||
# Authenticated user
|
||||
gdpr_request = await gdpr_service.create_data_export_request(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
is_anonymous=False
|
||||
)
|
||||
elif anonymous_request and anonymous_request.email:
|
||||
# Anonymous user - requires email
|
||||
gdpr_request = await gdpr_service.create_data_export_request(
|
||||
db=db,
|
||||
user_email=anonymous_request.email,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
is_anonymous=True
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail='Either authentication required or email must be provided for anonymous requests'
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
'request_id': gdpr_request.id,
|
||||
'verification_token': gdpr_request.verification_token,
|
||||
'status': gdpr_request.status.value,
|
||||
'expires_at': gdpr_request.expires_at.isoformat() if gdpr_request.expires_at else None
|
||||
'expires_at': gdpr_request.expires_at.isoformat() if gdpr_request.expires_at else None,
|
||||
'is_anonymous': gdpr_request.is_anonymous
|
||||
},
|
||||
message='Data export request created. You will receive an email with download link once ready.'
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating data export request: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -51,20 +83,26 @@ async def request_data_export(
|
||||
async def get_export_data(
|
||||
request_id: int,
|
||||
verification_token: str = Query(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get exported user data."""
|
||||
"""Get exported user data - supports both authenticated and anonymous users via verification token."""
|
||||
try:
|
||||
gdpr_request = db.query(GDPRRequest).options(
|
||||
# Build query - verification token is required for both authenticated and anonymous
|
||||
query = db.query(GDPRRequest).options(
|
||||
noload(GDPRRequest.user),
|
||||
noload(GDPRRequest.processor)
|
||||
).filter(
|
||||
GDPRRequest.id == request_id,
|
||||
GDPRRequest.user_id == current_user.id,
|
||||
GDPRRequest.verification_token == verification_token,
|
||||
GDPRRequest.request_type == GDPRRequestType.data_export
|
||||
).first()
|
||||
)
|
||||
|
||||
# For authenticated users, also verify user_id matches
|
||||
if current_user:
|
||||
query = query.filter(GDPRRequest.user_id == current_user.id)
|
||||
|
||||
gdpr_request = query.first()
|
||||
|
||||
if not gdpr_request:
|
||||
raise HTTPException(status_code=404, detail='Export request not found or invalid token')
|
||||
@@ -73,8 +111,10 @@ async def get_export_data(
|
||||
# Process export
|
||||
export_data = await gdpr_service.export_user_data(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
request_id=request_id
|
||||
user_id=gdpr_request.user_id,
|
||||
user_email=gdpr_request.user_email,
|
||||
request_id=request_id,
|
||||
is_anonymous=gdpr_request.is_anonymous
|
||||
)
|
||||
return success_response(data=export_data)
|
||||
elif gdpr_request.status == GDPRRequestStatus.completed and gdpr_request.export_file_path:
|
||||
@@ -97,32 +137,57 @@ async def get_export_data(
|
||||
logger.error(f'Error getting export data: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
class AnonymousDeletionRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
@router.post('/delete')
|
||||
async def request_data_deletion(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
anonymous_request: Optional[AnonymousDeletionRequest] = None,
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Request deletion of user's personal data (GDPR - Right to be Forgotten)."""
|
||||
"""Request deletion of user's personal data (GDPR - Right to be Forgotten) - supports anonymous users."""
|
||||
try:
|
||||
client_ip = request.client.host if request.client else None
|
||||
user_agent = request.headers.get('User-Agent')
|
||||
|
||||
gdpr_request = await gdpr_service.create_data_deletion_request(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent
|
||||
)
|
||||
# Check if authenticated or anonymous
|
||||
if current_user:
|
||||
# Authenticated user
|
||||
gdpr_request = await gdpr_service.create_data_deletion_request(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
is_anonymous=False
|
||||
)
|
||||
elif anonymous_request and anonymous_request.email:
|
||||
# Anonymous user - requires email
|
||||
gdpr_request = await gdpr_service.create_data_deletion_request(
|
||||
db=db,
|
||||
user_email=anonymous_request.email,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
is_anonymous=True
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail='Either authentication required or email must be provided for anonymous requests'
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
'request_id': gdpr_request.id,
|
||||
'verification_token': gdpr_request.verification_token,
|
||||
'status': gdpr_request.status.value
|
||||
'status': gdpr_request.status.value,
|
||||
'is_anonymous': gdpr_request.is_anonymous
|
||||
},
|
||||
message='Data deletion request created. Please verify via email to proceed.'
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating data deletion request: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -131,21 +196,27 @@ async def request_data_deletion(
|
||||
async def confirm_data_deletion(
|
||||
request_id: int,
|
||||
verification_token: str = Query(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: Optional[User] = Depends(get_current_user_optional),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Confirm and process data deletion request."""
|
||||
"""Confirm and process data deletion request - supports anonymous users via verification token."""
|
||||
try:
|
||||
gdpr_request = db.query(GDPRRequest).options(
|
||||
# Build query - verification token is required for both authenticated and anonymous
|
||||
query = db.query(GDPRRequest).options(
|
||||
noload(GDPRRequest.user),
|
||||
noload(GDPRRequest.processor)
|
||||
).filter(
|
||||
GDPRRequest.id == request_id,
|
||||
GDPRRequest.user_id == current_user.id,
|
||||
GDPRRequest.verification_token == verification_token,
|
||||
GDPRRequest.request_type == GDPRRequestType.data_deletion,
|
||||
GDPRRequest.status == GDPRRequestStatus.pending
|
||||
).first()
|
||||
)
|
||||
|
||||
# For authenticated users, also verify user_id matches
|
||||
if current_user:
|
||||
query = query.filter(GDPRRequest.user_id == current_user.id)
|
||||
|
||||
gdpr_request = query.first()
|
||||
|
||||
if not gdpr_request:
|
||||
raise HTTPException(status_code=404, detail='Deletion request not found or already processed')
|
||||
@@ -153,14 +224,16 @@ async def confirm_data_deletion(
|
||||
# Process deletion
|
||||
deletion_log = await gdpr_service.delete_user_data(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
user_id=gdpr_request.user_id,
|
||||
user_email=gdpr_request.user_email,
|
||||
request_id=request_id,
|
||||
processed_by=current_user.id
|
||||
processed_by=current_user.id if current_user else None,
|
||||
is_anonymous=gdpr_request.is_anonymous
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data=deletion_log,
|
||||
message='Your data has been deleted successfully.'
|
||||
message=deletion_log.get('summary', {}).get('message', 'Your data has been deleted successfully.')
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -173,13 +246,17 @@ async def get_user_gdpr_requests(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get user's GDPR requests."""
|
||||
"""Get user's GDPR requests (both authenticated and anonymous requests by email)."""
|
||||
try:
|
||||
# Get requests by user_id (authenticated) or by email (includes anonymous)
|
||||
requests = db.query(GDPRRequest).options(
|
||||
noload(GDPRRequest.user),
|
||||
noload(GDPRRequest.processor)
|
||||
).filter(
|
||||
GDPRRequest.user_id == current_user.id
|
||||
or_(
|
||||
GDPRRequest.user_id == current_user.id,
|
||||
GDPRRequest.user_email == current_user.email
|
||||
)
|
||||
).order_by(GDPRRequest.created_at.desc()).all()
|
||||
|
||||
return success_response(data={
|
||||
@@ -187,6 +264,7 @@ async def get_user_gdpr_requests(
|
||||
'id': req.id,
|
||||
'request_type': req.request_type.value,
|
||||
'status': req.status.value,
|
||||
'is_anonymous': req.is_anonymous,
|
||||
'created_at': req.created_at.isoformat() if req.created_at else None,
|
||||
'processed_at': req.processed_at.isoformat() if req.processed_at else None,
|
||||
} for req in requests]
|
||||
@@ -270,3 +348,272 @@ async def delete_gdpr_request(
|
||||
logger.error(f'Error deleting GDPR request: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# GDPR Rights - Additional Routes
|
||||
|
||||
class DataRectificationRequest(BaseModel):
|
||||
corrections: Dict[str, Any] # e.g., {"full_name": "New Name", "email": "new@email.com"}
|
||||
|
||||
@router.post('/rectify')
|
||||
async def request_data_rectification(
|
||||
request: Request,
|
||||
rectification_data: DataRectificationRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Request data rectification (Article 16 GDPR - Right to rectification)."""
|
||||
try:
|
||||
client_ip = request.client.host if request.client else None
|
||||
user_agent = request.headers.get('User-Agent')
|
||||
|
||||
gdpr_request = await gdpr_service.request_data_rectification(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
corrections=rectification_data.corrections,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
'request_id': gdpr_request.id,
|
||||
'verification_token': gdpr_request.verification_token,
|
||||
'status': gdpr_request.status.value
|
||||
},
|
||||
message='Data rectification request created. An admin will review and process your request.'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating rectification request: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
class ProcessingRestrictionRequest(BaseModel):
|
||||
reason: str
|
||||
|
||||
@router.post('/restrict')
|
||||
async def request_processing_restriction(
|
||||
request: Request,
|
||||
restriction_data: ProcessingRestrictionRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Request restriction of processing (Article 18 GDPR)."""
|
||||
try:
|
||||
client_ip = request.client.host if request.client else None
|
||||
user_agent = request.headers.get('User-Agent')
|
||||
|
||||
gdpr_request = await gdpr_service.request_processing_restriction(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
reason=restriction_data.reason,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
'request_id': gdpr_request.id,
|
||||
'verification_token': gdpr_request.verification_token,
|
||||
'status': gdpr_request.status.value
|
||||
},
|
||||
message='Processing restriction request created. Your account has been temporarily restricted.'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating restriction request: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
class ProcessingObjectionRequest(BaseModel):
|
||||
processing_purpose: str
|
||||
reason: Optional[str] = None
|
||||
|
||||
@router.post('/object')
|
||||
async def request_processing_objection(
|
||||
request: Request,
|
||||
objection_data: ProcessingObjectionRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Object to processing (Article 21 GDPR - Right to object)."""
|
||||
try:
|
||||
client_ip = request.client.host if request.client else None
|
||||
user_agent = request.headers.get('User-Agent')
|
||||
|
||||
gdpr_request = await gdpr_service.request_processing_objection(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
processing_purpose=objection_data.processing_purpose,
|
||||
reason=objection_data.reason,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
'request_id': gdpr_request.id,
|
||||
'verification_token': gdpr_request.verification_token,
|
||||
'status': gdpr_request.status.value
|
||||
},
|
||||
message='Processing objection registered. We will review your objection and stop processing for the specified purpose if valid.'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating objection request: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Consent Management Routes
|
||||
|
||||
class ConsentUpdateRequest(BaseModel):
|
||||
consents: Dict[str, bool] # e.g., {"marketing": true, "analytics": false}
|
||||
|
||||
@router.get('/consents')
|
||||
async def get_user_consents(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get user's consent status for all consent types."""
|
||||
try:
|
||||
consents = consent_service.get_user_consents(db=db, user_id=current_user.id, include_withdrawn=True)
|
||||
|
||||
consent_status = {}
|
||||
for consent_type in ConsentType:
|
||||
consent_status[consent_type.value] = {
|
||||
'has_consent': consent_service.has_consent(db=db, user_id=current_user.id, consent_type=consent_type),
|
||||
'granted_at': None,
|
||||
'withdrawn_at': None,
|
||||
'status': 'none'
|
||||
}
|
||||
|
||||
for consent in consents:
|
||||
consent_status[consent.consent_type.value] = {
|
||||
'has_consent': consent.status == ConsentStatus.granted and (not consent.expires_at or consent.expires_at > datetime.utcnow()),
|
||||
'granted_at': consent.granted_at.isoformat() if consent.granted_at else None,
|
||||
'withdrawn_at': consent.withdrawn_at.isoformat() if consent.withdrawn_at else None,
|
||||
'status': consent.status.value,
|
||||
'expires_at': consent.expires_at.isoformat() if consent.expires_at else None
|
||||
}
|
||||
|
||||
return success_response(data={'consents': consent_status})
|
||||
except Exception as e:
|
||||
logger.error(f'Error getting consents: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post('/consents')
|
||||
async def update_consents(
|
||||
request: Request,
|
||||
consent_data: ConsentUpdateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Update user consent preferences."""
|
||||
try:
|
||||
client_ip = request.client.host if request.client else None
|
||||
user_agent = request.headers.get('User-Agent')
|
||||
|
||||
# Convert string keys to ConsentType enum
|
||||
consents_dict = {}
|
||||
for key, value in consent_data.consents.items():
|
||||
try:
|
||||
consent_type = ConsentType(key)
|
||||
consents_dict[consent_type] = value
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
results = await consent_service.update_consent_preferences(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
consents=consents_dict,
|
||||
legal_basis='consent',
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
source='gdpr_page'
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data={'updated_consents': len(results)},
|
||||
message='Consent preferences updated successfully'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Error updating consents: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post('/consents/{consent_type}/withdraw')
|
||||
async def withdraw_consent(
|
||||
request: Request,
|
||||
consent_type: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Withdraw specific consent (Article 7(3) GDPR)."""
|
||||
try:
|
||||
client_ip = request.client.host if request.client else None
|
||||
user_agent = request.headers.get('User-Agent')
|
||||
|
||||
try:
|
||||
consent_type_enum = ConsentType(consent_type)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f'Invalid consent type: {consent_type}')
|
||||
|
||||
consent = await consent_service.withdraw_consent(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
consent_type=consent_type_enum,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
'consent_id': consent.id,
|
||||
'consent_type': consent.consent_type.value,
|
||||
'withdrawn_at': consent.withdrawn_at.isoformat() if consent.withdrawn_at else None
|
||||
},
|
||||
message=f'Consent for {consent_type} withdrawn successfully'
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f'Error withdrawing consent: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Data Processing Records (User View)
|
||||
|
||||
@router.get('/processing-records')
|
||||
async def get_user_processing_records(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get data processing records for the user (Article 15 GDPR - Right of access)."""
|
||||
try:
|
||||
summary = data_processing_service.get_user_processing_summary(
|
||||
db=db,
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
return success_response(data=summary)
|
||||
except Exception as e:
|
||||
logger.error(f'Error getting processing records: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Admin Routes for Processing Requests
|
||||
|
||||
@router.post('/admin/rectify/{request_id}/process')
|
||||
async def process_rectification(
|
||||
request_id: int,
|
||||
current_user: User = Depends(authorize_roles('admin')),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Process data rectification request (admin only)."""
|
||||
try:
|
||||
result = await gdpr_service.process_data_rectification(
|
||||
db=db,
|
||||
request_id=request_id,
|
||||
processed_by=current_user.id
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data=result,
|
||||
message='Data rectification processed successfully'
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f'Error processing rectification: {str(e)}', exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
169
Backend/src/compliance/services/breach_service.py
Normal file
169
Backend/src/compliance/services/breach_service.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Data Breach Notification Service (Articles 33-34 GDPR).
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
from ..models.data_breach import DataBreach, BreachType, BreachStatus
|
||||
from ...shared.config.logging_config import get_logger
|
||||
from ...analytics.services.audit_service import audit_service
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class BreachService:
|
||||
"""Service for managing data breach notifications (Articles 33-34 GDPR)."""
|
||||
|
||||
NOTIFICATION_DEADLINE_HOURS = 72 # Article 33 - 72 hours to notify authority
|
||||
|
||||
@staticmethod
|
||||
async def create_breach(
|
||||
db: Session,
|
||||
breach_type: BreachType,
|
||||
description: str,
|
||||
reported_by: int,
|
||||
affected_data_categories: Optional[List[str]] = None,
|
||||
affected_data_subjects: Optional[int] = None,
|
||||
occurred_at: Optional[datetime] = None,
|
||||
likely_consequences: Optional[str] = None,
|
||||
measures_proposed: Optional[str] = None,
|
||||
risk_level: Optional[str] = None,
|
||||
extra_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> DataBreach:
|
||||
"""Create a data breach record."""
|
||||
breach = DataBreach(
|
||||
breach_type=breach_type,
|
||||
status=BreachStatus.detected,
|
||||
description=description,
|
||||
affected_data_categories=affected_data_categories or [],
|
||||
affected_data_subjects=affected_data_subjects,
|
||||
detected_at=datetime.utcnow(),
|
||||
occurred_at=occurred_at or datetime.utcnow(),
|
||||
likely_consequences=likely_consequences,
|
||||
measures_proposed=measures_proposed,
|
||||
risk_level=risk_level or 'medium',
|
||||
reported_by=reported_by,
|
||||
extra_metadata=extra_metadata
|
||||
)
|
||||
|
||||
db.add(breach)
|
||||
db.commit()
|
||||
db.refresh(breach)
|
||||
|
||||
# Log breach detection
|
||||
await audit_service.log_action(
|
||||
db=db,
|
||||
action='data_breach_detected',
|
||||
resource_type='data_breach',
|
||||
user_id=reported_by,
|
||||
resource_id=breach.id,
|
||||
details={
|
||||
'breach_type': breach_type.value,
|
||||
'risk_level': risk_level,
|
||||
'affected_subjects': affected_data_subjects
|
||||
},
|
||||
status='warning'
|
||||
)
|
||||
|
||||
logger.warning(f'Data breach detected: {breach.id} - {breach_type.value}')
|
||||
return breach
|
||||
|
||||
@staticmethod
|
||||
async def report_to_authority(
|
||||
db: Session,
|
||||
breach_id: int,
|
||||
authority_reference: str,
|
||||
reported_by: int
|
||||
) -> DataBreach:
|
||||
"""Report breach to supervisory authority (Article 33)."""
|
||||
breach = db.query(DataBreach).filter(DataBreach.id == breach_id).first()
|
||||
if not breach:
|
||||
raise ValueError('Breach not found')
|
||||
|
||||
breach.status = BreachStatus.reported_to_authority
|
||||
breach.reported_to_authority_at = datetime.utcnow()
|
||||
breach.authority_reference = authority_reference
|
||||
|
||||
db.commit()
|
||||
db.refresh(breach)
|
||||
|
||||
# Check if within deadline
|
||||
time_since_detection = datetime.utcnow() - breach.detected_at
|
||||
if time_since_detection > timedelta(hours=BreachService.NOTIFICATION_DEADLINE_HOURS):
|
||||
logger.warning(f'Breach {breach_id} reported after {BreachService.NOTIFICATION_DEADLINE_HOURS} hour deadline')
|
||||
|
||||
# Log report
|
||||
await audit_service.log_action(
|
||||
db=db,
|
||||
action='breach_reported_to_authority',
|
||||
resource_type='data_breach',
|
||||
user_id=reported_by,
|
||||
resource_id=breach_id,
|
||||
details={'authority_reference': authority_reference},
|
||||
status='success'
|
||||
)
|
||||
|
||||
logger.info(f'Breach {breach_id} reported to authority: {authority_reference}')
|
||||
return breach
|
||||
|
||||
@staticmethod
|
||||
async def notify_data_subjects(
|
||||
db: Session,
|
||||
breach_id: int,
|
||||
notification_method: str,
|
||||
notified_by: int
|
||||
) -> DataBreach:
|
||||
"""Notify affected data subjects (Article 34)."""
|
||||
breach = db.query(DataBreach).filter(DataBreach.id == breach_id).first()
|
||||
if not breach:
|
||||
raise ValueError('Breach not found')
|
||||
|
||||
breach.status = BreachStatus.notified_data_subjects
|
||||
breach.notified_data_subjects_at = datetime.utcnow()
|
||||
breach.notification_method = notification_method
|
||||
|
||||
db.commit()
|
||||
db.refresh(breach)
|
||||
|
||||
# Log notification
|
||||
await audit_service.log_action(
|
||||
db=db,
|
||||
action='breach_subjects_notified',
|
||||
resource_type='data_breach',
|
||||
user_id=notified_by,
|
||||
resource_id=breach_id,
|
||||
details={'notification_method': notification_method},
|
||||
status='success'
|
||||
)
|
||||
|
||||
logger.info(f'Data subjects notified for breach {breach_id}')
|
||||
return breach
|
||||
|
||||
@staticmethod
|
||||
def get_breaches(
|
||||
db: Session,
|
||||
status: Optional[BreachStatus] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[DataBreach]:
|
||||
"""Get data breaches with optional filters."""
|
||||
query = db.query(DataBreach)
|
||||
|
||||
if status:
|
||||
query = query.filter(DataBreach.status == status)
|
||||
|
||||
return query.order_by(DataBreach.detected_at.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
def get_breaches_requiring_notification(
|
||||
db: Session
|
||||
) -> List[DataBreach]:
|
||||
"""Get breaches that require notification (not yet reported)."""
|
||||
deadline = datetime.utcnow() - timedelta(hours=BreachService.NOTIFICATION_DEADLINE_HOURS)
|
||||
|
||||
return db.query(DataBreach).filter(
|
||||
DataBreach.status.in_([BreachStatus.detected, BreachStatus.investigating]),
|
||||
DataBreach.detected_at < deadline
|
||||
).all()
|
||||
|
||||
breach_service = BreachService()
|
||||
|
||||
202
Backend/src/compliance/services/consent_service.py
Normal file
202
Backend/src/compliance/services/consent_service.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
GDPR Consent Management Service.
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
from ..models.consent import Consent, ConsentType, ConsentStatus
|
||||
from ...auth.models.user import User
|
||||
from ...shared.config.logging_config import get_logger
|
||||
from ...analytics.services.audit_service import audit_service
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class ConsentService:
|
||||
"""Service for managing user consent (Article 7 GDPR)."""
|
||||
|
||||
@staticmethod
|
||||
async def grant_consent(
|
||||
db: Session,
|
||||
user_id: int,
|
||||
consent_type: ConsentType,
|
||||
legal_basis: str,
|
||||
consent_method: str = 'explicit',
|
||||
consent_version: Optional[str] = None,
|
||||
expires_at: Optional[datetime] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
extra_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Consent:
|
||||
"""Grant consent for a specific purpose."""
|
||||
# Withdraw any existing consent of this type
|
||||
existing = db.query(Consent).filter(
|
||||
Consent.user_id == user_id,
|
||||
Consent.consent_type == consent_type,
|
||||
Consent.status == ConsentStatus.granted
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
existing.status = ConsentStatus.withdrawn
|
||||
existing.withdrawn_at = datetime.utcnow()
|
||||
|
||||
# Create new consent
|
||||
consent = Consent(
|
||||
user_id=user_id,
|
||||
consent_type=consent_type,
|
||||
status=ConsentStatus.granted,
|
||||
granted_at=datetime.utcnow(),
|
||||
expires_at=expires_at,
|
||||
legal_basis=legal_basis,
|
||||
consent_method=consent_method,
|
||||
consent_version=consent_version,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
source=source,
|
||||
extra_metadata=extra_metadata
|
||||
)
|
||||
|
||||
db.add(consent)
|
||||
db.commit()
|
||||
db.refresh(consent)
|
||||
|
||||
# Log consent grant
|
||||
await audit_service.log_action(
|
||||
db=db,
|
||||
action='consent_granted',
|
||||
resource_type='consent',
|
||||
user_id=user_id,
|
||||
resource_id=consent.id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
details={
|
||||
'consent_type': consent_type.value,
|
||||
'legal_basis': legal_basis,
|
||||
'consent_method': consent_method
|
||||
},
|
||||
status='success'
|
||||
)
|
||||
|
||||
logger.info(f'Consent granted: {consent_type.value} for user {user_id}')
|
||||
return consent
|
||||
|
||||
@staticmethod
|
||||
async def withdraw_consent(
|
||||
db: Session,
|
||||
user_id: int,
|
||||
consent_type: ConsentType,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None
|
||||
) -> Consent:
|
||||
"""Withdraw consent (Article 7(3) GDPR)."""
|
||||
consent = db.query(Consent).filter(
|
||||
Consent.user_id == user_id,
|
||||
Consent.consent_type == consent_type,
|
||||
Consent.status == ConsentStatus.granted
|
||||
).order_by(Consent.granted_at.desc()).first()
|
||||
|
||||
if not consent:
|
||||
raise ValueError(f'No active consent found for {consent_type.value}')
|
||||
|
||||
consent.status = ConsentStatus.withdrawn
|
||||
consent.withdrawn_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(consent)
|
||||
|
||||
# Log consent withdrawal
|
||||
await audit_service.log_action(
|
||||
db=db,
|
||||
action='consent_withdrawn',
|
||||
resource_type='consent',
|
||||
user_id=user_id,
|
||||
resource_id=consent.id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
details={'consent_type': consent_type.value},
|
||||
status='success'
|
||||
)
|
||||
|
||||
logger.info(f'Consent withdrawn: {consent_type.value} for user {user_id}')
|
||||
return consent
|
||||
|
||||
@staticmethod
|
||||
def get_user_consents(
|
||||
db: Session,
|
||||
user_id: int,
|
||||
include_withdrawn: bool = False
|
||||
) -> List[Consent]:
|
||||
"""Get all consents for a user."""
|
||||
query = db.query(Consent).filter(Consent.user_id == user_id)
|
||||
|
||||
if not include_withdrawn:
|
||||
query = query.filter(Consent.status == ConsentStatus.granted)
|
||||
|
||||
return query.order_by(Consent.granted_at.desc()).all()
|
||||
|
||||
@staticmethod
|
||||
def has_consent(
|
||||
db: Session,
|
||||
user_id: int,
|
||||
consent_type: ConsentType
|
||||
) -> bool:
|
||||
"""Check if user has active consent for a specific type."""
|
||||
consent = db.query(Consent).filter(
|
||||
Consent.user_id == user_id,
|
||||
Consent.consent_type == consent_type,
|
||||
Consent.status == ConsentStatus.granted
|
||||
).first()
|
||||
|
||||
if not consent:
|
||||
return False
|
||||
|
||||
# Check if expired
|
||||
if consent.expires_at and consent.expires_at < datetime.utcnow():
|
||||
consent.status = ConsentStatus.expired
|
||||
db.commit()
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def update_consent_preferences(
|
||||
db: Session,
|
||||
user_id: int,
|
||||
consents: Dict[ConsentType, bool],
|
||||
legal_basis: str = 'consent',
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
source: Optional[str] = None
|
||||
) -> List[Consent]:
|
||||
"""Update multiple consent preferences at once."""
|
||||
results = []
|
||||
|
||||
for consent_type, granted in consents.items():
|
||||
if granted:
|
||||
consent = await ConsentService.grant_consent(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
consent_type=consent_type,
|
||||
legal_basis=legal_basis,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
source=source
|
||||
)
|
||||
results.append(consent)
|
||||
else:
|
||||
try:
|
||||
consent = await ConsentService.withdraw_consent(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
consent_type=consent_type,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent
|
||||
)
|
||||
results.append(consent)
|
||||
except ValueError:
|
||||
# No active consent to withdraw
|
||||
pass
|
||||
|
||||
return results
|
||||
|
||||
consent_service = ConsentService()
|
||||
|
||||
128
Backend/src/compliance/services/data_processing_service.py
Normal file
128
Backend/src/compliance/services/data_processing_service.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Data Processing Records Service (Article 30 GDPR).
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
from ..models.data_processing_record import DataProcessingRecord, ProcessingCategory, LegalBasis
|
||||
from ...shared.config.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class DataProcessingService:
|
||||
"""Service for maintaining data processing records (Article 30 GDPR)."""
|
||||
|
||||
@staticmethod
|
||||
async def create_processing_record(
|
||||
db: Session,
|
||||
processing_category: ProcessingCategory,
|
||||
legal_basis: LegalBasis,
|
||||
purpose: str,
|
||||
data_categories: Optional[List[str]] = None,
|
||||
data_subjects: Optional[List[str]] = None,
|
||||
recipients: Optional[List[str]] = None,
|
||||
third_parties: Optional[List[str]] = None,
|
||||
transfers_to_third_countries: bool = False,
|
||||
transfer_countries: Optional[List[str]] = None,
|
||||
safeguards: Optional[str] = None,
|
||||
retention_period: Optional[str] = None,
|
||||
retention_criteria: Optional[str] = None,
|
||||
security_measures: Optional[str] = None,
|
||||
user_id: Optional[int] = None,
|
||||
related_booking_id: Optional[int] = None,
|
||||
related_payment_id: Optional[int] = None,
|
||||
processed_by: Optional[int] = None,
|
||||
extra_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> DataProcessingRecord:
|
||||
"""Create a data processing record."""
|
||||
record = DataProcessingRecord(
|
||||
processing_category=processing_category,
|
||||
legal_basis=legal_basis,
|
||||
purpose=purpose,
|
||||
data_categories=data_categories or [],
|
||||
data_subjects=data_subjects or [],
|
||||
recipients=recipients or [],
|
||||
third_parties=third_parties or [],
|
||||
transfers_to_third_countries=transfers_to_third_countries,
|
||||
transfer_countries=transfer_countries or [],
|
||||
safeguards=safeguards,
|
||||
retention_period=retention_period,
|
||||
retention_criteria=retention_criteria,
|
||||
security_measures=security_measures,
|
||||
user_id=user_id,
|
||||
related_booking_id=related_booking_id,
|
||||
related_payment_id=related_payment_id,
|
||||
processed_by=processed_by,
|
||||
processing_timestamp=datetime.utcnow(),
|
||||
extra_metadata=extra_metadata
|
||||
)
|
||||
|
||||
db.add(record)
|
||||
db.commit()
|
||||
db.refresh(record)
|
||||
|
||||
logger.info(f'Data processing record created: {record.id}')
|
||||
return record
|
||||
|
||||
@staticmethod
|
||||
def get_processing_records(
|
||||
db: Session,
|
||||
user_id: Optional[int] = None,
|
||||
processing_category: Optional[ProcessingCategory] = None,
|
||||
legal_basis: Optional[LegalBasis] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> List[DataProcessingRecord]:
|
||||
"""Get data processing records with optional filters."""
|
||||
query = db.query(DataProcessingRecord)
|
||||
|
||||
if user_id:
|
||||
query = query.filter(DataProcessingRecord.user_id == user_id)
|
||||
|
||||
if processing_category:
|
||||
query = query.filter(DataProcessingRecord.processing_category == processing_category)
|
||||
|
||||
if legal_basis:
|
||||
query = query.filter(DataProcessingRecord.legal_basis == legal_basis)
|
||||
|
||||
return query.order_by(DataProcessingRecord.processing_timestamp.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
def get_user_processing_summary(
|
||||
db: Session,
|
||||
user_id: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Get a summary of all data processing activities for a user."""
|
||||
records = db.query(DataProcessingRecord).filter(
|
||||
DataProcessingRecord.user_id == user_id
|
||||
).all()
|
||||
|
||||
summary = {
|
||||
'total_records': len(records),
|
||||
'by_category': {},
|
||||
'by_legal_basis': {},
|
||||
'third_party_sharing': [],
|
||||
'transfers_to_third_countries': []
|
||||
}
|
||||
|
||||
for record in records:
|
||||
# By category
|
||||
category = record.processing_category.value
|
||||
summary['by_category'][category] = summary['by_category'].get(category, 0) + 1
|
||||
|
||||
# By legal basis
|
||||
basis = record.legal_basis.value
|
||||
summary['by_legal_basis'][basis] = summary['by_legal_basis'].get(basis, 0) + 1
|
||||
|
||||
# Third party sharing
|
||||
if record.third_parties:
|
||||
summary['third_party_sharing'].extend(record.third_parties)
|
||||
|
||||
# Transfers
|
||||
if record.transfers_to_third_countries:
|
||||
summary['transfers_to_third_countries'].extend(record.transfer_countries or [])
|
||||
|
||||
return summary
|
||||
|
||||
data_processing_service = DataProcessingService()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
141
Backend/src/compliance/services/retention_service.py
Normal file
141
Backend/src/compliance/services/retention_service.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Data Retention Service for GDPR compliance.
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
from ..models.data_retention import RetentionRule, DataRetentionLog
|
||||
from ...shared.config.logging_config import get_logger
|
||||
from ...analytics.services.audit_service import audit_service
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class RetentionService:
|
||||
"""Service for managing data retention policies and cleanup."""
|
||||
|
||||
@staticmethod
|
||||
def create_retention_rule(
|
||||
db: Session,
|
||||
data_category: str,
|
||||
retention_period_days: int,
|
||||
retention_period_months: Optional[int] = None,
|
||||
retention_period_years: Optional[int] = None,
|
||||
legal_basis: Optional[str] = None,
|
||||
legal_requirement: Optional[str] = None,
|
||||
action_after_retention: str = 'anonymize',
|
||||
conditions: Optional[Dict[str, Any]] = None,
|
||||
description: Optional[str] = None,
|
||||
created_by: Optional[int] = None
|
||||
) -> RetentionRule:
|
||||
"""Create a data retention rule."""
|
||||
rule = RetentionRule(
|
||||
data_category=data_category,
|
||||
retention_period_days=retention_period_days,
|
||||
retention_period_months=retention_period_months,
|
||||
retention_period_years=retention_period_years,
|
||||
legal_basis=legal_basis,
|
||||
legal_requirement=legal_requirement,
|
||||
action_after_retention=action_after_retention,
|
||||
conditions=conditions,
|
||||
description=description,
|
||||
created_by=created_by,
|
||||
is_active=True
|
||||
)
|
||||
|
||||
db.add(rule)
|
||||
db.commit()
|
||||
db.refresh(rule)
|
||||
|
||||
logger.info(f'Retention rule created: {data_category} - {retention_period_days} days')
|
||||
return rule
|
||||
|
||||
@staticmethod
|
||||
def get_retention_rules(
|
||||
db: Session,
|
||||
is_active: Optional[bool] = None
|
||||
) -> List[RetentionRule]:
|
||||
"""Get retention rules."""
|
||||
query = db.query(RetentionRule)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(RetentionRule.is_active == is_active)
|
||||
|
||||
return query.order_by(RetentionRule.data_category).all()
|
||||
|
||||
@staticmethod
|
||||
def get_retention_rule(
|
||||
db: Session,
|
||||
data_category: str
|
||||
) -> Optional[RetentionRule]:
|
||||
"""Get retention rule for a specific data category."""
|
||||
return db.query(RetentionRule).filter(
|
||||
RetentionRule.data_category == data_category,
|
||||
RetentionRule.is_active == True
|
||||
).first()
|
||||
|
||||
@staticmethod
|
||||
async def log_retention_action(
|
||||
db: Session,
|
||||
retention_rule_id: int,
|
||||
data_category: str,
|
||||
action_taken: str,
|
||||
records_affected: int,
|
||||
affected_ids: Optional[List[int]] = None,
|
||||
executed_by: Optional[int] = None,
|
||||
success: bool = True,
|
||||
error_message: Optional[str] = None,
|
||||
extra_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> DataRetentionLog:
|
||||
"""Log a data retention action."""
|
||||
log = DataRetentionLog(
|
||||
retention_rule_id=retention_rule_id,
|
||||
data_category=data_category,
|
||||
action_taken=action_taken,
|
||||
records_affected=records_affected,
|
||||
affected_ids=affected_ids or [],
|
||||
executed_by=executed_by,
|
||||
executed_at=datetime.utcnow(),
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
extra_metadata=extra_metadata
|
||||
)
|
||||
|
||||
db.add(log)
|
||||
db.commit()
|
||||
db.refresh(log)
|
||||
|
||||
# Log to audit trail
|
||||
await audit_service.log_action(
|
||||
db=db,
|
||||
action='data_retention_action',
|
||||
resource_type='retention_log',
|
||||
user_id=executed_by,
|
||||
resource_id=log.id,
|
||||
details={
|
||||
'data_category': data_category,
|
||||
'action_taken': action_taken,
|
||||
'records_affected': records_affected
|
||||
},
|
||||
status='success' if success else 'error'
|
||||
)
|
||||
|
||||
logger.info(f'Retention action logged: {action_taken} on {data_category} - {records_affected} records')
|
||||
return log
|
||||
|
||||
@staticmethod
|
||||
def get_retention_logs(
|
||||
db: Session,
|
||||
data_category: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> List[DataRetentionLog]:
|
||||
"""Get retention action logs."""
|
||||
query = db.query(DataRetentionLog)
|
||||
|
||||
if data_category:
|
||||
query = query.filter(DataRetentionLog.data_category == data_category)
|
||||
|
||||
return query.order_by(DataRetentionLog.executed_at.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
retention_service = RetentionService()
|
||||
|
||||
Binary file not shown.
@@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
import random
|
||||
import secrets
|
||||
|
||||
from ...shared.config.database import get_db
|
||||
from ...shared.config.logging_config import get_logger
|
||||
@@ -33,7 +33,8 @@ router = APIRouter(prefix="/service-bookings", tags=["service-bookings"])
|
||||
def generate_service_booking_number() -> str:
|
||||
prefix = "SB"
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d")
|
||||
random_suffix = random.randint(1000, 9999)
|
||||
# Use cryptographically secure random number to prevent enumeration attacks
|
||||
random_suffix = secrets.randbelow(9000) + 1000 # Random number between 1000-9999
|
||||
return f"{prefix}{timestamp}{random_suffix}"
|
||||
|
||||
@router.post("/")
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,7 +1,7 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime, timedelta, date
|
||||
from typing import Optional
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
from ..models.user_loyalty import UserLoyalty
|
||||
from ..models.loyalty_tier import LoyaltyTier, TierLevel
|
||||
@@ -78,19 +78,23 @@ class LoyaltyService:
|
||||
|
||||
@staticmethod
|
||||
def generate_referral_code(db: Session, user_id: int, length: int = 8) -> str:
|
||||
"""Generate unique referral code for user"""
|
||||
"""Generate unique referral code for user using cryptographically secure random"""
|
||||
max_attempts = 10
|
||||
alphabet = string.ascii_uppercase + string.digits
|
||||
for _ in range(max_attempts):
|
||||
# Generate code: USER1234 format
|
||||
code = f"USER{user_id:04d}{''.join(random.choices(string.ascii_uppercase + string.digits, k=length-8))}"
|
||||
# Generate code: USER1234 format using cryptographically secure random
|
||||
# Use secrets.choice() instead of random.choices() for security
|
||||
random_part = ''.join(secrets.choice(alphabet) for _ in range(length-8))
|
||||
code = f"USER{user_id:04d}{random_part}"
|
||||
|
||||
# Check if code exists
|
||||
existing = db.query(UserLoyalty).filter(UserLoyalty.referral_code == code).first()
|
||||
if not existing:
|
||||
return code
|
||||
|
||||
# Fallback: timestamp-based
|
||||
return f"REF{int(datetime.utcnow().timestamp())}{user_id}"
|
||||
# Fallback: timestamp-based with secure random suffix
|
||||
random_suffix = ''.join(secrets.choice(alphabet) for _ in range(4))
|
||||
return f"REF{int(datetime.utcnow().timestamp())}{user_id}{random_suffix}"
|
||||
|
||||
@staticmethod
|
||||
def create_default_tiers(db: Session):
|
||||
@@ -340,14 +344,18 @@ class LoyaltyService:
|
||||
|
||||
@staticmethod
|
||||
def generate_redemption_code(db: Session, length: int = 12) -> str:
|
||||
"""Generate unique redemption code"""
|
||||
"""Generate unique redemption code using cryptographically secure random"""
|
||||
max_attempts = 10
|
||||
alphabet = string.ascii_uppercase + string.digits
|
||||
for _ in range(max_attempts):
|
||||
code = ''.join(random.choices(string.ascii_uppercase + string.digits, k=length))
|
||||
# Use secrets.choice() instead of random.choices() for security
|
||||
code = ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
existing = db.query(RewardRedemption).filter(RewardRedemption.code == code).first()
|
||||
if not existing:
|
||||
return code
|
||||
return f"RED{int(datetime.utcnow().timestamp())}"
|
||||
# Fallback with secure random suffix
|
||||
random_suffix = ''.join(secrets.choice(alphabet) for _ in range(4))
|
||||
return f"RED{int(datetime.utcnow().timestamp())}{random_suffix}"
|
||||
|
||||
@staticmethod
|
||||
def process_referral(
|
||||
|
||||
@@ -95,10 +95,16 @@ else:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(f'Allowed CORS origins: {", ".join(settings.CORS_ORIGINS)}')
|
||||
|
||||
app.add_middleware(CORSMiddleware, allow_origins=settings.CORS_ORIGINS or [], allow_credentials=True, allow_methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], allow_headers=['*'])
|
||||
# SECURITY: Use explicit headers instead of wildcard to prevent header injection
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS or [],
|
||||
allow_credentials=True,
|
||||
allow_methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'],
|
||||
allow_headers=['Content-Type', 'Authorization', 'X-XSRF-TOKEN', 'X-Requested-With', 'X-Request-ID', 'Accept', 'Accept-Language']
|
||||
)
|
||||
uploads_dir = Path(__file__).parent.parent / settings.UPLOAD_DIR
|
||||
uploads_dir.mkdir(exist_ok=True)
|
||||
app.mount('/uploads', StaticFiles(directory=str(uploads_dir)), name='uploads')
|
||||
app.add_exception_handler(HTTPException, http_exception_handler)
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
app.add_exception_handler(IntegrityError, integrity_error_handler)
|
||||
@@ -108,18 +114,18 @@ app.add_exception_handler(Exception, general_exception_handler)
|
||||
@app.get('/health', tags=['health'])
|
||||
@app.get('/api/health', tags=['health'])
|
||||
async def health_check(db: Session=Depends(get_db)):
|
||||
"""Comprehensive health check endpoint"""
|
||||
"""
|
||||
Public health check endpoint.
|
||||
Returns minimal information for security - no sensitive details exposed.
|
||||
"""
|
||||
health_status = {
|
||||
'status': 'healthy',
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'service': settings.APP_NAME,
|
||||
'version': settings.APP_VERSION,
|
||||
'environment': settings.ENVIRONMENT,
|
||||
# SECURITY: Don't expose service name, version, or environment in public endpoint
|
||||
'checks': {
|
||||
'api': 'ok',
|
||||
'database': 'unknown',
|
||||
'disk_space': 'unknown',
|
||||
'memory': 'unknown'
|
||||
'database': 'unknown'
|
||||
# SECURITY: Don't expose disk_space or memory details publicly
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,60 +137,26 @@ async def health_check(db: Session=Depends(get_db)):
|
||||
except OperationalError as e:
|
||||
health_status['status'] = 'unhealthy'
|
||||
health_status['checks']['database'] = 'error'
|
||||
health_status['error'] = str(e)
|
||||
# SECURITY: Don't expose database error details publicly
|
||||
logger.error(f'Database health check failed: {str(e)}')
|
||||
# Remove error details from response
|
||||
return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content=health_status)
|
||||
except Exception as e:
|
||||
health_status['status'] = 'unhealthy'
|
||||
health_status['checks']['database'] = 'error'
|
||||
health_status['error'] = str(e)
|
||||
# SECURITY: Don't expose error details publicly
|
||||
logger.error(f'Health check failed: {str(e)}')
|
||||
# Remove error details from response
|
||||
return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content=health_status)
|
||||
|
||||
# Check disk space (if available)
|
||||
try:
|
||||
import shutil
|
||||
disk = shutil.disk_usage('/')
|
||||
free_percent = (disk.free / disk.total) * 100
|
||||
if free_percent < 10:
|
||||
health_status['checks']['disk_space'] = 'warning'
|
||||
health_status['status'] = 'degraded'
|
||||
else:
|
||||
health_status['checks']['disk_space'] = 'ok'
|
||||
health_status['disk_space'] = {
|
||||
'free_gb': round(disk.free / (1024**3), 2),
|
||||
'total_gb': round(disk.total / (1024**3), 2),
|
||||
'free_percent': round(free_percent, 2)
|
||||
}
|
||||
except Exception:
|
||||
health_status['checks']['disk_space'] = 'unknown'
|
||||
|
||||
# Check memory (if available)
|
||||
try:
|
||||
import psutil
|
||||
memory = psutil.virtual_memory()
|
||||
if memory.percent > 90:
|
||||
health_status['checks']['memory'] = 'warning'
|
||||
if health_status['status'] == 'healthy':
|
||||
health_status['status'] = 'degraded'
|
||||
else:
|
||||
health_status['checks']['memory'] = 'ok'
|
||||
health_status['memory'] = {
|
||||
'used_percent': round(memory.percent, 2),
|
||||
'available_gb': round(memory.available / (1024**3), 2),
|
||||
'total_gb': round(memory.total / (1024**3), 2)
|
||||
}
|
||||
except ImportError:
|
||||
# psutil not available, skip memory check
|
||||
health_status['checks']['memory'] = 'unavailable'
|
||||
except Exception:
|
||||
health_status['checks']['memory'] = 'unknown'
|
||||
# SECURITY: Disk space and memory checks removed from public endpoint
|
||||
# These details should only be available on internal/admin health endpoint
|
||||
|
||||
# Determine overall status
|
||||
if health_status['status'] == 'healthy' and any(
|
||||
check == 'warning' for check in health_status['checks'].values()
|
||||
check == 'error' for check in health_status['checks'].values()
|
||||
):
|
||||
health_status['status'] = 'degraded'
|
||||
health_status['status'] = 'unhealthy'
|
||||
|
||||
status_code = status.HTTP_200_OK
|
||||
if health_status['status'] == 'unhealthy':
|
||||
@@ -195,8 +167,110 @@ async def health_check(db: Session=Depends(get_db)):
|
||||
return JSONResponse(status_code=status_code, content=health_status)
|
||||
|
||||
@app.get('/metrics', tags=['monitoring'])
|
||||
async def metrics():
|
||||
return {'status': 'success', 'service': settings.APP_NAME, 'version': settings.APP_VERSION, 'environment': settings.ENVIRONMENT, 'timestamp': datetime.utcnow().isoformat()}
|
||||
async def metrics(
|
||||
current_user = Depends(lambda: None)
|
||||
):
|
||||
"""
|
||||
Protected metrics endpoint - requires admin or staff authentication.
|
||||
SECURITY: Prevents information disclosure to unauthorized users.
|
||||
"""
|
||||
from ..security.middleware.auth import authorize_roles
|
||||
|
||||
# Only allow admin and staff to access metrics
|
||||
# Use authorize_roles as dependency - it will check authorization automatically
|
||||
admin_or_staff = authorize_roles('admin', 'staff')
|
||||
# FastAPI will inject dependencies when this dependency is resolved
|
||||
current_user = admin_or_staff()
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'service': settings.APP_NAME,
|
||||
'version': settings.APP_VERSION,
|
||||
'environment': settings.ENVIRONMENT,
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Custom route for serving uploads with CORS headers
|
||||
# This route takes precedence over the mount below
|
||||
from fastapi.responses import FileResponse
|
||||
import re
|
||||
|
||||
@app.options('/uploads/{file_path:path}')
|
||||
async def serve_upload_file_options(file_path: str, request: Request):
|
||||
"""Handle CORS preflight for upload files."""
|
||||
origin = request.headers.get('origin')
|
||||
if origin:
|
||||
if settings.is_development:
|
||||
if re.match(r'http://(localhost|127\.0\.0\.1)(:\d+)?', origin):
|
||||
return JSONResponse(
|
||||
content={},
|
||||
headers={
|
||||
'Access-Control-Allow-Origin': origin,
|
||||
'Access-Control-Allow-Credentials': 'true',
|
||||
'Access-Control-Allow-Methods': 'GET, HEAD, OPTIONS',
|
||||
'Access-Control-Allow-Headers': '*',
|
||||
'Access-Control-Max-Age': '3600'
|
||||
}
|
||||
)
|
||||
elif origin in (settings.CORS_ORIGINS or []):
|
||||
return JSONResponse(
|
||||
content={},
|
||||
headers={
|
||||
'Access-Control-Allow-Origin': origin,
|
||||
'Access-Control-Allow-Credentials': 'true',
|
||||
'Access-Control-Allow-Methods': 'GET, HEAD, OPTIONS',
|
||||
'Access-Control-Allow-Headers': '*',
|
||||
'Access-Control-Max-Age': '3600'
|
||||
}
|
||||
)
|
||||
return JSONResponse(content={})
|
||||
|
||||
@app.get('/uploads/{file_path:path}')
|
||||
@app.head('/uploads/{file_path:path}')
|
||||
async def serve_upload_file(file_path: str, request: Request):
|
||||
"""Serve uploaded files with proper CORS headers."""
|
||||
file_location = uploads_dir / file_path
|
||||
|
||||
# Security: Prevent directory traversal
|
||||
try:
|
||||
resolved_path = file_location.resolve()
|
||||
resolved_uploads = uploads_dir.resolve()
|
||||
if not str(resolved_path).startswith(str(resolved_uploads)):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
except (ValueError, OSError):
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
if not file_location.exists() or not file_location.is_file():
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
# Get origin from request
|
||||
origin = request.headers.get('origin')
|
||||
|
||||
# Prepare response
|
||||
response = FileResponse(str(file_location))
|
||||
|
||||
# Add CORS headers if origin matches
|
||||
if origin:
|
||||
if settings.is_development:
|
||||
if re.match(r'http://(localhost|127\.0\.0\.1)(:\d+)?', origin):
|
||||
response.headers['Access-Control-Allow-Origin'] = origin
|
||||
response.headers['Access-Control-Allow-Credentials'] = 'true'
|
||||
response.headers['Access-Control-Allow-Methods'] = 'GET, HEAD, OPTIONS'
|
||||
response.headers['Access-Control-Allow-Headers'] = '*'
|
||||
response.headers['Access-Control-Expose-Headers'] = '*'
|
||||
elif origin in (settings.CORS_ORIGINS or []):
|
||||
response.headers['Access-Control-Allow-Origin'] = origin
|
||||
response.headers['Access-Control-Allow-Credentials'] = 'true'
|
||||
response.headers['Access-Control-Allow-Methods'] = 'GET, HEAD, OPTIONS'
|
||||
response.headers['Access-Control-Allow-Headers'] = '*'
|
||||
response.headers['Access-Control-Expose-Headers'] = '*'
|
||||
|
||||
return response
|
||||
|
||||
# Mount static files as fallback (routes take precedence)
|
||||
from starlette.staticfiles import StaticFiles
|
||||
app.mount('/uploads-static', StaticFiles(directory=str(uploads_dir)), name='uploads-static')
|
||||
|
||||
# Import all route modules from feature-based structure
|
||||
from .auth.routes import auth_routes, user_routes
|
||||
from .rooms.routes import room_routes, advanced_room_routes, rate_plan_routes
|
||||
@@ -219,6 +293,7 @@ from .security.routes import security_routes, compliance_routes
|
||||
from .system.routes import system_settings_routes, workflow_routes, task_routes, approval_routes, backup_routes
|
||||
from .ai.routes import ai_assistant_routes
|
||||
from .compliance.routes import gdpr_routes
|
||||
from .compliance.routes.gdpr_admin_routes import router as gdpr_admin_routes
|
||||
from .integrations.routes import webhook_routes, api_key_routes
|
||||
from .auth.routes import session_routes
|
||||
|
||||
@@ -274,6 +349,7 @@ app.include_router(blog_routes.router, prefix=api_prefix)
|
||||
app.include_router(ai_assistant_routes.router, prefix=api_prefix)
|
||||
app.include_router(approval_routes.router, prefix=api_prefix)
|
||||
app.include_router(gdpr_routes.router, prefix=api_prefix)
|
||||
app.include_router(gdpr_admin_routes, prefix=api_prefix)
|
||||
app.include_router(webhook_routes.router, prefix=api_prefix)
|
||||
app.include_router(api_key_routes.router, prefix=api_prefix)
|
||||
app.include_router(session_routes.router, prefix=api_prefix)
|
||||
@@ -281,57 +357,38 @@ app.include_router(backup_routes.router, prefix=api_prefix)
|
||||
logger.info('All routes registered successfully')
|
||||
|
||||
def ensure_jwt_secret():
|
||||
"""Generate and save JWT secret if it's using the default value.
|
||||
|
||||
In production, fail fast if default secret is used for security.
|
||||
In development, auto-generate a secure secret if needed.
|
||||
"""
|
||||
default_secret = 'dev-secret-key-change-in-production-12345'
|
||||
Validate JWT secret is properly configured.
|
||||
|
||||
SECURITY: JWT_SECRET must be explicitly set via environment variable.
|
||||
No default values are acceptable for security.
|
||||
"""
|
||||
current_secret = settings.JWT_SECRET
|
||||
|
||||
# Security check: Fail fast in production if using default secret
|
||||
if settings.is_production and (not current_secret or current_secret == default_secret):
|
||||
error_msg = (
|
||||
'CRITICAL SECURITY ERROR: JWT_SECRET is using default value in production! '
|
||||
'Please set a secure JWT_SECRET in your environment variables.'
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Development mode: Auto-generate if needed
|
||||
if not current_secret or current_secret == default_secret:
|
||||
new_secret = secrets.token_urlsafe(64)
|
||||
|
||||
os.environ['JWT_SECRET'] = new_secret
|
||||
|
||||
env_file = Path(__file__).parent.parent / '.env'
|
||||
if env_file.exists():
|
||||
try:
|
||||
env_content = env_file.read_text(encoding='utf-8')
|
||||
|
||||
jwt_pattern = re.compile(r'^JWT_SECRET=.*$', re.MULTILINE)
|
||||
|
||||
if jwt_pattern.search(env_content):
|
||||
env_content = jwt_pattern.sub(f'JWT_SECRET={new_secret}', env_content)
|
||||
else:
|
||||
jwt_section_pattern = re.compile(r'(# =+.*JWT.*=+.*\n)', re.IGNORECASE | re.MULTILINE)
|
||||
match = jwt_section_pattern.search(env_content)
|
||||
if match:
|
||||
insert_pos = match.end()
|
||||
env_content = env_content[:insert_pos] + f'JWT_SECRET={new_secret}\n' + env_content[insert_pos:]
|
||||
else:
|
||||
env_content += f'\nJWT_SECRET={new_secret}\n'
|
||||
|
||||
env_file.write_text(env_content, encoding='utf-8')
|
||||
logger.info('✓ JWT secret generated and saved to .env file')
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not update .env file: {e}')
|
||||
logger.info(f'Generated JWT secret (add to .env manually): JWT_SECRET={new_secret}')
|
||||
# SECURITY: JWT_SECRET validation is now handled in settings.py
|
||||
# This function is kept for backward compatibility and logging
|
||||
if not current_secret or current_secret.strip() == '':
|
||||
if settings.is_production:
|
||||
# This should not happen as settings validation should catch it
|
||||
error_msg = (
|
||||
'CRITICAL SECURITY ERROR: JWT_SECRET is not configured. '
|
||||
'Please set JWT_SECRET environment variable before starting the application.'
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
logger.info(f'Generated JWT secret (add to .env file): JWT_SECRET={new_secret}')
|
||||
|
||||
logger.info('✓ Secure JWT secret generated automatically')
|
||||
logger.warning(
|
||||
'JWT_SECRET is not configured. Authentication will fail. '
|
||||
'Set JWT_SECRET environment variable before starting the application.'
|
||||
)
|
||||
else:
|
||||
# Validate secret strength
|
||||
if len(current_secret) < 64:
|
||||
if settings.is_production:
|
||||
logger.warning(
|
||||
f'JWT_SECRET is only {len(current_secret)} characters. '
|
||||
'Recommend using at least 64 characters for production security.'
|
||||
)
|
||||
logger.info('✓ JWT secret is configured')
|
||||
|
||||
@app.on_event('startup')
|
||||
@@ -375,7 +432,34 @@ async def shutdown_event():
|
||||
logger.info(f'{settings.APP_NAME} shutting down gracefully')
|
||||
if __name__ == '__main__':
|
||||
import uvicorn
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
"""Handle Ctrl+C gracefully."""
|
||||
logger.info('\nReceived interrupt signal (Ctrl+C). Shutting down gracefully...')
|
||||
sys.exit(0)
|
||||
|
||||
# Register signal handler for graceful shutdown on Ctrl+C
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
base_dir = Path(__file__).parent.parent
|
||||
src_dir = str(base_dir / 'src')
|
||||
uvicorn.run('src.main:app', host=settings.HOST, port=settings.PORT, reload=settings.is_development, log_level=settings.LOG_LEVEL.lower(), reload_dirs=[src_dir] if settings.is_development else None, reload_excludes=['*.log', '*.pyc', '*.pyo', '*.pyd', '__pycache__', '**/__pycache__/**', '*.db', '*.sqlite', '*.sqlite3'], reload_delay=0.5)
|
||||
# Enable hot reload in development mode or if explicitly enabled via environment variable
|
||||
use_reload = settings.is_development or os.getenv('ENABLE_RELOAD', 'false').lower() == 'true'
|
||||
if use_reload:
|
||||
logger.info('Hot reload enabled - server will restart on code changes')
|
||||
logger.info('Press Ctrl+C to stop the server')
|
||||
uvicorn.run(
|
||||
'src.main:app',
|
||||
host=settings.HOST,
|
||||
port=settings.PORT,
|
||||
reload=use_reload,
|
||||
log_level=settings.LOG_LEVEL.lower(),
|
||||
reload_dirs=[src_dir] if use_reload else None,
|
||||
reload_excludes=['*.log', '*.pyc', '*.pyo', '*.pyd', '__pycache__', '**/__pycache__/**', '*.db', '*.sqlite', '*.sqlite3', 'venv/**', '.venv/**'],
|
||||
reload_delay=0.5
|
||||
)
|
||||
Binary file not shown.
@@ -174,10 +174,13 @@ class BoricaService:
|
||||
backend=default_backend()
|
||||
)
|
||||
|
||||
# NOTE: SHA1 is required by Borica payment gateway protocol
|
||||
# This is a known security trade-off required for payment gateway compatibility
|
||||
# Monitor for Borica protocol updates that support stronger algorithms
|
||||
signature = private_key.sign(
|
||||
data.encode('utf-8'),
|
||||
padding.PKCS1v15(),
|
||||
hashes.SHA1()
|
||||
hashes.SHA1() # nosec B303 # Required by Borica protocol - acceptable risk
|
||||
)
|
||||
return base64.b64encode(signature).decode('utf-8')
|
||||
except Exception as e:
|
||||
@@ -228,11 +231,13 @@ class BoricaService:
|
||||
public_key = cert.public_key()
|
||||
signature_bytes = base64.b64decode(signature)
|
||||
|
||||
# NOTE: SHA1 is required by Borica payment gateway protocol
|
||||
# This is a known security trade-off required for payment gateway compatibility
|
||||
public_key.verify(
|
||||
signature_bytes,
|
||||
signature_data.encode('utf-8'),
|
||||
padding.PKCS1v15(),
|
||||
hashes.SHA1()
|
||||
hashes.SHA1() # nosec B303 # Required by Borica protocol - acceptable risk
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
|
||||
Binary file not shown.
@@ -10,7 +10,12 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
security_headers = {'X-Content-Type-Options': 'nosniff', 'X-Frame-Options': 'DENY', 'X-XSS-Protection': '1; mode=block', 'Referrer-Policy': 'strict-origin-when-cross-origin', 'Permissions-Policy': 'geolocation=(), microphone=(), camera=()'}
|
||||
security_headers.setdefault('Cross-Origin-Resource-Policy', 'cross-origin')
|
||||
# Allow cross-origin resource sharing for uploads/images
|
||||
# This is needed for images to load from different origins in development
|
||||
if '/uploads/' in str(request.url):
|
||||
security_headers.setdefault('Cross-Origin-Resource-Policy', 'cross-origin')
|
||||
else:
|
||||
security_headers.setdefault('Cross-Origin-Resource-Policy', 'same-origin')
|
||||
if settings.is_production:
|
||||
# Enhanced CSP with stricter directives
|
||||
# Using 'strict-dynamic' for better security with nonce-based scripts
|
||||
|
||||
Binary file not shown.
@@ -10,14 +10,14 @@ class Settings(BaseSettings):
|
||||
ENVIRONMENT: str = Field(default='development', description='Environment: development, staging, production')
|
||||
DEBUG: bool = Field(default=False, description='Debug mode')
|
||||
API_V1_PREFIX: str = Field(default='/api/v1', description='API v1 prefix')
|
||||
HOST: str = Field(default='0.0.0.0', description='Server host')
|
||||
HOST: str = Field(default='0.0.0.0', description='Server host. WARNING: 0.0.0.0 binds to all interfaces. Use 127.0.0.1 for development or specific IP for production.') # nosec B104 # Acceptable default with validation warning in production
|
||||
PORT: int = Field(default=8000, description='Server port')
|
||||
DB_USER: str = Field(default='root', description='Database user')
|
||||
DB_PASS: str = Field(default='', description='Database password')
|
||||
DB_NAME: str = Field(default='hotel_db', description='Database name')
|
||||
DB_HOST: str = Field(default='localhost', description='Database host')
|
||||
DB_PORT: str = Field(default='3306', description='Database port')
|
||||
JWT_SECRET: str = Field(default='dev-secret-key-change-in-production-12345', description='JWT secret key')
|
||||
JWT_SECRET: str = Field(default='', description='JWT secret key - MUST be set via environment variable. Minimum 64 characters recommended for production.')
|
||||
JWT_ALGORITHM: str = Field(default='HS256', description='JWT algorithm')
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = Field(default=30, description='JWT access token expiration in minutes')
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = Field(default=3, description='JWT refresh token expiration in days (reduced from 7 for better security)')
|
||||
@@ -97,6 +97,20 @@ class Settings(BaseSettings):
|
||||
IP_WHITELIST_ENABLED: bool = Field(default=False, description='Enable IP whitelisting for admin endpoints')
|
||||
ADMIN_IP_WHITELIST: List[str] = Field(default_factory=list, description='List of allowed IP addresses/CIDR ranges for admin endpoints')
|
||||
|
||||
def validate_host_configuration(self) -> None:
|
||||
"""
|
||||
Validate HOST configuration for security.
|
||||
Warns if binding to all interfaces (0.0.0.0) in production.
|
||||
"""
|
||||
if self.HOST == '0.0.0.0' and self.is_production:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
'SECURITY WARNING: HOST is set to 0.0.0.0 in production. '
|
||||
'This binds the server to all network interfaces. '
|
||||
'Consider using a specific IP address or ensure proper firewall rules are in place.'
|
||||
)
|
||||
|
||||
def validate_encryption_key(self) -> None:
|
||||
"""
|
||||
Validate encryption key is properly configured.
|
||||
@@ -139,3 +153,40 @@ class Settings(BaseSettings):
|
||||
logger.warning(f'Invalid ENCRYPTION_KEY format: {str(e)}')
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# Validate JWT_SECRET on startup - fail fast if not configured
|
||||
def validate_jwt_secret():
|
||||
"""Validate JWT_SECRET is properly configured. Called on startup."""
|
||||
if not settings.JWT_SECRET or settings.JWT_SECRET.strip() == '':
|
||||
error_msg = (
|
||||
'CRITICAL SECURITY ERROR: JWT_SECRET is not configured. '
|
||||
'Please set JWT_SECRET environment variable to a secure random string. '
|
||||
'Minimum 64 characters recommended for production. '
|
||||
'Generate one using: python -c "import secrets; print(secrets.token_urlsafe(64))"'
|
||||
)
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(error_msg)
|
||||
if settings.is_production:
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
logger.warning(
|
||||
'JWT_SECRET not configured. This will cause authentication to fail. '
|
||||
'Set JWT_SECRET environment variable before starting the application.'
|
||||
)
|
||||
|
||||
# Warn if using weak secret (less than 64 characters)
|
||||
if len(settings.JWT_SECRET) < 64:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
if settings.is_production:
|
||||
logger.warning(
|
||||
f'JWT_SECRET is only {len(settings.JWT_SECRET)} characters. '
|
||||
'Recommend using at least 64 characters for production security.'
|
||||
)
|
||||
else:
|
||||
logger.debug(f'JWT_SECRET length: {len(settings.JWT_SECRET)} characters')
|
||||
|
||||
# Validate on import
|
||||
validate_jwt_secret()
|
||||
settings.validate_host_configuration()
|
||||
Binary file not shown.
168
Backend/src/shared/utils/sanitization.py
Normal file
168
Backend/src/shared/utils/sanitization.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
HTML/XSS sanitization utilities using bleach library.
|
||||
Prevents stored XSS attacks by sanitizing user-generated content.
|
||||
"""
|
||||
import bleach
|
||||
from typing import Optional
|
||||
|
||||
# Allowed HTML tags for rich text content
|
||||
ALLOWED_TAGS = [
|
||||
'p', 'br', 'strong', 'em', 'u', 'b', 'i', 's', 'strike',
|
||||
'a', 'ul', 'ol', 'li', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
|
||||
'blockquote', 'pre', 'code', 'hr', 'div', 'span',
|
||||
'table', 'thead', 'tbody', 'tr', 'th', 'td',
|
||||
'img'
|
||||
]
|
||||
|
||||
# Allowed attributes for specific tags
|
||||
ALLOWED_ATTRIBUTES = {
|
||||
'a': ['href', 'title', 'target', 'rel'],
|
||||
'img': ['src', 'alt', 'title', 'width', 'height'],
|
||||
'div': ['class'],
|
||||
'span': ['class'],
|
||||
'p': ['class'],
|
||||
'table': ['class', 'border'],
|
||||
'th': ['colspan', 'rowspan'],
|
||||
'td': ['colspan', 'rowspan']
|
||||
}
|
||||
|
||||
# Allowed URL schemes
|
||||
ALLOWED_PROTOCOLS = ['http', 'https', 'mailto']
|
||||
|
||||
# Allowed CSS classes (optional - can be expanded)
|
||||
ALLOWED_STYLES = []
|
||||
|
||||
|
||||
def sanitize_html(content: Optional[str], strip: bool = False) -> str:
|
||||
"""
|
||||
Sanitize HTML content to prevent XSS attacks.
|
||||
|
||||
Args:
|
||||
content: The HTML content to sanitize (can be None)
|
||||
strip: If True, remove disallowed tags instead of escaping them
|
||||
|
||||
Returns:
|
||||
Sanitized HTML string
|
||||
"""
|
||||
if not content:
|
||||
return ''
|
||||
|
||||
if not isinstance(content, str):
|
||||
content = str(content)
|
||||
|
||||
# Sanitize HTML
|
||||
sanitized = bleach.clean(
|
||||
content,
|
||||
tags=ALLOWED_TAGS,
|
||||
attributes=ALLOWED_ATTRIBUTES,
|
||||
protocols=ALLOWED_PROTOCOLS,
|
||||
strip=strip,
|
||||
strip_comments=True
|
||||
)
|
||||
|
||||
# Linkify URLs (convert plain URLs to links)
|
||||
# Only linkify if content doesn't already contain HTML links
|
||||
if '<a' not in sanitized:
|
||||
sanitized = bleach.linkify(
|
||||
sanitized,
|
||||
protocols=ALLOWED_PROTOCOLS,
|
||||
parse_email=True
|
||||
)
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_text(content: Optional[str]) -> str:
|
||||
"""
|
||||
Strip all HTML tags from content, leaving only plain text.
|
||||
Useful for fields that should not contain any HTML.
|
||||
|
||||
Args:
|
||||
content: The content to sanitize (can be None)
|
||||
|
||||
Returns:
|
||||
Plain text string with all HTML removed
|
||||
"""
|
||||
if not content:
|
||||
return ''
|
||||
|
||||
if not isinstance(content, str):
|
||||
content = str(content)
|
||||
|
||||
# Strip all HTML tags
|
||||
return bleach.clean(content, tags=[], strip=True)
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""
|
||||
Sanitize filename to prevent path traversal and other attacks.
|
||||
|
||||
Args:
|
||||
filename: The original filename
|
||||
|
||||
Returns:
|
||||
Sanitized filename safe for filesystem operations
|
||||
"""
|
||||
import os
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
|
||||
if not filename:
|
||||
# Generate a random filename if none provided
|
||||
return f"{secrets.token_urlsafe(16)}.bin"
|
||||
|
||||
# Remove path components (prevent directory traversal)
|
||||
filename = os.path.basename(filename)
|
||||
|
||||
# Remove dangerous characters
|
||||
# Keep only alphanumeric, dots, dashes, and underscores
|
||||
safe_chars = []
|
||||
for char in filename:
|
||||
if char.isalnum() or char in '._-':
|
||||
safe_chars.append(char)
|
||||
else:
|
||||
safe_chars.append('_')
|
||||
|
||||
filename = ''.join(safe_chars)
|
||||
|
||||
# Limit length (filesystem limit is typically 255)
|
||||
if len(filename) > 255:
|
||||
name, ext = os.path.splitext(filename)
|
||||
max_name_length = 255 - len(ext)
|
||||
filename = name[:max_name_length] + ext
|
||||
|
||||
# Ensure filename is not empty
|
||||
if not filename or filename == '.' or filename == '..':
|
||||
filename = f"{secrets.token_urlsafe(16)}.bin"
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def sanitize_url(url: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
Sanitize URL to ensure it uses allowed protocols.
|
||||
|
||||
Args:
|
||||
url: The URL to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized URL or None if invalid
|
||||
"""
|
||||
if not url:
|
||||
return None
|
||||
|
||||
if not isinstance(url, str):
|
||||
url = str(url)
|
||||
|
||||
# Check if URL uses allowed protocol
|
||||
url_lower = url.lower().strip()
|
||||
if any(url_lower.startswith(proto + ':') for proto in ALLOWED_PROTOCOLS):
|
||||
return url
|
||||
|
||||
# If no protocol, assume https
|
||||
if '://' not in url:
|
||||
return f'https://{url}'
|
||||
|
||||
# Invalid protocol - return None
|
||||
return None
|
||||
|
||||
BIN
Backend/src/system/routes/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
Backend/src/system/routes/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
7
Backend/venv/bin/bandit
Executable file
7
Backend/venv/bin/bandit
Executable file
@@ -0,0 +1,7 @@
|
||||
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
|
||||
import sys
|
||||
from bandit.cli.main import main
|
||||
if __name__ == '__main__':
|
||||
if sys.argv[0].endswith('.exe'):
|
||||
sys.argv[0] = sys.argv[0][:-4]
|
||||
sys.exit(main())
|
||||
7
Backend/venv/bin/bandit-baseline
Executable file
7
Backend/venv/bin/bandit-baseline
Executable file
@@ -0,0 +1,7 @@
|
||||
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
|
||||
import sys
|
||||
from bandit.cli.baseline import main
|
||||
if __name__ == '__main__':
|
||||
if sys.argv[0].endswith('.exe'):
|
||||
sys.argv[0] = sys.argv[0][:-4]
|
||||
sys.exit(main())
|
||||
7
Backend/venv/bin/bandit-config-generator
Executable file
7
Backend/venv/bin/bandit-config-generator
Executable file
@@ -0,0 +1,7 @@
|
||||
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
|
||||
import sys
|
||||
from bandit.cli.config_generator import main
|
||||
if __name__ == '__main__':
|
||||
if sys.argv[0].endswith('.exe'):
|
||||
sys.argv[0] = sys.argv[0][:-4]
|
||||
sys.exit(main())
|
||||
7
Backend/venv/bin/doesitcache
Executable file
7
Backend/venv/bin/doesitcache
Executable file
@@ -0,0 +1,7 @@
|
||||
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
|
||||
import sys
|
||||
from cachecontrol._cmd import main
|
||||
if __name__ == '__main__':
|
||||
if sys.argv[0].endswith('.exe'):
|
||||
sys.argv[0] = sys.argv[0][:-4]
|
||||
sys.exit(main())
|
||||
7
Backend/venv/bin/fastapi
Executable file
7
Backend/venv/bin/fastapi
Executable file
@@ -0,0 +1,7 @@
|
||||
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
|
||||
import sys
|
||||
from fastapi.cli import main
|
||||
if __name__ == '__main__':
|
||||
if sys.argv[0].endswith('.exe'):
|
||||
sys.argv[0] = sys.argv[0][:-4]
|
||||
sys.exit(main())
|
||||
7
Backend/venv/bin/markdown-it
Executable file
7
Backend/venv/bin/markdown-it
Executable file
@@ -0,0 +1,7 @@
|
||||
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
|
||||
import sys
|
||||
from markdown_it.cli.parse import main
|
||||
if __name__ == '__main__':
|
||||
if sys.argv[0].endswith('.exe'):
|
||||
sys.argv[0] = sys.argv[0][:-4]
|
||||
sys.exit(main())
|
||||
7
Backend/venv/bin/nltk
Executable file
7
Backend/venv/bin/nltk
Executable file
@@ -0,0 +1,7 @@
|
||||
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
|
||||
import sys
|
||||
from nltk.cli import cli
|
||||
if __name__ == '__main__':
|
||||
if sys.argv[0].endswith('.exe'):
|
||||
sys.argv[0] = sys.argv[0][:-4]
|
||||
sys.exit(cli())
|
||||
7
Backend/venv/bin/pip-audit
Executable file
7
Backend/venv/bin/pip-audit
Executable file
@@ -0,0 +1,7 @@
|
||||
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
|
||||
import sys
|
||||
from pip_audit._cli import audit
|
||||
if __name__ == '__main__':
|
||||
if sys.argv[0].endswith('.exe'):
|
||||
sys.argv[0] = sys.argv[0][:-4]
|
||||
sys.exit(audit())
|
||||
7
Backend/venv/bin/safety
Executable file
7
Backend/venv/bin/safety
Executable file
@@ -0,0 +1,7 @@
|
||||
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
|
||||
import sys
|
||||
from safety.cli import cli
|
||||
if __name__ == '__main__':
|
||||
if sys.argv[0].endswith('.exe'):
|
||||
sys.argv[0] = sys.argv[0][:-4]
|
||||
sys.exit(cli())
|
||||
7
Backend/venv/bin/tqdm
Executable file
7
Backend/venv/bin/tqdm
Executable file
@@ -0,0 +1,7 @@
|
||||
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
|
||||
import sys
|
||||
from tqdm.cli import main
|
||||
if __name__ == '__main__':
|
||||
if sys.argv[0].endswith('.exe'):
|
||||
sys.argv[0] = sys.argv[0][:-4]
|
||||
sys.exit(main())
|
||||
7
Backend/venv/bin/typer
Executable file
7
Backend/venv/bin/typer
Executable file
@@ -0,0 +1,7 @@
|
||||
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
|
||||
import sys
|
||||
from typer.cli import main
|
||||
if __name__ == '__main__':
|
||||
if sys.argv[0].endswith('.exe'):
|
||||
sys.argv[0] = sys.argv[0][:-4]
|
||||
sys.exit(main())
|
||||
@@ -1,59 +0,0 @@
|
||||
Jinja2-3.1.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
Jinja2-3.1.2.dist-info/LICENSE.rst,sha256=O0nc7kEF6ze6wQ-vG-JgQI_oXSUrjp3y4JefweCUQ3s,1475
|
||||
Jinja2-3.1.2.dist-info/METADATA,sha256=PZ6v2SIidMNixR7MRUX9f7ZWsPwtXanknqiZUmRbh4U,3539
|
||||
Jinja2-3.1.2.dist-info/RECORD,,
|
||||
Jinja2-3.1.2.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
Jinja2-3.1.2.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
||||
Jinja2-3.1.2.dist-info/entry_points.txt,sha256=zRd62fbqIyfUpsRtU7EVIFyiu1tPwfgO7EvPErnxgTE,59
|
||||
Jinja2-3.1.2.dist-info/top_level.txt,sha256=PkeVWtLb3-CqjWi1fO29OCbj55EhX_chhKrCdrVe_zs,7
|
||||
jinja2/__init__.py,sha256=8vGduD8ytwgD6GDSqpYc2m3aU-T7PKOAddvVXgGr_Fs,1927
|
||||
jinja2/__pycache__/__init__.cpython-312.pyc,,
|
||||
jinja2/__pycache__/_identifier.cpython-312.pyc,,
|
||||
jinja2/__pycache__/async_utils.cpython-312.pyc,,
|
||||
jinja2/__pycache__/bccache.cpython-312.pyc,,
|
||||
jinja2/__pycache__/compiler.cpython-312.pyc,,
|
||||
jinja2/__pycache__/constants.cpython-312.pyc,,
|
||||
jinja2/__pycache__/debug.cpython-312.pyc,,
|
||||
jinja2/__pycache__/defaults.cpython-312.pyc,,
|
||||
jinja2/__pycache__/environment.cpython-312.pyc,,
|
||||
jinja2/__pycache__/exceptions.cpython-312.pyc,,
|
||||
jinja2/__pycache__/ext.cpython-312.pyc,,
|
||||
jinja2/__pycache__/filters.cpython-312.pyc,,
|
||||
jinja2/__pycache__/idtracking.cpython-312.pyc,,
|
||||
jinja2/__pycache__/lexer.cpython-312.pyc,,
|
||||
jinja2/__pycache__/loaders.cpython-312.pyc,,
|
||||
jinja2/__pycache__/meta.cpython-312.pyc,,
|
||||
jinja2/__pycache__/nativetypes.cpython-312.pyc,,
|
||||
jinja2/__pycache__/nodes.cpython-312.pyc,,
|
||||
jinja2/__pycache__/optimizer.cpython-312.pyc,,
|
||||
jinja2/__pycache__/parser.cpython-312.pyc,,
|
||||
jinja2/__pycache__/runtime.cpython-312.pyc,,
|
||||
jinja2/__pycache__/sandbox.cpython-312.pyc,,
|
||||
jinja2/__pycache__/tests.cpython-312.pyc,,
|
||||
jinja2/__pycache__/utils.cpython-312.pyc,,
|
||||
jinja2/__pycache__/visitor.cpython-312.pyc,,
|
||||
jinja2/_identifier.py,sha256=_zYctNKzRqlk_murTNlzrju1FFJL7Va_Ijqqd7ii2lU,1958
|
||||
jinja2/async_utils.py,sha256=dHlbTeaxFPtAOQEYOGYh_PHcDT0rsDaUJAFDl_0XtTg,2472
|
||||
jinja2/bccache.py,sha256=mhz5xtLxCcHRAa56azOhphIAe19u1we0ojifNMClDio,14061
|
||||
jinja2/compiler.py,sha256=Gs-N8ThJ7OWK4-reKoO8Wh1ZXz95MVphBKNVf75qBr8,72172
|
||||
jinja2/constants.py,sha256=GMoFydBF_kdpaRKPoM5cl5MviquVRLVyZtfp5-16jg0,1433
|
||||
jinja2/debug.py,sha256=iWJ432RadxJNnaMOPrjIDInz50UEgni3_HKuFXi2vuQ,6299
|
||||
jinja2/defaults.py,sha256=boBcSw78h-lp20YbaXSJsqkAI2uN_mD_TtCydpeq5wU,1267
|
||||
jinja2/environment.py,sha256=6uHIcc7ZblqOMdx_uYNKqRnnwAF0_nzbyeMP9FFtuh4,61349
|
||||
jinja2/exceptions.py,sha256=ioHeHrWwCWNaXX1inHmHVblvc4haO7AXsjCp3GfWvx0,5071
|
||||
jinja2/ext.py,sha256=ivr3P7LKbddiXDVez20EflcO3q2aHQwz9P_PgWGHVqE,31502
|
||||
jinja2/filters.py,sha256=9js1V-h2RlyW90IhLiBGLM2U-k6SCy2F4BUUMgB3K9Q,53509
|
||||
jinja2/idtracking.py,sha256=GfNmadir4oDALVxzn3DL9YInhJDr69ebXeA2ygfuCGA,10704
|
||||
jinja2/lexer.py,sha256=DW2nX9zk-6MWp65YR2bqqj0xqCvLtD-u9NWT8AnFRxQ,29726
|
||||
jinja2/loaders.py,sha256=BfptfvTVpClUd-leMkHczdyPNYFzp_n7PKOJ98iyHOg,23207
|
||||
jinja2/meta.py,sha256=GNPEvifmSaU3CMxlbheBOZjeZ277HThOPUTf1RkppKQ,4396
|
||||
jinja2/nativetypes.py,sha256=DXgORDPRmVWgy034H0xL8eF7qYoK3DrMxs-935d0Fzk,4226
|
||||
jinja2/nodes.py,sha256=i34GPRAZexXMT6bwuf5SEyvdmS-bRCy9KMjwN5O6pjk,34550
|
||||
jinja2/optimizer.py,sha256=tHkMwXxfZkbfA1KmLcqmBMSaz7RLIvvItrJcPoXTyD8,1650
|
||||
jinja2/parser.py,sha256=nHd-DFHbiygvfaPtm9rcQXJChZG7DPsWfiEsqfwKerY,39595
|
||||
jinja2/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
jinja2/runtime.py,sha256=5CmD5BjbEJxSiDNTFBeKCaq8qU4aYD2v6q2EluyExms,33476
|
||||
jinja2/sandbox.py,sha256=Y0xZeXQnH6EX5VjaV2YixESxoepnRbW_3UeQosaBU3M,14584
|
||||
jinja2/tests.py,sha256=Am5Z6Lmfr2XaH_npIfJJ8MdXtWsbLjMULZJulTAj30E,5905
|
||||
jinja2/utils.py,sha256=u9jXESxGn8ATZNVolwmkjUVu4SA-tLgV0W7PcSfPfdQ,23965
|
||||
jinja2/visitor.py,sha256=MH14C6yq24G_KVtWzjwaI7Wg14PCJIYlWW1kpkxYak0,3568
|
||||
@@ -1,2 +0,0 @@
|
||||
[babel.extractors]
|
||||
jinja2 = jinja2.ext:babel_extract[i18n]
|
||||
@@ -1 +0,0 @@
|
||||
jinja2
|
||||
291
Backend/venv/lib/python3.12/site-packages/PIL/AvifImagePlugin.py
Normal file
291
Backend/venv/lib/python3.12/site-packages/PIL/AvifImagePlugin.py
Normal file
@@ -0,0 +1,291 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from io import BytesIO
|
||||
from typing import IO
|
||||
|
||||
from . import ExifTags, Image, ImageFile
|
||||
|
||||
try:
|
||||
from . import _avif
|
||||
|
||||
SUPPORTED = True
|
||||
except ImportError:
|
||||
SUPPORTED = False
|
||||
|
||||
# Decoder options as module globals, until there is a way to pass parameters
|
||||
# to Image.open (see https://github.com/python-pillow/Pillow/issues/569)
|
||||
DECODE_CODEC_CHOICE = "auto"
|
||||
DEFAULT_MAX_THREADS = 0
|
||||
|
||||
|
||||
def get_codec_version(codec_name: str) -> str | None:
|
||||
versions = _avif.codec_versions()
|
||||
for version in versions.split(", "):
|
||||
if version.split(" [")[0] == codec_name:
|
||||
return version.split(":")[-1].split(" ")[0]
|
||||
return None
|
||||
|
||||
|
||||
def _accept(prefix: bytes) -> bool | str:
|
||||
if prefix[4:8] != b"ftyp":
|
||||
return False
|
||||
major_brand = prefix[8:12]
|
||||
if major_brand in (
|
||||
# coding brands
|
||||
b"avif",
|
||||
b"avis",
|
||||
# We accept files with AVIF container brands; we can't yet know if
|
||||
# the ftyp box has the correct compatible brands, but if it doesn't
|
||||
# then the plugin will raise a SyntaxError which Pillow will catch
|
||||
# before moving on to the next plugin that accepts the file.
|
||||
#
|
||||
# Also, because this file might not actually be an AVIF file, we
|
||||
# don't raise an error if AVIF support isn't properly compiled.
|
||||
b"mif1",
|
||||
b"msf1",
|
||||
):
|
||||
if not SUPPORTED:
|
||||
return (
|
||||
"image file could not be identified because AVIF support not installed"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_default_max_threads() -> int:
|
||||
if DEFAULT_MAX_THREADS:
|
||||
return DEFAULT_MAX_THREADS
|
||||
if hasattr(os, "sched_getaffinity"):
|
||||
return len(os.sched_getaffinity(0))
|
||||
else:
|
||||
return os.cpu_count() or 1
|
||||
|
||||
|
||||
class AvifImageFile(ImageFile.ImageFile):
|
||||
format = "AVIF"
|
||||
format_description = "AVIF image"
|
||||
__frame = -1
|
||||
|
||||
def _open(self) -> None:
|
||||
if not SUPPORTED:
|
||||
msg = "image file could not be opened because AVIF support not installed"
|
||||
raise SyntaxError(msg)
|
||||
|
||||
if DECODE_CODEC_CHOICE != "auto" and not _avif.decoder_codec_available(
|
||||
DECODE_CODEC_CHOICE
|
||||
):
|
||||
msg = "Invalid opening codec"
|
||||
raise ValueError(msg)
|
||||
self._decoder = _avif.AvifDecoder(
|
||||
self.fp.read(),
|
||||
DECODE_CODEC_CHOICE,
|
||||
_get_default_max_threads(),
|
||||
)
|
||||
|
||||
# Get info from decoder
|
||||
self._size, self.n_frames, self._mode, icc, exif, exif_orientation, xmp = (
|
||||
self._decoder.get_info()
|
||||
)
|
||||
self.is_animated = self.n_frames > 1
|
||||
|
||||
if icc:
|
||||
self.info["icc_profile"] = icc
|
||||
if xmp:
|
||||
self.info["xmp"] = xmp
|
||||
|
||||
if exif_orientation != 1 or exif:
|
||||
exif_data = Image.Exif()
|
||||
if exif:
|
||||
exif_data.load(exif)
|
||||
original_orientation = exif_data.get(ExifTags.Base.Orientation, 1)
|
||||
else:
|
||||
original_orientation = 1
|
||||
if exif_orientation != original_orientation:
|
||||
exif_data[ExifTags.Base.Orientation] = exif_orientation
|
||||
exif = exif_data.tobytes()
|
||||
if exif:
|
||||
self.info["exif"] = exif
|
||||
self.seek(0)
|
||||
|
||||
def seek(self, frame: int) -> None:
|
||||
if not self._seek_check(frame):
|
||||
return
|
||||
|
||||
# Set tile
|
||||
self.__frame = frame
|
||||
self.tile = [ImageFile._Tile("raw", (0, 0) + self.size, 0, self.mode)]
|
||||
|
||||
def load(self) -> Image.core.PixelAccess | None:
|
||||
if self.tile:
|
||||
# We need to load the image data for this frame
|
||||
data, timescale, pts_in_timescales, duration_in_timescales = (
|
||||
self._decoder.get_frame(self.__frame)
|
||||
)
|
||||
self.info["timestamp"] = round(1000 * (pts_in_timescales / timescale))
|
||||
self.info["duration"] = round(1000 * (duration_in_timescales / timescale))
|
||||
|
||||
if self.fp and self._exclusive_fp:
|
||||
self.fp.close()
|
||||
self.fp = BytesIO(data)
|
||||
|
||||
return super().load()
|
||||
|
||||
def load_seek(self, pos: int) -> None:
|
||||
pass
|
||||
|
||||
def tell(self) -> int:
|
||||
return self.__frame
|
||||
|
||||
|
||||
def _save_all(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
|
||||
_save(im, fp, filename, save_all=True)
|
||||
|
||||
|
||||
def _save(
|
||||
im: Image.Image, fp: IO[bytes], filename: str | bytes, save_all: bool = False
|
||||
) -> None:
|
||||
info = im.encoderinfo.copy()
|
||||
if save_all:
|
||||
append_images = list(info.get("append_images", []))
|
||||
else:
|
||||
append_images = []
|
||||
|
||||
total = 0
|
||||
for ims in [im] + append_images:
|
||||
total += getattr(ims, "n_frames", 1)
|
||||
|
||||
quality = info.get("quality", 75)
|
||||
if not isinstance(quality, int) or quality < 0 or quality > 100:
|
||||
msg = "Invalid quality setting"
|
||||
raise ValueError(msg)
|
||||
|
||||
duration = info.get("duration", 0)
|
||||
subsampling = info.get("subsampling", "4:2:0")
|
||||
speed = info.get("speed", 6)
|
||||
max_threads = info.get("max_threads", _get_default_max_threads())
|
||||
codec = info.get("codec", "auto")
|
||||
if codec != "auto" and not _avif.encoder_codec_available(codec):
|
||||
msg = "Invalid saving codec"
|
||||
raise ValueError(msg)
|
||||
range_ = info.get("range", "full")
|
||||
tile_rows_log2 = info.get("tile_rows", 0)
|
||||
tile_cols_log2 = info.get("tile_cols", 0)
|
||||
alpha_premultiplied = bool(info.get("alpha_premultiplied", False))
|
||||
autotiling = bool(info.get("autotiling", tile_rows_log2 == tile_cols_log2 == 0))
|
||||
|
||||
icc_profile = info.get("icc_profile", im.info.get("icc_profile"))
|
||||
exif_orientation = 1
|
||||
if exif := info.get("exif"):
|
||||
if isinstance(exif, Image.Exif):
|
||||
exif_data = exif
|
||||
else:
|
||||
exif_data = Image.Exif()
|
||||
exif_data.load(exif)
|
||||
if ExifTags.Base.Orientation in exif_data:
|
||||
exif_orientation = exif_data.pop(ExifTags.Base.Orientation)
|
||||
exif = exif_data.tobytes() if exif_data else b""
|
||||
elif isinstance(exif, Image.Exif):
|
||||
exif = exif_data.tobytes()
|
||||
|
||||
xmp = info.get("xmp")
|
||||
|
||||
if isinstance(xmp, str):
|
||||
xmp = xmp.encode("utf-8")
|
||||
|
||||
advanced = info.get("advanced")
|
||||
if advanced is not None:
|
||||
if isinstance(advanced, dict):
|
||||
advanced = advanced.items()
|
||||
try:
|
||||
advanced = tuple(advanced)
|
||||
except TypeError:
|
||||
invalid = True
|
||||
else:
|
||||
invalid = any(not isinstance(v, tuple) or len(v) != 2 for v in advanced)
|
||||
if invalid:
|
||||
msg = (
|
||||
"advanced codec options must be a dict of key-value string "
|
||||
"pairs or a series of key-value two-tuples"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Setup the AVIF encoder
|
||||
enc = _avif.AvifEncoder(
|
||||
im.size,
|
||||
subsampling,
|
||||
quality,
|
||||
speed,
|
||||
max_threads,
|
||||
codec,
|
||||
range_,
|
||||
tile_rows_log2,
|
||||
tile_cols_log2,
|
||||
alpha_premultiplied,
|
||||
autotiling,
|
||||
icc_profile or b"",
|
||||
exif or b"",
|
||||
exif_orientation,
|
||||
xmp or b"",
|
||||
advanced,
|
||||
)
|
||||
|
||||
# Add each frame
|
||||
frame_idx = 0
|
||||
frame_duration = 0
|
||||
cur_idx = im.tell()
|
||||
is_single_frame = total == 1
|
||||
try:
|
||||
for ims in [im] + append_images:
|
||||
# Get number of frames in this image
|
||||
nfr = getattr(ims, "n_frames", 1)
|
||||
|
||||
for idx in range(nfr):
|
||||
ims.seek(idx)
|
||||
|
||||
# Make sure image mode is supported
|
||||
frame = ims
|
||||
rawmode = ims.mode
|
||||
if ims.mode not in {"RGB", "RGBA"}:
|
||||
rawmode = "RGBA" if ims.has_transparency_data else "RGB"
|
||||
frame = ims.convert(rawmode)
|
||||
|
||||
# Update frame duration
|
||||
if isinstance(duration, (list, tuple)):
|
||||
frame_duration = duration[frame_idx]
|
||||
else:
|
||||
frame_duration = duration
|
||||
|
||||
# Append the frame to the animation encoder
|
||||
enc.add(
|
||||
frame.tobytes("raw", rawmode),
|
||||
frame_duration,
|
||||
frame.size,
|
||||
rawmode,
|
||||
is_single_frame,
|
||||
)
|
||||
|
||||
# Update frame index
|
||||
frame_idx += 1
|
||||
|
||||
if not save_all:
|
||||
break
|
||||
|
||||
finally:
|
||||
im.seek(cur_idx)
|
||||
|
||||
# Get the final output from the encoder
|
||||
data = enc.finish()
|
||||
if data is None:
|
||||
msg = "cannot write file as AVIF (encoder returned None)"
|
||||
raise OSError(msg)
|
||||
|
||||
fp.write(data)
|
||||
|
||||
|
||||
Image.register_open(AvifImageFile.format, AvifImageFile, _accept)
|
||||
if SUPPORTED:
|
||||
Image.register_save(AvifImageFile.format, _save)
|
||||
Image.register_save_all(AvifImageFile.format, _save_all)
|
||||
Image.register_extensions(AvifImageFile.format, [".avif", ".avifs"])
|
||||
Image.register_mime(AvifImageFile.format, "image/avif")
|
||||
@@ -20,29 +20,30 @@
|
||||
"""
|
||||
Parse X Bitmap Distribution Format (BDF)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import BinaryIO
|
||||
|
||||
from . import FontFile, Image
|
||||
|
||||
bdf_slant = {
|
||||
"R": "Roman",
|
||||
"I": "Italic",
|
||||
"O": "Oblique",
|
||||
"RI": "Reverse Italic",
|
||||
"RO": "Reverse Oblique",
|
||||
"OT": "Other",
|
||||
}
|
||||
|
||||
bdf_spacing = {"P": "Proportional", "M": "Monospaced", "C": "Cell"}
|
||||
|
||||
|
||||
def bdf_char(f):
|
||||
def bdf_char(
|
||||
f: BinaryIO,
|
||||
) -> (
|
||||
tuple[
|
||||
str,
|
||||
int,
|
||||
tuple[tuple[int, int], tuple[int, int, int, int], tuple[int, int, int, int]],
|
||||
Image.Image,
|
||||
]
|
||||
| None
|
||||
):
|
||||
# skip to STARTCHAR
|
||||
while True:
|
||||
s = f.readline()
|
||||
if not s:
|
||||
return None
|
||||
if s[:9] == b"STARTCHAR":
|
||||
if s.startswith(b"STARTCHAR"):
|
||||
break
|
||||
id = s[9:].strip().decode("ascii")
|
||||
|
||||
@@ -50,19 +51,18 @@ def bdf_char(f):
|
||||
props = {}
|
||||
while True:
|
||||
s = f.readline()
|
||||
if not s or s[:6] == b"BITMAP":
|
||||
if not s or s.startswith(b"BITMAP"):
|
||||
break
|
||||
i = s.find(b" ")
|
||||
props[s[:i].decode("ascii")] = s[i + 1 : -1].decode("ascii")
|
||||
|
||||
# load bitmap
|
||||
bitmap = []
|
||||
bitmap = bytearray()
|
||||
while True:
|
||||
s = f.readline()
|
||||
if not s or s[:7] == b"ENDCHAR":
|
||||
if not s or s.startswith(b"ENDCHAR"):
|
||||
break
|
||||
bitmap.append(s[:-1])
|
||||
bitmap = b"".join(bitmap)
|
||||
bitmap += s[:-1]
|
||||
|
||||
# The word BBX
|
||||
# followed by the width in x (BBw), height in y (BBh),
|
||||
@@ -92,11 +92,11 @@ def bdf_char(f):
|
||||
class BdfFontFile(FontFile.FontFile):
|
||||
"""Font file plugin for the X11 BDF format."""
|
||||
|
||||
def __init__(self, fp):
|
||||
def __init__(self, fp: BinaryIO) -> None:
|
||||
super().__init__()
|
||||
|
||||
s = fp.readline()
|
||||
if s[:13] != b"STARTFONT 2.1":
|
||||
if not s.startswith(b"STARTFONT 2.1"):
|
||||
msg = "not a valid BDF file"
|
||||
raise SyntaxError(msg)
|
||||
|
||||
@@ -105,7 +105,7 @@ class BdfFontFile(FontFile.FontFile):
|
||||
|
||||
while True:
|
||||
s = fp.readline()
|
||||
if not s or s[:13] == b"ENDPROPERTIES":
|
||||
if not s or s.startswith(b"ENDPROPERTIES"):
|
||||
break
|
||||
i = s.find(b" ")
|
||||
props[s[:i].decode("ascii")] = s[i + 1 : -1].decode("ascii")
|
||||
|
||||
@@ -29,10 +29,14 @@ BLP files come in many different flavours:
|
||||
- DXT5 compression is used if alpha_encoding == 7.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import os
|
||||
import struct
|
||||
from enum import IntEnum
|
||||
from io import BytesIO
|
||||
from typing import IO
|
||||
|
||||
from . import Image, ImageFile
|
||||
|
||||
@@ -53,11 +57,13 @@ class AlphaEncoding(IntEnum):
|
||||
DXT5 = 7
|
||||
|
||||
|
||||
def unpack_565(i):
|
||||
def unpack_565(i: int) -> tuple[int, int, int]:
|
||||
return ((i >> 11) & 0x1F) << 3, ((i >> 5) & 0x3F) << 2, (i & 0x1F) << 3
|
||||
|
||||
|
||||
def decode_dxt1(data, alpha=False):
|
||||
def decode_dxt1(
|
||||
data: bytes, alpha: bool = False
|
||||
) -> tuple[bytearray, bytearray, bytearray, bytearray]:
|
||||
"""
|
||||
input: one "row" of data (i.e. will produce 4*width pixels)
|
||||
"""
|
||||
@@ -65,9 +71,9 @@ def decode_dxt1(data, alpha=False):
|
||||
blocks = len(data) // 8 # number of blocks in row
|
||||
ret = (bytearray(), bytearray(), bytearray(), bytearray())
|
||||
|
||||
for block in range(blocks):
|
||||
for block_index in range(blocks):
|
||||
# Decode next 8-byte block.
|
||||
idx = block * 8
|
||||
idx = block_index * 8
|
||||
color0, color1, bits = struct.unpack_from("<HHI", data, idx)
|
||||
|
||||
r0, g0, b0 = unpack_565(color0)
|
||||
@@ -112,7 +118,7 @@ def decode_dxt1(data, alpha=False):
|
||||
return ret
|
||||
|
||||
|
||||
def decode_dxt3(data):
|
||||
def decode_dxt3(data: bytes) -> tuple[bytearray, bytearray, bytearray, bytearray]:
|
||||
"""
|
||||
input: one "row" of data (i.e. will produce 4*width pixels)
|
||||
"""
|
||||
@@ -120,8 +126,8 @@ def decode_dxt3(data):
|
||||
blocks = len(data) // 16 # number of blocks in row
|
||||
ret = (bytearray(), bytearray(), bytearray(), bytearray())
|
||||
|
||||
for block in range(blocks):
|
||||
idx = block * 16
|
||||
for block_index in range(blocks):
|
||||
idx = block_index * 16
|
||||
block = data[idx : idx + 16]
|
||||
# Decode next 16-byte block.
|
||||
bits = struct.unpack_from("<8B", block)
|
||||
@@ -165,7 +171,7 @@ def decode_dxt3(data):
|
||||
return ret
|
||||
|
||||
|
||||
def decode_dxt5(data):
|
||||
def decode_dxt5(data: bytes) -> tuple[bytearray, bytearray, bytearray, bytearray]:
|
||||
"""
|
||||
input: one "row" of data (i.e. will produce 4 * width pixels)
|
||||
"""
|
||||
@@ -173,8 +179,8 @@ def decode_dxt5(data):
|
||||
blocks = len(data) // 16 # number of blocks in row
|
||||
ret = (bytearray(), bytearray(), bytearray(), bytearray())
|
||||
|
||||
for block in range(blocks):
|
||||
idx = block * 16
|
||||
for block_index in range(blocks):
|
||||
idx = block_index * 16
|
||||
block = data[idx : idx + 16]
|
||||
# Decode next 16-byte block.
|
||||
a0, a1 = struct.unpack_from("<BB", block)
|
||||
@@ -239,8 +245,8 @@ class BLPFormatError(NotImplementedError):
|
||||
pass
|
||||
|
||||
|
||||
def _accept(prefix):
|
||||
return prefix[:4] in (b"BLP1", b"BLP2")
|
||||
def _accept(prefix: bytes) -> bool:
|
||||
return prefix.startswith((b"BLP1", b"BLP2"))
|
||||
|
||||
|
||||
class BlpImageFile(ImageFile.ImageFile):
|
||||
@@ -251,60 +257,65 @@ class BlpImageFile(ImageFile.ImageFile):
|
||||
format = "BLP"
|
||||
format_description = "Blizzard Mipmap Format"
|
||||
|
||||
def _open(self):
|
||||
def _open(self) -> None:
|
||||
self.magic = self.fp.read(4)
|
||||
|
||||
self.fp.seek(5, os.SEEK_CUR)
|
||||
(self._blp_alpha_depth,) = struct.unpack("<b", self.fp.read(1))
|
||||
|
||||
self.fp.seek(2, os.SEEK_CUR)
|
||||
self._size = struct.unpack("<II", self.fp.read(8))
|
||||
|
||||
if self.magic in (b"BLP1", b"BLP2"):
|
||||
decoder = self.magic.decode()
|
||||
else:
|
||||
if not _accept(self.magic):
|
||||
msg = f"Bad BLP magic {repr(self.magic)}"
|
||||
raise BLPFormatError(msg)
|
||||
|
||||
self._mode = "RGBA" if self._blp_alpha_depth else "RGB"
|
||||
self.tile = [(decoder, (0, 0) + self.size, 0, (self.mode, 0, 1))]
|
||||
compression = struct.unpack("<i", self.fp.read(4))[0]
|
||||
if self.magic == b"BLP1":
|
||||
alpha = struct.unpack("<I", self.fp.read(4))[0] != 0
|
||||
else:
|
||||
encoding = struct.unpack("<b", self.fp.read(1))[0]
|
||||
alpha = struct.unpack("<b", self.fp.read(1))[0] != 0
|
||||
alpha_encoding = struct.unpack("<b", self.fp.read(1))[0]
|
||||
self.fp.seek(1, os.SEEK_CUR) # mips
|
||||
|
||||
self._size = struct.unpack("<II", self.fp.read(8))
|
||||
|
||||
args: tuple[int, int, bool] | tuple[int, int, bool, int]
|
||||
if self.magic == b"BLP1":
|
||||
encoding = struct.unpack("<i", self.fp.read(4))[0]
|
||||
self.fp.seek(4, os.SEEK_CUR) # subtype
|
||||
|
||||
args = (compression, encoding, alpha)
|
||||
offset = 28
|
||||
else:
|
||||
args = (compression, encoding, alpha, alpha_encoding)
|
||||
offset = 20
|
||||
|
||||
decoder = self.magic.decode()
|
||||
|
||||
self._mode = "RGBA" if alpha else "RGB"
|
||||
self.tile = [ImageFile._Tile(decoder, (0, 0) + self.size, offset, args)]
|
||||
|
||||
|
||||
class _BLPBaseDecoder(ImageFile.PyDecoder):
|
||||
class _BLPBaseDecoder(abc.ABC, ImageFile.PyDecoder):
|
||||
_pulls_fd = True
|
||||
|
||||
def decode(self, buffer):
|
||||
def decode(self, buffer: bytes | Image.SupportsArrayInterface) -> tuple[int, int]:
|
||||
try:
|
||||
self._read_blp_header()
|
||||
self._read_header()
|
||||
self._load()
|
||||
except struct.error as e:
|
||||
msg = "Truncated BLP file"
|
||||
raise OSError(msg) from e
|
||||
return -1, 0
|
||||
|
||||
def _read_blp_header(self):
|
||||
self.fd.seek(4)
|
||||
(self._blp_compression,) = struct.unpack("<i", self._safe_read(4))
|
||||
@abc.abstractmethod
|
||||
def _load(self) -> None:
|
||||
pass
|
||||
|
||||
(self._blp_encoding,) = struct.unpack("<b", self._safe_read(1))
|
||||
(self._blp_alpha_depth,) = struct.unpack("<b", self._safe_read(1))
|
||||
(self._blp_alpha_encoding,) = struct.unpack("<b", self._safe_read(1))
|
||||
self.fd.seek(1, os.SEEK_CUR) # mips
|
||||
def _read_header(self) -> None:
|
||||
self._offsets = struct.unpack("<16I", self._safe_read(16 * 4))
|
||||
self._lengths = struct.unpack("<16I", self._safe_read(16 * 4))
|
||||
|
||||
self.size = struct.unpack("<II", self._safe_read(8))
|
||||
|
||||
if isinstance(self, BLP1Decoder):
|
||||
# Only present for BLP1
|
||||
(self._blp_encoding,) = struct.unpack("<i", self._safe_read(4))
|
||||
self.fd.seek(4, os.SEEK_CUR) # subtype
|
||||
|
||||
self._blp_offsets = struct.unpack("<16I", self._safe_read(16 * 4))
|
||||
self._blp_lengths = struct.unpack("<16I", self._safe_read(16 * 4))
|
||||
|
||||
def _safe_read(self, length):
|
||||
def _safe_read(self, length: int) -> bytes:
|
||||
assert self.fd is not None
|
||||
return ImageFile._safe_read(self.fd, length)
|
||||
|
||||
def _read_palette(self):
|
||||
def _read_palette(self) -> list[tuple[int, int, int, int]]:
|
||||
ret = []
|
||||
for i in range(256):
|
||||
try:
|
||||
@@ -314,110 +325,115 @@ class _BLPBaseDecoder(ImageFile.PyDecoder):
|
||||
ret.append((b, g, r, a))
|
||||
return ret
|
||||
|
||||
def _read_bgra(self, palette):
|
||||
def _read_bgra(
|
||||
self, palette: list[tuple[int, int, int, int]], alpha: bool
|
||||
) -> bytearray:
|
||||
data = bytearray()
|
||||
_data = BytesIO(self._safe_read(self._blp_lengths[0]))
|
||||
_data = BytesIO(self._safe_read(self._lengths[0]))
|
||||
while True:
|
||||
try:
|
||||
(offset,) = struct.unpack("<B", _data.read(1))
|
||||
except struct.error:
|
||||
break
|
||||
b, g, r, a = palette[offset]
|
||||
d = (r, g, b)
|
||||
if self._blp_alpha_depth:
|
||||
d: tuple[int, ...] = (r, g, b)
|
||||
if alpha:
|
||||
d += (a,)
|
||||
data.extend(d)
|
||||
return data
|
||||
|
||||
|
||||
class BLP1Decoder(_BLPBaseDecoder):
|
||||
def _load(self):
|
||||
if self._blp_compression == Format.JPEG:
|
||||
def _load(self) -> None:
|
||||
self._compression, self._encoding, alpha = self.args
|
||||
|
||||
if self._compression == Format.JPEG:
|
||||
self._decode_jpeg_stream()
|
||||
|
||||
elif self._blp_compression == 1:
|
||||
if self._blp_encoding in (4, 5):
|
||||
elif self._compression == 1:
|
||||
if self._encoding in (4, 5):
|
||||
palette = self._read_palette()
|
||||
data = self._read_bgra(palette)
|
||||
self.set_as_raw(bytes(data))
|
||||
data = self._read_bgra(palette, alpha)
|
||||
self.set_as_raw(data)
|
||||
else:
|
||||
msg = f"Unsupported BLP encoding {repr(self._blp_encoding)}"
|
||||
msg = f"Unsupported BLP encoding {repr(self._encoding)}"
|
||||
raise BLPFormatError(msg)
|
||||
else:
|
||||
msg = f"Unsupported BLP compression {repr(self._blp_encoding)}"
|
||||
msg = f"Unsupported BLP compression {repr(self._encoding)}"
|
||||
raise BLPFormatError(msg)
|
||||
|
||||
def _decode_jpeg_stream(self):
|
||||
def _decode_jpeg_stream(self) -> None:
|
||||
from .JpegImagePlugin import JpegImageFile
|
||||
|
||||
(jpeg_header_size,) = struct.unpack("<I", self._safe_read(4))
|
||||
jpeg_header = self._safe_read(jpeg_header_size)
|
||||
self._safe_read(self._blp_offsets[0] - self.fd.tell()) # What IS this?
|
||||
data = self._safe_read(self._blp_lengths[0])
|
||||
assert self.fd is not None
|
||||
self._safe_read(self._offsets[0] - self.fd.tell()) # What IS this?
|
||||
data = self._safe_read(self._lengths[0])
|
||||
data = jpeg_header + data
|
||||
data = BytesIO(data)
|
||||
image = JpegImageFile(data)
|
||||
image = JpegImageFile(BytesIO(data))
|
||||
Image._decompression_bomb_check(image.size)
|
||||
if image.mode == "CMYK":
|
||||
decoder_name, extents, offset, args = image.tile[0]
|
||||
image.tile = [(decoder_name, extents, offset, (args[0], "CMYK"))]
|
||||
r, g, b = image.convert("RGB").split()
|
||||
image = Image.merge("RGB", (b, g, r))
|
||||
self.set_as_raw(image.tobytes())
|
||||
args = image.tile[0].args
|
||||
assert isinstance(args, tuple)
|
||||
image.tile = [image.tile[0]._replace(args=(args[0], "CMYK"))]
|
||||
self.set_as_raw(image.convert("RGB").tobytes(), "BGR")
|
||||
|
||||
|
||||
class BLP2Decoder(_BLPBaseDecoder):
|
||||
def _load(self):
|
||||
def _load(self) -> None:
|
||||
self._compression, self._encoding, alpha, self._alpha_encoding = self.args
|
||||
|
||||
palette = self._read_palette()
|
||||
|
||||
self.fd.seek(self._blp_offsets[0])
|
||||
assert self.fd is not None
|
||||
self.fd.seek(self._offsets[0])
|
||||
|
||||
if self._blp_compression == 1:
|
||||
if self._compression == 1:
|
||||
# Uncompressed or DirectX compression
|
||||
|
||||
if self._blp_encoding == Encoding.UNCOMPRESSED:
|
||||
data = self._read_bgra(palette)
|
||||
if self._encoding == Encoding.UNCOMPRESSED:
|
||||
data = self._read_bgra(palette, alpha)
|
||||
|
||||
elif self._blp_encoding == Encoding.DXT:
|
||||
elif self._encoding == Encoding.DXT:
|
||||
data = bytearray()
|
||||
if self._blp_alpha_encoding == AlphaEncoding.DXT1:
|
||||
linesize = (self.size[0] + 3) // 4 * 8
|
||||
for yb in range((self.size[1] + 3) // 4):
|
||||
for d in decode_dxt1(
|
||||
self._safe_read(linesize), alpha=bool(self._blp_alpha_depth)
|
||||
):
|
||||
if self._alpha_encoding == AlphaEncoding.DXT1:
|
||||
linesize = (self.state.xsize + 3) // 4 * 8
|
||||
for yb in range((self.state.ysize + 3) // 4):
|
||||
for d in decode_dxt1(self._safe_read(linesize), alpha):
|
||||
data += d
|
||||
|
||||
elif self._blp_alpha_encoding == AlphaEncoding.DXT3:
|
||||
linesize = (self.size[0] + 3) // 4 * 16
|
||||
for yb in range((self.size[1] + 3) // 4):
|
||||
elif self._alpha_encoding == AlphaEncoding.DXT3:
|
||||
linesize = (self.state.xsize + 3) // 4 * 16
|
||||
for yb in range((self.state.ysize + 3) // 4):
|
||||
for d in decode_dxt3(self._safe_read(linesize)):
|
||||
data += d
|
||||
|
||||
elif self._blp_alpha_encoding == AlphaEncoding.DXT5:
|
||||
linesize = (self.size[0] + 3) // 4 * 16
|
||||
for yb in range((self.size[1] + 3) // 4):
|
||||
elif self._alpha_encoding == AlphaEncoding.DXT5:
|
||||
linesize = (self.state.xsize + 3) // 4 * 16
|
||||
for yb in range((self.state.ysize + 3) // 4):
|
||||
for d in decode_dxt5(self._safe_read(linesize)):
|
||||
data += d
|
||||
else:
|
||||
msg = f"Unsupported alpha encoding {repr(self._blp_alpha_encoding)}"
|
||||
msg = f"Unsupported alpha encoding {repr(self._alpha_encoding)}"
|
||||
raise BLPFormatError(msg)
|
||||
else:
|
||||
msg = f"Unknown BLP encoding {repr(self._blp_encoding)}"
|
||||
msg = f"Unknown BLP encoding {repr(self._encoding)}"
|
||||
raise BLPFormatError(msg)
|
||||
|
||||
else:
|
||||
msg = f"Unknown BLP compression {repr(self._blp_compression)}"
|
||||
msg = f"Unknown BLP compression {repr(self._compression)}"
|
||||
raise BLPFormatError(msg)
|
||||
|
||||
self.set_as_raw(bytes(data))
|
||||
self.set_as_raw(data)
|
||||
|
||||
|
||||
class BLPEncoder(ImageFile.PyEncoder):
|
||||
_pushes_fd = True
|
||||
|
||||
def _write_palette(self):
|
||||
def _write_palette(self) -> bytes:
|
||||
data = b""
|
||||
assert self.im is not None
|
||||
palette = self.im.getpalette("RGBA", "RGBA")
|
||||
for i in range(len(palette) // 4):
|
||||
r, g, b, a = palette[i * 4 : (i + 1) * 4]
|
||||
@@ -426,12 +442,13 @@ class BLPEncoder(ImageFile.PyEncoder):
|
||||
data += b"\x00" * 4
|
||||
return data
|
||||
|
||||
def encode(self, bufsize):
|
||||
def encode(self, bufsize: int) -> tuple[int, int, bytes]:
|
||||
palette_data = self._write_palette()
|
||||
|
||||
offset = 20 + 16 * 4 * 2 + len(palette_data)
|
||||
data = struct.pack("<16I", offset, *((0,) * 15))
|
||||
|
||||
assert self.im is not None
|
||||
w, h = self.im.size
|
||||
data += struct.pack("<16I", w * h, *((0,) * 15))
|
||||
|
||||
@@ -444,7 +461,7 @@ class BLPEncoder(ImageFile.PyEncoder):
|
||||
return len(data), 0, data
|
||||
|
||||
|
||||
def _save(im, fp, filename):
|
||||
def _save(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
|
||||
if im.mode != "P":
|
||||
msg = "Unsupported BLP image mode"
|
||||
raise ValueError(msg)
|
||||
@@ -452,17 +469,23 @@ def _save(im, fp, filename):
|
||||
magic = b"BLP1" if im.encoderinfo.get("blp_version") == "BLP1" else b"BLP2"
|
||||
fp.write(magic)
|
||||
|
||||
assert im.palette is not None
|
||||
fp.write(struct.pack("<i", 1)) # Uncompressed or DirectX compression
|
||||
fp.write(struct.pack("<b", Encoding.UNCOMPRESSED))
|
||||
fp.write(struct.pack("<b", 1 if im.palette.mode == "RGBA" else 0))
|
||||
fp.write(struct.pack("<b", 0)) # alpha encoding
|
||||
fp.write(struct.pack("<b", 0)) # mips
|
||||
|
||||
alpha_depth = 1 if im.palette.mode == "RGBA" else 0
|
||||
if magic == b"BLP1":
|
||||
fp.write(struct.pack("<L", alpha_depth))
|
||||
else:
|
||||
fp.write(struct.pack("<b", Encoding.UNCOMPRESSED))
|
||||
fp.write(struct.pack("<b", alpha_depth))
|
||||
fp.write(struct.pack("<b", 0)) # alpha encoding
|
||||
fp.write(struct.pack("<b", 0)) # mips
|
||||
fp.write(struct.pack("<II", *im.size))
|
||||
if magic == b"BLP1":
|
||||
fp.write(struct.pack("<i", 5))
|
||||
fp.write(struct.pack("<i", 0))
|
||||
|
||||
ImageFile._save(im, fp, [("BLP", (0, 0) + im.size, 0, im.mode)])
|
||||
ImageFile._save(im, fp, [ImageFile._Tile("BLP", (0, 0) + im.size, 0, im.mode)])
|
||||
|
||||
|
||||
Image.register_open(BlpImageFile.format, BlpImageFile, _accept)
|
||||
|
||||
@@ -22,9 +22,10 @@
|
||||
#
|
||||
# See the README file for information on usage and redistribution.
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import IO, Any
|
||||
|
||||
from . import Image, ImageFile, ImagePalette
|
||||
from ._binary import i16le as i16
|
||||
@@ -47,13 +48,15 @@ BIT2MODE = {
|
||||
32: ("RGB", "BGRX"),
|
||||
}
|
||||
|
||||
|
||||
def _accept(prefix):
|
||||
return prefix[:2] == b"BM"
|
||||
USE_RAW_ALPHA = False
|
||||
|
||||
|
||||
def _dib_accept(prefix):
|
||||
return i32(prefix) in [12, 40, 64, 108, 124]
|
||||
def _accept(prefix: bytes) -> bool:
|
||||
return prefix.startswith(b"BM")
|
||||
|
||||
|
||||
def _dib_accept(prefix: bytes) -> bool:
|
||||
return i32(prefix) in [12, 40, 52, 56, 64, 108, 124]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -71,31 +74,41 @@ class BmpImageFile(ImageFile.ImageFile):
|
||||
for k, v in COMPRESSIONS.items():
|
||||
vars()[k] = v
|
||||
|
||||
def _bitmap(self, header=0, offset=0):
|
||||
def _bitmap(self, header: int = 0, offset: int = 0) -> None:
|
||||
"""Read relevant info about the BMP"""
|
||||
read, seek = self.fp.read, self.fp.seek
|
||||
if header:
|
||||
seek(header)
|
||||
# read bmp header size @offset 14 (this is part of the header size)
|
||||
file_info = {"header_size": i32(read(4)), "direction": -1}
|
||||
file_info: dict[str, bool | int | tuple[int, ...]] = {
|
||||
"header_size": i32(read(4)),
|
||||
"direction": -1,
|
||||
}
|
||||
|
||||
# -------------------- If requested, read header at a specific position
|
||||
# read the rest of the bmp header, without its size
|
||||
assert isinstance(file_info["header_size"], int)
|
||||
header_data = ImageFile._safe_read(self.fp, file_info["header_size"] - 4)
|
||||
|
||||
# -------------------------------------------------- IBM OS/2 Bitmap v1
|
||||
# ------------------------------- Windows Bitmap v2, IBM OS/2 Bitmap v1
|
||||
# ----- This format has different offsets because of width/height types
|
||||
# 12: BITMAPCOREHEADER/OS21XBITMAPHEADER
|
||||
if file_info["header_size"] == 12:
|
||||
file_info["width"] = i16(header_data, 0)
|
||||
file_info["height"] = i16(header_data, 2)
|
||||
file_info["planes"] = i16(header_data, 4)
|
||||
file_info["bits"] = i16(header_data, 6)
|
||||
file_info["compression"] = self.RAW
|
||||
file_info["compression"] = self.COMPRESSIONS["RAW"]
|
||||
file_info["palette_padding"] = 3
|
||||
|
||||
# --------------------------------------------- Windows Bitmap v2 to v5
|
||||
# v3, OS/2 v2, v4, v5
|
||||
elif file_info["header_size"] in (40, 64, 108, 124):
|
||||
# --------------------------------------------- Windows Bitmap v3 to v5
|
||||
# 40: BITMAPINFOHEADER
|
||||
# 52: BITMAPV2HEADER
|
||||
# 56: BITMAPV3HEADER
|
||||
# 64: BITMAPCOREHEADER2/OS22XBITMAPHEADER
|
||||
# 108: BITMAPV4HEADER
|
||||
# 124: BITMAPV5HEADER
|
||||
elif file_info["header_size"] in (40, 52, 56, 64, 108, 124):
|
||||
file_info["y_flip"] = header_data[7] == 0xFF
|
||||
file_info["direction"] = 1 if file_info["y_flip"] else -1
|
||||
file_info["width"] = i32(header_data, 0)
|
||||
@@ -115,12 +128,16 @@ class BmpImageFile(ImageFile.ImageFile):
|
||||
)
|
||||
file_info["colors"] = i32(header_data, 28)
|
||||
file_info["palette_padding"] = 4
|
||||
assert isinstance(file_info["pixels_per_meter"], tuple)
|
||||
self.info["dpi"] = tuple(x / 39.3701 for x in file_info["pixels_per_meter"])
|
||||
if file_info["compression"] == self.BITFIELDS:
|
||||
if len(header_data) >= 52:
|
||||
for idx, mask in enumerate(
|
||||
["r_mask", "g_mask", "b_mask", "a_mask"]
|
||||
):
|
||||
if file_info["compression"] == self.COMPRESSIONS["BITFIELDS"]:
|
||||
masks = ["r_mask", "g_mask", "b_mask"]
|
||||
if len(header_data) >= 48:
|
||||
if len(header_data) >= 52:
|
||||
masks.append("a_mask")
|
||||
else:
|
||||
file_info["a_mask"] = 0x0
|
||||
for idx, mask in enumerate(masks):
|
||||
file_info[mask] = i32(header_data, 36 + idx * 4)
|
||||
else:
|
||||
# 40 byte headers only have the three components in the
|
||||
@@ -132,8 +149,12 @@ class BmpImageFile(ImageFile.ImageFile):
|
||||
# location, but it is listed as a reserved component,
|
||||
# and it is not generally an alpha channel
|
||||
file_info["a_mask"] = 0x0
|
||||
for mask in ["r_mask", "g_mask", "b_mask"]:
|
||||
for mask in masks:
|
||||
file_info[mask] = i32(read(4))
|
||||
assert isinstance(file_info["r_mask"], int)
|
||||
assert isinstance(file_info["g_mask"], int)
|
||||
assert isinstance(file_info["b_mask"], int)
|
||||
assert isinstance(file_info["a_mask"], int)
|
||||
file_info["rgb_mask"] = (
|
||||
file_info["r_mask"],
|
||||
file_info["g_mask"],
|
||||
@@ -151,33 +172,39 @@ class BmpImageFile(ImageFile.ImageFile):
|
||||
|
||||
# ------------------ Special case : header is reported 40, which
|
||||
# ---------------------- is shorter than real size for bpp >= 16
|
||||
assert isinstance(file_info["width"], int)
|
||||
assert isinstance(file_info["height"], int)
|
||||
self._size = file_info["width"], file_info["height"]
|
||||
|
||||
# ------- If color count was not found in the header, compute from bits
|
||||
assert isinstance(file_info["bits"], int)
|
||||
file_info["colors"] = (
|
||||
file_info["colors"]
|
||||
if file_info.get("colors", 0)
|
||||
else (1 << file_info["bits"])
|
||||
)
|
||||
assert isinstance(file_info["colors"], int)
|
||||
if offset == 14 + file_info["header_size"] and file_info["bits"] <= 8:
|
||||
offset += 4 * file_info["colors"]
|
||||
|
||||
# ---------------------- Check bit depth for unusual unsupported values
|
||||
self._mode, raw_mode = BIT2MODE.get(file_info["bits"], (None, None))
|
||||
if self.mode is None:
|
||||
self._mode, raw_mode = BIT2MODE.get(file_info["bits"], ("", ""))
|
||||
if not self.mode:
|
||||
msg = f"Unsupported BMP pixel depth ({file_info['bits']})"
|
||||
raise OSError(msg)
|
||||
|
||||
# ---------------- Process BMP with Bitfields compression (not palette)
|
||||
decoder_name = "raw"
|
||||
if file_info["compression"] == self.BITFIELDS:
|
||||
SUPPORTED = {
|
||||
if file_info["compression"] == self.COMPRESSIONS["BITFIELDS"]:
|
||||
SUPPORTED: dict[int, list[tuple[int, ...]]] = {
|
||||
32: [
|
||||
(0xFF0000, 0xFF00, 0xFF, 0x0),
|
||||
(0xFF000000, 0xFF0000, 0xFF00, 0x0),
|
||||
(0xFF000000, 0xFF00, 0xFF, 0x0),
|
||||
(0xFF000000, 0xFF0000, 0xFF00, 0xFF),
|
||||
(0xFF, 0xFF00, 0xFF0000, 0xFF000000),
|
||||
(0xFF0000, 0xFF00, 0xFF, 0xFF000000),
|
||||
(0xFF000000, 0xFF00, 0xFF, 0xFF0000),
|
||||
(0x0, 0x0, 0x0, 0x0),
|
||||
],
|
||||
24: [(0xFF0000, 0xFF00, 0xFF)],
|
||||
@@ -186,9 +213,11 @@ class BmpImageFile(ImageFile.ImageFile):
|
||||
MASK_MODES = {
|
||||
(32, (0xFF0000, 0xFF00, 0xFF, 0x0)): "BGRX",
|
||||
(32, (0xFF000000, 0xFF0000, 0xFF00, 0x0)): "XBGR",
|
||||
(32, (0xFF000000, 0xFF00, 0xFF, 0x0)): "BGXR",
|
||||
(32, (0xFF000000, 0xFF0000, 0xFF00, 0xFF)): "ABGR",
|
||||
(32, (0xFF, 0xFF00, 0xFF0000, 0xFF000000)): "RGBA",
|
||||
(32, (0xFF0000, 0xFF00, 0xFF, 0xFF000000)): "BGRA",
|
||||
(32, (0xFF000000, 0xFF00, 0xFF, 0xFF0000)): "BGAR",
|
||||
(32, (0x0, 0x0, 0x0, 0x0)): "BGRA",
|
||||
(24, (0xFF0000, 0xFF00, 0xFF)): "BGR",
|
||||
(16, (0xF800, 0x7E0, 0x1F)): "BGR;16",
|
||||
@@ -199,12 +228,14 @@ class BmpImageFile(ImageFile.ImageFile):
|
||||
file_info["bits"] == 32
|
||||
and file_info["rgba_mask"] in SUPPORTED[file_info["bits"]]
|
||||
):
|
||||
assert isinstance(file_info["rgba_mask"], tuple)
|
||||
raw_mode = MASK_MODES[(file_info["bits"], file_info["rgba_mask"])]
|
||||
self._mode = "RGBA" if "A" in raw_mode else self.mode
|
||||
elif (
|
||||
file_info["bits"] in (24, 16)
|
||||
and file_info["rgb_mask"] in SUPPORTED[file_info["bits"]]
|
||||
):
|
||||
assert isinstance(file_info["rgb_mask"], tuple)
|
||||
raw_mode = MASK_MODES[(file_info["bits"], file_info["rgb_mask"])]
|
||||
else:
|
||||
msg = "Unsupported BMP bitfields layout"
|
||||
@@ -212,10 +243,15 @@ class BmpImageFile(ImageFile.ImageFile):
|
||||
else:
|
||||
msg = "Unsupported BMP bitfields layout"
|
||||
raise OSError(msg)
|
||||
elif file_info["compression"] == self.RAW:
|
||||
if file_info["bits"] == 32 and header == 22: # 32-bit .cur offset
|
||||
elif file_info["compression"] == self.COMPRESSIONS["RAW"]:
|
||||
if file_info["bits"] == 32 and (
|
||||
header == 22 or USE_RAW_ALPHA # 32-bit .cur offset
|
||||
):
|
||||
raw_mode, self._mode = "BGRA", "RGBA"
|
||||
elif file_info["compression"] in (self.RLE8, self.RLE4):
|
||||
elif file_info["compression"] in (
|
||||
self.COMPRESSIONS["RLE8"],
|
||||
self.COMPRESSIONS["RLE4"],
|
||||
):
|
||||
decoder_name = "bmp_rle"
|
||||
else:
|
||||
msg = f"Unsupported BMP compression ({file_info['compression']})"
|
||||
@@ -228,23 +264,24 @@ class BmpImageFile(ImageFile.ImageFile):
|
||||
msg = f"Unsupported BMP Palette size ({file_info['colors']})"
|
||||
raise OSError(msg)
|
||||
else:
|
||||
assert isinstance(file_info["palette_padding"], int)
|
||||
padding = file_info["palette_padding"]
|
||||
palette = read(padding * file_info["colors"])
|
||||
greyscale = True
|
||||
grayscale = True
|
||||
indices = (
|
||||
(0, 255)
|
||||
if file_info["colors"] == 2
|
||||
else list(range(file_info["colors"]))
|
||||
)
|
||||
|
||||
# ----------------- Check if greyscale and ignore palette if so
|
||||
# ----------------- Check if grayscale and ignore palette if so
|
||||
for ind, val in enumerate(indices):
|
||||
rgb = palette[ind * padding : ind * padding + 3]
|
||||
if rgb != o8(val) * 3:
|
||||
greyscale = False
|
||||
grayscale = False
|
||||
|
||||
# ------- If all colors are grey, white or black, ditch palette
|
||||
if greyscale:
|
||||
# ------- If all colors are gray, white or black, ditch palette
|
||||
if grayscale:
|
||||
self._mode = "1" if file_info["colors"] == 2 else "L"
|
||||
raw_mode = self.mode
|
||||
else:
|
||||
@@ -255,14 +292,15 @@ class BmpImageFile(ImageFile.ImageFile):
|
||||
|
||||
# ---------------------------- Finally set the tile data for the plugin
|
||||
self.info["compression"] = file_info["compression"]
|
||||
args = [raw_mode]
|
||||
args: list[Any] = [raw_mode]
|
||||
if decoder_name == "bmp_rle":
|
||||
args.append(file_info["compression"] == self.RLE4)
|
||||
args.append(file_info["compression"] == self.COMPRESSIONS["RLE4"])
|
||||
else:
|
||||
assert isinstance(file_info["width"], int)
|
||||
args.append(((file_info["width"] * file_info["bits"] + 31) >> 3) & (~3))
|
||||
args.append(file_info["direction"])
|
||||
self.tile = [
|
||||
(
|
||||
ImageFile._Tile(
|
||||
decoder_name,
|
||||
(0, 0, file_info["width"], file_info["height"]),
|
||||
offset or self.fp.tell(),
|
||||
@@ -270,7 +308,7 @@ class BmpImageFile(ImageFile.ImageFile):
|
||||
)
|
||||
]
|
||||
|
||||
def _open(self):
|
||||
def _open(self) -> None:
|
||||
"""Open file, check magic number and read header"""
|
||||
# read 14 bytes: magic number, filesize, reserved, header final offset
|
||||
head_data = self.fp.read(14)
|
||||
@@ -287,11 +325,13 @@ class BmpImageFile(ImageFile.ImageFile):
|
||||
class BmpRleDecoder(ImageFile.PyDecoder):
|
||||
_pulls_fd = True
|
||||
|
||||
def decode(self, buffer):
|
||||
def decode(self, buffer: bytes | Image.SupportsArrayInterface) -> tuple[int, int]:
|
||||
assert self.fd is not None
|
||||
rle4 = self.args[1]
|
||||
data = bytearray()
|
||||
x = 0
|
||||
while len(data) < self.state.xsize * self.state.ysize:
|
||||
dest_length = self.state.xsize * self.state.ysize
|
||||
while len(data) < dest_length:
|
||||
pixels = self.fd.read(1)
|
||||
byte = self.fd.read(1)
|
||||
if not pixels or not byte:
|
||||
@@ -351,7 +391,7 @@ class BmpRleDecoder(ImageFile.PyDecoder):
|
||||
if self.fd.tell() % 2 != 0:
|
||||
self.fd.seek(1, os.SEEK_CUR)
|
||||
rawmode = "L" if self.mode == "L" else "P"
|
||||
self.set_as_raw(bytes(data), (rawmode, 0, self.args[-1]))
|
||||
self.set_as_raw(bytes(data), rawmode, (0, self.args[-1]))
|
||||
return -1, 0
|
||||
|
||||
|
||||
@@ -362,7 +402,7 @@ class DibImageFile(BmpImageFile):
|
||||
format = "DIB"
|
||||
format_description = "Windows Bitmap"
|
||||
|
||||
def _open(self):
|
||||
def _open(self) -> None:
|
||||
self._bitmap()
|
||||
|
||||
|
||||
@@ -380,11 +420,13 @@ SAVE = {
|
||||
}
|
||||
|
||||
|
||||
def _dib_save(im, fp, filename):
|
||||
def _dib_save(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
|
||||
_save(im, fp, filename, False)
|
||||
|
||||
|
||||
def _save(im, fp, filename, bitmap_header=True):
|
||||
def _save(
|
||||
im: Image.Image, fp: IO[bytes], filename: str | bytes, bitmap_header: bool = True
|
||||
) -> None:
|
||||
try:
|
||||
rawmode, bits, colors = SAVE[im.mode]
|
||||
except KeyError as e:
|
||||
@@ -396,16 +438,16 @@ def _save(im, fp, filename, bitmap_header=True):
|
||||
dpi = info.get("dpi", (96, 96))
|
||||
|
||||
# 1 meter == 39.3701 inches
|
||||
ppm = tuple(map(lambda x: int(x * 39.3701 + 0.5), dpi))
|
||||
ppm = tuple(int(x * 39.3701 + 0.5) for x in dpi)
|
||||
|
||||
stride = ((im.size[0] * bits + 7) // 8 + 3) & (~3)
|
||||
header = 40 # or 64 for OS/2 version 2
|
||||
image = stride * im.size[1]
|
||||
|
||||
if im.mode == "1":
|
||||
palette = b"".join(o8(i) * 4 for i in (0, 255))
|
||||
palette = b"".join(o8(i) * 3 + b"\x00" for i in (0, 255))
|
||||
elif im.mode == "L":
|
||||
palette = b"".join(o8(i) * 4 for i in range(256))
|
||||
palette = b"".join(o8(i) * 3 + b"\x00" for i in range(256))
|
||||
elif im.mode == "P":
|
||||
palette = im.im.getpalette("RGB", "BGRX")
|
||||
colors = len(palette) // 4
|
||||
@@ -446,7 +488,9 @@ def _save(im, fp, filename, bitmap_header=True):
|
||||
if palette:
|
||||
fp.write(palette)
|
||||
|
||||
ImageFile._save(im, fp, [("raw", (0, 0) + im.size, 0, (rawmode, stride, -1))])
|
||||
ImageFile._save(
|
||||
im, fp, [ImageFile._Tile("raw", (0, 0) + im.size, 0, (rawmode, stride, -1))]
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
|
||||
@@ -8,13 +8,17 @@
|
||||
#
|
||||
# See the README file for information on usage and redistribution.
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import IO
|
||||
|
||||
from . import Image, ImageFile
|
||||
|
||||
_handler = None
|
||||
|
||||
|
||||
def register_handler(handler):
|
||||
def register_handler(handler: ImageFile.StubHandler | None) -> None:
|
||||
"""
|
||||
Install application-specific BUFR image handler.
|
||||
|
||||
@@ -28,22 +32,20 @@ def register_handler(handler):
|
||||
# Image adapter
|
||||
|
||||
|
||||
def _accept(prefix):
|
||||
return prefix[:4] == b"BUFR" or prefix[:4] == b"ZCZC"
|
||||
def _accept(prefix: bytes) -> bool:
|
||||
return prefix.startswith((b"BUFR", b"ZCZC"))
|
||||
|
||||
|
||||
class BufrStubImageFile(ImageFile.StubImageFile):
|
||||
format = "BUFR"
|
||||
format_description = "BUFR"
|
||||
|
||||
def _open(self):
|
||||
offset = self.fp.tell()
|
||||
|
||||
def _open(self) -> None:
|
||||
if not _accept(self.fp.read(4)):
|
||||
msg = "Not a BUFR file"
|
||||
raise SyntaxError(msg)
|
||||
|
||||
self.fp.seek(offset)
|
||||
self.fp.seek(-4, os.SEEK_CUR)
|
||||
|
||||
# make something up
|
||||
self._mode = "F"
|
||||
@@ -53,11 +55,11 @@ class BufrStubImageFile(ImageFile.StubImageFile):
|
||||
if loader:
|
||||
loader.open(self)
|
||||
|
||||
def _load(self):
|
||||
def _load(self) -> ImageFile.StubHandler | None:
|
||||
return _handler
|
||||
|
||||
|
||||
def _save(im, fp, filename):
|
||||
def _save(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
|
||||
if _handler is None or not hasattr(_handler, "save"):
|
||||
msg = "BUFR save handler not installed"
|
||||
raise OSError(msg)
|
||||
|
||||
@@ -13,18 +13,20 @@
|
||||
#
|
||||
# See the README file for information on usage and redistribution.
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from collections.abc import Iterable
|
||||
from typing import IO, AnyStr, NoReturn
|
||||
|
||||
|
||||
class ContainerIO:
|
||||
class ContainerIO(IO[AnyStr]):
|
||||
"""
|
||||
A file object that provides read access to a part of an existing
|
||||
file (for example a TAR file).
|
||||
"""
|
||||
|
||||
def __init__(self, file, offset, length):
|
||||
def __init__(self, file: IO[AnyStr], offset: int, length: int) -> None:
|
||||
"""
|
||||
Create file object.
|
||||
|
||||
@@ -32,7 +34,7 @@ class ContainerIO:
|
||||
:param offset: Start of region, in bytes.
|
||||
:param length: Size of region, in bytes.
|
||||
"""
|
||||
self.fh = file
|
||||
self.fh: IO[AnyStr] = file
|
||||
self.pos = 0
|
||||
self.offset = offset
|
||||
self.length = length
|
||||
@@ -41,10 +43,13 @@ class ContainerIO:
|
||||
##
|
||||
# Always false.
|
||||
|
||||
def isatty(self):
|
||||
def isatty(self) -> bool:
|
||||
return False
|
||||
|
||||
def seek(self, offset, mode=io.SEEK_SET):
|
||||
def seekable(self) -> bool:
|
||||
return True
|
||||
|
||||
def seek(self, offset: int, mode: int = io.SEEK_SET) -> int:
|
||||
"""
|
||||
Move file pointer.
|
||||
|
||||
@@ -52,6 +57,7 @@ class ContainerIO:
|
||||
:param mode: Starting position. Use 0 for beginning of region, 1
|
||||
for current offset, and 2 for end of region. You cannot move
|
||||
the pointer outside the defined region.
|
||||
:returns: Offset from start of region, in bytes.
|
||||
"""
|
||||
if mode == 1:
|
||||
self.pos = self.pos + offset
|
||||
@@ -62,8 +68,9 @@ class ContainerIO:
|
||||
# clamp
|
||||
self.pos = max(0, min(self.pos, self.length))
|
||||
self.fh.seek(self.offset + self.pos)
|
||||
return self.pos
|
||||
|
||||
def tell(self):
|
||||
def tell(self) -> int:
|
||||
"""
|
||||
Get current file pointer.
|
||||
|
||||
@@ -71,44 +78,51 @@ class ContainerIO:
|
||||
"""
|
||||
return self.pos
|
||||
|
||||
def read(self, n=0):
|
||||
def readable(self) -> bool:
|
||||
return True
|
||||
|
||||
def read(self, n: int = -1) -> AnyStr:
|
||||
"""
|
||||
Read data.
|
||||
|
||||
:param n: Number of bytes to read. If omitted or zero,
|
||||
:param n: Number of bytes to read. If omitted, zero or negative,
|
||||
read until end of region.
|
||||
:returns: An 8-bit string.
|
||||
"""
|
||||
if n:
|
||||
if n > 0:
|
||||
n = min(n, self.length - self.pos)
|
||||
else:
|
||||
n = self.length - self.pos
|
||||
if not n: # EOF
|
||||
return b"" if "b" in self.fh.mode else ""
|
||||
if n <= 0: # EOF
|
||||
return b"" if "b" in self.fh.mode else "" # type: ignore[return-value]
|
||||
self.pos = self.pos + n
|
||||
return self.fh.read(n)
|
||||
|
||||
def readline(self):
|
||||
def readline(self, n: int = -1) -> AnyStr:
|
||||
"""
|
||||
Read a line of text.
|
||||
|
||||
:param n: Number of bytes to read. If omitted, zero or negative,
|
||||
read until end of line.
|
||||
:returns: An 8-bit string.
|
||||
"""
|
||||
s = b"" if "b" in self.fh.mode else ""
|
||||
s: AnyStr = b"" if "b" in self.fh.mode else "" # type: ignore[assignment]
|
||||
newline_character = b"\n" if "b" in self.fh.mode else "\n"
|
||||
while True:
|
||||
c = self.read(1)
|
||||
if not c:
|
||||
break
|
||||
s = s + c
|
||||
if c == newline_character:
|
||||
if c == newline_character or len(s) == n:
|
||||
break
|
||||
return s
|
||||
|
||||
def readlines(self):
|
||||
def readlines(self, n: int | None = -1) -> list[AnyStr]:
|
||||
"""
|
||||
Read multiple lines of text.
|
||||
|
||||
:param n: Number of lines to read. If omitted, zero, negative or None,
|
||||
read until end of region.
|
||||
:returns: A list of 8-bit strings.
|
||||
"""
|
||||
lines = []
|
||||
@@ -117,4 +131,43 @@ class ContainerIO:
|
||||
if not s:
|
||||
break
|
||||
lines.append(s)
|
||||
if len(lines) == n:
|
||||
break
|
||||
return lines
|
||||
|
||||
def writable(self) -> bool:
|
||||
return False
|
||||
|
||||
def write(self, b: AnyStr) -> NoReturn:
|
||||
raise NotImplementedError()
|
||||
|
||||
def writelines(self, lines: Iterable[AnyStr]) -> NoReturn:
|
||||
raise NotImplementedError()
|
||||
|
||||
def truncate(self, size: int | None = None) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __enter__(self) -> ContainerIO[AnyStr]:
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: object) -> None:
|
||||
self.close()
|
||||
|
||||
def __iter__(self) -> ContainerIO[AnyStr]:
|
||||
return self
|
||||
|
||||
def __next__(self) -> AnyStr:
|
||||
line = self.readline()
|
||||
if not line:
|
||||
msg = "end of region"
|
||||
raise StopIteration(msg)
|
||||
return line
|
||||
|
||||
def fileno(self) -> int:
|
||||
return self.fh.fileno()
|
||||
|
||||
def flush(self) -> None:
|
||||
self.fh.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
self.fh.close()
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
#
|
||||
# See the README file for information on usage and redistribution.
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
from . import BmpImagePlugin, Image
|
||||
from ._binary import i16le as i16
|
||||
from ._binary import i32le as i32
|
||||
@@ -23,8 +25,8 @@ from ._binary import i32le as i32
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
|
||||
def _accept(prefix):
|
||||
return prefix[:4] == b"\0\0\2\0"
|
||||
def _accept(prefix: bytes) -> bool:
|
||||
return prefix.startswith(b"\0\0\2\0")
|
||||
|
||||
|
||||
##
|
||||
@@ -35,7 +37,8 @@ class CurImageFile(BmpImagePlugin.BmpImageFile):
|
||||
format = "CUR"
|
||||
format_description = "Windows Cursor"
|
||||
|
||||
def _open(self):
|
||||
def _open(self) -> None:
|
||||
assert self.fp is not None
|
||||
offset = self.fp.tell()
|
||||
|
||||
# check magic
|
||||
@@ -61,10 +64,7 @@ class CurImageFile(BmpImagePlugin.BmpImageFile):
|
||||
|
||||
# patch up the bitmap height
|
||||
self._size = self.size[0], self.size[1] // 2
|
||||
d, e, o, a = self.tile[0]
|
||||
self.tile[0] = d, (0, 0) + self.size, o, a
|
||||
|
||||
return
|
||||
self.tile = [self.tile[0]._replace(extents=(0, 0) + self.size)]
|
||||
|
||||
|
||||
#
|
||||
|
||||
@@ -20,15 +20,17 @@
|
||||
#
|
||||
# See the README file for information on usage and redistribution.
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
from . import Image
|
||||
from ._binary import i32le as i32
|
||||
from ._util import DeferredError
|
||||
from .PcxImagePlugin import PcxImageFile
|
||||
|
||||
MAGIC = 0x3ADE68B1 # QUIZ: what's this value, then?
|
||||
|
||||
|
||||
def _accept(prefix):
|
||||
def _accept(prefix: bytes) -> bool:
|
||||
return len(prefix) >= 4 and i32(prefix) == MAGIC
|
||||
|
||||
|
||||
@@ -41,7 +43,7 @@ class DcxImageFile(PcxImageFile):
|
||||
format_description = "Intel DCX"
|
||||
_close_exclusive_fp_after_loading = False
|
||||
|
||||
def _open(self):
|
||||
def _open(self) -> None:
|
||||
# Header
|
||||
s = self.fp.read(4)
|
||||
if not _accept(s):
|
||||
@@ -57,20 +59,22 @@ class DcxImageFile(PcxImageFile):
|
||||
self._offset.append(offset)
|
||||
|
||||
self._fp = self.fp
|
||||
self.frame = None
|
||||
self.frame = -1
|
||||
self.n_frames = len(self._offset)
|
||||
self.is_animated = self.n_frames > 1
|
||||
self.seek(0)
|
||||
|
||||
def seek(self, frame):
|
||||
def seek(self, frame: int) -> None:
|
||||
if not self._seek_check(frame):
|
||||
return
|
||||
if isinstance(self._fp, DeferredError):
|
||||
raise self._fp.ex
|
||||
self.frame = frame
|
||||
self.fp = self._fp
|
||||
self.fp.seek(self._offset[frame])
|
||||
PcxImageFile._open(self)
|
||||
|
||||
def tell(self):
|
||||
def tell(self) -> int:
|
||||
return self.frame
|
||||
|
||||
|
||||
|
||||
@@ -1,118 +1,338 @@
|
||||
"""
|
||||
A Pillow loader for .dds files (S3TC-compressed aka DXTC)
|
||||
A Pillow plugin for .dds files (S3TC-compressed aka DXTC)
|
||||
Jerome Leclanche <jerome@leclan.ch>
|
||||
|
||||
Documentation:
|
||||
https://web.archive.org/web/20170802060935/http://oss.sgi.com/projects/ogl-sample/registry/EXT/texture_compression_s3tc.txt
|
||||
https://web.archive.org/web/20170802060935/http://oss.sgi.com/projects/ogl-sample/registry/EXT/texture_compression_s3tc.txt
|
||||
|
||||
The contents of this file are hereby released in the public domain (CC0)
|
||||
Full text of the CC0 license:
|
||||
https://creativecommons.org/publicdomain/zero/1.0/
|
||||
https://creativecommons.org/publicdomain/zero/1.0/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import struct
|
||||
from io import BytesIO
|
||||
import sys
|
||||
from enum import IntEnum, IntFlag
|
||||
from typing import IO
|
||||
|
||||
from . import Image, ImageFile, ImagePalette
|
||||
from ._binary import i32le as i32
|
||||
from ._binary import o8
|
||||
from ._binary import o32le as o32
|
||||
|
||||
# Magic ("DDS ")
|
||||
DDS_MAGIC = 0x20534444
|
||||
|
||||
|
||||
# DDS flags
|
||||
DDSD_CAPS = 0x1
|
||||
DDSD_HEIGHT = 0x2
|
||||
DDSD_WIDTH = 0x4
|
||||
DDSD_PITCH = 0x8
|
||||
DDSD_PIXELFORMAT = 0x1000
|
||||
DDSD_MIPMAPCOUNT = 0x20000
|
||||
DDSD_LINEARSIZE = 0x80000
|
||||
DDSD_DEPTH = 0x800000
|
||||
class DDSD(IntFlag):
|
||||
CAPS = 0x1
|
||||
HEIGHT = 0x2
|
||||
WIDTH = 0x4
|
||||
PITCH = 0x8
|
||||
PIXELFORMAT = 0x1000
|
||||
MIPMAPCOUNT = 0x20000
|
||||
LINEARSIZE = 0x80000
|
||||
DEPTH = 0x800000
|
||||
|
||||
|
||||
# DDS caps
|
||||
DDSCAPS_COMPLEX = 0x8
|
||||
DDSCAPS_TEXTURE = 0x1000
|
||||
DDSCAPS_MIPMAP = 0x400000
|
||||
class DDSCAPS(IntFlag):
|
||||
COMPLEX = 0x8
|
||||
TEXTURE = 0x1000
|
||||
MIPMAP = 0x400000
|
||||
|
||||
|
||||
class DDSCAPS2(IntFlag):
|
||||
CUBEMAP = 0x200
|
||||
CUBEMAP_POSITIVEX = 0x400
|
||||
CUBEMAP_NEGATIVEX = 0x800
|
||||
CUBEMAP_POSITIVEY = 0x1000
|
||||
CUBEMAP_NEGATIVEY = 0x2000
|
||||
CUBEMAP_POSITIVEZ = 0x4000
|
||||
CUBEMAP_NEGATIVEZ = 0x8000
|
||||
VOLUME = 0x200000
|
||||
|
||||
DDSCAPS2_CUBEMAP = 0x200
|
||||
DDSCAPS2_CUBEMAP_POSITIVEX = 0x400
|
||||
DDSCAPS2_CUBEMAP_NEGATIVEX = 0x800
|
||||
DDSCAPS2_CUBEMAP_POSITIVEY = 0x1000
|
||||
DDSCAPS2_CUBEMAP_NEGATIVEY = 0x2000
|
||||
DDSCAPS2_CUBEMAP_POSITIVEZ = 0x4000
|
||||
DDSCAPS2_CUBEMAP_NEGATIVEZ = 0x8000
|
||||
DDSCAPS2_VOLUME = 0x200000
|
||||
|
||||
# Pixel Format
|
||||
DDPF_ALPHAPIXELS = 0x1
|
||||
DDPF_ALPHA = 0x2
|
||||
DDPF_FOURCC = 0x4
|
||||
DDPF_PALETTEINDEXED8 = 0x20
|
||||
DDPF_RGB = 0x40
|
||||
DDPF_LUMINANCE = 0x20000
|
||||
|
||||
|
||||
# dds.h
|
||||
|
||||
DDS_FOURCC = DDPF_FOURCC
|
||||
DDS_RGB = DDPF_RGB
|
||||
DDS_RGBA = DDPF_RGB | DDPF_ALPHAPIXELS
|
||||
DDS_LUMINANCE = DDPF_LUMINANCE
|
||||
DDS_LUMINANCEA = DDPF_LUMINANCE | DDPF_ALPHAPIXELS
|
||||
DDS_ALPHA = DDPF_ALPHA
|
||||
DDS_PAL8 = DDPF_PALETTEINDEXED8
|
||||
|
||||
DDS_HEADER_FLAGS_TEXTURE = DDSD_CAPS | DDSD_HEIGHT | DDSD_WIDTH | DDSD_PIXELFORMAT
|
||||
DDS_HEADER_FLAGS_MIPMAP = DDSD_MIPMAPCOUNT
|
||||
DDS_HEADER_FLAGS_VOLUME = DDSD_DEPTH
|
||||
DDS_HEADER_FLAGS_PITCH = DDSD_PITCH
|
||||
DDS_HEADER_FLAGS_LINEARSIZE = DDSD_LINEARSIZE
|
||||
|
||||
DDS_HEIGHT = DDSD_HEIGHT
|
||||
DDS_WIDTH = DDSD_WIDTH
|
||||
|
||||
DDS_SURFACE_FLAGS_TEXTURE = DDSCAPS_TEXTURE
|
||||
DDS_SURFACE_FLAGS_MIPMAP = DDSCAPS_COMPLEX | DDSCAPS_MIPMAP
|
||||
DDS_SURFACE_FLAGS_CUBEMAP = DDSCAPS_COMPLEX
|
||||
|
||||
DDS_CUBEMAP_POSITIVEX = DDSCAPS2_CUBEMAP | DDSCAPS2_CUBEMAP_POSITIVEX
|
||||
DDS_CUBEMAP_NEGATIVEX = DDSCAPS2_CUBEMAP | DDSCAPS2_CUBEMAP_NEGATIVEX
|
||||
DDS_CUBEMAP_POSITIVEY = DDSCAPS2_CUBEMAP | DDSCAPS2_CUBEMAP_POSITIVEY
|
||||
DDS_CUBEMAP_NEGATIVEY = DDSCAPS2_CUBEMAP | DDSCAPS2_CUBEMAP_NEGATIVEY
|
||||
DDS_CUBEMAP_POSITIVEZ = DDSCAPS2_CUBEMAP | DDSCAPS2_CUBEMAP_POSITIVEZ
|
||||
DDS_CUBEMAP_NEGATIVEZ = DDSCAPS2_CUBEMAP | DDSCAPS2_CUBEMAP_NEGATIVEZ
|
||||
|
||||
|
||||
# DXT1
|
||||
DXT1_FOURCC = 0x31545844
|
||||
|
||||
# DXT3
|
||||
DXT3_FOURCC = 0x33545844
|
||||
|
||||
# DXT5
|
||||
DXT5_FOURCC = 0x35545844
|
||||
class DDPF(IntFlag):
|
||||
ALPHAPIXELS = 0x1
|
||||
ALPHA = 0x2
|
||||
FOURCC = 0x4
|
||||
PALETTEINDEXED8 = 0x20
|
||||
RGB = 0x40
|
||||
LUMINANCE = 0x20000
|
||||
|
||||
|
||||
# dxgiformat.h
|
||||
class DXGI_FORMAT(IntEnum):
|
||||
UNKNOWN = 0
|
||||
R32G32B32A32_TYPELESS = 1
|
||||
R32G32B32A32_FLOAT = 2
|
||||
R32G32B32A32_UINT = 3
|
||||
R32G32B32A32_SINT = 4
|
||||
R32G32B32_TYPELESS = 5
|
||||
R32G32B32_FLOAT = 6
|
||||
R32G32B32_UINT = 7
|
||||
R32G32B32_SINT = 8
|
||||
R16G16B16A16_TYPELESS = 9
|
||||
R16G16B16A16_FLOAT = 10
|
||||
R16G16B16A16_UNORM = 11
|
||||
R16G16B16A16_UINT = 12
|
||||
R16G16B16A16_SNORM = 13
|
||||
R16G16B16A16_SINT = 14
|
||||
R32G32_TYPELESS = 15
|
||||
R32G32_FLOAT = 16
|
||||
R32G32_UINT = 17
|
||||
R32G32_SINT = 18
|
||||
R32G8X24_TYPELESS = 19
|
||||
D32_FLOAT_S8X24_UINT = 20
|
||||
R32_FLOAT_X8X24_TYPELESS = 21
|
||||
X32_TYPELESS_G8X24_UINT = 22
|
||||
R10G10B10A2_TYPELESS = 23
|
||||
R10G10B10A2_UNORM = 24
|
||||
R10G10B10A2_UINT = 25
|
||||
R11G11B10_FLOAT = 26
|
||||
R8G8B8A8_TYPELESS = 27
|
||||
R8G8B8A8_UNORM = 28
|
||||
R8G8B8A8_UNORM_SRGB = 29
|
||||
R8G8B8A8_UINT = 30
|
||||
R8G8B8A8_SNORM = 31
|
||||
R8G8B8A8_SINT = 32
|
||||
R16G16_TYPELESS = 33
|
||||
R16G16_FLOAT = 34
|
||||
R16G16_UNORM = 35
|
||||
R16G16_UINT = 36
|
||||
R16G16_SNORM = 37
|
||||
R16G16_SINT = 38
|
||||
R32_TYPELESS = 39
|
||||
D32_FLOAT = 40
|
||||
R32_FLOAT = 41
|
||||
R32_UINT = 42
|
||||
R32_SINT = 43
|
||||
R24G8_TYPELESS = 44
|
||||
D24_UNORM_S8_UINT = 45
|
||||
R24_UNORM_X8_TYPELESS = 46
|
||||
X24_TYPELESS_G8_UINT = 47
|
||||
R8G8_TYPELESS = 48
|
||||
R8G8_UNORM = 49
|
||||
R8G8_UINT = 50
|
||||
R8G8_SNORM = 51
|
||||
R8G8_SINT = 52
|
||||
R16_TYPELESS = 53
|
||||
R16_FLOAT = 54
|
||||
D16_UNORM = 55
|
||||
R16_UNORM = 56
|
||||
R16_UINT = 57
|
||||
R16_SNORM = 58
|
||||
R16_SINT = 59
|
||||
R8_TYPELESS = 60
|
||||
R8_UNORM = 61
|
||||
R8_UINT = 62
|
||||
R8_SNORM = 63
|
||||
R8_SINT = 64
|
||||
A8_UNORM = 65
|
||||
R1_UNORM = 66
|
||||
R9G9B9E5_SHAREDEXP = 67
|
||||
R8G8_B8G8_UNORM = 68
|
||||
G8R8_G8B8_UNORM = 69
|
||||
BC1_TYPELESS = 70
|
||||
BC1_UNORM = 71
|
||||
BC1_UNORM_SRGB = 72
|
||||
BC2_TYPELESS = 73
|
||||
BC2_UNORM = 74
|
||||
BC2_UNORM_SRGB = 75
|
||||
BC3_TYPELESS = 76
|
||||
BC3_UNORM = 77
|
||||
BC3_UNORM_SRGB = 78
|
||||
BC4_TYPELESS = 79
|
||||
BC4_UNORM = 80
|
||||
BC4_SNORM = 81
|
||||
BC5_TYPELESS = 82
|
||||
BC5_UNORM = 83
|
||||
BC5_SNORM = 84
|
||||
B5G6R5_UNORM = 85
|
||||
B5G5R5A1_UNORM = 86
|
||||
B8G8R8A8_UNORM = 87
|
||||
B8G8R8X8_UNORM = 88
|
||||
R10G10B10_XR_BIAS_A2_UNORM = 89
|
||||
B8G8R8A8_TYPELESS = 90
|
||||
B8G8R8A8_UNORM_SRGB = 91
|
||||
B8G8R8X8_TYPELESS = 92
|
||||
B8G8R8X8_UNORM_SRGB = 93
|
||||
BC6H_TYPELESS = 94
|
||||
BC6H_UF16 = 95
|
||||
BC6H_SF16 = 96
|
||||
BC7_TYPELESS = 97
|
||||
BC7_UNORM = 98
|
||||
BC7_UNORM_SRGB = 99
|
||||
AYUV = 100
|
||||
Y410 = 101
|
||||
Y416 = 102
|
||||
NV12 = 103
|
||||
P010 = 104
|
||||
P016 = 105
|
||||
OPAQUE_420 = 106
|
||||
YUY2 = 107
|
||||
Y210 = 108
|
||||
Y216 = 109
|
||||
NV11 = 110
|
||||
AI44 = 111
|
||||
IA44 = 112
|
||||
P8 = 113
|
||||
A8P8 = 114
|
||||
B4G4R4A4_UNORM = 115
|
||||
P208 = 130
|
||||
V208 = 131
|
||||
V408 = 132
|
||||
SAMPLER_FEEDBACK_MIN_MIP_OPAQUE = 189
|
||||
SAMPLER_FEEDBACK_MIP_REGION_USED_OPAQUE = 190
|
||||
|
||||
DXGI_FORMAT_R8G8B8A8_TYPELESS = 27
|
||||
DXGI_FORMAT_R8G8B8A8_UNORM = 28
|
||||
DXGI_FORMAT_R8G8B8A8_UNORM_SRGB = 29
|
||||
DXGI_FORMAT_BC5_TYPELESS = 82
|
||||
DXGI_FORMAT_BC5_UNORM = 83
|
||||
DXGI_FORMAT_BC5_SNORM = 84
|
||||
DXGI_FORMAT_BC6H_UF16 = 95
|
||||
DXGI_FORMAT_BC6H_SF16 = 96
|
||||
DXGI_FORMAT_BC7_TYPELESS = 97
|
||||
DXGI_FORMAT_BC7_UNORM = 98
|
||||
DXGI_FORMAT_BC7_UNORM_SRGB = 99
|
||||
|
||||
class D3DFMT(IntEnum):
|
||||
UNKNOWN = 0
|
||||
R8G8B8 = 20
|
||||
A8R8G8B8 = 21
|
||||
X8R8G8B8 = 22
|
||||
R5G6B5 = 23
|
||||
X1R5G5B5 = 24
|
||||
A1R5G5B5 = 25
|
||||
A4R4G4B4 = 26
|
||||
R3G3B2 = 27
|
||||
A8 = 28
|
||||
A8R3G3B2 = 29
|
||||
X4R4G4B4 = 30
|
||||
A2B10G10R10 = 31
|
||||
A8B8G8R8 = 32
|
||||
X8B8G8R8 = 33
|
||||
G16R16 = 34
|
||||
A2R10G10B10 = 35
|
||||
A16B16G16R16 = 36
|
||||
A8P8 = 40
|
||||
P8 = 41
|
||||
L8 = 50
|
||||
A8L8 = 51
|
||||
A4L4 = 52
|
||||
V8U8 = 60
|
||||
L6V5U5 = 61
|
||||
X8L8V8U8 = 62
|
||||
Q8W8V8U8 = 63
|
||||
V16U16 = 64
|
||||
A2W10V10U10 = 67
|
||||
D16_LOCKABLE = 70
|
||||
D32 = 71
|
||||
D15S1 = 73
|
||||
D24S8 = 75
|
||||
D24X8 = 77
|
||||
D24X4S4 = 79
|
||||
D16 = 80
|
||||
D32F_LOCKABLE = 82
|
||||
D24FS8 = 83
|
||||
D32_LOCKABLE = 84
|
||||
S8_LOCKABLE = 85
|
||||
L16 = 81
|
||||
VERTEXDATA = 100
|
||||
INDEX16 = 101
|
||||
INDEX32 = 102
|
||||
Q16W16V16U16 = 110
|
||||
R16F = 111
|
||||
G16R16F = 112
|
||||
A16B16G16R16F = 113
|
||||
R32F = 114
|
||||
G32R32F = 115
|
||||
A32B32G32R32F = 116
|
||||
CxV8U8 = 117
|
||||
A1 = 118
|
||||
A2B10G10R10_XR_BIAS = 119
|
||||
BINARYBUFFER = 199
|
||||
|
||||
UYVY = i32(b"UYVY")
|
||||
R8G8_B8G8 = i32(b"RGBG")
|
||||
YUY2 = i32(b"YUY2")
|
||||
G8R8_G8B8 = i32(b"GRGB")
|
||||
DXT1 = i32(b"DXT1")
|
||||
DXT2 = i32(b"DXT2")
|
||||
DXT3 = i32(b"DXT3")
|
||||
DXT4 = i32(b"DXT4")
|
||||
DXT5 = i32(b"DXT5")
|
||||
DX10 = i32(b"DX10")
|
||||
BC4S = i32(b"BC4S")
|
||||
BC4U = i32(b"BC4U")
|
||||
BC5S = i32(b"BC5S")
|
||||
BC5U = i32(b"BC5U")
|
||||
ATI1 = i32(b"ATI1")
|
||||
ATI2 = i32(b"ATI2")
|
||||
MULTI2_ARGB8 = i32(b"MET1")
|
||||
|
||||
|
||||
# Backward compatibility layer
|
||||
module = sys.modules[__name__]
|
||||
for item in DDSD:
|
||||
assert item.name is not None
|
||||
setattr(module, f"DDSD_{item.name}", item.value)
|
||||
for item1 in DDSCAPS:
|
||||
assert item1.name is not None
|
||||
setattr(module, f"DDSCAPS_{item1.name}", item1.value)
|
||||
for item2 in DDSCAPS2:
|
||||
assert item2.name is not None
|
||||
setattr(module, f"DDSCAPS2_{item2.name}", item2.value)
|
||||
for item3 in DDPF:
|
||||
assert item3.name is not None
|
||||
setattr(module, f"DDPF_{item3.name}", item3.value)
|
||||
|
||||
DDS_FOURCC = DDPF.FOURCC
|
||||
DDS_RGB = DDPF.RGB
|
||||
DDS_RGBA = DDPF.RGB | DDPF.ALPHAPIXELS
|
||||
DDS_LUMINANCE = DDPF.LUMINANCE
|
||||
DDS_LUMINANCEA = DDPF.LUMINANCE | DDPF.ALPHAPIXELS
|
||||
DDS_ALPHA = DDPF.ALPHA
|
||||
DDS_PAL8 = DDPF.PALETTEINDEXED8
|
||||
|
||||
DDS_HEADER_FLAGS_TEXTURE = DDSD.CAPS | DDSD.HEIGHT | DDSD.WIDTH | DDSD.PIXELFORMAT
|
||||
DDS_HEADER_FLAGS_MIPMAP = DDSD.MIPMAPCOUNT
|
||||
DDS_HEADER_FLAGS_VOLUME = DDSD.DEPTH
|
||||
DDS_HEADER_FLAGS_PITCH = DDSD.PITCH
|
||||
DDS_HEADER_FLAGS_LINEARSIZE = DDSD.LINEARSIZE
|
||||
|
||||
DDS_HEIGHT = DDSD.HEIGHT
|
||||
DDS_WIDTH = DDSD.WIDTH
|
||||
|
||||
DDS_SURFACE_FLAGS_TEXTURE = DDSCAPS.TEXTURE
|
||||
DDS_SURFACE_FLAGS_MIPMAP = DDSCAPS.COMPLEX | DDSCAPS.MIPMAP
|
||||
DDS_SURFACE_FLAGS_CUBEMAP = DDSCAPS.COMPLEX
|
||||
|
||||
DDS_CUBEMAP_POSITIVEX = DDSCAPS2.CUBEMAP | DDSCAPS2.CUBEMAP_POSITIVEX
|
||||
DDS_CUBEMAP_NEGATIVEX = DDSCAPS2.CUBEMAP | DDSCAPS2.CUBEMAP_NEGATIVEX
|
||||
DDS_CUBEMAP_POSITIVEY = DDSCAPS2.CUBEMAP | DDSCAPS2.CUBEMAP_POSITIVEY
|
||||
DDS_CUBEMAP_NEGATIVEY = DDSCAPS2.CUBEMAP | DDSCAPS2.CUBEMAP_NEGATIVEY
|
||||
DDS_CUBEMAP_POSITIVEZ = DDSCAPS2.CUBEMAP | DDSCAPS2.CUBEMAP_POSITIVEZ
|
||||
DDS_CUBEMAP_NEGATIVEZ = DDSCAPS2.CUBEMAP | DDSCAPS2.CUBEMAP_NEGATIVEZ
|
||||
|
||||
DXT1_FOURCC = D3DFMT.DXT1
|
||||
DXT3_FOURCC = D3DFMT.DXT3
|
||||
DXT5_FOURCC = D3DFMT.DXT5
|
||||
|
||||
DXGI_FORMAT_R8G8B8A8_TYPELESS = DXGI_FORMAT.R8G8B8A8_TYPELESS
|
||||
DXGI_FORMAT_R8G8B8A8_UNORM = DXGI_FORMAT.R8G8B8A8_UNORM
|
||||
DXGI_FORMAT_R8G8B8A8_UNORM_SRGB = DXGI_FORMAT.R8G8B8A8_UNORM_SRGB
|
||||
DXGI_FORMAT_BC5_TYPELESS = DXGI_FORMAT.BC5_TYPELESS
|
||||
DXGI_FORMAT_BC5_UNORM = DXGI_FORMAT.BC5_UNORM
|
||||
DXGI_FORMAT_BC5_SNORM = DXGI_FORMAT.BC5_SNORM
|
||||
DXGI_FORMAT_BC6H_UF16 = DXGI_FORMAT.BC6H_UF16
|
||||
DXGI_FORMAT_BC6H_SF16 = DXGI_FORMAT.BC6H_SF16
|
||||
DXGI_FORMAT_BC7_TYPELESS = DXGI_FORMAT.BC7_TYPELESS
|
||||
DXGI_FORMAT_BC7_UNORM = DXGI_FORMAT.BC7_UNORM
|
||||
DXGI_FORMAT_BC7_UNORM_SRGB = DXGI_FORMAT.BC7_UNORM_SRGB
|
||||
|
||||
|
||||
class DdsImageFile(ImageFile.ImageFile):
|
||||
format = "DDS"
|
||||
format_description = "DirectDraw Surface"
|
||||
|
||||
def _open(self):
|
||||
def _open(self) -> None:
|
||||
if not _accept(self.fp.read(4)):
|
||||
msg = "not a DDS file"
|
||||
raise SyntaxError(msg)
|
||||
@@ -124,172 +344,281 @@ class DdsImageFile(ImageFile.ImageFile):
|
||||
if len(header_bytes) != 120:
|
||||
msg = f"Incomplete header: {len(header_bytes)} bytes"
|
||||
raise OSError(msg)
|
||||
header = BytesIO(header_bytes)
|
||||
header = io.BytesIO(header_bytes)
|
||||
|
||||
flags, height, width = struct.unpack("<3I", header.read(12))
|
||||
self._size = (width, height)
|
||||
self._mode = "RGBA"
|
||||
extents = (0, 0) + self.size
|
||||
|
||||
pitch, depth, mipmaps = struct.unpack("<3I", header.read(12))
|
||||
struct.unpack("<11I", header.read(44)) # reserved
|
||||
|
||||
# pixel format
|
||||
pfsize, pfflags = struct.unpack("<2I", header.read(8))
|
||||
fourcc = header.read(4)
|
||||
(bitcount,) = struct.unpack("<I", header.read(4))
|
||||
masks = struct.unpack("<4I", header.read(16))
|
||||
if pfflags & DDPF_LUMINANCE:
|
||||
# Texture contains uncompressed L or LA data
|
||||
if pfflags & DDPF_ALPHAPIXELS:
|
||||
pfsize, pfflags, fourcc, bitcount = struct.unpack("<4I", header.read(16))
|
||||
n = 0
|
||||
rawmode = None
|
||||
if pfflags & DDPF.RGB:
|
||||
# Texture contains uncompressed RGB data
|
||||
if pfflags & DDPF.ALPHAPIXELS:
|
||||
self._mode = "RGBA"
|
||||
mask_count = 4
|
||||
else:
|
||||
self._mode = "RGB"
|
||||
mask_count = 3
|
||||
|
||||
masks = struct.unpack(f"<{mask_count}I", header.read(mask_count * 4))
|
||||
self.tile = [ImageFile._Tile("dds_rgb", extents, 0, (bitcount, masks))]
|
||||
return
|
||||
elif pfflags & DDPF.LUMINANCE:
|
||||
if bitcount == 8:
|
||||
self._mode = "L"
|
||||
elif bitcount == 16 and pfflags & DDPF.ALPHAPIXELS:
|
||||
self._mode = "LA"
|
||||
else:
|
||||
self._mode = "L"
|
||||
|
||||
self.tile = [("raw", (0, 0) + self.size, 0, (self.mode, 0, 1))]
|
||||
elif pfflags & DDPF_RGB:
|
||||
# Texture contains uncompressed RGB data
|
||||
masks = {mask: ["R", "G", "B", "A"][i] for i, mask in enumerate(masks)}
|
||||
rawmode = ""
|
||||
if pfflags & DDPF_ALPHAPIXELS:
|
||||
rawmode += masks[0xFF000000]
|
||||
else:
|
||||
self._mode = "RGB"
|
||||
rawmode += masks[0xFF0000] + masks[0xFF00] + masks[0xFF]
|
||||
|
||||
self.tile = [("raw", (0, 0) + self.size, 0, (rawmode[::-1], 0, 1))]
|
||||
elif pfflags & DDPF_PALETTEINDEXED8:
|
||||
msg = f"Unsupported bitcount {bitcount} for {pfflags}"
|
||||
raise OSError(msg)
|
||||
elif pfflags & DDPF.PALETTEINDEXED8:
|
||||
self._mode = "P"
|
||||
self.palette = ImagePalette.raw("RGBA", self.fp.read(1024))
|
||||
self.tile = [("raw", (0, 0) + self.size, 0, "L")]
|
||||
else:
|
||||
data_start = header_size + 4
|
||||
n = 0
|
||||
if fourcc == b"DXT1":
|
||||
self.palette.mode = "RGBA"
|
||||
elif pfflags & DDPF.FOURCC:
|
||||
offset = header_size + 4
|
||||
if fourcc == D3DFMT.DXT1:
|
||||
self._mode = "RGBA"
|
||||
self.pixel_format = "DXT1"
|
||||
n = 1
|
||||
elif fourcc == b"DXT3":
|
||||
elif fourcc == D3DFMT.DXT3:
|
||||
self._mode = "RGBA"
|
||||
self.pixel_format = "DXT3"
|
||||
n = 2
|
||||
elif fourcc == b"DXT5":
|
||||
elif fourcc == D3DFMT.DXT5:
|
||||
self._mode = "RGBA"
|
||||
self.pixel_format = "DXT5"
|
||||
n = 3
|
||||
elif fourcc == b"ATI1":
|
||||
elif fourcc in (D3DFMT.BC4U, D3DFMT.ATI1):
|
||||
self._mode = "L"
|
||||
self.pixel_format = "BC4"
|
||||
n = 4
|
||||
self._mode = "L"
|
||||
elif fourcc in (b"ATI2", b"BC5U"):
|
||||
self.pixel_format = "BC5"
|
||||
n = 5
|
||||
elif fourcc == D3DFMT.BC5S:
|
||||
self._mode = "RGB"
|
||||
elif fourcc == b"BC5S":
|
||||
self.pixel_format = "BC5S"
|
||||
n = 5
|
||||
elif fourcc in (D3DFMT.BC5U, D3DFMT.ATI2):
|
||||
self._mode = "RGB"
|
||||
elif fourcc == b"DX10":
|
||||
data_start += 20
|
||||
self.pixel_format = "BC5"
|
||||
n = 5
|
||||
elif fourcc == D3DFMT.DX10:
|
||||
offset += 20
|
||||
# ignoring flags which pertain to volume textures and cubemaps
|
||||
(dxgi_format,) = struct.unpack("<I", self.fp.read(4))
|
||||
self.fp.read(16)
|
||||
if dxgi_format in (DXGI_FORMAT_BC5_TYPELESS, DXGI_FORMAT_BC5_UNORM):
|
||||
if dxgi_format in (
|
||||
DXGI_FORMAT.BC1_UNORM,
|
||||
DXGI_FORMAT.BC1_TYPELESS,
|
||||
):
|
||||
self._mode = "RGBA"
|
||||
self.pixel_format = "BC1"
|
||||
n = 1
|
||||
elif dxgi_format in (DXGI_FORMAT.BC2_TYPELESS, DXGI_FORMAT.BC2_UNORM):
|
||||
self._mode = "RGBA"
|
||||
self.pixel_format = "BC2"
|
||||
n = 2
|
||||
elif dxgi_format in (DXGI_FORMAT.BC3_TYPELESS, DXGI_FORMAT.BC3_UNORM):
|
||||
self._mode = "RGBA"
|
||||
self.pixel_format = "BC3"
|
||||
n = 3
|
||||
elif dxgi_format in (DXGI_FORMAT.BC4_TYPELESS, DXGI_FORMAT.BC4_UNORM):
|
||||
self._mode = "L"
|
||||
self.pixel_format = "BC4"
|
||||
n = 4
|
||||
elif dxgi_format in (DXGI_FORMAT.BC5_TYPELESS, DXGI_FORMAT.BC5_UNORM):
|
||||
self._mode = "RGB"
|
||||
self.pixel_format = "BC5"
|
||||
n = 5
|
||||
elif dxgi_format == DXGI_FORMAT.BC5_SNORM:
|
||||
self._mode = "RGB"
|
||||
elif dxgi_format == DXGI_FORMAT_BC5_SNORM:
|
||||
self.pixel_format = "BC5S"
|
||||
n = 5
|
||||
elif dxgi_format == DXGI_FORMAT.BC6H_UF16:
|
||||
self._mode = "RGB"
|
||||
elif dxgi_format == DXGI_FORMAT_BC6H_UF16:
|
||||
self.pixel_format = "BC6H"
|
||||
n = 6
|
||||
elif dxgi_format == DXGI_FORMAT.BC6H_SF16:
|
||||
self._mode = "RGB"
|
||||
elif dxgi_format == DXGI_FORMAT_BC6H_SF16:
|
||||
self.pixel_format = "BC6HS"
|
||||
n = 6
|
||||
self._mode = "RGB"
|
||||
elif dxgi_format in (DXGI_FORMAT_BC7_TYPELESS, DXGI_FORMAT_BC7_UNORM):
|
||||
self.pixel_format = "BC7"
|
||||
n = 7
|
||||
elif dxgi_format == DXGI_FORMAT_BC7_UNORM_SRGB:
|
||||
self.pixel_format = "BC7"
|
||||
self.info["gamma"] = 1 / 2.2
|
||||
n = 7
|
||||
elif dxgi_format in (
|
||||
DXGI_FORMAT_R8G8B8A8_TYPELESS,
|
||||
DXGI_FORMAT_R8G8B8A8_UNORM,
|
||||
DXGI_FORMAT_R8G8B8A8_UNORM_SRGB,
|
||||
DXGI_FORMAT.BC7_TYPELESS,
|
||||
DXGI_FORMAT.BC7_UNORM,
|
||||
DXGI_FORMAT.BC7_UNORM_SRGB,
|
||||
):
|
||||
self.tile = [("raw", (0, 0) + self.size, 0, ("RGBA", 0, 1))]
|
||||
if dxgi_format == DXGI_FORMAT_R8G8B8A8_UNORM_SRGB:
|
||||
self._mode = "RGBA"
|
||||
self.pixel_format = "BC7"
|
||||
n = 7
|
||||
if dxgi_format == DXGI_FORMAT.BC7_UNORM_SRGB:
|
||||
self.info["gamma"] = 1 / 2.2
|
||||
elif dxgi_format in (
|
||||
DXGI_FORMAT.R8G8B8A8_TYPELESS,
|
||||
DXGI_FORMAT.R8G8B8A8_UNORM,
|
||||
DXGI_FORMAT.R8G8B8A8_UNORM_SRGB,
|
||||
):
|
||||
self._mode = "RGBA"
|
||||
if dxgi_format == DXGI_FORMAT.R8G8B8A8_UNORM_SRGB:
|
||||
self.info["gamma"] = 1 / 2.2
|
||||
return
|
||||
else:
|
||||
msg = f"Unimplemented DXGI format {dxgi_format}"
|
||||
raise NotImplementedError(msg)
|
||||
else:
|
||||
msg = f"Unimplemented pixel format {repr(fourcc)}"
|
||||
raise NotImplementedError(msg)
|
||||
else:
|
||||
msg = f"Unknown pixel format flags {pfflags}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if n:
|
||||
self.tile = [
|
||||
("bcn", (0, 0) + self.size, data_start, (n, self.pixel_format))
|
||||
ImageFile._Tile("bcn", extents, offset, (n, self.pixel_format))
|
||||
]
|
||||
else:
|
||||
self.tile = [ImageFile._Tile("raw", extents, 0, rawmode or self.mode)]
|
||||
|
||||
def load_seek(self, pos):
|
||||
def load_seek(self, pos: int) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _save(im, fp, filename):
|
||||
class DdsRgbDecoder(ImageFile.PyDecoder):
|
||||
_pulls_fd = True
|
||||
|
||||
def decode(self, buffer: bytes | Image.SupportsArrayInterface) -> tuple[int, int]:
|
||||
assert self.fd is not None
|
||||
bitcount, masks = self.args
|
||||
|
||||
# Some masks will be padded with zeros, e.g. R 0b11 G 0b1100
|
||||
# Calculate how many zeros each mask is padded with
|
||||
mask_offsets = []
|
||||
# And the maximum value of each channel without the padding
|
||||
mask_totals = []
|
||||
for mask in masks:
|
||||
offset = 0
|
||||
if mask != 0:
|
||||
while mask >> (offset + 1) << (offset + 1) == mask:
|
||||
offset += 1
|
||||
mask_offsets.append(offset)
|
||||
mask_totals.append(mask >> offset)
|
||||
|
||||
data = bytearray()
|
||||
bytecount = bitcount // 8
|
||||
dest_length = self.state.xsize * self.state.ysize * len(masks)
|
||||
while len(data) < dest_length:
|
||||
value = int.from_bytes(self.fd.read(bytecount), "little")
|
||||
for i, mask in enumerate(masks):
|
||||
masked_value = value & mask
|
||||
# Remove the zero padding, and scale it to 8 bits
|
||||
data += o8(
|
||||
int(((masked_value >> mask_offsets[i]) / mask_totals[i]) * 255)
|
||||
)
|
||||
self.set_as_raw(data)
|
||||
return -1, 0
|
||||
|
||||
|
||||
def _save(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
|
||||
if im.mode not in ("RGB", "RGBA", "L", "LA"):
|
||||
msg = f"cannot write mode {im.mode} as DDS"
|
||||
raise OSError(msg)
|
||||
|
||||
rawmode = im.mode
|
||||
masks = [0xFF0000, 0xFF00, 0xFF]
|
||||
if im.mode in ("L", "LA"):
|
||||
pixel_flags = DDPF_LUMINANCE
|
||||
flags = DDSD.CAPS | DDSD.HEIGHT | DDSD.WIDTH | DDSD.PIXELFORMAT
|
||||
bitcount = len(im.getbands()) * 8
|
||||
pixel_format = im.encoderinfo.get("pixel_format")
|
||||
args: tuple[int] | str
|
||||
if pixel_format:
|
||||
codec_name = "bcn"
|
||||
flags |= DDSD.LINEARSIZE
|
||||
pitch = (im.width + 3) * 4
|
||||
rgba_mask = [0, 0, 0, 0]
|
||||
pixel_flags = DDPF.FOURCC
|
||||
if pixel_format == "DXT1":
|
||||
fourcc = D3DFMT.DXT1
|
||||
args = (1,)
|
||||
elif pixel_format == "DXT3":
|
||||
fourcc = D3DFMT.DXT3
|
||||
args = (2,)
|
||||
elif pixel_format == "DXT5":
|
||||
fourcc = D3DFMT.DXT5
|
||||
args = (3,)
|
||||
else:
|
||||
fourcc = D3DFMT.DX10
|
||||
if pixel_format == "BC2":
|
||||
args = (2,)
|
||||
dxgi_format = DXGI_FORMAT.BC2_TYPELESS
|
||||
elif pixel_format == "BC3":
|
||||
args = (3,)
|
||||
dxgi_format = DXGI_FORMAT.BC3_TYPELESS
|
||||
elif pixel_format == "BC5":
|
||||
args = (5,)
|
||||
dxgi_format = DXGI_FORMAT.BC5_TYPELESS
|
||||
if im.mode != "RGB":
|
||||
msg = "only RGB mode can be written as BC5"
|
||||
raise OSError(msg)
|
||||
else:
|
||||
msg = f"cannot write pixel format {pixel_format}"
|
||||
raise OSError(msg)
|
||||
else:
|
||||
pixel_flags = DDPF_RGB
|
||||
rawmode = rawmode[::-1]
|
||||
if im.mode in ("LA", "RGBA"):
|
||||
pixel_flags |= DDPF_ALPHAPIXELS
|
||||
masks.append(0xFF000000)
|
||||
codec_name = "raw"
|
||||
flags |= DDSD.PITCH
|
||||
pitch = (im.width * bitcount + 7) // 8
|
||||
|
||||
bitcount = len(masks) * 8
|
||||
while len(masks) < 4:
|
||||
masks.append(0)
|
||||
alpha = im.mode[-1] == "A"
|
||||
if im.mode[0] == "L":
|
||||
pixel_flags = DDPF.LUMINANCE
|
||||
args = im.mode
|
||||
if alpha:
|
||||
rgba_mask = [0x000000FF, 0x000000FF, 0x000000FF]
|
||||
else:
|
||||
rgba_mask = [0xFF000000, 0xFF000000, 0xFF000000]
|
||||
else:
|
||||
pixel_flags = DDPF.RGB
|
||||
args = im.mode[::-1]
|
||||
rgba_mask = [0x00FF0000, 0x0000FF00, 0x000000FF]
|
||||
|
||||
if alpha:
|
||||
r, g, b, a = im.split()
|
||||
im = Image.merge("RGBA", (a, r, g, b))
|
||||
if alpha:
|
||||
pixel_flags |= DDPF.ALPHAPIXELS
|
||||
rgba_mask.append(0xFF000000 if alpha else 0)
|
||||
|
||||
fourcc = D3DFMT.UNKNOWN
|
||||
fp.write(
|
||||
o32(DDS_MAGIC)
|
||||
+ o32(124) # header size
|
||||
+ o32(
|
||||
DDSD_CAPS | DDSD_HEIGHT | DDSD_WIDTH | DDSD_PITCH | DDSD_PIXELFORMAT
|
||||
) # flags
|
||||
+ o32(im.height)
|
||||
+ o32(im.width)
|
||||
+ o32((im.width * bitcount + 7) // 8) # pitch
|
||||
+ o32(0) # depth
|
||||
+ o32(0) # mipmaps
|
||||
+ o32(0) * 11 # reserved
|
||||
+ o32(32) # pfsize
|
||||
+ o32(pixel_flags) # pfflags
|
||||
+ o32(0) # fourcc
|
||||
+ o32(bitcount) # bitcount
|
||||
+ b"".join(o32(mask) for mask in masks) # rgbabitmask
|
||||
+ o32(DDSCAPS_TEXTURE) # dwCaps
|
||||
+ o32(0) # dwCaps2
|
||||
+ o32(0) # dwCaps3
|
||||
+ o32(0) # dwCaps4
|
||||
+ o32(0) # dwReserved2
|
||||
+ struct.pack(
|
||||
"<7I",
|
||||
124, # header size
|
||||
flags, # flags
|
||||
im.height,
|
||||
im.width,
|
||||
pitch,
|
||||
0, # depth
|
||||
0, # mipmaps
|
||||
)
|
||||
+ struct.pack("11I", *((0,) * 11)) # reserved
|
||||
# pfsize, pfflags, fourcc, bitcount
|
||||
+ struct.pack("<4I", 32, pixel_flags, fourcc, bitcount)
|
||||
+ struct.pack("<4I", *rgba_mask) # dwRGBABitMask
|
||||
+ struct.pack("<5I", DDSCAPS.TEXTURE, 0, 0, 0, 0)
|
||||
)
|
||||
if im.mode == "RGBA":
|
||||
r, g, b, a = im.split()
|
||||
im = Image.merge("RGBA", (a, r, g, b))
|
||||
ImageFile._save(im, fp, [("raw", (0, 0) + im.size, 0, (rawmode, 0, 1))])
|
||||
if fourcc == D3DFMT.DX10:
|
||||
fp.write(
|
||||
# dxgi_format, 2D resource, misc, array size, straight alpha
|
||||
struct.pack("<5I", dxgi_format, 3, 0, 0, 1)
|
||||
)
|
||||
ImageFile._save(im, fp, [ImageFile._Tile(codec_name, (0, 0) + im.size, 0, args)])
|
||||
|
||||
|
||||
def _accept(prefix):
|
||||
return prefix[:4] == b"DDS "
|
||||
def _accept(prefix: bytes) -> bool:
|
||||
return prefix.startswith(b"DDS ")
|
||||
|
||||
|
||||
Image.register_open(DdsImageFile.format, DdsImageFile, _accept)
|
||||
Image.register_decoder("dds_rgb", DdsRgbDecoder)
|
||||
Image.register_save(DdsImageFile.format, _save)
|
||||
Image.register_extension(DdsImageFile.format, ".dds")
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#
|
||||
# See the README file for information on usage and redistribution.
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
@@ -26,10 +27,10 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import IO
|
||||
|
||||
from . import Image, ImageFile
|
||||
from ._binary import i32le as i32
|
||||
from ._deprecate import deprecate
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
@@ -37,11 +38,11 @@ from ._deprecate import deprecate
|
||||
split = re.compile(r"^%%([^:]*):[ \t]*(.*)[ \t]*$")
|
||||
field = re.compile(r"^%[%!\w]([^:]*)[ \t]*$")
|
||||
|
||||
gs_binary = None
|
||||
gs_binary: str | bool | None = None
|
||||
gs_windows_binary = None
|
||||
|
||||
|
||||
def has_ghostscript():
|
||||
def has_ghostscript() -> bool:
|
||||
global gs_binary, gs_windows_binary
|
||||
if gs_binary is None:
|
||||
if sys.platform.startswith("win"):
|
||||
@@ -64,27 +65,32 @@ def has_ghostscript():
|
||||
return gs_binary is not False
|
||||
|
||||
|
||||
def Ghostscript(tile, size, fp, scale=1, transparency=False):
|
||||
def Ghostscript(
|
||||
tile: list[ImageFile._Tile],
|
||||
size: tuple[int, int],
|
||||
fp: IO[bytes],
|
||||
scale: int = 1,
|
||||
transparency: bool = False,
|
||||
) -> Image.core.ImagingCore:
|
||||
"""Render an image using Ghostscript"""
|
||||
global gs_binary
|
||||
if not has_ghostscript():
|
||||
msg = "Unable to locate Ghostscript on paths"
|
||||
raise OSError(msg)
|
||||
assert isinstance(gs_binary, str)
|
||||
|
||||
# Unpack decoder tile
|
||||
decoder, tile, offset, data = tile[0]
|
||||
length, bbox = data
|
||||
args = tile[0].args
|
||||
assert isinstance(args, tuple)
|
||||
length, bbox = args
|
||||
|
||||
# Hack to support hi-res rendering
|
||||
scale = int(scale) or 1
|
||||
# orig_size = size
|
||||
# orig_bbox = bbox
|
||||
size = (size[0] * scale, size[1] * scale)
|
||||
width = size[0] * scale
|
||||
height = size[1] * scale
|
||||
# resolution is dependent on bbox and size
|
||||
res = (
|
||||
72.0 * size[0] / (bbox[2] - bbox[0]),
|
||||
72.0 * size[1] / (bbox[3] - bbox[1]),
|
||||
)
|
||||
res_x = 72.0 * width / (bbox[2] - bbox[0])
|
||||
res_y = 72.0 * height / (bbox[3] - bbox[1])
|
||||
|
||||
out_fd, outfile = tempfile.mkstemp()
|
||||
os.close(out_fd)
|
||||
@@ -115,14 +121,20 @@ def Ghostscript(tile, size, fp, scale=1, transparency=False):
|
||||
lengthfile -= len(s)
|
||||
f.write(s)
|
||||
|
||||
device = "pngalpha" if transparency else "ppmraw"
|
||||
if transparency:
|
||||
# "RGBA"
|
||||
device = "pngalpha"
|
||||
else:
|
||||
# "pnmraw" automatically chooses between
|
||||
# PBM ("1"), PGM ("L"), and PPM ("RGB").
|
||||
device = "pnmraw"
|
||||
|
||||
# Build Ghostscript command
|
||||
command = [
|
||||
gs_binary,
|
||||
"-q", # quiet mode
|
||||
"-g%dx%d" % size, # set output geometry (pixels)
|
||||
"-r%fx%f" % res, # set input DPI (dots per inch)
|
||||
f"-g{width:d}x{height:d}", # set output geometry (pixels)
|
||||
f"-r{res_x:f}x{res_y:f}", # set input DPI (dots per inch)
|
||||
"-dBATCH", # exit after processing
|
||||
"-dNOPAUSE", # don't pause between pages
|
||||
"-dSAFER", # safe mode
|
||||
@@ -145,8 +157,9 @@ def Ghostscript(tile, size, fp, scale=1, transparency=False):
|
||||
startupinfo = subprocess.STARTUPINFO()
|
||||
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||
subprocess.check_call(command, startupinfo=startupinfo)
|
||||
out_im = Image.open(outfile)
|
||||
out_im.load()
|
||||
with Image.open(outfile) as out_im:
|
||||
out_im.load()
|
||||
return out_im.im.copy()
|
||||
finally:
|
||||
try:
|
||||
os.unlink(outfile)
|
||||
@@ -155,50 +168,11 @@ def Ghostscript(tile, size, fp, scale=1, transparency=False):
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
im = out_im.im.copy()
|
||||
out_im.close()
|
||||
return im
|
||||
|
||||
|
||||
class PSFile:
|
||||
"""
|
||||
Wrapper for bytesio object that treats either CR or LF as end of line.
|
||||
This class is no longer used internally, but kept for backwards compatibility.
|
||||
"""
|
||||
|
||||
def __init__(self, fp):
|
||||
deprecate(
|
||||
"PSFile",
|
||||
11,
|
||||
action="If you need the functionality of this class "
|
||||
"you will need to implement it yourself.",
|
||||
)
|
||||
self.fp = fp
|
||||
self.char = None
|
||||
|
||||
def seek(self, offset, whence=io.SEEK_SET):
|
||||
self.char = None
|
||||
self.fp.seek(offset, whence)
|
||||
|
||||
def readline(self):
|
||||
s = [self.char or b""]
|
||||
self.char = None
|
||||
|
||||
c = self.fp.read(1)
|
||||
while (c not in b"\r\n") and len(c):
|
||||
s.append(c)
|
||||
c = self.fp.read(1)
|
||||
|
||||
self.char = self.fp.read(1)
|
||||
# line endings can be 1 or 2 of \r \n, in either order
|
||||
if self.char in b"\r\n":
|
||||
self.char = None
|
||||
|
||||
return b"".join(s).decode("latin-1")
|
||||
|
||||
|
||||
def _accept(prefix):
|
||||
return prefix[:4] == b"%!PS" or (len(prefix) >= 4 and i32(prefix) == 0xC6D3D0C5)
|
||||
def _accept(prefix: bytes) -> bool:
|
||||
return prefix.startswith(b"%!PS") or (
|
||||
len(prefix) >= 4 and i32(prefix) == 0xC6D3D0C5
|
||||
)
|
||||
|
||||
|
||||
##
|
||||
@@ -214,14 +188,18 @@ class EpsImageFile(ImageFile.ImageFile):
|
||||
|
||||
mode_map = {1: "L", 2: "LAB", 3: "RGB", 4: "CMYK"}
|
||||
|
||||
def _open(self):
|
||||
def _open(self) -> None:
|
||||
(length, offset) = self._find_offset(self.fp)
|
||||
|
||||
# go to offset - start of "%!PS"
|
||||
self.fp.seek(offset)
|
||||
|
||||
self._mode = "RGB"
|
||||
self._size = None
|
||||
|
||||
# When reading header comments, the first comment is used.
|
||||
# When reading trailer comments, the last comment is used.
|
||||
bounding_box: list[int] | None = None
|
||||
imagedata_size: tuple[int, int] | None = None
|
||||
|
||||
byte_arr = bytearray(255)
|
||||
bytes_mv = memoryview(byte_arr)
|
||||
@@ -230,7 +208,12 @@ class EpsImageFile(ImageFile.ImageFile):
|
||||
reading_trailer_comments = False
|
||||
trailer_reached = False
|
||||
|
||||
def check_required_header_comments():
|
||||
def check_required_header_comments() -> None:
|
||||
"""
|
||||
The EPS specification requires that some headers exist.
|
||||
This should be checked when the header comments formally end,
|
||||
when image data starts, or when the file ends, whichever comes first.
|
||||
"""
|
||||
if "PS-Adobe" not in self.info:
|
||||
msg = 'EPS header missing "%!PS-Adobe" comment'
|
||||
raise SyntaxError(msg)
|
||||
@@ -238,41 +221,39 @@ class EpsImageFile(ImageFile.ImageFile):
|
||||
msg = 'EPS header missing "%%BoundingBox" comment'
|
||||
raise SyntaxError(msg)
|
||||
|
||||
def _read_comment(s):
|
||||
nonlocal reading_trailer_comments
|
||||
def read_comment(s: str) -> bool:
|
||||
nonlocal bounding_box, reading_trailer_comments
|
||||
try:
|
||||
m = split.match(s)
|
||||
except re.error as e:
|
||||
msg = "not an EPS file"
|
||||
raise SyntaxError(msg) from e
|
||||
|
||||
if m:
|
||||
k, v = m.group(1, 2)
|
||||
self.info[k] = v
|
||||
if k == "BoundingBox":
|
||||
if v == "(atend)":
|
||||
reading_trailer_comments = True
|
||||
elif not self._size or (
|
||||
trailer_reached and reading_trailer_comments
|
||||
):
|
||||
try:
|
||||
# Note: The DSC spec says that BoundingBox
|
||||
# fields should be integers, but some drivers
|
||||
# put floating point values there anyway.
|
||||
box = [int(float(i)) for i in v.split()]
|
||||
self._size = box[2] - box[0], box[3] - box[1]
|
||||
self.tile = [
|
||||
("eps", (0, 0) + self.size, offset, (length, box))
|
||||
]
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
if not m:
|
||||
return False
|
||||
|
||||
k, v = m.group(1, 2)
|
||||
self.info[k] = v
|
||||
if k == "BoundingBox":
|
||||
if v == "(atend)":
|
||||
reading_trailer_comments = True
|
||||
elif not bounding_box or (trailer_reached and reading_trailer_comments):
|
||||
try:
|
||||
# Note: The DSC spec says that BoundingBox
|
||||
# fields should be integers, but some drivers
|
||||
# put floating point values there anyway.
|
||||
bounding_box = [int(float(i)) for i in v.split()]
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
|
||||
while True:
|
||||
byte = self.fp.read(1)
|
||||
if byte == b"":
|
||||
# if we didn't read a byte we must be at the end of the file
|
||||
if bytes_read == 0:
|
||||
if reading_header_comments:
|
||||
check_required_header_comments()
|
||||
break
|
||||
elif byte in b"\r\n":
|
||||
# if we read a line ending character, ignore it and parse what
|
||||
@@ -312,11 +293,11 @@ class EpsImageFile(ImageFile.ImageFile):
|
||||
continue
|
||||
|
||||
s = str(bytes_mv[:bytes_read], "latin-1")
|
||||
if not _read_comment(s):
|
||||
if not read_comment(s):
|
||||
m = field.match(s)
|
||||
if m:
|
||||
k = m.group(1)
|
||||
if k[:8] == "PS-Adobe":
|
||||
if k.startswith("PS-Adobe"):
|
||||
self.info["PS-Adobe"] = k[9:]
|
||||
else:
|
||||
self.info[k] = ""
|
||||
@@ -331,6 +312,12 @@ class EpsImageFile(ImageFile.ImageFile):
|
||||
# Check for an "ImageData" descriptor
|
||||
# https://www.adobe.com/devnet-apps/photoshop/fileformatashtml/#50577413_pgfId-1035096
|
||||
|
||||
# If we've already read an "ImageData" descriptor,
|
||||
# don't read another one.
|
||||
if imagedata_size:
|
||||
bytes_read = 0
|
||||
continue
|
||||
|
||||
# Values:
|
||||
# columns
|
||||
# rows
|
||||
@@ -356,29 +343,39 @@ class EpsImageFile(ImageFile.ImageFile):
|
||||
else:
|
||||
break
|
||||
|
||||
self._size = columns, rows
|
||||
return
|
||||
# Parse the columns and rows after checking the bit depth and mode
|
||||
# in case the bit depth and/or mode are invalid.
|
||||
imagedata_size = columns, rows
|
||||
elif bytes_mv[:5] == b"%%EOF":
|
||||
break
|
||||
elif trailer_reached and reading_trailer_comments:
|
||||
# Load EPS trailer
|
||||
|
||||
# if this line starts with "%%EOF",
|
||||
# then we've reached the end of the file
|
||||
if bytes_mv[:5] == b"%%EOF":
|
||||
break
|
||||
|
||||
s = str(bytes_mv[:bytes_read], "latin-1")
|
||||
_read_comment(s)
|
||||
read_comment(s)
|
||||
elif bytes_mv[:9] == b"%%Trailer":
|
||||
trailer_reached = True
|
||||
elif bytes_mv[:14] == b"%%BeginBinary:":
|
||||
bytecount = int(byte_arr[14:bytes_read])
|
||||
self.fp.seek(bytecount, os.SEEK_CUR)
|
||||
bytes_read = 0
|
||||
|
||||
check_required_header_comments()
|
||||
|
||||
if not self._size:
|
||||
# A "BoundingBox" is always required,
|
||||
# even if an "ImageData" descriptor size exists.
|
||||
if not bounding_box:
|
||||
msg = "cannot determine EPS bounding box"
|
||||
raise OSError(msg)
|
||||
|
||||
def _find_offset(self, fp):
|
||||
# An "ImageData" size takes precedence over the "BoundingBox".
|
||||
self._size = imagedata_size or (
|
||||
bounding_box[2] - bounding_box[0],
|
||||
bounding_box[3] - bounding_box[1],
|
||||
)
|
||||
|
||||
self.tile = [
|
||||
ImageFile._Tile("eps", (0, 0) + self.size, offset, (length, bounding_box))
|
||||
]
|
||||
|
||||
def _find_offset(self, fp: IO[bytes]) -> tuple[int, int]:
|
||||
s = fp.read(4)
|
||||
|
||||
if s == b"%!PS":
|
||||
@@ -401,7 +398,9 @@ class EpsImageFile(ImageFile.ImageFile):
|
||||
|
||||
return length, offset
|
||||
|
||||
def load(self, scale=1, transparency=False):
|
||||
def load(
|
||||
self, scale: int = 1, transparency: bool = False
|
||||
) -> Image.core.PixelAccess | None:
|
||||
# Load EPS via Ghostscript
|
||||
if self.tile:
|
||||
self.im = Ghostscript(self.tile, self.size, self.fp, scale, transparency)
|
||||
@@ -410,7 +409,7 @@ class EpsImageFile(ImageFile.ImageFile):
|
||||
self.tile = []
|
||||
return Image.Image.load(self)
|
||||
|
||||
def load_seek(self, *args, **kwargs):
|
||||
def load_seek(self, pos: int) -> None:
|
||||
# we can't incrementally load, so force ImageFile.parser to
|
||||
# use our custom load method by defining this method.
|
||||
pass
|
||||
@@ -419,7 +418,7 @@ class EpsImageFile(ImageFile.ImageFile):
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
|
||||
def _save(im, fp, filename, eps=1):
|
||||
def _save(im: Image.Image, fp: IO[bytes], filename: str | bytes, eps: int = 1) -> None:
|
||||
"""EPS Writer for the Python Imaging Library."""
|
||||
|
||||
# make sure image data is available
|
||||
@@ -460,7 +459,7 @@ def _save(im, fp, filename, eps=1):
|
||||
if hasattr(fp, "flush"):
|
||||
fp.flush()
|
||||
|
||||
ImageFile._save(im, fp, [("eps", (0, 0) + im.size, 0, None)])
|
||||
ImageFile._save(im, fp, [ImageFile._Tile("eps", (0, 0) + im.size)])
|
||||
|
||||
fp.write(b"\n%%%%EndBinary\n")
|
||||
fp.write(b"grestore end\n")
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
This module provides constants and clear-text names for various
|
||||
well-known EXIF tags.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
@@ -302,38 +303,38 @@ TAGS = {
|
||||
|
||||
|
||||
class GPS(IntEnum):
|
||||
GPSVersionID = 0
|
||||
GPSLatitudeRef = 1
|
||||
GPSLatitude = 2
|
||||
GPSLongitudeRef = 3
|
||||
GPSLongitude = 4
|
||||
GPSAltitudeRef = 5
|
||||
GPSAltitude = 6
|
||||
GPSTimeStamp = 7
|
||||
GPSSatellites = 8
|
||||
GPSStatus = 9
|
||||
GPSMeasureMode = 10
|
||||
GPSDOP = 11
|
||||
GPSSpeedRef = 12
|
||||
GPSSpeed = 13
|
||||
GPSTrackRef = 14
|
||||
GPSTrack = 15
|
||||
GPSImgDirectionRef = 16
|
||||
GPSImgDirection = 17
|
||||
GPSMapDatum = 18
|
||||
GPSDestLatitudeRef = 19
|
||||
GPSDestLatitude = 20
|
||||
GPSDestLongitudeRef = 21
|
||||
GPSDestLongitude = 22
|
||||
GPSDestBearingRef = 23
|
||||
GPSDestBearing = 24
|
||||
GPSDestDistanceRef = 25
|
||||
GPSDestDistance = 26
|
||||
GPSProcessingMethod = 27
|
||||
GPSAreaInformation = 28
|
||||
GPSDateStamp = 29
|
||||
GPSDifferential = 30
|
||||
GPSHPositioningError = 31
|
||||
GPSVersionID = 0x00
|
||||
GPSLatitudeRef = 0x01
|
||||
GPSLatitude = 0x02
|
||||
GPSLongitudeRef = 0x03
|
||||
GPSLongitude = 0x04
|
||||
GPSAltitudeRef = 0x05
|
||||
GPSAltitude = 0x06
|
||||
GPSTimeStamp = 0x07
|
||||
GPSSatellites = 0x08
|
||||
GPSStatus = 0x09
|
||||
GPSMeasureMode = 0x0A
|
||||
GPSDOP = 0x0B
|
||||
GPSSpeedRef = 0x0C
|
||||
GPSSpeed = 0x0D
|
||||
GPSTrackRef = 0x0E
|
||||
GPSTrack = 0x0F
|
||||
GPSImgDirectionRef = 0x10
|
||||
GPSImgDirection = 0x11
|
||||
GPSMapDatum = 0x12
|
||||
GPSDestLatitudeRef = 0x13
|
||||
GPSDestLatitude = 0x14
|
||||
GPSDestLongitudeRef = 0x15
|
||||
GPSDestLongitude = 0x16
|
||||
GPSDestBearingRef = 0x17
|
||||
GPSDestBearing = 0x18
|
||||
GPSDestDistanceRef = 0x19
|
||||
GPSDestDistance = 0x1A
|
||||
GPSProcessingMethod = 0x1B
|
||||
GPSAreaInformation = 0x1C
|
||||
GPSDateStamp = 0x1D
|
||||
GPSDifferential = 0x1E
|
||||
GPSHPositioningError = 0x1F
|
||||
|
||||
|
||||
"""Maps EXIF GPS tags to tag names."""
|
||||
@@ -341,40 +342,41 @@ GPSTAGS = {i.value: i.name for i in GPS}
|
||||
|
||||
|
||||
class Interop(IntEnum):
|
||||
InteropIndex = 1
|
||||
InteropVersion = 2
|
||||
RelatedImageFileFormat = 4096
|
||||
RelatedImageWidth = 4097
|
||||
RleatedImageHeight = 4098
|
||||
InteropIndex = 0x0001
|
||||
InteropVersion = 0x0002
|
||||
RelatedImageFileFormat = 0x1000
|
||||
RelatedImageWidth = 0x1001
|
||||
RelatedImageHeight = 0x1002
|
||||
|
||||
|
||||
class IFD(IntEnum):
|
||||
Exif = 34665
|
||||
GPSInfo = 34853
|
||||
Makernote = 37500
|
||||
Interop = 40965
|
||||
Exif = 0x8769
|
||||
GPSInfo = 0x8825
|
||||
MakerNote = 0x927C
|
||||
Makernote = 0x927C # Deprecated
|
||||
Interop = 0xA005
|
||||
IFD1 = -1
|
||||
|
||||
|
||||
class LightSource(IntEnum):
|
||||
Unknown = 0
|
||||
Daylight = 1
|
||||
Fluorescent = 2
|
||||
Tungsten = 3
|
||||
Flash = 4
|
||||
Fine = 9
|
||||
Cloudy = 10
|
||||
Shade = 11
|
||||
DaylightFluorescent = 12
|
||||
DayWhiteFluorescent = 13
|
||||
CoolWhiteFluorescent = 14
|
||||
WhiteFluorescent = 15
|
||||
StandardLightA = 17
|
||||
StandardLightB = 18
|
||||
StandardLightC = 19
|
||||
D55 = 20
|
||||
D65 = 21
|
||||
D75 = 22
|
||||
D50 = 23
|
||||
ISO = 24
|
||||
Other = 255
|
||||
Unknown = 0x00
|
||||
Daylight = 0x01
|
||||
Fluorescent = 0x02
|
||||
Tungsten = 0x03
|
||||
Flash = 0x04
|
||||
Fine = 0x09
|
||||
Cloudy = 0x0A
|
||||
Shade = 0x0B
|
||||
DaylightFluorescent = 0x0C
|
||||
DayWhiteFluorescent = 0x0D
|
||||
CoolWhiteFluorescent = 0x0E
|
||||
WhiteFluorescent = 0x0F
|
||||
StandardLightA = 0x11
|
||||
StandardLightB = 0x12
|
||||
StandardLightC = 0x13
|
||||
D55 = 0x14
|
||||
D65 = 0x15
|
||||
D75 = 0x16
|
||||
D50 = 0x17
|
||||
ISO = 0x18
|
||||
Other = 0xFF
|
||||
|
||||
@@ -8,30 +8,52 @@
|
||||
#
|
||||
# See the README file for information on usage and redistribution.
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
import gzip
|
||||
import math
|
||||
|
||||
from . import Image, ImageFile
|
||||
|
||||
|
||||
def _accept(prefix):
|
||||
return prefix[:6] == b"SIMPLE"
|
||||
def _accept(prefix: bytes) -> bool:
|
||||
return prefix.startswith(b"SIMPLE")
|
||||
|
||||
|
||||
class FitsImageFile(ImageFile.ImageFile):
|
||||
format = "FITS"
|
||||
format_description = "FITS"
|
||||
|
||||
def _open(self):
|
||||
headers = {}
|
||||
def _open(self) -> None:
|
||||
assert self.fp is not None
|
||||
|
||||
headers: dict[bytes, bytes] = {}
|
||||
header_in_progress = False
|
||||
decoder_name = ""
|
||||
while True:
|
||||
header = self.fp.read(80)
|
||||
if not header:
|
||||
msg = "Truncated FITS file"
|
||||
raise OSError(msg)
|
||||
keyword = header[:8].strip()
|
||||
if keyword == b"END":
|
||||
if keyword in (b"SIMPLE", b"XTENSION"):
|
||||
header_in_progress = True
|
||||
elif headers and not header_in_progress:
|
||||
# This is now a data unit
|
||||
break
|
||||
elif keyword == b"END":
|
||||
# Seek to the end of the header unit
|
||||
self.fp.seek(math.ceil(self.fp.tell() / 2880) * 2880)
|
||||
if not decoder_name:
|
||||
decoder_name, offset, args = self._parse_headers(headers)
|
||||
|
||||
header_in_progress = False
|
||||
continue
|
||||
|
||||
if decoder_name:
|
||||
# Keep going to read past the headers
|
||||
continue
|
||||
|
||||
value = header[8:].split(b"/")[0].strip()
|
||||
if value.startswith(b"="):
|
||||
value = value[1:].strip()
|
||||
@@ -40,34 +62,91 @@ class FitsImageFile(ImageFile.ImageFile):
|
||||
raise SyntaxError(msg)
|
||||
headers[keyword] = value
|
||||
|
||||
naxis = int(headers[b"NAXIS"])
|
||||
if naxis == 0:
|
||||
if not decoder_name:
|
||||
msg = "No image data"
|
||||
raise ValueError(msg)
|
||||
elif naxis == 1:
|
||||
self._size = 1, int(headers[b"NAXIS1"])
|
||||
else:
|
||||
self._size = int(headers[b"NAXIS1"]), int(headers[b"NAXIS2"])
|
||||
|
||||
number_of_bits = int(headers[b"BITPIX"])
|
||||
offset += self.fp.tell() - 80
|
||||
self.tile = [ImageFile._Tile(decoder_name, (0, 0) + self.size, offset, args)]
|
||||
|
||||
def _get_size(
|
||||
self, headers: dict[bytes, bytes], prefix: bytes
|
||||
) -> tuple[int, int] | None:
|
||||
naxis = int(headers[prefix + b"NAXIS"])
|
||||
if naxis == 0:
|
||||
return None
|
||||
|
||||
if naxis == 1:
|
||||
return 1, int(headers[prefix + b"NAXIS1"])
|
||||
else:
|
||||
return int(headers[prefix + b"NAXIS1"]), int(headers[prefix + b"NAXIS2"])
|
||||
|
||||
def _parse_headers(
|
||||
self, headers: dict[bytes, bytes]
|
||||
) -> tuple[str, int, tuple[str | int, ...]]:
|
||||
prefix = b""
|
||||
decoder_name = "raw"
|
||||
offset = 0
|
||||
if (
|
||||
headers.get(b"XTENSION") == b"'BINTABLE'"
|
||||
and headers.get(b"ZIMAGE") == b"T"
|
||||
and headers[b"ZCMPTYPE"] == b"'GZIP_1 '"
|
||||
):
|
||||
no_prefix_size = self._get_size(headers, prefix) or (0, 0)
|
||||
number_of_bits = int(headers[b"BITPIX"])
|
||||
offset = no_prefix_size[0] * no_prefix_size[1] * (number_of_bits // 8)
|
||||
|
||||
prefix = b"Z"
|
||||
decoder_name = "fits_gzip"
|
||||
|
||||
size = self._get_size(headers, prefix)
|
||||
if not size:
|
||||
return "", 0, ()
|
||||
|
||||
self._size = size
|
||||
|
||||
number_of_bits = int(headers[prefix + b"BITPIX"])
|
||||
if number_of_bits == 8:
|
||||
self._mode = "L"
|
||||
elif number_of_bits == 16:
|
||||
self._mode = "I"
|
||||
# rawmode = "I;16S"
|
||||
self._mode = "I;16"
|
||||
elif number_of_bits == 32:
|
||||
self._mode = "I"
|
||||
elif number_of_bits in (-32, -64):
|
||||
self._mode = "F"
|
||||
# rawmode = "F" if number_of_bits == -32 else "F;64F"
|
||||
|
||||
offset = math.ceil(self.fp.tell() / 2880) * 2880
|
||||
self.tile = [("raw", (0, 0) + self.size, offset, (self.mode, 0, -1))]
|
||||
args: tuple[str | int, ...]
|
||||
if decoder_name == "raw":
|
||||
args = (self.mode, 0, -1)
|
||||
else:
|
||||
args = (number_of_bits,)
|
||||
return decoder_name, offset, args
|
||||
|
||||
|
||||
class FitsGzipDecoder(ImageFile.PyDecoder):
|
||||
_pulls_fd = True
|
||||
|
||||
def decode(self, buffer: bytes | Image.SupportsArrayInterface) -> tuple[int, int]:
|
||||
assert self.fd is not None
|
||||
value = gzip.decompress(self.fd.read())
|
||||
|
||||
rows = []
|
||||
offset = 0
|
||||
number_of_bits = min(self.args[0] // 8, 4)
|
||||
for y in range(self.state.ysize):
|
||||
row = bytearray()
|
||||
for x in range(self.state.xsize):
|
||||
row += value[offset + (4 - number_of_bits) : offset + 4]
|
||||
offset += 4
|
||||
rows.append(row)
|
||||
self.set_as_raw(bytes([pixel for row in rows[::-1] for pixel in row]))
|
||||
return -1, 0
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Registry
|
||||
|
||||
Image.register_open(FitsImageFile.format, FitsImageFile, _accept)
|
||||
Image.register_decoder("fits_gzip", FitsGzipDecoder)
|
||||
|
||||
Image.register_extensions(FitsImageFile.format, [".fit", ".fits"])
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#
|
||||
# See the README file for information on usage and redistribution.
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
@@ -21,14 +22,15 @@ from . import Image, ImageFile, ImagePalette
|
||||
from ._binary import i16le as i16
|
||||
from ._binary import i32le as i32
|
||||
from ._binary import o8
|
||||
from ._util import DeferredError
|
||||
|
||||
#
|
||||
# decoder
|
||||
|
||||
|
||||
def _accept(prefix):
|
||||
def _accept(prefix: bytes) -> bool:
|
||||
return (
|
||||
len(prefix) >= 6
|
||||
len(prefix) >= 16
|
||||
and i16(prefix, 4) in [0xAF11, 0xAF12]
|
||||
and i16(prefix, 14) in [0, 3] # flags
|
||||
)
|
||||
@@ -44,10 +46,16 @@ class FliImageFile(ImageFile.ImageFile):
|
||||
format_description = "Autodesk FLI/FLC Animation"
|
||||
_close_exclusive_fp_after_loading = False
|
||||
|
||||
def _open(self):
|
||||
def _open(self) -> None:
|
||||
# HEAD
|
||||
assert self.fp is not None
|
||||
s = self.fp.read(128)
|
||||
if not (_accept(s) and s[20:22] == b"\x00\x00"):
|
||||
if not (
|
||||
_accept(s)
|
||||
and s[20:22] == b"\x00" * 2
|
||||
and s[42:80] == b"\x00" * 38
|
||||
and s[88:] == b"\x00" * 40
|
||||
):
|
||||
msg = "not an FLI/FLC file"
|
||||
raise SyntaxError(msg)
|
||||
|
||||
@@ -75,13 +83,13 @@ class FliImageFile(ImageFile.ImageFile):
|
||||
|
||||
if i16(s, 4) == 0xF100:
|
||||
# prefix chunk; ignore it
|
||||
self.__offset = self.__offset + i32(s)
|
||||
self.fp.seek(self.__offset + i32(s))
|
||||
s = self.fp.read(16)
|
||||
|
||||
if i16(s, 4) == 0xF1FA:
|
||||
# look for palette chunk
|
||||
number_of_subchunks = i16(s, 6)
|
||||
chunk_size = None
|
||||
chunk_size: int | None = None
|
||||
for _ in range(number_of_subchunks):
|
||||
if chunk_size is not None:
|
||||
self.fp.seek(chunk_size - 6, os.SEEK_CUR)
|
||||
@@ -94,8 +102,9 @@ class FliImageFile(ImageFile.ImageFile):
|
||||
if not chunk_size:
|
||||
break
|
||||
|
||||
palette = [o8(r) + o8(g) + o8(b) for (r, g, b) in palette]
|
||||
self.palette = ImagePalette.raw("RGB", b"".join(palette))
|
||||
self.palette = ImagePalette.raw(
|
||||
"RGB", b"".join(o8(r) + o8(g) + o8(b) for (r, g, b) in palette)
|
||||
)
|
||||
|
||||
# set things up to decode first frame
|
||||
self.__frame = -1
|
||||
@@ -103,10 +112,11 @@ class FliImageFile(ImageFile.ImageFile):
|
||||
self.__rewind = self.fp.tell()
|
||||
self.seek(0)
|
||||
|
||||
def _palette(self, palette, shift):
|
||||
def _palette(self, palette: list[tuple[int, int, int]], shift: int) -> None:
|
||||
# load palette
|
||||
|
||||
i = 0
|
||||
assert self.fp is not None
|
||||
for e in range(i16(self.fp.read(2))):
|
||||
s = self.fp.read(2)
|
||||
i = i + s[0]
|
||||
@@ -121,7 +131,7 @@ class FliImageFile(ImageFile.ImageFile):
|
||||
palette[i] = (r, g, b)
|
||||
i += 1
|
||||
|
||||
def seek(self, frame):
|
||||
def seek(self, frame: int) -> None:
|
||||
if not self._seek_check(frame):
|
||||
return
|
||||
if frame < self.__frame:
|
||||
@@ -130,7 +140,9 @@ class FliImageFile(ImageFile.ImageFile):
|
||||
for f in range(self.__frame + 1, frame + 1):
|
||||
self._seek(f)
|
||||
|
||||
def _seek(self, frame):
|
||||
def _seek(self, frame: int) -> None:
|
||||
if isinstance(self._fp, DeferredError):
|
||||
raise self._fp.ex
|
||||
if frame == 0:
|
||||
self.__frame = -1
|
||||
self._fp.seek(self.__rewind)
|
||||
@@ -150,16 +162,17 @@ class FliImageFile(ImageFile.ImageFile):
|
||||
|
||||
s = self.fp.read(4)
|
||||
if not s:
|
||||
raise EOFError
|
||||
msg = "missing frame size"
|
||||
raise EOFError(msg)
|
||||
|
||||
framesize = i32(s)
|
||||
|
||||
self.decodermaxblock = framesize
|
||||
self.tile = [("fli", (0, 0) + self.size, self.__offset, None)]
|
||||
self.tile = [ImageFile._Tile("fli", (0, 0) + self.size, self.__offset)]
|
||||
|
||||
self.__offset += framesize
|
||||
|
||||
def tell(self):
|
||||
def tell(self) -> int:
|
||||
return self.__frame
|
||||
|
||||
|
||||
|
||||
@@ -13,16 +13,19 @@
|
||||
#
|
||||
# See the README file for information on usage and redistribution.
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import BinaryIO
|
||||
|
||||
from . import Image, _binary
|
||||
|
||||
WIDTH = 800
|
||||
|
||||
|
||||
def puti16(fp, values):
|
||||
def puti16(
|
||||
fp: BinaryIO, values: tuple[int, int, int, int, int, int, int, int, int, int]
|
||||
) -> None:
|
||||
"""Write network order (big-endian) 16-bit sequence"""
|
||||
for v in values:
|
||||
if v < 0:
|
||||
@@ -33,16 +36,32 @@ def puti16(fp, values):
|
||||
class FontFile:
|
||||
"""Base class for raster font file handlers."""
|
||||
|
||||
bitmap = None
|
||||
bitmap: Image.Image | None = None
|
||||
|
||||
def __init__(self):
|
||||
self.info = {}
|
||||
self.glyph = [None] * 256
|
||||
def __init__(self) -> None:
|
||||
self.info: dict[bytes, bytes | int] = {}
|
||||
self.glyph: list[
|
||||
tuple[
|
||||
tuple[int, int],
|
||||
tuple[int, int, int, int],
|
||||
tuple[int, int, int, int],
|
||||
Image.Image,
|
||||
]
|
||||
| None
|
||||
] = [None] * 256
|
||||
|
||||
def __getitem__(self, ix):
|
||||
def __getitem__(self, ix: int) -> (
|
||||
tuple[
|
||||
tuple[int, int],
|
||||
tuple[int, int, int, int],
|
||||
tuple[int, int, int, int],
|
||||
Image.Image,
|
||||
]
|
||||
| None
|
||||
):
|
||||
return self.glyph[ix]
|
||||
|
||||
def compile(self):
|
||||
def compile(self) -> None:
|
||||
"""Create metrics and bitmap"""
|
||||
|
||||
if self.bitmap:
|
||||
@@ -51,7 +70,7 @@ class FontFile:
|
||||
# create bitmap large enough to hold all data
|
||||
h = w = maxwidth = 0
|
||||
lines = 1
|
||||
for glyph in self:
|
||||
for glyph in self.glyph:
|
||||
if glyph:
|
||||
d, dst, src, im = glyph
|
||||
h = max(h, src[3] - src[1])
|
||||
@@ -65,20 +84,22 @@ class FontFile:
|
||||
ysize = lines * h
|
||||
|
||||
if xsize == 0 and ysize == 0:
|
||||
return ""
|
||||
return
|
||||
|
||||
self.ysize = h
|
||||
|
||||
# paste glyphs into bitmap
|
||||
self.bitmap = Image.new("1", (xsize, ysize))
|
||||
self.metrics = [None] * 256
|
||||
self.metrics: list[
|
||||
tuple[tuple[int, int], tuple[int, int, int, int], tuple[int, int, int, int]]
|
||||
| None
|
||||
] = [None] * 256
|
||||
x = y = 0
|
||||
for i in range(256):
|
||||
glyph = self[i]
|
||||
if glyph:
|
||||
d, dst, src, im = glyph
|
||||
xx = src[2] - src[0]
|
||||
# yy = src[3] - src[1]
|
||||
x0, y0 = x, y
|
||||
x = x + xx
|
||||
if x > WIDTH:
|
||||
@@ -89,12 +110,15 @@ class FontFile:
|
||||
self.bitmap.paste(im.crop(src), s)
|
||||
self.metrics[i] = d, dst, s
|
||||
|
||||
def save(self, filename):
|
||||
def save(self, filename: str) -> None:
|
||||
"""Save font"""
|
||||
|
||||
self.compile()
|
||||
|
||||
# font data
|
||||
if not self.bitmap:
|
||||
msg = "No bitmap created"
|
||||
raise ValueError(msg)
|
||||
self.bitmap.save(os.path.splitext(filename)[0] + ".pbm", "PNG")
|
||||
|
||||
# font metrics
|
||||
@@ -105,6 +129,6 @@ class FontFile:
|
||||
for id in range(256):
|
||||
m = self.metrics[id]
|
||||
if not m:
|
||||
puti16(fp, [0] * 10)
|
||||
puti16(fp, (0,) * 10)
|
||||
else:
|
||||
puti16(fp, m[0] + m[1] + m[2])
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
#
|
||||
# See the README file for information on usage and redistribution.
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
import olefile
|
||||
|
||||
from . import Image, ImageFile
|
||||
@@ -39,8 +41,8 @@ MODES = {
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
|
||||
def _accept(prefix):
|
||||
return prefix[:8] == olefile.MAGIC
|
||||
def _accept(prefix: bytes) -> bool:
|
||||
return prefix.startswith(olefile.MAGIC)
|
||||
|
||||
|
||||
##
|
||||
@@ -51,7 +53,7 @@ class FpxImageFile(ImageFile.ImageFile):
|
||||
format = "FPX"
|
||||
format_description = "FlashPix"
|
||||
|
||||
def _open(self):
|
||||
def _open(self) -> None:
|
||||
#
|
||||
# read the OLE directory and see if this is a likely
|
||||
# to be a FlashPix file
|
||||
@@ -62,13 +64,14 @@ class FpxImageFile(ImageFile.ImageFile):
|
||||
msg = "not an FPX file; invalid OLE file"
|
||||
raise SyntaxError(msg) from e
|
||||
|
||||
if self.ole.root.clsid != "56616700-C154-11CE-8553-00AA00A1F95B":
|
||||
root = self.ole.root
|
||||
if not root or root.clsid != "56616700-C154-11CE-8553-00AA00A1F95B":
|
||||
msg = "not an FPX file; bad root CLSID"
|
||||
raise SyntaxError(msg)
|
||||
|
||||
self._open_index(1)
|
||||
|
||||
def _open_index(self, index=1):
|
||||
def _open_index(self, index: int = 1) -> None:
|
||||
#
|
||||
# get the Image Contents Property Set
|
||||
|
||||
@@ -78,12 +81,14 @@ class FpxImageFile(ImageFile.ImageFile):
|
||||
|
||||
# size (highest resolution)
|
||||
|
||||
assert isinstance(prop[0x1000002], int)
|
||||
assert isinstance(prop[0x1000003], int)
|
||||
self._size = prop[0x1000002], prop[0x1000003]
|
||||
|
||||
size = max(self.size)
|
||||
i = 1
|
||||
while size > 64:
|
||||
size = size / 2
|
||||
size = size // 2
|
||||
i += 1
|
||||
self.maxid = i - 1
|
||||
|
||||
@@ -97,16 +102,14 @@ class FpxImageFile(ImageFile.ImageFile):
|
||||
|
||||
s = prop[0x2000002 | id]
|
||||
|
||||
colors = []
|
||||
bands = i32(s, 4)
|
||||
if bands > 4:
|
||||
if not isinstance(s, bytes) or (bands := i32(s, 4)) > 4:
|
||||
msg = "Invalid number of bands"
|
||||
raise OSError(msg)
|
||||
for i in range(bands):
|
||||
# note: for now, we ignore the "uncalibrated" flag
|
||||
colors.append(i32(s, 8 + i * 4) & 0x7FFFFFFF)
|
||||
|
||||
self._mode, self.rawmode = MODES[tuple(colors)]
|
||||
# note: for now, we ignore the "uncalibrated" flag
|
||||
colors = tuple(i32(s, 8 + i * 4) & 0x7FFFFFFF for i in range(bands))
|
||||
|
||||
self._mode, self.rawmode = MODES[colors]
|
||||
|
||||
# load JPEG tables, if any
|
||||
self.jpeg = {}
|
||||
@@ -117,7 +120,7 @@ class FpxImageFile(ImageFile.ImageFile):
|
||||
|
||||
self._open_subimage(1, self.maxid)
|
||||
|
||||
def _open_subimage(self, index=1, subimage=0):
|
||||
def _open_subimage(self, index: int = 1, subimage: int = 0) -> None:
|
||||
#
|
||||
# setup tile descriptors for a given subimage
|
||||
|
||||
@@ -163,18 +166,18 @@ class FpxImageFile(ImageFile.ImageFile):
|
||||
|
||||
if compression == 0:
|
||||
self.tile.append(
|
||||
(
|
||||
ImageFile._Tile(
|
||||
"raw",
|
||||
(x, y, x1, y1),
|
||||
i32(s, i) + 28,
|
||||
(self.rawmode,),
|
||||
self.rawmode,
|
||||
)
|
||||
)
|
||||
|
||||
elif compression == 1:
|
||||
# FIXME: the fill decoder is not implemented
|
||||
self.tile.append(
|
||||
(
|
||||
ImageFile._Tile(
|
||||
"fill",
|
||||
(x, y, x1, y1),
|
||||
i32(s, i) + 28,
|
||||
@@ -202,7 +205,7 @@ class FpxImageFile(ImageFile.ImageFile):
|
||||
jpegmode = rawmode
|
||||
|
||||
self.tile.append(
|
||||
(
|
||||
ImageFile._Tile(
|
||||
"jpeg",
|
||||
(x, y, x1, y1),
|
||||
i32(s, i) + 28,
|
||||
@@ -227,19 +230,20 @@ class FpxImageFile(ImageFile.ImageFile):
|
||||
break # isn't really required
|
||||
|
||||
self.stream = stream
|
||||
self._fp = self.fp
|
||||
self.fp = None
|
||||
|
||||
def load(self):
|
||||
def load(self) -> Image.core.PixelAccess | None:
|
||||
if not self.fp:
|
||||
self.fp = self.ole.openstream(self.stream[:2] + ["Subimage 0000 Data"])
|
||||
|
||||
return ImageFile.ImageFile.load(self)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
self.ole.close()
|
||||
super().close()
|
||||
|
||||
def __exit__(self, *args):
|
||||
def __exit__(self, *args: object) -> None:
|
||||
self.ole.close()
|
||||
super().__exit__()
|
||||
|
||||
|
||||
@@ -51,6 +51,8 @@ bytes for that mipmap level.
|
||||
Note: All data is stored in little-Endian (Intel) byte order.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from enum import IntEnum
|
||||
from io import BytesIO
|
||||
@@ -69,7 +71,7 @@ class FtexImageFile(ImageFile.ImageFile):
|
||||
format = "FTEX"
|
||||
format_description = "Texture File Format (IW2:EOC)"
|
||||
|
||||
def _open(self):
|
||||
def _open(self) -> None:
|
||||
if not _accept(self.fp.read(4)):
|
||||
msg = "not an FTEX file"
|
||||
raise SyntaxError(msg)
|
||||
@@ -77,8 +79,6 @@ class FtexImageFile(ImageFile.ImageFile):
|
||||
self._size = struct.unpack("<2i", self.fp.read(8))
|
||||
mipmap_count, format_count = struct.unpack("<2i", self.fp.read(8))
|
||||
|
||||
self._mode = "RGB"
|
||||
|
||||
# Only support single-format files.
|
||||
# I don't know of any multi-format file.
|
||||
assert format_count == 1
|
||||
@@ -91,9 +91,10 @@ class FtexImageFile(ImageFile.ImageFile):
|
||||
|
||||
if format == Format.DXT1:
|
||||
self._mode = "RGBA"
|
||||
self.tile = [("bcn", (0, 0) + self.size, 0, 1)]
|
||||
self.tile = [ImageFile._Tile("bcn", (0, 0) + self.size, 0, (1,))]
|
||||
elif format == Format.UNCOMPRESSED:
|
||||
self.tile = [("raw", (0, 0) + self.size, 0, ("RGB", 0, 1))]
|
||||
self._mode = "RGB"
|
||||
self.tile = [ImageFile._Tile("raw", (0, 0) + self.size, 0, "RGB")]
|
||||
else:
|
||||
msg = f"Invalid texture compression format: {repr(format)}"
|
||||
raise ValueError(msg)
|
||||
@@ -101,12 +102,12 @@ class FtexImageFile(ImageFile.ImageFile):
|
||||
self.fp.close()
|
||||
self.fp = BytesIO(data)
|
||||
|
||||
def load_seek(self, pos):
|
||||
def load_seek(self, pos: int) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _accept(prefix):
|
||||
return prefix[:4] == MAGIC
|
||||
def _accept(prefix: bytes) -> bool:
|
||||
return prefix.startswith(MAGIC)
|
||||
|
||||
|
||||
Image.register_open(FtexImageFile.format, FtexImageFile, _accept)
|
||||
|
||||
@@ -23,12 +23,13 @@
|
||||
# Version 2 files are saved by GIMP v2.8 (at least)
|
||||
# Version 3 files have a format specifier of 18 for 16bit floats in
|
||||
# the color depth field. This is currently unsupported by Pillow.
|
||||
from __future__ import annotations
|
||||
|
||||
from . import Image, ImageFile
|
||||
from ._binary import i32be as i32
|
||||
|
||||
|
||||
def _accept(prefix):
|
||||
def _accept(prefix: bytes) -> bool:
|
||||
return len(prefix) >= 8 and i32(prefix, 0) >= 20 and i32(prefix, 4) in (1, 2)
|
||||
|
||||
|
||||
@@ -40,7 +41,7 @@ class GbrImageFile(ImageFile.ImageFile):
|
||||
format = "GBR"
|
||||
format_description = "GIMP brush file"
|
||||
|
||||
def _open(self):
|
||||
def _open(self) -> None:
|
||||
header_size = i32(self.fp.read(4))
|
||||
if header_size < 20:
|
||||
msg = "not a GIMP brush"
|
||||
@@ -53,7 +54,7 @@ class GbrImageFile(ImageFile.ImageFile):
|
||||
width = i32(self.fp.read(4))
|
||||
height = i32(self.fp.read(4))
|
||||
color_depth = i32(self.fp.read(4))
|
||||
if width <= 0 or height <= 0:
|
||||
if width == 0 or height == 0:
|
||||
msg = "not a GIMP brush"
|
||||
raise SyntaxError(msg)
|
||||
if color_depth not in (1, 4):
|
||||
@@ -70,7 +71,7 @@ class GbrImageFile(ImageFile.ImageFile):
|
||||
raise SyntaxError(msg)
|
||||
self.info["spacing"] = i32(self.fp.read(4))
|
||||
|
||||
comment = self.fp.read(comment_length)[:-1]
|
||||
self.info["comment"] = self.fp.read(comment_length)[:-1]
|
||||
|
||||
if color_depth == 1:
|
||||
self._mode = "L"
|
||||
@@ -79,16 +80,14 @@ class GbrImageFile(ImageFile.ImageFile):
|
||||
|
||||
self._size = width, height
|
||||
|
||||
self.info["comment"] = comment
|
||||
|
||||
# Image might not be small
|
||||
Image._decompression_bomb_check(self.size)
|
||||
|
||||
# Data is an uncompressed block of w * h * bytes/pixel
|
||||
self._data_size = width * height * color_depth
|
||||
|
||||
def load(self):
|
||||
if not self.im:
|
||||
def load(self) -> Image.core.PixelAccess | None:
|
||||
if self._im is None:
|
||||
self.im = Image.core.new(self.mode, self.size)
|
||||
self.frombytes(self.fp.read(self._data_size))
|
||||
return Image.Image.load(self)
|
||||
|
||||
@@ -25,11 +25,14 @@
|
||||
implementation is provided for convenience and demonstrational
|
||||
purposes only.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import IO
|
||||
|
||||
from . import ImageFile, ImagePalette, UnidentifiedImageError
|
||||
from ._binary import i16be as i16
|
||||
from ._binary import i32be as i32
|
||||
from ._typing import StrOrBytesPath
|
||||
|
||||
|
||||
class GdImageFile(ImageFile.ImageFile):
|
||||
@@ -43,15 +46,17 @@ class GdImageFile(ImageFile.ImageFile):
|
||||
format = "GD"
|
||||
format_description = "GD uncompressed images"
|
||||
|
||||
def _open(self):
|
||||
def _open(self) -> None:
|
||||
# Header
|
||||
assert self.fp is not None
|
||||
|
||||
s = self.fp.read(1037)
|
||||
|
||||
if i16(s) not in [65534, 65535]:
|
||||
msg = "Not a valid GD 2.x .gd file"
|
||||
raise SyntaxError(msg)
|
||||
|
||||
self._mode = "L" # FIXME: "P"
|
||||
self._mode = "P"
|
||||
self._size = i16(s, 2), i16(s, 4)
|
||||
|
||||
true_color = s[6]
|
||||
@@ -63,20 +68,20 @@ class GdImageFile(ImageFile.ImageFile):
|
||||
self.info["transparency"] = tindex
|
||||
|
||||
self.palette = ImagePalette.raw(
|
||||
"XBGR", s[7 + true_color_offset + 4 : 7 + true_color_offset + 4 + 256 * 4]
|
||||
"RGBX", s[7 + true_color_offset + 6 : 7 + true_color_offset + 6 + 256 * 4]
|
||||
)
|
||||
|
||||
self.tile = [
|
||||
(
|
||||
ImageFile._Tile(
|
||||
"raw",
|
||||
(0, 0) + self.size,
|
||||
7 + true_color_offset + 4 + 256 * 4,
|
||||
("L", 0, 1),
|
||||
7 + true_color_offset + 6 + 256 * 4,
|
||||
"L",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def open(fp, mode="r"):
|
||||
def open(fp: StrOrBytesPath | IO[bytes], mode: str = "r") -> GdImageFile:
|
||||
"""
|
||||
Load texture from a GD image file.
|
||||
|
||||
|
||||
@@ -23,17 +23,36 @@
|
||||
#
|
||||
# See the README file for information on usage and redistribution.
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
from enum import IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Any, NamedTuple, cast
|
||||
|
||||
from . import Image, ImageChops, ImageFile, ImagePalette, ImageSequence
|
||||
from . import (
|
||||
Image,
|
||||
ImageChops,
|
||||
ImageFile,
|
||||
ImageMath,
|
||||
ImageOps,
|
||||
ImagePalette,
|
||||
ImageSequence,
|
||||
)
|
||||
from ._binary import i16le as i16
|
||||
from ._binary import o8
|
||||
from ._binary import o16le as o16
|
||||
from ._util import DeferredError
|
||||
|
||||
TYPE_CHECKING = False
|
||||
if TYPE_CHECKING:
|
||||
from typing import IO, Literal
|
||||
|
||||
from . import _imaging
|
||||
from ._typing import Buffer
|
||||
|
||||
|
||||
class LoadingStrategy(IntEnum):
|
||||
@@ -51,8 +70,8 @@ LOADING_STRATEGY = LoadingStrategy.RGB_AFTER_FIRST
|
||||
# Identify/read GIF files
|
||||
|
||||
|
||||
def _accept(prefix):
|
||||
return prefix[:6] in [b"GIF87a", b"GIF89a"]
|
||||
def _accept(prefix: bytes) -> bool:
|
||||
return prefix.startswith((b"GIF87a", b"GIF89a"))
|
||||
|
||||
|
||||
##
|
||||
@@ -67,19 +86,19 @@ class GifImageFile(ImageFile.ImageFile):
|
||||
|
||||
global_palette = None
|
||||
|
||||
def data(self):
|
||||
def data(self) -> bytes | None:
|
||||
s = self.fp.read(1)
|
||||
if s and s[0]:
|
||||
return self.fp.read(s[0])
|
||||
return None
|
||||
|
||||
def _is_palette_needed(self, p):
|
||||
def _is_palette_needed(self, p: bytes) -> bool:
|
||||
for i in range(0, len(p), 3):
|
||||
if not (i // 3 == p[i] == p[i + 1] == p[i + 2]):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _open(self):
|
||||
def _open(self) -> None:
|
||||
# Screen
|
||||
s = self.fp.read(13)
|
||||
if not _accept(s):
|
||||
@@ -88,7 +107,6 @@ class GifImageFile(ImageFile.ImageFile):
|
||||
|
||||
self.info["version"] = s[:6]
|
||||
self._size = i16(s, 6), i16(s, 8)
|
||||
self.tile = []
|
||||
flags = s[10]
|
||||
bits = (flags & 7) + 1
|
||||
|
||||
@@ -103,12 +121,11 @@ class GifImageFile(ImageFile.ImageFile):
|
||||
|
||||
self._fp = self.fp # FIXME: hack
|
||||
self.__rewind = self.fp.tell()
|
||||
self._n_frames = None
|
||||
self._is_animated = None
|
||||
self._n_frames: int | None = None
|
||||
self._seek(0) # get ready to read first frame
|
||||
|
||||
@property
|
||||
def n_frames(self):
|
||||
def n_frames(self) -> int:
|
||||
if self._n_frames is None:
|
||||
current = self.tell()
|
||||
try:
|
||||
@@ -119,30 +136,29 @@ class GifImageFile(ImageFile.ImageFile):
|
||||
self.seek(current)
|
||||
return self._n_frames
|
||||
|
||||
@property
|
||||
def is_animated(self):
|
||||
if self._is_animated is None:
|
||||
if self._n_frames is not None:
|
||||
self._is_animated = self._n_frames != 1
|
||||
else:
|
||||
current = self.tell()
|
||||
if current:
|
||||
self._is_animated = True
|
||||
else:
|
||||
try:
|
||||
self._seek(1, False)
|
||||
self._is_animated = True
|
||||
except EOFError:
|
||||
self._is_animated = False
|
||||
@cached_property
|
||||
def is_animated(self) -> bool:
|
||||
if self._n_frames is not None:
|
||||
return self._n_frames != 1
|
||||
|
||||
self.seek(current)
|
||||
return self._is_animated
|
||||
current = self.tell()
|
||||
if current:
|
||||
return True
|
||||
|
||||
def seek(self, frame):
|
||||
try:
|
||||
self._seek(1, False)
|
||||
is_animated = True
|
||||
except EOFError:
|
||||
is_animated = False
|
||||
|
||||
self.seek(current)
|
||||
return is_animated
|
||||
|
||||
def seek(self, frame: int) -> None:
|
||||
if not self._seek_check(frame):
|
||||
return
|
||||
if frame < self.__frame:
|
||||
self.im = None
|
||||
self._im = None
|
||||
self._seek(0)
|
||||
|
||||
last_frame = self.__frame
|
||||
@@ -154,11 +170,13 @@ class GifImageFile(ImageFile.ImageFile):
|
||||
msg = "no more images in GIF file"
|
||||
raise EOFError(msg) from e
|
||||
|
||||
def _seek(self, frame, update_image=True):
|
||||
def _seek(self, frame: int, update_image: bool = True) -> None:
|
||||
if isinstance(self._fp, DeferredError):
|
||||
raise self._fp.ex
|
||||
if frame == 0:
|
||||
# rewind
|
||||
self.__offset = 0
|
||||
self.dispose = None
|
||||
self.dispose: _imaging.ImagingCore | None = None
|
||||
self.__frame = -1
|
||||
self._fp.seek(self.__rewind)
|
||||
self.disposal_method = 0
|
||||
@@ -183,11 +201,12 @@ class GifImageFile(ImageFile.ImageFile):
|
||||
|
||||
s = self.fp.read(1)
|
||||
if not s or s == b";":
|
||||
raise EOFError
|
||||
msg = "no more images in GIF file"
|
||||
raise EOFError(msg)
|
||||
|
||||
palette = None
|
||||
palette: ImagePalette.ImagePalette | Literal[False] | None = None
|
||||
|
||||
info = {}
|
||||
info: dict[str, Any] = {}
|
||||
frame_transparency = None
|
||||
interlace = None
|
||||
frame_dispose_extent = None
|
||||
@@ -203,7 +222,7 @@ class GifImageFile(ImageFile.ImageFile):
|
||||
#
|
||||
s = self.fp.read(1)
|
||||
block = self.data()
|
||||
if s[0] == 249:
|
||||
if s[0] == 249 and block is not None:
|
||||
#
|
||||
# graphic control extension
|
||||
#
|
||||
@@ -239,14 +258,14 @@ class GifImageFile(ImageFile.ImageFile):
|
||||
info["comment"] = comment
|
||||
s = None
|
||||
continue
|
||||
elif s[0] == 255 and frame == 0:
|
||||
elif s[0] == 255 and frame == 0 and block is not None:
|
||||
#
|
||||
# application extension
|
||||
#
|
||||
info["extension"] = block, self.fp.tell()
|
||||
if block[:11] == b"NETSCAPE2.0":
|
||||
if block.startswith(b"NETSCAPE2.0"):
|
||||
block = self.data()
|
||||
if len(block) >= 3 and block[0] == 1:
|
||||
if block and len(block) >= 3 and block[0] == 1:
|
||||
self.info["loop"] = i16(block, 1)
|
||||
while self.data():
|
||||
pass
|
||||
@@ -280,15 +299,11 @@ class GifImageFile(ImageFile.ImageFile):
|
||||
bits = self.fp.read(1)[0]
|
||||
self.__offset = self.fp.tell()
|
||||
break
|
||||
|
||||
else:
|
||||
pass
|
||||
# raise OSError, "illegal GIF tag `%x`" % s[0]
|
||||
s = None
|
||||
|
||||
if interlace is None:
|
||||
# self._fp = None
|
||||
raise EOFError
|
||||
msg = "image not found in GIF frame"
|
||||
raise EOFError(msg)
|
||||
|
||||
self.__frame = frame
|
||||
if not update_image:
|
||||
@@ -310,18 +325,20 @@ class GifImageFile(ImageFile.ImageFile):
|
||||
else:
|
||||
self._mode = "L"
|
||||
|
||||
if not palette and self.global_palette:
|
||||
if palette:
|
||||
self.palette = palette
|
||||
elif self.global_palette:
|
||||
from copy import copy
|
||||
|
||||
palette = copy(self.global_palette)
|
||||
self.palette = palette
|
||||
self.palette = copy(self.global_palette)
|
||||
else:
|
||||
self.palette = None
|
||||
else:
|
||||
if self.mode == "P":
|
||||
if (
|
||||
LOADING_STRATEGY != LoadingStrategy.RGB_AFTER_DIFFERENT_PALETTE_ONLY
|
||||
or palette
|
||||
):
|
||||
self.pyaccess = None
|
||||
if "transparency" in self.info:
|
||||
self.im.putpalettealpha(self.info["transparency"], 0)
|
||||
self.im = self.im.convert("RGBA", Image.Dither.FLOYDSTEINBERG)
|
||||
@@ -331,58 +348,63 @@ class GifImageFile(ImageFile.ImageFile):
|
||||
self._mode = "RGB"
|
||||
self.im = self.im.convert("RGB", Image.Dither.FLOYDSTEINBERG)
|
||||
|
||||
def _rgb(color):
|
||||
def _rgb(color: int) -> tuple[int, int, int]:
|
||||
if self._frame_palette:
|
||||
color = tuple(self._frame_palette.palette[color * 3 : color * 3 + 3])
|
||||
if color * 3 + 3 > len(self._frame_palette.palette):
|
||||
color = 0
|
||||
return cast(
|
||||
tuple[int, int, int],
|
||||
tuple(self._frame_palette.palette[color * 3 : color * 3 + 3]),
|
||||
)
|
||||
else:
|
||||
color = (color, color, color)
|
||||
return color
|
||||
return (color, color, color)
|
||||
|
||||
self.dispose_extent = frame_dispose_extent
|
||||
try:
|
||||
if self.disposal_method < 2:
|
||||
# do not dispose or none specified
|
||||
self.dispose = None
|
||||
elif self.disposal_method == 2:
|
||||
# replace with background colour
|
||||
self.dispose = None
|
||||
self.dispose_extent: tuple[int, int, int, int] | None = frame_dispose_extent
|
||||
if self.dispose_extent and self.disposal_method >= 2:
|
||||
try:
|
||||
if self.disposal_method == 2:
|
||||
# replace with background colour
|
||||
|
||||
# only dispose the extent in this frame
|
||||
x0, y0, x1, y1 = self.dispose_extent
|
||||
dispose_size = (x1 - x0, y1 - y0)
|
||||
|
||||
Image._decompression_bomb_check(dispose_size)
|
||||
|
||||
# by convention, attempt to use transparency first
|
||||
dispose_mode = "P"
|
||||
color = self.info.get("transparency", frame_transparency)
|
||||
if color is not None:
|
||||
if self.mode in ("RGB", "RGBA"):
|
||||
dispose_mode = "RGBA"
|
||||
color = _rgb(color) + (0,)
|
||||
else:
|
||||
color = self.info.get("background", 0)
|
||||
if self.mode in ("RGB", "RGBA"):
|
||||
dispose_mode = "RGB"
|
||||
color = _rgb(color)
|
||||
self.dispose = Image.core.fill(dispose_mode, dispose_size, color)
|
||||
else:
|
||||
# replace with previous contents
|
||||
if self.im is not None:
|
||||
# only dispose the extent in this frame
|
||||
self.dispose = self._crop(self.im, self.dispose_extent)
|
||||
elif frame_transparency is not None:
|
||||
x0, y0, x1, y1 = self.dispose_extent
|
||||
dispose_size = (x1 - x0, y1 - y0)
|
||||
|
||||
Image._decompression_bomb_check(dispose_size)
|
||||
|
||||
# by convention, attempt to use transparency first
|
||||
dispose_mode = "P"
|
||||
color = frame_transparency
|
||||
if self.mode in ("RGB", "RGBA"):
|
||||
dispose_mode = "RGBA"
|
||||
color = _rgb(frame_transparency) + (0,)
|
||||
color = self.info.get("transparency", frame_transparency)
|
||||
if color is not None:
|
||||
if self.mode in ("RGB", "RGBA"):
|
||||
dispose_mode = "RGBA"
|
||||
color = _rgb(color) + (0,)
|
||||
else:
|
||||
color = self.info.get("background", 0)
|
||||
if self.mode in ("RGB", "RGBA"):
|
||||
dispose_mode = "RGB"
|
||||
color = _rgb(color)
|
||||
self.dispose = Image.core.fill(dispose_mode, dispose_size, color)
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
# replace with previous contents
|
||||
if self._im is not None:
|
||||
# only dispose the extent in this frame
|
||||
self.dispose = self._crop(self.im, self.dispose_extent)
|
||||
elif frame_transparency is not None:
|
||||
x0, y0, x1, y1 = self.dispose_extent
|
||||
dispose_size = (x1 - x0, y1 - y0)
|
||||
|
||||
Image._decompression_bomb_check(dispose_size)
|
||||
dispose_mode = "P"
|
||||
color = frame_transparency
|
||||
if self.mode in ("RGB", "RGBA"):
|
||||
dispose_mode = "RGBA"
|
||||
color = _rgb(frame_transparency) + (0,)
|
||||
self.dispose = Image.core.fill(
|
||||
dispose_mode, dispose_size, color
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if interlace is not None:
|
||||
transparency = -1
|
||||
@@ -393,7 +415,7 @@ class GifImageFile(ImageFile.ImageFile):
|
||||
elif self.mode not in ("RGB", "RGBA"):
|
||||
transparency = frame_transparency
|
||||
self.tile = [
|
||||
(
|
||||
ImageFile._Tile(
|
||||
"gif",
|
||||
(x0, y0, x1, y1),
|
||||
self.__offset,
|
||||
@@ -409,7 +431,7 @@ class GifImageFile(ImageFile.ImageFile):
|
||||
elif k in self.info:
|
||||
del self.info[k]
|
||||
|
||||
def load_prepare(self):
|
||||
def load_prepare(self) -> None:
|
||||
temp_mode = "P" if self._frame_palette else "L"
|
||||
self._prev_im = None
|
||||
if self.__frame == 0:
|
||||
@@ -421,15 +443,22 @@ class GifImageFile(ImageFile.ImageFile):
|
||||
self._prev_im = self.im
|
||||
if self._frame_palette:
|
||||
self.im = Image.core.fill("P", self.size, self._frame_transparency or 0)
|
||||
self.im.putpalette(*self._frame_palette.getdata())
|
||||
self.im.putpalette("RGB", *self._frame_palette.getdata())
|
||||
else:
|
||||
self.im = None
|
||||
self._im = None
|
||||
if not self._prev_im and self._im is not None and self.size != self.im.size:
|
||||
expanded_im = Image.core.fill(self.im.mode, self.size)
|
||||
if self._frame_palette:
|
||||
expanded_im.putpalette("RGB", *self._frame_palette.getdata())
|
||||
expanded_im.paste(self.im, (0, 0) + self.im.size)
|
||||
|
||||
self.im = expanded_im
|
||||
self._mode = temp_mode
|
||||
self._frame_palette = None
|
||||
|
||||
super().load_prepare()
|
||||
|
||||
def load_end(self):
|
||||
def load_end(self) -> None:
|
||||
if self.__frame == 0:
|
||||
if self.mode == "P" and LOADING_STRATEGY == LoadingStrategy.RGB_ALWAYS:
|
||||
if self._frame_transparency is not None:
|
||||
@@ -441,21 +470,37 @@ class GifImageFile(ImageFile.ImageFile):
|
||||
return
|
||||
if not self._prev_im:
|
||||
return
|
||||
if self.size != self._prev_im.size:
|
||||
if self._frame_transparency is not None:
|
||||
expanded_im = Image.core.fill("RGBA", self.size)
|
||||
else:
|
||||
expanded_im = Image.core.fill("P", self.size)
|
||||
expanded_im.putpalette("RGB", "RGB", self.im.getpalette())
|
||||
expanded_im = expanded_im.convert("RGB")
|
||||
expanded_im.paste(self._prev_im, (0, 0) + self._prev_im.size)
|
||||
|
||||
self._prev_im = expanded_im
|
||||
assert self._prev_im is not None
|
||||
if self._frame_transparency is not None:
|
||||
self.im.putpalettealpha(self._frame_transparency, 0)
|
||||
frame_im = self.im.convert("RGBA")
|
||||
if self.mode == "L":
|
||||
frame_im = self.im.convert_transparent("LA", self._frame_transparency)
|
||||
else:
|
||||
self.im.putpalettealpha(self._frame_transparency, 0)
|
||||
frame_im = self.im.convert("RGBA")
|
||||
else:
|
||||
frame_im = self.im.convert("RGB")
|
||||
|
||||
assert self.dispose_extent is not None
|
||||
frame_im = self._crop(frame_im, self.dispose_extent)
|
||||
|
||||
self.im = self._prev_im
|
||||
self._mode = self.im.mode
|
||||
if frame_im.mode == "RGBA":
|
||||
if frame_im.mode in ("LA", "RGBA"):
|
||||
self.im.paste(frame_im, self.dispose_extent, frame_im)
|
||||
else:
|
||||
self.im.paste(frame_im, self.dispose_extent)
|
||||
|
||||
def tell(self):
|
||||
def tell(self) -> int:
|
||||
return self.__frame
|
||||
|
||||
|
||||
@@ -466,7 +511,7 @@ class GifImageFile(ImageFile.ImageFile):
|
||||
RAWMODE = {"1": "L", "L": "L", "P": "P"}
|
||||
|
||||
|
||||
def _normalize_mode(im):
|
||||
def _normalize_mode(im: Image.Image) -> Image.Image:
|
||||
"""
|
||||
Takes an image (or frame), returns an image in a mode that is appropriate
|
||||
for saving in a Gif.
|
||||
@@ -482,6 +527,7 @@ def _normalize_mode(im):
|
||||
return im
|
||||
if Image.getmodebase(im.mode) == "RGB":
|
||||
im = im.convert("P", palette=Image.Palette.ADAPTIVE)
|
||||
assert im.palette is not None
|
||||
if im.palette.mode == "RGBA":
|
||||
for rgba in im.palette.colors:
|
||||
if rgba[3] == 0:
|
||||
@@ -491,7 +537,12 @@ def _normalize_mode(im):
|
||||
return im.convert("L")
|
||||
|
||||
|
||||
def _normalize_palette(im, palette, info):
|
||||
_Palette = bytes | bytearray | list[int] | ImagePalette.ImagePalette
|
||||
|
||||
|
||||
def _normalize_palette(
|
||||
im: Image.Image, palette: _Palette | None, info: dict[str, Any]
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Normalizes the palette for image.
|
||||
- Sets the palette to the incoming palette, if provided.
|
||||
@@ -513,14 +564,18 @@ def _normalize_palette(im, palette, info):
|
||||
|
||||
if im.mode == "P":
|
||||
if not source_palette:
|
||||
source_palette = im.im.getpalette("RGB")[:768]
|
||||
im_palette = im.getpalette(None)
|
||||
assert im_palette is not None
|
||||
source_palette = bytearray(im_palette)
|
||||
else: # L-mode
|
||||
if not source_palette:
|
||||
source_palette = bytearray(i // 3 for i in range(768))
|
||||
im.palette = ImagePalette.ImagePalette("RGB", palette=source_palette)
|
||||
assert source_palette is not None
|
||||
|
||||
if palette:
|
||||
used_palette_colors = []
|
||||
used_palette_colors: list[int | None] = []
|
||||
assert im.palette is not None
|
||||
for i in range(0, len(source_palette), 3):
|
||||
source_color = tuple(source_palette[i : i + 3])
|
||||
index = im.palette.colors.get(source_color)
|
||||
@@ -533,20 +588,38 @@ def _normalize_palette(im, palette, info):
|
||||
if j not in used_palette_colors:
|
||||
used_palette_colors[i] = j
|
||||
break
|
||||
im = im.remap_palette(used_palette_colors)
|
||||
dest_map: list[int] = []
|
||||
for index in used_palette_colors:
|
||||
assert index is not None
|
||||
dest_map.append(index)
|
||||
im = im.remap_palette(dest_map)
|
||||
else:
|
||||
used_palette_colors = _get_optimize(im, info)
|
||||
if used_palette_colors is not None:
|
||||
return im.remap_palette(used_palette_colors, source_palette)
|
||||
optimized_palette_colors = _get_optimize(im, info)
|
||||
if optimized_palette_colors is not None:
|
||||
im = im.remap_palette(optimized_palette_colors, source_palette)
|
||||
if "transparency" in info:
|
||||
try:
|
||||
info["transparency"] = optimized_palette_colors.index(
|
||||
info["transparency"]
|
||||
)
|
||||
except ValueError:
|
||||
del info["transparency"]
|
||||
return im
|
||||
|
||||
assert im.palette is not None
|
||||
im.palette.palette = source_palette
|
||||
return im
|
||||
|
||||
|
||||
def _write_single_frame(im, fp, palette):
|
||||
def _write_single_frame(
|
||||
im: Image.Image,
|
||||
fp: IO[bytes],
|
||||
palette: _Palette | None,
|
||||
) -> None:
|
||||
im_out = _normalize_mode(im)
|
||||
for k, v in im_out.info.items():
|
||||
im.encoderinfo.setdefault(k, v)
|
||||
if isinstance(k, str):
|
||||
im.encoderinfo.setdefault(k, v)
|
||||
im_out = _normalize_palette(im_out, palette, im.encoderinfo)
|
||||
|
||||
for s in _get_global_header(im_out, im.encoderinfo):
|
||||
@@ -559,26 +632,40 @@ def _write_single_frame(im, fp, palette):
|
||||
_write_local_header(fp, im, (0, 0), flags)
|
||||
|
||||
im_out.encoderconfig = (8, get_interlace(im))
|
||||
ImageFile._save(im_out, fp, [("gif", (0, 0) + im.size, 0, RAWMODE[im_out.mode])])
|
||||
ImageFile._save(
|
||||
im_out, fp, [ImageFile._Tile("gif", (0, 0) + im.size, 0, RAWMODE[im_out.mode])]
|
||||
)
|
||||
|
||||
fp.write(b"\0") # end of image data
|
||||
|
||||
|
||||
def _getbbox(base_im, im_frame):
|
||||
if _get_palette_bytes(im_frame) == _get_palette_bytes(base_im):
|
||||
delta = ImageChops.subtract_modulo(im_frame, base_im)
|
||||
else:
|
||||
delta = ImageChops.subtract_modulo(
|
||||
im_frame.convert("RGBA"), base_im.convert("RGBA")
|
||||
)
|
||||
return delta.getbbox(alpha_only=False)
|
||||
def _getbbox(
|
||||
base_im: Image.Image, im_frame: Image.Image
|
||||
) -> tuple[Image.Image, tuple[int, int, int, int] | None]:
|
||||
palette_bytes = [
|
||||
bytes(im.palette.palette) if im.palette else b"" for im in (base_im, im_frame)
|
||||
]
|
||||
if palette_bytes[0] != palette_bytes[1]:
|
||||
im_frame = im_frame.convert("RGBA")
|
||||
base_im = base_im.convert("RGBA")
|
||||
delta = ImageChops.subtract_modulo(im_frame, base_im)
|
||||
return delta, delta.getbbox(alpha_only=False)
|
||||
|
||||
|
||||
def _write_multiple_frames(im, fp, palette):
|
||||
class _Frame(NamedTuple):
|
||||
im: Image.Image
|
||||
bbox: tuple[int, int, int, int] | None
|
||||
encoderinfo: dict[str, Any]
|
||||
|
||||
|
||||
def _write_multiple_frames(
|
||||
im: Image.Image, fp: IO[bytes], palette: _Palette | None
|
||||
) -> bool:
|
||||
duration = im.encoderinfo.get("duration")
|
||||
disposal = im.encoderinfo.get("disposal", im.info.get("disposal"))
|
||||
|
||||
im_frames = []
|
||||
im_frames: list[_Frame] = []
|
||||
previous_im: Image.Image | None = None
|
||||
frame_count = 0
|
||||
background_im = None
|
||||
for imSequence in itertools.chain([im], im.encoderinfo.get("append_images", [])):
|
||||
@@ -589,12 +676,13 @@ def _write_multiple_frames(im, fp, palette):
|
||||
for k, v in im_frame.info.items():
|
||||
if k == "transparency":
|
||||
continue
|
||||
im.encoderinfo.setdefault(k, v)
|
||||
if isinstance(k, str):
|
||||
im.encoderinfo.setdefault(k, v)
|
||||
|
||||
encoderinfo = im.encoderinfo.copy()
|
||||
im_frame = _normalize_palette(im_frame, palette, encoderinfo)
|
||||
if "transparency" in im_frame.info:
|
||||
encoderinfo.setdefault("transparency", im_frame.info["transparency"])
|
||||
im_frame = _normalize_palette(im_frame, palette, encoderinfo)
|
||||
if isinstance(duration, (list, tuple)):
|
||||
encoderinfo["duration"] = duration[frame_count]
|
||||
elif duration is None and "duration" in im_frame.info:
|
||||
@@ -603,63 +691,116 @@ def _write_multiple_frames(im, fp, palette):
|
||||
encoderinfo["disposal"] = disposal[frame_count]
|
||||
frame_count += 1
|
||||
|
||||
if im_frames:
|
||||
diff_frame = None
|
||||
if im_frames and previous_im:
|
||||
# delta frame
|
||||
previous = im_frames[-1]
|
||||
bbox = _getbbox(previous["im"], im_frame)
|
||||
delta, bbox = _getbbox(previous_im, im_frame)
|
||||
if not bbox:
|
||||
# This frame is identical to the previous frame
|
||||
if encoderinfo.get("duration"):
|
||||
previous["encoderinfo"]["duration"] += encoderinfo["duration"]
|
||||
im_frames[-1].encoderinfo["duration"] += encoderinfo["duration"]
|
||||
continue
|
||||
if encoderinfo.get("disposal") == 2:
|
||||
if background_im is None:
|
||||
color = im.encoderinfo.get(
|
||||
"transparency", im.info.get("transparency", (0, 0, 0))
|
||||
)
|
||||
background = _get_background(im_frame, color)
|
||||
background_im = Image.new("P", im_frame.size, background)
|
||||
background_im.putpalette(im_frames[0]["im"].palette)
|
||||
bbox = _getbbox(background_im, im_frame)
|
||||
if im_frames[-1].encoderinfo.get("disposal") == 2:
|
||||
# To appear correctly in viewers using a convention,
|
||||
# only consider transparency, and not background color
|
||||
color = im.encoderinfo.get(
|
||||
"transparency", im.info.get("transparency")
|
||||
)
|
||||
if color is not None:
|
||||
if background_im is None:
|
||||
background = _get_background(im_frame, color)
|
||||
background_im = Image.new("P", im_frame.size, background)
|
||||
first_palette = im_frames[0].im.palette
|
||||
assert first_palette is not None
|
||||
background_im.putpalette(first_palette, first_palette.mode)
|
||||
bbox = _getbbox(background_im, im_frame)[1]
|
||||
else:
|
||||
bbox = (0, 0) + im_frame.size
|
||||
elif encoderinfo.get("optimize") and im_frame.mode != "1":
|
||||
if "transparency" not in encoderinfo:
|
||||
assert im_frame.palette is not None
|
||||
try:
|
||||
encoderinfo["transparency"] = (
|
||||
im_frame.palette._new_color_index(im_frame)
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
if "transparency" in encoderinfo:
|
||||
# When the delta is zero, fill the image with transparency
|
||||
diff_frame = im_frame.copy()
|
||||
fill = Image.new("P", delta.size, encoderinfo["transparency"])
|
||||
if delta.mode == "RGBA":
|
||||
r, g, b, a = delta.split()
|
||||
mask = ImageMath.lambda_eval(
|
||||
lambda args: args["convert"](
|
||||
args["max"](
|
||||
args["max"](
|
||||
args["max"](args["r"], args["g"]), args["b"]
|
||||
),
|
||||
args["a"],
|
||||
)
|
||||
* 255,
|
||||
"1",
|
||||
),
|
||||
r=r,
|
||||
g=g,
|
||||
b=b,
|
||||
a=a,
|
||||
)
|
||||
else:
|
||||
if delta.mode == "P":
|
||||
# Convert to L without considering palette
|
||||
delta_l = Image.new("L", delta.size)
|
||||
delta_l.putdata(delta.getdata())
|
||||
delta = delta_l
|
||||
mask = ImageMath.lambda_eval(
|
||||
lambda args: args["convert"](args["im"] * 255, "1"),
|
||||
im=delta,
|
||||
)
|
||||
diff_frame.paste(fill, mask=ImageOps.invert(mask))
|
||||
else:
|
||||
bbox = None
|
||||
im_frames.append({"im": im_frame, "bbox": bbox, "encoderinfo": encoderinfo})
|
||||
previous_im = im_frame
|
||||
im_frames.append(_Frame(diff_frame or im_frame, bbox, encoderinfo))
|
||||
|
||||
if len(im_frames) > 1:
|
||||
for frame_data in im_frames:
|
||||
im_frame = frame_data["im"]
|
||||
if not frame_data["bbox"]:
|
||||
# global header
|
||||
for s in _get_global_header(im_frame, frame_data["encoderinfo"]):
|
||||
fp.write(s)
|
||||
offset = (0, 0)
|
||||
else:
|
||||
# compress difference
|
||||
if not palette:
|
||||
frame_data["encoderinfo"]["include_color_table"] = True
|
||||
if len(im_frames) == 1:
|
||||
if "duration" in im.encoderinfo:
|
||||
# Since multiple frames will not be written, use the combined duration
|
||||
im.encoderinfo["duration"] = im_frames[0].encoderinfo["duration"]
|
||||
return False
|
||||
|
||||
im_frame = im_frame.crop(frame_data["bbox"])
|
||||
offset = frame_data["bbox"][:2]
|
||||
_write_frame_data(fp, im_frame, offset, frame_data["encoderinfo"])
|
||||
return True
|
||||
elif "duration" in im.encoderinfo and isinstance(
|
||||
im.encoderinfo["duration"], (list, tuple)
|
||||
):
|
||||
# Since multiple frames will not be written, add together the frame durations
|
||||
im.encoderinfo["duration"] = sum(im.encoderinfo["duration"])
|
||||
for frame_data in im_frames:
|
||||
im_frame = frame_data.im
|
||||
if not frame_data.bbox:
|
||||
# global header
|
||||
for s in _get_global_header(im_frame, frame_data.encoderinfo):
|
||||
fp.write(s)
|
||||
offset = (0, 0)
|
||||
else:
|
||||
# compress difference
|
||||
if not palette:
|
||||
frame_data.encoderinfo["include_color_table"] = True
|
||||
|
||||
if frame_data.bbox != (0, 0) + im_frame.size:
|
||||
im_frame = im_frame.crop(frame_data.bbox)
|
||||
offset = frame_data.bbox[:2]
|
||||
_write_frame_data(fp, im_frame, offset, frame_data.encoderinfo)
|
||||
return True
|
||||
|
||||
|
||||
def _save_all(im, fp, filename):
|
||||
def _save_all(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
|
||||
_save(im, fp, filename, save_all=True)
|
||||
|
||||
|
||||
def _save(im, fp, filename, save_all=False):
|
||||
def _save(
|
||||
im: Image.Image, fp: IO[bytes], filename: str | bytes, save_all: bool = False
|
||||
) -> None:
|
||||
# header
|
||||
if "palette" in im.encoderinfo or "palette" in im.info:
|
||||
palette = im.encoderinfo.get("palette", im.info.get("palette"))
|
||||
else:
|
||||
palette = None
|
||||
im.encoderinfo["optimize"] = im.encoderinfo.get("optimize", True)
|
||||
im.encoderinfo.setdefault("optimize", True)
|
||||
|
||||
if not save_all or not _write_multiple_frames(im, fp, palette):
|
||||
_write_single_frame(im, fp, palette)
|
||||
@@ -670,7 +811,7 @@ def _save(im, fp, filename, save_all=False):
|
||||
fp.flush()
|
||||
|
||||
|
||||
def get_interlace(im):
|
||||
def get_interlace(im: Image.Image) -> int:
|
||||
interlace = im.encoderinfo.get("interlace", 1)
|
||||
|
||||
# workaround for @PIL153
|
||||
@@ -680,23 +821,13 @@ def get_interlace(im):
|
||||
return interlace
|
||||
|
||||
|
||||
def _write_local_header(fp, im, offset, flags):
|
||||
transparent_color_exists = False
|
||||
def _write_local_header(
|
||||
fp: IO[bytes], im: Image.Image, offset: tuple[int, int], flags: int
|
||||
) -> None:
|
||||
try:
|
||||
transparency = int(im.encoderinfo["transparency"])
|
||||
except (KeyError, ValueError):
|
||||
pass
|
||||
else:
|
||||
# optimize the block away if transparent color is not used
|
||||
transparent_color_exists = True
|
||||
|
||||
used_palette_colors = _get_optimize(im, im.encoderinfo)
|
||||
if used_palette_colors is not None:
|
||||
# adjust the transparency index after optimize
|
||||
try:
|
||||
transparency = used_palette_colors.index(transparency)
|
||||
except ValueError:
|
||||
transparent_color_exists = False
|
||||
transparency = im.encoderinfo["transparency"]
|
||||
except KeyError:
|
||||
transparency = None
|
||||
|
||||
if "duration" in im.encoderinfo:
|
||||
duration = int(im.encoderinfo["duration"] / 10)
|
||||
@@ -705,11 +836,9 @@ def _write_local_header(fp, im, offset, flags):
|
||||
|
||||
disposal = int(im.encoderinfo.get("disposal", 0))
|
||||
|
||||
if transparent_color_exists or duration != 0 or disposal:
|
||||
packed_flag = 1 if transparent_color_exists else 0
|
||||
if transparency is not None or duration != 0 or disposal:
|
||||
packed_flag = 1 if transparency is not None else 0
|
||||
packed_flag |= disposal << 2
|
||||
if not transparent_color_exists:
|
||||
transparency = 0
|
||||
|
||||
fp.write(
|
||||
b"!"
|
||||
@@ -717,7 +846,7 @@ def _write_local_header(fp, im, offset, flags):
|
||||
+ o8(4) # length
|
||||
+ o8(packed_flag) # packed fields
|
||||
+ o16(duration) # duration
|
||||
+ o8(transparency) # transparency index
|
||||
+ o8(transparency or 0) # transparency index
|
||||
+ o8(0)
|
||||
)
|
||||
|
||||
@@ -742,7 +871,7 @@ def _write_local_header(fp, im, offset, flags):
|
||||
fp.write(o8(8)) # bits
|
||||
|
||||
|
||||
def _save_netpbm(im, fp, filename):
|
||||
def _save_netpbm(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
|
||||
# Unused by default.
|
||||
# To use, uncomment the register_save call at the end of the file.
|
||||
#
|
||||
@@ -773,6 +902,7 @@ def _save_netpbm(im, fp, filename):
|
||||
)
|
||||
|
||||
# Allow ppmquant to receive SIGPIPE if ppmtogif exits
|
||||
assert quant_proc.stdout is not None
|
||||
quant_proc.stdout.close()
|
||||
|
||||
retcode = quant_proc.wait()
|
||||
@@ -794,7 +924,7 @@ def _save_netpbm(im, fp, filename):
|
||||
_FORCE_OPTIMIZE = False
|
||||
|
||||
|
||||
def _get_optimize(im, info):
|
||||
def _get_optimize(im: Image.Image, info: dict[str, Any]) -> list[int] | None:
|
||||
"""
|
||||
Palette optimization is a potentially expensive operation.
|
||||
|
||||
@@ -805,7 +935,7 @@ def _get_optimize(im, info):
|
||||
:param info: encoderinfo
|
||||
:returns: list of indexes of palette entries in use, or None
|
||||
"""
|
||||
if im.mode in ("P", "L") and info and info.get("optimize", 0):
|
||||
if im.mode in ("P", "L") and info and info.get("optimize"):
|
||||
# Potentially expensive operation.
|
||||
|
||||
# The palette saves 3 bytes per color not used, but palette
|
||||
@@ -827,6 +957,7 @@ def _get_optimize(im, info):
|
||||
if optimise or max(used_palette_colors) >= len(used_palette_colors):
|
||||
return used_palette_colors
|
||||
|
||||
assert im.palette is not None
|
||||
num_palette_colors = len(im.palette.palette) // Image.getmodebands(
|
||||
im.palette.mode
|
||||
)
|
||||
@@ -838,9 +969,10 @@ def _get_optimize(im, info):
|
||||
and current_palette_size > 2
|
||||
):
|
||||
return used_palette_colors
|
||||
return None
|
||||
|
||||
|
||||
def _get_color_table_size(palette_bytes):
|
||||
def _get_color_table_size(palette_bytes: bytes) -> int:
|
||||
# calculate the palette size for the header
|
||||
if not palette_bytes:
|
||||
return 0
|
||||
@@ -850,7 +982,7 @@ def _get_color_table_size(palette_bytes):
|
||||
return math.ceil(math.log(len(palette_bytes) // 3, 2)) - 1
|
||||
|
||||
|
||||
def _get_header_palette(palette_bytes):
|
||||
def _get_header_palette(palette_bytes: bytes) -> bytes:
|
||||
"""
|
||||
Returns the palette, null padded to the next power of 2 (*3) bytes
|
||||
suitable for direct inclusion in the GIF header
|
||||
@@ -868,23 +1000,33 @@ def _get_header_palette(palette_bytes):
|
||||
return palette_bytes
|
||||
|
||||
|
||||
def _get_palette_bytes(im):
|
||||
def _get_palette_bytes(im: Image.Image) -> bytes:
|
||||
"""
|
||||
Gets the palette for inclusion in the gif header
|
||||
|
||||
:param im: Image object
|
||||
:returns: Bytes, len<=768 suitable for inclusion in gif header
|
||||
"""
|
||||
return im.palette.palette if im.palette else b""
|
||||
if not im.palette:
|
||||
return b""
|
||||
|
||||
palette = bytes(im.palette.palette)
|
||||
if im.palette.mode == "RGBA":
|
||||
palette = b"".join(palette[i * 4 : i * 4 + 3] for i in range(len(palette) // 3))
|
||||
return palette
|
||||
|
||||
|
||||
def _get_background(im, info_background):
|
||||
def _get_background(
|
||||
im: Image.Image,
|
||||
info_background: int | tuple[int, int, int] | tuple[int, int, int, int] | None,
|
||||
) -> int:
|
||||
background = 0
|
||||
if info_background:
|
||||
if isinstance(info_background, tuple):
|
||||
# WebPImagePlugin stores an RGBA value in info["background"]
|
||||
# So it must be converted to the same format as GifImagePlugin's
|
||||
# info["background"] - a global color table index
|
||||
assert im.palette is not None
|
||||
try:
|
||||
background = im.palette.getcolor(info_background, im)
|
||||
except ValueError as e:
|
||||
@@ -901,7 +1043,7 @@ def _get_background(im, info_background):
|
||||
return background
|
||||
|
||||
|
||||
def _get_global_header(im, info):
|
||||
def _get_global_header(im: Image.Image, info: dict[str, Any]) -> list[bytes]:
|
||||
"""Return a list of strings representing a GIF header"""
|
||||
|
||||
# Header Block
|
||||
@@ -963,7 +1105,12 @@ def _get_global_header(im, info):
|
||||
return header
|
||||
|
||||
|
||||
def _write_frame_data(fp, im_frame, offset, params):
|
||||
def _write_frame_data(
|
||||
fp: IO[bytes],
|
||||
im_frame: Image.Image,
|
||||
offset: tuple[int, int],
|
||||
params: dict[str, Any],
|
||||
) -> None:
|
||||
try:
|
||||
im_frame.encoderinfo = params
|
||||
|
||||
@@ -971,7 +1118,9 @@ def _write_frame_data(fp, im_frame, offset, params):
|
||||
_write_local_header(fp, im_frame, offset, 0)
|
||||
|
||||
ImageFile._save(
|
||||
im_frame, fp, [("gif", (0, 0) + im_frame.size, 0, RAWMODE[im_frame.mode])]
|
||||
im_frame,
|
||||
fp,
|
||||
[ImageFile._Tile("gif", (0, 0) + im_frame.size, 0, RAWMODE[im_frame.mode])],
|
||||
)
|
||||
|
||||
fp.write(b"\0") # end of image data
|
||||
@@ -983,7 +1132,9 @@ def _write_frame_data(fp, im_frame, offset, params):
|
||||
# Legacy GIF utilities
|
||||
|
||||
|
||||
def getheader(im, palette=None, info=None):
|
||||
def getheader(
|
||||
im: Image.Image, palette: _Palette | None = None, info: dict[str, Any] | None = None
|
||||
) -> tuple[list[bytes], list[int] | None]:
|
||||
"""
|
||||
Legacy Method to get Gif data from image.
|
||||
|
||||
@@ -995,11 +1146,11 @@ def getheader(im, palette=None, info=None):
|
||||
:returns: tuple of(list of header items, optimized palette)
|
||||
|
||||
"""
|
||||
used_palette_colors = _get_optimize(im, info)
|
||||
|
||||
if info is None:
|
||||
info = {}
|
||||
|
||||
used_palette_colors = _get_optimize(im, info)
|
||||
|
||||
if "background" not in info and "background" in im.info:
|
||||
info["background"] = im.info["background"]
|
||||
|
||||
@@ -1011,7 +1162,9 @@ def getheader(im, palette=None, info=None):
|
||||
return header, used_palette_colors
|
||||
|
||||
|
||||
def getdata(im, offset=(0, 0), **params):
|
||||
def getdata(
|
||||
im: Image.Image, offset: tuple[int, int] = (0, 0), **params: Any
|
||||
) -> list[bytes]:
|
||||
"""
|
||||
Legacy Method
|
||||
|
||||
@@ -1028,12 +1181,14 @@ def getdata(im, offset=(0, 0), **params):
|
||||
:returns: List of bytes containing GIF encoded frame data
|
||||
|
||||
"""
|
||||
from io import BytesIO
|
||||
|
||||
class Collector:
|
||||
class Collector(BytesIO):
|
||||
data = []
|
||||
|
||||
def write(self, data):
|
||||
def write(self, data: Buffer) -> int:
|
||||
self.data.append(data)
|
||||
return len(data)
|
||||
|
||||
im.load() # make sure raster data is available
|
||||
|
||||
|
||||
@@ -18,17 +18,22 @@ Stuff to translate curve segments to palette values (derived from
|
||||
the corresponding code in GIMP, written by Federico Mena Quintero.
|
||||
See the GIMP distribution for more information.)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from math import log, pi, sin, sqrt
|
||||
|
||||
from ._binary import o8
|
||||
|
||||
TYPE_CHECKING = False
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from typing import IO
|
||||
|
||||
EPSILON = 1e-10
|
||||
"""""" # Enable auto-doc for data member
|
||||
|
||||
|
||||
def linear(middle, pos):
|
||||
def linear(middle: float, pos: float) -> float:
|
||||
if pos <= middle:
|
||||
if middle < EPSILON:
|
||||
return 0.0
|
||||
@@ -43,19 +48,19 @@ def linear(middle, pos):
|
||||
return 0.5 + 0.5 * pos / middle
|
||||
|
||||
|
||||
def curved(middle, pos):
|
||||
def curved(middle: float, pos: float) -> float:
|
||||
return pos ** (log(0.5) / log(max(middle, EPSILON)))
|
||||
|
||||
|
||||
def sine(middle, pos):
|
||||
def sine(middle: float, pos: float) -> float:
|
||||
return (sin((-pi / 2.0) + pi * linear(middle, pos)) + 1.0) / 2.0
|
||||
|
||||
|
||||
def sphere_increasing(middle, pos):
|
||||
def sphere_increasing(middle: float, pos: float) -> float:
|
||||
return sqrt(1.0 - (linear(middle, pos) - 1.0) ** 2)
|
||||
|
||||
|
||||
def sphere_decreasing(middle, pos):
|
||||
def sphere_decreasing(middle: float, pos: float) -> float:
|
||||
return 1.0 - sqrt(1.0 - linear(middle, pos) ** 2)
|
||||
|
||||
|
||||
@@ -64,9 +69,22 @@ SEGMENTS = [linear, curved, sine, sphere_increasing, sphere_decreasing]
|
||||
|
||||
|
||||
class GradientFile:
|
||||
gradient = None
|
||||
gradient: (
|
||||
list[
|
||||
tuple[
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
list[float],
|
||||
list[float],
|
||||
Callable[[float, float], float],
|
||||
]
|
||||
]
|
||||
| None
|
||||
) = None
|
||||
|
||||
def getpalette(self, entries=256):
|
||||
def getpalette(self, entries: int = 256) -> tuple[bytes, str]:
|
||||
assert self.gradient is not None
|
||||
palette = []
|
||||
|
||||
ix = 0
|
||||
@@ -101,8 +119,8 @@ class GradientFile:
|
||||
class GimpGradientFile(GradientFile):
|
||||
"""File handler for GIMP's gradient format."""
|
||||
|
||||
def __init__(self, fp):
|
||||
if fp.readline()[:13] != b"GIMP Gradient":
|
||||
def __init__(self, fp: IO[bytes]) -> None:
|
||||
if not fp.readline().startswith(b"GIMP Gradient"):
|
||||
msg = "not a GIMP gradient file"
|
||||
raise SyntaxError(msg)
|
||||
|
||||
@@ -114,7 +132,7 @@ class GimpGradientFile(GradientFile):
|
||||
|
||||
count = int(line)
|
||||
|
||||
gradient = []
|
||||
self.gradient = []
|
||||
|
||||
for i in range(count):
|
||||
s = fp.readline().split()
|
||||
@@ -132,6 +150,4 @@ class GimpGradientFile(GradientFile):
|
||||
msg = "cannot handle HSV colour space"
|
||||
raise OSError(msg)
|
||||
|
||||
gradient.append((x0, x1, xm, rgb0, rgb1, segment))
|
||||
|
||||
self.gradient = gradient
|
||||
self.gradient.append((x0, x1, xm, rgb0, rgb1, segment))
|
||||
|
||||
@@ -13,10 +13,14 @@
|
||||
#
|
||||
# See the README file for information on usage and redistribution.
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
from ._binary import o8
|
||||
TYPE_CHECKING = False
|
||||
if TYPE_CHECKING:
|
||||
from typing import IO
|
||||
|
||||
|
||||
class GimpPaletteFile:
|
||||
@@ -24,14 +28,18 @@ class GimpPaletteFile:
|
||||
|
||||
rawmode = "RGB"
|
||||
|
||||
def __init__(self, fp):
|
||||
self.palette = [o8(i) * 3 for i in range(256)]
|
||||
|
||||
if fp.readline()[:12] != b"GIMP Palette":
|
||||
def _read(self, fp: IO[bytes], limit: bool = True) -> None:
|
||||
if not fp.readline().startswith(b"GIMP Palette"):
|
||||
msg = "not a GIMP palette file"
|
||||
raise SyntaxError(msg)
|
||||
|
||||
for i in range(256):
|
||||
palette: list[int] = []
|
||||
i = 0
|
||||
while True:
|
||||
if limit and i == 256 + 3:
|
||||
break
|
||||
|
||||
i += 1
|
||||
s = fp.readline()
|
||||
if not s:
|
||||
break
|
||||
@@ -39,18 +47,29 @@ class GimpPaletteFile:
|
||||
# skip fields and comment lines
|
||||
if re.match(rb"\w+:|#", s):
|
||||
continue
|
||||
if len(s) > 100:
|
||||
if limit and len(s) > 100:
|
||||
msg = "bad palette file"
|
||||
raise SyntaxError(msg)
|
||||
|
||||
v = tuple(map(int, s.split()[:3]))
|
||||
if len(v) != 3:
|
||||
v = s.split(maxsplit=3)
|
||||
if len(v) < 3:
|
||||
msg = "bad palette entry"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.palette[i] = o8(v[0]) + o8(v[1]) + o8(v[2])
|
||||
palette += (int(v[i]) for i in range(3))
|
||||
if limit and len(palette) == 768:
|
||||
break
|
||||
|
||||
self.palette = b"".join(self.palette)
|
||||
self.palette = bytes(palette)
|
||||
|
||||
def getpalette(self):
|
||||
def __init__(self, fp: IO[bytes]) -> None:
|
||||
self._read(fp)
|
||||
|
||||
@classmethod
|
||||
def frombytes(cls, data: bytes) -> GimpPaletteFile:
|
||||
self = cls.__new__(cls)
|
||||
self._read(BytesIO(data), False)
|
||||
return self
|
||||
|
||||
def getpalette(self) -> tuple[bytes, str]:
|
||||
return self.palette, self.rawmode
|
||||
|
||||
@@ -8,13 +8,17 @@
|
||||
#
|
||||
# See the README file for information on usage and redistribution.
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import IO
|
||||
|
||||
from . import Image, ImageFile
|
||||
|
||||
_handler = None
|
||||
|
||||
|
||||
def register_handler(handler):
|
||||
def register_handler(handler: ImageFile.StubHandler | None) -> None:
|
||||
"""
|
||||
Install application-specific GRIB image handler.
|
||||
|
||||
@@ -28,22 +32,20 @@ def register_handler(handler):
|
||||
# Image adapter
|
||||
|
||||
|
||||
def _accept(prefix):
|
||||
return prefix[:4] == b"GRIB" and prefix[7] == 1
|
||||
def _accept(prefix: bytes) -> bool:
|
||||
return len(prefix) >= 8 and prefix.startswith(b"GRIB") and prefix[7] == 1
|
||||
|
||||
|
||||
class GribStubImageFile(ImageFile.StubImageFile):
|
||||
format = "GRIB"
|
||||
format_description = "GRIB"
|
||||
|
||||
def _open(self):
|
||||
offset = self.fp.tell()
|
||||
|
||||
def _open(self) -> None:
|
||||
if not _accept(self.fp.read(8)):
|
||||
msg = "Not a GRIB file"
|
||||
raise SyntaxError(msg)
|
||||
|
||||
self.fp.seek(offset)
|
||||
self.fp.seek(-8, os.SEEK_CUR)
|
||||
|
||||
# make something up
|
||||
self._mode = "F"
|
||||
@@ -53,11 +55,11 @@ class GribStubImageFile(ImageFile.StubImageFile):
|
||||
if loader:
|
||||
loader.open(self)
|
||||
|
||||
def _load(self):
|
||||
def _load(self) -> ImageFile.StubHandler | None:
|
||||
return _handler
|
||||
|
||||
|
||||
def _save(im, fp, filename):
|
||||
def _save(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
|
||||
if _handler is None or not hasattr(_handler, "save"):
|
||||
msg = "GRIB save handler not installed"
|
||||
raise OSError(msg)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user