This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -0,0 +1,50 @@
"""add_anonymous_gdpr_support
Revision ID: 6f7f8689fc98
Revises: 7a899ef55e3b
Create Date: 2025-12-01 04:15:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '6f7f8689fc98'
down_revision = '7a899ef55e3b'
branch_labels = None
depends_on = None
def upgrade() -> None:
# Update gdpr_requests table to support anonymous users
op.alter_column('gdpr_requests', 'user_id',
existing_type=sa.Integer(),
nullable=True)
op.add_column('gdpr_requests', sa.Column('is_anonymous', sa.Boolean(), nullable=False, server_default='0'))
op.create_index(op.f('ix_gdpr_requests_is_anonymous'), 'gdpr_requests', ['is_anonymous'], unique=False)
# Update consents table to support anonymous users
op.alter_column('consents', 'user_id',
existing_type=sa.Integer(),
nullable=True)
op.add_column('consents', sa.Column('user_email', sa.String(length=255), nullable=True))
op.add_column('consents', sa.Column('is_anonymous', sa.Boolean(), nullable=False, server_default='0'))
op.create_index(op.f('ix_consents_user_email'), 'consents', ['user_email'], unique=False)
op.create_index(op.f('ix_consents_is_anonymous'), 'consents', ['is_anonymous'], unique=False)
def downgrade() -> None:
op.drop_index(op.f('ix_consents_is_anonymous'), table_name='consents')
op.drop_index(op.f('ix_consents_user_email'), table_name='consents')
op.drop_column('consents', 'is_anonymous')
op.drop_column('consents', 'user_email')
op.alter_column('consents', 'user_id',
existing_type=sa.Integer(),
nullable=False)
op.drop_index(op.f('ix_gdpr_requests_is_anonymous'), table_name='gdpr_requests')
op.drop_column('gdpr_requests', 'is_anonymous')
op.alter_column('gdpr_requests', 'user_id',
existing_type=sa.Integer(),
nullable=False)

View File

@@ -0,0 +1,173 @@
"""add_comprehensive_gdpr_tables
Revision ID: 7a899ef55e3b
Revises: dbafe747c931
Create Date: 2025-12-01 04:10:25.699589
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision = '7a899ef55e3b'
down_revision = 'dbafe747c931'
branch_labels = None
depends_on = None
def upgrade() -> None:
# Consent table
op.create_table(
'consents',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('consent_type', sa.Enum('marketing', 'analytics', 'necessary', 'preferences', 'third_party_sharing', 'profiling', 'automated_decision_making', name='consenttype'), nullable=False),
sa.Column('status', sa.Enum('granted', 'withdrawn', 'pending', 'expired', name='consentstatus'), nullable=False),
sa.Column('granted_at', sa.DateTime(), nullable=True),
sa.Column('withdrawn_at', sa.DateTime(), nullable=True),
sa.Column('expires_at', sa.DateTime(), nullable=True),
sa.Column('legal_basis', sa.String(length=100), nullable=True),
sa.Column('consent_method', sa.String(length=50), nullable=True),
sa.Column('consent_version', sa.String(length=20), nullable=True),
sa.Column('ip_address', sa.String(length=45), nullable=True),
sa.Column('user_agent', sa.String(length=255), nullable=True),
sa.Column('source', sa.String(length=100), nullable=True),
sa.Column('extra_metadata', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_consents_id'), 'consents', ['id'], unique=False)
op.create_index(op.f('ix_consents_user_id'), 'consents', ['user_id'], unique=False)
op.create_index(op.f('ix_consents_consent_type'), 'consents', ['consent_type'], unique=False)
op.create_index(op.f('ix_consents_status'), 'consents', ['status'], unique=False)
op.create_index(op.f('ix_consents_created_at'), 'consents', ['created_at'], unique=False)
# Data processing records table
op.create_table(
'data_processing_records',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('processing_category', sa.Enum('collection', 'storage', 'usage', 'sharing', 'deletion', 'anonymization', 'transfer', name='processingcategory'), nullable=False),
sa.Column('legal_basis', sa.Enum('consent', 'contract', 'legal_obligation', 'vital_interests', 'public_task', 'legitimate_interests', name='legalbasis'), nullable=False),
sa.Column('purpose', sa.Text(), nullable=False),
sa.Column('data_categories', sa.JSON(), nullable=True),
sa.Column('data_subjects', sa.JSON(), nullable=True),
sa.Column('recipients', sa.JSON(), nullable=True),
sa.Column('third_parties', sa.JSON(), nullable=True),
sa.Column('transfers_to_third_countries', sa.Boolean(), nullable=False),
sa.Column('transfer_countries', sa.JSON(), nullable=True),
sa.Column('safeguards', sa.Text(), nullable=True),
sa.Column('retention_period', sa.String(length=100), nullable=True),
sa.Column('retention_criteria', sa.Text(), nullable=True),
sa.Column('security_measures', sa.Text(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('related_booking_id', sa.Integer(), nullable=True),
sa.Column('related_payment_id', sa.Integer(), nullable=True),
sa.Column('processed_by', sa.Integer(), nullable=True),
sa.Column('processing_timestamp', sa.DateTime(), nullable=False),
sa.Column('extra_metadata', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['processed_by'], ['users.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_data_processing_records_id'), 'data_processing_records', ['id'], unique=False)
op.create_index(op.f('ix_data_processing_records_processing_category'), 'data_processing_records', ['processing_category'], unique=False)
op.create_index(op.f('ix_data_processing_records_legal_basis'), 'data_processing_records', ['legal_basis'], unique=False)
op.create_index(op.f('ix_data_processing_records_user_id'), 'data_processing_records', ['user_id'], unique=False)
op.create_index(op.f('ix_data_processing_records_processing_timestamp'), 'data_processing_records', ['processing_timestamp'], unique=False)
op.create_index(op.f('ix_data_processing_records_created_at'), 'data_processing_records', ['created_at'], unique=False)
# Data breaches table
op.create_table(
'data_breaches',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('breach_type', sa.Enum('confidentiality', 'integrity', 'availability', name='breachtype'), nullable=False),
sa.Column('status', sa.Enum('detected', 'investigating', 'contained', 'reported_to_authority', 'notified_data_subjects', 'resolved', name='breachstatus'), nullable=False),
sa.Column('description', sa.Text(), nullable=False),
sa.Column('affected_data_categories', sa.JSON(), nullable=True),
sa.Column('affected_data_subjects', sa.JSON(), nullable=True),
sa.Column('detected_at', sa.DateTime(), nullable=False),
sa.Column('occurred_at', sa.DateTime(), nullable=True),
sa.Column('contained_at', sa.DateTime(), nullable=True),
sa.Column('reported_to_authority_at', sa.DateTime(), nullable=True),
sa.Column('authority_reference', sa.String(length=255), nullable=True),
sa.Column('notified_data_subjects_at', sa.DateTime(), nullable=True),
sa.Column('notification_method', sa.String(length=100), nullable=True),
sa.Column('likely_consequences', sa.Text(), nullable=True),
sa.Column('measures_proposed', sa.Text(), nullable=True),
sa.Column('risk_level', sa.String(length=20), nullable=True),
sa.Column('reported_by', sa.Integer(), nullable=False),
sa.Column('investigated_by', sa.Integer(), nullable=True),
sa.Column('extra_metadata', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['investigated_by'], ['users.id'], ),
sa.ForeignKeyConstraint(['reported_by'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_data_breaches_id'), 'data_breaches', ['id'], unique=False)
op.create_index(op.f('ix_data_breaches_breach_type'), 'data_breaches', ['breach_type'], unique=False)
op.create_index(op.f('ix_data_breaches_status'), 'data_breaches', ['status'], unique=False)
op.create_index(op.f('ix_data_breaches_detected_at'), 'data_breaches', ['detected_at'], unique=False)
# Retention rules table
op.create_table(
'retention_rules',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('data_category', sa.String(length=100), nullable=False),
sa.Column('retention_period_days', sa.Integer(), nullable=False),
sa.Column('retention_period_months', sa.Integer(), nullable=True),
sa.Column('retention_period_years', sa.Integer(), nullable=True),
sa.Column('legal_basis', sa.Text(), nullable=True),
sa.Column('legal_requirement', sa.Text(), nullable=True),
sa.Column('action_after_retention', sa.String(length=50), nullable=False),
sa.Column('conditions', sa.JSON(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['created_by'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('data_category')
)
op.create_index(op.f('ix_retention_rules_id'), 'retention_rules', ['id'], unique=False)
op.create_index(op.f('ix_retention_rules_data_category'), 'retention_rules', ['data_category'], unique=True)
op.create_index(op.f('ix_retention_rules_is_active'), 'retention_rules', ['is_active'], unique=False)
# Data retention logs table
op.create_table(
'data_retention_logs',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('retention_rule_id', sa.Integer(), nullable=False),
sa.Column('data_category', sa.String(length=100), nullable=False),
sa.Column('action_taken', sa.String(length=50), nullable=False),
sa.Column('records_affected', sa.Integer(), nullable=False),
sa.Column('affected_ids', sa.JSON(), nullable=True),
sa.Column('executed_by', sa.Integer(), nullable=True),
sa.Column('executed_at', sa.DateTime(), nullable=False),
sa.Column('success', sa.Boolean(), nullable=False),
sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('extra_metadata', sa.JSON(), nullable=True),
sa.ForeignKeyConstraint(['executed_by'], ['users.id'], ),
sa.ForeignKeyConstraint(['retention_rule_id'], ['retention_rules.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_data_retention_logs_id'), 'data_retention_logs', ['id'], unique=False)
op.create_index(op.f('ix_data_retention_logs_retention_rule_id'), 'data_retention_logs', ['retention_rule_id'], unique=False)
op.create_index(op.f('ix_data_retention_logs_data_category'), 'data_retention_logs', ['data_category'], unique=False)
op.create_index(op.f('ix_data_retention_logs_executed_at'), 'data_retention_logs', ['executed_at'], unique=False)
def downgrade() -> None:
# Drop foreign keys first, then indexes, then tables
op.drop_table('data_retention_logs')
op.drop_table('retention_rules')
op.drop_table('data_breaches')
op.drop_table('data_processing_records')
op.drop_table('consents')

View File

@@ -1,26 +1,28 @@
fastapi==0.104.1 fastapi==0.123.0
uvicorn[standard]==0.24.0 uvicorn[standard]==0.24.0
python-dotenv==1.0.0 python-dotenv==1.0.0
sqlalchemy==2.0.23 sqlalchemy==2.0.23
pymysql==1.1.0 pymysql==1.1.2
cryptography>=41.0.7 cryptography>=41.0.7
python-jose[cryptography]==3.3.0 python-jose[cryptography]==3.5.0
bcrypt==4.1.2 bcrypt==4.1.2
python-multipart==0.0.6 python-multipart==0.0.20
aiofiles==23.2.1 aiofiles==23.2.1
email-validator==2.1.0 email-validator==2.1.0
pydantic==2.5.0 pydantic==2.5.0
pydantic-settings==2.1.0 pydantic-settings==2.1.0
slowapi==0.1.9 slowapi==0.1.9
pillow==10.1.0 pillow==12.0.0
aiosmtplib==3.0.1 aiosmtplib==3.0.1
jinja2==3.1.2 jinja2==3.1.6
alembic==1.12.1 alembic==1.12.1
stripe>=13.2.0 stripe>=13.2.0
paypal-checkout-serversdk>=1.0.3 paypal-checkout-serversdk>=1.0.3
pyotp==2.9.0 pyotp==2.9.0
qrcode[pil]==7.4.2 qrcode[pil]==7.4.2
httpx==0.25.2 httpx==0.28.1
httpcore==1.0.9
h11==0.16.0
cryptography>=41.0.7 cryptography>=41.0.7
bleach==6.1.0 bleach==6.1.0

View File

@@ -1,13 +1,38 @@
import uvicorn import uvicorn
import signal
import sys
from src.shared.config.settings import settings from src.shared.config.settings import settings
from src.shared.config.logging_config import setup_logging, get_logger from src.shared.config.logging_config import setup_logging, get_logger
setup_logging() setup_logging()
logger = get_logger(__name__) logger = get_logger(__name__)
def signal_handler(sig, frame):
"""Handle Ctrl+C gracefully."""
logger.info('\nReceived interrupt signal (Ctrl+C). Shutting down gracefully...')
sys.exit(0)
if __name__ == '__main__': if __name__ == '__main__':
# Register signal handler for graceful shutdown on Ctrl+C
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
logger.info(f'Starting {settings.APP_NAME} on {settings.HOST}:{settings.PORT}') logger.info(f'Starting {settings.APP_NAME} on {settings.HOST}:{settings.PORT}')
import os import os
from pathlib import Path from pathlib import Path
base_dir = Path(__file__).parent base_dir = Path(__file__).parent
src_dir = str(base_dir / 'src') src_dir = str(base_dir / 'src')
use_reload = False # Enable hot reload in development mode or if explicitly enabled via environment variable
uvicorn.run('src.main:app', host=settings.HOST, port=settings.PORT, reload=use_reload, log_level=settings.LOG_LEVEL.lower(), reload_dirs=[src_dir] if use_reload else None, reload_excludes=['*.log', '*.pyc', '*.pyo', '*.pyd', '__pycache__', '**/__pycache__/**', '*.db', '*.sqlite', '*.sqlite3'], reload_delay=1.0) use_reload = settings.is_development or os.getenv('ENABLE_RELOAD', 'false').lower() == 'true'
if use_reload:
logger.info('Hot reload enabled - server will restart on code changes')
logger.info('Press Ctrl+C to stop the server')
uvicorn.run(
'src.main:app',
host=settings.HOST,
port=settings.PORT,
reload=use_reload,
log_level=settings.LOG_LEVEL.lower(),
reload_dirs=[src_dir] if use_reload else None,
reload_excludes=['*.log', '*.pyc', '*.pyo', '*.pyd', '__pycache__', '**/__pycache__/**', '*.db', '*.sqlite', '*.sqlite3', 'venv/**', '.venv/**'],
reload_delay=0.5
)

View File

@@ -7,6 +7,7 @@ import uuid
import os import os
from ...shared.config.database import get_db from ...shared.config.database import get_db
from ..services.auth_service import auth_service from ..services.auth_service import auth_service
from ..services.session_service import session_service
from ..schemas.auth import RegisterRequest, LoginRequest, RefreshTokenRequest, ForgotPasswordRequest, ResetPasswordRequest, AuthResponse, TokenResponse, MessageResponse, MFAInitResponse, EnableMFARequest, VerifyMFARequest, MFAStatusResponse, UpdateProfileRequest from ..schemas.auth import RegisterRequest, LoginRequest, RefreshTokenRequest, ForgotPasswordRequest, ResetPasswordRequest, AuthResponse, TokenResponse, MessageResponse, MFAInitResponse, EnableMFARequest, VerifyMFARequest, MFAStatusResponse, UpdateProfileRequest
from ...security.middleware.auth import get_current_user from ...security.middleware.auth import get_current_user
from ..models.user import User from ..models.user import User
@@ -85,6 +86,26 @@ async def register(
path='/' path='/'
) )
# Create user session for new registration
try:
# Extract device info from user agent
device_info = None
if user_agent:
device_info = {'user_agent': user_agent}
session_service.create_session(
db=db,
user_id=result['user']['id'],
ip_address=client_ip,
user_agent=user_agent,
device_info=str(device_info) if device_info else None
)
except Exception as e:
# Log error but don't fail registration if session creation fails
from ...shared.config.logging_config import get_logger
logger = get_logger(__name__)
logger.warning(f'Failed to create session during registration: {str(e)}')
# Log successful registration # Log successful registration
await audit_service.log_action( await audit_service.log_action(
db=db, db=db,
@@ -171,6 +192,26 @@ async def login(
path='/' path='/'
) )
# Create user session
try:
# Extract device info from user agent
device_info = None
if user_agent:
device_info = {'user_agent': user_agent}
session_service.create_session(
db=db,
user_id=result['user']['id'],
ip_address=client_ip,
user_agent=user_agent,
device_info=str(device_info) if device_info else None
)
except Exception as e:
# Log error but don't fail login if session creation fails
from ...shared.config.logging_config import get_logger
logger = get_logger(__name__)
logger.warning(f'Failed to create session during login: {str(e)}')
# Log successful login # Log successful login
await audit_service.log_action( await audit_service.log_action(
db=db, db=db,
@@ -394,16 +435,23 @@ async def upload_avatar(request: Request, image: UploadFile=File(...), current_u
# Validate file completely (MIME type, size, magic bytes, integrity) # Validate file completely (MIME type, size, magic bytes, integrity)
content = await validate_uploaded_image(image, max_avatar_size) content = await validate_uploaded_image(image, max_avatar_size)
upload_dir = Path(__file__).parent.parent.parent / 'uploads' / 'avatars' # Use same path calculation as main.py: go from Backend/src/auth/routes/auth_routes.py
# to Backend/uploads/avatars
upload_dir = Path(__file__).parent.parent.parent.parent / 'uploads' / 'avatars'
upload_dir.mkdir(parents=True, exist_ok=True) upload_dir.mkdir(parents=True, exist_ok=True)
if current_user.avatar: if current_user.avatar:
old_avatar_path = Path(__file__).parent.parent.parent / current_user.avatar.lstrip('/') old_avatar_path = Path(__file__).parent.parent.parent.parent / current_user.avatar.lstrip('/')
if old_avatar_path.exists() and old_avatar_path.is_file(): if old_avatar_path.exists() and old_avatar_path.is_file():
try: try:
old_avatar_path.unlink() old_avatar_path.unlink()
except Exception: except Exception:
pass pass
ext = Path(image.filename).suffix or '.png' # Sanitize filename to prevent path traversal attacks
from ...shared.utils.sanitization import sanitize_filename
original_filename = image.filename or 'avatar.png'
sanitized_filename = sanitize_filename(original_filename)
ext = Path(sanitized_filename).suffix or '.png'
# Generate secure filename with user ID and UUID to prevent collisions
filename = f'avatar-{current_user.id}-{uuid.uuid4()}{ext}' filename = f'avatar-{current_user.id}-{uuid.uuid4()}{ext}'
file_path = upload_dir / filename file_path = upload_dir / filename
async with aiofiles.open(file_path, 'wb') as f: async with aiofiles.open(file_path, 'wb') as f:

View File

@@ -1,14 +1,17 @@
""" """
User session management routes. User session management routes.
""" """
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException, Request, Response, Cookie
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from ...shared.config.database import get_db from ...shared.config.database import get_db
from ...shared.config.logging_config import get_logger from ...shared.config.logging_config import get_logger
from ...shared.config.settings import settings
from ...security.middleware.auth import get_current_user from ...security.middleware.auth import get_current_user
from ...auth.models.user import User from ...auth.models.user import User
from ...auth.models.user_session import UserSession
from ...auth.services.session_service import session_service from ...auth.services.session_service import session_service
from ...shared.utils.response_helpers import success_response from ...shared.utils.response_helpers import success_response
from jose import jwt
logger = get_logger(__name__) logger = get_logger(__name__)
router = APIRouter(prefix='/sessions', tags=['sessions']) router = APIRouter(prefix='/sessions', tags=['sessions'])
@@ -44,13 +47,15 @@ async def get_my_sessions(
@router.delete('/{session_id}') @router.delete('/{session_id}')
async def revoke_session( async def revoke_session(
session_id: int, session_id: int,
request: Request,
response: Response,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
access_token: str = Cookie(None, alias='accessToken'),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""Revoke a specific session.""" """Revoke a specific session."""
try: try:
# Verify session belongs to user # Verify session belongs to user
from ...auth.models.user_session import UserSession
session = db.query(UserSession).filter( session = db.query(UserSession).filter(
UserSession.id == session_id, UserSession.id == session_id,
UserSession.user_id == current_user.id UserSession.user_id == current_user.id
@@ -59,10 +64,62 @@ async def revoke_session(
if not session: if not session:
raise HTTPException(status_code=404, detail='Session not found') raise HTTPException(status_code=404, detail='Session not found')
# Check if this is the current session being revoked
# We detect this by checking if:
# 1. The session IP matches the request IP (if available)
# 2. The session is the most recent active session
is_current_session = False
try:
client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent', '')
# Check if session matches current request characteristics
if client_ip and session.ip_address == client_ip:
# Also check if it's the most recent session
recent_session = db.query(UserSession).filter(
UserSession.user_id == current_user.id,
UserSession.is_active == True
).order_by(UserSession.last_activity.desc()).first()
if recent_session and recent_session.id == session_id:
is_current_session = True
except Exception as e:
logger.warning(f'Could not determine if session is current: {str(e)}')
# If we can't determine, check if it's the only active session
active_sessions_count = db.query(UserSession).filter(
UserSession.user_id == current_user.id,
UserSession.is_active == True
).count()
if active_sessions_count <= 1:
is_current_session = True
success = session_service.revoke_session(db=db, session_token=session.session_token) success = session_service.revoke_session(db=db, session_token=session.session_token)
if not success: if not success:
raise HTTPException(status_code=404, detail='Session not found') raise HTTPException(status_code=404, detail='Session not found')
# If this was the current session, clear cookies and indicate logout needed
if is_current_session:
from ...shared.config.settings import settings
samesite_value = 'strict' if settings.is_production else 'lax'
# Clear access token cookie
response.delete_cookie(
key='accessToken',
path='/',
samesite=samesite_value,
secure=settings.is_production
)
# Clear refresh token cookie
response.delete_cookie(
key='refreshToken',
path='/',
samesite=samesite_value,
secure=settings.is_production
)
return success_response(
message='Session revoked successfully. You have been logged out.',
data={'logout_required': True}
)
return success_response(message='Session revoked successfully') return success_response(message='Session revoked successfully')
except HTTPException: except HTTPException:
raise raise
@@ -72,19 +129,41 @@ async def revoke_session(
@router.post('/revoke-all') @router.post('/revoke-all')
async def revoke_all_sessions( async def revoke_all_sessions(
request: Request,
response: Response,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
access_token: str = Cookie(None, alias='accessToken'),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""Revoke all sessions for current user.""" """Revoke all sessions for current user."""
try: try:
count = session_service.revoke_all_user_sessions( count = session_service.revoke_all_user_sessions(
db=db, db=db,
user_id=current_user.id user_id=current_user.id,
exclude_token=None # Don't exclude current session, revoke all
)
# Clear cookies since all sessions (including current) are revoked
from ...shared.config.settings import settings
samesite_value = 'strict' if settings.is_production else 'lax'
# Clear access token cookie
response.delete_cookie(
key='accessToken',
path='/',
samesite=samesite_value,
secure=settings.is_production
)
# Clear refresh token cookie
response.delete_cookie(
key='refreshToken',
path='/',
samesite=samesite_value,
secure=settings.is_production
) )
return success_response( return success_response(
data={'revoked_count': count}, data={'revoked_count': count, 'logout_required': True},
message=f'Revoked {count} session(s)' message=f'Revoked {count} session(s). You have been logged out.'
) )
except Exception as e: except Exception as e:
logger.error(f'Error revoking all sessions: {str(e)}', exc_info=True) logger.error(f'Error revoking all sessions: {str(e)}', exc_info=True)

View File

@@ -29,19 +29,13 @@ class AuthService:
if not self.jwt_secret: if not self.jwt_secret:
error_msg = ( error_msg = (
'CRITICAL: JWT_SECRET is not configured. ' 'CRITICAL: JWT_SECRET is not configured. '
'Please set JWT_SECRET environment variable to a secure random string (minimum 32 characters).' 'Please set JWT_SECRET environment variable to a secure random string (minimum 64 characters). '
'Generate one using: python -c "import secrets; print(secrets.token_urlsafe(64))"'
) )
logger.error(error_msg) logger.error(error_msg)
if settings.is_production: # SECURITY: Always fail if JWT_SECRET is not configured, even in development
raise ValueError(error_msg) # This prevents accidental deployment without proper secrets
else: raise ValueError(error_msg)
# In development, generate a secure secret but warn
import secrets
self.jwt_secret = secrets.token_urlsafe(64)
logger.warning(
f'JWT_SECRET not configured. Auto-generated secret for development. '
f'Set JWT_SECRET environment variable for production: {self.jwt_secret}'
)
# Validate JWT secret strength # Validate JWT secret strength
if len(self.jwt_secret) < 32: if len(self.jwt_secret) < 32:
@@ -65,14 +59,37 @@ class AuthService:
self.jwt_refresh_expires_in = os.getenv("JWT_REFRESH_EXPIRES_IN", "7d") self.jwt_refresh_expires_in = os.getenv("JWT_REFRESH_EXPIRES_IN", "7d")
def generate_tokens(self, user_id: int) -> dict: def generate_tokens(self, user_id: int) -> dict:
from datetime import datetime, timedelta
# SECURITY: Add standard JWT claims for better security
now = datetime.utcnow()
access_expires = now + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
refresh_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
access_payload = {
"userId": user_id,
"exp": access_expires, # Expiration time
"iat": now, # Issued at
"iss": settings.APP_NAME, # Issuer
"type": "access" # Token type
}
refresh_payload = {
"userId": user_id,
"exp": refresh_expires, # Expiration time
"iat": now, # Issued at
"iss": settings.APP_NAME, # Issuer
"type": "refresh" # Token type
}
access_token = jwt.encode( access_token = jwt.encode(
{"userId": user_id}, access_payload,
self.jwt_secret, self.jwt_secret,
algorithm="HS256" algorithm="HS256"
) )
refresh_token = jwt.encode( refresh_token = jwt.encode(
{"userId": user_id}, refresh_payload,
self.jwt_refresh_secret, self.jwt_refresh_secret,
algorithm="HS256" algorithm="HS256"
) )
@@ -316,8 +333,22 @@ class AuthService:
db.commit() db.commit()
raise ValueError("Refresh token expired") raise ValueError("Refresh token expired")
from datetime import datetime, timedelta
# SECURITY: Add standard JWT claims when refreshing token
now = datetime.utcnow()
access_expires = now + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
access_payload = {
"userId": decoded["userId"],
"exp": access_expires, # Expiration time
"iat": now, # Issued at
"iss": settings.APP_NAME, # Issuer
"type": "access" # Token type
}
access_token = jwt.encode( access_token = jwt.encode(
{"userId": decoded["userId"]}, access_payload,
self.jwt_secret, self.jwt_secret,
algorithm="HS256" algorithm="HS256"
) )

View File

@@ -4,7 +4,7 @@ from sqlalchemy import and_, or_, func
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from typing import Optional from typing import Optional
from datetime import datetime from datetime import datetime
import random import secrets
import os import os
from ...shared.config.database import get_db from ...shared.config.database import get_db
from ...shared.config.settings import settings from ...shared.config.settings import settings
@@ -37,7 +37,8 @@ def _generate_invoice_email_html(invoice: dict, is_proforma: bool=False) -> str:
def generate_booking_number() -> str: def generate_booking_number() -> str:
prefix = 'BK' prefix = 'BK'
ts = int(datetime.utcnow().timestamp() * 1000) ts = int(datetime.utcnow().timestamp() * 1000)
rand = random.randint(1000, 9999) # Use cryptographically secure random number to prevent enumeration attacks
rand = secrets.randbelow(9000) + 1000 # Random number between 1000-9999
return f'{prefix}-{ts}-{rand}' return f'{prefix}-{ts}-{rand}'
def calculate_booking_payment_balance(booking: Booking) -> dict: def calculate_booking_payment_balance(booking: Booking) -> dict:

View File

@@ -1,7 +1,7 @@
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
import random import secrets
import string import string
from decimal import Decimal from decimal import Decimal
from ..models.group_booking import ( from ..models.group_booking import (
@@ -21,11 +21,13 @@ class GroupBookingService:
@staticmethod @staticmethod
def generate_group_booking_number(db: Session) -> str: def generate_group_booking_number(db: Session) -> str:
"""Generate unique group booking number""" """Generate unique group booking number using cryptographically secure random"""
max_attempts = 10 max_attempts = 10
alphabet = string.ascii_uppercase + string.digits
for _ in range(max_attempts): for _ in range(max_attempts):
timestamp = datetime.utcnow().strftime('%Y%m%d') timestamp = datetime.utcnow().strftime('%Y%m%d')
random_suffix = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6)) # Use secrets.choice() instead of random.choices() for security
random_suffix = ''.join(secrets.choice(alphabet) for _ in range(6))
booking_number = f"GRP-{timestamp}-{random_suffix}" booking_number = f"GRP-{timestamp}-{random_suffix}"
existing = db.query(GroupBooking).filter( existing = db.query(GroupBooking).filter(
@@ -35,8 +37,9 @@ class GroupBookingService:
if not existing: if not existing:
return booking_number return booking_number
# Fallback # Fallback with secure random suffix
return f"GRP-{int(datetime.utcnow().timestamp())}" random_suffix = ''.join(secrets.choice(alphabet) for _ in range(4))
return f"GRP-{int(datetime.utcnow().timestamp())}{random_suffix}"
@staticmethod @staticmethod
def calculate_group_discount( def calculate_group_discount(
@@ -405,17 +408,19 @@ class GroupBookingService:
# Use proportional share # Use proportional share
booking_price = group_booking.total_price / group_booking.total_rooms booking_price = group_booking.total_price / group_booking.total_rooms
# Generate booking number # Generate booking number using cryptographically secure random
import random
prefix = 'BK' prefix = 'BK'
ts = int(datetime.utcnow().timestamp() * 1000) ts = int(datetime.utcnow().timestamp() * 1000)
rand = random.randint(1000, 9999) # Use secrets.randbelow() instead of random.randint() for security
rand = secrets.randbelow(9000) + 1000 # Random number between 1000-9999
booking_number = f'{prefix}-{ts}-{rand}' booking_number = f'{prefix}-{ts}-{rand}'
# Ensure uniqueness # Ensure uniqueness
existing = db.query(Booking).filter(Booking.booking_number == booking_number).first() existing = db.query(Booking).filter(Booking.booking_number == booking_number).first()
if existing: if existing:
booking_number = f'{prefix}-{ts}-{rand + 1}' # If collision, generate new secure random number
rand = secrets.randbelow(9000) + 1000
booking_number = f'{prefix}-{ts}-{rand}'
# Create booking # Create booking
booking = Booking( booking = Booking(

View File

@@ -0,0 +1,26 @@
"""
GDPR Compliance Models.
"""
from .gdpr_request import GDPRRequest, GDPRRequestType, GDPRRequestStatus
from .consent import Consent, ConsentType, ConsentStatus
from .data_processing_record import DataProcessingRecord, ProcessingCategory, LegalBasis
from .data_breach import DataBreach, BreachType, BreachStatus
from .data_retention import RetentionRule, DataRetentionLog
__all__ = [
'GDPRRequest',
'GDPRRequestType',
'GDPRRequestStatus',
'Consent',
'ConsentType',
'ConsentStatus',
'DataProcessingRecord',
'ProcessingCategory',
'LegalBasis',
'DataBreach',
'BreachType',
'BreachStatus',
'RetentionRule',
'DataRetentionLog',
]

View File

@@ -0,0 +1,64 @@
"""
GDPR Consent Management Model.
"""
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum, JSON, Boolean
from sqlalchemy.orm import relationship
from datetime import datetime
import enum
from ...shared.config.database import Base
class ConsentType(str, enum.Enum):
"""Types of consent that can be given or withdrawn."""
marketing = "marketing"
analytics = "analytics"
necessary = "necessary"
preferences = "preferences"
third_party_sharing = "third_party_sharing"
profiling = "profiling"
automated_decision_making = "automated_decision_making"
class ConsentStatus(str, enum.Enum):
"""Status of consent."""
granted = "granted"
withdrawn = "withdrawn"
pending = "pending"
expired = "expired"
class Consent(Base):
"""Model for tracking user consent for GDPR compliance."""
__tablename__ = 'consents'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('users.id'), nullable=True, index=True) # Nullable for anonymous users
user_email = Column(String(255), nullable=True, index=True) # Email for anonymous users
is_anonymous = Column(Boolean, default=False, nullable=False, index=True) # Flag for anonymous consent
consent_type = Column(Enum(ConsentType), nullable=False, index=True)
status = Column(Enum(ConsentStatus), default=ConsentStatus.granted, nullable=False, index=True)
# Consent details
granted_at = Column(DateTime, nullable=True)
withdrawn_at = Column(DateTime, nullable=True)
expires_at = Column(DateTime, nullable=True) # For time-limited consent
# Legal basis (Article 6 GDPR)
legal_basis = Column(String(100), nullable=True) # consent, contract, legal_obligation, vital_interests, public_task, legitimate_interests
# Consent method
consent_method = Column(String(50), nullable=True) # explicit, implicit, pre_checked
consent_version = Column(String(20), nullable=True) # Version of privacy policy when consent was given
# Metadata
ip_address = Column(String(45), nullable=True)
user_agent = Column(String(255), nullable=True)
source = Column(String(100), nullable=True) # Where consent was given (registration, cookie_banner, etc.)
# Additional data
extra_metadata = Column(JSON, nullable=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
# Relationships
user = relationship('User', foreign_keys=[user_id])

View File

@@ -0,0 +1,70 @@
"""
GDPR Data Breach Notification Model (Article 33-34 GDPR).
"""
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum, JSON, Boolean
from sqlalchemy.orm import relationship
from datetime import datetime
import enum
from ...shared.config.database import Base
class BreachType(str, enum.Enum):
"""Types of data breaches."""
confidentiality = "confidentiality" # Unauthorized disclosure
integrity = "integrity" # Unauthorized alteration
availability = "availability" # Unauthorized destruction or loss
class BreachStatus(str, enum.Enum):
"""Status of breach notification."""
detected = "detected"
investigating = "investigating"
contained = "contained"
reported_to_authority = "reported_to_authority"
notified_data_subjects = "notified_data_subjects"
resolved = "resolved"
class DataBreach(Base):
"""Data breach notification record (Articles 33-34 GDPR)."""
__tablename__ = 'data_breaches'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
# Breach details
breach_type = Column(Enum(BreachType), nullable=False, index=True)
status = Column(Enum(BreachStatus), default=BreachStatus.detected, nullable=False, index=True)
# Description
description = Column(Text, nullable=False) # Nature of the breach
affected_data_categories = Column(JSON, nullable=True) # Categories of personal data affected
affected_data_subjects = Column(JSON, nullable=True) # Approximate number of affected individuals
# Timeline
detected_at = Column(DateTime, nullable=False, index=True)
occurred_at = Column(DateTime, nullable=True) # When breach occurred (if known)
contained_at = Column(DateTime, nullable=True)
# Notification
reported_to_authority_at = Column(DateTime, nullable=True) # Article 33 - 72 hours
authority_reference = Column(String(255), nullable=True) # Reference from supervisory authority
notified_data_subjects_at = Column(DateTime, nullable=True) # Article 34 - without undue delay
notification_method = Column(String(100), nullable=True) # email, public_notice, etc.
# Risk assessment
likely_consequences = Column(Text, nullable=True)
measures_proposed = Column(Text, nullable=True) # Measures to address the breach
risk_level = Column(String(20), nullable=True) # low, medium, high
# Reporting
reported_by = Column(Integer, ForeignKey('users.id'), nullable=False) # Who detected/reported
investigated_by = Column(Integer, ForeignKey('users.id'), nullable=True) # DPO or responsible person
# Additional details
extra_metadata = Column(JSON, nullable=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
# Relationships
reporter = relationship('User', foreign_keys=[reported_by])
investigator = relationship('User', foreign_keys=[investigated_by])

View File

@@ -0,0 +1,78 @@
"""
GDPR Data Processing Records Model (Article 30 GDPR).
"""
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum, JSON, Boolean
from sqlalchemy.orm import relationship
from datetime import datetime
import enum
from ...shared.config.database import Base
class ProcessingCategory(str, enum.Enum):
"""Categories of data processing."""
collection = "collection"
storage = "storage"
usage = "usage"
sharing = "sharing"
deletion = "deletion"
anonymization = "anonymization"
transfer = "transfer"
class LegalBasis(str, enum.Enum):
"""Legal basis for processing (Article 6 GDPR)."""
consent = "consent"
contract = "contract"
legal_obligation = "legal_obligation"
vital_interests = "vital_interests"
public_task = "public_task"
legitimate_interests = "legitimate_interests"
class DataProcessingRecord(Base):
"""Record of data processing activities (Article 30 GDPR requirement)."""
__tablename__ = 'data_processing_records'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
# Processing details
processing_category = Column(Enum(ProcessingCategory), nullable=False, index=True)
legal_basis = Column(Enum(LegalBasis), nullable=False, index=True)
purpose = Column(Text, nullable=False) # Purpose of processing
# Data categories
data_categories = Column(JSON, nullable=True) # List of data categories processed
data_subjects = Column(JSON, nullable=True) # Categories of data subjects
# Recipients
recipients = Column(JSON, nullable=True) # Categories of recipients (internal, third_party, etc.)
third_parties = Column(JSON, nullable=True) # Specific third parties if any
# Transfers
transfers_to_third_countries = Column(Boolean, default=False, nullable=False)
transfer_countries = Column(JSON, nullable=True) # List of countries
safeguards = Column(Text, nullable=True) # Safeguards for transfers
# Retention
retention_period = Column(String(100), nullable=True) # How long data is retained
retention_criteria = Column(Text, nullable=True) # Criteria for determining retention period
# Security measures
security_measures = Column(Text, nullable=True)
# Related entities
user_id = Column(Integer, ForeignKey('users.id'), nullable=True, index=True) # If specific to a user
related_booking_id = Column(Integer, nullable=True, index=True)
related_payment_id = Column(Integer, nullable=True, index=True)
# Processing details
processed_by = Column(Integer, ForeignKey('users.id'), nullable=True) # Staff/admin who processed
processing_timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
# Additional metadata
extra_metadata = Column(JSON, nullable=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
# Relationships
user = relationship('User', foreign_keys=[user_id])
processor = relationship('User', foreign_keys=[processed_by])

View File

@@ -0,0 +1,75 @@
"""
GDPR Data Retention Policy Model.
"""
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum, JSON, Boolean
from sqlalchemy.orm import relationship
from datetime import datetime, timedelta
import enum
from ...shared.config.database import Base
class RetentionRule(Base):
"""Data retention rules for different data types."""
__tablename__ = 'retention_rules'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
# Rule details
data_category = Column(String(100), nullable=False, unique=True, index=True) # user_data, booking_data, payment_data, etc.
retention_period_days = Column(Integer, nullable=False) # Number of days to retain
retention_period_months = Column(Integer, nullable=True) # Alternative: months
retention_period_years = Column(Integer, nullable=True) # Alternative: years
# Legal basis
legal_basis = Column(Text, nullable=True) # Why we retain for this period
legal_requirement = Column(Text, nullable=True) # Specific legal requirement if any
# Action after retention
action_after_retention = Column(String(50), nullable=False, default='anonymize') # delete, anonymize, archive
# Conditions
conditions = Column(JSON, nullable=True) # Additional conditions (e.g., active bookings)
# Status
is_active = Column(Boolean, default=True, nullable=False, index=True)
# Metadata
description = Column(Text, nullable=True)
created_by = Column(Integer, ForeignKey('users.id'), nullable=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
# Relationships
creator = relationship('User', foreign_keys=[created_by])
class DataRetentionLog(Base):
"""Log of data retention actions performed."""
__tablename__ = 'data_retention_logs'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
# Retention action
retention_rule_id = Column(Integer, ForeignKey('retention_rules.id'), nullable=False, index=True)
data_category = Column(String(100), nullable=False, index=True)
action_taken = Column(String(50), nullable=False) # deleted, anonymized, archived
# Affected records
records_affected = Column(Integer, nullable=False, default=0)
affected_ids = Column(JSON, nullable=True) # IDs of affected records (for audit)
# Execution
executed_by = Column(Integer, ForeignKey('users.id'), nullable=True) # System or admin
executed_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
# Results
success = Column(Boolean, default=True, nullable=False)
error_message = Column(Text, nullable=True)
# Metadata
extra_metadata = Column(JSON, nullable=True)
# Relationships
retention_rule = relationship('RetentionRule', foreign_keys=[retention_rule_id])
executor = relationship('User', foreign_keys=[executed_by])

View File

@@ -1,7 +1,7 @@
""" """
GDPR compliance models for data export and deletion requests. GDPR compliance models for data export and deletion requests.
""" """
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum, JSON from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum, JSON, Boolean
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from datetime import datetime from datetime import datetime
import enum import enum
@@ -27,9 +27,10 @@ class GDPRRequest(Base):
request_type = Column(Enum(GDPRRequestType), nullable=False, index=True) request_type = Column(Enum(GDPRRequestType), nullable=False, index=True)
status = Column(Enum(GDPRRequestStatus), default=GDPRRequestStatus.pending, nullable=False, index=True) status = Column(Enum(GDPRRequestStatus), default=GDPRRequestStatus.pending, nullable=False, index=True)
# User making the request # User making the request (nullable for anonymous users)
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True) user_id = Column(Integer, ForeignKey('users.id'), nullable=True, index=True)
user_email = Column(String(255), nullable=False) # Store email even if user is deleted user_email = Column(String(255), nullable=False) # Required: email for anonymous or registered users
is_anonymous = Column(Boolean, default=False, nullable=False, index=True) # Flag for anonymous requests
# Request details # Request details
request_data = Column(JSON, nullable=True) # Additional request parameters request_data = Column(JSON, nullable=True) # Additional request parameters

View File

@@ -0,0 +1,340 @@
"""
Admin routes for GDPR compliance management.
"""
from fastapi import APIRouter, Depends, HTTPException, Query, Body
from sqlalchemy.orm import Session
from typing import Optional, Dict, Any, List
from pydantic import BaseModel
from datetime import datetime
from ...shared.config.database import get_db
from ...shared.config.logging_config import get_logger
from ...security.middleware.auth import authorize_roles
from ...auth.models.user import User
from ..services.breach_service import breach_service
from ..services.retention_service import retention_service
from ..services.data_processing_service import data_processing_service
from ..models.data_breach import BreachType, BreachStatus
from ...shared.utils.response_helpers import success_response
logger = get_logger(__name__)
router = APIRouter(prefix='/gdpr/admin', tags=['gdpr-admin'])
# Data Breach Management
class BreachCreateRequest(BaseModel):
breach_type: str
description: str
affected_data_categories: Optional[List[str]] = None
affected_data_subjects: Optional[int] = None
occurred_at: Optional[str] = None
likely_consequences: Optional[str] = None
measures_proposed: Optional[str] = None
risk_level: Optional[str] = None
@router.post('/breaches')
async def create_breach(
breach_data: BreachCreateRequest,
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Create a data breach record (admin only)."""
try:
try:
breach_type_enum = BreachType(breach_data.breach_type)
except ValueError:
raise HTTPException(status_code=400, detail=f'Invalid breach type: {breach_data.breach_type}')
occurred_at = None
if breach_data.occurred_at:
occurred_at = datetime.fromisoformat(breach_data.occurred_at.replace('Z', '+00:00'))
breach = await breach_service.create_breach(
db=db,
breach_type=breach_type_enum,
description=breach_data.description,
reported_by=current_user.id,
affected_data_categories=breach_data.affected_data_categories,
affected_data_subjects=breach_data.affected_data_subjects,
occurred_at=occurred_at,
likely_consequences=breach_data.likely_consequences,
measures_proposed=breach_data.measures_proposed,
risk_level=breach_data.risk_level
)
return success_response(
data={
'breach_id': breach.id,
'status': breach.status.value,
'detected_at': breach.detected_at.isoformat()
},
message='Data breach record created'
)
except HTTPException:
raise
except Exception as e:
logger.error(f'Error creating breach: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get('/breaches')
async def get_breaches(
status: Optional[str] = Query(None),
page: int = Query(1, ge=1),
limit: int = Query(20, ge=1, le=100),
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Get all data breaches (admin only)."""
try:
status_enum = None
if status:
try:
status_enum = BreachStatus(status)
except ValueError:
raise HTTPException(status_code=400, detail=f'Invalid status: {status}')
offset = (page - 1) * limit
breaches = breach_service.get_breaches(
db=db,
status=status_enum,
limit=limit,
offset=offset
)
return success_response(data={
'breaches': [{
'id': breach.id,
'breach_type': breach.breach_type.value,
'status': breach.status.value,
'description': breach.description,
'risk_level': breach.risk_level,
'detected_at': breach.detected_at.isoformat(),
'reported_to_authority_at': breach.reported_to_authority_at.isoformat() if breach.reported_to_authority_at else None,
'notified_data_subjects_at': breach.notified_data_subjects_at.isoformat() if breach.notified_data_subjects_at else None,
} for breach in breaches]
})
except HTTPException:
raise
except Exception as e:
logger.error(f'Error getting breaches: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post('/breaches/{breach_id}/report-authority')
async def report_breach_to_authority(
breach_id: int,
authority_reference: str = Body(...),
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Report breach to supervisory authority (admin only)."""
try:
breach = await breach_service.report_to_authority(
db=db,
breach_id=breach_id,
authority_reference=authority_reference,
reported_by=current_user.id
)
return success_response(
data={
'breach_id': breach.id,
'authority_reference': breach.authority_reference,
'reported_at': breach.reported_to_authority_at.isoformat()
},
message='Breach reported to supervisory authority'
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error(f'Error reporting breach: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post('/breaches/{breach_id}/notify-subjects')
async def notify_data_subjects(
breach_id: int,
notification_method: str = Body(...),
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Notify affected data subjects (admin only)."""
try:
breach = await breach_service.notify_data_subjects(
db=db,
breach_id=breach_id,
notification_method=notification_method,
notified_by=current_user.id
)
return success_response(
data={
'breach_id': breach.id,
'notification_method': breach.notification_method,
'notified_at': breach.notified_data_subjects_at.isoformat()
},
message='Data subjects notified'
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error(f'Error notifying subjects: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# Data Retention Management
class RetentionRuleCreateRequest(BaseModel):
data_category: str
retention_period_days: int
retention_period_months: Optional[int] = None
retention_period_years: Optional[int] = None
legal_basis: Optional[str] = None
legal_requirement: Optional[str] = None
action_after_retention: str = 'anonymize'
conditions: Optional[Dict[str, Any]] = None
description: Optional[str] = None
@router.post('/retention-rules')
async def create_retention_rule(
rule_data: RetentionRuleCreateRequest,
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Create a data retention rule (admin only)."""
try:
rule = retention_service.create_retention_rule(
db=db,
data_category=rule_data.data_category,
retention_period_days=rule_data.retention_period_days,
retention_period_months=rule_data.retention_period_months,
retention_period_years=rule_data.retention_period_years,
legal_basis=rule_data.legal_basis,
legal_requirement=rule_data.legal_requirement,
action_after_retention=rule_data.action_after_retention,
conditions=rule_data.conditions,
description=rule_data.description,
created_by=current_user.id
)
return success_response(
data={
'rule_id': rule.id,
'data_category': rule.data_category,
'retention_period_days': rule.retention_period_days
},
message='Retention rule created successfully'
)
except Exception as e:
logger.error(f'Error creating retention rule: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get('/retention-rules')
async def get_retention_rules(
is_active: Optional[bool] = Query(None),
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Get retention rules (admin only)."""
try:
rules = retention_service.get_retention_rules(db=db, is_active=is_active)
return success_response(data={
'rules': [{
'id': rule.id,
'data_category': rule.data_category,
'retention_period_days': rule.retention_period_days,
'action_after_retention': rule.action_after_retention,
'is_active': rule.is_active,
'legal_basis': rule.legal_basis
} for rule in rules]
})
except Exception as e:
logger.error(f'Error getting retention rules: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get('/retention-logs')
async def get_retention_logs(
data_category: Optional[str] = Query(None),
page: int = Query(1, ge=1),
limit: int = Query(50, ge=1, le=100),
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Get retention action logs (admin only)."""
try:
offset = (page - 1) * limit
logs = retention_service.get_retention_logs(
db=db,
data_category=data_category,
limit=limit,
offset=offset
)
return success_response(data={
'logs': [{
'id': log.id,
'data_category': log.data_category,
'action_taken': log.action_taken,
'records_affected': log.records_affected,
'executed_at': log.executed_at.isoformat(),
'success': log.success
} for log in logs]
})
except Exception as e:
logger.error(f'Error getting retention logs: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# Data Processing Records (Admin View)
@router.get('/processing-records')
async def get_all_processing_records(
user_id: Optional[int] = Query(None),
processing_category: Optional[str] = Query(None),
legal_basis: Optional[str] = Query(None),
page: int = Query(1, ge=1),
limit: int = Query(50, ge=1, le=100),
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Get all data processing records (admin only)."""
try:
from ..models.data_processing_record import ProcessingCategory, LegalBasis
category_enum = None
if processing_category:
try:
category_enum = ProcessingCategory(processing_category)
except ValueError:
raise HTTPException(status_code=400, detail=f'Invalid processing category: {processing_category}')
basis_enum = None
if legal_basis:
try:
basis_enum = LegalBasis(legal_basis)
except ValueError:
raise HTTPException(status_code=400, detail=f'Invalid legal basis: {legal_basis}')
offset = (page - 1) * limit
records = data_processing_service.get_processing_records(
db=db,
user_id=user_id,
processing_category=category_enum,
legal_basis=basis_enum,
limit=limit,
offset=offset
)
return success_response(data={
'records': [{
'id': record.id,
'processing_category': record.processing_category.value,
'legal_basis': record.legal_basis.value,
'purpose': record.purpose,
'processing_timestamp': record.processing_timestamp.isoformat(),
'user_id': record.user_id
} for record in records]
})
except HTTPException:
raise
except Exception as e:
logger.error(f'Error getting processing records: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -3,46 +3,78 @@ GDPR compliance routes for data export and deletion.
""" """
from fastapi import APIRouter, Depends, HTTPException, Query, Response from fastapi import APIRouter, Depends, HTTPException, Query, Response
from sqlalchemy.orm import Session, noload from sqlalchemy.orm import Session, noload
from sqlalchemy import or_
from typing import Optional from typing import Optional
from datetime import datetime
from ...shared.config.database import get_db from ...shared.config.database import get_db
from ...shared.config.logging_config import get_logger from ...shared.config.logging_config import get_logger
from ...security.middleware.auth import get_current_user, authorize_roles from ...security.middleware.auth import get_current_user, authorize_roles, get_current_user_optional
from ...auth.models.user import User from ...auth.models.user import User
from ..services.gdpr_service import gdpr_service from ..services.gdpr_service import gdpr_service
from ..services.consent_service import consent_service
from ..services.data_processing_service import data_processing_service
from ..models.gdpr_request import GDPRRequest, GDPRRequestType, GDPRRequestStatus from ..models.gdpr_request import GDPRRequest, GDPRRequestType, GDPRRequestStatus
from ..models.consent import ConsentType, ConsentStatus
from ...shared.utils.response_helpers import success_response from ...shared.utils.response_helpers import success_response
from fastapi import Request from fastapi import Request
from pydantic import BaseModel
from typing import Dict, Any, Optional, List
logger = get_logger(__name__) logger = get_logger(__name__)
router = APIRouter(prefix='/gdpr', tags=['gdpr']) router = APIRouter(prefix='/gdpr', tags=['gdpr'])
class AnonymousExportRequest(BaseModel):
email: str
@router.post('/export') @router.post('/export')
async def request_data_export( async def request_data_export(
request: Request, request: Request,
current_user: User = Depends(get_current_user), anonymous_request: Optional[AnonymousExportRequest] = None,
current_user: Optional[User] = Depends(get_current_user_optional),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""Request export of user's personal data (GDPR).""" """Request export of user's personal data (GDPR) - supports both authenticated and anonymous users."""
try: try:
client_ip = request.client.host if request.client else None client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent') user_agent = request.headers.get('User-Agent')
gdpr_request = await gdpr_service.create_data_export_request( # Check if authenticated or anonymous
db=db, if current_user:
user_id=current_user.id, # Authenticated user
ip_address=client_ip, gdpr_request = await gdpr_service.create_data_export_request(
user_agent=user_agent db=db,
) user_id=current_user.id,
ip_address=client_ip,
user_agent=user_agent,
is_anonymous=False
)
elif anonymous_request and anonymous_request.email:
# Anonymous user - requires email
gdpr_request = await gdpr_service.create_data_export_request(
db=db,
user_email=anonymous_request.email,
ip_address=client_ip,
user_agent=user_agent,
is_anonymous=True
)
else:
raise HTTPException(
status_code=400,
detail='Either authentication required or email must be provided for anonymous requests'
)
return success_response( return success_response(
data={ data={
'request_id': gdpr_request.id, 'request_id': gdpr_request.id,
'verification_token': gdpr_request.verification_token, 'verification_token': gdpr_request.verification_token,
'status': gdpr_request.status.value, 'status': gdpr_request.status.value,
'expires_at': gdpr_request.expires_at.isoformat() if gdpr_request.expires_at else None 'expires_at': gdpr_request.expires_at.isoformat() if gdpr_request.expires_at else None,
'is_anonymous': gdpr_request.is_anonymous
}, },
message='Data export request created. You will receive an email with download link once ready.' message='Data export request created. You will receive an email with download link once ready.'
) )
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e: except Exception as e:
logger.error(f'Error creating data export request: {str(e)}', exc_info=True) logger.error(f'Error creating data export request: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@@ -51,20 +83,26 @@ async def request_data_export(
async def get_export_data( async def get_export_data(
request_id: int, request_id: int,
verification_token: str = Query(...), verification_token: str = Query(...),
current_user: User = Depends(get_current_user), current_user: Optional[User] = Depends(get_current_user_optional),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""Get exported user data.""" """Get exported user data - supports both authenticated and anonymous users via verification token."""
try: try:
gdpr_request = db.query(GDPRRequest).options( # Build query - verification token is required for both authenticated and anonymous
query = db.query(GDPRRequest).options(
noload(GDPRRequest.user), noload(GDPRRequest.user),
noload(GDPRRequest.processor) noload(GDPRRequest.processor)
).filter( ).filter(
GDPRRequest.id == request_id, GDPRRequest.id == request_id,
GDPRRequest.user_id == current_user.id,
GDPRRequest.verification_token == verification_token, GDPRRequest.verification_token == verification_token,
GDPRRequest.request_type == GDPRRequestType.data_export GDPRRequest.request_type == GDPRRequestType.data_export
).first() )
# For authenticated users, also verify user_id matches
if current_user:
query = query.filter(GDPRRequest.user_id == current_user.id)
gdpr_request = query.first()
if not gdpr_request: if not gdpr_request:
raise HTTPException(status_code=404, detail='Export request not found or invalid token') raise HTTPException(status_code=404, detail='Export request not found or invalid token')
@@ -73,8 +111,10 @@ async def get_export_data(
# Process export # Process export
export_data = await gdpr_service.export_user_data( export_data = await gdpr_service.export_user_data(
db=db, db=db,
user_id=current_user.id, user_id=gdpr_request.user_id,
request_id=request_id user_email=gdpr_request.user_email,
request_id=request_id,
is_anonymous=gdpr_request.is_anonymous
) )
return success_response(data=export_data) return success_response(data=export_data)
elif gdpr_request.status == GDPRRequestStatus.completed and gdpr_request.export_file_path: elif gdpr_request.status == GDPRRequestStatus.completed and gdpr_request.export_file_path:
@@ -97,32 +137,57 @@ async def get_export_data(
logger.error(f'Error getting export data: {str(e)}', exc_info=True) logger.error(f'Error getting export data: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
class AnonymousDeletionRequest(BaseModel):
email: str
@router.post('/delete') @router.post('/delete')
async def request_data_deletion( async def request_data_deletion(
request: Request, request: Request,
current_user: User = Depends(get_current_user), anonymous_request: Optional[AnonymousDeletionRequest] = None,
current_user: Optional[User] = Depends(get_current_user_optional),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""Request deletion of user's personal data (GDPR - Right to be Forgotten).""" """Request deletion of user's personal data (GDPR - Right to be Forgotten) - supports anonymous users."""
try: try:
client_ip = request.client.host if request.client else None client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent') user_agent = request.headers.get('User-Agent')
gdpr_request = await gdpr_service.create_data_deletion_request( # Check if authenticated or anonymous
db=db, if current_user:
user_id=current_user.id, # Authenticated user
ip_address=client_ip, gdpr_request = await gdpr_service.create_data_deletion_request(
user_agent=user_agent db=db,
) user_id=current_user.id,
ip_address=client_ip,
user_agent=user_agent,
is_anonymous=False
)
elif anonymous_request and anonymous_request.email:
# Anonymous user - requires email
gdpr_request = await gdpr_service.create_data_deletion_request(
db=db,
user_email=anonymous_request.email,
ip_address=client_ip,
user_agent=user_agent,
is_anonymous=True
)
else:
raise HTTPException(
status_code=400,
detail='Either authentication required or email must be provided for anonymous requests'
)
return success_response( return success_response(
data={ data={
'request_id': gdpr_request.id, 'request_id': gdpr_request.id,
'verification_token': gdpr_request.verification_token, 'verification_token': gdpr_request.verification_token,
'status': gdpr_request.status.value 'status': gdpr_request.status.value,
'is_anonymous': gdpr_request.is_anonymous
}, },
message='Data deletion request created. Please verify via email to proceed.' message='Data deletion request created. Please verify via email to proceed.'
) )
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e: except Exception as e:
logger.error(f'Error creating data deletion request: {str(e)}', exc_info=True) logger.error(f'Error creating data deletion request: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@@ -131,21 +196,27 @@ async def request_data_deletion(
async def confirm_data_deletion( async def confirm_data_deletion(
request_id: int, request_id: int,
verification_token: str = Query(...), verification_token: str = Query(...),
current_user: User = Depends(get_current_user), current_user: Optional[User] = Depends(get_current_user_optional),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""Confirm and process data deletion request.""" """Confirm and process data deletion request - supports anonymous users via verification token."""
try: try:
gdpr_request = db.query(GDPRRequest).options( # Build query - verification token is required for both authenticated and anonymous
query = db.query(GDPRRequest).options(
noload(GDPRRequest.user), noload(GDPRRequest.user),
noload(GDPRRequest.processor) noload(GDPRRequest.processor)
).filter( ).filter(
GDPRRequest.id == request_id, GDPRRequest.id == request_id,
GDPRRequest.user_id == current_user.id,
GDPRRequest.verification_token == verification_token, GDPRRequest.verification_token == verification_token,
GDPRRequest.request_type == GDPRRequestType.data_deletion, GDPRRequest.request_type == GDPRRequestType.data_deletion,
GDPRRequest.status == GDPRRequestStatus.pending GDPRRequest.status == GDPRRequestStatus.pending
).first() )
# For authenticated users, also verify user_id matches
if current_user:
query = query.filter(GDPRRequest.user_id == current_user.id)
gdpr_request = query.first()
if not gdpr_request: if not gdpr_request:
raise HTTPException(status_code=404, detail='Deletion request not found or already processed') raise HTTPException(status_code=404, detail='Deletion request not found or already processed')
@@ -153,14 +224,16 @@ async def confirm_data_deletion(
# Process deletion # Process deletion
deletion_log = await gdpr_service.delete_user_data( deletion_log = await gdpr_service.delete_user_data(
db=db, db=db,
user_id=current_user.id, user_id=gdpr_request.user_id,
user_email=gdpr_request.user_email,
request_id=request_id, request_id=request_id,
processed_by=current_user.id processed_by=current_user.id if current_user else None,
is_anonymous=gdpr_request.is_anonymous
) )
return success_response( return success_response(
data=deletion_log, data=deletion_log,
message='Your data has been deleted successfully.' message=deletion_log.get('summary', {}).get('message', 'Your data has been deleted successfully.')
) )
except HTTPException: except HTTPException:
raise raise
@@ -173,13 +246,17 @@ async def get_user_gdpr_requests(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""Get user's GDPR requests.""" """Get user's GDPR requests (both authenticated and anonymous requests by email)."""
try: try:
# Get requests by user_id (authenticated) or by email (includes anonymous)
requests = db.query(GDPRRequest).options( requests = db.query(GDPRRequest).options(
noload(GDPRRequest.user), noload(GDPRRequest.user),
noload(GDPRRequest.processor) noload(GDPRRequest.processor)
).filter( ).filter(
GDPRRequest.user_id == current_user.id or_(
GDPRRequest.user_id == current_user.id,
GDPRRequest.user_email == current_user.email
)
).order_by(GDPRRequest.created_at.desc()).all() ).order_by(GDPRRequest.created_at.desc()).all()
return success_response(data={ return success_response(data={
@@ -187,6 +264,7 @@ async def get_user_gdpr_requests(
'id': req.id, 'id': req.id,
'request_type': req.request_type.value, 'request_type': req.request_type.value,
'status': req.status.value, 'status': req.status.value,
'is_anonymous': req.is_anonymous,
'created_at': req.created_at.isoformat() if req.created_at else None, 'created_at': req.created_at.isoformat() if req.created_at else None,
'processed_at': req.processed_at.isoformat() if req.processed_at else None, 'processed_at': req.processed_at.isoformat() if req.processed_at else None,
} for req in requests] } for req in requests]
@@ -270,3 +348,272 @@ async def delete_gdpr_request(
logger.error(f'Error deleting GDPR request: {str(e)}', exc_info=True) logger.error(f'Error deleting GDPR request: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
# GDPR Rights - Additional Routes
class DataRectificationRequest(BaseModel):
corrections: Dict[str, Any] # e.g., {"full_name": "New Name", "email": "new@email.com"}
@router.post('/rectify')
async def request_data_rectification(
request: Request,
rectification_data: DataRectificationRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Request data rectification (Article 16 GDPR - Right to rectification)."""
try:
client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent')
gdpr_request = await gdpr_service.request_data_rectification(
db=db,
user_id=current_user.id,
corrections=rectification_data.corrections,
ip_address=client_ip,
user_agent=user_agent
)
return success_response(
data={
'request_id': gdpr_request.id,
'verification_token': gdpr_request.verification_token,
'status': gdpr_request.status.value
},
message='Data rectification request created. An admin will review and process your request.'
)
except Exception as e:
logger.error(f'Error creating rectification request: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
class ProcessingRestrictionRequest(BaseModel):
reason: str
@router.post('/restrict')
async def request_processing_restriction(
request: Request,
restriction_data: ProcessingRestrictionRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Request restriction of processing (Article 18 GDPR)."""
try:
client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent')
gdpr_request = await gdpr_service.request_processing_restriction(
db=db,
user_id=current_user.id,
reason=restriction_data.reason,
ip_address=client_ip,
user_agent=user_agent
)
return success_response(
data={
'request_id': gdpr_request.id,
'verification_token': gdpr_request.verification_token,
'status': gdpr_request.status.value
},
message='Processing restriction request created. Your account has been temporarily restricted.'
)
except Exception as e:
logger.error(f'Error creating restriction request: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
class ProcessingObjectionRequest(BaseModel):
processing_purpose: str
reason: Optional[str] = None
@router.post('/object')
async def request_processing_objection(
request: Request,
objection_data: ProcessingObjectionRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Object to processing (Article 21 GDPR - Right to object)."""
try:
client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent')
gdpr_request = await gdpr_service.request_processing_objection(
db=db,
user_id=current_user.id,
processing_purpose=objection_data.processing_purpose,
reason=objection_data.reason,
ip_address=client_ip,
user_agent=user_agent
)
return success_response(
data={
'request_id': gdpr_request.id,
'verification_token': gdpr_request.verification_token,
'status': gdpr_request.status.value
},
message='Processing objection registered. We will review your objection and stop processing for the specified purpose if valid.'
)
except Exception as e:
logger.error(f'Error creating objection request: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# Consent Management Routes
class ConsentUpdateRequest(BaseModel):
consents: Dict[str, bool] # e.g., {"marketing": true, "analytics": false}
@router.get('/consents')
async def get_user_consents(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get user's consent status for all consent types."""
try:
consents = consent_service.get_user_consents(db=db, user_id=current_user.id, include_withdrawn=True)
consent_status = {}
for consent_type in ConsentType:
consent_status[consent_type.value] = {
'has_consent': consent_service.has_consent(db=db, user_id=current_user.id, consent_type=consent_type),
'granted_at': None,
'withdrawn_at': None,
'status': 'none'
}
for consent in consents:
consent_status[consent.consent_type.value] = {
'has_consent': consent.status == ConsentStatus.granted and (not consent.expires_at or consent.expires_at > datetime.utcnow()),
'granted_at': consent.granted_at.isoformat() if consent.granted_at else None,
'withdrawn_at': consent.withdrawn_at.isoformat() if consent.withdrawn_at else None,
'status': consent.status.value,
'expires_at': consent.expires_at.isoformat() if consent.expires_at else None
}
return success_response(data={'consents': consent_status})
except Exception as e:
logger.error(f'Error getting consents: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post('/consents')
async def update_consents(
request: Request,
consent_data: ConsentUpdateRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Update user consent preferences."""
try:
client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent')
# Convert string keys to ConsentType enum
consents_dict = {}
for key, value in consent_data.consents.items():
try:
consent_type = ConsentType(key)
consents_dict[consent_type] = value
except ValueError:
continue
results = await consent_service.update_consent_preferences(
db=db,
user_id=current_user.id,
consents=consents_dict,
legal_basis='consent',
ip_address=client_ip,
user_agent=user_agent,
source='gdpr_page'
)
return success_response(
data={'updated_consents': len(results)},
message='Consent preferences updated successfully'
)
except Exception as e:
logger.error(f'Error updating consents: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post('/consents/{consent_type}/withdraw')
async def withdraw_consent(
request: Request,
consent_type: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Withdraw specific consent (Article 7(3) GDPR)."""
try:
client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent')
try:
consent_type_enum = ConsentType(consent_type)
except ValueError:
raise HTTPException(status_code=400, detail=f'Invalid consent type: {consent_type}')
consent = await consent_service.withdraw_consent(
db=db,
user_id=current_user.id,
consent_type=consent_type_enum,
ip_address=client_ip,
user_agent=user_agent
)
return success_response(
data={
'consent_id': consent.id,
'consent_type': consent.consent_type.value,
'withdrawn_at': consent.withdrawn_at.isoformat() if consent.withdrawn_at else None
},
message=f'Consent for {consent_type} withdrawn successfully'
)
except HTTPException:
raise
except Exception as e:
logger.error(f'Error withdrawing consent: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# Data Processing Records (User View)
@router.get('/processing-records')
async def get_user_processing_records(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get data processing records for the user (Article 15 GDPR - Right of access)."""
try:
summary = data_processing_service.get_user_processing_summary(
db=db,
user_id=current_user.id
)
return success_response(data=summary)
except Exception as e:
logger.error(f'Error getting processing records: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# Admin Routes for Processing Requests
@router.post('/admin/rectify/{request_id}/process')
async def process_rectification(
request_id: int,
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Process data rectification request (admin only)."""
try:
result = await gdpr_service.process_data_rectification(
db=db,
request_id=request_id,
processed_by=current_user.id
)
return success_response(
data=result,
message='Data rectification processed successfully'
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error(f'Error processing rectification: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,169 @@
"""
Data Breach Notification Service (Articles 33-34 GDPR).
"""
from sqlalchemy.orm import Session
from typing import Dict, Any, Optional, List
from datetime import datetime, timedelta
from ..models.data_breach import DataBreach, BreachType, BreachStatus
from ...shared.config.logging_config import get_logger
from ...analytics.services.audit_service import audit_service
logger = get_logger(__name__)
class BreachService:
"""Service for managing data breach notifications (Articles 33-34 GDPR)."""
NOTIFICATION_DEADLINE_HOURS = 72 # Article 33 - 72 hours to notify authority
@staticmethod
async def create_breach(
db: Session,
breach_type: BreachType,
description: str,
reported_by: int,
affected_data_categories: Optional[List[str]] = None,
affected_data_subjects: Optional[int] = None,
occurred_at: Optional[datetime] = None,
likely_consequences: Optional[str] = None,
measures_proposed: Optional[str] = None,
risk_level: Optional[str] = None,
extra_metadata: Optional[Dict[str, Any]] = None
) -> DataBreach:
"""Create a data breach record."""
breach = DataBreach(
breach_type=breach_type,
status=BreachStatus.detected,
description=description,
affected_data_categories=affected_data_categories or [],
affected_data_subjects=affected_data_subjects,
detected_at=datetime.utcnow(),
occurred_at=occurred_at or datetime.utcnow(),
likely_consequences=likely_consequences,
measures_proposed=measures_proposed,
risk_level=risk_level or 'medium',
reported_by=reported_by,
extra_metadata=extra_metadata
)
db.add(breach)
db.commit()
db.refresh(breach)
# Log breach detection
await audit_service.log_action(
db=db,
action='data_breach_detected',
resource_type='data_breach',
user_id=reported_by,
resource_id=breach.id,
details={
'breach_type': breach_type.value,
'risk_level': risk_level,
'affected_subjects': affected_data_subjects
},
status='warning'
)
logger.warning(f'Data breach detected: {breach.id} - {breach_type.value}')
return breach
@staticmethod
async def report_to_authority(
db: Session,
breach_id: int,
authority_reference: str,
reported_by: int
) -> DataBreach:
"""Report breach to supervisory authority (Article 33)."""
breach = db.query(DataBreach).filter(DataBreach.id == breach_id).first()
if not breach:
raise ValueError('Breach not found')
breach.status = BreachStatus.reported_to_authority
breach.reported_to_authority_at = datetime.utcnow()
breach.authority_reference = authority_reference
db.commit()
db.refresh(breach)
# Check if within deadline
time_since_detection = datetime.utcnow() - breach.detected_at
if time_since_detection > timedelta(hours=BreachService.NOTIFICATION_DEADLINE_HOURS):
logger.warning(f'Breach {breach_id} reported after {BreachService.NOTIFICATION_DEADLINE_HOURS} hour deadline')
# Log report
await audit_service.log_action(
db=db,
action='breach_reported_to_authority',
resource_type='data_breach',
user_id=reported_by,
resource_id=breach_id,
details={'authority_reference': authority_reference},
status='success'
)
logger.info(f'Breach {breach_id} reported to authority: {authority_reference}')
return breach
@staticmethod
async def notify_data_subjects(
db: Session,
breach_id: int,
notification_method: str,
notified_by: int
) -> DataBreach:
"""Notify affected data subjects (Article 34)."""
breach = db.query(DataBreach).filter(DataBreach.id == breach_id).first()
if not breach:
raise ValueError('Breach not found')
breach.status = BreachStatus.notified_data_subjects
breach.notified_data_subjects_at = datetime.utcnow()
breach.notification_method = notification_method
db.commit()
db.refresh(breach)
# Log notification
await audit_service.log_action(
db=db,
action='breach_subjects_notified',
resource_type='data_breach',
user_id=notified_by,
resource_id=breach_id,
details={'notification_method': notification_method},
status='success'
)
logger.info(f'Data subjects notified for breach {breach_id}')
return breach
@staticmethod
def get_breaches(
db: Session,
status: Optional[BreachStatus] = None,
limit: int = 50,
offset: int = 0
) -> List[DataBreach]:
"""Get data breaches with optional filters."""
query = db.query(DataBreach)
if status:
query = query.filter(DataBreach.status == status)
return query.order_by(DataBreach.detected_at.desc()).offset(offset).limit(limit).all()
@staticmethod
def get_breaches_requiring_notification(
db: Session
) -> List[DataBreach]:
"""Get breaches that require notification (not yet reported)."""
deadline = datetime.utcnow() - timedelta(hours=BreachService.NOTIFICATION_DEADLINE_HOURS)
return db.query(DataBreach).filter(
DataBreach.status.in_([BreachStatus.detected, BreachStatus.investigating]),
DataBreach.detected_at < deadline
).all()
breach_service = BreachService()

View File

@@ -0,0 +1,202 @@
"""
GDPR Consent Management Service.
"""
from sqlalchemy.orm import Session
from typing import Dict, Any, Optional, List
from datetime import datetime, timedelta
from ..models.consent import Consent, ConsentType, ConsentStatus
from ...auth.models.user import User
from ...shared.config.logging_config import get_logger
from ...analytics.services.audit_service import audit_service
logger = get_logger(__name__)
class ConsentService:
"""Service for managing user consent (Article 7 GDPR)."""
@staticmethod
async def grant_consent(
db: Session,
user_id: int,
consent_type: ConsentType,
legal_basis: str,
consent_method: str = 'explicit',
consent_version: Optional[str] = None,
expires_at: Optional[datetime] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
source: Optional[str] = None,
extra_metadata: Optional[Dict[str, Any]] = None
) -> Consent:
"""Grant consent for a specific purpose."""
# Withdraw any existing consent of this type
existing = db.query(Consent).filter(
Consent.user_id == user_id,
Consent.consent_type == consent_type,
Consent.status == ConsentStatus.granted
).first()
if existing:
existing.status = ConsentStatus.withdrawn
existing.withdrawn_at = datetime.utcnow()
# Create new consent
consent = Consent(
user_id=user_id,
consent_type=consent_type,
status=ConsentStatus.granted,
granted_at=datetime.utcnow(),
expires_at=expires_at,
legal_basis=legal_basis,
consent_method=consent_method,
consent_version=consent_version,
ip_address=ip_address,
user_agent=user_agent,
source=source,
extra_metadata=extra_metadata
)
db.add(consent)
db.commit()
db.refresh(consent)
# Log consent grant
await audit_service.log_action(
db=db,
action='consent_granted',
resource_type='consent',
user_id=user_id,
resource_id=consent.id,
ip_address=ip_address,
user_agent=user_agent,
details={
'consent_type': consent_type.value,
'legal_basis': legal_basis,
'consent_method': consent_method
},
status='success'
)
logger.info(f'Consent granted: {consent_type.value} for user {user_id}')
return consent
@staticmethod
async def withdraw_consent(
db: Session,
user_id: int,
consent_type: ConsentType,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None
) -> Consent:
"""Withdraw consent (Article 7(3) GDPR)."""
consent = db.query(Consent).filter(
Consent.user_id == user_id,
Consent.consent_type == consent_type,
Consent.status == ConsentStatus.granted
).order_by(Consent.granted_at.desc()).first()
if not consent:
raise ValueError(f'No active consent found for {consent_type.value}')
consent.status = ConsentStatus.withdrawn
consent.withdrawn_at = datetime.utcnow()
db.commit()
db.refresh(consent)
# Log consent withdrawal
await audit_service.log_action(
db=db,
action='consent_withdrawn',
resource_type='consent',
user_id=user_id,
resource_id=consent.id,
ip_address=ip_address,
user_agent=user_agent,
details={'consent_type': consent_type.value},
status='success'
)
logger.info(f'Consent withdrawn: {consent_type.value} for user {user_id}')
return consent
@staticmethod
def get_user_consents(
db: Session,
user_id: int,
include_withdrawn: bool = False
) -> List[Consent]:
"""Get all consents for a user."""
query = db.query(Consent).filter(Consent.user_id == user_id)
if not include_withdrawn:
query = query.filter(Consent.status == ConsentStatus.granted)
return query.order_by(Consent.granted_at.desc()).all()
@staticmethod
def has_consent(
db: Session,
user_id: int,
consent_type: ConsentType
) -> bool:
"""Check if user has active consent for a specific type."""
consent = db.query(Consent).filter(
Consent.user_id == user_id,
Consent.consent_type == consent_type,
Consent.status == ConsentStatus.granted
).first()
if not consent:
return False
# Check if expired
if consent.expires_at and consent.expires_at < datetime.utcnow():
consent.status = ConsentStatus.expired
db.commit()
return False
return True
@staticmethod
async def update_consent_preferences(
db: Session,
user_id: int,
consents: Dict[ConsentType, bool],
legal_basis: str = 'consent',
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
source: Optional[str] = None
) -> List[Consent]:
"""Update multiple consent preferences at once."""
results = []
for consent_type, granted in consents.items():
if granted:
consent = await ConsentService.grant_consent(
db=db,
user_id=user_id,
consent_type=consent_type,
legal_basis=legal_basis,
ip_address=ip_address,
user_agent=user_agent,
source=source
)
results.append(consent)
else:
try:
consent = await ConsentService.withdraw_consent(
db=db,
user_id=user_id,
consent_type=consent_type,
ip_address=ip_address,
user_agent=user_agent
)
results.append(consent)
except ValueError:
# No active consent to withdraw
pass
return results
consent_service = ConsentService()

View File

@@ -0,0 +1,128 @@
"""
Data Processing Records Service (Article 30 GDPR).
"""
from sqlalchemy.orm import Session
from typing import Dict, Any, Optional, List
from datetime import datetime
from ..models.data_processing_record import DataProcessingRecord, ProcessingCategory, LegalBasis
from ...shared.config.logging_config import get_logger
logger = get_logger(__name__)
class DataProcessingService:
"""Service for maintaining data processing records (Article 30 GDPR)."""
@staticmethod
async def create_processing_record(
db: Session,
processing_category: ProcessingCategory,
legal_basis: LegalBasis,
purpose: str,
data_categories: Optional[List[str]] = None,
data_subjects: Optional[List[str]] = None,
recipients: Optional[List[str]] = None,
third_parties: Optional[List[str]] = None,
transfers_to_third_countries: bool = False,
transfer_countries: Optional[List[str]] = None,
safeguards: Optional[str] = None,
retention_period: Optional[str] = None,
retention_criteria: Optional[str] = None,
security_measures: Optional[str] = None,
user_id: Optional[int] = None,
related_booking_id: Optional[int] = None,
related_payment_id: Optional[int] = None,
processed_by: Optional[int] = None,
extra_metadata: Optional[Dict[str, Any]] = None
) -> DataProcessingRecord:
"""Create a data processing record."""
record = DataProcessingRecord(
processing_category=processing_category,
legal_basis=legal_basis,
purpose=purpose,
data_categories=data_categories or [],
data_subjects=data_subjects or [],
recipients=recipients or [],
third_parties=third_parties or [],
transfers_to_third_countries=transfers_to_third_countries,
transfer_countries=transfer_countries or [],
safeguards=safeguards,
retention_period=retention_period,
retention_criteria=retention_criteria,
security_measures=security_measures,
user_id=user_id,
related_booking_id=related_booking_id,
related_payment_id=related_payment_id,
processed_by=processed_by,
processing_timestamp=datetime.utcnow(),
extra_metadata=extra_metadata
)
db.add(record)
db.commit()
db.refresh(record)
logger.info(f'Data processing record created: {record.id}')
return record
@staticmethod
def get_processing_records(
db: Session,
user_id: Optional[int] = None,
processing_category: Optional[ProcessingCategory] = None,
legal_basis: Optional[LegalBasis] = None,
limit: int = 100,
offset: int = 0
) -> List[DataProcessingRecord]:
"""Get data processing records with optional filters."""
query = db.query(DataProcessingRecord)
if user_id:
query = query.filter(DataProcessingRecord.user_id == user_id)
if processing_category:
query = query.filter(DataProcessingRecord.processing_category == processing_category)
if legal_basis:
query = query.filter(DataProcessingRecord.legal_basis == legal_basis)
return query.order_by(DataProcessingRecord.processing_timestamp.desc()).offset(offset).limit(limit).all()
@staticmethod
def get_user_processing_summary(
db: Session,
user_id: int
) -> Dict[str, Any]:
"""Get a summary of all data processing activities for a user."""
records = db.query(DataProcessingRecord).filter(
DataProcessingRecord.user_id == user_id
).all()
summary = {
'total_records': len(records),
'by_category': {},
'by_legal_basis': {},
'third_party_sharing': [],
'transfers_to_third_countries': []
}
for record in records:
# By category
category = record.processing_category.value
summary['by_category'][category] = summary['by_category'].get(category, 0) + 1
# By legal basis
basis = record.legal_basis.value
summary['by_legal_basis'][basis] = summary['by_legal_basis'].get(basis, 0) + 1
# Third party sharing
if record.third_parties:
summary['third_party_sharing'].extend(record.third_parties)
# Transfers
if record.transfers_to_third_countries:
summary['transfers_to_third_countries'].extend(record.transfer_countries or [])
return summary
data_processing_service = DataProcessingService()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,141 @@
"""
Data Retention Service for GDPR compliance.
"""
from sqlalchemy.orm import Session
from typing import Dict, Any, Optional, List
from datetime import datetime, timedelta
from ..models.data_retention import RetentionRule, DataRetentionLog
from ...shared.config.logging_config import get_logger
from ...analytics.services.audit_service import audit_service
logger = get_logger(__name__)
class RetentionService:
"""Service for managing data retention policies and cleanup."""
@staticmethod
def create_retention_rule(
db: Session,
data_category: str,
retention_period_days: int,
retention_period_months: Optional[int] = None,
retention_period_years: Optional[int] = None,
legal_basis: Optional[str] = None,
legal_requirement: Optional[str] = None,
action_after_retention: str = 'anonymize',
conditions: Optional[Dict[str, Any]] = None,
description: Optional[str] = None,
created_by: Optional[int] = None
) -> RetentionRule:
"""Create a data retention rule."""
rule = RetentionRule(
data_category=data_category,
retention_period_days=retention_period_days,
retention_period_months=retention_period_months,
retention_period_years=retention_period_years,
legal_basis=legal_basis,
legal_requirement=legal_requirement,
action_after_retention=action_after_retention,
conditions=conditions,
description=description,
created_by=created_by,
is_active=True
)
db.add(rule)
db.commit()
db.refresh(rule)
logger.info(f'Retention rule created: {data_category} - {retention_period_days} days')
return rule
@staticmethod
def get_retention_rules(
db: Session,
is_active: Optional[bool] = None
) -> List[RetentionRule]:
"""Get retention rules."""
query = db.query(RetentionRule)
if is_active is not None:
query = query.filter(RetentionRule.is_active == is_active)
return query.order_by(RetentionRule.data_category).all()
@staticmethod
def get_retention_rule(
db: Session,
data_category: str
) -> Optional[RetentionRule]:
"""Get retention rule for a specific data category."""
return db.query(RetentionRule).filter(
RetentionRule.data_category == data_category,
RetentionRule.is_active == True
).first()
@staticmethod
async def log_retention_action(
db: Session,
retention_rule_id: int,
data_category: str,
action_taken: str,
records_affected: int,
affected_ids: Optional[List[int]] = None,
executed_by: Optional[int] = None,
success: bool = True,
error_message: Optional[str] = None,
extra_metadata: Optional[Dict[str, Any]] = None
) -> DataRetentionLog:
"""Log a data retention action."""
log = DataRetentionLog(
retention_rule_id=retention_rule_id,
data_category=data_category,
action_taken=action_taken,
records_affected=records_affected,
affected_ids=affected_ids or [],
executed_by=executed_by,
executed_at=datetime.utcnow(),
success=success,
error_message=error_message,
extra_metadata=extra_metadata
)
db.add(log)
db.commit()
db.refresh(log)
# Log to audit trail
await audit_service.log_action(
db=db,
action='data_retention_action',
resource_type='retention_log',
user_id=executed_by,
resource_id=log.id,
details={
'data_category': data_category,
'action_taken': action_taken,
'records_affected': records_affected
},
status='success' if success else 'error'
)
logger.info(f'Retention action logged: {action_taken} on {data_category} - {records_affected} records')
return log
@staticmethod
def get_retention_logs(
db: Session,
data_category: Optional[str] = None,
limit: int = 100,
offset: int = 0
) -> List[DataRetentionLog]:
"""Get retention action logs."""
query = db.query(DataRetentionLog)
if data_category:
query = query.filter(DataRetentionLog.data_category == data_category)
return query.order_by(DataRetentionLog.executed_at.desc()).offset(offset).limit(limit).all()
retention_service = RetentionService()

View File

@@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session, joinedload from sqlalchemy.orm import Session, joinedload
from typing import Optional from typing import Optional
from datetime import datetime from datetime import datetime
import random import secrets
from ...shared.config.database import get_db from ...shared.config.database import get_db
from ...shared.config.logging_config import get_logger from ...shared.config.logging_config import get_logger
@@ -33,7 +33,8 @@ router = APIRouter(prefix="/service-bookings", tags=["service-bookings"])
def generate_service_booking_number() -> str: def generate_service_booking_number() -> str:
prefix = "SB" prefix = "SB"
timestamp = datetime.utcnow().strftime("%Y%m%d") timestamp = datetime.utcnow().strftime("%Y%m%d")
random_suffix = random.randint(1000, 9999) # Use cryptographically secure random number to prevent enumeration attacks
random_suffix = secrets.randbelow(9000) + 1000 # Random number between 1000-9999
return f"{prefix}{timestamp}{random_suffix}" return f"{prefix}{timestamp}{random_suffix}"
@router.post("/") @router.post("/")

View File

@@ -1,7 +1,7 @@
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from datetime import datetime, timedelta, date from datetime import datetime, timedelta, date
from typing import Optional from typing import Optional
import random import secrets
import string import string
from ..models.user_loyalty import UserLoyalty from ..models.user_loyalty import UserLoyalty
from ..models.loyalty_tier import LoyaltyTier, TierLevel from ..models.loyalty_tier import LoyaltyTier, TierLevel
@@ -78,19 +78,23 @@ class LoyaltyService:
@staticmethod @staticmethod
def generate_referral_code(db: Session, user_id: int, length: int = 8) -> str: def generate_referral_code(db: Session, user_id: int, length: int = 8) -> str:
"""Generate unique referral code for user""" """Generate unique referral code for user using cryptographically secure random"""
max_attempts = 10 max_attempts = 10
alphabet = string.ascii_uppercase + string.digits
for _ in range(max_attempts): for _ in range(max_attempts):
# Generate code: USER1234 format # Generate code: USER1234 format using cryptographically secure random
code = f"USER{user_id:04d}{''.join(random.choices(string.ascii_uppercase + string.digits, k=length-8))}" # Use secrets.choice() instead of random.choices() for security
random_part = ''.join(secrets.choice(alphabet) for _ in range(length-8))
code = f"USER{user_id:04d}{random_part}"
# Check if code exists # Check if code exists
existing = db.query(UserLoyalty).filter(UserLoyalty.referral_code == code).first() existing = db.query(UserLoyalty).filter(UserLoyalty.referral_code == code).first()
if not existing: if not existing:
return code return code
# Fallback: timestamp-based # Fallback: timestamp-based with secure random suffix
return f"REF{int(datetime.utcnow().timestamp())}{user_id}" random_suffix = ''.join(secrets.choice(alphabet) for _ in range(4))
return f"REF{int(datetime.utcnow().timestamp())}{user_id}{random_suffix}"
@staticmethod @staticmethod
def create_default_tiers(db: Session): def create_default_tiers(db: Session):
@@ -340,14 +344,18 @@ class LoyaltyService:
@staticmethod @staticmethod
def generate_redemption_code(db: Session, length: int = 12) -> str: def generate_redemption_code(db: Session, length: int = 12) -> str:
"""Generate unique redemption code""" """Generate unique redemption code using cryptographically secure random"""
max_attempts = 10 max_attempts = 10
alphabet = string.ascii_uppercase + string.digits
for _ in range(max_attempts): for _ in range(max_attempts):
code = ''.join(random.choices(string.ascii_uppercase + string.digits, k=length)) # Use secrets.choice() instead of random.choices() for security
code = ''.join(secrets.choice(alphabet) for _ in range(length))
existing = db.query(RewardRedemption).filter(RewardRedemption.code == code).first() existing = db.query(RewardRedemption).filter(RewardRedemption.code == code).first()
if not existing: if not existing:
return code return code
return f"RED{int(datetime.utcnow().timestamp())}" # Fallback with secure random suffix
random_suffix = ''.join(secrets.choice(alphabet) for _ in range(4))
return f"RED{int(datetime.utcnow().timestamp())}{random_suffix}"
@staticmethod @staticmethod
def process_referral( def process_referral(

View File

@@ -95,10 +95,16 @@ else:
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'Allowed CORS origins: {", ".join(settings.CORS_ORIGINS)}') logger.debug(f'Allowed CORS origins: {", ".join(settings.CORS_ORIGINS)}')
app.add_middleware(CORSMiddleware, allow_origins=settings.CORS_ORIGINS or [], allow_credentials=True, allow_methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], allow_headers=['*']) # SECURITY: Use explicit headers instead of wildcard to prevent header injection
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS or [],
allow_credentials=True,
allow_methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'],
allow_headers=['Content-Type', 'Authorization', 'X-XSRF-TOKEN', 'X-Requested-With', 'X-Request-ID', 'Accept', 'Accept-Language']
)
uploads_dir = Path(__file__).parent.parent / settings.UPLOAD_DIR uploads_dir = Path(__file__).parent.parent / settings.UPLOAD_DIR
uploads_dir.mkdir(exist_ok=True) uploads_dir.mkdir(exist_ok=True)
app.mount('/uploads', StaticFiles(directory=str(uploads_dir)), name='uploads')
app.add_exception_handler(HTTPException, http_exception_handler) app.add_exception_handler(HTTPException, http_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler) app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(IntegrityError, integrity_error_handler) app.add_exception_handler(IntegrityError, integrity_error_handler)
@@ -108,18 +114,18 @@ app.add_exception_handler(Exception, general_exception_handler)
@app.get('/health', tags=['health']) @app.get('/health', tags=['health'])
@app.get('/api/health', tags=['health']) @app.get('/api/health', tags=['health'])
async def health_check(db: Session=Depends(get_db)): async def health_check(db: Session=Depends(get_db)):
"""Comprehensive health check endpoint""" """
Public health check endpoint.
Returns minimal information for security - no sensitive details exposed.
"""
health_status = { health_status = {
'status': 'healthy', 'status': 'healthy',
'timestamp': datetime.utcnow().isoformat(), 'timestamp': datetime.utcnow().isoformat(),
'service': settings.APP_NAME, # SECURITY: Don't expose service name, version, or environment in public endpoint
'version': settings.APP_VERSION,
'environment': settings.ENVIRONMENT,
'checks': { 'checks': {
'api': 'ok', 'api': 'ok',
'database': 'unknown', 'database': 'unknown'
'disk_space': 'unknown', # SECURITY: Don't expose disk_space or memory details publicly
'memory': 'unknown'
} }
} }
@@ -131,60 +137,26 @@ async def health_check(db: Session=Depends(get_db)):
except OperationalError as e: except OperationalError as e:
health_status['status'] = 'unhealthy' health_status['status'] = 'unhealthy'
health_status['checks']['database'] = 'error' health_status['checks']['database'] = 'error'
health_status['error'] = str(e) # SECURITY: Don't expose database error details publicly
logger.error(f'Database health check failed: {str(e)}') logger.error(f'Database health check failed: {str(e)}')
# Remove error details from response
return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content=health_status) return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content=health_status)
except Exception as e: except Exception as e:
health_status['status'] = 'unhealthy' health_status['status'] = 'unhealthy'
health_status['checks']['database'] = 'error' health_status['checks']['database'] = 'error'
health_status['error'] = str(e) # SECURITY: Don't expose error details publicly
logger.error(f'Health check failed: {str(e)}') logger.error(f'Health check failed: {str(e)}')
# Remove error details from response
return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content=health_status) return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content=health_status)
# Check disk space (if available) # SECURITY: Disk space and memory checks removed from public endpoint
try: # These details should only be available on internal/admin health endpoint
import shutil
disk = shutil.disk_usage('/')
free_percent = (disk.free / disk.total) * 100
if free_percent < 10:
health_status['checks']['disk_space'] = 'warning'
health_status['status'] = 'degraded'
else:
health_status['checks']['disk_space'] = 'ok'
health_status['disk_space'] = {
'free_gb': round(disk.free / (1024**3), 2),
'total_gb': round(disk.total / (1024**3), 2),
'free_percent': round(free_percent, 2)
}
except Exception:
health_status['checks']['disk_space'] = 'unknown'
# Check memory (if available)
try:
import psutil
memory = psutil.virtual_memory()
if memory.percent > 90:
health_status['checks']['memory'] = 'warning'
if health_status['status'] == 'healthy':
health_status['status'] = 'degraded'
else:
health_status['checks']['memory'] = 'ok'
health_status['memory'] = {
'used_percent': round(memory.percent, 2),
'available_gb': round(memory.available / (1024**3), 2),
'total_gb': round(memory.total / (1024**3), 2)
}
except ImportError:
# psutil not available, skip memory check
health_status['checks']['memory'] = 'unavailable'
except Exception:
health_status['checks']['memory'] = 'unknown'
# Determine overall status # Determine overall status
if health_status['status'] == 'healthy' and any( if health_status['status'] == 'healthy' and any(
check == 'warning' for check in health_status['checks'].values() check == 'error' for check in health_status['checks'].values()
): ):
health_status['status'] = 'degraded' health_status['status'] = 'unhealthy'
status_code = status.HTTP_200_OK status_code = status.HTTP_200_OK
if health_status['status'] == 'unhealthy': if health_status['status'] == 'unhealthy':
@@ -195,8 +167,110 @@ async def health_check(db: Session=Depends(get_db)):
return JSONResponse(status_code=status_code, content=health_status) return JSONResponse(status_code=status_code, content=health_status)
@app.get('/metrics', tags=['monitoring']) @app.get('/metrics', tags=['monitoring'])
async def metrics(): async def metrics(
return {'status': 'success', 'service': settings.APP_NAME, 'version': settings.APP_VERSION, 'environment': settings.ENVIRONMENT, 'timestamp': datetime.utcnow().isoformat()} current_user = Depends(lambda: None)
):
"""
Protected metrics endpoint - requires admin or staff authentication.
SECURITY: Prevents information disclosure to unauthorized users.
"""
from ..security.middleware.auth import authorize_roles
# Only allow admin and staff to access metrics
# Use authorize_roles as dependency - it will check authorization automatically
admin_or_staff = authorize_roles('admin', 'staff')
# FastAPI will inject dependencies when this dependency is resolved
current_user = admin_or_staff()
return {
'status': 'success',
'service': settings.APP_NAME,
'version': settings.APP_VERSION,
'environment': settings.ENVIRONMENT,
'timestamp': datetime.utcnow().isoformat()
}
# Custom route for serving uploads with CORS headers
# This route takes precedence over the mount below
from fastapi.responses import FileResponse
import re
@app.options('/uploads/{file_path:path}')
async def serve_upload_file_options(file_path: str, request: Request):
"""Handle CORS preflight for upload files."""
origin = request.headers.get('origin')
if origin:
if settings.is_development:
if re.match(r'http://(localhost|127\.0\.0\.1)(:\d+)?', origin):
return JSONResponse(
content={},
headers={
'Access-Control-Allow-Origin': origin,
'Access-Control-Allow-Credentials': 'true',
'Access-Control-Allow-Methods': 'GET, HEAD, OPTIONS',
'Access-Control-Allow-Headers': '*',
'Access-Control-Max-Age': '3600'
}
)
elif origin in (settings.CORS_ORIGINS or []):
return JSONResponse(
content={},
headers={
'Access-Control-Allow-Origin': origin,
'Access-Control-Allow-Credentials': 'true',
'Access-Control-Allow-Methods': 'GET, HEAD, OPTIONS',
'Access-Control-Allow-Headers': '*',
'Access-Control-Max-Age': '3600'
}
)
return JSONResponse(content={})
@app.get('/uploads/{file_path:path}')
@app.head('/uploads/{file_path:path}')
async def serve_upload_file(file_path: str, request: Request):
"""Serve uploaded files with proper CORS headers."""
file_location = uploads_dir / file_path
# Security: Prevent directory traversal
try:
resolved_path = file_location.resolve()
resolved_uploads = uploads_dir.resolve()
if not str(resolved_path).startswith(str(resolved_uploads)):
raise HTTPException(status_code=403, detail="Access denied")
except (ValueError, OSError):
raise HTTPException(status_code=404, detail="File not found")
if not file_location.exists() or not file_location.is_file():
raise HTTPException(status_code=404, detail="File not found")
# Get origin from request
origin = request.headers.get('origin')
# Prepare response
response = FileResponse(str(file_location))
# Add CORS headers if origin matches
if origin:
if settings.is_development:
if re.match(r'http://(localhost|127\.0\.0\.1)(:\d+)?', origin):
response.headers['Access-Control-Allow-Origin'] = origin
response.headers['Access-Control-Allow-Credentials'] = 'true'
response.headers['Access-Control-Allow-Methods'] = 'GET, HEAD, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = '*'
response.headers['Access-Control-Expose-Headers'] = '*'
elif origin in (settings.CORS_ORIGINS or []):
response.headers['Access-Control-Allow-Origin'] = origin
response.headers['Access-Control-Allow-Credentials'] = 'true'
response.headers['Access-Control-Allow-Methods'] = 'GET, HEAD, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = '*'
response.headers['Access-Control-Expose-Headers'] = '*'
return response
# Mount static files as fallback (routes take precedence)
from starlette.staticfiles import StaticFiles
app.mount('/uploads-static', StaticFiles(directory=str(uploads_dir)), name='uploads-static')
# Import all route modules from feature-based structure # Import all route modules from feature-based structure
from .auth.routes import auth_routes, user_routes from .auth.routes import auth_routes, user_routes
from .rooms.routes import room_routes, advanced_room_routes, rate_plan_routes from .rooms.routes import room_routes, advanced_room_routes, rate_plan_routes
@@ -219,6 +293,7 @@ from .security.routes import security_routes, compliance_routes
from .system.routes import system_settings_routes, workflow_routes, task_routes, approval_routes, backup_routes from .system.routes import system_settings_routes, workflow_routes, task_routes, approval_routes, backup_routes
from .ai.routes import ai_assistant_routes from .ai.routes import ai_assistant_routes
from .compliance.routes import gdpr_routes from .compliance.routes import gdpr_routes
from .compliance.routes.gdpr_admin_routes import router as gdpr_admin_routes
from .integrations.routes import webhook_routes, api_key_routes from .integrations.routes import webhook_routes, api_key_routes
from .auth.routes import session_routes from .auth.routes import session_routes
@@ -274,6 +349,7 @@ app.include_router(blog_routes.router, prefix=api_prefix)
app.include_router(ai_assistant_routes.router, prefix=api_prefix) app.include_router(ai_assistant_routes.router, prefix=api_prefix)
app.include_router(approval_routes.router, prefix=api_prefix) app.include_router(approval_routes.router, prefix=api_prefix)
app.include_router(gdpr_routes.router, prefix=api_prefix) app.include_router(gdpr_routes.router, prefix=api_prefix)
app.include_router(gdpr_admin_routes, prefix=api_prefix)
app.include_router(webhook_routes.router, prefix=api_prefix) app.include_router(webhook_routes.router, prefix=api_prefix)
app.include_router(api_key_routes.router, prefix=api_prefix) app.include_router(api_key_routes.router, prefix=api_prefix)
app.include_router(session_routes.router, prefix=api_prefix) app.include_router(session_routes.router, prefix=api_prefix)
@@ -281,57 +357,38 @@ app.include_router(backup_routes.router, prefix=api_prefix)
logger.info('All routes registered successfully') logger.info('All routes registered successfully')
def ensure_jwt_secret(): def ensure_jwt_secret():
"""Generate and save JWT secret if it's using the default value.
In production, fail fast if default secret is used for security.
In development, auto-generate a secure secret if needed.
""" """
default_secret = 'dev-secret-key-change-in-production-12345' Validate JWT secret is properly configured.
SECURITY: JWT_SECRET must be explicitly set via environment variable.
No default values are acceptable for security.
"""
current_secret = settings.JWT_SECRET current_secret = settings.JWT_SECRET
# Security check: Fail fast in production if using default secret # SECURITY: JWT_SECRET validation is now handled in settings.py
if settings.is_production and (not current_secret or current_secret == default_secret): # This function is kept for backward compatibility and logging
error_msg = ( if not current_secret or current_secret.strip() == '':
'CRITICAL SECURITY ERROR: JWT_SECRET is using default value in production! ' if settings.is_production:
'Please set a secure JWT_SECRET in your environment variables.' # This should not happen as settings validation should catch it
) error_msg = (
logger.error(error_msg) 'CRITICAL SECURITY ERROR: JWT_SECRET is not configured. '
raise ValueError(error_msg) 'Please set JWT_SECRET environment variable before starting the application.'
)
# Development mode: Auto-generate if needed logger.error(error_msg)
if not current_secret or current_secret == default_secret: raise ValueError(error_msg)
new_secret = secrets.token_urlsafe(64)
os.environ['JWT_SECRET'] = new_secret
env_file = Path(__file__).parent.parent / '.env'
if env_file.exists():
try:
env_content = env_file.read_text(encoding='utf-8')
jwt_pattern = re.compile(r'^JWT_SECRET=.*$', re.MULTILINE)
if jwt_pattern.search(env_content):
env_content = jwt_pattern.sub(f'JWT_SECRET={new_secret}', env_content)
else:
jwt_section_pattern = re.compile(r'(# =+.*JWT.*=+.*\n)', re.IGNORECASE | re.MULTILINE)
match = jwt_section_pattern.search(env_content)
if match:
insert_pos = match.end()
env_content = env_content[:insert_pos] + f'JWT_SECRET={new_secret}\n' + env_content[insert_pos:]
else:
env_content += f'\nJWT_SECRET={new_secret}\n'
env_file.write_text(env_content, encoding='utf-8')
logger.info('✓ JWT secret generated and saved to .env file')
except Exception as e:
logger.warning(f'Could not update .env file: {e}')
logger.info(f'Generated JWT secret (add to .env manually): JWT_SECRET={new_secret}')
else: else:
logger.info(f'Generated JWT secret (add to .env file): JWT_SECRET={new_secret}') logger.warning(
'JWT_SECRET is not configured. Authentication will fail. '
logger.info('✓ Secure JWT secret generated automatically') 'Set JWT_SECRET environment variable before starting the application.'
)
else: else:
# Validate secret strength
if len(current_secret) < 64:
if settings.is_production:
logger.warning(
f'JWT_SECRET is only {len(current_secret)} characters. '
'Recommend using at least 64 characters for production security.'
)
logger.info('✓ JWT secret is configured') logger.info('✓ JWT secret is configured')
@app.on_event('startup') @app.on_event('startup')
@@ -375,7 +432,34 @@ async def shutdown_event():
logger.info(f'{settings.APP_NAME} shutting down gracefully') logger.info(f'{settings.APP_NAME} shutting down gracefully')
if __name__ == '__main__': if __name__ == '__main__':
import uvicorn import uvicorn
import os
import signal
import sys
from pathlib import Path from pathlib import Path
def signal_handler(sig, frame):
"""Handle Ctrl+C gracefully."""
logger.info('\nReceived interrupt signal (Ctrl+C). Shutting down gracefully...')
sys.exit(0)
# Register signal handler for graceful shutdown on Ctrl+C
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
base_dir = Path(__file__).parent.parent base_dir = Path(__file__).parent.parent
src_dir = str(base_dir / 'src') src_dir = str(base_dir / 'src')
uvicorn.run('src.main:app', host=settings.HOST, port=settings.PORT, reload=settings.is_development, log_level=settings.LOG_LEVEL.lower(), reload_dirs=[src_dir] if settings.is_development else None, reload_excludes=['*.log', '*.pyc', '*.pyo', '*.pyd', '__pycache__', '**/__pycache__/**', '*.db', '*.sqlite', '*.sqlite3'], reload_delay=0.5) # Enable hot reload in development mode or if explicitly enabled via environment variable
use_reload = settings.is_development or os.getenv('ENABLE_RELOAD', 'false').lower() == 'true'
if use_reload:
logger.info('Hot reload enabled - server will restart on code changes')
logger.info('Press Ctrl+C to stop the server')
uvicorn.run(
'src.main:app',
host=settings.HOST,
port=settings.PORT,
reload=use_reload,
log_level=settings.LOG_LEVEL.lower(),
reload_dirs=[src_dir] if use_reload else None,
reload_excludes=['*.log', '*.pyc', '*.pyo', '*.pyd', '__pycache__', '**/__pycache__/**', '*.db', '*.sqlite', '*.sqlite3', 'venv/**', '.venv/**'],
reload_delay=0.5
)

View File

@@ -174,10 +174,13 @@ class BoricaService:
backend=default_backend() backend=default_backend()
) )
# NOTE: SHA1 is required by Borica payment gateway protocol
# This is a known security trade-off required for payment gateway compatibility
# Monitor for Borica protocol updates that support stronger algorithms
signature = private_key.sign( signature = private_key.sign(
data.encode('utf-8'), data.encode('utf-8'),
padding.PKCS1v15(), padding.PKCS1v15(),
hashes.SHA1() hashes.SHA1() # nosec B303 # Required by Borica protocol - acceptable risk
) )
return base64.b64encode(signature).decode('utf-8') return base64.b64encode(signature).decode('utf-8')
except Exception as e: except Exception as e:
@@ -228,11 +231,13 @@ class BoricaService:
public_key = cert.public_key() public_key = cert.public_key()
signature_bytes = base64.b64decode(signature) signature_bytes = base64.b64decode(signature)
# NOTE: SHA1 is required by Borica payment gateway protocol
# This is a known security trade-off required for payment gateway compatibility
public_key.verify( public_key.verify(
signature_bytes, signature_bytes,
signature_data.encode('utf-8'), signature_data.encode('utf-8'),
padding.PKCS1v15(), padding.PKCS1v15(),
hashes.SHA1() hashes.SHA1() # nosec B303 # Required by Borica protocol - acceptable risk
) )
return True return True
except Exception as e: except Exception as e:

View File

@@ -10,7 +10,12 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
response = await call_next(request) response = await call_next(request)
security_headers = {'X-Content-Type-Options': 'nosniff', 'X-Frame-Options': 'DENY', 'X-XSS-Protection': '1; mode=block', 'Referrer-Policy': 'strict-origin-when-cross-origin', 'Permissions-Policy': 'geolocation=(), microphone=(), camera=()'} security_headers = {'X-Content-Type-Options': 'nosniff', 'X-Frame-Options': 'DENY', 'X-XSS-Protection': '1; mode=block', 'Referrer-Policy': 'strict-origin-when-cross-origin', 'Permissions-Policy': 'geolocation=(), microphone=(), camera=()'}
security_headers.setdefault('Cross-Origin-Resource-Policy', 'cross-origin') # Allow cross-origin resource sharing for uploads/images
# This is needed for images to load from different origins in development
if '/uploads/' in str(request.url):
security_headers.setdefault('Cross-Origin-Resource-Policy', 'cross-origin')
else:
security_headers.setdefault('Cross-Origin-Resource-Policy', 'same-origin')
if settings.is_production: if settings.is_production:
# Enhanced CSP with stricter directives # Enhanced CSP with stricter directives
# Using 'strict-dynamic' for better security with nonce-based scripts # Using 'strict-dynamic' for better security with nonce-based scripts

View File

@@ -10,14 +10,14 @@ class Settings(BaseSettings):
ENVIRONMENT: str = Field(default='development', description='Environment: development, staging, production') ENVIRONMENT: str = Field(default='development', description='Environment: development, staging, production')
DEBUG: bool = Field(default=False, description='Debug mode') DEBUG: bool = Field(default=False, description='Debug mode')
API_V1_PREFIX: str = Field(default='/api/v1', description='API v1 prefix') API_V1_PREFIX: str = Field(default='/api/v1', description='API v1 prefix')
HOST: str = Field(default='0.0.0.0', description='Server host') HOST: str = Field(default='0.0.0.0', description='Server host. WARNING: 0.0.0.0 binds to all interfaces. Use 127.0.0.1 for development or specific IP for production.') # nosec B104 # Acceptable default with validation warning in production
PORT: int = Field(default=8000, description='Server port') PORT: int = Field(default=8000, description='Server port')
DB_USER: str = Field(default='root', description='Database user') DB_USER: str = Field(default='root', description='Database user')
DB_PASS: str = Field(default='', description='Database password') DB_PASS: str = Field(default='', description='Database password')
DB_NAME: str = Field(default='hotel_db', description='Database name') DB_NAME: str = Field(default='hotel_db', description='Database name')
DB_HOST: str = Field(default='localhost', description='Database host') DB_HOST: str = Field(default='localhost', description='Database host')
DB_PORT: str = Field(default='3306', description='Database port') DB_PORT: str = Field(default='3306', description='Database port')
JWT_SECRET: str = Field(default='dev-secret-key-change-in-production-12345', description='JWT secret key') JWT_SECRET: str = Field(default='', description='JWT secret key - MUST be set via environment variable. Minimum 64 characters recommended for production.')
JWT_ALGORITHM: str = Field(default='HS256', description='JWT algorithm') JWT_ALGORITHM: str = Field(default='HS256', description='JWT algorithm')
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = Field(default=30, description='JWT access token expiration in minutes') JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = Field(default=30, description='JWT access token expiration in minutes')
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = Field(default=3, description='JWT refresh token expiration in days (reduced from 7 for better security)') JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = Field(default=3, description='JWT refresh token expiration in days (reduced from 7 for better security)')
@@ -97,6 +97,20 @@ class Settings(BaseSettings):
IP_WHITELIST_ENABLED: bool = Field(default=False, description='Enable IP whitelisting for admin endpoints') IP_WHITELIST_ENABLED: bool = Field(default=False, description='Enable IP whitelisting for admin endpoints')
ADMIN_IP_WHITELIST: List[str] = Field(default_factory=list, description='List of allowed IP addresses/CIDR ranges for admin endpoints') ADMIN_IP_WHITELIST: List[str] = Field(default_factory=list, description='List of allowed IP addresses/CIDR ranges for admin endpoints')
def validate_host_configuration(self) -> None:
"""
Validate HOST configuration for security.
Warns if binding to all interfaces (0.0.0.0) in production.
"""
if self.HOST == '0.0.0.0' and self.is_production:
import logging
logger = logging.getLogger(__name__)
logger.warning(
'SECURITY WARNING: HOST is set to 0.0.0.0 in production. '
'This binds the server to all network interfaces. '
'Consider using a specific IP address or ensure proper firewall rules are in place.'
)
def validate_encryption_key(self) -> None: def validate_encryption_key(self) -> None:
""" """
Validate encryption key is properly configured. Validate encryption key is properly configured.
@@ -138,4 +152,41 @@ class Settings(BaseSettings):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.warning(f'Invalid ENCRYPTION_KEY format: {str(e)}') logger.warning(f'Invalid ENCRYPTION_KEY format: {str(e)}')
settings = Settings() settings = Settings()
# Validate JWT_SECRET on startup - fail fast if not configured
def validate_jwt_secret():
"""Validate JWT_SECRET is properly configured. Called on startup."""
if not settings.JWT_SECRET or settings.JWT_SECRET.strip() == '':
error_msg = (
'CRITICAL SECURITY ERROR: JWT_SECRET is not configured. '
'Please set JWT_SECRET environment variable to a secure random string. '
'Minimum 64 characters recommended for production. '
'Generate one using: python -c "import secrets; print(secrets.token_urlsafe(64))"'
)
import logging
logger = logging.getLogger(__name__)
logger.error(error_msg)
if settings.is_production:
raise ValueError(error_msg)
else:
logger.warning(
'JWT_SECRET not configured. This will cause authentication to fail. '
'Set JWT_SECRET environment variable before starting the application.'
)
# Warn if using weak secret (less than 64 characters)
if len(settings.JWT_SECRET) < 64:
import logging
logger = logging.getLogger(__name__)
if settings.is_production:
logger.warning(
f'JWT_SECRET is only {len(settings.JWT_SECRET)} characters. '
'Recommend using at least 64 characters for production security.'
)
else:
logger.debug(f'JWT_SECRET length: {len(settings.JWT_SECRET)} characters')
# Validate on import
validate_jwt_secret()
settings.validate_host_configuration()

View File

@@ -0,0 +1,168 @@
"""
HTML/XSS sanitization utilities using bleach library.
Prevents stored XSS attacks by sanitizing user-generated content.
"""
import bleach
from typing import Optional
# Allowed HTML tags for rich text content
ALLOWED_TAGS = [
'p', 'br', 'strong', 'em', 'u', 'b', 'i', 's', 'strike',
'a', 'ul', 'ol', 'li', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
'blockquote', 'pre', 'code', 'hr', 'div', 'span',
'table', 'thead', 'tbody', 'tr', 'th', 'td',
'img'
]
# Allowed attributes for specific tags
ALLOWED_ATTRIBUTES = {
'a': ['href', 'title', 'target', 'rel'],
'img': ['src', 'alt', 'title', 'width', 'height'],
'div': ['class'],
'span': ['class'],
'p': ['class'],
'table': ['class', 'border'],
'th': ['colspan', 'rowspan'],
'td': ['colspan', 'rowspan']
}
# Allowed URL schemes
ALLOWED_PROTOCOLS = ['http', 'https', 'mailto']
# Allowed CSS classes (optional - can be expanded)
ALLOWED_STYLES = []
def sanitize_html(content: Optional[str], strip: bool = False) -> str:
"""
Sanitize HTML content to prevent XSS attacks.
Args:
content: The HTML content to sanitize (can be None)
strip: If True, remove disallowed tags instead of escaping them
Returns:
Sanitized HTML string
"""
if not content:
return ''
if not isinstance(content, str):
content = str(content)
# Sanitize HTML
sanitized = bleach.clean(
content,
tags=ALLOWED_TAGS,
attributes=ALLOWED_ATTRIBUTES,
protocols=ALLOWED_PROTOCOLS,
strip=strip,
strip_comments=True
)
# Linkify URLs (convert plain URLs to links)
# Only linkify if content doesn't already contain HTML links
if '<a' not in sanitized:
sanitized = bleach.linkify(
sanitized,
protocols=ALLOWED_PROTOCOLS,
parse_email=True
)
return sanitized
def sanitize_text(content: Optional[str]) -> str:
"""
Strip all HTML tags from content, leaving only plain text.
Useful for fields that should not contain any HTML.
Args:
content: The content to sanitize (can be None)
Returns:
Plain text string with all HTML removed
"""
if not content:
return ''
if not isinstance(content, str):
content = str(content)
# Strip all HTML tags
return bleach.clean(content, tags=[], strip=True)
def sanitize_filename(filename: str) -> str:
"""
Sanitize filename to prevent path traversal and other attacks.
Args:
filename: The original filename
Returns:
Sanitized filename safe for filesystem operations
"""
import os
import secrets
from pathlib import Path
if not filename:
# Generate a random filename if none provided
return f"{secrets.token_urlsafe(16)}.bin"
# Remove path components (prevent directory traversal)
filename = os.path.basename(filename)
# Remove dangerous characters
# Keep only alphanumeric, dots, dashes, and underscores
safe_chars = []
for char in filename:
if char.isalnum() or char in '._-':
safe_chars.append(char)
else:
safe_chars.append('_')
filename = ''.join(safe_chars)
# Limit length (filesystem limit is typically 255)
if len(filename) > 255:
name, ext = os.path.splitext(filename)
max_name_length = 255 - len(ext)
filename = name[:max_name_length] + ext
# Ensure filename is not empty
if not filename or filename == '.' or filename == '..':
filename = f"{secrets.token_urlsafe(16)}.bin"
return filename
def sanitize_url(url: Optional[str]) -> Optional[str]:
"""
Sanitize URL to ensure it uses allowed protocols.
Args:
url: The URL to sanitize
Returns:
Sanitized URL or None if invalid
"""
if not url:
return None
if not isinstance(url, str):
url = str(url)
# Check if URL uses allowed protocol
url_lower = url.lower().strip()
if any(url_lower.startswith(proto + ':') for proto in ALLOWED_PROTOCOLS):
return url
# If no protocol, assume https
if '://' not in url:
return f'https://{url}'
# Invalid protocol - return None
return None

7
Backend/venv/bin/bandit Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from bandit.cli.main import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from bandit.cli.baseline import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from bandit.cli.config_generator import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

7
Backend/venv/bin/doesitcache Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from cachecontrol._cmd import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

7
Backend/venv/bin/fastapi Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from fastapi.cli import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

7
Backend/venv/bin/markdown-it Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from markdown_it.cli.parse import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

7
Backend/venv/bin/nltk Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from nltk.cli import cli
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(cli())

7
Backend/venv/bin/pip-audit Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from pip_audit._cli import audit
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(audit())

7
Backend/venv/bin/safety Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from safety.cli import cli
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(cli())

7
Backend/venv/bin/tqdm Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from tqdm.cli import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

7
Backend/venv/bin/typer Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from typer.cli import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

View File

@@ -1,59 +0,0 @@
Jinja2-3.1.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
Jinja2-3.1.2.dist-info/LICENSE.rst,sha256=O0nc7kEF6ze6wQ-vG-JgQI_oXSUrjp3y4JefweCUQ3s,1475
Jinja2-3.1.2.dist-info/METADATA,sha256=PZ6v2SIidMNixR7MRUX9f7ZWsPwtXanknqiZUmRbh4U,3539
Jinja2-3.1.2.dist-info/RECORD,,
Jinja2-3.1.2.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
Jinja2-3.1.2.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
Jinja2-3.1.2.dist-info/entry_points.txt,sha256=zRd62fbqIyfUpsRtU7EVIFyiu1tPwfgO7EvPErnxgTE,59
Jinja2-3.1.2.dist-info/top_level.txt,sha256=PkeVWtLb3-CqjWi1fO29OCbj55EhX_chhKrCdrVe_zs,7
jinja2/__init__.py,sha256=8vGduD8ytwgD6GDSqpYc2m3aU-T7PKOAddvVXgGr_Fs,1927
jinja2/__pycache__/__init__.cpython-312.pyc,,
jinja2/__pycache__/_identifier.cpython-312.pyc,,
jinja2/__pycache__/async_utils.cpython-312.pyc,,
jinja2/__pycache__/bccache.cpython-312.pyc,,
jinja2/__pycache__/compiler.cpython-312.pyc,,
jinja2/__pycache__/constants.cpython-312.pyc,,
jinja2/__pycache__/debug.cpython-312.pyc,,
jinja2/__pycache__/defaults.cpython-312.pyc,,
jinja2/__pycache__/environment.cpython-312.pyc,,
jinja2/__pycache__/exceptions.cpython-312.pyc,,
jinja2/__pycache__/ext.cpython-312.pyc,,
jinja2/__pycache__/filters.cpython-312.pyc,,
jinja2/__pycache__/idtracking.cpython-312.pyc,,
jinja2/__pycache__/lexer.cpython-312.pyc,,
jinja2/__pycache__/loaders.cpython-312.pyc,,
jinja2/__pycache__/meta.cpython-312.pyc,,
jinja2/__pycache__/nativetypes.cpython-312.pyc,,
jinja2/__pycache__/nodes.cpython-312.pyc,,
jinja2/__pycache__/optimizer.cpython-312.pyc,,
jinja2/__pycache__/parser.cpython-312.pyc,,
jinja2/__pycache__/runtime.cpython-312.pyc,,
jinja2/__pycache__/sandbox.cpython-312.pyc,,
jinja2/__pycache__/tests.cpython-312.pyc,,
jinja2/__pycache__/utils.cpython-312.pyc,,
jinja2/__pycache__/visitor.cpython-312.pyc,,
jinja2/_identifier.py,sha256=_zYctNKzRqlk_murTNlzrju1FFJL7Va_Ijqqd7ii2lU,1958
jinja2/async_utils.py,sha256=dHlbTeaxFPtAOQEYOGYh_PHcDT0rsDaUJAFDl_0XtTg,2472
jinja2/bccache.py,sha256=mhz5xtLxCcHRAa56azOhphIAe19u1we0ojifNMClDio,14061
jinja2/compiler.py,sha256=Gs-N8ThJ7OWK4-reKoO8Wh1ZXz95MVphBKNVf75qBr8,72172
jinja2/constants.py,sha256=GMoFydBF_kdpaRKPoM5cl5MviquVRLVyZtfp5-16jg0,1433
jinja2/debug.py,sha256=iWJ432RadxJNnaMOPrjIDInz50UEgni3_HKuFXi2vuQ,6299
jinja2/defaults.py,sha256=boBcSw78h-lp20YbaXSJsqkAI2uN_mD_TtCydpeq5wU,1267
jinja2/environment.py,sha256=6uHIcc7ZblqOMdx_uYNKqRnnwAF0_nzbyeMP9FFtuh4,61349
jinja2/exceptions.py,sha256=ioHeHrWwCWNaXX1inHmHVblvc4haO7AXsjCp3GfWvx0,5071
jinja2/ext.py,sha256=ivr3P7LKbddiXDVez20EflcO3q2aHQwz9P_PgWGHVqE,31502
jinja2/filters.py,sha256=9js1V-h2RlyW90IhLiBGLM2U-k6SCy2F4BUUMgB3K9Q,53509
jinja2/idtracking.py,sha256=GfNmadir4oDALVxzn3DL9YInhJDr69ebXeA2ygfuCGA,10704
jinja2/lexer.py,sha256=DW2nX9zk-6MWp65YR2bqqj0xqCvLtD-u9NWT8AnFRxQ,29726
jinja2/loaders.py,sha256=BfptfvTVpClUd-leMkHczdyPNYFzp_n7PKOJ98iyHOg,23207
jinja2/meta.py,sha256=GNPEvifmSaU3CMxlbheBOZjeZ277HThOPUTf1RkppKQ,4396
jinja2/nativetypes.py,sha256=DXgORDPRmVWgy034H0xL8eF7qYoK3DrMxs-935d0Fzk,4226
jinja2/nodes.py,sha256=i34GPRAZexXMT6bwuf5SEyvdmS-bRCy9KMjwN5O6pjk,34550
jinja2/optimizer.py,sha256=tHkMwXxfZkbfA1KmLcqmBMSaz7RLIvvItrJcPoXTyD8,1650
jinja2/parser.py,sha256=nHd-DFHbiygvfaPtm9rcQXJChZG7DPsWfiEsqfwKerY,39595
jinja2/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
jinja2/runtime.py,sha256=5CmD5BjbEJxSiDNTFBeKCaq8qU4aYD2v6q2EluyExms,33476
jinja2/sandbox.py,sha256=Y0xZeXQnH6EX5VjaV2YixESxoepnRbW_3UeQosaBU3M,14584
jinja2/tests.py,sha256=Am5Z6Lmfr2XaH_npIfJJ8MdXtWsbLjMULZJulTAj30E,5905
jinja2/utils.py,sha256=u9jXESxGn8ATZNVolwmkjUVu4SA-tLgV0W7PcSfPfdQ,23965
jinja2/visitor.py,sha256=MH14C6yq24G_KVtWzjwaI7Wg14PCJIYlWW1kpkxYak0,3568

View File

@@ -1,2 +0,0 @@
[babel.extractors]
jinja2 = jinja2.ext:babel_extract[i18n]

View File

@@ -0,0 +1,291 @@
from __future__ import annotations
import os
from io import BytesIO
from typing import IO
from . import ExifTags, Image, ImageFile
try:
from . import _avif
SUPPORTED = True
except ImportError:
SUPPORTED = False
# Decoder options as module globals, until there is a way to pass parameters
# to Image.open (see https://github.com/python-pillow/Pillow/issues/569)
DECODE_CODEC_CHOICE = "auto"
DEFAULT_MAX_THREADS = 0
def get_codec_version(codec_name: str) -> str | None:
versions = _avif.codec_versions()
for version in versions.split(", "):
if version.split(" [")[0] == codec_name:
return version.split(":")[-1].split(" ")[0]
return None
def _accept(prefix: bytes) -> bool | str:
if prefix[4:8] != b"ftyp":
return False
major_brand = prefix[8:12]
if major_brand in (
# coding brands
b"avif",
b"avis",
# We accept files with AVIF container brands; we can't yet know if
# the ftyp box has the correct compatible brands, but if it doesn't
# then the plugin will raise a SyntaxError which Pillow will catch
# before moving on to the next plugin that accepts the file.
#
# Also, because this file might not actually be an AVIF file, we
# don't raise an error if AVIF support isn't properly compiled.
b"mif1",
b"msf1",
):
if not SUPPORTED:
return (
"image file could not be identified because AVIF support not installed"
)
return True
return False
def _get_default_max_threads() -> int:
if DEFAULT_MAX_THREADS:
return DEFAULT_MAX_THREADS
if hasattr(os, "sched_getaffinity"):
return len(os.sched_getaffinity(0))
else:
return os.cpu_count() or 1
class AvifImageFile(ImageFile.ImageFile):
format = "AVIF"
format_description = "AVIF image"
__frame = -1
def _open(self) -> None:
if not SUPPORTED:
msg = "image file could not be opened because AVIF support not installed"
raise SyntaxError(msg)
if DECODE_CODEC_CHOICE != "auto" and not _avif.decoder_codec_available(
DECODE_CODEC_CHOICE
):
msg = "Invalid opening codec"
raise ValueError(msg)
self._decoder = _avif.AvifDecoder(
self.fp.read(),
DECODE_CODEC_CHOICE,
_get_default_max_threads(),
)
# Get info from decoder
self._size, self.n_frames, self._mode, icc, exif, exif_orientation, xmp = (
self._decoder.get_info()
)
self.is_animated = self.n_frames > 1
if icc:
self.info["icc_profile"] = icc
if xmp:
self.info["xmp"] = xmp
if exif_orientation != 1 or exif:
exif_data = Image.Exif()
if exif:
exif_data.load(exif)
original_orientation = exif_data.get(ExifTags.Base.Orientation, 1)
else:
original_orientation = 1
if exif_orientation != original_orientation:
exif_data[ExifTags.Base.Orientation] = exif_orientation
exif = exif_data.tobytes()
if exif:
self.info["exif"] = exif
self.seek(0)
def seek(self, frame: int) -> None:
if not self._seek_check(frame):
return
# Set tile
self.__frame = frame
self.tile = [ImageFile._Tile("raw", (0, 0) + self.size, 0, self.mode)]
def load(self) -> Image.core.PixelAccess | None:
if self.tile:
# We need to load the image data for this frame
data, timescale, pts_in_timescales, duration_in_timescales = (
self._decoder.get_frame(self.__frame)
)
self.info["timestamp"] = round(1000 * (pts_in_timescales / timescale))
self.info["duration"] = round(1000 * (duration_in_timescales / timescale))
if self.fp and self._exclusive_fp:
self.fp.close()
self.fp = BytesIO(data)
return super().load()
def load_seek(self, pos: int) -> None:
pass
def tell(self) -> int:
return self.__frame
def _save_all(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
_save(im, fp, filename, save_all=True)
def _save(
im: Image.Image, fp: IO[bytes], filename: str | bytes, save_all: bool = False
) -> None:
info = im.encoderinfo.copy()
if save_all:
append_images = list(info.get("append_images", []))
else:
append_images = []
total = 0
for ims in [im] + append_images:
total += getattr(ims, "n_frames", 1)
quality = info.get("quality", 75)
if not isinstance(quality, int) or quality < 0 or quality > 100:
msg = "Invalid quality setting"
raise ValueError(msg)
duration = info.get("duration", 0)
subsampling = info.get("subsampling", "4:2:0")
speed = info.get("speed", 6)
max_threads = info.get("max_threads", _get_default_max_threads())
codec = info.get("codec", "auto")
if codec != "auto" and not _avif.encoder_codec_available(codec):
msg = "Invalid saving codec"
raise ValueError(msg)
range_ = info.get("range", "full")
tile_rows_log2 = info.get("tile_rows", 0)
tile_cols_log2 = info.get("tile_cols", 0)
alpha_premultiplied = bool(info.get("alpha_premultiplied", False))
autotiling = bool(info.get("autotiling", tile_rows_log2 == tile_cols_log2 == 0))
icc_profile = info.get("icc_profile", im.info.get("icc_profile"))
exif_orientation = 1
if exif := info.get("exif"):
if isinstance(exif, Image.Exif):
exif_data = exif
else:
exif_data = Image.Exif()
exif_data.load(exif)
if ExifTags.Base.Orientation in exif_data:
exif_orientation = exif_data.pop(ExifTags.Base.Orientation)
exif = exif_data.tobytes() if exif_data else b""
elif isinstance(exif, Image.Exif):
exif = exif_data.tobytes()
xmp = info.get("xmp")
if isinstance(xmp, str):
xmp = xmp.encode("utf-8")
advanced = info.get("advanced")
if advanced is not None:
if isinstance(advanced, dict):
advanced = advanced.items()
try:
advanced = tuple(advanced)
except TypeError:
invalid = True
else:
invalid = any(not isinstance(v, tuple) or len(v) != 2 for v in advanced)
if invalid:
msg = (
"advanced codec options must be a dict of key-value string "
"pairs or a series of key-value two-tuples"
)
raise ValueError(msg)
# Setup the AVIF encoder
enc = _avif.AvifEncoder(
im.size,
subsampling,
quality,
speed,
max_threads,
codec,
range_,
tile_rows_log2,
tile_cols_log2,
alpha_premultiplied,
autotiling,
icc_profile or b"",
exif or b"",
exif_orientation,
xmp or b"",
advanced,
)
# Add each frame
frame_idx = 0
frame_duration = 0
cur_idx = im.tell()
is_single_frame = total == 1
try:
for ims in [im] + append_images:
# Get number of frames in this image
nfr = getattr(ims, "n_frames", 1)
for idx in range(nfr):
ims.seek(idx)
# Make sure image mode is supported
frame = ims
rawmode = ims.mode
if ims.mode not in {"RGB", "RGBA"}:
rawmode = "RGBA" if ims.has_transparency_data else "RGB"
frame = ims.convert(rawmode)
# Update frame duration
if isinstance(duration, (list, tuple)):
frame_duration = duration[frame_idx]
else:
frame_duration = duration
# Append the frame to the animation encoder
enc.add(
frame.tobytes("raw", rawmode),
frame_duration,
frame.size,
rawmode,
is_single_frame,
)
# Update frame index
frame_idx += 1
if not save_all:
break
finally:
im.seek(cur_idx)
# Get the final output from the encoder
data = enc.finish()
if data is None:
msg = "cannot write file as AVIF (encoder returned None)"
raise OSError(msg)
fp.write(data)
Image.register_open(AvifImageFile.format, AvifImageFile, _accept)
if SUPPORTED:
Image.register_save(AvifImageFile.format, _save)
Image.register_save_all(AvifImageFile.format, _save_all)
Image.register_extensions(AvifImageFile.format, [".avif", ".avifs"])
Image.register_mime(AvifImageFile.format, "image/avif")

View File

@@ -20,29 +20,30 @@
""" """
Parse X Bitmap Distribution Format (BDF) Parse X Bitmap Distribution Format (BDF)
""" """
from __future__ import annotations
from typing import BinaryIO
from . import FontFile, Image from . import FontFile, Image
bdf_slant = {
"R": "Roman",
"I": "Italic",
"O": "Oblique",
"RI": "Reverse Italic",
"RO": "Reverse Oblique",
"OT": "Other",
}
bdf_spacing = {"P": "Proportional", "M": "Monospaced", "C": "Cell"} def bdf_char(
f: BinaryIO,
) -> (
def bdf_char(f): tuple[
str,
int,
tuple[tuple[int, int], tuple[int, int, int, int], tuple[int, int, int, int]],
Image.Image,
]
| None
):
# skip to STARTCHAR # skip to STARTCHAR
while True: while True:
s = f.readline() s = f.readline()
if not s: if not s:
return None return None
if s[:9] == b"STARTCHAR": if s.startswith(b"STARTCHAR"):
break break
id = s[9:].strip().decode("ascii") id = s[9:].strip().decode("ascii")
@@ -50,19 +51,18 @@ def bdf_char(f):
props = {} props = {}
while True: while True:
s = f.readline() s = f.readline()
if not s or s[:6] == b"BITMAP": if not s or s.startswith(b"BITMAP"):
break break
i = s.find(b" ") i = s.find(b" ")
props[s[:i].decode("ascii")] = s[i + 1 : -1].decode("ascii") props[s[:i].decode("ascii")] = s[i + 1 : -1].decode("ascii")
# load bitmap # load bitmap
bitmap = [] bitmap = bytearray()
while True: while True:
s = f.readline() s = f.readline()
if not s or s[:7] == b"ENDCHAR": if not s or s.startswith(b"ENDCHAR"):
break break
bitmap.append(s[:-1]) bitmap += s[:-1]
bitmap = b"".join(bitmap)
# The word BBX # The word BBX
# followed by the width in x (BBw), height in y (BBh), # followed by the width in x (BBw), height in y (BBh),
@@ -92,11 +92,11 @@ def bdf_char(f):
class BdfFontFile(FontFile.FontFile): class BdfFontFile(FontFile.FontFile):
"""Font file plugin for the X11 BDF format.""" """Font file plugin for the X11 BDF format."""
def __init__(self, fp): def __init__(self, fp: BinaryIO) -> None:
super().__init__() super().__init__()
s = fp.readline() s = fp.readline()
if s[:13] != b"STARTFONT 2.1": if not s.startswith(b"STARTFONT 2.1"):
msg = "not a valid BDF file" msg = "not a valid BDF file"
raise SyntaxError(msg) raise SyntaxError(msg)
@@ -105,7 +105,7 @@ class BdfFontFile(FontFile.FontFile):
while True: while True:
s = fp.readline() s = fp.readline()
if not s or s[:13] == b"ENDPROPERTIES": if not s or s.startswith(b"ENDPROPERTIES"):
break break
i = s.find(b" ") i = s.find(b" ")
props[s[:i].decode("ascii")] = s[i + 1 : -1].decode("ascii") props[s[:i].decode("ascii")] = s[i + 1 : -1].decode("ascii")

View File

@@ -29,10 +29,14 @@ BLP files come in many different flavours:
- DXT5 compression is used if alpha_encoding == 7. - DXT5 compression is used if alpha_encoding == 7.
""" """
from __future__ import annotations
import abc
import os import os
import struct import struct
from enum import IntEnum from enum import IntEnum
from io import BytesIO from io import BytesIO
from typing import IO
from . import Image, ImageFile from . import Image, ImageFile
@@ -53,11 +57,13 @@ class AlphaEncoding(IntEnum):
DXT5 = 7 DXT5 = 7
def unpack_565(i): def unpack_565(i: int) -> tuple[int, int, int]:
return ((i >> 11) & 0x1F) << 3, ((i >> 5) & 0x3F) << 2, (i & 0x1F) << 3 return ((i >> 11) & 0x1F) << 3, ((i >> 5) & 0x3F) << 2, (i & 0x1F) << 3
def decode_dxt1(data, alpha=False): def decode_dxt1(
data: bytes, alpha: bool = False
) -> tuple[bytearray, bytearray, bytearray, bytearray]:
""" """
input: one "row" of data (i.e. will produce 4*width pixels) input: one "row" of data (i.e. will produce 4*width pixels)
""" """
@@ -65,9 +71,9 @@ def decode_dxt1(data, alpha=False):
blocks = len(data) // 8 # number of blocks in row blocks = len(data) // 8 # number of blocks in row
ret = (bytearray(), bytearray(), bytearray(), bytearray()) ret = (bytearray(), bytearray(), bytearray(), bytearray())
for block in range(blocks): for block_index in range(blocks):
# Decode next 8-byte block. # Decode next 8-byte block.
idx = block * 8 idx = block_index * 8
color0, color1, bits = struct.unpack_from("<HHI", data, idx) color0, color1, bits = struct.unpack_from("<HHI", data, idx)
r0, g0, b0 = unpack_565(color0) r0, g0, b0 = unpack_565(color0)
@@ -112,7 +118,7 @@ def decode_dxt1(data, alpha=False):
return ret return ret
def decode_dxt3(data): def decode_dxt3(data: bytes) -> tuple[bytearray, bytearray, bytearray, bytearray]:
""" """
input: one "row" of data (i.e. will produce 4*width pixels) input: one "row" of data (i.e. will produce 4*width pixels)
""" """
@@ -120,8 +126,8 @@ def decode_dxt3(data):
blocks = len(data) // 16 # number of blocks in row blocks = len(data) // 16 # number of blocks in row
ret = (bytearray(), bytearray(), bytearray(), bytearray()) ret = (bytearray(), bytearray(), bytearray(), bytearray())
for block in range(blocks): for block_index in range(blocks):
idx = block * 16 idx = block_index * 16
block = data[idx : idx + 16] block = data[idx : idx + 16]
# Decode next 16-byte block. # Decode next 16-byte block.
bits = struct.unpack_from("<8B", block) bits = struct.unpack_from("<8B", block)
@@ -165,7 +171,7 @@ def decode_dxt3(data):
return ret return ret
def decode_dxt5(data): def decode_dxt5(data: bytes) -> tuple[bytearray, bytearray, bytearray, bytearray]:
""" """
input: one "row" of data (i.e. will produce 4 * width pixels) input: one "row" of data (i.e. will produce 4 * width pixels)
""" """
@@ -173,8 +179,8 @@ def decode_dxt5(data):
blocks = len(data) // 16 # number of blocks in row blocks = len(data) // 16 # number of blocks in row
ret = (bytearray(), bytearray(), bytearray(), bytearray()) ret = (bytearray(), bytearray(), bytearray(), bytearray())
for block in range(blocks): for block_index in range(blocks):
idx = block * 16 idx = block_index * 16
block = data[idx : idx + 16] block = data[idx : idx + 16]
# Decode next 16-byte block. # Decode next 16-byte block.
a0, a1 = struct.unpack_from("<BB", block) a0, a1 = struct.unpack_from("<BB", block)
@@ -239,8 +245,8 @@ class BLPFormatError(NotImplementedError):
pass pass
def _accept(prefix): def _accept(prefix: bytes) -> bool:
return prefix[:4] in (b"BLP1", b"BLP2") return prefix.startswith((b"BLP1", b"BLP2"))
class BlpImageFile(ImageFile.ImageFile): class BlpImageFile(ImageFile.ImageFile):
@@ -251,60 +257,65 @@ class BlpImageFile(ImageFile.ImageFile):
format = "BLP" format = "BLP"
format_description = "Blizzard Mipmap Format" format_description = "Blizzard Mipmap Format"
def _open(self): def _open(self) -> None:
self.magic = self.fp.read(4) self.magic = self.fp.read(4)
if not _accept(self.magic):
self.fp.seek(5, os.SEEK_CUR)
(self._blp_alpha_depth,) = struct.unpack("<b", self.fp.read(1))
self.fp.seek(2, os.SEEK_CUR)
self._size = struct.unpack("<II", self.fp.read(8))
if self.magic in (b"BLP1", b"BLP2"):
decoder = self.magic.decode()
else:
msg = f"Bad BLP magic {repr(self.magic)}" msg = f"Bad BLP magic {repr(self.magic)}"
raise BLPFormatError(msg) raise BLPFormatError(msg)
self._mode = "RGBA" if self._blp_alpha_depth else "RGB" compression = struct.unpack("<i", self.fp.read(4))[0]
self.tile = [(decoder, (0, 0) + self.size, 0, (self.mode, 0, 1))] if self.magic == b"BLP1":
alpha = struct.unpack("<I", self.fp.read(4))[0] != 0
else:
encoding = struct.unpack("<b", self.fp.read(1))[0]
alpha = struct.unpack("<b", self.fp.read(1))[0] != 0
alpha_encoding = struct.unpack("<b", self.fp.read(1))[0]
self.fp.seek(1, os.SEEK_CUR) # mips
self._size = struct.unpack("<II", self.fp.read(8))
args: tuple[int, int, bool] | tuple[int, int, bool, int]
if self.magic == b"BLP1":
encoding = struct.unpack("<i", self.fp.read(4))[0]
self.fp.seek(4, os.SEEK_CUR) # subtype
args = (compression, encoding, alpha)
offset = 28
else:
args = (compression, encoding, alpha, alpha_encoding)
offset = 20
decoder = self.magic.decode()
self._mode = "RGBA" if alpha else "RGB"
self.tile = [ImageFile._Tile(decoder, (0, 0) + self.size, offset, args)]
class _BLPBaseDecoder(ImageFile.PyDecoder): class _BLPBaseDecoder(abc.ABC, ImageFile.PyDecoder):
_pulls_fd = True _pulls_fd = True
def decode(self, buffer): def decode(self, buffer: bytes | Image.SupportsArrayInterface) -> tuple[int, int]:
try: try:
self._read_blp_header() self._read_header()
self._load() self._load()
except struct.error as e: except struct.error as e:
msg = "Truncated BLP file" msg = "Truncated BLP file"
raise OSError(msg) from e raise OSError(msg) from e
return -1, 0 return -1, 0
def _read_blp_header(self): @abc.abstractmethod
self.fd.seek(4) def _load(self) -> None:
(self._blp_compression,) = struct.unpack("<i", self._safe_read(4)) pass
(self._blp_encoding,) = struct.unpack("<b", self._safe_read(1)) def _read_header(self) -> None:
(self._blp_alpha_depth,) = struct.unpack("<b", self._safe_read(1)) self._offsets = struct.unpack("<16I", self._safe_read(16 * 4))
(self._blp_alpha_encoding,) = struct.unpack("<b", self._safe_read(1)) self._lengths = struct.unpack("<16I", self._safe_read(16 * 4))
self.fd.seek(1, os.SEEK_CUR) # mips
self.size = struct.unpack("<II", self._safe_read(8)) def _safe_read(self, length: int) -> bytes:
assert self.fd is not None
if isinstance(self, BLP1Decoder):
# Only present for BLP1
(self._blp_encoding,) = struct.unpack("<i", self._safe_read(4))
self.fd.seek(4, os.SEEK_CUR) # subtype
self._blp_offsets = struct.unpack("<16I", self._safe_read(16 * 4))
self._blp_lengths = struct.unpack("<16I", self._safe_read(16 * 4))
def _safe_read(self, length):
return ImageFile._safe_read(self.fd, length) return ImageFile._safe_read(self.fd, length)
def _read_palette(self): def _read_palette(self) -> list[tuple[int, int, int, int]]:
ret = [] ret = []
for i in range(256): for i in range(256):
try: try:
@@ -314,110 +325,115 @@ class _BLPBaseDecoder(ImageFile.PyDecoder):
ret.append((b, g, r, a)) ret.append((b, g, r, a))
return ret return ret
def _read_bgra(self, palette): def _read_bgra(
self, palette: list[tuple[int, int, int, int]], alpha: bool
) -> bytearray:
data = bytearray() data = bytearray()
_data = BytesIO(self._safe_read(self._blp_lengths[0])) _data = BytesIO(self._safe_read(self._lengths[0]))
while True: while True:
try: try:
(offset,) = struct.unpack("<B", _data.read(1)) (offset,) = struct.unpack("<B", _data.read(1))
except struct.error: except struct.error:
break break
b, g, r, a = palette[offset] b, g, r, a = palette[offset]
d = (r, g, b) d: tuple[int, ...] = (r, g, b)
if self._blp_alpha_depth: if alpha:
d += (a,) d += (a,)
data.extend(d) data.extend(d)
return data return data
class BLP1Decoder(_BLPBaseDecoder): class BLP1Decoder(_BLPBaseDecoder):
def _load(self): def _load(self) -> None:
if self._blp_compression == Format.JPEG: self._compression, self._encoding, alpha = self.args
if self._compression == Format.JPEG:
self._decode_jpeg_stream() self._decode_jpeg_stream()
elif self._blp_compression == 1: elif self._compression == 1:
if self._blp_encoding in (4, 5): if self._encoding in (4, 5):
palette = self._read_palette() palette = self._read_palette()
data = self._read_bgra(palette) data = self._read_bgra(palette, alpha)
self.set_as_raw(bytes(data)) self.set_as_raw(data)
else: else:
msg = f"Unsupported BLP encoding {repr(self._blp_encoding)}" msg = f"Unsupported BLP encoding {repr(self._encoding)}"
raise BLPFormatError(msg) raise BLPFormatError(msg)
else: else:
msg = f"Unsupported BLP compression {repr(self._blp_encoding)}" msg = f"Unsupported BLP compression {repr(self._encoding)}"
raise BLPFormatError(msg) raise BLPFormatError(msg)
def _decode_jpeg_stream(self): def _decode_jpeg_stream(self) -> None:
from .JpegImagePlugin import JpegImageFile from .JpegImagePlugin import JpegImageFile
(jpeg_header_size,) = struct.unpack("<I", self._safe_read(4)) (jpeg_header_size,) = struct.unpack("<I", self._safe_read(4))
jpeg_header = self._safe_read(jpeg_header_size) jpeg_header = self._safe_read(jpeg_header_size)
self._safe_read(self._blp_offsets[0] - self.fd.tell()) # What IS this? assert self.fd is not None
data = self._safe_read(self._blp_lengths[0]) self._safe_read(self._offsets[0] - self.fd.tell()) # What IS this?
data = self._safe_read(self._lengths[0])
data = jpeg_header + data data = jpeg_header + data
data = BytesIO(data) image = JpegImageFile(BytesIO(data))
image = JpegImageFile(data)
Image._decompression_bomb_check(image.size) Image._decompression_bomb_check(image.size)
if image.mode == "CMYK": if image.mode == "CMYK":
decoder_name, extents, offset, args = image.tile[0] args = image.tile[0].args
image.tile = [(decoder_name, extents, offset, (args[0], "CMYK"))] assert isinstance(args, tuple)
r, g, b = image.convert("RGB").split() image.tile = [image.tile[0]._replace(args=(args[0], "CMYK"))]
image = Image.merge("RGB", (b, g, r)) self.set_as_raw(image.convert("RGB").tobytes(), "BGR")
self.set_as_raw(image.tobytes())
class BLP2Decoder(_BLPBaseDecoder): class BLP2Decoder(_BLPBaseDecoder):
def _load(self): def _load(self) -> None:
self._compression, self._encoding, alpha, self._alpha_encoding = self.args
palette = self._read_palette() palette = self._read_palette()
self.fd.seek(self._blp_offsets[0]) assert self.fd is not None
self.fd.seek(self._offsets[0])
if self._blp_compression == 1: if self._compression == 1:
# Uncompressed or DirectX compression # Uncompressed or DirectX compression
if self._blp_encoding == Encoding.UNCOMPRESSED: if self._encoding == Encoding.UNCOMPRESSED:
data = self._read_bgra(palette) data = self._read_bgra(palette, alpha)
elif self._blp_encoding == Encoding.DXT: elif self._encoding == Encoding.DXT:
data = bytearray() data = bytearray()
if self._blp_alpha_encoding == AlphaEncoding.DXT1: if self._alpha_encoding == AlphaEncoding.DXT1:
linesize = (self.size[0] + 3) // 4 * 8 linesize = (self.state.xsize + 3) // 4 * 8
for yb in range((self.size[1] + 3) // 4): for yb in range((self.state.ysize + 3) // 4):
for d in decode_dxt1( for d in decode_dxt1(self._safe_read(linesize), alpha):
self._safe_read(linesize), alpha=bool(self._blp_alpha_depth)
):
data += d data += d
elif self._blp_alpha_encoding == AlphaEncoding.DXT3: elif self._alpha_encoding == AlphaEncoding.DXT3:
linesize = (self.size[0] + 3) // 4 * 16 linesize = (self.state.xsize + 3) // 4 * 16
for yb in range((self.size[1] + 3) // 4): for yb in range((self.state.ysize + 3) // 4):
for d in decode_dxt3(self._safe_read(linesize)): for d in decode_dxt3(self._safe_read(linesize)):
data += d data += d
elif self._blp_alpha_encoding == AlphaEncoding.DXT5: elif self._alpha_encoding == AlphaEncoding.DXT5:
linesize = (self.size[0] + 3) // 4 * 16 linesize = (self.state.xsize + 3) // 4 * 16
for yb in range((self.size[1] + 3) // 4): for yb in range((self.state.ysize + 3) // 4):
for d in decode_dxt5(self._safe_read(linesize)): for d in decode_dxt5(self._safe_read(linesize)):
data += d data += d
else: else:
msg = f"Unsupported alpha encoding {repr(self._blp_alpha_encoding)}" msg = f"Unsupported alpha encoding {repr(self._alpha_encoding)}"
raise BLPFormatError(msg) raise BLPFormatError(msg)
else: else:
msg = f"Unknown BLP encoding {repr(self._blp_encoding)}" msg = f"Unknown BLP encoding {repr(self._encoding)}"
raise BLPFormatError(msg) raise BLPFormatError(msg)
else: else:
msg = f"Unknown BLP compression {repr(self._blp_compression)}" msg = f"Unknown BLP compression {repr(self._compression)}"
raise BLPFormatError(msg) raise BLPFormatError(msg)
self.set_as_raw(bytes(data)) self.set_as_raw(data)
class BLPEncoder(ImageFile.PyEncoder): class BLPEncoder(ImageFile.PyEncoder):
_pushes_fd = True _pushes_fd = True
def _write_palette(self): def _write_palette(self) -> bytes:
data = b"" data = b""
assert self.im is not None
palette = self.im.getpalette("RGBA", "RGBA") palette = self.im.getpalette("RGBA", "RGBA")
for i in range(len(palette) // 4): for i in range(len(palette) // 4):
r, g, b, a = palette[i * 4 : (i + 1) * 4] r, g, b, a = palette[i * 4 : (i + 1) * 4]
@@ -426,12 +442,13 @@ class BLPEncoder(ImageFile.PyEncoder):
data += b"\x00" * 4 data += b"\x00" * 4
return data return data
def encode(self, bufsize): def encode(self, bufsize: int) -> tuple[int, int, bytes]:
palette_data = self._write_palette() palette_data = self._write_palette()
offset = 20 + 16 * 4 * 2 + len(palette_data) offset = 20 + 16 * 4 * 2 + len(palette_data)
data = struct.pack("<16I", offset, *((0,) * 15)) data = struct.pack("<16I", offset, *((0,) * 15))
assert self.im is not None
w, h = self.im.size w, h = self.im.size
data += struct.pack("<16I", w * h, *((0,) * 15)) data += struct.pack("<16I", w * h, *((0,) * 15))
@@ -444,7 +461,7 @@ class BLPEncoder(ImageFile.PyEncoder):
return len(data), 0, data return len(data), 0, data
def _save(im, fp, filename): def _save(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
if im.mode != "P": if im.mode != "P":
msg = "Unsupported BLP image mode" msg = "Unsupported BLP image mode"
raise ValueError(msg) raise ValueError(msg)
@@ -452,17 +469,23 @@ def _save(im, fp, filename):
magic = b"BLP1" if im.encoderinfo.get("blp_version") == "BLP1" else b"BLP2" magic = b"BLP1" if im.encoderinfo.get("blp_version") == "BLP1" else b"BLP2"
fp.write(magic) fp.write(magic)
assert im.palette is not None
fp.write(struct.pack("<i", 1)) # Uncompressed or DirectX compression fp.write(struct.pack("<i", 1)) # Uncompressed or DirectX compression
fp.write(struct.pack("<b", Encoding.UNCOMPRESSED))
fp.write(struct.pack("<b", 1 if im.palette.mode == "RGBA" else 0)) alpha_depth = 1 if im.palette.mode == "RGBA" else 0
fp.write(struct.pack("<b", 0)) # alpha encoding if magic == b"BLP1":
fp.write(struct.pack("<b", 0)) # mips fp.write(struct.pack("<L", alpha_depth))
else:
fp.write(struct.pack("<b", Encoding.UNCOMPRESSED))
fp.write(struct.pack("<b", alpha_depth))
fp.write(struct.pack("<b", 0)) # alpha encoding
fp.write(struct.pack("<b", 0)) # mips
fp.write(struct.pack("<II", *im.size)) fp.write(struct.pack("<II", *im.size))
if magic == b"BLP1": if magic == b"BLP1":
fp.write(struct.pack("<i", 5)) fp.write(struct.pack("<i", 5))
fp.write(struct.pack("<i", 0)) fp.write(struct.pack("<i", 0))
ImageFile._save(im, fp, [("BLP", (0, 0) + im.size, 0, im.mode)]) ImageFile._save(im, fp, [ImageFile._Tile("BLP", (0, 0) + im.size, 0, im.mode)])
Image.register_open(BlpImageFile.format, BlpImageFile, _accept) Image.register_open(BlpImageFile.format, BlpImageFile, _accept)

View File

@@ -22,9 +22,10 @@
# #
# See the README file for information on usage and redistribution. # See the README file for information on usage and redistribution.
# #
from __future__ import annotations
import os import os
from typing import IO, Any
from . import Image, ImageFile, ImagePalette from . import Image, ImageFile, ImagePalette
from ._binary import i16le as i16 from ._binary import i16le as i16
@@ -47,13 +48,15 @@ BIT2MODE = {
32: ("RGB", "BGRX"), 32: ("RGB", "BGRX"),
} }
USE_RAW_ALPHA = False
def _accept(prefix):
return prefix[:2] == b"BM"
def _dib_accept(prefix): def _accept(prefix: bytes) -> bool:
return i32(prefix) in [12, 40, 64, 108, 124] return prefix.startswith(b"BM")
def _dib_accept(prefix: bytes) -> bool:
return i32(prefix) in [12, 40, 52, 56, 64, 108, 124]
# ============================================================================= # =============================================================================
@@ -71,31 +74,41 @@ class BmpImageFile(ImageFile.ImageFile):
for k, v in COMPRESSIONS.items(): for k, v in COMPRESSIONS.items():
vars()[k] = v vars()[k] = v
def _bitmap(self, header=0, offset=0): def _bitmap(self, header: int = 0, offset: int = 0) -> None:
"""Read relevant info about the BMP""" """Read relevant info about the BMP"""
read, seek = self.fp.read, self.fp.seek read, seek = self.fp.read, self.fp.seek
if header: if header:
seek(header) seek(header)
# read bmp header size @offset 14 (this is part of the header size) # read bmp header size @offset 14 (this is part of the header size)
file_info = {"header_size": i32(read(4)), "direction": -1} file_info: dict[str, bool | int | tuple[int, ...]] = {
"header_size": i32(read(4)),
"direction": -1,
}
# -------------------- If requested, read header at a specific position # -------------------- If requested, read header at a specific position
# read the rest of the bmp header, without its size # read the rest of the bmp header, without its size
assert isinstance(file_info["header_size"], int)
header_data = ImageFile._safe_read(self.fp, file_info["header_size"] - 4) header_data = ImageFile._safe_read(self.fp, file_info["header_size"] - 4)
# -------------------------------------------------- IBM OS/2 Bitmap v1 # ------------------------------- Windows Bitmap v2, IBM OS/2 Bitmap v1
# ----- This format has different offsets because of width/height types # ----- This format has different offsets because of width/height types
# 12: BITMAPCOREHEADER/OS21XBITMAPHEADER
if file_info["header_size"] == 12: if file_info["header_size"] == 12:
file_info["width"] = i16(header_data, 0) file_info["width"] = i16(header_data, 0)
file_info["height"] = i16(header_data, 2) file_info["height"] = i16(header_data, 2)
file_info["planes"] = i16(header_data, 4) file_info["planes"] = i16(header_data, 4)
file_info["bits"] = i16(header_data, 6) file_info["bits"] = i16(header_data, 6)
file_info["compression"] = self.RAW file_info["compression"] = self.COMPRESSIONS["RAW"]
file_info["palette_padding"] = 3 file_info["palette_padding"] = 3
# --------------------------------------------- Windows Bitmap v2 to v5 # --------------------------------------------- Windows Bitmap v3 to v5
# v3, OS/2 v2, v4, v5 # 40: BITMAPINFOHEADER
elif file_info["header_size"] in (40, 64, 108, 124): # 52: BITMAPV2HEADER
# 56: BITMAPV3HEADER
# 64: BITMAPCOREHEADER2/OS22XBITMAPHEADER
# 108: BITMAPV4HEADER
# 124: BITMAPV5HEADER
elif file_info["header_size"] in (40, 52, 56, 64, 108, 124):
file_info["y_flip"] = header_data[7] == 0xFF file_info["y_flip"] = header_data[7] == 0xFF
file_info["direction"] = 1 if file_info["y_flip"] else -1 file_info["direction"] = 1 if file_info["y_flip"] else -1
file_info["width"] = i32(header_data, 0) file_info["width"] = i32(header_data, 0)
@@ -115,12 +128,16 @@ class BmpImageFile(ImageFile.ImageFile):
) )
file_info["colors"] = i32(header_data, 28) file_info["colors"] = i32(header_data, 28)
file_info["palette_padding"] = 4 file_info["palette_padding"] = 4
assert isinstance(file_info["pixels_per_meter"], tuple)
self.info["dpi"] = tuple(x / 39.3701 for x in file_info["pixels_per_meter"]) self.info["dpi"] = tuple(x / 39.3701 for x in file_info["pixels_per_meter"])
if file_info["compression"] == self.BITFIELDS: if file_info["compression"] == self.COMPRESSIONS["BITFIELDS"]:
if len(header_data) >= 52: masks = ["r_mask", "g_mask", "b_mask"]
for idx, mask in enumerate( if len(header_data) >= 48:
["r_mask", "g_mask", "b_mask", "a_mask"] if len(header_data) >= 52:
): masks.append("a_mask")
else:
file_info["a_mask"] = 0x0
for idx, mask in enumerate(masks):
file_info[mask] = i32(header_data, 36 + idx * 4) file_info[mask] = i32(header_data, 36 + idx * 4)
else: else:
# 40 byte headers only have the three components in the # 40 byte headers only have the three components in the
@@ -132,8 +149,12 @@ class BmpImageFile(ImageFile.ImageFile):
# location, but it is listed as a reserved component, # location, but it is listed as a reserved component,
# and it is not generally an alpha channel # and it is not generally an alpha channel
file_info["a_mask"] = 0x0 file_info["a_mask"] = 0x0
for mask in ["r_mask", "g_mask", "b_mask"]: for mask in masks:
file_info[mask] = i32(read(4)) file_info[mask] = i32(read(4))
assert isinstance(file_info["r_mask"], int)
assert isinstance(file_info["g_mask"], int)
assert isinstance(file_info["b_mask"], int)
assert isinstance(file_info["a_mask"], int)
file_info["rgb_mask"] = ( file_info["rgb_mask"] = (
file_info["r_mask"], file_info["r_mask"],
file_info["g_mask"], file_info["g_mask"],
@@ -151,33 +172,39 @@ class BmpImageFile(ImageFile.ImageFile):
# ------------------ Special case : header is reported 40, which # ------------------ Special case : header is reported 40, which
# ---------------------- is shorter than real size for bpp >= 16 # ---------------------- is shorter than real size for bpp >= 16
assert isinstance(file_info["width"], int)
assert isinstance(file_info["height"], int)
self._size = file_info["width"], file_info["height"] self._size = file_info["width"], file_info["height"]
# ------- If color count was not found in the header, compute from bits # ------- If color count was not found in the header, compute from bits
assert isinstance(file_info["bits"], int)
file_info["colors"] = ( file_info["colors"] = (
file_info["colors"] file_info["colors"]
if file_info.get("colors", 0) if file_info.get("colors", 0)
else (1 << file_info["bits"]) else (1 << file_info["bits"])
) )
assert isinstance(file_info["colors"], int)
if offset == 14 + file_info["header_size"] and file_info["bits"] <= 8: if offset == 14 + file_info["header_size"] and file_info["bits"] <= 8:
offset += 4 * file_info["colors"] offset += 4 * file_info["colors"]
# ---------------------- Check bit depth for unusual unsupported values # ---------------------- Check bit depth for unusual unsupported values
self._mode, raw_mode = BIT2MODE.get(file_info["bits"], (None, None)) self._mode, raw_mode = BIT2MODE.get(file_info["bits"], ("", ""))
if self.mode is None: if not self.mode:
msg = f"Unsupported BMP pixel depth ({file_info['bits']})" msg = f"Unsupported BMP pixel depth ({file_info['bits']})"
raise OSError(msg) raise OSError(msg)
# ---------------- Process BMP with Bitfields compression (not palette) # ---------------- Process BMP with Bitfields compression (not palette)
decoder_name = "raw" decoder_name = "raw"
if file_info["compression"] == self.BITFIELDS: if file_info["compression"] == self.COMPRESSIONS["BITFIELDS"]:
SUPPORTED = { SUPPORTED: dict[int, list[tuple[int, ...]]] = {
32: [ 32: [
(0xFF0000, 0xFF00, 0xFF, 0x0), (0xFF0000, 0xFF00, 0xFF, 0x0),
(0xFF000000, 0xFF0000, 0xFF00, 0x0), (0xFF000000, 0xFF0000, 0xFF00, 0x0),
(0xFF000000, 0xFF00, 0xFF, 0x0),
(0xFF000000, 0xFF0000, 0xFF00, 0xFF), (0xFF000000, 0xFF0000, 0xFF00, 0xFF),
(0xFF, 0xFF00, 0xFF0000, 0xFF000000), (0xFF, 0xFF00, 0xFF0000, 0xFF000000),
(0xFF0000, 0xFF00, 0xFF, 0xFF000000), (0xFF0000, 0xFF00, 0xFF, 0xFF000000),
(0xFF000000, 0xFF00, 0xFF, 0xFF0000),
(0x0, 0x0, 0x0, 0x0), (0x0, 0x0, 0x0, 0x0),
], ],
24: [(0xFF0000, 0xFF00, 0xFF)], 24: [(0xFF0000, 0xFF00, 0xFF)],
@@ -186,9 +213,11 @@ class BmpImageFile(ImageFile.ImageFile):
MASK_MODES = { MASK_MODES = {
(32, (0xFF0000, 0xFF00, 0xFF, 0x0)): "BGRX", (32, (0xFF0000, 0xFF00, 0xFF, 0x0)): "BGRX",
(32, (0xFF000000, 0xFF0000, 0xFF00, 0x0)): "XBGR", (32, (0xFF000000, 0xFF0000, 0xFF00, 0x0)): "XBGR",
(32, (0xFF000000, 0xFF00, 0xFF, 0x0)): "BGXR",
(32, (0xFF000000, 0xFF0000, 0xFF00, 0xFF)): "ABGR", (32, (0xFF000000, 0xFF0000, 0xFF00, 0xFF)): "ABGR",
(32, (0xFF, 0xFF00, 0xFF0000, 0xFF000000)): "RGBA", (32, (0xFF, 0xFF00, 0xFF0000, 0xFF000000)): "RGBA",
(32, (0xFF0000, 0xFF00, 0xFF, 0xFF000000)): "BGRA", (32, (0xFF0000, 0xFF00, 0xFF, 0xFF000000)): "BGRA",
(32, (0xFF000000, 0xFF00, 0xFF, 0xFF0000)): "BGAR",
(32, (0x0, 0x0, 0x0, 0x0)): "BGRA", (32, (0x0, 0x0, 0x0, 0x0)): "BGRA",
(24, (0xFF0000, 0xFF00, 0xFF)): "BGR", (24, (0xFF0000, 0xFF00, 0xFF)): "BGR",
(16, (0xF800, 0x7E0, 0x1F)): "BGR;16", (16, (0xF800, 0x7E0, 0x1F)): "BGR;16",
@@ -199,12 +228,14 @@ class BmpImageFile(ImageFile.ImageFile):
file_info["bits"] == 32 file_info["bits"] == 32
and file_info["rgba_mask"] in SUPPORTED[file_info["bits"]] and file_info["rgba_mask"] in SUPPORTED[file_info["bits"]]
): ):
assert isinstance(file_info["rgba_mask"], tuple)
raw_mode = MASK_MODES[(file_info["bits"], file_info["rgba_mask"])] raw_mode = MASK_MODES[(file_info["bits"], file_info["rgba_mask"])]
self._mode = "RGBA" if "A" in raw_mode else self.mode self._mode = "RGBA" if "A" in raw_mode else self.mode
elif ( elif (
file_info["bits"] in (24, 16) file_info["bits"] in (24, 16)
and file_info["rgb_mask"] in SUPPORTED[file_info["bits"]] and file_info["rgb_mask"] in SUPPORTED[file_info["bits"]]
): ):
assert isinstance(file_info["rgb_mask"], tuple)
raw_mode = MASK_MODES[(file_info["bits"], file_info["rgb_mask"])] raw_mode = MASK_MODES[(file_info["bits"], file_info["rgb_mask"])]
else: else:
msg = "Unsupported BMP bitfields layout" msg = "Unsupported BMP bitfields layout"
@@ -212,10 +243,15 @@ class BmpImageFile(ImageFile.ImageFile):
else: else:
msg = "Unsupported BMP bitfields layout" msg = "Unsupported BMP bitfields layout"
raise OSError(msg) raise OSError(msg)
elif file_info["compression"] == self.RAW: elif file_info["compression"] == self.COMPRESSIONS["RAW"]:
if file_info["bits"] == 32 and header == 22: # 32-bit .cur offset if file_info["bits"] == 32 and (
header == 22 or USE_RAW_ALPHA # 32-bit .cur offset
):
raw_mode, self._mode = "BGRA", "RGBA" raw_mode, self._mode = "BGRA", "RGBA"
elif file_info["compression"] in (self.RLE8, self.RLE4): elif file_info["compression"] in (
self.COMPRESSIONS["RLE8"],
self.COMPRESSIONS["RLE4"],
):
decoder_name = "bmp_rle" decoder_name = "bmp_rle"
else: else:
msg = f"Unsupported BMP compression ({file_info['compression']})" msg = f"Unsupported BMP compression ({file_info['compression']})"
@@ -228,23 +264,24 @@ class BmpImageFile(ImageFile.ImageFile):
msg = f"Unsupported BMP Palette size ({file_info['colors']})" msg = f"Unsupported BMP Palette size ({file_info['colors']})"
raise OSError(msg) raise OSError(msg)
else: else:
assert isinstance(file_info["palette_padding"], int)
padding = file_info["palette_padding"] padding = file_info["palette_padding"]
palette = read(padding * file_info["colors"]) palette = read(padding * file_info["colors"])
greyscale = True grayscale = True
indices = ( indices = (
(0, 255) (0, 255)
if file_info["colors"] == 2 if file_info["colors"] == 2
else list(range(file_info["colors"])) else list(range(file_info["colors"]))
) )
# ----------------- Check if greyscale and ignore palette if so # ----------------- Check if grayscale and ignore palette if so
for ind, val in enumerate(indices): for ind, val in enumerate(indices):
rgb = palette[ind * padding : ind * padding + 3] rgb = palette[ind * padding : ind * padding + 3]
if rgb != o8(val) * 3: if rgb != o8(val) * 3:
greyscale = False grayscale = False
# ------- If all colors are grey, white or black, ditch palette # ------- If all colors are gray, white or black, ditch palette
if greyscale: if grayscale:
self._mode = "1" if file_info["colors"] == 2 else "L" self._mode = "1" if file_info["colors"] == 2 else "L"
raw_mode = self.mode raw_mode = self.mode
else: else:
@@ -255,14 +292,15 @@ class BmpImageFile(ImageFile.ImageFile):
# ---------------------------- Finally set the tile data for the plugin # ---------------------------- Finally set the tile data for the plugin
self.info["compression"] = file_info["compression"] self.info["compression"] = file_info["compression"]
args = [raw_mode] args: list[Any] = [raw_mode]
if decoder_name == "bmp_rle": if decoder_name == "bmp_rle":
args.append(file_info["compression"] == self.RLE4) args.append(file_info["compression"] == self.COMPRESSIONS["RLE4"])
else: else:
assert isinstance(file_info["width"], int)
args.append(((file_info["width"] * file_info["bits"] + 31) >> 3) & (~3)) args.append(((file_info["width"] * file_info["bits"] + 31) >> 3) & (~3))
args.append(file_info["direction"]) args.append(file_info["direction"])
self.tile = [ self.tile = [
( ImageFile._Tile(
decoder_name, decoder_name,
(0, 0, file_info["width"], file_info["height"]), (0, 0, file_info["width"], file_info["height"]),
offset or self.fp.tell(), offset or self.fp.tell(),
@@ -270,7 +308,7 @@ class BmpImageFile(ImageFile.ImageFile):
) )
] ]
def _open(self): def _open(self) -> None:
"""Open file, check magic number and read header""" """Open file, check magic number and read header"""
# read 14 bytes: magic number, filesize, reserved, header final offset # read 14 bytes: magic number, filesize, reserved, header final offset
head_data = self.fp.read(14) head_data = self.fp.read(14)
@@ -287,11 +325,13 @@ class BmpImageFile(ImageFile.ImageFile):
class BmpRleDecoder(ImageFile.PyDecoder): class BmpRleDecoder(ImageFile.PyDecoder):
_pulls_fd = True _pulls_fd = True
def decode(self, buffer): def decode(self, buffer: bytes | Image.SupportsArrayInterface) -> tuple[int, int]:
assert self.fd is not None
rle4 = self.args[1] rle4 = self.args[1]
data = bytearray() data = bytearray()
x = 0 x = 0
while len(data) < self.state.xsize * self.state.ysize: dest_length = self.state.xsize * self.state.ysize
while len(data) < dest_length:
pixels = self.fd.read(1) pixels = self.fd.read(1)
byte = self.fd.read(1) byte = self.fd.read(1)
if not pixels or not byte: if not pixels or not byte:
@@ -351,7 +391,7 @@ class BmpRleDecoder(ImageFile.PyDecoder):
if self.fd.tell() % 2 != 0: if self.fd.tell() % 2 != 0:
self.fd.seek(1, os.SEEK_CUR) self.fd.seek(1, os.SEEK_CUR)
rawmode = "L" if self.mode == "L" else "P" rawmode = "L" if self.mode == "L" else "P"
self.set_as_raw(bytes(data), (rawmode, 0, self.args[-1])) self.set_as_raw(bytes(data), rawmode, (0, self.args[-1]))
return -1, 0 return -1, 0
@@ -362,7 +402,7 @@ class DibImageFile(BmpImageFile):
format = "DIB" format = "DIB"
format_description = "Windows Bitmap" format_description = "Windows Bitmap"
def _open(self): def _open(self) -> None:
self._bitmap() self._bitmap()
@@ -380,11 +420,13 @@ SAVE = {
} }
def _dib_save(im, fp, filename): def _dib_save(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
_save(im, fp, filename, False) _save(im, fp, filename, False)
def _save(im, fp, filename, bitmap_header=True): def _save(
im: Image.Image, fp: IO[bytes], filename: str | bytes, bitmap_header: bool = True
) -> None:
try: try:
rawmode, bits, colors = SAVE[im.mode] rawmode, bits, colors = SAVE[im.mode]
except KeyError as e: except KeyError as e:
@@ -396,16 +438,16 @@ def _save(im, fp, filename, bitmap_header=True):
dpi = info.get("dpi", (96, 96)) dpi = info.get("dpi", (96, 96))
# 1 meter == 39.3701 inches # 1 meter == 39.3701 inches
ppm = tuple(map(lambda x: int(x * 39.3701 + 0.5), dpi)) ppm = tuple(int(x * 39.3701 + 0.5) for x in dpi)
stride = ((im.size[0] * bits + 7) // 8 + 3) & (~3) stride = ((im.size[0] * bits + 7) // 8 + 3) & (~3)
header = 40 # or 64 for OS/2 version 2 header = 40 # or 64 for OS/2 version 2
image = stride * im.size[1] image = stride * im.size[1]
if im.mode == "1": if im.mode == "1":
palette = b"".join(o8(i) * 4 for i in (0, 255)) palette = b"".join(o8(i) * 3 + b"\x00" for i in (0, 255))
elif im.mode == "L": elif im.mode == "L":
palette = b"".join(o8(i) * 4 for i in range(256)) palette = b"".join(o8(i) * 3 + b"\x00" for i in range(256))
elif im.mode == "P": elif im.mode == "P":
palette = im.im.getpalette("RGB", "BGRX") palette = im.im.getpalette("RGB", "BGRX")
colors = len(palette) // 4 colors = len(palette) // 4
@@ -446,7 +488,9 @@ def _save(im, fp, filename, bitmap_header=True):
if palette: if palette:
fp.write(palette) fp.write(palette)
ImageFile._save(im, fp, [("raw", (0, 0) + im.size, 0, (rawmode, stride, -1))]) ImageFile._save(
im, fp, [ImageFile._Tile("raw", (0, 0) + im.size, 0, (rawmode, stride, -1))]
)
# #

View File

@@ -8,13 +8,17 @@
# #
# See the README file for information on usage and redistribution. # See the README file for information on usage and redistribution.
# #
from __future__ import annotations
import os
from typing import IO
from . import Image, ImageFile from . import Image, ImageFile
_handler = None _handler = None
def register_handler(handler): def register_handler(handler: ImageFile.StubHandler | None) -> None:
""" """
Install application-specific BUFR image handler. Install application-specific BUFR image handler.
@@ -28,22 +32,20 @@ def register_handler(handler):
# Image adapter # Image adapter
def _accept(prefix): def _accept(prefix: bytes) -> bool:
return prefix[:4] == b"BUFR" or prefix[:4] == b"ZCZC" return prefix.startswith((b"BUFR", b"ZCZC"))
class BufrStubImageFile(ImageFile.StubImageFile): class BufrStubImageFile(ImageFile.StubImageFile):
format = "BUFR" format = "BUFR"
format_description = "BUFR" format_description = "BUFR"
def _open(self): def _open(self) -> None:
offset = self.fp.tell()
if not _accept(self.fp.read(4)): if not _accept(self.fp.read(4)):
msg = "Not a BUFR file" msg = "Not a BUFR file"
raise SyntaxError(msg) raise SyntaxError(msg)
self.fp.seek(offset) self.fp.seek(-4, os.SEEK_CUR)
# make something up # make something up
self._mode = "F" self._mode = "F"
@@ -53,11 +55,11 @@ class BufrStubImageFile(ImageFile.StubImageFile):
if loader: if loader:
loader.open(self) loader.open(self)
def _load(self): def _load(self) -> ImageFile.StubHandler | None:
return _handler return _handler
def _save(im, fp, filename): def _save(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
if _handler is None or not hasattr(_handler, "save"): if _handler is None or not hasattr(_handler, "save"):
msg = "BUFR save handler not installed" msg = "BUFR save handler not installed"
raise OSError(msg) raise OSError(msg)

View File

@@ -13,18 +13,20 @@
# #
# See the README file for information on usage and redistribution. # See the README file for information on usage and redistribution.
# #
from __future__ import annotations
import io import io
from collections.abc import Iterable
from typing import IO, AnyStr, NoReturn
class ContainerIO: class ContainerIO(IO[AnyStr]):
""" """
A file object that provides read access to a part of an existing A file object that provides read access to a part of an existing
file (for example a TAR file). file (for example a TAR file).
""" """
def __init__(self, file, offset, length): def __init__(self, file: IO[AnyStr], offset: int, length: int) -> None:
""" """
Create file object. Create file object.
@@ -32,7 +34,7 @@ class ContainerIO:
:param offset: Start of region, in bytes. :param offset: Start of region, in bytes.
:param length: Size of region, in bytes. :param length: Size of region, in bytes.
""" """
self.fh = file self.fh: IO[AnyStr] = file
self.pos = 0 self.pos = 0
self.offset = offset self.offset = offset
self.length = length self.length = length
@@ -41,10 +43,13 @@ class ContainerIO:
## ##
# Always false. # Always false.
def isatty(self): def isatty(self) -> bool:
return False return False
def seek(self, offset, mode=io.SEEK_SET): def seekable(self) -> bool:
return True
def seek(self, offset: int, mode: int = io.SEEK_SET) -> int:
""" """
Move file pointer. Move file pointer.
@@ -52,6 +57,7 @@ class ContainerIO:
:param mode: Starting position. Use 0 for beginning of region, 1 :param mode: Starting position. Use 0 for beginning of region, 1
for current offset, and 2 for end of region. You cannot move for current offset, and 2 for end of region. You cannot move
the pointer outside the defined region. the pointer outside the defined region.
:returns: Offset from start of region, in bytes.
""" """
if mode == 1: if mode == 1:
self.pos = self.pos + offset self.pos = self.pos + offset
@@ -62,8 +68,9 @@ class ContainerIO:
# clamp # clamp
self.pos = max(0, min(self.pos, self.length)) self.pos = max(0, min(self.pos, self.length))
self.fh.seek(self.offset + self.pos) self.fh.seek(self.offset + self.pos)
return self.pos
def tell(self): def tell(self) -> int:
""" """
Get current file pointer. Get current file pointer.
@@ -71,44 +78,51 @@ class ContainerIO:
""" """
return self.pos return self.pos
def read(self, n=0): def readable(self) -> bool:
return True
def read(self, n: int = -1) -> AnyStr:
""" """
Read data. Read data.
:param n: Number of bytes to read. If omitted or zero, :param n: Number of bytes to read. If omitted, zero or negative,
read until end of region. read until end of region.
:returns: An 8-bit string. :returns: An 8-bit string.
""" """
if n: if n > 0:
n = min(n, self.length - self.pos) n = min(n, self.length - self.pos)
else: else:
n = self.length - self.pos n = self.length - self.pos
if not n: # EOF if n <= 0: # EOF
return b"" if "b" in self.fh.mode else "" return b"" if "b" in self.fh.mode else "" # type: ignore[return-value]
self.pos = self.pos + n self.pos = self.pos + n
return self.fh.read(n) return self.fh.read(n)
def readline(self): def readline(self, n: int = -1) -> AnyStr:
""" """
Read a line of text. Read a line of text.
:param n: Number of bytes to read. If omitted, zero or negative,
read until end of line.
:returns: An 8-bit string. :returns: An 8-bit string.
""" """
s = b"" if "b" in self.fh.mode else "" s: AnyStr = b"" if "b" in self.fh.mode else "" # type: ignore[assignment]
newline_character = b"\n" if "b" in self.fh.mode else "\n" newline_character = b"\n" if "b" in self.fh.mode else "\n"
while True: while True:
c = self.read(1) c = self.read(1)
if not c: if not c:
break break
s = s + c s = s + c
if c == newline_character: if c == newline_character or len(s) == n:
break break
return s return s
def readlines(self): def readlines(self, n: int | None = -1) -> list[AnyStr]:
""" """
Read multiple lines of text. Read multiple lines of text.
:param n: Number of lines to read. If omitted, zero, negative or None,
read until end of region.
:returns: A list of 8-bit strings. :returns: A list of 8-bit strings.
""" """
lines = [] lines = []
@@ -117,4 +131,43 @@ class ContainerIO:
if not s: if not s:
break break
lines.append(s) lines.append(s)
if len(lines) == n:
break
return lines return lines
def writable(self) -> bool:
return False
def write(self, b: AnyStr) -> NoReturn:
raise NotImplementedError()
def writelines(self, lines: Iterable[AnyStr]) -> NoReturn:
raise NotImplementedError()
def truncate(self, size: int | None = None) -> int:
raise NotImplementedError()
def __enter__(self) -> ContainerIO[AnyStr]:
return self
def __exit__(self, *args: object) -> None:
self.close()
def __iter__(self) -> ContainerIO[AnyStr]:
return self
def __next__(self) -> AnyStr:
line = self.readline()
if not line:
msg = "end of region"
raise StopIteration(msg)
return line
def fileno(self) -> int:
return self.fh.fileno()
def flush(self) -> None:
self.fh.flush()
def close(self) -> None:
self.fh.close()

View File

@@ -15,6 +15,8 @@
# #
# See the README file for information on usage and redistribution. # See the README file for information on usage and redistribution.
# #
from __future__ import annotations
from . import BmpImagePlugin, Image from . import BmpImagePlugin, Image
from ._binary import i16le as i16 from ._binary import i16le as i16
from ._binary import i32le as i32 from ._binary import i32le as i32
@@ -23,8 +25,8 @@ from ._binary import i32le as i32
# -------------------------------------------------------------------- # --------------------------------------------------------------------
def _accept(prefix): def _accept(prefix: bytes) -> bool:
return prefix[:4] == b"\0\0\2\0" return prefix.startswith(b"\0\0\2\0")
## ##
@@ -35,7 +37,8 @@ class CurImageFile(BmpImagePlugin.BmpImageFile):
format = "CUR" format = "CUR"
format_description = "Windows Cursor" format_description = "Windows Cursor"
def _open(self): def _open(self) -> None:
assert self.fp is not None
offset = self.fp.tell() offset = self.fp.tell()
# check magic # check magic
@@ -61,10 +64,7 @@ class CurImageFile(BmpImagePlugin.BmpImageFile):
# patch up the bitmap height # patch up the bitmap height
self._size = self.size[0], self.size[1] // 2 self._size = self.size[0], self.size[1] // 2
d, e, o, a = self.tile[0] self.tile = [self.tile[0]._replace(extents=(0, 0) + self.size)]
self.tile[0] = d, (0, 0) + self.size, o, a
return
# #

View File

@@ -20,15 +20,17 @@
# #
# See the README file for information on usage and redistribution. # See the README file for information on usage and redistribution.
# #
from __future__ import annotations
from . import Image from . import Image
from ._binary import i32le as i32 from ._binary import i32le as i32
from ._util import DeferredError
from .PcxImagePlugin import PcxImageFile from .PcxImagePlugin import PcxImageFile
MAGIC = 0x3ADE68B1 # QUIZ: what's this value, then? MAGIC = 0x3ADE68B1 # QUIZ: what's this value, then?
def _accept(prefix): def _accept(prefix: bytes) -> bool:
return len(prefix) >= 4 and i32(prefix) == MAGIC return len(prefix) >= 4 and i32(prefix) == MAGIC
@@ -41,7 +43,7 @@ class DcxImageFile(PcxImageFile):
format_description = "Intel DCX" format_description = "Intel DCX"
_close_exclusive_fp_after_loading = False _close_exclusive_fp_after_loading = False
def _open(self): def _open(self) -> None:
# Header # Header
s = self.fp.read(4) s = self.fp.read(4)
if not _accept(s): if not _accept(s):
@@ -57,20 +59,22 @@ class DcxImageFile(PcxImageFile):
self._offset.append(offset) self._offset.append(offset)
self._fp = self.fp self._fp = self.fp
self.frame = None self.frame = -1
self.n_frames = len(self._offset) self.n_frames = len(self._offset)
self.is_animated = self.n_frames > 1 self.is_animated = self.n_frames > 1
self.seek(0) self.seek(0)
def seek(self, frame): def seek(self, frame: int) -> None:
if not self._seek_check(frame): if not self._seek_check(frame):
return return
if isinstance(self._fp, DeferredError):
raise self._fp.ex
self.frame = frame self.frame = frame
self.fp = self._fp self.fp = self._fp
self.fp.seek(self._offset[frame]) self.fp.seek(self._offset[frame])
PcxImageFile._open(self) PcxImageFile._open(self)
def tell(self): def tell(self) -> int:
return self.frame return self.frame

View File

@@ -1,118 +1,338 @@
""" """
A Pillow loader for .dds files (S3TC-compressed aka DXTC) A Pillow plugin for .dds files (S3TC-compressed aka DXTC)
Jerome Leclanche <jerome@leclan.ch> Jerome Leclanche <jerome@leclan.ch>
Documentation: Documentation:
https://web.archive.org/web/20170802060935/http://oss.sgi.com/projects/ogl-sample/registry/EXT/texture_compression_s3tc.txt https://web.archive.org/web/20170802060935/http://oss.sgi.com/projects/ogl-sample/registry/EXT/texture_compression_s3tc.txt
The contents of this file are hereby released in the public domain (CC0) The contents of this file are hereby released in the public domain (CC0)
Full text of the CC0 license: Full text of the CC0 license:
https://creativecommons.org/publicdomain/zero/1.0/ https://creativecommons.org/publicdomain/zero/1.0/
""" """
from __future__ import annotations
import io
import struct import struct
from io import BytesIO import sys
from enum import IntEnum, IntFlag
from typing import IO
from . import Image, ImageFile, ImagePalette from . import Image, ImageFile, ImagePalette
from ._binary import i32le as i32
from ._binary import o8
from ._binary import o32le as o32 from ._binary import o32le as o32
# Magic ("DDS ") # Magic ("DDS ")
DDS_MAGIC = 0x20534444 DDS_MAGIC = 0x20534444
# DDS flags # DDS flags
DDSD_CAPS = 0x1 class DDSD(IntFlag):
DDSD_HEIGHT = 0x2 CAPS = 0x1
DDSD_WIDTH = 0x4 HEIGHT = 0x2
DDSD_PITCH = 0x8 WIDTH = 0x4
DDSD_PIXELFORMAT = 0x1000 PITCH = 0x8
DDSD_MIPMAPCOUNT = 0x20000 PIXELFORMAT = 0x1000
DDSD_LINEARSIZE = 0x80000 MIPMAPCOUNT = 0x20000
DDSD_DEPTH = 0x800000 LINEARSIZE = 0x80000
DEPTH = 0x800000
# DDS caps # DDS caps
DDSCAPS_COMPLEX = 0x8 class DDSCAPS(IntFlag):
DDSCAPS_TEXTURE = 0x1000 COMPLEX = 0x8
DDSCAPS_MIPMAP = 0x400000 TEXTURE = 0x1000
MIPMAP = 0x400000
class DDSCAPS2(IntFlag):
CUBEMAP = 0x200
CUBEMAP_POSITIVEX = 0x400
CUBEMAP_NEGATIVEX = 0x800
CUBEMAP_POSITIVEY = 0x1000
CUBEMAP_NEGATIVEY = 0x2000
CUBEMAP_POSITIVEZ = 0x4000
CUBEMAP_NEGATIVEZ = 0x8000
VOLUME = 0x200000
DDSCAPS2_CUBEMAP = 0x200
DDSCAPS2_CUBEMAP_POSITIVEX = 0x400
DDSCAPS2_CUBEMAP_NEGATIVEX = 0x800
DDSCAPS2_CUBEMAP_POSITIVEY = 0x1000
DDSCAPS2_CUBEMAP_NEGATIVEY = 0x2000
DDSCAPS2_CUBEMAP_POSITIVEZ = 0x4000
DDSCAPS2_CUBEMAP_NEGATIVEZ = 0x8000
DDSCAPS2_VOLUME = 0x200000
# Pixel Format # Pixel Format
DDPF_ALPHAPIXELS = 0x1 class DDPF(IntFlag):
DDPF_ALPHA = 0x2 ALPHAPIXELS = 0x1
DDPF_FOURCC = 0x4 ALPHA = 0x2
DDPF_PALETTEINDEXED8 = 0x20 FOURCC = 0x4
DDPF_RGB = 0x40 PALETTEINDEXED8 = 0x20
DDPF_LUMINANCE = 0x20000 RGB = 0x40
LUMINANCE = 0x20000
# dds.h
DDS_FOURCC = DDPF_FOURCC
DDS_RGB = DDPF_RGB
DDS_RGBA = DDPF_RGB | DDPF_ALPHAPIXELS
DDS_LUMINANCE = DDPF_LUMINANCE
DDS_LUMINANCEA = DDPF_LUMINANCE | DDPF_ALPHAPIXELS
DDS_ALPHA = DDPF_ALPHA
DDS_PAL8 = DDPF_PALETTEINDEXED8
DDS_HEADER_FLAGS_TEXTURE = DDSD_CAPS | DDSD_HEIGHT | DDSD_WIDTH | DDSD_PIXELFORMAT
DDS_HEADER_FLAGS_MIPMAP = DDSD_MIPMAPCOUNT
DDS_HEADER_FLAGS_VOLUME = DDSD_DEPTH
DDS_HEADER_FLAGS_PITCH = DDSD_PITCH
DDS_HEADER_FLAGS_LINEARSIZE = DDSD_LINEARSIZE
DDS_HEIGHT = DDSD_HEIGHT
DDS_WIDTH = DDSD_WIDTH
DDS_SURFACE_FLAGS_TEXTURE = DDSCAPS_TEXTURE
DDS_SURFACE_FLAGS_MIPMAP = DDSCAPS_COMPLEX | DDSCAPS_MIPMAP
DDS_SURFACE_FLAGS_CUBEMAP = DDSCAPS_COMPLEX
DDS_CUBEMAP_POSITIVEX = DDSCAPS2_CUBEMAP | DDSCAPS2_CUBEMAP_POSITIVEX
DDS_CUBEMAP_NEGATIVEX = DDSCAPS2_CUBEMAP | DDSCAPS2_CUBEMAP_NEGATIVEX
DDS_CUBEMAP_POSITIVEY = DDSCAPS2_CUBEMAP | DDSCAPS2_CUBEMAP_POSITIVEY
DDS_CUBEMAP_NEGATIVEY = DDSCAPS2_CUBEMAP | DDSCAPS2_CUBEMAP_NEGATIVEY
DDS_CUBEMAP_POSITIVEZ = DDSCAPS2_CUBEMAP | DDSCAPS2_CUBEMAP_POSITIVEZ
DDS_CUBEMAP_NEGATIVEZ = DDSCAPS2_CUBEMAP | DDSCAPS2_CUBEMAP_NEGATIVEZ
# DXT1
DXT1_FOURCC = 0x31545844
# DXT3
DXT3_FOURCC = 0x33545844
# DXT5
DXT5_FOURCC = 0x35545844
# dxgiformat.h # dxgiformat.h
class DXGI_FORMAT(IntEnum):
UNKNOWN = 0
R32G32B32A32_TYPELESS = 1
R32G32B32A32_FLOAT = 2
R32G32B32A32_UINT = 3
R32G32B32A32_SINT = 4
R32G32B32_TYPELESS = 5
R32G32B32_FLOAT = 6
R32G32B32_UINT = 7
R32G32B32_SINT = 8
R16G16B16A16_TYPELESS = 9
R16G16B16A16_FLOAT = 10
R16G16B16A16_UNORM = 11
R16G16B16A16_UINT = 12
R16G16B16A16_SNORM = 13
R16G16B16A16_SINT = 14
R32G32_TYPELESS = 15
R32G32_FLOAT = 16
R32G32_UINT = 17
R32G32_SINT = 18
R32G8X24_TYPELESS = 19
D32_FLOAT_S8X24_UINT = 20
R32_FLOAT_X8X24_TYPELESS = 21
X32_TYPELESS_G8X24_UINT = 22
R10G10B10A2_TYPELESS = 23
R10G10B10A2_UNORM = 24
R10G10B10A2_UINT = 25
R11G11B10_FLOAT = 26
R8G8B8A8_TYPELESS = 27
R8G8B8A8_UNORM = 28
R8G8B8A8_UNORM_SRGB = 29
R8G8B8A8_UINT = 30
R8G8B8A8_SNORM = 31
R8G8B8A8_SINT = 32
R16G16_TYPELESS = 33
R16G16_FLOAT = 34
R16G16_UNORM = 35
R16G16_UINT = 36
R16G16_SNORM = 37
R16G16_SINT = 38
R32_TYPELESS = 39
D32_FLOAT = 40
R32_FLOAT = 41
R32_UINT = 42
R32_SINT = 43
R24G8_TYPELESS = 44
D24_UNORM_S8_UINT = 45
R24_UNORM_X8_TYPELESS = 46
X24_TYPELESS_G8_UINT = 47
R8G8_TYPELESS = 48
R8G8_UNORM = 49
R8G8_UINT = 50
R8G8_SNORM = 51
R8G8_SINT = 52
R16_TYPELESS = 53
R16_FLOAT = 54
D16_UNORM = 55
R16_UNORM = 56
R16_UINT = 57
R16_SNORM = 58
R16_SINT = 59
R8_TYPELESS = 60
R8_UNORM = 61
R8_UINT = 62
R8_SNORM = 63
R8_SINT = 64
A8_UNORM = 65
R1_UNORM = 66
R9G9B9E5_SHAREDEXP = 67
R8G8_B8G8_UNORM = 68
G8R8_G8B8_UNORM = 69
BC1_TYPELESS = 70
BC1_UNORM = 71
BC1_UNORM_SRGB = 72
BC2_TYPELESS = 73
BC2_UNORM = 74
BC2_UNORM_SRGB = 75
BC3_TYPELESS = 76
BC3_UNORM = 77
BC3_UNORM_SRGB = 78
BC4_TYPELESS = 79
BC4_UNORM = 80
BC4_SNORM = 81
BC5_TYPELESS = 82
BC5_UNORM = 83
BC5_SNORM = 84
B5G6R5_UNORM = 85
B5G5R5A1_UNORM = 86
B8G8R8A8_UNORM = 87
B8G8R8X8_UNORM = 88
R10G10B10_XR_BIAS_A2_UNORM = 89
B8G8R8A8_TYPELESS = 90
B8G8R8A8_UNORM_SRGB = 91
B8G8R8X8_TYPELESS = 92
B8G8R8X8_UNORM_SRGB = 93
BC6H_TYPELESS = 94
BC6H_UF16 = 95
BC6H_SF16 = 96
BC7_TYPELESS = 97
BC7_UNORM = 98
BC7_UNORM_SRGB = 99
AYUV = 100
Y410 = 101
Y416 = 102
NV12 = 103
P010 = 104
P016 = 105
OPAQUE_420 = 106
YUY2 = 107
Y210 = 108
Y216 = 109
NV11 = 110
AI44 = 111
IA44 = 112
P8 = 113
A8P8 = 114
B4G4R4A4_UNORM = 115
P208 = 130
V208 = 131
V408 = 132
SAMPLER_FEEDBACK_MIN_MIP_OPAQUE = 189
SAMPLER_FEEDBACK_MIP_REGION_USED_OPAQUE = 190
DXGI_FORMAT_R8G8B8A8_TYPELESS = 27
DXGI_FORMAT_R8G8B8A8_UNORM = 28 class D3DFMT(IntEnum):
DXGI_FORMAT_R8G8B8A8_UNORM_SRGB = 29 UNKNOWN = 0
DXGI_FORMAT_BC5_TYPELESS = 82 R8G8B8 = 20
DXGI_FORMAT_BC5_UNORM = 83 A8R8G8B8 = 21
DXGI_FORMAT_BC5_SNORM = 84 X8R8G8B8 = 22
DXGI_FORMAT_BC6H_UF16 = 95 R5G6B5 = 23
DXGI_FORMAT_BC6H_SF16 = 96 X1R5G5B5 = 24
DXGI_FORMAT_BC7_TYPELESS = 97 A1R5G5B5 = 25
DXGI_FORMAT_BC7_UNORM = 98 A4R4G4B4 = 26
DXGI_FORMAT_BC7_UNORM_SRGB = 99 R3G3B2 = 27
A8 = 28
A8R3G3B2 = 29
X4R4G4B4 = 30
A2B10G10R10 = 31
A8B8G8R8 = 32
X8B8G8R8 = 33
G16R16 = 34
A2R10G10B10 = 35
A16B16G16R16 = 36
A8P8 = 40
P8 = 41
L8 = 50
A8L8 = 51
A4L4 = 52
V8U8 = 60
L6V5U5 = 61
X8L8V8U8 = 62
Q8W8V8U8 = 63
V16U16 = 64
A2W10V10U10 = 67
D16_LOCKABLE = 70
D32 = 71
D15S1 = 73
D24S8 = 75
D24X8 = 77
D24X4S4 = 79
D16 = 80
D32F_LOCKABLE = 82
D24FS8 = 83
D32_LOCKABLE = 84
S8_LOCKABLE = 85
L16 = 81
VERTEXDATA = 100
INDEX16 = 101
INDEX32 = 102
Q16W16V16U16 = 110
R16F = 111
G16R16F = 112
A16B16G16R16F = 113
R32F = 114
G32R32F = 115
A32B32G32R32F = 116
CxV8U8 = 117
A1 = 118
A2B10G10R10_XR_BIAS = 119
BINARYBUFFER = 199
UYVY = i32(b"UYVY")
R8G8_B8G8 = i32(b"RGBG")
YUY2 = i32(b"YUY2")
G8R8_G8B8 = i32(b"GRGB")
DXT1 = i32(b"DXT1")
DXT2 = i32(b"DXT2")
DXT3 = i32(b"DXT3")
DXT4 = i32(b"DXT4")
DXT5 = i32(b"DXT5")
DX10 = i32(b"DX10")
BC4S = i32(b"BC4S")
BC4U = i32(b"BC4U")
BC5S = i32(b"BC5S")
BC5U = i32(b"BC5U")
ATI1 = i32(b"ATI1")
ATI2 = i32(b"ATI2")
MULTI2_ARGB8 = i32(b"MET1")
# Backward compatibility layer
module = sys.modules[__name__]
for item in DDSD:
assert item.name is not None
setattr(module, f"DDSD_{item.name}", item.value)
for item1 in DDSCAPS:
assert item1.name is not None
setattr(module, f"DDSCAPS_{item1.name}", item1.value)
for item2 in DDSCAPS2:
assert item2.name is not None
setattr(module, f"DDSCAPS2_{item2.name}", item2.value)
for item3 in DDPF:
assert item3.name is not None
setattr(module, f"DDPF_{item3.name}", item3.value)
DDS_FOURCC = DDPF.FOURCC
DDS_RGB = DDPF.RGB
DDS_RGBA = DDPF.RGB | DDPF.ALPHAPIXELS
DDS_LUMINANCE = DDPF.LUMINANCE
DDS_LUMINANCEA = DDPF.LUMINANCE | DDPF.ALPHAPIXELS
DDS_ALPHA = DDPF.ALPHA
DDS_PAL8 = DDPF.PALETTEINDEXED8
DDS_HEADER_FLAGS_TEXTURE = DDSD.CAPS | DDSD.HEIGHT | DDSD.WIDTH | DDSD.PIXELFORMAT
DDS_HEADER_FLAGS_MIPMAP = DDSD.MIPMAPCOUNT
DDS_HEADER_FLAGS_VOLUME = DDSD.DEPTH
DDS_HEADER_FLAGS_PITCH = DDSD.PITCH
DDS_HEADER_FLAGS_LINEARSIZE = DDSD.LINEARSIZE
DDS_HEIGHT = DDSD.HEIGHT
DDS_WIDTH = DDSD.WIDTH
DDS_SURFACE_FLAGS_TEXTURE = DDSCAPS.TEXTURE
DDS_SURFACE_FLAGS_MIPMAP = DDSCAPS.COMPLEX | DDSCAPS.MIPMAP
DDS_SURFACE_FLAGS_CUBEMAP = DDSCAPS.COMPLEX
DDS_CUBEMAP_POSITIVEX = DDSCAPS2.CUBEMAP | DDSCAPS2.CUBEMAP_POSITIVEX
DDS_CUBEMAP_NEGATIVEX = DDSCAPS2.CUBEMAP | DDSCAPS2.CUBEMAP_NEGATIVEX
DDS_CUBEMAP_POSITIVEY = DDSCAPS2.CUBEMAP | DDSCAPS2.CUBEMAP_POSITIVEY
DDS_CUBEMAP_NEGATIVEY = DDSCAPS2.CUBEMAP | DDSCAPS2.CUBEMAP_NEGATIVEY
DDS_CUBEMAP_POSITIVEZ = DDSCAPS2.CUBEMAP | DDSCAPS2.CUBEMAP_POSITIVEZ
DDS_CUBEMAP_NEGATIVEZ = DDSCAPS2.CUBEMAP | DDSCAPS2.CUBEMAP_NEGATIVEZ
DXT1_FOURCC = D3DFMT.DXT1
DXT3_FOURCC = D3DFMT.DXT3
DXT5_FOURCC = D3DFMT.DXT5
DXGI_FORMAT_R8G8B8A8_TYPELESS = DXGI_FORMAT.R8G8B8A8_TYPELESS
DXGI_FORMAT_R8G8B8A8_UNORM = DXGI_FORMAT.R8G8B8A8_UNORM
DXGI_FORMAT_R8G8B8A8_UNORM_SRGB = DXGI_FORMAT.R8G8B8A8_UNORM_SRGB
DXGI_FORMAT_BC5_TYPELESS = DXGI_FORMAT.BC5_TYPELESS
DXGI_FORMAT_BC5_UNORM = DXGI_FORMAT.BC5_UNORM
DXGI_FORMAT_BC5_SNORM = DXGI_FORMAT.BC5_SNORM
DXGI_FORMAT_BC6H_UF16 = DXGI_FORMAT.BC6H_UF16
DXGI_FORMAT_BC6H_SF16 = DXGI_FORMAT.BC6H_SF16
DXGI_FORMAT_BC7_TYPELESS = DXGI_FORMAT.BC7_TYPELESS
DXGI_FORMAT_BC7_UNORM = DXGI_FORMAT.BC7_UNORM
DXGI_FORMAT_BC7_UNORM_SRGB = DXGI_FORMAT.BC7_UNORM_SRGB
class DdsImageFile(ImageFile.ImageFile): class DdsImageFile(ImageFile.ImageFile):
format = "DDS" format = "DDS"
format_description = "DirectDraw Surface" format_description = "DirectDraw Surface"
def _open(self): def _open(self) -> None:
if not _accept(self.fp.read(4)): if not _accept(self.fp.read(4)):
msg = "not a DDS file" msg = "not a DDS file"
raise SyntaxError(msg) raise SyntaxError(msg)
@@ -124,172 +344,281 @@ class DdsImageFile(ImageFile.ImageFile):
if len(header_bytes) != 120: if len(header_bytes) != 120:
msg = f"Incomplete header: {len(header_bytes)} bytes" msg = f"Incomplete header: {len(header_bytes)} bytes"
raise OSError(msg) raise OSError(msg)
header = BytesIO(header_bytes) header = io.BytesIO(header_bytes)
flags, height, width = struct.unpack("<3I", header.read(12)) flags, height, width = struct.unpack("<3I", header.read(12))
self._size = (width, height) self._size = (width, height)
self._mode = "RGBA" extents = (0, 0) + self.size
pitch, depth, mipmaps = struct.unpack("<3I", header.read(12)) pitch, depth, mipmaps = struct.unpack("<3I", header.read(12))
struct.unpack("<11I", header.read(44)) # reserved struct.unpack("<11I", header.read(44)) # reserved
# pixel format # pixel format
pfsize, pfflags = struct.unpack("<2I", header.read(8)) pfsize, pfflags, fourcc, bitcount = struct.unpack("<4I", header.read(16))
fourcc = header.read(4) n = 0
(bitcount,) = struct.unpack("<I", header.read(4)) rawmode = None
masks = struct.unpack("<4I", header.read(16)) if pfflags & DDPF.RGB:
if pfflags & DDPF_LUMINANCE: # Texture contains uncompressed RGB data
# Texture contains uncompressed L or LA data if pfflags & DDPF.ALPHAPIXELS:
if pfflags & DDPF_ALPHAPIXELS: self._mode = "RGBA"
mask_count = 4
else:
self._mode = "RGB"
mask_count = 3
masks = struct.unpack(f"<{mask_count}I", header.read(mask_count * 4))
self.tile = [ImageFile._Tile("dds_rgb", extents, 0, (bitcount, masks))]
return
elif pfflags & DDPF.LUMINANCE:
if bitcount == 8:
self._mode = "L"
elif bitcount == 16 and pfflags & DDPF.ALPHAPIXELS:
self._mode = "LA" self._mode = "LA"
else: else:
self._mode = "L" msg = f"Unsupported bitcount {bitcount} for {pfflags}"
raise OSError(msg)
self.tile = [("raw", (0, 0) + self.size, 0, (self.mode, 0, 1))] elif pfflags & DDPF.PALETTEINDEXED8:
elif pfflags & DDPF_RGB:
# Texture contains uncompressed RGB data
masks = {mask: ["R", "G", "B", "A"][i] for i, mask in enumerate(masks)}
rawmode = ""
if pfflags & DDPF_ALPHAPIXELS:
rawmode += masks[0xFF000000]
else:
self._mode = "RGB"
rawmode += masks[0xFF0000] + masks[0xFF00] + masks[0xFF]
self.tile = [("raw", (0, 0) + self.size, 0, (rawmode[::-1], 0, 1))]
elif pfflags & DDPF_PALETTEINDEXED8:
self._mode = "P" self._mode = "P"
self.palette = ImagePalette.raw("RGBA", self.fp.read(1024)) self.palette = ImagePalette.raw("RGBA", self.fp.read(1024))
self.tile = [("raw", (0, 0) + self.size, 0, "L")] self.palette.mode = "RGBA"
else: elif pfflags & DDPF.FOURCC:
data_start = header_size + 4 offset = header_size + 4
n = 0 if fourcc == D3DFMT.DXT1:
if fourcc == b"DXT1": self._mode = "RGBA"
self.pixel_format = "DXT1" self.pixel_format = "DXT1"
n = 1 n = 1
elif fourcc == b"DXT3": elif fourcc == D3DFMT.DXT3:
self._mode = "RGBA"
self.pixel_format = "DXT3" self.pixel_format = "DXT3"
n = 2 n = 2
elif fourcc == b"DXT5": elif fourcc == D3DFMT.DXT5:
self._mode = "RGBA"
self.pixel_format = "DXT5" self.pixel_format = "DXT5"
n = 3 n = 3
elif fourcc == b"ATI1": elif fourcc in (D3DFMT.BC4U, D3DFMT.ATI1):
self._mode = "L"
self.pixel_format = "BC4" self.pixel_format = "BC4"
n = 4 n = 4
self._mode = "L" elif fourcc == D3DFMT.BC5S:
elif fourcc in (b"ATI2", b"BC5U"):
self.pixel_format = "BC5"
n = 5
self._mode = "RGB" self._mode = "RGB"
elif fourcc == b"BC5S":
self.pixel_format = "BC5S" self.pixel_format = "BC5S"
n = 5 n = 5
elif fourcc in (D3DFMT.BC5U, D3DFMT.ATI2):
self._mode = "RGB" self._mode = "RGB"
elif fourcc == b"DX10": self.pixel_format = "BC5"
data_start += 20 n = 5
elif fourcc == D3DFMT.DX10:
offset += 20
# ignoring flags which pertain to volume textures and cubemaps # ignoring flags which pertain to volume textures and cubemaps
(dxgi_format,) = struct.unpack("<I", self.fp.read(4)) (dxgi_format,) = struct.unpack("<I", self.fp.read(4))
self.fp.read(16) self.fp.read(16)
if dxgi_format in (DXGI_FORMAT_BC5_TYPELESS, DXGI_FORMAT_BC5_UNORM): if dxgi_format in (
DXGI_FORMAT.BC1_UNORM,
DXGI_FORMAT.BC1_TYPELESS,
):
self._mode = "RGBA"
self.pixel_format = "BC1"
n = 1
elif dxgi_format in (DXGI_FORMAT.BC2_TYPELESS, DXGI_FORMAT.BC2_UNORM):
self._mode = "RGBA"
self.pixel_format = "BC2"
n = 2
elif dxgi_format in (DXGI_FORMAT.BC3_TYPELESS, DXGI_FORMAT.BC3_UNORM):
self._mode = "RGBA"
self.pixel_format = "BC3"
n = 3
elif dxgi_format in (DXGI_FORMAT.BC4_TYPELESS, DXGI_FORMAT.BC4_UNORM):
self._mode = "L"
self.pixel_format = "BC4"
n = 4
elif dxgi_format in (DXGI_FORMAT.BC5_TYPELESS, DXGI_FORMAT.BC5_UNORM):
self._mode = "RGB"
self.pixel_format = "BC5" self.pixel_format = "BC5"
n = 5 n = 5
elif dxgi_format == DXGI_FORMAT.BC5_SNORM:
self._mode = "RGB" self._mode = "RGB"
elif dxgi_format == DXGI_FORMAT_BC5_SNORM:
self.pixel_format = "BC5S" self.pixel_format = "BC5S"
n = 5 n = 5
elif dxgi_format == DXGI_FORMAT.BC6H_UF16:
self._mode = "RGB" self._mode = "RGB"
elif dxgi_format == DXGI_FORMAT_BC6H_UF16:
self.pixel_format = "BC6H" self.pixel_format = "BC6H"
n = 6 n = 6
elif dxgi_format == DXGI_FORMAT.BC6H_SF16:
self._mode = "RGB" self._mode = "RGB"
elif dxgi_format == DXGI_FORMAT_BC6H_SF16:
self.pixel_format = "BC6HS" self.pixel_format = "BC6HS"
n = 6 n = 6
self._mode = "RGB"
elif dxgi_format in (DXGI_FORMAT_BC7_TYPELESS, DXGI_FORMAT_BC7_UNORM):
self.pixel_format = "BC7"
n = 7
elif dxgi_format == DXGI_FORMAT_BC7_UNORM_SRGB:
self.pixel_format = "BC7"
self.info["gamma"] = 1 / 2.2
n = 7
elif dxgi_format in ( elif dxgi_format in (
DXGI_FORMAT_R8G8B8A8_TYPELESS, DXGI_FORMAT.BC7_TYPELESS,
DXGI_FORMAT_R8G8B8A8_UNORM, DXGI_FORMAT.BC7_UNORM,
DXGI_FORMAT_R8G8B8A8_UNORM_SRGB, DXGI_FORMAT.BC7_UNORM_SRGB,
): ):
self.tile = [("raw", (0, 0) + self.size, 0, ("RGBA", 0, 1))] self._mode = "RGBA"
if dxgi_format == DXGI_FORMAT_R8G8B8A8_UNORM_SRGB: self.pixel_format = "BC7"
n = 7
if dxgi_format == DXGI_FORMAT.BC7_UNORM_SRGB:
self.info["gamma"] = 1 / 2.2
elif dxgi_format in (
DXGI_FORMAT.R8G8B8A8_TYPELESS,
DXGI_FORMAT.R8G8B8A8_UNORM,
DXGI_FORMAT.R8G8B8A8_UNORM_SRGB,
):
self._mode = "RGBA"
if dxgi_format == DXGI_FORMAT.R8G8B8A8_UNORM_SRGB:
self.info["gamma"] = 1 / 2.2 self.info["gamma"] = 1 / 2.2
return
else: else:
msg = f"Unimplemented DXGI format {dxgi_format}" msg = f"Unimplemented DXGI format {dxgi_format}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
else: else:
msg = f"Unimplemented pixel format {repr(fourcc)}" msg = f"Unimplemented pixel format {repr(fourcc)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
else:
msg = f"Unknown pixel format flags {pfflags}"
raise NotImplementedError(msg)
if n:
self.tile = [ self.tile = [
("bcn", (0, 0) + self.size, data_start, (n, self.pixel_format)) ImageFile._Tile("bcn", extents, offset, (n, self.pixel_format))
] ]
else:
self.tile = [ImageFile._Tile("raw", extents, 0, rawmode or self.mode)]
def load_seek(self, pos): def load_seek(self, pos: int) -> None:
pass pass
def _save(im, fp, filename): class DdsRgbDecoder(ImageFile.PyDecoder):
_pulls_fd = True
def decode(self, buffer: bytes | Image.SupportsArrayInterface) -> tuple[int, int]:
assert self.fd is not None
bitcount, masks = self.args
# Some masks will be padded with zeros, e.g. R 0b11 G 0b1100
# Calculate how many zeros each mask is padded with
mask_offsets = []
# And the maximum value of each channel without the padding
mask_totals = []
for mask in masks:
offset = 0
if mask != 0:
while mask >> (offset + 1) << (offset + 1) == mask:
offset += 1
mask_offsets.append(offset)
mask_totals.append(mask >> offset)
data = bytearray()
bytecount = bitcount // 8
dest_length = self.state.xsize * self.state.ysize * len(masks)
while len(data) < dest_length:
value = int.from_bytes(self.fd.read(bytecount), "little")
for i, mask in enumerate(masks):
masked_value = value & mask
# Remove the zero padding, and scale it to 8 bits
data += o8(
int(((masked_value >> mask_offsets[i]) / mask_totals[i]) * 255)
)
self.set_as_raw(data)
return -1, 0
def _save(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
if im.mode not in ("RGB", "RGBA", "L", "LA"): if im.mode not in ("RGB", "RGBA", "L", "LA"):
msg = f"cannot write mode {im.mode} as DDS" msg = f"cannot write mode {im.mode} as DDS"
raise OSError(msg) raise OSError(msg)
rawmode = im.mode flags = DDSD.CAPS | DDSD.HEIGHT | DDSD.WIDTH | DDSD.PIXELFORMAT
masks = [0xFF0000, 0xFF00, 0xFF] bitcount = len(im.getbands()) * 8
if im.mode in ("L", "LA"): pixel_format = im.encoderinfo.get("pixel_format")
pixel_flags = DDPF_LUMINANCE args: tuple[int] | str
if pixel_format:
codec_name = "bcn"
flags |= DDSD.LINEARSIZE
pitch = (im.width + 3) * 4
rgba_mask = [0, 0, 0, 0]
pixel_flags = DDPF.FOURCC
if pixel_format == "DXT1":
fourcc = D3DFMT.DXT1
args = (1,)
elif pixel_format == "DXT3":
fourcc = D3DFMT.DXT3
args = (2,)
elif pixel_format == "DXT5":
fourcc = D3DFMT.DXT5
args = (3,)
else:
fourcc = D3DFMT.DX10
if pixel_format == "BC2":
args = (2,)
dxgi_format = DXGI_FORMAT.BC2_TYPELESS
elif pixel_format == "BC3":
args = (3,)
dxgi_format = DXGI_FORMAT.BC3_TYPELESS
elif pixel_format == "BC5":
args = (5,)
dxgi_format = DXGI_FORMAT.BC5_TYPELESS
if im.mode != "RGB":
msg = "only RGB mode can be written as BC5"
raise OSError(msg)
else:
msg = f"cannot write pixel format {pixel_format}"
raise OSError(msg)
else: else:
pixel_flags = DDPF_RGB codec_name = "raw"
rawmode = rawmode[::-1] flags |= DDSD.PITCH
if im.mode in ("LA", "RGBA"): pitch = (im.width * bitcount + 7) // 8
pixel_flags |= DDPF_ALPHAPIXELS
masks.append(0xFF000000)
bitcount = len(masks) * 8 alpha = im.mode[-1] == "A"
while len(masks) < 4: if im.mode[0] == "L":
masks.append(0) pixel_flags = DDPF.LUMINANCE
args = im.mode
if alpha:
rgba_mask = [0x000000FF, 0x000000FF, 0x000000FF]
else:
rgba_mask = [0xFF000000, 0xFF000000, 0xFF000000]
else:
pixel_flags = DDPF.RGB
args = im.mode[::-1]
rgba_mask = [0x00FF0000, 0x0000FF00, 0x000000FF]
if alpha:
r, g, b, a = im.split()
im = Image.merge("RGBA", (a, r, g, b))
if alpha:
pixel_flags |= DDPF.ALPHAPIXELS
rgba_mask.append(0xFF000000 if alpha else 0)
fourcc = D3DFMT.UNKNOWN
fp.write( fp.write(
o32(DDS_MAGIC) o32(DDS_MAGIC)
+ o32(124) # header size + struct.pack(
+ o32( "<7I",
DDSD_CAPS | DDSD_HEIGHT | DDSD_WIDTH | DDSD_PITCH | DDSD_PIXELFORMAT 124, # header size
) # flags flags, # flags
+ o32(im.height) im.height,
+ o32(im.width) im.width,
+ o32((im.width * bitcount + 7) // 8) # pitch pitch,
+ o32(0) # depth 0, # depth
+ o32(0) # mipmaps 0, # mipmaps
+ o32(0) * 11 # reserved )
+ o32(32) # pfsize + struct.pack("11I", *((0,) * 11)) # reserved
+ o32(pixel_flags) # pfflags # pfsize, pfflags, fourcc, bitcount
+ o32(0) # fourcc + struct.pack("<4I", 32, pixel_flags, fourcc, bitcount)
+ o32(bitcount) # bitcount + struct.pack("<4I", *rgba_mask) # dwRGBABitMask
+ b"".join(o32(mask) for mask in masks) # rgbabitmask + struct.pack("<5I", DDSCAPS.TEXTURE, 0, 0, 0, 0)
+ o32(DDSCAPS_TEXTURE) # dwCaps
+ o32(0) # dwCaps2
+ o32(0) # dwCaps3
+ o32(0) # dwCaps4
+ o32(0) # dwReserved2
) )
if im.mode == "RGBA": if fourcc == D3DFMT.DX10:
r, g, b, a = im.split() fp.write(
im = Image.merge("RGBA", (a, r, g, b)) # dxgi_format, 2D resource, misc, array size, straight alpha
ImageFile._save(im, fp, [("raw", (0, 0) + im.size, 0, (rawmode, 0, 1))]) struct.pack("<5I", dxgi_format, 3, 0, 0, 1)
)
ImageFile._save(im, fp, [ImageFile._Tile(codec_name, (0, 0) + im.size, 0, args)])
def _accept(prefix): def _accept(prefix: bytes) -> bool:
return prefix[:4] == b"DDS " return prefix.startswith(b"DDS ")
Image.register_open(DdsImageFile.format, DdsImageFile, _accept) Image.register_open(DdsImageFile.format, DdsImageFile, _accept)
Image.register_decoder("dds_rgb", DdsRgbDecoder)
Image.register_save(DdsImageFile.format, _save) Image.register_save(DdsImageFile.format, _save)
Image.register_extension(DdsImageFile.format, ".dds") Image.register_extension(DdsImageFile.format, ".dds")

View File

@@ -19,6 +19,7 @@
# #
# See the README file for information on usage and redistribution. # See the README file for information on usage and redistribution.
# #
from __future__ import annotations
import io import io
import os import os
@@ -26,10 +27,10 @@ import re
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
from typing import IO
from . import Image, ImageFile from . import Image, ImageFile
from ._binary import i32le as i32 from ._binary import i32le as i32
from ._deprecate import deprecate
# -------------------------------------------------------------------- # --------------------------------------------------------------------
@@ -37,11 +38,11 @@ from ._deprecate import deprecate
split = re.compile(r"^%%([^:]*):[ \t]*(.*)[ \t]*$") split = re.compile(r"^%%([^:]*):[ \t]*(.*)[ \t]*$")
field = re.compile(r"^%[%!\w]([^:]*)[ \t]*$") field = re.compile(r"^%[%!\w]([^:]*)[ \t]*$")
gs_binary = None gs_binary: str | bool | None = None
gs_windows_binary = None gs_windows_binary = None
def has_ghostscript(): def has_ghostscript() -> bool:
global gs_binary, gs_windows_binary global gs_binary, gs_windows_binary
if gs_binary is None: if gs_binary is None:
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
@@ -64,27 +65,32 @@ def has_ghostscript():
return gs_binary is not False return gs_binary is not False
def Ghostscript(tile, size, fp, scale=1, transparency=False): def Ghostscript(
tile: list[ImageFile._Tile],
size: tuple[int, int],
fp: IO[bytes],
scale: int = 1,
transparency: bool = False,
) -> Image.core.ImagingCore:
"""Render an image using Ghostscript""" """Render an image using Ghostscript"""
global gs_binary global gs_binary
if not has_ghostscript(): if not has_ghostscript():
msg = "Unable to locate Ghostscript on paths" msg = "Unable to locate Ghostscript on paths"
raise OSError(msg) raise OSError(msg)
assert isinstance(gs_binary, str)
# Unpack decoder tile # Unpack decoder tile
decoder, tile, offset, data = tile[0] args = tile[0].args
length, bbox = data assert isinstance(args, tuple)
length, bbox = args
# Hack to support hi-res rendering # Hack to support hi-res rendering
scale = int(scale) or 1 scale = int(scale) or 1
# orig_size = size width = size[0] * scale
# orig_bbox = bbox height = size[1] * scale
size = (size[0] * scale, size[1] * scale)
# resolution is dependent on bbox and size # resolution is dependent on bbox and size
res = ( res_x = 72.0 * width / (bbox[2] - bbox[0])
72.0 * size[0] / (bbox[2] - bbox[0]), res_y = 72.0 * height / (bbox[3] - bbox[1])
72.0 * size[1] / (bbox[3] - bbox[1]),
)
out_fd, outfile = tempfile.mkstemp() out_fd, outfile = tempfile.mkstemp()
os.close(out_fd) os.close(out_fd)
@@ -115,14 +121,20 @@ def Ghostscript(tile, size, fp, scale=1, transparency=False):
lengthfile -= len(s) lengthfile -= len(s)
f.write(s) f.write(s)
device = "pngalpha" if transparency else "ppmraw" if transparency:
# "RGBA"
device = "pngalpha"
else:
# "pnmraw" automatically chooses between
# PBM ("1"), PGM ("L"), and PPM ("RGB").
device = "pnmraw"
# Build Ghostscript command # Build Ghostscript command
command = [ command = [
gs_binary, gs_binary,
"-q", # quiet mode "-q", # quiet mode
"-g%dx%d" % size, # set output geometry (pixels) f"-g{width:d}x{height:d}", # set output geometry (pixels)
"-r%fx%f" % res, # set input DPI (dots per inch) f"-r{res_x:f}x{res_y:f}", # set input DPI (dots per inch)
"-dBATCH", # exit after processing "-dBATCH", # exit after processing
"-dNOPAUSE", # don't pause between pages "-dNOPAUSE", # don't pause between pages
"-dSAFER", # safe mode "-dSAFER", # safe mode
@@ -145,8 +157,9 @@ def Ghostscript(tile, size, fp, scale=1, transparency=False):
startupinfo = subprocess.STARTUPINFO() startupinfo = subprocess.STARTUPINFO()
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
subprocess.check_call(command, startupinfo=startupinfo) subprocess.check_call(command, startupinfo=startupinfo)
out_im = Image.open(outfile) with Image.open(outfile) as out_im:
out_im.load() out_im.load()
return out_im.im.copy()
finally: finally:
try: try:
os.unlink(outfile) os.unlink(outfile)
@@ -155,50 +168,11 @@ def Ghostscript(tile, size, fp, scale=1, transparency=False):
except OSError: except OSError:
pass pass
im = out_im.im.copy()
out_im.close()
return im
def _accept(prefix: bytes) -> bool:
class PSFile: return prefix.startswith(b"%!PS") or (
""" len(prefix) >= 4 and i32(prefix) == 0xC6D3D0C5
Wrapper for bytesio object that treats either CR or LF as end of line. )
This class is no longer used internally, but kept for backwards compatibility.
"""
def __init__(self, fp):
deprecate(
"PSFile",
11,
action="If you need the functionality of this class "
"you will need to implement it yourself.",
)
self.fp = fp
self.char = None
def seek(self, offset, whence=io.SEEK_SET):
self.char = None
self.fp.seek(offset, whence)
def readline(self):
s = [self.char or b""]
self.char = None
c = self.fp.read(1)
while (c not in b"\r\n") and len(c):
s.append(c)
c = self.fp.read(1)
self.char = self.fp.read(1)
# line endings can be 1 or 2 of \r \n, in either order
if self.char in b"\r\n":
self.char = None
return b"".join(s).decode("latin-1")
def _accept(prefix):
return prefix[:4] == b"%!PS" or (len(prefix) >= 4 and i32(prefix) == 0xC6D3D0C5)
## ##
@@ -214,14 +188,18 @@ class EpsImageFile(ImageFile.ImageFile):
mode_map = {1: "L", 2: "LAB", 3: "RGB", 4: "CMYK"} mode_map = {1: "L", 2: "LAB", 3: "RGB", 4: "CMYK"}
def _open(self): def _open(self) -> None:
(length, offset) = self._find_offset(self.fp) (length, offset) = self._find_offset(self.fp)
# go to offset - start of "%!PS" # go to offset - start of "%!PS"
self.fp.seek(offset) self.fp.seek(offset)
self._mode = "RGB" self._mode = "RGB"
self._size = None
# When reading header comments, the first comment is used.
# When reading trailer comments, the last comment is used.
bounding_box: list[int] | None = None
imagedata_size: tuple[int, int] | None = None
byte_arr = bytearray(255) byte_arr = bytearray(255)
bytes_mv = memoryview(byte_arr) bytes_mv = memoryview(byte_arr)
@@ -230,7 +208,12 @@ class EpsImageFile(ImageFile.ImageFile):
reading_trailer_comments = False reading_trailer_comments = False
trailer_reached = False trailer_reached = False
def check_required_header_comments(): def check_required_header_comments() -> None:
"""
The EPS specification requires that some headers exist.
This should be checked when the header comments formally end,
when image data starts, or when the file ends, whichever comes first.
"""
if "PS-Adobe" not in self.info: if "PS-Adobe" not in self.info:
msg = 'EPS header missing "%!PS-Adobe" comment' msg = 'EPS header missing "%!PS-Adobe" comment'
raise SyntaxError(msg) raise SyntaxError(msg)
@@ -238,41 +221,39 @@ class EpsImageFile(ImageFile.ImageFile):
msg = 'EPS header missing "%%BoundingBox" comment' msg = 'EPS header missing "%%BoundingBox" comment'
raise SyntaxError(msg) raise SyntaxError(msg)
def _read_comment(s): def read_comment(s: str) -> bool:
nonlocal reading_trailer_comments nonlocal bounding_box, reading_trailer_comments
try: try:
m = split.match(s) m = split.match(s)
except re.error as e: except re.error as e:
msg = "not an EPS file" msg = "not an EPS file"
raise SyntaxError(msg) from e raise SyntaxError(msg) from e
if m: if not m:
k, v = m.group(1, 2) return False
self.info[k] = v
if k == "BoundingBox": k, v = m.group(1, 2)
if v == "(atend)": self.info[k] = v
reading_trailer_comments = True if k == "BoundingBox":
elif not self._size or ( if v == "(atend)":
trailer_reached and reading_trailer_comments reading_trailer_comments = True
): elif not bounding_box or (trailer_reached and reading_trailer_comments):
try: try:
# Note: The DSC spec says that BoundingBox # Note: The DSC spec says that BoundingBox
# fields should be integers, but some drivers # fields should be integers, but some drivers
# put floating point values there anyway. # put floating point values there anyway.
box = [int(float(i)) for i in v.split()] bounding_box = [int(float(i)) for i in v.split()]
self._size = box[2] - box[0], box[3] - box[1] except Exception:
self.tile = [ pass
("eps", (0, 0) + self.size, offset, (length, box)) return True
]
except Exception:
pass
return True
while True: while True:
byte = self.fp.read(1) byte = self.fp.read(1)
if byte == b"": if byte == b"":
# if we didn't read a byte we must be at the end of the file # if we didn't read a byte we must be at the end of the file
if bytes_read == 0: if bytes_read == 0:
if reading_header_comments:
check_required_header_comments()
break break
elif byte in b"\r\n": elif byte in b"\r\n":
# if we read a line ending character, ignore it and parse what # if we read a line ending character, ignore it and parse what
@@ -312,11 +293,11 @@ class EpsImageFile(ImageFile.ImageFile):
continue continue
s = str(bytes_mv[:bytes_read], "latin-1") s = str(bytes_mv[:bytes_read], "latin-1")
if not _read_comment(s): if not read_comment(s):
m = field.match(s) m = field.match(s)
if m: if m:
k = m.group(1) k = m.group(1)
if k[:8] == "PS-Adobe": if k.startswith("PS-Adobe"):
self.info["PS-Adobe"] = k[9:] self.info["PS-Adobe"] = k[9:]
else: else:
self.info[k] = "" self.info[k] = ""
@@ -331,6 +312,12 @@ class EpsImageFile(ImageFile.ImageFile):
# Check for an "ImageData" descriptor # Check for an "ImageData" descriptor
# https://www.adobe.com/devnet-apps/photoshop/fileformatashtml/#50577413_pgfId-1035096 # https://www.adobe.com/devnet-apps/photoshop/fileformatashtml/#50577413_pgfId-1035096
# If we've already read an "ImageData" descriptor,
# don't read another one.
if imagedata_size:
bytes_read = 0
continue
# Values: # Values:
# columns # columns
# rows # rows
@@ -356,29 +343,39 @@ class EpsImageFile(ImageFile.ImageFile):
else: else:
break break
self._size = columns, rows # Parse the columns and rows after checking the bit depth and mode
return # in case the bit depth and/or mode are invalid.
imagedata_size = columns, rows
elif bytes_mv[:5] == b"%%EOF":
break
elif trailer_reached and reading_trailer_comments: elif trailer_reached and reading_trailer_comments:
# Load EPS trailer # Load EPS trailer
# if this line starts with "%%EOF",
# then we've reached the end of the file
if bytes_mv[:5] == b"%%EOF":
break
s = str(bytes_mv[:bytes_read], "latin-1") s = str(bytes_mv[:bytes_read], "latin-1")
_read_comment(s) read_comment(s)
elif bytes_mv[:9] == b"%%Trailer": elif bytes_mv[:9] == b"%%Trailer":
trailer_reached = True trailer_reached = True
elif bytes_mv[:14] == b"%%BeginBinary:":
bytecount = int(byte_arr[14:bytes_read])
self.fp.seek(bytecount, os.SEEK_CUR)
bytes_read = 0 bytes_read = 0
check_required_header_comments() # A "BoundingBox" is always required,
# even if an "ImageData" descriptor size exists.
if not self._size: if not bounding_box:
msg = "cannot determine EPS bounding box" msg = "cannot determine EPS bounding box"
raise OSError(msg) raise OSError(msg)
def _find_offset(self, fp): # An "ImageData" size takes precedence over the "BoundingBox".
self._size = imagedata_size or (
bounding_box[2] - bounding_box[0],
bounding_box[3] - bounding_box[1],
)
self.tile = [
ImageFile._Tile("eps", (0, 0) + self.size, offset, (length, bounding_box))
]
def _find_offset(self, fp: IO[bytes]) -> tuple[int, int]:
s = fp.read(4) s = fp.read(4)
if s == b"%!PS": if s == b"%!PS":
@@ -401,7 +398,9 @@ class EpsImageFile(ImageFile.ImageFile):
return length, offset return length, offset
def load(self, scale=1, transparency=False): def load(
self, scale: int = 1, transparency: bool = False
) -> Image.core.PixelAccess | None:
# Load EPS via Ghostscript # Load EPS via Ghostscript
if self.tile: if self.tile:
self.im = Ghostscript(self.tile, self.size, self.fp, scale, transparency) self.im = Ghostscript(self.tile, self.size, self.fp, scale, transparency)
@@ -410,7 +409,7 @@ class EpsImageFile(ImageFile.ImageFile):
self.tile = [] self.tile = []
return Image.Image.load(self) return Image.Image.load(self)
def load_seek(self, *args, **kwargs): def load_seek(self, pos: int) -> None:
# we can't incrementally load, so force ImageFile.parser to # we can't incrementally load, so force ImageFile.parser to
# use our custom load method by defining this method. # use our custom load method by defining this method.
pass pass
@@ -419,7 +418,7 @@ class EpsImageFile(ImageFile.ImageFile):
# -------------------------------------------------------------------- # --------------------------------------------------------------------
def _save(im, fp, filename, eps=1): def _save(im: Image.Image, fp: IO[bytes], filename: str | bytes, eps: int = 1) -> None:
"""EPS Writer for the Python Imaging Library.""" """EPS Writer for the Python Imaging Library."""
# make sure image data is available # make sure image data is available
@@ -460,7 +459,7 @@ def _save(im, fp, filename, eps=1):
if hasattr(fp, "flush"): if hasattr(fp, "flush"):
fp.flush() fp.flush()
ImageFile._save(im, fp, [("eps", (0, 0) + im.size, 0, None)]) ImageFile._save(im, fp, [ImageFile._Tile("eps", (0, 0) + im.size)])
fp.write(b"\n%%%%EndBinary\n") fp.write(b"\n%%%%EndBinary\n")
fp.write(b"grestore end\n") fp.write(b"grestore end\n")

View File

@@ -13,6 +13,7 @@
This module provides constants and clear-text names for various This module provides constants and clear-text names for various
well-known EXIF tags. well-known EXIF tags.
""" """
from __future__ import annotations
from enum import IntEnum from enum import IntEnum
@@ -302,38 +303,38 @@ TAGS = {
class GPS(IntEnum): class GPS(IntEnum):
GPSVersionID = 0 GPSVersionID = 0x00
GPSLatitudeRef = 1 GPSLatitudeRef = 0x01
GPSLatitude = 2 GPSLatitude = 0x02
GPSLongitudeRef = 3 GPSLongitudeRef = 0x03
GPSLongitude = 4 GPSLongitude = 0x04
GPSAltitudeRef = 5 GPSAltitudeRef = 0x05
GPSAltitude = 6 GPSAltitude = 0x06
GPSTimeStamp = 7 GPSTimeStamp = 0x07
GPSSatellites = 8 GPSSatellites = 0x08
GPSStatus = 9 GPSStatus = 0x09
GPSMeasureMode = 10 GPSMeasureMode = 0x0A
GPSDOP = 11 GPSDOP = 0x0B
GPSSpeedRef = 12 GPSSpeedRef = 0x0C
GPSSpeed = 13 GPSSpeed = 0x0D
GPSTrackRef = 14 GPSTrackRef = 0x0E
GPSTrack = 15 GPSTrack = 0x0F
GPSImgDirectionRef = 16 GPSImgDirectionRef = 0x10
GPSImgDirection = 17 GPSImgDirection = 0x11
GPSMapDatum = 18 GPSMapDatum = 0x12
GPSDestLatitudeRef = 19 GPSDestLatitudeRef = 0x13
GPSDestLatitude = 20 GPSDestLatitude = 0x14
GPSDestLongitudeRef = 21 GPSDestLongitudeRef = 0x15
GPSDestLongitude = 22 GPSDestLongitude = 0x16
GPSDestBearingRef = 23 GPSDestBearingRef = 0x17
GPSDestBearing = 24 GPSDestBearing = 0x18
GPSDestDistanceRef = 25 GPSDestDistanceRef = 0x19
GPSDestDistance = 26 GPSDestDistance = 0x1A
GPSProcessingMethod = 27 GPSProcessingMethod = 0x1B
GPSAreaInformation = 28 GPSAreaInformation = 0x1C
GPSDateStamp = 29 GPSDateStamp = 0x1D
GPSDifferential = 30 GPSDifferential = 0x1E
GPSHPositioningError = 31 GPSHPositioningError = 0x1F
"""Maps EXIF GPS tags to tag names.""" """Maps EXIF GPS tags to tag names."""
@@ -341,40 +342,41 @@ GPSTAGS = {i.value: i.name for i in GPS}
class Interop(IntEnum): class Interop(IntEnum):
InteropIndex = 1 InteropIndex = 0x0001
InteropVersion = 2 InteropVersion = 0x0002
RelatedImageFileFormat = 4096 RelatedImageFileFormat = 0x1000
RelatedImageWidth = 4097 RelatedImageWidth = 0x1001
RleatedImageHeight = 4098 RelatedImageHeight = 0x1002
class IFD(IntEnum): class IFD(IntEnum):
Exif = 34665 Exif = 0x8769
GPSInfo = 34853 GPSInfo = 0x8825
Makernote = 37500 MakerNote = 0x927C
Interop = 40965 Makernote = 0x927C # Deprecated
Interop = 0xA005
IFD1 = -1 IFD1 = -1
class LightSource(IntEnum): class LightSource(IntEnum):
Unknown = 0 Unknown = 0x00
Daylight = 1 Daylight = 0x01
Fluorescent = 2 Fluorescent = 0x02
Tungsten = 3 Tungsten = 0x03
Flash = 4 Flash = 0x04
Fine = 9 Fine = 0x09
Cloudy = 10 Cloudy = 0x0A
Shade = 11 Shade = 0x0B
DaylightFluorescent = 12 DaylightFluorescent = 0x0C
DayWhiteFluorescent = 13 DayWhiteFluorescent = 0x0D
CoolWhiteFluorescent = 14 CoolWhiteFluorescent = 0x0E
WhiteFluorescent = 15 WhiteFluorescent = 0x0F
StandardLightA = 17 StandardLightA = 0x11
StandardLightB = 18 StandardLightB = 0x12
StandardLightC = 19 StandardLightC = 0x13
D55 = 20 D55 = 0x14
D65 = 21 D65 = 0x15
D75 = 22 D75 = 0x16
D50 = 23 D50 = 0x17
ISO = 24 ISO = 0x18
Other = 255 Other = 0xFF

View File

@@ -8,30 +8,52 @@
# #
# See the README file for information on usage and redistribution. # See the README file for information on usage and redistribution.
# #
from __future__ import annotations
import gzip
import math import math
from . import Image, ImageFile from . import Image, ImageFile
def _accept(prefix): def _accept(prefix: bytes) -> bool:
return prefix[:6] == b"SIMPLE" return prefix.startswith(b"SIMPLE")
class FitsImageFile(ImageFile.ImageFile): class FitsImageFile(ImageFile.ImageFile):
format = "FITS" format = "FITS"
format_description = "FITS" format_description = "FITS"
def _open(self): def _open(self) -> None:
headers = {} assert self.fp is not None
headers: dict[bytes, bytes] = {}
header_in_progress = False
decoder_name = ""
while True: while True:
header = self.fp.read(80) header = self.fp.read(80)
if not header: if not header:
msg = "Truncated FITS file" msg = "Truncated FITS file"
raise OSError(msg) raise OSError(msg)
keyword = header[:8].strip() keyword = header[:8].strip()
if keyword == b"END": if keyword in (b"SIMPLE", b"XTENSION"):
header_in_progress = True
elif headers and not header_in_progress:
# This is now a data unit
break break
elif keyword == b"END":
# Seek to the end of the header unit
self.fp.seek(math.ceil(self.fp.tell() / 2880) * 2880)
if not decoder_name:
decoder_name, offset, args = self._parse_headers(headers)
header_in_progress = False
continue
if decoder_name:
# Keep going to read past the headers
continue
value = header[8:].split(b"/")[0].strip() value = header[8:].split(b"/")[0].strip()
if value.startswith(b"="): if value.startswith(b"="):
value = value[1:].strip() value = value[1:].strip()
@@ -40,34 +62,91 @@ class FitsImageFile(ImageFile.ImageFile):
raise SyntaxError(msg) raise SyntaxError(msg)
headers[keyword] = value headers[keyword] = value
naxis = int(headers[b"NAXIS"]) if not decoder_name:
if naxis == 0:
msg = "No image data" msg = "No image data"
raise ValueError(msg) raise ValueError(msg)
elif naxis == 1:
self._size = 1, int(headers[b"NAXIS1"])
else:
self._size = int(headers[b"NAXIS1"]), int(headers[b"NAXIS2"])
number_of_bits = int(headers[b"BITPIX"]) offset += self.fp.tell() - 80
self.tile = [ImageFile._Tile(decoder_name, (0, 0) + self.size, offset, args)]
def _get_size(
self, headers: dict[bytes, bytes], prefix: bytes
) -> tuple[int, int] | None:
naxis = int(headers[prefix + b"NAXIS"])
if naxis == 0:
return None
if naxis == 1:
return 1, int(headers[prefix + b"NAXIS1"])
else:
return int(headers[prefix + b"NAXIS1"]), int(headers[prefix + b"NAXIS2"])
def _parse_headers(
self, headers: dict[bytes, bytes]
) -> tuple[str, int, tuple[str | int, ...]]:
prefix = b""
decoder_name = "raw"
offset = 0
if (
headers.get(b"XTENSION") == b"'BINTABLE'"
and headers.get(b"ZIMAGE") == b"T"
and headers[b"ZCMPTYPE"] == b"'GZIP_1 '"
):
no_prefix_size = self._get_size(headers, prefix) or (0, 0)
number_of_bits = int(headers[b"BITPIX"])
offset = no_prefix_size[0] * no_prefix_size[1] * (number_of_bits // 8)
prefix = b"Z"
decoder_name = "fits_gzip"
size = self._get_size(headers, prefix)
if not size:
return "", 0, ()
self._size = size
number_of_bits = int(headers[prefix + b"BITPIX"])
if number_of_bits == 8: if number_of_bits == 8:
self._mode = "L" self._mode = "L"
elif number_of_bits == 16: elif number_of_bits == 16:
self._mode = "I" self._mode = "I;16"
# rawmode = "I;16S"
elif number_of_bits == 32: elif number_of_bits == 32:
self._mode = "I" self._mode = "I"
elif number_of_bits in (-32, -64): elif number_of_bits in (-32, -64):
self._mode = "F" self._mode = "F"
# rawmode = "F" if number_of_bits == -32 else "F;64F"
offset = math.ceil(self.fp.tell() / 2880) * 2880 args: tuple[str | int, ...]
self.tile = [("raw", (0, 0) + self.size, offset, (self.mode, 0, -1))] if decoder_name == "raw":
args = (self.mode, 0, -1)
else:
args = (number_of_bits,)
return decoder_name, offset, args
class FitsGzipDecoder(ImageFile.PyDecoder):
_pulls_fd = True
def decode(self, buffer: bytes | Image.SupportsArrayInterface) -> tuple[int, int]:
assert self.fd is not None
value = gzip.decompress(self.fd.read())
rows = []
offset = 0
number_of_bits = min(self.args[0] // 8, 4)
for y in range(self.state.ysize):
row = bytearray()
for x in range(self.state.xsize):
row += value[offset + (4 - number_of_bits) : offset + 4]
offset += 4
rows.append(row)
self.set_as_raw(bytes([pixel for row in rows[::-1] for pixel in row]))
return -1, 0
# -------------------------------------------------------------------- # --------------------------------------------------------------------
# Registry # Registry
Image.register_open(FitsImageFile.format, FitsImageFile, _accept) Image.register_open(FitsImageFile.format, FitsImageFile, _accept)
Image.register_decoder("fits_gzip", FitsGzipDecoder)
Image.register_extensions(FitsImageFile.format, [".fit", ".fits"]) Image.register_extensions(FitsImageFile.format, [".fit", ".fits"])

View File

@@ -14,6 +14,7 @@
# #
# See the README file for information on usage and redistribution. # See the README file for information on usage and redistribution.
# #
from __future__ import annotations
import os import os
@@ -21,14 +22,15 @@ from . import Image, ImageFile, ImagePalette
from ._binary import i16le as i16 from ._binary import i16le as i16
from ._binary import i32le as i32 from ._binary import i32le as i32
from ._binary import o8 from ._binary import o8
from ._util import DeferredError
# #
# decoder # decoder
def _accept(prefix): def _accept(prefix: bytes) -> bool:
return ( return (
len(prefix) >= 6 len(prefix) >= 16
and i16(prefix, 4) in [0xAF11, 0xAF12] and i16(prefix, 4) in [0xAF11, 0xAF12]
and i16(prefix, 14) in [0, 3] # flags and i16(prefix, 14) in [0, 3] # flags
) )
@@ -44,10 +46,16 @@ class FliImageFile(ImageFile.ImageFile):
format_description = "Autodesk FLI/FLC Animation" format_description = "Autodesk FLI/FLC Animation"
_close_exclusive_fp_after_loading = False _close_exclusive_fp_after_loading = False
def _open(self): def _open(self) -> None:
# HEAD # HEAD
assert self.fp is not None
s = self.fp.read(128) s = self.fp.read(128)
if not (_accept(s) and s[20:22] == b"\x00\x00"): if not (
_accept(s)
and s[20:22] == b"\x00" * 2
and s[42:80] == b"\x00" * 38
and s[88:] == b"\x00" * 40
):
msg = "not an FLI/FLC file" msg = "not an FLI/FLC file"
raise SyntaxError(msg) raise SyntaxError(msg)
@@ -75,13 +83,13 @@ class FliImageFile(ImageFile.ImageFile):
if i16(s, 4) == 0xF100: if i16(s, 4) == 0xF100:
# prefix chunk; ignore it # prefix chunk; ignore it
self.__offset = self.__offset + i32(s) self.fp.seek(self.__offset + i32(s))
s = self.fp.read(16) s = self.fp.read(16)
if i16(s, 4) == 0xF1FA: if i16(s, 4) == 0xF1FA:
# look for palette chunk # look for palette chunk
number_of_subchunks = i16(s, 6) number_of_subchunks = i16(s, 6)
chunk_size = None chunk_size: int | None = None
for _ in range(number_of_subchunks): for _ in range(number_of_subchunks):
if chunk_size is not None: if chunk_size is not None:
self.fp.seek(chunk_size - 6, os.SEEK_CUR) self.fp.seek(chunk_size - 6, os.SEEK_CUR)
@@ -94,8 +102,9 @@ class FliImageFile(ImageFile.ImageFile):
if not chunk_size: if not chunk_size:
break break
palette = [o8(r) + o8(g) + o8(b) for (r, g, b) in palette] self.palette = ImagePalette.raw(
self.palette = ImagePalette.raw("RGB", b"".join(palette)) "RGB", b"".join(o8(r) + o8(g) + o8(b) for (r, g, b) in palette)
)
# set things up to decode first frame # set things up to decode first frame
self.__frame = -1 self.__frame = -1
@@ -103,10 +112,11 @@ class FliImageFile(ImageFile.ImageFile):
self.__rewind = self.fp.tell() self.__rewind = self.fp.tell()
self.seek(0) self.seek(0)
def _palette(self, palette, shift): def _palette(self, palette: list[tuple[int, int, int]], shift: int) -> None:
# load palette # load palette
i = 0 i = 0
assert self.fp is not None
for e in range(i16(self.fp.read(2))): for e in range(i16(self.fp.read(2))):
s = self.fp.read(2) s = self.fp.read(2)
i = i + s[0] i = i + s[0]
@@ -121,7 +131,7 @@ class FliImageFile(ImageFile.ImageFile):
palette[i] = (r, g, b) palette[i] = (r, g, b)
i += 1 i += 1
def seek(self, frame): def seek(self, frame: int) -> None:
if not self._seek_check(frame): if not self._seek_check(frame):
return return
if frame < self.__frame: if frame < self.__frame:
@@ -130,7 +140,9 @@ class FliImageFile(ImageFile.ImageFile):
for f in range(self.__frame + 1, frame + 1): for f in range(self.__frame + 1, frame + 1):
self._seek(f) self._seek(f)
def _seek(self, frame): def _seek(self, frame: int) -> None:
if isinstance(self._fp, DeferredError):
raise self._fp.ex
if frame == 0: if frame == 0:
self.__frame = -1 self.__frame = -1
self._fp.seek(self.__rewind) self._fp.seek(self.__rewind)
@@ -150,16 +162,17 @@ class FliImageFile(ImageFile.ImageFile):
s = self.fp.read(4) s = self.fp.read(4)
if not s: if not s:
raise EOFError msg = "missing frame size"
raise EOFError(msg)
framesize = i32(s) framesize = i32(s)
self.decodermaxblock = framesize self.decodermaxblock = framesize
self.tile = [("fli", (0, 0) + self.size, self.__offset, None)] self.tile = [ImageFile._Tile("fli", (0, 0) + self.size, self.__offset)]
self.__offset += framesize self.__offset += framesize
def tell(self): def tell(self) -> int:
return self.__frame return self.__frame

View File

@@ -13,16 +13,19 @@
# #
# See the README file for information on usage and redistribution. # See the README file for information on usage and redistribution.
# #
from __future__ import annotations
import os import os
from typing import BinaryIO
from . import Image, _binary from . import Image, _binary
WIDTH = 800 WIDTH = 800
def puti16(fp, values): def puti16(
fp: BinaryIO, values: tuple[int, int, int, int, int, int, int, int, int, int]
) -> None:
"""Write network order (big-endian) 16-bit sequence""" """Write network order (big-endian) 16-bit sequence"""
for v in values: for v in values:
if v < 0: if v < 0:
@@ -33,16 +36,32 @@ def puti16(fp, values):
class FontFile: class FontFile:
"""Base class for raster font file handlers.""" """Base class for raster font file handlers."""
bitmap = None bitmap: Image.Image | None = None
def __init__(self): def __init__(self) -> None:
self.info = {} self.info: dict[bytes, bytes | int] = {}
self.glyph = [None] * 256 self.glyph: list[
tuple[
tuple[int, int],
tuple[int, int, int, int],
tuple[int, int, int, int],
Image.Image,
]
| None
] = [None] * 256
def __getitem__(self, ix): def __getitem__(self, ix: int) -> (
tuple[
tuple[int, int],
tuple[int, int, int, int],
tuple[int, int, int, int],
Image.Image,
]
| None
):
return self.glyph[ix] return self.glyph[ix]
def compile(self): def compile(self) -> None:
"""Create metrics and bitmap""" """Create metrics and bitmap"""
if self.bitmap: if self.bitmap:
@@ -51,7 +70,7 @@ class FontFile:
# create bitmap large enough to hold all data # create bitmap large enough to hold all data
h = w = maxwidth = 0 h = w = maxwidth = 0
lines = 1 lines = 1
for glyph in self: for glyph in self.glyph:
if glyph: if glyph:
d, dst, src, im = glyph d, dst, src, im = glyph
h = max(h, src[3] - src[1]) h = max(h, src[3] - src[1])
@@ -65,20 +84,22 @@ class FontFile:
ysize = lines * h ysize = lines * h
if xsize == 0 and ysize == 0: if xsize == 0 and ysize == 0:
return "" return
self.ysize = h self.ysize = h
# paste glyphs into bitmap # paste glyphs into bitmap
self.bitmap = Image.new("1", (xsize, ysize)) self.bitmap = Image.new("1", (xsize, ysize))
self.metrics = [None] * 256 self.metrics: list[
tuple[tuple[int, int], tuple[int, int, int, int], tuple[int, int, int, int]]
| None
] = [None] * 256
x = y = 0 x = y = 0
for i in range(256): for i in range(256):
glyph = self[i] glyph = self[i]
if glyph: if glyph:
d, dst, src, im = glyph d, dst, src, im = glyph
xx = src[2] - src[0] xx = src[2] - src[0]
# yy = src[3] - src[1]
x0, y0 = x, y x0, y0 = x, y
x = x + xx x = x + xx
if x > WIDTH: if x > WIDTH:
@@ -89,12 +110,15 @@ class FontFile:
self.bitmap.paste(im.crop(src), s) self.bitmap.paste(im.crop(src), s)
self.metrics[i] = d, dst, s self.metrics[i] = d, dst, s
def save(self, filename): def save(self, filename: str) -> None:
"""Save font""" """Save font"""
self.compile() self.compile()
# font data # font data
if not self.bitmap:
msg = "No bitmap created"
raise ValueError(msg)
self.bitmap.save(os.path.splitext(filename)[0] + ".pbm", "PNG") self.bitmap.save(os.path.splitext(filename)[0] + ".pbm", "PNG")
# font metrics # font metrics
@@ -105,6 +129,6 @@ class FontFile:
for id in range(256): for id in range(256):
m = self.metrics[id] m = self.metrics[id]
if not m: if not m:
puti16(fp, [0] * 10) puti16(fp, (0,) * 10)
else: else:
puti16(fp, m[0] + m[1] + m[2]) puti16(fp, m[0] + m[1] + m[2])

View File

@@ -14,6 +14,8 @@
# #
# See the README file for information on usage and redistribution. # See the README file for information on usage and redistribution.
# #
from __future__ import annotations
import olefile import olefile
from . import Image, ImageFile from . import Image, ImageFile
@@ -39,8 +41,8 @@ MODES = {
# -------------------------------------------------------------------- # --------------------------------------------------------------------
def _accept(prefix): def _accept(prefix: bytes) -> bool:
return prefix[:8] == olefile.MAGIC return prefix.startswith(olefile.MAGIC)
## ##
@@ -51,7 +53,7 @@ class FpxImageFile(ImageFile.ImageFile):
format = "FPX" format = "FPX"
format_description = "FlashPix" format_description = "FlashPix"
def _open(self): def _open(self) -> None:
# #
# read the OLE directory and see if this is a likely # read the OLE directory and see if this is a likely
# to be a FlashPix file # to be a FlashPix file
@@ -62,13 +64,14 @@ class FpxImageFile(ImageFile.ImageFile):
msg = "not an FPX file; invalid OLE file" msg = "not an FPX file; invalid OLE file"
raise SyntaxError(msg) from e raise SyntaxError(msg) from e
if self.ole.root.clsid != "56616700-C154-11CE-8553-00AA00A1F95B": root = self.ole.root
if not root or root.clsid != "56616700-C154-11CE-8553-00AA00A1F95B":
msg = "not an FPX file; bad root CLSID" msg = "not an FPX file; bad root CLSID"
raise SyntaxError(msg) raise SyntaxError(msg)
self._open_index(1) self._open_index(1)
def _open_index(self, index=1): def _open_index(self, index: int = 1) -> None:
# #
# get the Image Contents Property Set # get the Image Contents Property Set
@@ -78,12 +81,14 @@ class FpxImageFile(ImageFile.ImageFile):
# size (highest resolution) # size (highest resolution)
assert isinstance(prop[0x1000002], int)
assert isinstance(prop[0x1000003], int)
self._size = prop[0x1000002], prop[0x1000003] self._size = prop[0x1000002], prop[0x1000003]
size = max(self.size) size = max(self.size)
i = 1 i = 1
while size > 64: while size > 64:
size = size / 2 size = size // 2
i += 1 i += 1
self.maxid = i - 1 self.maxid = i - 1
@@ -97,16 +102,14 @@ class FpxImageFile(ImageFile.ImageFile):
s = prop[0x2000002 | id] s = prop[0x2000002 | id]
colors = [] if not isinstance(s, bytes) or (bands := i32(s, 4)) > 4:
bands = i32(s, 4)
if bands > 4:
msg = "Invalid number of bands" msg = "Invalid number of bands"
raise OSError(msg) raise OSError(msg)
for i in range(bands):
# note: for now, we ignore the "uncalibrated" flag
colors.append(i32(s, 8 + i * 4) & 0x7FFFFFFF)
self._mode, self.rawmode = MODES[tuple(colors)] # note: for now, we ignore the "uncalibrated" flag
colors = tuple(i32(s, 8 + i * 4) & 0x7FFFFFFF for i in range(bands))
self._mode, self.rawmode = MODES[colors]
# load JPEG tables, if any # load JPEG tables, if any
self.jpeg = {} self.jpeg = {}
@@ -117,7 +120,7 @@ class FpxImageFile(ImageFile.ImageFile):
self._open_subimage(1, self.maxid) self._open_subimage(1, self.maxid)
def _open_subimage(self, index=1, subimage=0): def _open_subimage(self, index: int = 1, subimage: int = 0) -> None:
# #
# setup tile descriptors for a given subimage # setup tile descriptors for a given subimage
@@ -163,18 +166,18 @@ class FpxImageFile(ImageFile.ImageFile):
if compression == 0: if compression == 0:
self.tile.append( self.tile.append(
( ImageFile._Tile(
"raw", "raw",
(x, y, x1, y1), (x, y, x1, y1),
i32(s, i) + 28, i32(s, i) + 28,
(self.rawmode,), self.rawmode,
) )
) )
elif compression == 1: elif compression == 1:
# FIXME: the fill decoder is not implemented # FIXME: the fill decoder is not implemented
self.tile.append( self.tile.append(
( ImageFile._Tile(
"fill", "fill",
(x, y, x1, y1), (x, y, x1, y1),
i32(s, i) + 28, i32(s, i) + 28,
@@ -202,7 +205,7 @@ class FpxImageFile(ImageFile.ImageFile):
jpegmode = rawmode jpegmode = rawmode
self.tile.append( self.tile.append(
( ImageFile._Tile(
"jpeg", "jpeg",
(x, y, x1, y1), (x, y, x1, y1),
i32(s, i) + 28, i32(s, i) + 28,
@@ -227,19 +230,20 @@ class FpxImageFile(ImageFile.ImageFile):
break # isn't really required break # isn't really required
self.stream = stream self.stream = stream
self._fp = self.fp
self.fp = None self.fp = None
def load(self): def load(self) -> Image.core.PixelAccess | None:
if not self.fp: if not self.fp:
self.fp = self.ole.openstream(self.stream[:2] + ["Subimage 0000 Data"]) self.fp = self.ole.openstream(self.stream[:2] + ["Subimage 0000 Data"])
return ImageFile.ImageFile.load(self) return ImageFile.ImageFile.load(self)
def close(self): def close(self) -> None:
self.ole.close() self.ole.close()
super().close() super().close()
def __exit__(self, *args): def __exit__(self, *args: object) -> None:
self.ole.close() self.ole.close()
super().__exit__() super().__exit__()

View File

@@ -51,6 +51,8 @@ bytes for that mipmap level.
Note: All data is stored in little-Endian (Intel) byte order. Note: All data is stored in little-Endian (Intel) byte order.
""" """
from __future__ import annotations
import struct import struct
from enum import IntEnum from enum import IntEnum
from io import BytesIO from io import BytesIO
@@ -69,7 +71,7 @@ class FtexImageFile(ImageFile.ImageFile):
format = "FTEX" format = "FTEX"
format_description = "Texture File Format (IW2:EOC)" format_description = "Texture File Format (IW2:EOC)"
def _open(self): def _open(self) -> None:
if not _accept(self.fp.read(4)): if not _accept(self.fp.read(4)):
msg = "not an FTEX file" msg = "not an FTEX file"
raise SyntaxError(msg) raise SyntaxError(msg)
@@ -77,8 +79,6 @@ class FtexImageFile(ImageFile.ImageFile):
self._size = struct.unpack("<2i", self.fp.read(8)) self._size = struct.unpack("<2i", self.fp.read(8))
mipmap_count, format_count = struct.unpack("<2i", self.fp.read(8)) mipmap_count, format_count = struct.unpack("<2i", self.fp.read(8))
self._mode = "RGB"
# Only support single-format files. # Only support single-format files.
# I don't know of any multi-format file. # I don't know of any multi-format file.
assert format_count == 1 assert format_count == 1
@@ -91,9 +91,10 @@ class FtexImageFile(ImageFile.ImageFile):
if format == Format.DXT1: if format == Format.DXT1:
self._mode = "RGBA" self._mode = "RGBA"
self.tile = [("bcn", (0, 0) + self.size, 0, 1)] self.tile = [ImageFile._Tile("bcn", (0, 0) + self.size, 0, (1,))]
elif format == Format.UNCOMPRESSED: elif format == Format.UNCOMPRESSED:
self.tile = [("raw", (0, 0) + self.size, 0, ("RGB", 0, 1))] self._mode = "RGB"
self.tile = [ImageFile._Tile("raw", (0, 0) + self.size, 0, "RGB")]
else: else:
msg = f"Invalid texture compression format: {repr(format)}" msg = f"Invalid texture compression format: {repr(format)}"
raise ValueError(msg) raise ValueError(msg)
@@ -101,12 +102,12 @@ class FtexImageFile(ImageFile.ImageFile):
self.fp.close() self.fp.close()
self.fp = BytesIO(data) self.fp = BytesIO(data)
def load_seek(self, pos): def load_seek(self, pos: int) -> None:
pass pass
def _accept(prefix): def _accept(prefix: bytes) -> bool:
return prefix[:4] == MAGIC return prefix.startswith(MAGIC)
Image.register_open(FtexImageFile.format, FtexImageFile, _accept) Image.register_open(FtexImageFile.format, FtexImageFile, _accept)

View File

@@ -23,12 +23,13 @@
# Version 2 files are saved by GIMP v2.8 (at least) # Version 2 files are saved by GIMP v2.8 (at least)
# Version 3 files have a format specifier of 18 for 16bit floats in # Version 3 files have a format specifier of 18 for 16bit floats in
# the color depth field. This is currently unsupported by Pillow. # the color depth field. This is currently unsupported by Pillow.
from __future__ import annotations
from . import Image, ImageFile from . import Image, ImageFile
from ._binary import i32be as i32 from ._binary import i32be as i32
def _accept(prefix): def _accept(prefix: bytes) -> bool:
return len(prefix) >= 8 and i32(prefix, 0) >= 20 and i32(prefix, 4) in (1, 2) return len(prefix) >= 8 and i32(prefix, 0) >= 20 and i32(prefix, 4) in (1, 2)
@@ -40,7 +41,7 @@ class GbrImageFile(ImageFile.ImageFile):
format = "GBR" format = "GBR"
format_description = "GIMP brush file" format_description = "GIMP brush file"
def _open(self): def _open(self) -> None:
header_size = i32(self.fp.read(4)) header_size = i32(self.fp.read(4))
if header_size < 20: if header_size < 20:
msg = "not a GIMP brush" msg = "not a GIMP brush"
@@ -53,7 +54,7 @@ class GbrImageFile(ImageFile.ImageFile):
width = i32(self.fp.read(4)) width = i32(self.fp.read(4))
height = i32(self.fp.read(4)) height = i32(self.fp.read(4))
color_depth = i32(self.fp.read(4)) color_depth = i32(self.fp.read(4))
if width <= 0 or height <= 0: if width == 0 or height == 0:
msg = "not a GIMP brush" msg = "not a GIMP brush"
raise SyntaxError(msg) raise SyntaxError(msg)
if color_depth not in (1, 4): if color_depth not in (1, 4):
@@ -70,7 +71,7 @@ class GbrImageFile(ImageFile.ImageFile):
raise SyntaxError(msg) raise SyntaxError(msg)
self.info["spacing"] = i32(self.fp.read(4)) self.info["spacing"] = i32(self.fp.read(4))
comment = self.fp.read(comment_length)[:-1] self.info["comment"] = self.fp.read(comment_length)[:-1]
if color_depth == 1: if color_depth == 1:
self._mode = "L" self._mode = "L"
@@ -79,16 +80,14 @@ class GbrImageFile(ImageFile.ImageFile):
self._size = width, height self._size = width, height
self.info["comment"] = comment
# Image might not be small # Image might not be small
Image._decompression_bomb_check(self.size) Image._decompression_bomb_check(self.size)
# Data is an uncompressed block of w * h * bytes/pixel # Data is an uncompressed block of w * h * bytes/pixel
self._data_size = width * height * color_depth self._data_size = width * height * color_depth
def load(self): def load(self) -> Image.core.PixelAccess | None:
if not self.im: if self._im is None:
self.im = Image.core.new(self.mode, self.size) self.im = Image.core.new(self.mode, self.size)
self.frombytes(self.fp.read(self._data_size)) self.frombytes(self.fp.read(self._data_size))
return Image.Image.load(self) return Image.Image.load(self)

View File

@@ -25,11 +25,14 @@
implementation is provided for convenience and demonstrational implementation is provided for convenience and demonstrational
purposes only. purposes only.
""" """
from __future__ import annotations
from typing import IO
from . import ImageFile, ImagePalette, UnidentifiedImageError from . import ImageFile, ImagePalette, UnidentifiedImageError
from ._binary import i16be as i16 from ._binary import i16be as i16
from ._binary import i32be as i32 from ._binary import i32be as i32
from ._typing import StrOrBytesPath
class GdImageFile(ImageFile.ImageFile): class GdImageFile(ImageFile.ImageFile):
@@ -43,15 +46,17 @@ class GdImageFile(ImageFile.ImageFile):
format = "GD" format = "GD"
format_description = "GD uncompressed images" format_description = "GD uncompressed images"
def _open(self): def _open(self) -> None:
# Header # Header
assert self.fp is not None
s = self.fp.read(1037) s = self.fp.read(1037)
if i16(s) not in [65534, 65535]: if i16(s) not in [65534, 65535]:
msg = "Not a valid GD 2.x .gd file" msg = "Not a valid GD 2.x .gd file"
raise SyntaxError(msg) raise SyntaxError(msg)
self._mode = "L" # FIXME: "P" self._mode = "P"
self._size = i16(s, 2), i16(s, 4) self._size = i16(s, 2), i16(s, 4)
true_color = s[6] true_color = s[6]
@@ -63,20 +68,20 @@ class GdImageFile(ImageFile.ImageFile):
self.info["transparency"] = tindex self.info["transparency"] = tindex
self.palette = ImagePalette.raw( self.palette = ImagePalette.raw(
"XBGR", s[7 + true_color_offset + 4 : 7 + true_color_offset + 4 + 256 * 4] "RGBX", s[7 + true_color_offset + 6 : 7 + true_color_offset + 6 + 256 * 4]
) )
self.tile = [ self.tile = [
( ImageFile._Tile(
"raw", "raw",
(0, 0) + self.size, (0, 0) + self.size,
7 + true_color_offset + 4 + 256 * 4, 7 + true_color_offset + 6 + 256 * 4,
("L", 0, 1), "L",
) )
] ]
def open(fp, mode="r"): def open(fp: StrOrBytesPath | IO[bytes], mode: str = "r") -> GdImageFile:
""" """
Load texture from a GD image file. Load texture from a GD image file.

View File

@@ -23,17 +23,36 @@
# #
# See the README file for information on usage and redistribution. # See the README file for information on usage and redistribution.
# #
from __future__ import annotations
import itertools import itertools
import math import math
import os import os
import subprocess import subprocess
from enum import IntEnum from enum import IntEnum
from functools import cached_property
from typing import Any, NamedTuple, cast
from . import Image, ImageChops, ImageFile, ImagePalette, ImageSequence from . import (
Image,
ImageChops,
ImageFile,
ImageMath,
ImageOps,
ImagePalette,
ImageSequence,
)
from ._binary import i16le as i16 from ._binary import i16le as i16
from ._binary import o8 from ._binary import o8
from ._binary import o16le as o16 from ._binary import o16le as o16
from ._util import DeferredError
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import IO, Literal
from . import _imaging
from ._typing import Buffer
class LoadingStrategy(IntEnum): class LoadingStrategy(IntEnum):
@@ -51,8 +70,8 @@ LOADING_STRATEGY = LoadingStrategy.RGB_AFTER_FIRST
# Identify/read GIF files # Identify/read GIF files
def _accept(prefix): def _accept(prefix: bytes) -> bool:
return prefix[:6] in [b"GIF87a", b"GIF89a"] return prefix.startswith((b"GIF87a", b"GIF89a"))
## ##
@@ -67,19 +86,19 @@ class GifImageFile(ImageFile.ImageFile):
global_palette = None global_palette = None
def data(self): def data(self) -> bytes | None:
s = self.fp.read(1) s = self.fp.read(1)
if s and s[0]: if s and s[0]:
return self.fp.read(s[0]) return self.fp.read(s[0])
return None return None
def _is_palette_needed(self, p): def _is_palette_needed(self, p: bytes) -> bool:
for i in range(0, len(p), 3): for i in range(0, len(p), 3):
if not (i // 3 == p[i] == p[i + 1] == p[i + 2]): if not (i // 3 == p[i] == p[i + 1] == p[i + 2]):
return True return True
return False return False
def _open(self): def _open(self) -> None:
# Screen # Screen
s = self.fp.read(13) s = self.fp.read(13)
if not _accept(s): if not _accept(s):
@@ -88,7 +107,6 @@ class GifImageFile(ImageFile.ImageFile):
self.info["version"] = s[:6] self.info["version"] = s[:6]
self._size = i16(s, 6), i16(s, 8) self._size = i16(s, 6), i16(s, 8)
self.tile = []
flags = s[10] flags = s[10]
bits = (flags & 7) + 1 bits = (flags & 7) + 1
@@ -103,12 +121,11 @@ class GifImageFile(ImageFile.ImageFile):
self._fp = self.fp # FIXME: hack self._fp = self.fp # FIXME: hack
self.__rewind = self.fp.tell() self.__rewind = self.fp.tell()
self._n_frames = None self._n_frames: int | None = None
self._is_animated = None
self._seek(0) # get ready to read first frame self._seek(0) # get ready to read first frame
@property @property
def n_frames(self): def n_frames(self) -> int:
if self._n_frames is None: if self._n_frames is None:
current = self.tell() current = self.tell()
try: try:
@@ -119,30 +136,29 @@ class GifImageFile(ImageFile.ImageFile):
self.seek(current) self.seek(current)
return self._n_frames return self._n_frames
@property @cached_property
def is_animated(self): def is_animated(self) -> bool:
if self._is_animated is None: if self._n_frames is not None:
if self._n_frames is not None: return self._n_frames != 1
self._is_animated = self._n_frames != 1
else:
current = self.tell()
if current:
self._is_animated = True
else:
try:
self._seek(1, False)
self._is_animated = True
except EOFError:
self._is_animated = False
self.seek(current) current = self.tell()
return self._is_animated if current:
return True
def seek(self, frame): try:
self._seek(1, False)
is_animated = True
except EOFError:
is_animated = False
self.seek(current)
return is_animated
def seek(self, frame: int) -> None:
if not self._seek_check(frame): if not self._seek_check(frame):
return return
if frame < self.__frame: if frame < self.__frame:
self.im = None self._im = None
self._seek(0) self._seek(0)
last_frame = self.__frame last_frame = self.__frame
@@ -154,11 +170,13 @@ class GifImageFile(ImageFile.ImageFile):
msg = "no more images in GIF file" msg = "no more images in GIF file"
raise EOFError(msg) from e raise EOFError(msg) from e
def _seek(self, frame, update_image=True): def _seek(self, frame: int, update_image: bool = True) -> None:
if isinstance(self._fp, DeferredError):
raise self._fp.ex
if frame == 0: if frame == 0:
# rewind # rewind
self.__offset = 0 self.__offset = 0
self.dispose = None self.dispose: _imaging.ImagingCore | None = None
self.__frame = -1 self.__frame = -1
self._fp.seek(self.__rewind) self._fp.seek(self.__rewind)
self.disposal_method = 0 self.disposal_method = 0
@@ -183,11 +201,12 @@ class GifImageFile(ImageFile.ImageFile):
s = self.fp.read(1) s = self.fp.read(1)
if not s or s == b";": if not s or s == b";":
raise EOFError msg = "no more images in GIF file"
raise EOFError(msg)
palette = None palette: ImagePalette.ImagePalette | Literal[False] | None = None
info = {} info: dict[str, Any] = {}
frame_transparency = None frame_transparency = None
interlace = None interlace = None
frame_dispose_extent = None frame_dispose_extent = None
@@ -203,7 +222,7 @@ class GifImageFile(ImageFile.ImageFile):
# #
s = self.fp.read(1) s = self.fp.read(1)
block = self.data() block = self.data()
if s[0] == 249: if s[0] == 249 and block is not None:
# #
# graphic control extension # graphic control extension
# #
@@ -239,14 +258,14 @@ class GifImageFile(ImageFile.ImageFile):
info["comment"] = comment info["comment"] = comment
s = None s = None
continue continue
elif s[0] == 255 and frame == 0: elif s[0] == 255 and frame == 0 and block is not None:
# #
# application extension # application extension
# #
info["extension"] = block, self.fp.tell() info["extension"] = block, self.fp.tell()
if block[:11] == b"NETSCAPE2.0": if block.startswith(b"NETSCAPE2.0"):
block = self.data() block = self.data()
if len(block) >= 3 and block[0] == 1: if block and len(block) >= 3 and block[0] == 1:
self.info["loop"] = i16(block, 1) self.info["loop"] = i16(block, 1)
while self.data(): while self.data():
pass pass
@@ -280,15 +299,11 @@ class GifImageFile(ImageFile.ImageFile):
bits = self.fp.read(1)[0] bits = self.fp.read(1)[0]
self.__offset = self.fp.tell() self.__offset = self.fp.tell()
break break
else:
pass
# raise OSError, "illegal GIF tag `%x`" % s[0]
s = None s = None
if interlace is None: if interlace is None:
# self._fp = None msg = "image not found in GIF frame"
raise EOFError raise EOFError(msg)
self.__frame = frame self.__frame = frame
if not update_image: if not update_image:
@@ -310,18 +325,20 @@ class GifImageFile(ImageFile.ImageFile):
else: else:
self._mode = "L" self._mode = "L"
if not palette and self.global_palette: if palette:
self.palette = palette
elif self.global_palette:
from copy import copy from copy import copy
palette = copy(self.global_palette) self.palette = copy(self.global_palette)
self.palette = palette else:
self.palette = None
else: else:
if self.mode == "P": if self.mode == "P":
if ( if (
LOADING_STRATEGY != LoadingStrategy.RGB_AFTER_DIFFERENT_PALETTE_ONLY LOADING_STRATEGY != LoadingStrategy.RGB_AFTER_DIFFERENT_PALETTE_ONLY
or palette or palette
): ):
self.pyaccess = None
if "transparency" in self.info: if "transparency" in self.info:
self.im.putpalettealpha(self.info["transparency"], 0) self.im.putpalettealpha(self.info["transparency"], 0)
self.im = self.im.convert("RGBA", Image.Dither.FLOYDSTEINBERG) self.im = self.im.convert("RGBA", Image.Dither.FLOYDSTEINBERG)
@@ -331,58 +348,63 @@ class GifImageFile(ImageFile.ImageFile):
self._mode = "RGB" self._mode = "RGB"
self.im = self.im.convert("RGB", Image.Dither.FLOYDSTEINBERG) self.im = self.im.convert("RGB", Image.Dither.FLOYDSTEINBERG)
def _rgb(color): def _rgb(color: int) -> tuple[int, int, int]:
if self._frame_palette: if self._frame_palette:
color = tuple(self._frame_palette.palette[color * 3 : color * 3 + 3]) if color * 3 + 3 > len(self._frame_palette.palette):
color = 0
return cast(
tuple[int, int, int],
tuple(self._frame_palette.palette[color * 3 : color * 3 + 3]),
)
else: else:
color = (color, color, color) return (color, color, color)
return color
self.dispose_extent = frame_dispose_extent self.dispose = None
try: self.dispose_extent: tuple[int, int, int, int] | None = frame_dispose_extent
if self.disposal_method < 2: if self.dispose_extent and self.disposal_method >= 2:
# do not dispose or none specified try:
self.dispose = None if self.disposal_method == 2:
elif self.disposal_method == 2: # replace with background colour
# replace with background colour
# only dispose the extent in this frame
x0, y0, x1, y1 = self.dispose_extent
dispose_size = (x1 - x0, y1 - y0)
Image._decompression_bomb_check(dispose_size)
# by convention, attempt to use transparency first
dispose_mode = "P"
color = self.info.get("transparency", frame_transparency)
if color is not None:
if self.mode in ("RGB", "RGBA"):
dispose_mode = "RGBA"
color = _rgb(color) + (0,)
else:
color = self.info.get("background", 0)
if self.mode in ("RGB", "RGBA"):
dispose_mode = "RGB"
color = _rgb(color)
self.dispose = Image.core.fill(dispose_mode, dispose_size, color)
else:
# replace with previous contents
if self.im is not None:
# only dispose the extent in this frame # only dispose the extent in this frame
self.dispose = self._crop(self.im, self.dispose_extent)
elif frame_transparency is not None:
x0, y0, x1, y1 = self.dispose_extent x0, y0, x1, y1 = self.dispose_extent
dispose_size = (x1 - x0, y1 - y0) dispose_size = (x1 - x0, y1 - y0)
Image._decompression_bomb_check(dispose_size) Image._decompression_bomb_check(dispose_size)
# by convention, attempt to use transparency first
dispose_mode = "P" dispose_mode = "P"
color = frame_transparency color = self.info.get("transparency", frame_transparency)
if self.mode in ("RGB", "RGBA"): if color is not None:
dispose_mode = "RGBA" if self.mode in ("RGB", "RGBA"):
color = _rgb(frame_transparency) + (0,) dispose_mode = "RGBA"
color = _rgb(color) + (0,)
else:
color = self.info.get("background", 0)
if self.mode in ("RGB", "RGBA"):
dispose_mode = "RGB"
color = _rgb(color)
self.dispose = Image.core.fill(dispose_mode, dispose_size, color) self.dispose = Image.core.fill(dispose_mode, dispose_size, color)
except AttributeError: else:
pass # replace with previous contents
if self._im is not None:
# only dispose the extent in this frame
self.dispose = self._crop(self.im, self.dispose_extent)
elif frame_transparency is not None:
x0, y0, x1, y1 = self.dispose_extent
dispose_size = (x1 - x0, y1 - y0)
Image._decompression_bomb_check(dispose_size)
dispose_mode = "P"
color = frame_transparency
if self.mode in ("RGB", "RGBA"):
dispose_mode = "RGBA"
color = _rgb(frame_transparency) + (0,)
self.dispose = Image.core.fill(
dispose_mode, dispose_size, color
)
except AttributeError:
pass
if interlace is not None: if interlace is not None:
transparency = -1 transparency = -1
@@ -393,7 +415,7 @@ class GifImageFile(ImageFile.ImageFile):
elif self.mode not in ("RGB", "RGBA"): elif self.mode not in ("RGB", "RGBA"):
transparency = frame_transparency transparency = frame_transparency
self.tile = [ self.tile = [
( ImageFile._Tile(
"gif", "gif",
(x0, y0, x1, y1), (x0, y0, x1, y1),
self.__offset, self.__offset,
@@ -409,7 +431,7 @@ class GifImageFile(ImageFile.ImageFile):
elif k in self.info: elif k in self.info:
del self.info[k] del self.info[k]
def load_prepare(self): def load_prepare(self) -> None:
temp_mode = "P" if self._frame_palette else "L" temp_mode = "P" if self._frame_palette else "L"
self._prev_im = None self._prev_im = None
if self.__frame == 0: if self.__frame == 0:
@@ -421,15 +443,22 @@ class GifImageFile(ImageFile.ImageFile):
self._prev_im = self.im self._prev_im = self.im
if self._frame_palette: if self._frame_palette:
self.im = Image.core.fill("P", self.size, self._frame_transparency or 0) self.im = Image.core.fill("P", self.size, self._frame_transparency or 0)
self.im.putpalette(*self._frame_palette.getdata()) self.im.putpalette("RGB", *self._frame_palette.getdata())
else: else:
self.im = None self._im = None
if not self._prev_im and self._im is not None and self.size != self.im.size:
expanded_im = Image.core.fill(self.im.mode, self.size)
if self._frame_palette:
expanded_im.putpalette("RGB", *self._frame_palette.getdata())
expanded_im.paste(self.im, (0, 0) + self.im.size)
self.im = expanded_im
self._mode = temp_mode self._mode = temp_mode
self._frame_palette = None self._frame_palette = None
super().load_prepare() super().load_prepare()
def load_end(self): def load_end(self) -> None:
if self.__frame == 0: if self.__frame == 0:
if self.mode == "P" and LOADING_STRATEGY == LoadingStrategy.RGB_ALWAYS: if self.mode == "P" and LOADING_STRATEGY == LoadingStrategy.RGB_ALWAYS:
if self._frame_transparency is not None: if self._frame_transparency is not None:
@@ -441,21 +470,37 @@ class GifImageFile(ImageFile.ImageFile):
return return
if not self._prev_im: if not self._prev_im:
return return
if self.size != self._prev_im.size:
if self._frame_transparency is not None:
expanded_im = Image.core.fill("RGBA", self.size)
else:
expanded_im = Image.core.fill("P", self.size)
expanded_im.putpalette("RGB", "RGB", self.im.getpalette())
expanded_im = expanded_im.convert("RGB")
expanded_im.paste(self._prev_im, (0, 0) + self._prev_im.size)
self._prev_im = expanded_im
assert self._prev_im is not None
if self._frame_transparency is not None: if self._frame_transparency is not None:
self.im.putpalettealpha(self._frame_transparency, 0) if self.mode == "L":
frame_im = self.im.convert("RGBA") frame_im = self.im.convert_transparent("LA", self._frame_transparency)
else:
self.im.putpalettealpha(self._frame_transparency, 0)
frame_im = self.im.convert("RGBA")
else: else:
frame_im = self.im.convert("RGB") frame_im = self.im.convert("RGB")
assert self.dispose_extent is not None
frame_im = self._crop(frame_im, self.dispose_extent) frame_im = self._crop(frame_im, self.dispose_extent)
self.im = self._prev_im self.im = self._prev_im
self._mode = self.im.mode self._mode = self.im.mode
if frame_im.mode == "RGBA": if frame_im.mode in ("LA", "RGBA"):
self.im.paste(frame_im, self.dispose_extent, frame_im) self.im.paste(frame_im, self.dispose_extent, frame_im)
else: else:
self.im.paste(frame_im, self.dispose_extent) self.im.paste(frame_im, self.dispose_extent)
def tell(self): def tell(self) -> int:
return self.__frame return self.__frame
@@ -466,7 +511,7 @@ class GifImageFile(ImageFile.ImageFile):
RAWMODE = {"1": "L", "L": "L", "P": "P"} RAWMODE = {"1": "L", "L": "L", "P": "P"}
def _normalize_mode(im): def _normalize_mode(im: Image.Image) -> Image.Image:
""" """
Takes an image (or frame), returns an image in a mode that is appropriate Takes an image (or frame), returns an image in a mode that is appropriate
for saving in a Gif. for saving in a Gif.
@@ -482,6 +527,7 @@ def _normalize_mode(im):
return im return im
if Image.getmodebase(im.mode) == "RGB": if Image.getmodebase(im.mode) == "RGB":
im = im.convert("P", palette=Image.Palette.ADAPTIVE) im = im.convert("P", palette=Image.Palette.ADAPTIVE)
assert im.palette is not None
if im.palette.mode == "RGBA": if im.palette.mode == "RGBA":
for rgba in im.palette.colors: for rgba in im.palette.colors:
if rgba[3] == 0: if rgba[3] == 0:
@@ -491,7 +537,12 @@ def _normalize_mode(im):
return im.convert("L") return im.convert("L")
def _normalize_palette(im, palette, info): _Palette = bytes | bytearray | list[int] | ImagePalette.ImagePalette
def _normalize_palette(
im: Image.Image, palette: _Palette | None, info: dict[str, Any]
) -> Image.Image:
""" """
Normalizes the palette for image. Normalizes the palette for image.
- Sets the palette to the incoming palette, if provided. - Sets the palette to the incoming palette, if provided.
@@ -513,14 +564,18 @@ def _normalize_palette(im, palette, info):
if im.mode == "P": if im.mode == "P":
if not source_palette: if not source_palette:
source_palette = im.im.getpalette("RGB")[:768] im_palette = im.getpalette(None)
assert im_palette is not None
source_palette = bytearray(im_palette)
else: # L-mode else: # L-mode
if not source_palette: if not source_palette:
source_palette = bytearray(i // 3 for i in range(768)) source_palette = bytearray(i // 3 for i in range(768))
im.palette = ImagePalette.ImagePalette("RGB", palette=source_palette) im.palette = ImagePalette.ImagePalette("RGB", palette=source_palette)
assert source_palette is not None
if palette: if palette:
used_palette_colors = [] used_palette_colors: list[int | None] = []
assert im.palette is not None
for i in range(0, len(source_palette), 3): for i in range(0, len(source_palette), 3):
source_color = tuple(source_palette[i : i + 3]) source_color = tuple(source_palette[i : i + 3])
index = im.palette.colors.get(source_color) index = im.palette.colors.get(source_color)
@@ -533,20 +588,38 @@ def _normalize_palette(im, palette, info):
if j not in used_palette_colors: if j not in used_palette_colors:
used_palette_colors[i] = j used_palette_colors[i] = j
break break
im = im.remap_palette(used_palette_colors) dest_map: list[int] = []
for index in used_palette_colors:
assert index is not None
dest_map.append(index)
im = im.remap_palette(dest_map)
else: else:
used_palette_colors = _get_optimize(im, info) optimized_palette_colors = _get_optimize(im, info)
if used_palette_colors is not None: if optimized_palette_colors is not None:
return im.remap_palette(used_palette_colors, source_palette) im = im.remap_palette(optimized_palette_colors, source_palette)
if "transparency" in info:
try:
info["transparency"] = optimized_palette_colors.index(
info["transparency"]
)
except ValueError:
del info["transparency"]
return im
assert im.palette is not None
im.palette.palette = source_palette im.palette.palette = source_palette
return im return im
def _write_single_frame(im, fp, palette): def _write_single_frame(
im: Image.Image,
fp: IO[bytes],
palette: _Palette | None,
) -> None:
im_out = _normalize_mode(im) im_out = _normalize_mode(im)
for k, v in im_out.info.items(): for k, v in im_out.info.items():
im.encoderinfo.setdefault(k, v) if isinstance(k, str):
im.encoderinfo.setdefault(k, v)
im_out = _normalize_palette(im_out, palette, im.encoderinfo) im_out = _normalize_palette(im_out, palette, im.encoderinfo)
for s in _get_global_header(im_out, im.encoderinfo): for s in _get_global_header(im_out, im.encoderinfo):
@@ -559,26 +632,40 @@ def _write_single_frame(im, fp, palette):
_write_local_header(fp, im, (0, 0), flags) _write_local_header(fp, im, (0, 0), flags)
im_out.encoderconfig = (8, get_interlace(im)) im_out.encoderconfig = (8, get_interlace(im))
ImageFile._save(im_out, fp, [("gif", (0, 0) + im.size, 0, RAWMODE[im_out.mode])]) ImageFile._save(
im_out, fp, [ImageFile._Tile("gif", (0, 0) + im.size, 0, RAWMODE[im_out.mode])]
)
fp.write(b"\0") # end of image data fp.write(b"\0") # end of image data
def _getbbox(base_im, im_frame): def _getbbox(
if _get_palette_bytes(im_frame) == _get_palette_bytes(base_im): base_im: Image.Image, im_frame: Image.Image
delta = ImageChops.subtract_modulo(im_frame, base_im) ) -> tuple[Image.Image, tuple[int, int, int, int] | None]:
else: palette_bytes = [
delta = ImageChops.subtract_modulo( bytes(im.palette.palette) if im.palette else b"" for im in (base_im, im_frame)
im_frame.convert("RGBA"), base_im.convert("RGBA") ]
) if palette_bytes[0] != palette_bytes[1]:
return delta.getbbox(alpha_only=False) im_frame = im_frame.convert("RGBA")
base_im = base_im.convert("RGBA")
delta = ImageChops.subtract_modulo(im_frame, base_im)
return delta, delta.getbbox(alpha_only=False)
def _write_multiple_frames(im, fp, palette): class _Frame(NamedTuple):
im: Image.Image
bbox: tuple[int, int, int, int] | None
encoderinfo: dict[str, Any]
def _write_multiple_frames(
im: Image.Image, fp: IO[bytes], palette: _Palette | None
) -> bool:
duration = im.encoderinfo.get("duration") duration = im.encoderinfo.get("duration")
disposal = im.encoderinfo.get("disposal", im.info.get("disposal")) disposal = im.encoderinfo.get("disposal", im.info.get("disposal"))
im_frames = [] im_frames: list[_Frame] = []
previous_im: Image.Image | None = None
frame_count = 0 frame_count = 0
background_im = None background_im = None
for imSequence in itertools.chain([im], im.encoderinfo.get("append_images", [])): for imSequence in itertools.chain([im], im.encoderinfo.get("append_images", [])):
@@ -589,12 +676,13 @@ def _write_multiple_frames(im, fp, palette):
for k, v in im_frame.info.items(): for k, v in im_frame.info.items():
if k == "transparency": if k == "transparency":
continue continue
im.encoderinfo.setdefault(k, v) if isinstance(k, str):
im.encoderinfo.setdefault(k, v)
encoderinfo = im.encoderinfo.copy() encoderinfo = im.encoderinfo.copy()
im_frame = _normalize_palette(im_frame, palette, encoderinfo)
if "transparency" in im_frame.info: if "transparency" in im_frame.info:
encoderinfo.setdefault("transparency", im_frame.info["transparency"]) encoderinfo.setdefault("transparency", im_frame.info["transparency"])
im_frame = _normalize_palette(im_frame, palette, encoderinfo)
if isinstance(duration, (list, tuple)): if isinstance(duration, (list, tuple)):
encoderinfo["duration"] = duration[frame_count] encoderinfo["duration"] = duration[frame_count]
elif duration is None and "duration" in im_frame.info: elif duration is None and "duration" in im_frame.info:
@@ -603,63 +691,116 @@ def _write_multiple_frames(im, fp, palette):
encoderinfo["disposal"] = disposal[frame_count] encoderinfo["disposal"] = disposal[frame_count]
frame_count += 1 frame_count += 1
if im_frames: diff_frame = None
if im_frames and previous_im:
# delta frame # delta frame
previous = im_frames[-1] delta, bbox = _getbbox(previous_im, im_frame)
bbox = _getbbox(previous["im"], im_frame)
if not bbox: if not bbox:
# This frame is identical to the previous frame # This frame is identical to the previous frame
if encoderinfo.get("duration"): if encoderinfo.get("duration"):
previous["encoderinfo"]["duration"] += encoderinfo["duration"] im_frames[-1].encoderinfo["duration"] += encoderinfo["duration"]
continue continue
if encoderinfo.get("disposal") == 2: if im_frames[-1].encoderinfo.get("disposal") == 2:
if background_im is None: # To appear correctly in viewers using a convention,
color = im.encoderinfo.get( # only consider transparency, and not background color
"transparency", im.info.get("transparency", (0, 0, 0)) color = im.encoderinfo.get(
) "transparency", im.info.get("transparency")
background = _get_background(im_frame, color) )
background_im = Image.new("P", im_frame.size, background) if color is not None:
background_im.putpalette(im_frames[0]["im"].palette) if background_im is None:
bbox = _getbbox(background_im, im_frame) background = _get_background(im_frame, color)
background_im = Image.new("P", im_frame.size, background)
first_palette = im_frames[0].im.palette
assert first_palette is not None
background_im.putpalette(first_palette, first_palette.mode)
bbox = _getbbox(background_im, im_frame)[1]
else:
bbox = (0, 0) + im_frame.size
elif encoderinfo.get("optimize") and im_frame.mode != "1":
if "transparency" not in encoderinfo:
assert im_frame.palette is not None
try:
encoderinfo["transparency"] = (
im_frame.palette._new_color_index(im_frame)
)
except ValueError:
pass
if "transparency" in encoderinfo:
# When the delta is zero, fill the image with transparency
diff_frame = im_frame.copy()
fill = Image.new("P", delta.size, encoderinfo["transparency"])
if delta.mode == "RGBA":
r, g, b, a = delta.split()
mask = ImageMath.lambda_eval(
lambda args: args["convert"](
args["max"](
args["max"](
args["max"](args["r"], args["g"]), args["b"]
),
args["a"],
)
* 255,
"1",
),
r=r,
g=g,
b=b,
a=a,
)
else:
if delta.mode == "P":
# Convert to L without considering palette
delta_l = Image.new("L", delta.size)
delta_l.putdata(delta.getdata())
delta = delta_l
mask = ImageMath.lambda_eval(
lambda args: args["convert"](args["im"] * 255, "1"),
im=delta,
)
diff_frame.paste(fill, mask=ImageOps.invert(mask))
else: else:
bbox = None bbox = None
im_frames.append({"im": im_frame, "bbox": bbox, "encoderinfo": encoderinfo}) previous_im = im_frame
im_frames.append(_Frame(diff_frame or im_frame, bbox, encoderinfo))
if len(im_frames) > 1: if len(im_frames) == 1:
for frame_data in im_frames: if "duration" in im.encoderinfo:
im_frame = frame_data["im"] # Since multiple frames will not be written, use the combined duration
if not frame_data["bbox"]: im.encoderinfo["duration"] = im_frames[0].encoderinfo["duration"]
# global header return False
for s in _get_global_header(im_frame, frame_data["encoderinfo"]):
fp.write(s)
offset = (0, 0)
else:
# compress difference
if not palette:
frame_data["encoderinfo"]["include_color_table"] = True
im_frame = im_frame.crop(frame_data["bbox"]) for frame_data in im_frames:
offset = frame_data["bbox"][:2] im_frame = frame_data.im
_write_frame_data(fp, im_frame, offset, frame_data["encoderinfo"]) if not frame_data.bbox:
return True # global header
elif "duration" in im.encoderinfo and isinstance( for s in _get_global_header(im_frame, frame_data.encoderinfo):
im.encoderinfo["duration"], (list, tuple) fp.write(s)
): offset = (0, 0)
# Since multiple frames will not be written, add together the frame durations else:
im.encoderinfo["duration"] = sum(im.encoderinfo["duration"]) # compress difference
if not palette:
frame_data.encoderinfo["include_color_table"] = True
if frame_data.bbox != (0, 0) + im_frame.size:
im_frame = im_frame.crop(frame_data.bbox)
offset = frame_data.bbox[:2]
_write_frame_data(fp, im_frame, offset, frame_data.encoderinfo)
return True
def _save_all(im, fp, filename): def _save_all(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
_save(im, fp, filename, save_all=True) _save(im, fp, filename, save_all=True)
def _save(im, fp, filename, save_all=False): def _save(
im: Image.Image, fp: IO[bytes], filename: str | bytes, save_all: bool = False
) -> None:
# header # header
if "palette" in im.encoderinfo or "palette" in im.info: if "palette" in im.encoderinfo or "palette" in im.info:
palette = im.encoderinfo.get("palette", im.info.get("palette")) palette = im.encoderinfo.get("palette", im.info.get("palette"))
else: else:
palette = None palette = None
im.encoderinfo["optimize"] = im.encoderinfo.get("optimize", True) im.encoderinfo.setdefault("optimize", True)
if not save_all or not _write_multiple_frames(im, fp, palette): if not save_all or not _write_multiple_frames(im, fp, palette):
_write_single_frame(im, fp, palette) _write_single_frame(im, fp, palette)
@@ -670,7 +811,7 @@ def _save(im, fp, filename, save_all=False):
fp.flush() fp.flush()
def get_interlace(im): def get_interlace(im: Image.Image) -> int:
interlace = im.encoderinfo.get("interlace", 1) interlace = im.encoderinfo.get("interlace", 1)
# workaround for @PIL153 # workaround for @PIL153
@@ -680,23 +821,13 @@ def get_interlace(im):
return interlace return interlace
def _write_local_header(fp, im, offset, flags): def _write_local_header(
transparent_color_exists = False fp: IO[bytes], im: Image.Image, offset: tuple[int, int], flags: int
) -> None:
try: try:
transparency = int(im.encoderinfo["transparency"]) transparency = im.encoderinfo["transparency"]
except (KeyError, ValueError): except KeyError:
pass transparency = None
else:
# optimize the block away if transparent color is not used
transparent_color_exists = True
used_palette_colors = _get_optimize(im, im.encoderinfo)
if used_palette_colors is not None:
# adjust the transparency index after optimize
try:
transparency = used_palette_colors.index(transparency)
except ValueError:
transparent_color_exists = False
if "duration" in im.encoderinfo: if "duration" in im.encoderinfo:
duration = int(im.encoderinfo["duration"] / 10) duration = int(im.encoderinfo["duration"] / 10)
@@ -705,11 +836,9 @@ def _write_local_header(fp, im, offset, flags):
disposal = int(im.encoderinfo.get("disposal", 0)) disposal = int(im.encoderinfo.get("disposal", 0))
if transparent_color_exists or duration != 0 or disposal: if transparency is not None or duration != 0 or disposal:
packed_flag = 1 if transparent_color_exists else 0 packed_flag = 1 if transparency is not None else 0
packed_flag |= disposal << 2 packed_flag |= disposal << 2
if not transparent_color_exists:
transparency = 0
fp.write( fp.write(
b"!" b"!"
@@ -717,7 +846,7 @@ def _write_local_header(fp, im, offset, flags):
+ o8(4) # length + o8(4) # length
+ o8(packed_flag) # packed fields + o8(packed_flag) # packed fields
+ o16(duration) # duration + o16(duration) # duration
+ o8(transparency) # transparency index + o8(transparency or 0) # transparency index
+ o8(0) + o8(0)
) )
@@ -742,7 +871,7 @@ def _write_local_header(fp, im, offset, flags):
fp.write(o8(8)) # bits fp.write(o8(8)) # bits
def _save_netpbm(im, fp, filename): def _save_netpbm(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
# Unused by default. # Unused by default.
# To use, uncomment the register_save call at the end of the file. # To use, uncomment the register_save call at the end of the file.
# #
@@ -773,6 +902,7 @@ def _save_netpbm(im, fp, filename):
) )
# Allow ppmquant to receive SIGPIPE if ppmtogif exits # Allow ppmquant to receive SIGPIPE if ppmtogif exits
assert quant_proc.stdout is not None
quant_proc.stdout.close() quant_proc.stdout.close()
retcode = quant_proc.wait() retcode = quant_proc.wait()
@@ -794,7 +924,7 @@ def _save_netpbm(im, fp, filename):
_FORCE_OPTIMIZE = False _FORCE_OPTIMIZE = False
def _get_optimize(im, info): def _get_optimize(im: Image.Image, info: dict[str, Any]) -> list[int] | None:
""" """
Palette optimization is a potentially expensive operation. Palette optimization is a potentially expensive operation.
@@ -805,7 +935,7 @@ def _get_optimize(im, info):
:param info: encoderinfo :param info: encoderinfo
:returns: list of indexes of palette entries in use, or None :returns: list of indexes of palette entries in use, or None
""" """
if im.mode in ("P", "L") and info and info.get("optimize", 0): if im.mode in ("P", "L") and info and info.get("optimize"):
# Potentially expensive operation. # Potentially expensive operation.
# The palette saves 3 bytes per color not used, but palette # The palette saves 3 bytes per color not used, but palette
@@ -827,6 +957,7 @@ def _get_optimize(im, info):
if optimise or max(used_palette_colors) >= len(used_palette_colors): if optimise or max(used_palette_colors) >= len(used_palette_colors):
return used_palette_colors return used_palette_colors
assert im.palette is not None
num_palette_colors = len(im.palette.palette) // Image.getmodebands( num_palette_colors = len(im.palette.palette) // Image.getmodebands(
im.palette.mode im.palette.mode
) )
@@ -838,9 +969,10 @@ def _get_optimize(im, info):
and current_palette_size > 2 and current_palette_size > 2
): ):
return used_palette_colors return used_palette_colors
return None
def _get_color_table_size(palette_bytes): def _get_color_table_size(palette_bytes: bytes) -> int:
# calculate the palette size for the header # calculate the palette size for the header
if not palette_bytes: if not palette_bytes:
return 0 return 0
@@ -850,7 +982,7 @@ def _get_color_table_size(palette_bytes):
return math.ceil(math.log(len(palette_bytes) // 3, 2)) - 1 return math.ceil(math.log(len(palette_bytes) // 3, 2)) - 1
def _get_header_palette(palette_bytes): def _get_header_palette(palette_bytes: bytes) -> bytes:
""" """
Returns the palette, null padded to the next power of 2 (*3) bytes Returns the palette, null padded to the next power of 2 (*3) bytes
suitable for direct inclusion in the GIF header suitable for direct inclusion in the GIF header
@@ -868,23 +1000,33 @@ def _get_header_palette(palette_bytes):
return palette_bytes return palette_bytes
def _get_palette_bytes(im): def _get_palette_bytes(im: Image.Image) -> bytes:
""" """
Gets the palette for inclusion in the gif header Gets the palette for inclusion in the gif header
:param im: Image object :param im: Image object
:returns: Bytes, len<=768 suitable for inclusion in gif header :returns: Bytes, len<=768 suitable for inclusion in gif header
""" """
return im.palette.palette if im.palette else b"" if not im.palette:
return b""
palette = bytes(im.palette.palette)
if im.palette.mode == "RGBA":
palette = b"".join(palette[i * 4 : i * 4 + 3] for i in range(len(palette) // 3))
return palette
def _get_background(im, info_background): def _get_background(
im: Image.Image,
info_background: int | tuple[int, int, int] | tuple[int, int, int, int] | None,
) -> int:
background = 0 background = 0
if info_background: if info_background:
if isinstance(info_background, tuple): if isinstance(info_background, tuple):
# WebPImagePlugin stores an RGBA value in info["background"] # WebPImagePlugin stores an RGBA value in info["background"]
# So it must be converted to the same format as GifImagePlugin's # So it must be converted to the same format as GifImagePlugin's
# info["background"] - a global color table index # info["background"] - a global color table index
assert im.palette is not None
try: try:
background = im.palette.getcolor(info_background, im) background = im.palette.getcolor(info_background, im)
except ValueError as e: except ValueError as e:
@@ -901,7 +1043,7 @@ def _get_background(im, info_background):
return background return background
def _get_global_header(im, info): def _get_global_header(im: Image.Image, info: dict[str, Any]) -> list[bytes]:
"""Return a list of strings representing a GIF header""" """Return a list of strings representing a GIF header"""
# Header Block # Header Block
@@ -963,7 +1105,12 @@ def _get_global_header(im, info):
return header return header
def _write_frame_data(fp, im_frame, offset, params): def _write_frame_data(
fp: IO[bytes],
im_frame: Image.Image,
offset: tuple[int, int],
params: dict[str, Any],
) -> None:
try: try:
im_frame.encoderinfo = params im_frame.encoderinfo = params
@@ -971,7 +1118,9 @@ def _write_frame_data(fp, im_frame, offset, params):
_write_local_header(fp, im_frame, offset, 0) _write_local_header(fp, im_frame, offset, 0)
ImageFile._save( ImageFile._save(
im_frame, fp, [("gif", (0, 0) + im_frame.size, 0, RAWMODE[im_frame.mode])] im_frame,
fp,
[ImageFile._Tile("gif", (0, 0) + im_frame.size, 0, RAWMODE[im_frame.mode])],
) )
fp.write(b"\0") # end of image data fp.write(b"\0") # end of image data
@@ -983,7 +1132,9 @@ def _write_frame_data(fp, im_frame, offset, params):
# Legacy GIF utilities # Legacy GIF utilities
def getheader(im, palette=None, info=None): def getheader(
im: Image.Image, palette: _Palette | None = None, info: dict[str, Any] | None = None
) -> tuple[list[bytes], list[int] | None]:
""" """
Legacy Method to get Gif data from image. Legacy Method to get Gif data from image.
@@ -995,11 +1146,11 @@ def getheader(im, palette=None, info=None):
:returns: tuple of(list of header items, optimized palette) :returns: tuple of(list of header items, optimized palette)
""" """
used_palette_colors = _get_optimize(im, info)
if info is None: if info is None:
info = {} info = {}
used_palette_colors = _get_optimize(im, info)
if "background" not in info and "background" in im.info: if "background" not in info and "background" in im.info:
info["background"] = im.info["background"] info["background"] = im.info["background"]
@@ -1011,7 +1162,9 @@ def getheader(im, palette=None, info=None):
return header, used_palette_colors return header, used_palette_colors
def getdata(im, offset=(0, 0), **params): def getdata(
im: Image.Image, offset: tuple[int, int] = (0, 0), **params: Any
) -> list[bytes]:
""" """
Legacy Method Legacy Method
@@ -1028,12 +1181,14 @@ def getdata(im, offset=(0, 0), **params):
:returns: List of bytes containing GIF encoded frame data :returns: List of bytes containing GIF encoded frame data
""" """
from io import BytesIO
class Collector: class Collector(BytesIO):
data = [] data = []
def write(self, data): def write(self, data: Buffer) -> int:
self.data.append(data) self.data.append(data)
return len(data)
im.load() # make sure raster data is available im.load() # make sure raster data is available

View File

@@ -18,17 +18,22 @@ Stuff to translate curve segments to palette values (derived from
the corresponding code in GIMP, written by Federico Mena Quintero. the corresponding code in GIMP, written by Federico Mena Quintero.
See the GIMP distribution for more information.) See the GIMP distribution for more information.)
""" """
from __future__ import annotations
from math import log, pi, sin, sqrt from math import log, pi, sin, sqrt
from ._binary import o8 from ._binary import o8
TYPE_CHECKING = False
if TYPE_CHECKING:
from collections.abc import Callable
from typing import IO
EPSILON = 1e-10 EPSILON = 1e-10
"""""" # Enable auto-doc for data member """""" # Enable auto-doc for data member
def linear(middle, pos): def linear(middle: float, pos: float) -> float:
if pos <= middle: if pos <= middle:
if middle < EPSILON: if middle < EPSILON:
return 0.0 return 0.0
@@ -43,19 +48,19 @@ def linear(middle, pos):
return 0.5 + 0.5 * pos / middle return 0.5 + 0.5 * pos / middle
def curved(middle, pos): def curved(middle: float, pos: float) -> float:
return pos ** (log(0.5) / log(max(middle, EPSILON))) return pos ** (log(0.5) / log(max(middle, EPSILON)))
def sine(middle, pos): def sine(middle: float, pos: float) -> float:
return (sin((-pi / 2.0) + pi * linear(middle, pos)) + 1.0) / 2.0 return (sin((-pi / 2.0) + pi * linear(middle, pos)) + 1.0) / 2.0
def sphere_increasing(middle, pos): def sphere_increasing(middle: float, pos: float) -> float:
return sqrt(1.0 - (linear(middle, pos) - 1.0) ** 2) return sqrt(1.0 - (linear(middle, pos) - 1.0) ** 2)
def sphere_decreasing(middle, pos): def sphere_decreasing(middle: float, pos: float) -> float:
return 1.0 - sqrt(1.0 - linear(middle, pos) ** 2) return 1.0 - sqrt(1.0 - linear(middle, pos) ** 2)
@@ -64,9 +69,22 @@ SEGMENTS = [linear, curved, sine, sphere_increasing, sphere_decreasing]
class GradientFile: class GradientFile:
gradient = None gradient: (
list[
tuple[
float,
float,
float,
list[float],
list[float],
Callable[[float, float], float],
]
]
| None
) = None
def getpalette(self, entries=256): def getpalette(self, entries: int = 256) -> tuple[bytes, str]:
assert self.gradient is not None
palette = [] palette = []
ix = 0 ix = 0
@@ -101,8 +119,8 @@ class GradientFile:
class GimpGradientFile(GradientFile): class GimpGradientFile(GradientFile):
"""File handler for GIMP's gradient format.""" """File handler for GIMP's gradient format."""
def __init__(self, fp): def __init__(self, fp: IO[bytes]) -> None:
if fp.readline()[:13] != b"GIMP Gradient": if not fp.readline().startswith(b"GIMP Gradient"):
msg = "not a GIMP gradient file" msg = "not a GIMP gradient file"
raise SyntaxError(msg) raise SyntaxError(msg)
@@ -114,7 +132,7 @@ class GimpGradientFile(GradientFile):
count = int(line) count = int(line)
gradient = [] self.gradient = []
for i in range(count): for i in range(count):
s = fp.readline().split() s = fp.readline().split()
@@ -132,6 +150,4 @@ class GimpGradientFile(GradientFile):
msg = "cannot handle HSV colour space" msg = "cannot handle HSV colour space"
raise OSError(msg) raise OSError(msg)
gradient.append((x0, x1, xm, rgb0, rgb1, segment)) self.gradient.append((x0, x1, xm, rgb0, rgb1, segment))
self.gradient = gradient

View File

@@ -13,10 +13,14 @@
# #
# See the README file for information on usage and redistribution. # See the README file for information on usage and redistribution.
# #
from __future__ import annotations
import re import re
from io import BytesIO
from ._binary import o8 TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import IO
class GimpPaletteFile: class GimpPaletteFile:
@@ -24,14 +28,18 @@ class GimpPaletteFile:
rawmode = "RGB" rawmode = "RGB"
def __init__(self, fp): def _read(self, fp: IO[bytes], limit: bool = True) -> None:
self.palette = [o8(i) * 3 for i in range(256)] if not fp.readline().startswith(b"GIMP Palette"):
if fp.readline()[:12] != b"GIMP Palette":
msg = "not a GIMP palette file" msg = "not a GIMP palette file"
raise SyntaxError(msg) raise SyntaxError(msg)
for i in range(256): palette: list[int] = []
i = 0
while True:
if limit and i == 256 + 3:
break
i += 1
s = fp.readline() s = fp.readline()
if not s: if not s:
break break
@@ -39,18 +47,29 @@ class GimpPaletteFile:
# skip fields and comment lines # skip fields and comment lines
if re.match(rb"\w+:|#", s): if re.match(rb"\w+:|#", s):
continue continue
if len(s) > 100: if limit and len(s) > 100:
msg = "bad palette file" msg = "bad palette file"
raise SyntaxError(msg) raise SyntaxError(msg)
v = tuple(map(int, s.split()[:3])) v = s.split(maxsplit=3)
if len(v) != 3: if len(v) < 3:
msg = "bad palette entry" msg = "bad palette entry"
raise ValueError(msg) raise ValueError(msg)
self.palette[i] = o8(v[0]) + o8(v[1]) + o8(v[2]) palette += (int(v[i]) for i in range(3))
if limit and len(palette) == 768:
break
self.palette = b"".join(self.palette) self.palette = bytes(palette)
def getpalette(self): def __init__(self, fp: IO[bytes]) -> None:
self._read(fp)
@classmethod
def frombytes(cls, data: bytes) -> GimpPaletteFile:
self = cls.__new__(cls)
self._read(BytesIO(data), False)
return self
def getpalette(self) -> tuple[bytes, str]:
return self.palette, self.rawmode return self.palette, self.rawmode

View File

@@ -8,13 +8,17 @@
# #
# See the README file for information on usage and redistribution. # See the README file for information on usage and redistribution.
# #
from __future__ import annotations
import os
from typing import IO
from . import Image, ImageFile from . import Image, ImageFile
_handler = None _handler = None
def register_handler(handler): def register_handler(handler: ImageFile.StubHandler | None) -> None:
""" """
Install application-specific GRIB image handler. Install application-specific GRIB image handler.
@@ -28,22 +32,20 @@ def register_handler(handler):
# Image adapter # Image adapter
def _accept(prefix): def _accept(prefix: bytes) -> bool:
return prefix[:4] == b"GRIB" and prefix[7] == 1 return len(prefix) >= 8 and prefix.startswith(b"GRIB") and prefix[7] == 1
class GribStubImageFile(ImageFile.StubImageFile): class GribStubImageFile(ImageFile.StubImageFile):
format = "GRIB" format = "GRIB"
format_description = "GRIB" format_description = "GRIB"
def _open(self): def _open(self) -> None:
offset = self.fp.tell()
if not _accept(self.fp.read(8)): if not _accept(self.fp.read(8)):
msg = "Not a GRIB file" msg = "Not a GRIB file"
raise SyntaxError(msg) raise SyntaxError(msg)
self.fp.seek(offset) self.fp.seek(-8, os.SEEK_CUR)
# make something up # make something up
self._mode = "F" self._mode = "F"
@@ -53,11 +55,11 @@ class GribStubImageFile(ImageFile.StubImageFile):
if loader: if loader:
loader.open(self) loader.open(self)
def _load(self): def _load(self) -> ImageFile.StubHandler | None:
return _handler return _handler
def _save(im, fp, filename): def _save(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
if _handler is None or not hasattr(_handler, "save"): if _handler is None or not hasattr(_handler, "save"):
msg = "GRIB save handler not installed" msg = "GRIB save handler not installed"
raise OSError(msg) raise OSError(msg)

Some files were not shown because too many files have changed in this diff Show More