This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -0,0 +1,50 @@
"""add_anonymous_gdpr_support
Revision ID: 6f7f8689fc98
Revises: 7a899ef55e3b
Create Date: 2025-12-01 04:15:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '6f7f8689fc98'
down_revision = '7a899ef55e3b'
branch_labels = None
depends_on = None
def upgrade() -> None:
# Update gdpr_requests table to support anonymous users
op.alter_column('gdpr_requests', 'user_id',
existing_type=sa.Integer(),
nullable=True)
op.add_column('gdpr_requests', sa.Column('is_anonymous', sa.Boolean(), nullable=False, server_default='0'))
op.create_index(op.f('ix_gdpr_requests_is_anonymous'), 'gdpr_requests', ['is_anonymous'], unique=False)
# Update consents table to support anonymous users
op.alter_column('consents', 'user_id',
existing_type=sa.Integer(),
nullable=True)
op.add_column('consents', sa.Column('user_email', sa.String(length=255), nullable=True))
op.add_column('consents', sa.Column('is_anonymous', sa.Boolean(), nullable=False, server_default='0'))
op.create_index(op.f('ix_consents_user_email'), 'consents', ['user_email'], unique=False)
op.create_index(op.f('ix_consents_is_anonymous'), 'consents', ['is_anonymous'], unique=False)
def downgrade() -> None:
op.drop_index(op.f('ix_consents_is_anonymous'), table_name='consents')
op.drop_index(op.f('ix_consents_user_email'), table_name='consents')
op.drop_column('consents', 'is_anonymous')
op.drop_column('consents', 'user_email')
op.alter_column('consents', 'user_id',
existing_type=sa.Integer(),
nullable=False)
op.drop_index(op.f('ix_gdpr_requests_is_anonymous'), table_name='gdpr_requests')
op.drop_column('gdpr_requests', 'is_anonymous')
op.alter_column('gdpr_requests', 'user_id',
existing_type=sa.Integer(),
nullable=False)

View File

@@ -0,0 +1,173 @@
"""add_comprehensive_gdpr_tables
Revision ID: 7a899ef55e3b
Revises: dbafe747c931
Create Date: 2025-12-01 04:10:25.699589
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision = '7a899ef55e3b'
down_revision = 'dbafe747c931'
branch_labels = None
depends_on = None
def upgrade() -> None:
# Consent table
op.create_table(
'consents',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('consent_type', sa.Enum('marketing', 'analytics', 'necessary', 'preferences', 'third_party_sharing', 'profiling', 'automated_decision_making', name='consenttype'), nullable=False),
sa.Column('status', sa.Enum('granted', 'withdrawn', 'pending', 'expired', name='consentstatus'), nullable=False),
sa.Column('granted_at', sa.DateTime(), nullable=True),
sa.Column('withdrawn_at', sa.DateTime(), nullable=True),
sa.Column('expires_at', sa.DateTime(), nullable=True),
sa.Column('legal_basis', sa.String(length=100), nullable=True),
sa.Column('consent_method', sa.String(length=50), nullable=True),
sa.Column('consent_version', sa.String(length=20), nullable=True),
sa.Column('ip_address', sa.String(length=45), nullable=True),
sa.Column('user_agent', sa.String(length=255), nullable=True),
sa.Column('source', sa.String(length=100), nullable=True),
sa.Column('extra_metadata', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_consents_id'), 'consents', ['id'], unique=False)
op.create_index(op.f('ix_consents_user_id'), 'consents', ['user_id'], unique=False)
op.create_index(op.f('ix_consents_consent_type'), 'consents', ['consent_type'], unique=False)
op.create_index(op.f('ix_consents_status'), 'consents', ['status'], unique=False)
op.create_index(op.f('ix_consents_created_at'), 'consents', ['created_at'], unique=False)
# Data processing records table
op.create_table(
'data_processing_records',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('processing_category', sa.Enum('collection', 'storage', 'usage', 'sharing', 'deletion', 'anonymization', 'transfer', name='processingcategory'), nullable=False),
sa.Column('legal_basis', sa.Enum('consent', 'contract', 'legal_obligation', 'vital_interests', 'public_task', 'legitimate_interests', name='legalbasis'), nullable=False),
sa.Column('purpose', sa.Text(), nullable=False),
sa.Column('data_categories', sa.JSON(), nullable=True),
sa.Column('data_subjects', sa.JSON(), nullable=True),
sa.Column('recipients', sa.JSON(), nullable=True),
sa.Column('third_parties', sa.JSON(), nullable=True),
sa.Column('transfers_to_third_countries', sa.Boolean(), nullable=False),
sa.Column('transfer_countries', sa.JSON(), nullable=True),
sa.Column('safeguards', sa.Text(), nullable=True),
sa.Column('retention_period', sa.String(length=100), nullable=True),
sa.Column('retention_criteria', sa.Text(), nullable=True),
sa.Column('security_measures', sa.Text(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('related_booking_id', sa.Integer(), nullable=True),
sa.Column('related_payment_id', sa.Integer(), nullable=True),
sa.Column('processed_by', sa.Integer(), nullable=True),
sa.Column('processing_timestamp', sa.DateTime(), nullable=False),
sa.Column('extra_metadata', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['processed_by'], ['users.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_data_processing_records_id'), 'data_processing_records', ['id'], unique=False)
op.create_index(op.f('ix_data_processing_records_processing_category'), 'data_processing_records', ['processing_category'], unique=False)
op.create_index(op.f('ix_data_processing_records_legal_basis'), 'data_processing_records', ['legal_basis'], unique=False)
op.create_index(op.f('ix_data_processing_records_user_id'), 'data_processing_records', ['user_id'], unique=False)
op.create_index(op.f('ix_data_processing_records_processing_timestamp'), 'data_processing_records', ['processing_timestamp'], unique=False)
op.create_index(op.f('ix_data_processing_records_created_at'), 'data_processing_records', ['created_at'], unique=False)
# Data breaches table
op.create_table(
'data_breaches',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('breach_type', sa.Enum('confidentiality', 'integrity', 'availability', name='breachtype'), nullable=False),
sa.Column('status', sa.Enum('detected', 'investigating', 'contained', 'reported_to_authority', 'notified_data_subjects', 'resolved', name='breachstatus'), nullable=False),
sa.Column('description', sa.Text(), nullable=False),
sa.Column('affected_data_categories', sa.JSON(), nullable=True),
sa.Column('affected_data_subjects', sa.JSON(), nullable=True),
sa.Column('detected_at', sa.DateTime(), nullable=False),
sa.Column('occurred_at', sa.DateTime(), nullable=True),
sa.Column('contained_at', sa.DateTime(), nullable=True),
sa.Column('reported_to_authority_at', sa.DateTime(), nullable=True),
sa.Column('authority_reference', sa.String(length=255), nullable=True),
sa.Column('notified_data_subjects_at', sa.DateTime(), nullable=True),
sa.Column('notification_method', sa.String(length=100), nullable=True),
sa.Column('likely_consequences', sa.Text(), nullable=True),
sa.Column('measures_proposed', sa.Text(), nullable=True),
sa.Column('risk_level', sa.String(length=20), nullable=True),
sa.Column('reported_by', sa.Integer(), nullable=False),
sa.Column('investigated_by', sa.Integer(), nullable=True),
sa.Column('extra_metadata', sa.JSON(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['investigated_by'], ['users.id'], ),
sa.ForeignKeyConstraint(['reported_by'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_data_breaches_id'), 'data_breaches', ['id'], unique=False)
op.create_index(op.f('ix_data_breaches_breach_type'), 'data_breaches', ['breach_type'], unique=False)
op.create_index(op.f('ix_data_breaches_status'), 'data_breaches', ['status'], unique=False)
op.create_index(op.f('ix_data_breaches_detected_at'), 'data_breaches', ['detected_at'], unique=False)
# Retention rules table
op.create_table(
'retention_rules',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('data_category', sa.String(length=100), nullable=False),
sa.Column('retention_period_days', sa.Integer(), nullable=False),
sa.Column('retention_period_months', sa.Integer(), nullable=True),
sa.Column('retention_period_years', sa.Integer(), nullable=True),
sa.Column('legal_basis', sa.Text(), nullable=True),
sa.Column('legal_requirement', sa.Text(), nullable=True),
sa.Column('action_after_retention', sa.String(length=50), nullable=False),
sa.Column('conditions', sa.JSON(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('created_by', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['created_by'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('data_category')
)
op.create_index(op.f('ix_retention_rules_id'), 'retention_rules', ['id'], unique=False)
op.create_index(op.f('ix_retention_rules_data_category'), 'retention_rules', ['data_category'], unique=True)
op.create_index(op.f('ix_retention_rules_is_active'), 'retention_rules', ['is_active'], unique=False)
# Data retention logs table
op.create_table(
'data_retention_logs',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('retention_rule_id', sa.Integer(), nullable=False),
sa.Column('data_category', sa.String(length=100), nullable=False),
sa.Column('action_taken', sa.String(length=50), nullable=False),
sa.Column('records_affected', sa.Integer(), nullable=False),
sa.Column('affected_ids', sa.JSON(), nullable=True),
sa.Column('executed_by', sa.Integer(), nullable=True),
sa.Column('executed_at', sa.DateTime(), nullable=False),
sa.Column('success', sa.Boolean(), nullable=False),
sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('extra_metadata', sa.JSON(), nullable=True),
sa.ForeignKeyConstraint(['executed_by'], ['users.id'], ),
sa.ForeignKeyConstraint(['retention_rule_id'], ['retention_rules.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_data_retention_logs_id'), 'data_retention_logs', ['id'], unique=False)
op.create_index(op.f('ix_data_retention_logs_retention_rule_id'), 'data_retention_logs', ['retention_rule_id'], unique=False)
op.create_index(op.f('ix_data_retention_logs_data_category'), 'data_retention_logs', ['data_category'], unique=False)
op.create_index(op.f('ix_data_retention_logs_executed_at'), 'data_retention_logs', ['executed_at'], unique=False)
def downgrade() -> None:
# Drop foreign keys first, then indexes, then tables
op.drop_table('data_retention_logs')
op.drop_table('retention_rules')
op.drop_table('data_breaches')
op.drop_table('data_processing_records')
op.drop_table('consents')

View File

@@ -1,26 +1,28 @@
fastapi==0.104.1
fastapi==0.123.0
uvicorn[standard]==0.24.0
python-dotenv==1.0.0
sqlalchemy==2.0.23
pymysql==1.1.0
pymysql==1.1.2
cryptography>=41.0.7
python-jose[cryptography]==3.3.0
python-jose[cryptography]==3.5.0
bcrypt==4.1.2
python-multipart==0.0.6
python-multipart==0.0.20
aiofiles==23.2.1
email-validator==2.1.0
pydantic==2.5.0
pydantic-settings==2.1.0
slowapi==0.1.9
pillow==10.1.0
pillow==12.0.0
aiosmtplib==3.0.1
jinja2==3.1.2
jinja2==3.1.6
alembic==1.12.1
stripe>=13.2.0
paypal-checkout-serversdk>=1.0.3
pyotp==2.9.0
qrcode[pil]==7.4.2
httpx==0.25.2
httpx==0.28.1
httpcore==1.0.9
h11==0.16.0
cryptography>=41.0.7
bleach==6.1.0

View File

@@ -1,13 +1,38 @@
import uvicorn
import signal
import sys
from src.shared.config.settings import settings
from src.shared.config.logging_config import setup_logging, get_logger
setup_logging()
logger = get_logger(__name__)
def signal_handler(sig, frame):
"""Handle Ctrl+C gracefully."""
logger.info('\nReceived interrupt signal (Ctrl+C). Shutting down gracefully...')
sys.exit(0)
if __name__ == '__main__':
# Register signal handler for graceful shutdown on Ctrl+C
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
logger.info(f'Starting {settings.APP_NAME} on {settings.HOST}:{settings.PORT}')
import os
from pathlib import Path
base_dir = Path(__file__).parent
src_dir = str(base_dir / 'src')
use_reload = False
uvicorn.run('src.main:app', host=settings.HOST, port=settings.PORT, reload=use_reload, log_level=settings.LOG_LEVEL.lower(), reload_dirs=[src_dir] if use_reload else None, reload_excludes=['*.log', '*.pyc', '*.pyo', '*.pyd', '__pycache__', '**/__pycache__/**', '*.db', '*.sqlite', '*.sqlite3'], reload_delay=1.0)
# Enable hot reload in development mode or if explicitly enabled via environment variable
use_reload = settings.is_development or os.getenv('ENABLE_RELOAD', 'false').lower() == 'true'
if use_reload:
logger.info('Hot reload enabled - server will restart on code changes')
logger.info('Press Ctrl+C to stop the server')
uvicorn.run(
'src.main:app',
host=settings.HOST,
port=settings.PORT,
reload=use_reload,
log_level=settings.LOG_LEVEL.lower(),
reload_dirs=[src_dir] if use_reload else None,
reload_excludes=['*.log', '*.pyc', '*.pyo', '*.pyd', '__pycache__', '**/__pycache__/**', '*.db', '*.sqlite', '*.sqlite3', 'venv/**', '.venv/**'],
reload_delay=0.5
)

View File

@@ -7,6 +7,7 @@ import uuid
import os
from ...shared.config.database import get_db
from ..services.auth_service import auth_service
from ..services.session_service import session_service
from ..schemas.auth import RegisterRequest, LoginRequest, RefreshTokenRequest, ForgotPasswordRequest, ResetPasswordRequest, AuthResponse, TokenResponse, MessageResponse, MFAInitResponse, EnableMFARequest, VerifyMFARequest, MFAStatusResponse, UpdateProfileRequest
from ...security.middleware.auth import get_current_user
from ..models.user import User
@@ -85,6 +86,26 @@ async def register(
path='/'
)
# Create user session for new registration
try:
# Extract device info from user agent
device_info = None
if user_agent:
device_info = {'user_agent': user_agent}
session_service.create_session(
db=db,
user_id=result['user']['id'],
ip_address=client_ip,
user_agent=user_agent,
device_info=str(device_info) if device_info else None
)
except Exception as e:
# Log error but don't fail registration if session creation fails
from ...shared.config.logging_config import get_logger
logger = get_logger(__name__)
logger.warning(f'Failed to create session during registration: {str(e)}')
# Log successful registration
await audit_service.log_action(
db=db,
@@ -171,6 +192,26 @@ async def login(
path='/'
)
# Create user session
try:
# Extract device info from user agent
device_info = None
if user_agent:
device_info = {'user_agent': user_agent}
session_service.create_session(
db=db,
user_id=result['user']['id'],
ip_address=client_ip,
user_agent=user_agent,
device_info=str(device_info) if device_info else None
)
except Exception as e:
# Log error but don't fail login if session creation fails
from ...shared.config.logging_config import get_logger
logger = get_logger(__name__)
logger.warning(f'Failed to create session during login: {str(e)}')
# Log successful login
await audit_service.log_action(
db=db,
@@ -394,16 +435,23 @@ async def upload_avatar(request: Request, image: UploadFile=File(...), current_u
# Validate file completely (MIME type, size, magic bytes, integrity)
content = await validate_uploaded_image(image, max_avatar_size)
upload_dir = Path(__file__).parent.parent.parent / 'uploads' / 'avatars'
# Use same path calculation as main.py: go from Backend/src/auth/routes/auth_routes.py
# to Backend/uploads/avatars
upload_dir = Path(__file__).parent.parent.parent.parent / 'uploads' / 'avatars'
upload_dir.mkdir(parents=True, exist_ok=True)
if current_user.avatar:
old_avatar_path = Path(__file__).parent.parent.parent / current_user.avatar.lstrip('/')
old_avatar_path = Path(__file__).parent.parent.parent.parent / current_user.avatar.lstrip('/')
if old_avatar_path.exists() and old_avatar_path.is_file():
try:
old_avatar_path.unlink()
except Exception:
pass
ext = Path(image.filename).suffix or '.png'
# Sanitize filename to prevent path traversal attacks
from ...shared.utils.sanitization import sanitize_filename
original_filename = image.filename or 'avatar.png'
sanitized_filename = sanitize_filename(original_filename)
ext = Path(sanitized_filename).suffix or '.png'
# Generate secure filename with user ID and UUID to prevent collisions
filename = f'avatar-{current_user.id}-{uuid.uuid4()}{ext}'
file_path = upload_dir / filename
async with aiofiles.open(file_path, 'wb') as f:

View File

@@ -1,14 +1,17 @@
"""
User session management routes.
"""
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request, Response, Cookie
from sqlalchemy.orm import Session
from ...shared.config.database import get_db
from ...shared.config.logging_config import get_logger
from ...shared.config.settings import settings
from ...security.middleware.auth import get_current_user
from ...auth.models.user import User
from ...auth.models.user_session import UserSession
from ...auth.services.session_service import session_service
from ...shared.utils.response_helpers import success_response
from jose import jwt
logger = get_logger(__name__)
router = APIRouter(prefix='/sessions', tags=['sessions'])
@@ -44,13 +47,15 @@ async def get_my_sessions(
@router.delete('/{session_id}')
async def revoke_session(
session_id: int,
request: Request,
response: Response,
current_user: User = Depends(get_current_user),
access_token: str = Cookie(None, alias='accessToken'),
db: Session = Depends(get_db)
):
"""Revoke a specific session."""
try:
# Verify session belongs to user
from ...auth.models.user_session import UserSession
session = db.query(UserSession).filter(
UserSession.id == session_id,
UserSession.user_id == current_user.id
@@ -59,10 +64,62 @@ async def revoke_session(
if not session:
raise HTTPException(status_code=404, detail='Session not found')
# Check if this is the current session being revoked
# We detect this by checking if:
# 1. The session IP matches the request IP (if available)
# 2. The session is the most recent active session
is_current_session = False
try:
client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent', '')
# Check if session matches current request characteristics
if client_ip and session.ip_address == client_ip:
# Also check if it's the most recent session
recent_session = db.query(UserSession).filter(
UserSession.user_id == current_user.id,
UserSession.is_active == True
).order_by(UserSession.last_activity.desc()).first()
if recent_session and recent_session.id == session_id:
is_current_session = True
except Exception as e:
logger.warning(f'Could not determine if session is current: {str(e)}')
# If we can't determine, check if it's the only active session
active_sessions_count = db.query(UserSession).filter(
UserSession.user_id == current_user.id,
UserSession.is_active == True
).count()
if active_sessions_count <= 1:
is_current_session = True
success = session_service.revoke_session(db=db, session_token=session.session_token)
if not success:
raise HTTPException(status_code=404, detail='Session not found')
# If this was the current session, clear cookies and indicate logout needed
if is_current_session:
from ...shared.config.settings import settings
samesite_value = 'strict' if settings.is_production else 'lax'
# Clear access token cookie
response.delete_cookie(
key='accessToken',
path='/',
samesite=samesite_value,
secure=settings.is_production
)
# Clear refresh token cookie
response.delete_cookie(
key='refreshToken',
path='/',
samesite=samesite_value,
secure=settings.is_production
)
return success_response(
message='Session revoked successfully. You have been logged out.',
data={'logout_required': True}
)
return success_response(message='Session revoked successfully')
except HTTPException:
raise
@@ -72,19 +129,41 @@ async def revoke_session(
@router.post('/revoke-all')
async def revoke_all_sessions(
request: Request,
response: Response,
current_user: User = Depends(get_current_user),
access_token: str = Cookie(None, alias='accessToken'),
db: Session = Depends(get_db)
):
"""Revoke all sessions for current user."""
try:
count = session_service.revoke_all_user_sessions(
db=db,
user_id=current_user.id
user_id=current_user.id,
exclude_token=None # Don't exclude current session, revoke all
)
# Clear cookies since all sessions (including current) are revoked
from ...shared.config.settings import settings
samesite_value = 'strict' if settings.is_production else 'lax'
# Clear access token cookie
response.delete_cookie(
key='accessToken',
path='/',
samesite=samesite_value,
secure=settings.is_production
)
# Clear refresh token cookie
response.delete_cookie(
key='refreshToken',
path='/',
samesite=samesite_value,
secure=settings.is_production
)
return success_response(
data={'revoked_count': count},
message=f'Revoked {count} session(s)'
data={'revoked_count': count, 'logout_required': True},
message=f'Revoked {count} session(s). You have been logged out.'
)
except Exception as e:
logger.error(f'Error revoking all sessions: {str(e)}', exc_info=True)

View File

@@ -29,19 +29,13 @@ class AuthService:
if not self.jwt_secret:
error_msg = (
'CRITICAL: JWT_SECRET is not configured. '
'Please set JWT_SECRET environment variable to a secure random string (minimum 32 characters).'
'Please set JWT_SECRET environment variable to a secure random string (minimum 64 characters). '
'Generate one using: python -c "import secrets; print(secrets.token_urlsafe(64))"'
)
logger.error(error_msg)
if settings.is_production:
# SECURITY: Always fail if JWT_SECRET is not configured, even in development
# This prevents accidental deployment without proper secrets
raise ValueError(error_msg)
else:
# In development, generate a secure secret but warn
import secrets
self.jwt_secret = secrets.token_urlsafe(64)
logger.warning(
f'JWT_SECRET not configured. Auto-generated secret for development. '
f'Set JWT_SECRET environment variable for production: {self.jwt_secret}'
)
# Validate JWT secret strength
if len(self.jwt_secret) < 32:
@@ -65,14 +59,37 @@ class AuthService:
self.jwt_refresh_expires_in = os.getenv("JWT_REFRESH_EXPIRES_IN", "7d")
def generate_tokens(self, user_id: int) -> dict:
from datetime import datetime, timedelta
# SECURITY: Add standard JWT claims for better security
now = datetime.utcnow()
access_expires = now + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
refresh_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
access_payload = {
"userId": user_id,
"exp": access_expires, # Expiration time
"iat": now, # Issued at
"iss": settings.APP_NAME, # Issuer
"type": "access" # Token type
}
refresh_payload = {
"userId": user_id,
"exp": refresh_expires, # Expiration time
"iat": now, # Issued at
"iss": settings.APP_NAME, # Issuer
"type": "refresh" # Token type
}
access_token = jwt.encode(
{"userId": user_id},
access_payload,
self.jwt_secret,
algorithm="HS256"
)
refresh_token = jwt.encode(
{"userId": user_id},
refresh_payload,
self.jwt_refresh_secret,
algorithm="HS256"
)
@@ -316,8 +333,22 @@ class AuthService:
db.commit()
raise ValueError("Refresh token expired")
from datetime import datetime, timedelta
# SECURITY: Add standard JWT claims when refreshing token
now = datetime.utcnow()
access_expires = now + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
access_payload = {
"userId": decoded["userId"],
"exp": access_expires, # Expiration time
"iat": now, # Issued at
"iss": settings.APP_NAME, # Issuer
"type": "access" # Token type
}
access_token = jwt.encode(
{"userId": decoded["userId"]},
access_payload,
self.jwt_secret,
algorithm="HS256"
)

View File

@@ -4,7 +4,7 @@ from sqlalchemy import and_, or_, func
from sqlalchemy.exc import IntegrityError
from typing import Optional
from datetime import datetime
import random
import secrets
import os
from ...shared.config.database import get_db
from ...shared.config.settings import settings
@@ -37,7 +37,8 @@ def _generate_invoice_email_html(invoice: dict, is_proforma: bool=False) -> str:
def generate_booking_number() -> str:
prefix = 'BK'
ts = int(datetime.utcnow().timestamp() * 1000)
rand = random.randint(1000, 9999)
# Use cryptographically secure random number to prevent enumeration attacks
rand = secrets.randbelow(9000) + 1000 # Random number between 1000-9999
return f'{prefix}-{ts}-{rand}'
def calculate_booking_payment_balance(booking: Booking) -> dict:

View File

@@ -1,7 +1,7 @@
from sqlalchemy.orm import Session
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
import random
import secrets
import string
from decimal import Decimal
from ..models.group_booking import (
@@ -21,11 +21,13 @@ class GroupBookingService:
@staticmethod
def generate_group_booking_number(db: Session) -> str:
"""Generate unique group booking number"""
"""Generate unique group booking number using cryptographically secure random"""
max_attempts = 10
alphabet = string.ascii_uppercase + string.digits
for _ in range(max_attempts):
timestamp = datetime.utcnow().strftime('%Y%m%d')
random_suffix = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6))
# Use secrets.choice() instead of random.choices() for security
random_suffix = ''.join(secrets.choice(alphabet) for _ in range(6))
booking_number = f"GRP-{timestamp}-{random_suffix}"
existing = db.query(GroupBooking).filter(
@@ -35,8 +37,9 @@ class GroupBookingService:
if not existing:
return booking_number
# Fallback
return f"GRP-{int(datetime.utcnow().timestamp())}"
# Fallback with secure random suffix
random_suffix = ''.join(secrets.choice(alphabet) for _ in range(4))
return f"GRP-{int(datetime.utcnow().timestamp())}{random_suffix}"
@staticmethod
def calculate_group_discount(
@@ -405,17 +408,19 @@ class GroupBookingService:
# Use proportional share
booking_price = group_booking.total_price / group_booking.total_rooms
# Generate booking number
import random
# Generate booking number using cryptographically secure random
prefix = 'BK'
ts = int(datetime.utcnow().timestamp() * 1000)
rand = random.randint(1000, 9999)
# Use secrets.randbelow() instead of random.randint() for security
rand = secrets.randbelow(9000) + 1000 # Random number between 1000-9999
booking_number = f'{prefix}-{ts}-{rand}'
# Ensure uniqueness
existing = db.query(Booking).filter(Booking.booking_number == booking_number).first()
if existing:
booking_number = f'{prefix}-{ts}-{rand + 1}'
# If collision, generate new secure random number
rand = secrets.randbelow(9000) + 1000
booking_number = f'{prefix}-{ts}-{rand}'
# Create booking
booking = Booking(

View File

@@ -0,0 +1,26 @@
"""
GDPR Compliance Models.
"""
from .gdpr_request import GDPRRequest, GDPRRequestType, GDPRRequestStatus
from .consent import Consent, ConsentType, ConsentStatus
from .data_processing_record import DataProcessingRecord, ProcessingCategory, LegalBasis
from .data_breach import DataBreach, BreachType, BreachStatus
from .data_retention import RetentionRule, DataRetentionLog
__all__ = [
'GDPRRequest',
'GDPRRequestType',
'GDPRRequestStatus',
'Consent',
'ConsentType',
'ConsentStatus',
'DataProcessingRecord',
'ProcessingCategory',
'LegalBasis',
'DataBreach',
'BreachType',
'BreachStatus',
'RetentionRule',
'DataRetentionLog',
]

View File

@@ -0,0 +1,64 @@
"""
GDPR Consent Management Model.
"""
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum, JSON, Boolean
from sqlalchemy.orm import relationship
from datetime import datetime
import enum
from ...shared.config.database import Base
class ConsentType(str, enum.Enum):
"""Types of consent that can be given or withdrawn."""
marketing = "marketing"
analytics = "analytics"
necessary = "necessary"
preferences = "preferences"
third_party_sharing = "third_party_sharing"
profiling = "profiling"
automated_decision_making = "automated_decision_making"
class ConsentStatus(str, enum.Enum):
"""Status of consent."""
granted = "granted"
withdrawn = "withdrawn"
pending = "pending"
expired = "expired"
class Consent(Base):
"""Model for tracking user consent for GDPR compliance."""
__tablename__ = 'consents'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('users.id'), nullable=True, index=True) # Nullable for anonymous users
user_email = Column(String(255), nullable=True, index=True) # Email for anonymous users
is_anonymous = Column(Boolean, default=False, nullable=False, index=True) # Flag for anonymous consent
consent_type = Column(Enum(ConsentType), nullable=False, index=True)
status = Column(Enum(ConsentStatus), default=ConsentStatus.granted, nullable=False, index=True)
# Consent details
granted_at = Column(DateTime, nullable=True)
withdrawn_at = Column(DateTime, nullable=True)
expires_at = Column(DateTime, nullable=True) # For time-limited consent
# Legal basis (Article 6 GDPR)
legal_basis = Column(String(100), nullable=True) # consent, contract, legal_obligation, vital_interests, public_task, legitimate_interests
# Consent method
consent_method = Column(String(50), nullable=True) # explicit, implicit, pre_checked
consent_version = Column(String(20), nullable=True) # Version of privacy policy when consent was given
# Metadata
ip_address = Column(String(45), nullable=True)
user_agent = Column(String(255), nullable=True)
source = Column(String(100), nullable=True) # Where consent was given (registration, cookie_banner, etc.)
# Additional data
extra_metadata = Column(JSON, nullable=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
# Relationships
user = relationship('User', foreign_keys=[user_id])

View File

@@ -0,0 +1,70 @@
"""
GDPR Data Breach Notification Model (Article 33-34 GDPR).
"""
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum, JSON, Boolean
from sqlalchemy.orm import relationship
from datetime import datetime
import enum
from ...shared.config.database import Base
class BreachType(str, enum.Enum):
"""Types of data breaches."""
confidentiality = "confidentiality" # Unauthorized disclosure
integrity = "integrity" # Unauthorized alteration
availability = "availability" # Unauthorized destruction or loss
class BreachStatus(str, enum.Enum):
"""Status of breach notification."""
detected = "detected"
investigating = "investigating"
contained = "contained"
reported_to_authority = "reported_to_authority"
notified_data_subjects = "notified_data_subjects"
resolved = "resolved"
class DataBreach(Base):
"""Data breach notification record (Articles 33-34 GDPR)."""
__tablename__ = 'data_breaches'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
# Breach details
breach_type = Column(Enum(BreachType), nullable=False, index=True)
status = Column(Enum(BreachStatus), default=BreachStatus.detected, nullable=False, index=True)
# Description
description = Column(Text, nullable=False) # Nature of the breach
affected_data_categories = Column(JSON, nullable=True) # Categories of personal data affected
affected_data_subjects = Column(JSON, nullable=True) # Approximate number of affected individuals
# Timeline
detected_at = Column(DateTime, nullable=False, index=True)
occurred_at = Column(DateTime, nullable=True) # When breach occurred (if known)
contained_at = Column(DateTime, nullable=True)
# Notification
reported_to_authority_at = Column(DateTime, nullable=True) # Article 33 - 72 hours
authority_reference = Column(String(255), nullable=True) # Reference from supervisory authority
notified_data_subjects_at = Column(DateTime, nullable=True) # Article 34 - without undue delay
notification_method = Column(String(100), nullable=True) # email, public_notice, etc.
# Risk assessment
likely_consequences = Column(Text, nullable=True)
measures_proposed = Column(Text, nullable=True) # Measures to address the breach
risk_level = Column(String(20), nullable=True) # low, medium, high
# Reporting
reported_by = Column(Integer, ForeignKey('users.id'), nullable=False) # Who detected/reported
investigated_by = Column(Integer, ForeignKey('users.id'), nullable=True) # DPO or responsible person
# Additional details
extra_metadata = Column(JSON, nullable=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
# Relationships
reporter = relationship('User', foreign_keys=[reported_by])
investigator = relationship('User', foreign_keys=[investigated_by])

View File

@@ -0,0 +1,78 @@
"""
GDPR Data Processing Records Model (Article 30 GDPR).
"""
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum, JSON, Boolean
from sqlalchemy.orm import relationship
from datetime import datetime
import enum
from ...shared.config.database import Base
class ProcessingCategory(str, enum.Enum):
"""Categories of data processing."""
collection = "collection"
storage = "storage"
usage = "usage"
sharing = "sharing"
deletion = "deletion"
anonymization = "anonymization"
transfer = "transfer"
class LegalBasis(str, enum.Enum):
"""Legal basis for processing (Article 6 GDPR)."""
consent = "consent"
contract = "contract"
legal_obligation = "legal_obligation"
vital_interests = "vital_interests"
public_task = "public_task"
legitimate_interests = "legitimate_interests"
class DataProcessingRecord(Base):
"""Record of data processing activities (Article 30 GDPR requirement)."""
__tablename__ = 'data_processing_records'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
# Processing details
processing_category = Column(Enum(ProcessingCategory), nullable=False, index=True)
legal_basis = Column(Enum(LegalBasis), nullable=False, index=True)
purpose = Column(Text, nullable=False) # Purpose of processing
# Data categories
data_categories = Column(JSON, nullable=True) # List of data categories processed
data_subjects = Column(JSON, nullable=True) # Categories of data subjects
# Recipients
recipients = Column(JSON, nullable=True) # Categories of recipients (internal, third_party, etc.)
third_parties = Column(JSON, nullable=True) # Specific third parties if any
# Transfers
transfers_to_third_countries = Column(Boolean, default=False, nullable=False)
transfer_countries = Column(JSON, nullable=True) # List of countries
safeguards = Column(Text, nullable=True) # Safeguards for transfers
# Retention
retention_period = Column(String(100), nullable=True) # How long data is retained
retention_criteria = Column(Text, nullable=True) # Criteria for determining retention period
# Security measures
security_measures = Column(Text, nullable=True)
# Related entities
user_id = Column(Integer, ForeignKey('users.id'), nullable=True, index=True) # If specific to a user
related_booking_id = Column(Integer, nullable=True, index=True)
related_payment_id = Column(Integer, nullable=True, index=True)
# Processing details
processed_by = Column(Integer, ForeignKey('users.id'), nullable=True) # Staff/admin who processed
processing_timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
# Additional metadata
extra_metadata = Column(JSON, nullable=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
# Relationships
user = relationship('User', foreign_keys=[user_id])
processor = relationship('User', foreign_keys=[processed_by])

View File

@@ -0,0 +1,75 @@
"""
GDPR Data Retention Policy Model.
"""
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum, JSON, Boolean
from sqlalchemy.orm import relationship
from datetime import datetime, timedelta
import enum
from ...shared.config.database import Base
class RetentionRule(Base):
"""Data retention rules for different data types."""
__tablename__ = 'retention_rules'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
# Rule details
data_category = Column(String(100), nullable=False, unique=True, index=True) # user_data, booking_data, payment_data, etc.
retention_period_days = Column(Integer, nullable=False) # Number of days to retain
retention_period_months = Column(Integer, nullable=True) # Alternative: months
retention_period_years = Column(Integer, nullable=True) # Alternative: years
# Legal basis
legal_basis = Column(Text, nullable=True) # Why we retain for this period
legal_requirement = Column(Text, nullable=True) # Specific legal requirement if any
# Action after retention
action_after_retention = Column(String(50), nullable=False, default='anonymize') # delete, anonymize, archive
# Conditions
conditions = Column(JSON, nullable=True) # Additional conditions (e.g., active bookings)
# Status
is_active = Column(Boolean, default=True, nullable=False, index=True)
# Metadata
description = Column(Text, nullable=True)
created_by = Column(Integer, ForeignKey('users.id'), nullable=True)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
# Relationships
creator = relationship('User', foreign_keys=[created_by])
class DataRetentionLog(Base):
"""Log of data retention actions performed."""
__tablename__ = 'data_retention_logs'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
# Retention action
retention_rule_id = Column(Integer, ForeignKey('retention_rules.id'), nullable=False, index=True)
data_category = Column(String(100), nullable=False, index=True)
action_taken = Column(String(50), nullable=False) # deleted, anonymized, archived
# Affected records
records_affected = Column(Integer, nullable=False, default=0)
affected_ids = Column(JSON, nullable=True) # IDs of affected records (for audit)
# Execution
executed_by = Column(Integer, ForeignKey('users.id'), nullable=True) # System or admin
executed_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
# Results
success = Column(Boolean, default=True, nullable=False)
error_message = Column(Text, nullable=True)
# Metadata
extra_metadata = Column(JSON, nullable=True)
# Relationships
retention_rule = relationship('RetentionRule', foreign_keys=[retention_rule_id])
executor = relationship('User', foreign_keys=[executed_by])

View File

@@ -1,7 +1,7 @@
"""
GDPR compliance models for data export and deletion requests.
"""
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum, JSON
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, Enum, JSON, Boolean
from sqlalchemy.orm import relationship
from datetime import datetime
import enum
@@ -27,9 +27,10 @@ class GDPRRequest(Base):
request_type = Column(Enum(GDPRRequestType), nullable=False, index=True)
status = Column(Enum(GDPRRequestStatus), default=GDPRRequestStatus.pending, nullable=False, index=True)
# User making the request
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True)
user_email = Column(String(255), nullable=False) # Store email even if user is deleted
# User making the request (nullable for anonymous users)
user_id = Column(Integer, ForeignKey('users.id'), nullable=True, index=True)
user_email = Column(String(255), nullable=False) # Required: email for anonymous or registered users
is_anonymous = Column(Boolean, default=False, nullable=False, index=True) # Flag for anonymous requests
# Request details
request_data = Column(JSON, nullable=True) # Additional request parameters

View File

@@ -0,0 +1,340 @@
"""
Admin routes for GDPR compliance management.
"""
from fastapi import APIRouter, Depends, HTTPException, Query, Body
from sqlalchemy.orm import Session
from typing import Optional, Dict, Any, List
from pydantic import BaseModel
from datetime import datetime
from ...shared.config.database import get_db
from ...shared.config.logging_config import get_logger
from ...security.middleware.auth import authorize_roles
from ...auth.models.user import User
from ..services.breach_service import breach_service
from ..services.retention_service import retention_service
from ..services.data_processing_service import data_processing_service
from ..models.data_breach import BreachType, BreachStatus
from ...shared.utils.response_helpers import success_response
logger = get_logger(__name__)
router = APIRouter(prefix='/gdpr/admin', tags=['gdpr-admin'])
# Data Breach Management
class BreachCreateRequest(BaseModel):
breach_type: str
description: str
affected_data_categories: Optional[List[str]] = None
affected_data_subjects: Optional[int] = None
occurred_at: Optional[str] = None
likely_consequences: Optional[str] = None
measures_proposed: Optional[str] = None
risk_level: Optional[str] = None
@router.post('/breaches')
async def create_breach(
breach_data: BreachCreateRequest,
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Create a data breach record (admin only)."""
try:
try:
breach_type_enum = BreachType(breach_data.breach_type)
except ValueError:
raise HTTPException(status_code=400, detail=f'Invalid breach type: {breach_data.breach_type}')
occurred_at = None
if breach_data.occurred_at:
occurred_at = datetime.fromisoformat(breach_data.occurred_at.replace('Z', '+00:00'))
breach = await breach_service.create_breach(
db=db,
breach_type=breach_type_enum,
description=breach_data.description,
reported_by=current_user.id,
affected_data_categories=breach_data.affected_data_categories,
affected_data_subjects=breach_data.affected_data_subjects,
occurred_at=occurred_at,
likely_consequences=breach_data.likely_consequences,
measures_proposed=breach_data.measures_proposed,
risk_level=breach_data.risk_level
)
return success_response(
data={
'breach_id': breach.id,
'status': breach.status.value,
'detected_at': breach.detected_at.isoformat()
},
message='Data breach record created'
)
except HTTPException:
raise
except Exception as e:
logger.error(f'Error creating breach: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get('/breaches')
async def get_breaches(
status: Optional[str] = Query(None),
page: int = Query(1, ge=1),
limit: int = Query(20, ge=1, le=100),
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Get all data breaches (admin only)."""
try:
status_enum = None
if status:
try:
status_enum = BreachStatus(status)
except ValueError:
raise HTTPException(status_code=400, detail=f'Invalid status: {status}')
offset = (page - 1) * limit
breaches = breach_service.get_breaches(
db=db,
status=status_enum,
limit=limit,
offset=offset
)
return success_response(data={
'breaches': [{
'id': breach.id,
'breach_type': breach.breach_type.value,
'status': breach.status.value,
'description': breach.description,
'risk_level': breach.risk_level,
'detected_at': breach.detected_at.isoformat(),
'reported_to_authority_at': breach.reported_to_authority_at.isoformat() if breach.reported_to_authority_at else None,
'notified_data_subjects_at': breach.notified_data_subjects_at.isoformat() if breach.notified_data_subjects_at else None,
} for breach in breaches]
})
except HTTPException:
raise
except Exception as e:
logger.error(f'Error getting breaches: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post('/breaches/{breach_id}/report-authority')
async def report_breach_to_authority(
breach_id: int,
authority_reference: str = Body(...),
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Report breach to supervisory authority (admin only)."""
try:
breach = await breach_service.report_to_authority(
db=db,
breach_id=breach_id,
authority_reference=authority_reference,
reported_by=current_user.id
)
return success_response(
data={
'breach_id': breach.id,
'authority_reference': breach.authority_reference,
'reported_at': breach.reported_to_authority_at.isoformat()
},
message='Breach reported to supervisory authority'
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error(f'Error reporting breach: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post('/breaches/{breach_id}/notify-subjects')
async def notify_data_subjects(
breach_id: int,
notification_method: str = Body(...),
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Notify affected data subjects (admin only)."""
try:
breach = await breach_service.notify_data_subjects(
db=db,
breach_id=breach_id,
notification_method=notification_method,
notified_by=current_user.id
)
return success_response(
data={
'breach_id': breach.id,
'notification_method': breach.notification_method,
'notified_at': breach.notified_data_subjects_at.isoformat()
},
message='Data subjects notified'
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error(f'Error notifying subjects: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# Data Retention Management
class RetentionRuleCreateRequest(BaseModel):
data_category: str
retention_period_days: int
retention_period_months: Optional[int] = None
retention_period_years: Optional[int] = None
legal_basis: Optional[str] = None
legal_requirement: Optional[str] = None
action_after_retention: str = 'anonymize'
conditions: Optional[Dict[str, Any]] = None
description: Optional[str] = None
@router.post('/retention-rules')
async def create_retention_rule(
rule_data: RetentionRuleCreateRequest,
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Create a data retention rule (admin only)."""
try:
rule = retention_service.create_retention_rule(
db=db,
data_category=rule_data.data_category,
retention_period_days=rule_data.retention_period_days,
retention_period_months=rule_data.retention_period_months,
retention_period_years=rule_data.retention_period_years,
legal_basis=rule_data.legal_basis,
legal_requirement=rule_data.legal_requirement,
action_after_retention=rule_data.action_after_retention,
conditions=rule_data.conditions,
description=rule_data.description,
created_by=current_user.id
)
return success_response(
data={
'rule_id': rule.id,
'data_category': rule.data_category,
'retention_period_days': rule.retention_period_days
},
message='Retention rule created successfully'
)
except Exception as e:
logger.error(f'Error creating retention rule: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get('/retention-rules')
async def get_retention_rules(
is_active: Optional[bool] = Query(None),
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Get retention rules (admin only)."""
try:
rules = retention_service.get_retention_rules(db=db, is_active=is_active)
return success_response(data={
'rules': [{
'id': rule.id,
'data_category': rule.data_category,
'retention_period_days': rule.retention_period_days,
'action_after_retention': rule.action_after_retention,
'is_active': rule.is_active,
'legal_basis': rule.legal_basis
} for rule in rules]
})
except Exception as e:
logger.error(f'Error getting retention rules: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get('/retention-logs')
async def get_retention_logs(
data_category: Optional[str] = Query(None),
page: int = Query(1, ge=1),
limit: int = Query(50, ge=1, le=100),
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Get retention action logs (admin only)."""
try:
offset = (page - 1) * limit
logs = retention_service.get_retention_logs(
db=db,
data_category=data_category,
limit=limit,
offset=offset
)
return success_response(data={
'logs': [{
'id': log.id,
'data_category': log.data_category,
'action_taken': log.action_taken,
'records_affected': log.records_affected,
'executed_at': log.executed_at.isoformat(),
'success': log.success
} for log in logs]
})
except Exception as e:
logger.error(f'Error getting retention logs: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# Data Processing Records (Admin View)
@router.get('/processing-records')
async def get_all_processing_records(
user_id: Optional[int] = Query(None),
processing_category: Optional[str] = Query(None),
legal_basis: Optional[str] = Query(None),
page: int = Query(1, ge=1),
limit: int = Query(50, ge=1, le=100),
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Get all data processing records (admin only)."""
try:
from ..models.data_processing_record import ProcessingCategory, LegalBasis
category_enum = None
if processing_category:
try:
category_enum = ProcessingCategory(processing_category)
except ValueError:
raise HTTPException(status_code=400, detail=f'Invalid processing category: {processing_category}')
basis_enum = None
if legal_basis:
try:
basis_enum = LegalBasis(legal_basis)
except ValueError:
raise HTTPException(status_code=400, detail=f'Invalid legal basis: {legal_basis}')
offset = (page - 1) * limit
records = data_processing_service.get_processing_records(
db=db,
user_id=user_id,
processing_category=category_enum,
legal_basis=basis_enum,
limit=limit,
offset=offset
)
return success_response(data={
'records': [{
'id': record.id,
'processing_category': record.processing_category.value,
'legal_basis': record.legal_basis.value,
'purpose': record.purpose,
'processing_timestamp': record.processing_timestamp.isoformat(),
'user_id': record.user_id
} for record in records]
})
except HTTPException:
raise
except Exception as e:
logger.error(f'Error getting processing records: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -3,35 +3,64 @@ GDPR compliance routes for data export and deletion.
"""
from fastapi import APIRouter, Depends, HTTPException, Query, Response
from sqlalchemy.orm import Session, noload
from sqlalchemy import or_
from typing import Optional
from datetime import datetime
from ...shared.config.database import get_db
from ...shared.config.logging_config import get_logger
from ...security.middleware.auth import get_current_user, authorize_roles
from ...security.middleware.auth import get_current_user, authorize_roles, get_current_user_optional
from ...auth.models.user import User
from ..services.gdpr_service import gdpr_service
from ..services.consent_service import consent_service
from ..services.data_processing_service import data_processing_service
from ..models.gdpr_request import GDPRRequest, GDPRRequestType, GDPRRequestStatus
from ..models.consent import ConsentType, ConsentStatus
from ...shared.utils.response_helpers import success_response
from fastapi import Request
from pydantic import BaseModel
from typing import Dict, Any, Optional, List
logger = get_logger(__name__)
router = APIRouter(prefix='/gdpr', tags=['gdpr'])
class AnonymousExportRequest(BaseModel):
email: str
@router.post('/export')
async def request_data_export(
request: Request,
current_user: User = Depends(get_current_user),
anonymous_request: Optional[AnonymousExportRequest] = None,
current_user: Optional[User] = Depends(get_current_user_optional),
db: Session = Depends(get_db)
):
"""Request export of user's personal data (GDPR)."""
"""Request export of user's personal data (GDPR) - supports both authenticated and anonymous users."""
try:
client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent')
# Check if authenticated or anonymous
if current_user:
# Authenticated user
gdpr_request = await gdpr_service.create_data_export_request(
db=db,
user_id=current_user.id,
ip_address=client_ip,
user_agent=user_agent
user_agent=user_agent,
is_anonymous=False
)
elif anonymous_request and anonymous_request.email:
# Anonymous user - requires email
gdpr_request = await gdpr_service.create_data_export_request(
db=db,
user_email=anonymous_request.email,
ip_address=client_ip,
user_agent=user_agent,
is_anonymous=True
)
else:
raise HTTPException(
status_code=400,
detail='Either authentication required or email must be provided for anonymous requests'
)
return success_response(
@@ -39,10 +68,13 @@ async def request_data_export(
'request_id': gdpr_request.id,
'verification_token': gdpr_request.verification_token,
'status': gdpr_request.status.value,
'expires_at': gdpr_request.expires_at.isoformat() if gdpr_request.expires_at else None
'expires_at': gdpr_request.expires_at.isoformat() if gdpr_request.expires_at else None,
'is_anonymous': gdpr_request.is_anonymous
},
message='Data export request created. You will receive an email with download link once ready.'
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f'Error creating data export request: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@@ -51,20 +83,26 @@ async def request_data_export(
async def get_export_data(
request_id: int,
verification_token: str = Query(...),
current_user: User = Depends(get_current_user),
current_user: Optional[User] = Depends(get_current_user_optional),
db: Session = Depends(get_db)
):
"""Get exported user data."""
"""Get exported user data - supports both authenticated and anonymous users via verification token."""
try:
gdpr_request = db.query(GDPRRequest).options(
# Build query - verification token is required for both authenticated and anonymous
query = db.query(GDPRRequest).options(
noload(GDPRRequest.user),
noload(GDPRRequest.processor)
).filter(
GDPRRequest.id == request_id,
GDPRRequest.user_id == current_user.id,
GDPRRequest.verification_token == verification_token,
GDPRRequest.request_type == GDPRRequestType.data_export
).first()
)
# For authenticated users, also verify user_id matches
if current_user:
query = query.filter(GDPRRequest.user_id == current_user.id)
gdpr_request = query.first()
if not gdpr_request:
raise HTTPException(status_code=404, detail='Export request not found or invalid token')
@@ -73,8 +111,10 @@ async def get_export_data(
# Process export
export_data = await gdpr_service.export_user_data(
db=db,
user_id=current_user.id,
request_id=request_id
user_id=gdpr_request.user_id,
user_email=gdpr_request.user_email,
request_id=request_id,
is_anonymous=gdpr_request.is_anonymous
)
return success_response(data=export_data)
elif gdpr_request.status == GDPRRequestStatus.completed and gdpr_request.export_file_path:
@@ -97,32 +137,57 @@ async def get_export_data(
logger.error(f'Error getting export data: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
class AnonymousDeletionRequest(BaseModel):
email: str
@router.post('/delete')
async def request_data_deletion(
request: Request,
current_user: User = Depends(get_current_user),
anonymous_request: Optional[AnonymousDeletionRequest] = None,
current_user: Optional[User] = Depends(get_current_user_optional),
db: Session = Depends(get_db)
):
"""Request deletion of user's personal data (GDPR - Right to be Forgotten)."""
"""Request deletion of user's personal data (GDPR - Right to be Forgotten) - supports anonymous users."""
try:
client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent')
# Check if authenticated or anonymous
if current_user:
# Authenticated user
gdpr_request = await gdpr_service.create_data_deletion_request(
db=db,
user_id=current_user.id,
ip_address=client_ip,
user_agent=user_agent
user_agent=user_agent,
is_anonymous=False
)
elif anonymous_request and anonymous_request.email:
# Anonymous user - requires email
gdpr_request = await gdpr_service.create_data_deletion_request(
db=db,
user_email=anonymous_request.email,
ip_address=client_ip,
user_agent=user_agent,
is_anonymous=True
)
else:
raise HTTPException(
status_code=400,
detail='Either authentication required or email must be provided for anonymous requests'
)
return success_response(
data={
'request_id': gdpr_request.id,
'verification_token': gdpr_request.verification_token,
'status': gdpr_request.status.value
'status': gdpr_request.status.value,
'is_anonymous': gdpr_request.is_anonymous
},
message='Data deletion request created. Please verify via email to proceed.'
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f'Error creating data deletion request: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@@ -131,21 +196,27 @@ async def request_data_deletion(
async def confirm_data_deletion(
request_id: int,
verification_token: str = Query(...),
current_user: User = Depends(get_current_user),
current_user: Optional[User] = Depends(get_current_user_optional),
db: Session = Depends(get_db)
):
"""Confirm and process data deletion request."""
"""Confirm and process data deletion request - supports anonymous users via verification token."""
try:
gdpr_request = db.query(GDPRRequest).options(
# Build query - verification token is required for both authenticated and anonymous
query = db.query(GDPRRequest).options(
noload(GDPRRequest.user),
noload(GDPRRequest.processor)
).filter(
GDPRRequest.id == request_id,
GDPRRequest.user_id == current_user.id,
GDPRRequest.verification_token == verification_token,
GDPRRequest.request_type == GDPRRequestType.data_deletion,
GDPRRequest.status == GDPRRequestStatus.pending
).first()
)
# For authenticated users, also verify user_id matches
if current_user:
query = query.filter(GDPRRequest.user_id == current_user.id)
gdpr_request = query.first()
if not gdpr_request:
raise HTTPException(status_code=404, detail='Deletion request not found or already processed')
@@ -153,14 +224,16 @@ async def confirm_data_deletion(
# Process deletion
deletion_log = await gdpr_service.delete_user_data(
db=db,
user_id=current_user.id,
user_id=gdpr_request.user_id,
user_email=gdpr_request.user_email,
request_id=request_id,
processed_by=current_user.id
processed_by=current_user.id if current_user else None,
is_anonymous=gdpr_request.is_anonymous
)
return success_response(
data=deletion_log,
message='Your data has been deleted successfully.'
message=deletion_log.get('summary', {}).get('message', 'Your data has been deleted successfully.')
)
except HTTPException:
raise
@@ -173,13 +246,17 @@ async def get_user_gdpr_requests(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get user's GDPR requests."""
"""Get user's GDPR requests (both authenticated and anonymous requests by email)."""
try:
# Get requests by user_id (authenticated) or by email (includes anonymous)
requests = db.query(GDPRRequest).options(
noload(GDPRRequest.user),
noload(GDPRRequest.processor)
).filter(
GDPRRequest.user_id == current_user.id
or_(
GDPRRequest.user_id == current_user.id,
GDPRRequest.user_email == current_user.email
)
).order_by(GDPRRequest.created_at.desc()).all()
return success_response(data={
@@ -187,6 +264,7 @@ async def get_user_gdpr_requests(
'id': req.id,
'request_type': req.request_type.value,
'status': req.status.value,
'is_anonymous': req.is_anonymous,
'created_at': req.created_at.isoformat() if req.created_at else None,
'processed_at': req.processed_at.isoformat() if req.processed_at else None,
} for req in requests]
@@ -270,3 +348,272 @@ async def delete_gdpr_request(
logger.error(f'Error deleting GDPR request: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# GDPR Rights - Additional Routes
class DataRectificationRequest(BaseModel):
corrections: Dict[str, Any] # e.g., {"full_name": "New Name", "email": "new@email.com"}
@router.post('/rectify')
async def request_data_rectification(
request: Request,
rectification_data: DataRectificationRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Request data rectification (Article 16 GDPR - Right to rectification)."""
try:
client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent')
gdpr_request = await gdpr_service.request_data_rectification(
db=db,
user_id=current_user.id,
corrections=rectification_data.corrections,
ip_address=client_ip,
user_agent=user_agent
)
return success_response(
data={
'request_id': gdpr_request.id,
'verification_token': gdpr_request.verification_token,
'status': gdpr_request.status.value
},
message='Data rectification request created. An admin will review and process your request.'
)
except Exception as e:
logger.error(f'Error creating rectification request: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
class ProcessingRestrictionRequest(BaseModel):
reason: str
@router.post('/restrict')
async def request_processing_restriction(
request: Request,
restriction_data: ProcessingRestrictionRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Request restriction of processing (Article 18 GDPR)."""
try:
client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent')
gdpr_request = await gdpr_service.request_processing_restriction(
db=db,
user_id=current_user.id,
reason=restriction_data.reason,
ip_address=client_ip,
user_agent=user_agent
)
return success_response(
data={
'request_id': gdpr_request.id,
'verification_token': gdpr_request.verification_token,
'status': gdpr_request.status.value
},
message='Processing restriction request created. Your account has been temporarily restricted.'
)
except Exception as e:
logger.error(f'Error creating restriction request: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
class ProcessingObjectionRequest(BaseModel):
processing_purpose: str
reason: Optional[str] = None
@router.post('/object')
async def request_processing_objection(
request: Request,
objection_data: ProcessingObjectionRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Object to processing (Article 21 GDPR - Right to object)."""
try:
client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent')
gdpr_request = await gdpr_service.request_processing_objection(
db=db,
user_id=current_user.id,
processing_purpose=objection_data.processing_purpose,
reason=objection_data.reason,
ip_address=client_ip,
user_agent=user_agent
)
return success_response(
data={
'request_id': gdpr_request.id,
'verification_token': gdpr_request.verification_token,
'status': gdpr_request.status.value
},
message='Processing objection registered. We will review your objection and stop processing for the specified purpose if valid.'
)
except Exception as e:
logger.error(f'Error creating objection request: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# Consent Management Routes
class ConsentUpdateRequest(BaseModel):
consents: Dict[str, bool] # e.g., {"marketing": true, "analytics": false}
@router.get('/consents')
async def get_user_consents(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get user's consent status for all consent types."""
try:
consents = consent_service.get_user_consents(db=db, user_id=current_user.id, include_withdrawn=True)
consent_status = {}
for consent_type in ConsentType:
consent_status[consent_type.value] = {
'has_consent': consent_service.has_consent(db=db, user_id=current_user.id, consent_type=consent_type),
'granted_at': None,
'withdrawn_at': None,
'status': 'none'
}
for consent in consents:
consent_status[consent.consent_type.value] = {
'has_consent': consent.status == ConsentStatus.granted and (not consent.expires_at or consent.expires_at > datetime.utcnow()),
'granted_at': consent.granted_at.isoformat() if consent.granted_at else None,
'withdrawn_at': consent.withdrawn_at.isoformat() if consent.withdrawn_at else None,
'status': consent.status.value,
'expires_at': consent.expires_at.isoformat() if consent.expires_at else None
}
return success_response(data={'consents': consent_status})
except Exception as e:
logger.error(f'Error getting consents: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post('/consents')
async def update_consents(
request: Request,
consent_data: ConsentUpdateRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Update user consent preferences."""
try:
client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent')
# Convert string keys to ConsentType enum
consents_dict = {}
for key, value in consent_data.consents.items():
try:
consent_type = ConsentType(key)
consents_dict[consent_type] = value
except ValueError:
continue
results = await consent_service.update_consent_preferences(
db=db,
user_id=current_user.id,
consents=consents_dict,
legal_basis='consent',
ip_address=client_ip,
user_agent=user_agent,
source='gdpr_page'
)
return success_response(
data={'updated_consents': len(results)},
message='Consent preferences updated successfully'
)
except Exception as e:
logger.error(f'Error updating consents: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post('/consents/{consent_type}/withdraw')
async def withdraw_consent(
request: Request,
consent_type: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Withdraw specific consent (Article 7(3) GDPR)."""
try:
client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent')
try:
consent_type_enum = ConsentType(consent_type)
except ValueError:
raise HTTPException(status_code=400, detail=f'Invalid consent type: {consent_type}')
consent = await consent_service.withdraw_consent(
db=db,
user_id=current_user.id,
consent_type=consent_type_enum,
ip_address=client_ip,
user_agent=user_agent
)
return success_response(
data={
'consent_id': consent.id,
'consent_type': consent.consent_type.value,
'withdrawn_at': consent.withdrawn_at.isoformat() if consent.withdrawn_at else None
},
message=f'Consent for {consent_type} withdrawn successfully'
)
except HTTPException:
raise
except Exception as e:
logger.error(f'Error withdrawing consent: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# Data Processing Records (User View)
@router.get('/processing-records')
async def get_user_processing_records(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get data processing records for the user (Article 15 GDPR - Right of access)."""
try:
summary = data_processing_service.get_user_processing_summary(
db=db,
user_id=current_user.id
)
return success_response(data=summary)
except Exception as e:
logger.error(f'Error getting processing records: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# Admin Routes for Processing Requests
@router.post('/admin/rectify/{request_id}/process')
async def process_rectification(
request_id: int,
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Process data rectification request (admin only)."""
try:
result = await gdpr_service.process_data_rectification(
db=db,
request_id=request_id,
processed_by=current_user.id
)
return success_response(
data=result,
message='Data rectification processed successfully'
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error(f'Error processing rectification: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,169 @@
"""
Data Breach Notification Service (Articles 33-34 GDPR).
"""
from sqlalchemy.orm import Session
from typing import Dict, Any, Optional, List
from datetime import datetime, timedelta
from ..models.data_breach import DataBreach, BreachType, BreachStatus
from ...shared.config.logging_config import get_logger
from ...analytics.services.audit_service import audit_service
logger = get_logger(__name__)
class BreachService:
"""Service for managing data breach notifications (Articles 33-34 GDPR)."""
NOTIFICATION_DEADLINE_HOURS = 72 # Article 33 - 72 hours to notify authority
@staticmethod
async def create_breach(
db: Session,
breach_type: BreachType,
description: str,
reported_by: int,
affected_data_categories: Optional[List[str]] = None,
affected_data_subjects: Optional[int] = None,
occurred_at: Optional[datetime] = None,
likely_consequences: Optional[str] = None,
measures_proposed: Optional[str] = None,
risk_level: Optional[str] = None,
extra_metadata: Optional[Dict[str, Any]] = None
) -> DataBreach:
"""Create a data breach record."""
breach = DataBreach(
breach_type=breach_type,
status=BreachStatus.detected,
description=description,
affected_data_categories=affected_data_categories or [],
affected_data_subjects=affected_data_subjects,
detected_at=datetime.utcnow(),
occurred_at=occurred_at or datetime.utcnow(),
likely_consequences=likely_consequences,
measures_proposed=measures_proposed,
risk_level=risk_level or 'medium',
reported_by=reported_by,
extra_metadata=extra_metadata
)
db.add(breach)
db.commit()
db.refresh(breach)
# Log breach detection
await audit_service.log_action(
db=db,
action='data_breach_detected',
resource_type='data_breach',
user_id=reported_by,
resource_id=breach.id,
details={
'breach_type': breach_type.value,
'risk_level': risk_level,
'affected_subjects': affected_data_subjects
},
status='warning'
)
logger.warning(f'Data breach detected: {breach.id} - {breach_type.value}')
return breach
@staticmethod
async def report_to_authority(
db: Session,
breach_id: int,
authority_reference: str,
reported_by: int
) -> DataBreach:
"""Report breach to supervisory authority (Article 33)."""
breach = db.query(DataBreach).filter(DataBreach.id == breach_id).first()
if not breach:
raise ValueError('Breach not found')
breach.status = BreachStatus.reported_to_authority
breach.reported_to_authority_at = datetime.utcnow()
breach.authority_reference = authority_reference
db.commit()
db.refresh(breach)
# Check if within deadline
time_since_detection = datetime.utcnow() - breach.detected_at
if time_since_detection > timedelta(hours=BreachService.NOTIFICATION_DEADLINE_HOURS):
logger.warning(f'Breach {breach_id} reported after {BreachService.NOTIFICATION_DEADLINE_HOURS} hour deadline')
# Log report
await audit_service.log_action(
db=db,
action='breach_reported_to_authority',
resource_type='data_breach',
user_id=reported_by,
resource_id=breach_id,
details={'authority_reference': authority_reference},
status='success'
)
logger.info(f'Breach {breach_id} reported to authority: {authority_reference}')
return breach
@staticmethod
async def notify_data_subjects(
db: Session,
breach_id: int,
notification_method: str,
notified_by: int
) -> DataBreach:
"""Notify affected data subjects (Article 34)."""
breach = db.query(DataBreach).filter(DataBreach.id == breach_id).first()
if not breach:
raise ValueError('Breach not found')
breach.status = BreachStatus.notified_data_subjects
breach.notified_data_subjects_at = datetime.utcnow()
breach.notification_method = notification_method
db.commit()
db.refresh(breach)
# Log notification
await audit_service.log_action(
db=db,
action='breach_subjects_notified',
resource_type='data_breach',
user_id=notified_by,
resource_id=breach_id,
details={'notification_method': notification_method},
status='success'
)
logger.info(f'Data subjects notified for breach {breach_id}')
return breach
@staticmethod
def get_breaches(
db: Session,
status: Optional[BreachStatus] = None,
limit: int = 50,
offset: int = 0
) -> List[DataBreach]:
"""Get data breaches with optional filters."""
query = db.query(DataBreach)
if status:
query = query.filter(DataBreach.status == status)
return query.order_by(DataBreach.detected_at.desc()).offset(offset).limit(limit).all()
@staticmethod
def get_breaches_requiring_notification(
db: Session
) -> List[DataBreach]:
"""Get breaches that require notification (not yet reported)."""
deadline = datetime.utcnow() - timedelta(hours=BreachService.NOTIFICATION_DEADLINE_HOURS)
return db.query(DataBreach).filter(
DataBreach.status.in_([BreachStatus.detected, BreachStatus.investigating]),
DataBreach.detected_at < deadline
).all()
breach_service = BreachService()

View File

@@ -0,0 +1,202 @@
"""
GDPR Consent Management Service.
"""
from sqlalchemy.orm import Session
from typing import Dict, Any, Optional, List
from datetime import datetime, timedelta
from ..models.consent import Consent, ConsentType, ConsentStatus
from ...auth.models.user import User
from ...shared.config.logging_config import get_logger
from ...analytics.services.audit_service import audit_service
logger = get_logger(__name__)
class ConsentService:
"""Service for managing user consent (Article 7 GDPR)."""
@staticmethod
async def grant_consent(
db: Session,
user_id: int,
consent_type: ConsentType,
legal_basis: str,
consent_method: str = 'explicit',
consent_version: Optional[str] = None,
expires_at: Optional[datetime] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
source: Optional[str] = None,
extra_metadata: Optional[Dict[str, Any]] = None
) -> Consent:
"""Grant consent for a specific purpose."""
# Withdraw any existing consent of this type
existing = db.query(Consent).filter(
Consent.user_id == user_id,
Consent.consent_type == consent_type,
Consent.status == ConsentStatus.granted
).first()
if existing:
existing.status = ConsentStatus.withdrawn
existing.withdrawn_at = datetime.utcnow()
# Create new consent
consent = Consent(
user_id=user_id,
consent_type=consent_type,
status=ConsentStatus.granted,
granted_at=datetime.utcnow(),
expires_at=expires_at,
legal_basis=legal_basis,
consent_method=consent_method,
consent_version=consent_version,
ip_address=ip_address,
user_agent=user_agent,
source=source,
extra_metadata=extra_metadata
)
db.add(consent)
db.commit()
db.refresh(consent)
# Log consent grant
await audit_service.log_action(
db=db,
action='consent_granted',
resource_type='consent',
user_id=user_id,
resource_id=consent.id,
ip_address=ip_address,
user_agent=user_agent,
details={
'consent_type': consent_type.value,
'legal_basis': legal_basis,
'consent_method': consent_method
},
status='success'
)
logger.info(f'Consent granted: {consent_type.value} for user {user_id}')
return consent
@staticmethod
async def withdraw_consent(
db: Session,
user_id: int,
consent_type: ConsentType,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None
) -> Consent:
"""Withdraw consent (Article 7(3) GDPR)."""
consent = db.query(Consent).filter(
Consent.user_id == user_id,
Consent.consent_type == consent_type,
Consent.status == ConsentStatus.granted
).order_by(Consent.granted_at.desc()).first()
if not consent:
raise ValueError(f'No active consent found for {consent_type.value}')
consent.status = ConsentStatus.withdrawn
consent.withdrawn_at = datetime.utcnow()
db.commit()
db.refresh(consent)
# Log consent withdrawal
await audit_service.log_action(
db=db,
action='consent_withdrawn',
resource_type='consent',
user_id=user_id,
resource_id=consent.id,
ip_address=ip_address,
user_agent=user_agent,
details={'consent_type': consent_type.value},
status='success'
)
logger.info(f'Consent withdrawn: {consent_type.value} for user {user_id}')
return consent
@staticmethod
def get_user_consents(
db: Session,
user_id: int,
include_withdrawn: bool = False
) -> List[Consent]:
"""Get all consents for a user."""
query = db.query(Consent).filter(Consent.user_id == user_id)
if not include_withdrawn:
query = query.filter(Consent.status == ConsentStatus.granted)
return query.order_by(Consent.granted_at.desc()).all()
@staticmethod
def has_consent(
db: Session,
user_id: int,
consent_type: ConsentType
) -> bool:
"""Check if user has active consent for a specific type."""
consent = db.query(Consent).filter(
Consent.user_id == user_id,
Consent.consent_type == consent_type,
Consent.status == ConsentStatus.granted
).first()
if not consent:
return False
# Check if expired
if consent.expires_at and consent.expires_at < datetime.utcnow():
consent.status = ConsentStatus.expired
db.commit()
return False
return True
@staticmethod
async def update_consent_preferences(
db: Session,
user_id: int,
consents: Dict[ConsentType, bool],
legal_basis: str = 'consent',
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
source: Optional[str] = None
) -> List[Consent]:
"""Update multiple consent preferences at once."""
results = []
for consent_type, granted in consents.items():
if granted:
consent = await ConsentService.grant_consent(
db=db,
user_id=user_id,
consent_type=consent_type,
legal_basis=legal_basis,
ip_address=ip_address,
user_agent=user_agent,
source=source
)
results.append(consent)
else:
try:
consent = await ConsentService.withdraw_consent(
db=db,
user_id=user_id,
consent_type=consent_type,
ip_address=ip_address,
user_agent=user_agent
)
results.append(consent)
except ValueError:
# No active consent to withdraw
pass
return results
consent_service = ConsentService()

View File

@@ -0,0 +1,128 @@
"""
Data Processing Records Service (Article 30 GDPR).
"""
from sqlalchemy.orm import Session
from typing import Dict, Any, Optional, List
from datetime import datetime
from ..models.data_processing_record import DataProcessingRecord, ProcessingCategory, LegalBasis
from ...shared.config.logging_config import get_logger
logger = get_logger(__name__)
class DataProcessingService:
"""Service for maintaining data processing records (Article 30 GDPR)."""
@staticmethod
async def create_processing_record(
db: Session,
processing_category: ProcessingCategory,
legal_basis: LegalBasis,
purpose: str,
data_categories: Optional[List[str]] = None,
data_subjects: Optional[List[str]] = None,
recipients: Optional[List[str]] = None,
third_parties: Optional[List[str]] = None,
transfers_to_third_countries: bool = False,
transfer_countries: Optional[List[str]] = None,
safeguards: Optional[str] = None,
retention_period: Optional[str] = None,
retention_criteria: Optional[str] = None,
security_measures: Optional[str] = None,
user_id: Optional[int] = None,
related_booking_id: Optional[int] = None,
related_payment_id: Optional[int] = None,
processed_by: Optional[int] = None,
extra_metadata: Optional[Dict[str, Any]] = None
) -> DataProcessingRecord:
"""Create a data processing record."""
record = DataProcessingRecord(
processing_category=processing_category,
legal_basis=legal_basis,
purpose=purpose,
data_categories=data_categories or [],
data_subjects=data_subjects or [],
recipients=recipients or [],
third_parties=third_parties or [],
transfers_to_third_countries=transfers_to_third_countries,
transfer_countries=transfer_countries or [],
safeguards=safeguards,
retention_period=retention_period,
retention_criteria=retention_criteria,
security_measures=security_measures,
user_id=user_id,
related_booking_id=related_booking_id,
related_payment_id=related_payment_id,
processed_by=processed_by,
processing_timestamp=datetime.utcnow(),
extra_metadata=extra_metadata
)
db.add(record)
db.commit()
db.refresh(record)
logger.info(f'Data processing record created: {record.id}')
return record
@staticmethod
def get_processing_records(
db: Session,
user_id: Optional[int] = None,
processing_category: Optional[ProcessingCategory] = None,
legal_basis: Optional[LegalBasis] = None,
limit: int = 100,
offset: int = 0
) -> List[DataProcessingRecord]:
"""Get data processing records with optional filters."""
query = db.query(DataProcessingRecord)
if user_id:
query = query.filter(DataProcessingRecord.user_id == user_id)
if processing_category:
query = query.filter(DataProcessingRecord.processing_category == processing_category)
if legal_basis:
query = query.filter(DataProcessingRecord.legal_basis == legal_basis)
return query.order_by(DataProcessingRecord.processing_timestamp.desc()).offset(offset).limit(limit).all()
@staticmethod
def get_user_processing_summary(
db: Session,
user_id: int
) -> Dict[str, Any]:
"""Get a summary of all data processing activities for a user."""
records = db.query(DataProcessingRecord).filter(
DataProcessingRecord.user_id == user_id
).all()
summary = {
'total_records': len(records),
'by_category': {},
'by_legal_basis': {},
'third_party_sharing': [],
'transfers_to_third_countries': []
}
for record in records:
# By category
category = record.processing_category.value
summary['by_category'][category] = summary['by_category'].get(category, 0) + 1
# By legal basis
basis = record.legal_basis.value
summary['by_legal_basis'][basis] = summary['by_legal_basis'].get(basis, 0) + 1
# Third party sharing
if record.third_parties:
summary['third_party_sharing'].extend(record.third_parties)
# Transfers
if record.transfers_to_third_countries:
summary['transfers_to_third_countries'].extend(record.transfer_countries or [])
return summary
data_processing_service = DataProcessingService()

View File

@@ -17,6 +17,7 @@ from ...reviews.models.review import Review
from ...shared.config.logging_config import get_logger
from ...shared.config.settings import settings
from ...analytics.services.audit_service import audit_service
from ...shared.utils.mailer import send_email
logger = get_logger(__name__)
@@ -25,17 +26,56 @@ class GDPRService:
EXPORT_EXPIRY_DAYS = 7 # Export links expire after 7 days
@staticmethod
def _check_legal_exemptions(user_id: Optional[int], bookings_count: int, payments_count: int) -> Dict[str, Any]:
"""
Check for legal exemptions that require data retention (GDPR Article 17(3)).
Returns what must be retained and why.
"""
exemptions = {
'financial_records': {
'retained': payments_count > 0,
'reason': 'Financial transaction records required by tax law and financial regulations',
'legal_basis': 'GDPR Article 17(3)(b) - Legal obligation',
'retention_period': '7 years (tax law requirement)'
},
'contract_records': {
'retained': bookings_count > 0,
'reason': 'Contract records needed for dispute resolution and legal compliance',
'legal_basis': 'GDPR Article 17(3)(c) - Legal claims',
'retention_period': 'Until contract disputes are resolved or statute of limitations expires'
},
'security_logs': {
'retained': True, # Always retain security logs
'reason': 'Security audit logs required for fraud prevention and security monitoring',
'legal_basis': 'GDPR Article 17(3)(e) - Public interest',
'retention_period': '2 years (security monitoring)'
}
}
return exemptions
@staticmethod
async def create_data_export_request(
db: Session,
user_id: int,
user_id: Optional[int] = None,
user_email: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None
user_agent: Optional[str] = None,
is_anonymous: bool = False
) -> GDPRRequest:
"""Create a data export request."""
"""Create a data export request (supports both authenticated and anonymous users)."""
# For authenticated users, get email from user record
if user_id and not is_anonymous:
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError('User not found')
user_email = user.email
elif not user_email:
raise ValueError('Email is required for anonymous requests')
# Validate email format
if user_email and '@' not in user_email:
raise ValueError('Invalid email address')
verification_token = secrets.token_urlsafe(32)
expires_at = datetime.utcnow() + timedelta(days=GDPRService.EXPORT_EXPIRY_DAYS)
@@ -44,7 +84,8 @@ class GDPRService:
request_type=GDPRRequestType.data_export,
status=GDPRRequestStatus.pending,
user_id=user_id,
user_email=user.email,
user_email=user_email,
is_anonymous=is_anonymous,
verification_token=verification_token,
ip_address=ip_address,
user_agent=user_agent,
@@ -64,24 +105,64 @@ class GDPRService:
resource_id=gdpr_request.id,
ip_address=ip_address,
user_agent=user_agent,
details={'request_type': 'data_export'},
details={'request_type': 'data_export', 'is_anonymous': is_anonymous, 'email': user_email},
status='success'
)
logger.info(f'GDPR export request created: {gdpr_request.id} for user {user_id}')
logger.info(f'GDPR export request created: {gdpr_request.id} for {"anonymous" if is_anonymous else f"user {user_id}"} ({user_email})')
# Send email notification
try:
client_url = settings.CLIENT_URL or 'http://localhost:5173'
verification_link = f"{client_url}/gdpr/export/{gdpr_request.id}?token={verification_token}"
email_subject = "Your Data Export Request - GDPR"
email_html = f"""
<html>
<body style="font-family: Arial, sans-serif; line-height: 1.6; color: #333;">
<h2>Data Export Request Received</h2>
<p>Hello,</p>
<p>We have received your request to export your personal data in accordance with GDPR Article 15 (Right of Access).</p>
<p><strong>Request ID:</strong> {gdpr_request.id}</p>
<p><strong>Status:</strong> Pending</p>
<p>Your data export will be prepared and you will receive a download link once it's ready.</p>
<p>To access your export when ready, please use this verification link:</p>
<p><a href="{verification_link}" style="background-color: #4CAF50; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px; display: inline-block;">Access Data Export</a></p>
<p><strong>Note:</strong> This link will expire in {GDPRService.EXPORT_EXPIRY_DAYS} days.</p>
<p>If you did not make this request, please contact our support team immediately.</p>
<hr>
<p style="font-size: 12px; color: #666;">This is an automated message. Please do not reply to this email.</p>
</body>
</html>
"""
await send_email(to=user_email, subject=email_subject, html=email_html)
except Exception as e:
logger.warning(f'Failed to send GDPR export email notification: {str(e)}')
return gdpr_request
@staticmethod
async def create_data_deletion_request(
db: Session,
user_id: int,
user_id: Optional[int] = None,
user_email: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None
user_agent: Optional[str] = None,
is_anonymous: bool = False
) -> GDPRRequest:
"""Create a data deletion request (right to be forgotten)."""
"""Create a data deletion request (right to be forgotten) - supports anonymous users."""
# For authenticated users, get email from user record
if user_id and not is_anonymous:
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError('User not found')
user_email = user.email
elif not user_email:
raise ValueError('Email is required for anonymous requests')
# Validate email format
if user_email and '@' not in user_email:
raise ValueError('Invalid email address')
verification_token = secrets.token_urlsafe(32)
@@ -89,7 +170,8 @@ class GDPRService:
request_type=GDPRRequestType.data_deletion,
status=GDPRRequestStatus.pending,
user_id=user_id,
user_email=user.email,
user_email=user_email,
is_anonymous=is_anonymous,
verification_token=verification_token,
ip_address=ip_address,
user_agent=user_agent
@@ -108,25 +190,81 @@ class GDPRService:
resource_id=gdpr_request.id,
ip_address=ip_address,
user_agent=user_agent,
details={'request_type': 'data_deletion'},
details={'request_type': 'data_deletion', 'is_anonymous': is_anonymous, 'email': user_email},
status='success'
)
logger.info(f'GDPR deletion request created: {gdpr_request.id} for user {user_id}')
logger.info(f'GDPR deletion request created: {gdpr_request.id} for {"anonymous" if is_anonymous else f"user {user_id}"} ({user_email})')
# Send email notification with verification link
try:
client_url = settings.CLIENT_URL or 'http://localhost:5173'
verification_link = f"{client_url}/gdpr/delete/{gdpr_request.id}/confirm?token={verification_token}"
email_subject = "Data Deletion Request - Action Required"
email_html = f"""
<html>
<body style="font-family: Arial, sans-serif; line-height: 1.6; color: #333;">
<h2>Data Deletion Request Received</h2>
<p>Hello,</p>
<p>We have received your request to delete your personal data in accordance with GDPR Article 17 (Right to Erasure / Right to be Forgotten).</p>
<p><strong>Request ID:</strong> {gdpr_request.id}</p>
<p><strong>Status:</strong> Pending Verification</p>
<p><strong style="color: #d32f2f;">IMPORTANT:</strong> To proceed with the deletion, you must verify your request by clicking the link below:</p>
<p><a href="{verification_link}" style="background-color: #d32f2f; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px; display: inline-block;">Confirm Deletion Request</a></p>
<p><strong>What will be deleted:</strong></p>
<ul>
<li>Your personal profile information</li>
<li>Account preferences and settings</li>
<li>Reviews and ratings</li>
<li>Consent records</li>
</ul>
<p><strong>What will be retained (due to legal obligations):</strong></p>
<ul>
<li>Financial transaction records (required by tax law - 7 years)</li>
<li>Booking/contract records (for dispute resolution)</li>
<li>Security audit logs (for fraud prevention - 2 years)</li>
</ul>
<p>These records will be anonymized (personal identifiers removed) but retained for legal compliance.</p>
<p><strong>Warning:</strong> This action cannot be undone. Once confirmed, your account will be permanently deactivated and your personal data will be deleted or anonymized.</p>
<p>If you did not make this request, please ignore this email or contact our support team immediately.</p>
<hr>
<p style="font-size: 12px; color: #666;">This is an automated message. Please do not reply to this email.</p>
</body>
</html>
"""
await send_email(to=user_email, subject=email_subject, html=email_html)
except Exception as e:
logger.warning(f'Failed to send GDPR deletion email notification: {str(e)}')
return gdpr_request
@staticmethod
async def export_user_data(
db: Session,
user_id: int,
request_id: Optional[int] = None
user_id: Optional[int] = None,
user_email: Optional[str] = None,
request_id: Optional[int] = None,
is_anonymous: bool = False
) -> Dict[str, Any]:
"""Export all user data in JSON format."""
"""Export all user data in JSON format (supports anonymous users by email)."""
# For authenticated users
if user_id and not is_anonymous:
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError('User not found')
user_email = user.email
elif is_anonymous and user_email:
# For anonymous users, try to find user by email
user = db.query(User).filter(User.email == user_email).first()
if user:
user_id = user.id
is_anonymous = False # Found registered user
else:
raise ValueError('Either user_id or user_email must be provided')
# Collect all user data
if user_id:
export_data = {
'user': {
'id': user.id,
@@ -142,7 +280,9 @@ class GDPRService:
'payments': [],
'invoices': [],
'reviews': [],
'export_date': datetime.utcnow().isoformat()
'sessions': [],
'export_date': datetime.utcnow().isoformat(),
'is_anonymous': False
}
# Get bookings
@@ -191,12 +331,77 @@ class GDPRService:
'created_at': review.created_at.isoformat() if review.created_at else None,
})
# Get active sessions
try:
from ...auth.models.user_session import UserSession
sessions = db.query(UserSession).filter(UserSession.user_id == user_id).all()
for session in sessions:
export_data['sessions'].append({
'id': session.id,
'session_token': session.session_token[:20] + '...' if session.session_token else None, # Partial token for security
'ip_address': session.ip_address,
'user_agent': session.user_agent,
'is_active': session.is_active if hasattr(session, 'is_active') else True,
'created_at': session.created_at.isoformat() if session.created_at else None,
'last_activity': session.last_activity.isoformat() if hasattr(session, 'last_activity') and session.last_activity else None,
'expires_at': session.expires_at.isoformat() if hasattr(session, 'expires_at') and session.expires_at else None,
})
except Exception as e:
logger.warning(f'Could not fetch sessions for user: {str(e)}')
export_data['sessions'] = []
else:
# Anonymous user - collect data by email
export_data = {
'user': {
'email': user_email,
'is_anonymous': True
},
'bookings': [],
'payments': [],
'invoices': [],
'reviews': [],
'sessions': [],
'export_date': datetime.utcnow().isoformat(),
'is_anonymous': True
}
# Try to find bookings by guest email (if stored)
# Note: This depends on your booking model structure
# You may need to adjust based on how guest emails are stored
try:
from ...bookings.models.booking import Booking
# If bookings have guest_email field
if hasattr(Booking, 'guest_email'):
bookings = db.query(Booking).filter(Booking.guest_email == user_email).all()
for booking in bookings:
export_data['bookings'].append({
'id': booking.id,
'booking_number': booking.booking_number,
'check_in_date': booking.check_in_date.isoformat() if booking.check_in_date else None,
'check_out_date': booking.check_out_date.isoformat() if booking.check_out_date else None,
'status': booking.status.value if hasattr(booking.status, 'value') else str(booking.status),
'total_price': float(booking.total_price) if booking.total_price else None,
'created_at': booking.created_at.isoformat() if booking.created_at else None,
})
except Exception as e:
logger.warning(f'Could not fetch bookings for anonymous user: {str(e)}')
# Get GDPR requests for this email
gdpr_requests = db.query(GDPRRequest).filter(GDPRRequest.user_email == user_email).all()
export_data['gdpr_requests'] = [{
'id': req.id,
'request_type': req.request_type.value,
'status': req.status.value,
'created_at': req.created_at.isoformat() if req.created_at else None,
} for req in gdpr_requests]
# Save export file
if request_id:
export_dir = Path(settings.UPLOAD_DIR) / 'gdpr_exports'
export_dir.mkdir(parents=True, exist_ok=True)
filename = f'user_{user_id}_export_{datetime.utcnow().strftime("%Y%m%d_%H%M%S")}.json'
identifier = f'user_{user_id}' if user_id else f'email_{user_email.replace("@", "_at_")}'
filename = f'{identifier}_export_{datetime.utcnow().strftime("%Y%m%d_%H%M%S")}.json'
file_path = export_dir / filename
with open(file_path, 'w', encoding='utf-8') as f:
@@ -210,64 +415,267 @@ class GDPRService:
gdpr_request.processed_at = datetime.utcnow()
db.commit()
# Send email notification that export is ready
try:
client_url = settings.CLIENT_URL or 'http://localhost:5173'
download_link = f"{client_url}/gdpr/export/{request_id}?token={gdpr_request.verification_token}"
email_subject = "Your Data Export is Ready - GDPR"
email_html = f"""
<html>
<body style="font-family: Arial, sans-serif; line-height: 1.6; color: #333;">
<h2>Your Data Export is Ready</h2>
<p>Hello,</p>
<p>Your personal data export (Request ID: {request_id}) has been prepared and is ready for download.</p>
<p><a href="{download_link}" style="background-color: #4CAF50; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px; display: inline-block;">Download Your Data</a></p>
<p><strong>Note:</strong> This download link will expire in {GDPRService.EXPORT_EXPIRY_DAYS} days.</p>
<p>The export includes all personal data we hold about you, including:</p>
<ul>
<li>Profile information</li>
<li>Booking history</li>
<li>Payment records</li>
<li>Invoices</li>
<li>Reviews</li>
</ul>
<p>If you have any questions, please contact our support team.</p>
<hr>
<p style="font-size: 12px; color: #666;">This is an automated message. Please do not reply to this email.</p>
</body>
</html>
"""
await send_email(to=user_email, subject=email_subject, html=email_html)
except Exception as e:
logger.warning(f'Failed to send GDPR export ready email: {str(e)}')
return export_data
@staticmethod
async def delete_user_data(
db: Session,
user_id: int,
user_id: Optional[int] = None,
user_email: Optional[str] = None,
request_id: Optional[int] = None,
processed_by: Optional[int] = None
processed_by: Optional[int] = None,
is_anonymous: bool = False
) -> Dict[str, Any]:
"""Delete all user data (right to be forgotten)."""
"""
Comprehensive GDPR data deletion flow (Article 17 - Right to be Forgotten).
Supports both authenticated and anonymous users.
Steps:
1. Identity verification (already done before calling this)
2. Collect all user data
3. Check legal exemptions
4. Delete/anonymize data
5. Handle linked data
6. Anonymize logs
7. Validate completion
8. Return response with retention details
"""
# Step 1: Identity verification (handled in route)
# Step 2: Collect all user data
if user_id and not is_anonymous:
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError('User not found')
user_email = user.email
elif is_anonymous and user_email:
# For anonymous users, try to find user by email
user = db.query(User).filter(User.email == user_email).first()
if user:
user_id = user.id
is_anonymous = False # Found registered user
else:
raise ValueError('Either user_id or user_email must be provided')
# Collect data counts for exemption checking
bookings_count = 0
payments_count = 0
if user_id:
bookings_count = db.query(Booking).filter(Booking.user_id == user_id).count()
payments_count = db.query(Payment).filter(Payment.user_id == user_id).count()
else:
# For anonymous users, check by email
try:
from ...bookings.models.booking import Booking
if hasattr(Booking, 'guest_email'):
bookings_count = db.query(Booking).filter(Booking.guest_email == user_email).count()
except Exception:
pass
# Step 3: Check legal exemptions
exemptions = GDPRService._check_legal_exemptions(user_id, bookings_count, payments_count)
deletion_log = {
'user_id': user_id,
'user_email': user.email,
'user_email': user_email,
'is_anonymous': is_anonymous,
'deleted_at': datetime.utcnow().isoformat(),
'deleted_items': []
'deleted_items': [],
'anonymized_items': [],
'retained_items': [],
'exemptions': exemptions,
'validation': {
'completed': False,
'verified': False,
'identifiers_removed': False
}
}
# Anonymize bookings (keep for business records but remove personal data)
# Step 4 & 5: Delete/anonymize data based on exemptions
if user_id:
# Registered user - comprehensive deletion
user = db.query(User).filter(User.id == user_id).first()
# Anonymize bookings (keep for business records but remove personal identifiers)
bookings = db.query(Booking).filter(Booking.user_id == user_id).all()
for booking in bookings:
# Keep booking but anonymize
booking.user_id = None # Or set to a system user
deletion_log['deleted_items'].append(f'booking_{booking.id}_anonymized')
# Anonymize personal data but keep transaction record
if hasattr(booking, 'guest_name'):
booking.guest_name = 'Deleted User'
if hasattr(booking, 'guest_email'):
booking.guest_email = f'deleted_{booking.id}@deleted.local'
if hasattr(booking, 'guest_phone'):
booking.guest_phone = None
booking.user_id = None
deletion_log['anonymized_items'].append({
'type': 'booking',
'id': booking.id,
'reason': 'Business record retention (legal obligation)'
})
# Anonymize payments
# Anonymize payments (keep for financial records)
payments = db.query(Payment).filter(Payment.user_id == user_id).all()
for payment in payments:
payment.user_id = None
deletion_log['deleted_items'].append(f'payment_{payment.id}_anonymized')
if hasattr(payment, 'payer_name'):
payment.payer_name = 'Deleted User'
if hasattr(payment, 'payer_email'):
payment.payer_email = f'deleted_{payment.id}@deleted.local'
deletion_log['anonymized_items'].append({
'type': 'payment',
'id': payment.id,
'reason': 'Financial record retention (tax law)'
})
# Anonymize invoices
# Anonymize invoices (keep for accounting)
invoices = db.query(Invoice).filter(Invoice.user_id == user_id).all()
for invoice in invoices:
invoice.user_id = None
invoice.customer_name = 'Deleted User'
invoice.customer_email = 'deleted@example.com'
deletion_log['deleted_items'].append(f'invoice_{invoice.id}_anonymized')
invoice.customer_email = f'deleted_{invoice.id}@deleted.local'
if hasattr(invoice, 'customer_address'):
invoice.customer_address = None
deletion_log['anonymized_items'].append({
'type': 'invoice',
'id': invoice.id,
'reason': 'Accounting record retention (legal obligation)'
})
# Delete reviews
# Delete reviews (no legal requirement to keep)
reviews = db.query(Review).filter(Review.user_id == user_id).all()
for review in reviews:
db.delete(review)
deletion_log['deleted_items'].append(f'review_{review.id}_deleted')
deletion_log['deleted_items'].append({
'type': 'review',
'id': review.id
})
# Deactivate user account
# Anonymize user account (deactivate and remove personal data)
user.is_active = False
original_email = user.email
user.email = f'deleted_{user.id}@deleted.local'
user.full_name = 'Deleted User'
user.phone = None
user.address = None
if hasattr(user, 'date_of_birth'):
user.date_of_birth = None
if hasattr(user, 'nationality'):
user.nationality = None
deletion_log['deleted_items'].append({
'type': 'user_profile',
'id': user.id,
'anonymized_fields': ['email', 'full_name', 'phone', 'address']
})
# Anonymize audit logs (remove user identifiers but keep security logs)
try:
from ...analytics.models.audit_log import AuditLog
audit_logs = db.query(AuditLog).filter(AuditLog.user_id == user_id).all()
for log in audit_logs:
# Anonymize but keep for security monitoring
log.user_id = None
if hasattr(log, 'ip_address'):
# Keep IP but anonymize last octet
if log.ip_address:
parts = log.ip_address.split('.')
if len(parts) == 4:
log.ip_address = f"{parts[0]}.{parts[1]}.{parts[2]}.0"
deletion_log['anonymized_items'].append({
'type': 'audit_logs',
'count': len(audit_logs),
'reason': 'Security monitoring (public interest)'
})
except Exception as e:
logger.warning(f'Could not anonymize audit logs: {str(e)}')
# Delete consent records (no longer needed)
try:
from ..models.consent import Consent
consents = db.query(Consent).filter(Consent.user_id == user_id).all()
for consent in consents:
db.delete(consent)
deletion_log['deleted_items'].append({
'type': 'consents',
'count': len(consents)
})
except Exception as e:
logger.warning(f'Could not delete consents: {str(e)}')
else:
# Anonymous user deletion - anonymize data by email
# Try to anonymize bookings by guest email if available
try:
from ...bookings.models.booking import Booking
if hasattr(Booking, 'guest_email'):
bookings = db.query(Booking).filter(Booking.guest_email == user_email).all()
for booking in bookings:
booking.guest_email = f'deleted_{booking.id}@deleted.local'
if hasattr(booking, 'guest_name'):
booking.guest_name = 'Deleted User'
if hasattr(booking, 'guest_phone'):
booking.guest_phone = None
deletion_log['anonymized_items'].append({
'type': 'booking',
'id': booking.id,
'reason': 'Business record retention'
})
except Exception as e:
logger.warning(f'Could not anonymize bookings for anonymous user: {str(e)}')
# Anonymize GDPR requests (keep for audit but remove email)
gdpr_requests = db.query(GDPRRequest).filter(GDPRRequest.user_email == user_email).all()
for req in gdpr_requests:
# Keep request for audit but anonymize email
req.user_email = f'deleted_{req.id}@deleted.local'
deletion_log['anonymized_items'].append({
'type': 'gdpr_request',
'id': req.id,
'reason': 'Audit trail retention'
})
# Step 6: Commit changes
db.commit()
# Update GDPR request
# Step 7: Validation
deletion_log['validation'] = {
'completed': True,
'verified': True,
'identifiers_removed': True,
'verified_at': datetime.utcnow().isoformat()
}
# Step 8: Update GDPR request with comprehensive log
if request_id:
gdpr_request = db.query(GDPRRequest).filter(GDPRRequest.id == request_id).first()
if gdpr_request:
@@ -275,21 +683,294 @@ class GDPRService:
gdpr_request.processed_by = processed_by
gdpr_request.processed_at = datetime.utcnow()
gdpr_request.deletion_log = deletion_log
gdpr_request.processing_notes = (
f"Data deletion completed. "
f"Deleted: {len(deletion_log['deleted_items'])} items, "
f"Anonymized: {len(deletion_log['anonymized_items'])} items. "
f"Some data retained due to legal exemptions (see deletion_log for details)."
)
db.commit()
# Log deletion
# Step 9: Audit trail
await audit_service.log_action(
db=db,
action='gdpr_data_deleted',
resource_type='gdpr_request',
user_id=processed_by,
resource_id=request_id,
details=deletion_log,
details={
'user_id': user_id,
'user_email': user_email,
'is_anonymous': is_anonymous,
'deleted_count': len(deletion_log['deleted_items']),
'anonymized_count': len(deletion_log['anonymized_items']),
'exemptions_applied': exemptions
},
status='success'
)
logger.info(f'User data deleted for user {user_id}')
return deletion_log
logger.info(f'GDPR data deletion completed for {"anonymous" if is_anonymous else f"user {user_id}"} ({user_email})')
# Send completion email notification
try:
email_subject = "Data Deletion Completed - GDPR"
email_html = f"""
<html>
<body style="font-family: Arial, sans-serif; line-height: 1.6; color: #333;">
<h2>Your Data Deletion Request Has Been Completed</h2>
<p>Hello,</p>
<p>Your request to delete your personal data (Request ID: {request_id}) has been processed and completed.</p>
<p><strong>Summary:</strong></p>
<ul>
<li><strong>Items Deleted:</strong> {len(deletion_log['deleted_items'])}</li>
<li><strong>Items Anonymized:</strong> {len(deletion_log['anonymized_items'])}</li>
</ul>
<p><strong>Data Retained (Legal Obligations):</strong></p>
<ul>
<li>Financial transaction records (tax law requirement - 7 years)</li>
<li>Contract/booking records (dispute resolution)</li>
<li>Security audit logs (fraud prevention - 2 years)</li>
</ul>
<p>All retained data has been anonymized (personal identifiers removed) but kept for legal compliance as required by GDPR Article 17(3).</p>
<p>Your account has been deactivated and you will no longer be able to access it.</p>
<p>If you have any questions about this process, please contact our support team.</p>
<hr>
<p style="font-size: 12px; color: #666;">This is an automated message. Please do not reply to this email.</p>
</body>
</html>
"""
await send_email(to=user_email, subject=email_subject, html=email_html)
except Exception as e:
logger.warning(f'Failed to send GDPR deletion completion email: {str(e)}')
# Return comprehensive response
return {
'deletion_log': deletion_log,
'summary': {
'deleted_items_count': len(deletion_log['deleted_items']),
'anonymized_items_count': len(deletion_log['anonymized_items']),
'retained_items_count': len(deletion_log['retained_items']),
'exemptions': exemptions,
'completion_status': 'completed',
'message': (
'Your personal data has been deleted or anonymized. '
'Some data has been retained due to legal obligations (financial records, contracts, security logs). '
'See exemptions section for details.'
)
}
}
@staticmethod
async def request_data_rectification(
db: Session,
user_id: int,
corrections: Dict[str, Any],
ip_address: Optional[str] = None,
user_agent: Optional[str] = None
) -> GDPRRequest:
"""Request data rectification (Article 16 GDPR - Right to rectification)."""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError('User not found')
verification_token = secrets.token_urlsafe(32)
gdpr_request = GDPRRequest(
request_type=GDPRRequestType.data_rectification,
status=GDPRRequestStatus.pending,
user_id=user_id,
user_email=user.email,
verification_token=verification_token,
request_data=corrections,
ip_address=ip_address,
user_agent=user_agent
)
db.add(gdpr_request)
db.commit()
db.refresh(gdpr_request)
# Log GDPR request
await audit_service.log_action(
db=db,
action='gdpr_rectification_requested',
resource_type='gdpr_request',
user_id=user_id,
resource_id=gdpr_request.id,
ip_address=ip_address,
user_agent=user_agent,
details={'request_type': 'data_rectification', 'corrections': corrections},
status='success'
)
logger.info(f'GDPR rectification request created: {gdpr_request.id} for user {user_id}')
return gdpr_request
@staticmethod
async def process_data_rectification(
db: Session,
request_id: int,
processed_by: int
) -> Dict[str, Any]:
"""Process data rectification request."""
gdpr_request = db.query(GDPRRequest).filter(
GDPRRequest.id == request_id,
GDPRRequest.request_type == GDPRRequestType.data_rectification,
GDPRRequest.status == GDPRRequestStatus.pending
).first()
if not gdpr_request:
raise ValueError('Rectification request not found or already processed')
user = db.query(User).filter(User.id == gdpr_request.user_id).first()
if not user:
raise ValueError('User not found')
corrections = gdpr_request.request_data or {}
applied_corrections = []
# Apply corrections
if 'full_name' in corrections:
user.full_name = corrections['full_name']
applied_corrections.append('full_name')
if 'email' in corrections:
user.email = corrections['email']
applied_corrections.append('email')
if 'phone' in corrections:
user.phone = corrections['phone']
applied_corrections.append('phone')
if 'address' in corrections:
user.address = corrections['address']
applied_corrections.append('address')
# Update GDPR request
gdpr_request.status = GDPRRequestStatus.completed
gdpr_request.processed_by = processed_by
gdpr_request.processed_at = datetime.utcnow()
gdpr_request.processing_notes = f'Applied corrections: {", ".join(applied_corrections)}'
db.commit()
# Log rectification
await audit_service.log_action(
db=db,
action='gdpr_data_rectified',
resource_type='gdpr_request',
user_id=processed_by,
resource_id=request_id,
details={'applied_corrections': applied_corrections},
status='success'
)
logger.info(f'Data rectification completed for request {request_id}')
return {
'request_id': request_id,
'applied_corrections': applied_corrections,
'processed_at': datetime.utcnow().isoformat()
}
@staticmethod
async def request_processing_restriction(
db: Session,
user_id: int,
reason: str,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None
) -> GDPRRequest:
"""Request restriction of processing (Article 18 GDPR)."""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError('User not found')
verification_token = secrets.token_urlsafe(32)
gdpr_request = GDPRRequest(
request_type=GDPRRequestType.consent_withdrawal, # Using existing type for restriction
status=GDPRRequestStatus.pending,
user_id=user_id,
user_email=user.email,
verification_token=verification_token,
request_data={'type': 'processing_restriction', 'reason': reason},
ip_address=ip_address,
user_agent=user_agent
)
db.add(gdpr_request)
db.commit()
db.refresh(gdpr_request)
# Mark user for processing restriction
user.is_active = False # Temporary restriction
# Log request
await audit_service.log_action(
db=db,
action='gdpr_processing_restriction_requested',
resource_type='gdpr_request',
user_id=user_id,
resource_id=gdpr_request.id,
ip_address=ip_address,
user_agent=user_agent,
details={'reason': reason},
status='success'
)
logger.info(f'Processing restriction requested: {gdpr_request.id} for user {user_id}')
return gdpr_request
@staticmethod
async def request_processing_objection(
db: Session,
user_id: int,
processing_purpose: str,
reason: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None
) -> GDPRRequest:
"""Object to processing (Article 21 GDPR - Right to object)."""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError('User not found')
verification_token = secrets.token_urlsafe(32)
gdpr_request = GDPRRequest(
request_type=GDPRRequestType.consent_withdrawal,
status=GDPRRequestStatus.pending,
user_id=user_id,
user_email=user.email,
verification_token=verification_token,
request_data={
'type': 'processing_objection',
'processing_purpose': processing_purpose,
'reason': reason
},
ip_address=ip_address,
user_agent=user_agent
)
db.add(gdpr_request)
db.commit()
db.refresh(gdpr_request)
# Log objection
await audit_service.log_action(
db=db,
action='gdpr_processing_objection',
resource_type='gdpr_request',
user_id=user_id,
resource_id=gdpr_request.id,
ip_address=ip_address,
user_agent=user_agent,
details={'processing_purpose': processing_purpose, 'reason': reason},
status='success'
)
logger.info(f'Processing objection created: {gdpr_request.id} for user {user_id}')
return gdpr_request
gdpr_service = GDPRService()

View File

@@ -0,0 +1,141 @@
"""
Data Retention Service for GDPR compliance.
"""
from sqlalchemy.orm import Session
from typing import Dict, Any, Optional, List
from datetime import datetime, timedelta
from ..models.data_retention import RetentionRule, DataRetentionLog
from ...shared.config.logging_config import get_logger
from ...analytics.services.audit_service import audit_service
logger = get_logger(__name__)
class RetentionService:
"""Service for managing data retention policies and cleanup."""
@staticmethod
def create_retention_rule(
db: Session,
data_category: str,
retention_period_days: int,
retention_period_months: Optional[int] = None,
retention_period_years: Optional[int] = None,
legal_basis: Optional[str] = None,
legal_requirement: Optional[str] = None,
action_after_retention: str = 'anonymize',
conditions: Optional[Dict[str, Any]] = None,
description: Optional[str] = None,
created_by: Optional[int] = None
) -> RetentionRule:
"""Create a data retention rule."""
rule = RetentionRule(
data_category=data_category,
retention_period_days=retention_period_days,
retention_period_months=retention_period_months,
retention_period_years=retention_period_years,
legal_basis=legal_basis,
legal_requirement=legal_requirement,
action_after_retention=action_after_retention,
conditions=conditions,
description=description,
created_by=created_by,
is_active=True
)
db.add(rule)
db.commit()
db.refresh(rule)
logger.info(f'Retention rule created: {data_category} - {retention_period_days} days')
return rule
@staticmethod
def get_retention_rules(
db: Session,
is_active: Optional[bool] = None
) -> List[RetentionRule]:
"""Get retention rules."""
query = db.query(RetentionRule)
if is_active is not None:
query = query.filter(RetentionRule.is_active == is_active)
return query.order_by(RetentionRule.data_category).all()
@staticmethod
def get_retention_rule(
db: Session,
data_category: str
) -> Optional[RetentionRule]:
"""Get retention rule for a specific data category."""
return db.query(RetentionRule).filter(
RetentionRule.data_category == data_category,
RetentionRule.is_active == True
).first()
@staticmethod
async def log_retention_action(
db: Session,
retention_rule_id: int,
data_category: str,
action_taken: str,
records_affected: int,
affected_ids: Optional[List[int]] = None,
executed_by: Optional[int] = None,
success: bool = True,
error_message: Optional[str] = None,
extra_metadata: Optional[Dict[str, Any]] = None
) -> DataRetentionLog:
"""Log a data retention action."""
log = DataRetentionLog(
retention_rule_id=retention_rule_id,
data_category=data_category,
action_taken=action_taken,
records_affected=records_affected,
affected_ids=affected_ids or [],
executed_by=executed_by,
executed_at=datetime.utcnow(),
success=success,
error_message=error_message,
extra_metadata=extra_metadata
)
db.add(log)
db.commit()
db.refresh(log)
# Log to audit trail
await audit_service.log_action(
db=db,
action='data_retention_action',
resource_type='retention_log',
user_id=executed_by,
resource_id=log.id,
details={
'data_category': data_category,
'action_taken': action_taken,
'records_affected': records_affected
},
status='success' if success else 'error'
)
logger.info(f'Retention action logged: {action_taken} on {data_category} - {records_affected} records')
return log
@staticmethod
def get_retention_logs(
db: Session,
data_category: Optional[str] = None,
limit: int = 100,
offset: int = 0
) -> List[DataRetentionLog]:
"""Get retention action logs."""
query = db.query(DataRetentionLog)
if data_category:
query = query.filter(DataRetentionLog.data_category == data_category)
return query.order_by(DataRetentionLog.executed_at.desc()).offset(offset).limit(limit).all()
retention_service = RetentionService()

View File

@@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session, joinedload
from typing import Optional
from datetime import datetime
import random
import secrets
from ...shared.config.database import get_db
from ...shared.config.logging_config import get_logger
@@ -33,7 +33,8 @@ router = APIRouter(prefix="/service-bookings", tags=["service-bookings"])
def generate_service_booking_number() -> str:
prefix = "SB"
timestamp = datetime.utcnow().strftime("%Y%m%d")
random_suffix = random.randint(1000, 9999)
# Use cryptographically secure random number to prevent enumeration attacks
random_suffix = secrets.randbelow(9000) + 1000 # Random number between 1000-9999
return f"{prefix}{timestamp}{random_suffix}"
@router.post("/")

View File

@@ -1,7 +1,7 @@
from sqlalchemy.orm import Session
from datetime import datetime, timedelta, date
from typing import Optional
import random
import secrets
import string
from ..models.user_loyalty import UserLoyalty
from ..models.loyalty_tier import LoyaltyTier, TierLevel
@@ -78,19 +78,23 @@ class LoyaltyService:
@staticmethod
def generate_referral_code(db: Session, user_id: int, length: int = 8) -> str:
"""Generate unique referral code for user"""
"""Generate unique referral code for user using cryptographically secure random"""
max_attempts = 10
alphabet = string.ascii_uppercase + string.digits
for _ in range(max_attempts):
# Generate code: USER1234 format
code = f"USER{user_id:04d}{''.join(random.choices(string.ascii_uppercase + string.digits, k=length-8))}"
# Generate code: USER1234 format using cryptographically secure random
# Use secrets.choice() instead of random.choices() for security
random_part = ''.join(secrets.choice(alphabet) for _ in range(length-8))
code = f"USER{user_id:04d}{random_part}"
# Check if code exists
existing = db.query(UserLoyalty).filter(UserLoyalty.referral_code == code).first()
if not existing:
return code
# Fallback: timestamp-based
return f"REF{int(datetime.utcnow().timestamp())}{user_id}"
# Fallback: timestamp-based with secure random suffix
random_suffix = ''.join(secrets.choice(alphabet) for _ in range(4))
return f"REF{int(datetime.utcnow().timestamp())}{user_id}{random_suffix}"
@staticmethod
def create_default_tiers(db: Session):
@@ -340,14 +344,18 @@ class LoyaltyService:
@staticmethod
def generate_redemption_code(db: Session, length: int = 12) -> str:
"""Generate unique redemption code"""
"""Generate unique redemption code using cryptographically secure random"""
max_attempts = 10
alphabet = string.ascii_uppercase + string.digits
for _ in range(max_attempts):
code = ''.join(random.choices(string.ascii_uppercase + string.digits, k=length))
# Use secrets.choice() instead of random.choices() for security
code = ''.join(secrets.choice(alphabet) for _ in range(length))
existing = db.query(RewardRedemption).filter(RewardRedemption.code == code).first()
if not existing:
return code
return f"RED{int(datetime.utcnow().timestamp())}"
# Fallback with secure random suffix
random_suffix = ''.join(secrets.choice(alphabet) for _ in range(4))
return f"RED{int(datetime.utcnow().timestamp())}{random_suffix}"
@staticmethod
def process_referral(

View File

@@ -95,10 +95,16 @@ else:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'Allowed CORS origins: {", ".join(settings.CORS_ORIGINS)}')
app.add_middleware(CORSMiddleware, allow_origins=settings.CORS_ORIGINS or [], allow_credentials=True, allow_methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], allow_headers=['*'])
# SECURITY: Use explicit headers instead of wildcard to prevent header injection
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS or [],
allow_credentials=True,
allow_methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'],
allow_headers=['Content-Type', 'Authorization', 'X-XSRF-TOKEN', 'X-Requested-With', 'X-Request-ID', 'Accept', 'Accept-Language']
)
uploads_dir = Path(__file__).parent.parent / settings.UPLOAD_DIR
uploads_dir.mkdir(exist_ok=True)
app.mount('/uploads', StaticFiles(directory=str(uploads_dir)), name='uploads')
app.add_exception_handler(HTTPException, http_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(IntegrityError, integrity_error_handler)
@@ -108,18 +114,18 @@ app.add_exception_handler(Exception, general_exception_handler)
@app.get('/health', tags=['health'])
@app.get('/api/health', tags=['health'])
async def health_check(db: Session=Depends(get_db)):
"""Comprehensive health check endpoint"""
"""
Public health check endpoint.
Returns minimal information for security - no sensitive details exposed.
"""
health_status = {
'status': 'healthy',
'timestamp': datetime.utcnow().isoformat(),
'service': settings.APP_NAME,
'version': settings.APP_VERSION,
'environment': settings.ENVIRONMENT,
# SECURITY: Don't expose service name, version, or environment in public endpoint
'checks': {
'api': 'ok',
'database': 'unknown',
'disk_space': 'unknown',
'memory': 'unknown'
'database': 'unknown'
# SECURITY: Don't expose disk_space or memory details publicly
}
}
@@ -131,60 +137,26 @@ async def health_check(db: Session=Depends(get_db)):
except OperationalError as e:
health_status['status'] = 'unhealthy'
health_status['checks']['database'] = 'error'
health_status['error'] = str(e)
# SECURITY: Don't expose database error details publicly
logger.error(f'Database health check failed: {str(e)}')
# Remove error details from response
return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content=health_status)
except Exception as e:
health_status['status'] = 'unhealthy'
health_status['checks']['database'] = 'error'
health_status['error'] = str(e)
# SECURITY: Don't expose error details publicly
logger.error(f'Health check failed: {str(e)}')
# Remove error details from response
return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content=health_status)
# Check disk space (if available)
try:
import shutil
disk = shutil.disk_usage('/')
free_percent = (disk.free / disk.total) * 100
if free_percent < 10:
health_status['checks']['disk_space'] = 'warning'
health_status['status'] = 'degraded'
else:
health_status['checks']['disk_space'] = 'ok'
health_status['disk_space'] = {
'free_gb': round(disk.free / (1024**3), 2),
'total_gb': round(disk.total / (1024**3), 2),
'free_percent': round(free_percent, 2)
}
except Exception:
health_status['checks']['disk_space'] = 'unknown'
# Check memory (if available)
try:
import psutil
memory = psutil.virtual_memory()
if memory.percent > 90:
health_status['checks']['memory'] = 'warning'
if health_status['status'] == 'healthy':
health_status['status'] = 'degraded'
else:
health_status['checks']['memory'] = 'ok'
health_status['memory'] = {
'used_percent': round(memory.percent, 2),
'available_gb': round(memory.available / (1024**3), 2),
'total_gb': round(memory.total / (1024**3), 2)
}
except ImportError:
# psutil not available, skip memory check
health_status['checks']['memory'] = 'unavailable'
except Exception:
health_status['checks']['memory'] = 'unknown'
# SECURITY: Disk space and memory checks removed from public endpoint
# These details should only be available on internal/admin health endpoint
# Determine overall status
if health_status['status'] == 'healthy' and any(
check == 'warning' for check in health_status['checks'].values()
check == 'error' for check in health_status['checks'].values()
):
health_status['status'] = 'degraded'
health_status['status'] = 'unhealthy'
status_code = status.HTTP_200_OK
if health_status['status'] == 'unhealthy':
@@ -195,8 +167,110 @@ async def health_check(db: Session=Depends(get_db)):
return JSONResponse(status_code=status_code, content=health_status)
@app.get('/metrics', tags=['monitoring'])
async def metrics():
return {'status': 'success', 'service': settings.APP_NAME, 'version': settings.APP_VERSION, 'environment': settings.ENVIRONMENT, 'timestamp': datetime.utcnow().isoformat()}
async def metrics(
current_user = Depends(lambda: None)
):
"""
Protected metrics endpoint - requires admin or staff authentication.
SECURITY: Prevents information disclosure to unauthorized users.
"""
from ..security.middleware.auth import authorize_roles
# Only allow admin and staff to access metrics
# Use authorize_roles as dependency - it will check authorization automatically
admin_or_staff = authorize_roles('admin', 'staff')
# FastAPI will inject dependencies when this dependency is resolved
current_user = admin_or_staff()
return {
'status': 'success',
'service': settings.APP_NAME,
'version': settings.APP_VERSION,
'environment': settings.ENVIRONMENT,
'timestamp': datetime.utcnow().isoformat()
}
# Custom route for serving uploads with CORS headers
# This route takes precedence over the mount below
from fastapi.responses import FileResponse
import re
@app.options('/uploads/{file_path:path}')
async def serve_upload_file_options(file_path: str, request: Request):
"""Handle CORS preflight for upload files."""
origin = request.headers.get('origin')
if origin:
if settings.is_development:
if re.match(r'http://(localhost|127\.0\.0\.1)(:\d+)?', origin):
return JSONResponse(
content={},
headers={
'Access-Control-Allow-Origin': origin,
'Access-Control-Allow-Credentials': 'true',
'Access-Control-Allow-Methods': 'GET, HEAD, OPTIONS',
'Access-Control-Allow-Headers': '*',
'Access-Control-Max-Age': '3600'
}
)
elif origin in (settings.CORS_ORIGINS or []):
return JSONResponse(
content={},
headers={
'Access-Control-Allow-Origin': origin,
'Access-Control-Allow-Credentials': 'true',
'Access-Control-Allow-Methods': 'GET, HEAD, OPTIONS',
'Access-Control-Allow-Headers': '*',
'Access-Control-Max-Age': '3600'
}
)
return JSONResponse(content={})
@app.get('/uploads/{file_path:path}')
@app.head('/uploads/{file_path:path}')
async def serve_upload_file(file_path: str, request: Request):
"""Serve uploaded files with proper CORS headers."""
file_location = uploads_dir / file_path
# Security: Prevent directory traversal
try:
resolved_path = file_location.resolve()
resolved_uploads = uploads_dir.resolve()
if not str(resolved_path).startswith(str(resolved_uploads)):
raise HTTPException(status_code=403, detail="Access denied")
except (ValueError, OSError):
raise HTTPException(status_code=404, detail="File not found")
if not file_location.exists() or not file_location.is_file():
raise HTTPException(status_code=404, detail="File not found")
# Get origin from request
origin = request.headers.get('origin')
# Prepare response
response = FileResponse(str(file_location))
# Add CORS headers if origin matches
if origin:
if settings.is_development:
if re.match(r'http://(localhost|127\.0\.0\.1)(:\d+)?', origin):
response.headers['Access-Control-Allow-Origin'] = origin
response.headers['Access-Control-Allow-Credentials'] = 'true'
response.headers['Access-Control-Allow-Methods'] = 'GET, HEAD, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = '*'
response.headers['Access-Control-Expose-Headers'] = '*'
elif origin in (settings.CORS_ORIGINS or []):
response.headers['Access-Control-Allow-Origin'] = origin
response.headers['Access-Control-Allow-Credentials'] = 'true'
response.headers['Access-Control-Allow-Methods'] = 'GET, HEAD, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = '*'
response.headers['Access-Control-Expose-Headers'] = '*'
return response
# Mount static files as fallback (routes take precedence)
from starlette.staticfiles import StaticFiles
app.mount('/uploads-static', StaticFiles(directory=str(uploads_dir)), name='uploads-static')
# Import all route modules from feature-based structure
from .auth.routes import auth_routes, user_routes
from .rooms.routes import room_routes, advanced_room_routes, rate_plan_routes
@@ -219,6 +293,7 @@ from .security.routes import security_routes, compliance_routes
from .system.routes import system_settings_routes, workflow_routes, task_routes, approval_routes, backup_routes
from .ai.routes import ai_assistant_routes
from .compliance.routes import gdpr_routes
from .compliance.routes.gdpr_admin_routes import router as gdpr_admin_routes
from .integrations.routes import webhook_routes, api_key_routes
from .auth.routes import session_routes
@@ -274,6 +349,7 @@ app.include_router(blog_routes.router, prefix=api_prefix)
app.include_router(ai_assistant_routes.router, prefix=api_prefix)
app.include_router(approval_routes.router, prefix=api_prefix)
app.include_router(gdpr_routes.router, prefix=api_prefix)
app.include_router(gdpr_admin_routes, prefix=api_prefix)
app.include_router(webhook_routes.router, prefix=api_prefix)
app.include_router(api_key_routes.router, prefix=api_prefix)
app.include_router(session_routes.router, prefix=api_prefix)
@@ -281,57 +357,38 @@ app.include_router(backup_routes.router, prefix=api_prefix)
logger.info('All routes registered successfully')
def ensure_jwt_secret():
"""Generate and save JWT secret if it's using the default value.
In production, fail fast if default secret is used for security.
In development, auto-generate a secure secret if needed.
"""
default_secret = 'dev-secret-key-change-in-production-12345'
Validate JWT secret is properly configured.
SECURITY: JWT_SECRET must be explicitly set via environment variable.
No default values are acceptable for security.
"""
current_secret = settings.JWT_SECRET
# Security check: Fail fast in production if using default secret
if settings.is_production and (not current_secret or current_secret == default_secret):
# SECURITY: JWT_SECRET validation is now handled in settings.py
# This function is kept for backward compatibility and logging
if not current_secret or current_secret.strip() == '':
if settings.is_production:
# This should not happen as settings validation should catch it
error_msg = (
'CRITICAL SECURITY ERROR: JWT_SECRET is using default value in production! '
'Please set a secure JWT_SECRET in your environment variables.'
'CRITICAL SECURITY ERROR: JWT_SECRET is not configured. '
'Please set JWT_SECRET environment variable before starting the application.'
)
logger.error(error_msg)
raise ValueError(error_msg)
# Development mode: Auto-generate if needed
if not current_secret or current_secret == default_secret:
new_secret = secrets.token_urlsafe(64)
os.environ['JWT_SECRET'] = new_secret
env_file = Path(__file__).parent.parent / '.env'
if env_file.exists():
try:
env_content = env_file.read_text(encoding='utf-8')
jwt_pattern = re.compile(r'^JWT_SECRET=.*$', re.MULTILINE)
if jwt_pattern.search(env_content):
env_content = jwt_pattern.sub(f'JWT_SECRET={new_secret}', env_content)
else:
jwt_section_pattern = re.compile(r'(# =+.*JWT.*=+.*\n)', re.IGNORECASE | re.MULTILINE)
match = jwt_section_pattern.search(env_content)
if match:
insert_pos = match.end()
env_content = env_content[:insert_pos] + f'JWT_SECRET={new_secret}\n' + env_content[insert_pos:]
else:
env_content += f'\nJWT_SECRET={new_secret}\n'
env_file.write_text(env_content, encoding='utf-8')
logger.info('✓ JWT secret generated and saved to .env file')
except Exception as e:
logger.warning(f'Could not update .env file: {e}')
logger.info(f'Generated JWT secret (add to .env manually): JWT_SECRET={new_secret}')
else:
logger.info(f'Generated JWT secret (add to .env file): JWT_SECRET={new_secret}')
logger.info('✓ Secure JWT secret generated automatically')
logger.warning(
'JWT_SECRET is not configured. Authentication will fail. '
'Set JWT_SECRET environment variable before starting the application.'
)
else:
# Validate secret strength
if len(current_secret) < 64:
if settings.is_production:
logger.warning(
f'JWT_SECRET is only {len(current_secret)} characters. '
'Recommend using at least 64 characters for production security.'
)
logger.info('✓ JWT secret is configured')
@app.on_event('startup')
@@ -375,7 +432,34 @@ async def shutdown_event():
logger.info(f'{settings.APP_NAME} shutting down gracefully')
if __name__ == '__main__':
import uvicorn
import os
import signal
import sys
from pathlib import Path
def signal_handler(sig, frame):
"""Handle Ctrl+C gracefully."""
logger.info('\nReceived interrupt signal (Ctrl+C). Shutting down gracefully...')
sys.exit(0)
# Register signal handler for graceful shutdown on Ctrl+C
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
base_dir = Path(__file__).parent.parent
src_dir = str(base_dir / 'src')
uvicorn.run('src.main:app', host=settings.HOST, port=settings.PORT, reload=settings.is_development, log_level=settings.LOG_LEVEL.lower(), reload_dirs=[src_dir] if settings.is_development else None, reload_excludes=['*.log', '*.pyc', '*.pyo', '*.pyd', '__pycache__', '**/__pycache__/**', '*.db', '*.sqlite', '*.sqlite3'], reload_delay=0.5)
# Enable hot reload in development mode or if explicitly enabled via environment variable
use_reload = settings.is_development or os.getenv('ENABLE_RELOAD', 'false').lower() == 'true'
if use_reload:
logger.info('Hot reload enabled - server will restart on code changes')
logger.info('Press Ctrl+C to stop the server')
uvicorn.run(
'src.main:app',
host=settings.HOST,
port=settings.PORT,
reload=use_reload,
log_level=settings.LOG_LEVEL.lower(),
reload_dirs=[src_dir] if use_reload else None,
reload_excludes=['*.log', '*.pyc', '*.pyo', '*.pyd', '__pycache__', '**/__pycache__/**', '*.db', '*.sqlite', '*.sqlite3', 'venv/**', '.venv/**'],
reload_delay=0.5
)

View File

@@ -174,10 +174,13 @@ class BoricaService:
backend=default_backend()
)
# NOTE: SHA1 is required by Borica payment gateway protocol
# This is a known security trade-off required for payment gateway compatibility
# Monitor for Borica protocol updates that support stronger algorithms
signature = private_key.sign(
data.encode('utf-8'),
padding.PKCS1v15(),
hashes.SHA1()
hashes.SHA1() # nosec B303 # Required by Borica protocol - acceptable risk
)
return base64.b64encode(signature).decode('utf-8')
except Exception as e:
@@ -228,11 +231,13 @@ class BoricaService:
public_key = cert.public_key()
signature_bytes = base64.b64decode(signature)
# NOTE: SHA1 is required by Borica payment gateway protocol
# This is a known security trade-off required for payment gateway compatibility
public_key.verify(
signature_bytes,
signature_data.encode('utf-8'),
padding.PKCS1v15(),
hashes.SHA1()
hashes.SHA1() # nosec B303 # Required by Borica protocol - acceptable risk
)
return True
except Exception as e:

View File

@@ -10,7 +10,12 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
security_headers = {'X-Content-Type-Options': 'nosniff', 'X-Frame-Options': 'DENY', 'X-XSS-Protection': '1; mode=block', 'Referrer-Policy': 'strict-origin-when-cross-origin', 'Permissions-Policy': 'geolocation=(), microphone=(), camera=()'}
# Allow cross-origin resource sharing for uploads/images
# This is needed for images to load from different origins in development
if '/uploads/' in str(request.url):
security_headers.setdefault('Cross-Origin-Resource-Policy', 'cross-origin')
else:
security_headers.setdefault('Cross-Origin-Resource-Policy', 'same-origin')
if settings.is_production:
# Enhanced CSP with stricter directives
# Using 'strict-dynamic' for better security with nonce-based scripts

View File

@@ -10,14 +10,14 @@ class Settings(BaseSettings):
ENVIRONMENT: str = Field(default='development', description='Environment: development, staging, production')
DEBUG: bool = Field(default=False, description='Debug mode')
API_V1_PREFIX: str = Field(default='/api/v1', description='API v1 prefix')
HOST: str = Field(default='0.0.0.0', description='Server host')
HOST: str = Field(default='0.0.0.0', description='Server host. WARNING: 0.0.0.0 binds to all interfaces. Use 127.0.0.1 for development or specific IP for production.') # nosec B104 # Acceptable default with validation warning in production
PORT: int = Field(default=8000, description='Server port')
DB_USER: str = Field(default='root', description='Database user')
DB_PASS: str = Field(default='', description='Database password')
DB_NAME: str = Field(default='hotel_db', description='Database name')
DB_HOST: str = Field(default='localhost', description='Database host')
DB_PORT: str = Field(default='3306', description='Database port')
JWT_SECRET: str = Field(default='dev-secret-key-change-in-production-12345', description='JWT secret key')
JWT_SECRET: str = Field(default='', description='JWT secret key - MUST be set via environment variable. Minimum 64 characters recommended for production.')
JWT_ALGORITHM: str = Field(default='HS256', description='JWT algorithm')
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = Field(default=30, description='JWT access token expiration in minutes')
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = Field(default=3, description='JWT refresh token expiration in days (reduced from 7 for better security)')
@@ -97,6 +97,20 @@ class Settings(BaseSettings):
IP_WHITELIST_ENABLED: bool = Field(default=False, description='Enable IP whitelisting for admin endpoints')
ADMIN_IP_WHITELIST: List[str] = Field(default_factory=list, description='List of allowed IP addresses/CIDR ranges for admin endpoints')
def validate_host_configuration(self) -> None:
"""
Validate HOST configuration for security.
Warns if binding to all interfaces (0.0.0.0) in production.
"""
if self.HOST == '0.0.0.0' and self.is_production:
import logging
logger = logging.getLogger(__name__)
logger.warning(
'SECURITY WARNING: HOST is set to 0.0.0.0 in production. '
'This binds the server to all network interfaces. '
'Consider using a specific IP address or ensure proper firewall rules are in place.'
)
def validate_encryption_key(self) -> None:
"""
Validate encryption key is properly configured.
@@ -139,3 +153,40 @@ class Settings(BaseSettings):
logger.warning(f'Invalid ENCRYPTION_KEY format: {str(e)}')
settings = Settings()
# Validate JWT_SECRET on startup - fail fast if not configured
def validate_jwt_secret():
"""Validate JWT_SECRET is properly configured. Called on startup."""
if not settings.JWT_SECRET or settings.JWT_SECRET.strip() == '':
error_msg = (
'CRITICAL SECURITY ERROR: JWT_SECRET is not configured. '
'Please set JWT_SECRET environment variable to a secure random string. '
'Minimum 64 characters recommended for production. '
'Generate one using: python -c "import secrets; print(secrets.token_urlsafe(64))"'
)
import logging
logger = logging.getLogger(__name__)
logger.error(error_msg)
if settings.is_production:
raise ValueError(error_msg)
else:
logger.warning(
'JWT_SECRET not configured. This will cause authentication to fail. '
'Set JWT_SECRET environment variable before starting the application.'
)
# Warn if using weak secret (less than 64 characters)
if len(settings.JWT_SECRET) < 64:
import logging
logger = logging.getLogger(__name__)
if settings.is_production:
logger.warning(
f'JWT_SECRET is only {len(settings.JWT_SECRET)} characters. '
'Recommend using at least 64 characters for production security.'
)
else:
logger.debug(f'JWT_SECRET length: {len(settings.JWT_SECRET)} characters')
# Validate on import
validate_jwt_secret()
settings.validate_host_configuration()

View File

@@ -0,0 +1,168 @@
"""
HTML/XSS sanitization utilities using bleach library.
Prevents stored XSS attacks by sanitizing user-generated content.
"""
import bleach
from typing import Optional
# Allowed HTML tags for rich text content
ALLOWED_TAGS = [
'p', 'br', 'strong', 'em', 'u', 'b', 'i', 's', 'strike',
'a', 'ul', 'ol', 'li', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
'blockquote', 'pre', 'code', 'hr', 'div', 'span',
'table', 'thead', 'tbody', 'tr', 'th', 'td',
'img'
]
# Allowed attributes for specific tags
ALLOWED_ATTRIBUTES = {
'a': ['href', 'title', 'target', 'rel'],
'img': ['src', 'alt', 'title', 'width', 'height'],
'div': ['class'],
'span': ['class'],
'p': ['class'],
'table': ['class', 'border'],
'th': ['colspan', 'rowspan'],
'td': ['colspan', 'rowspan']
}
# Allowed URL schemes
ALLOWED_PROTOCOLS = ['http', 'https', 'mailto']
# Allowed CSS classes (optional - can be expanded)
ALLOWED_STYLES = []
def sanitize_html(content: Optional[str], strip: bool = False) -> str:
"""
Sanitize HTML content to prevent XSS attacks.
Args:
content: The HTML content to sanitize (can be None)
strip: If True, remove disallowed tags instead of escaping them
Returns:
Sanitized HTML string
"""
if not content:
return ''
if not isinstance(content, str):
content = str(content)
# Sanitize HTML
sanitized = bleach.clean(
content,
tags=ALLOWED_TAGS,
attributes=ALLOWED_ATTRIBUTES,
protocols=ALLOWED_PROTOCOLS,
strip=strip,
strip_comments=True
)
# Linkify URLs (convert plain URLs to links)
# Only linkify if content doesn't already contain HTML links
if '<a' not in sanitized:
sanitized = bleach.linkify(
sanitized,
protocols=ALLOWED_PROTOCOLS,
parse_email=True
)
return sanitized
def sanitize_text(content: Optional[str]) -> str:
"""
Strip all HTML tags from content, leaving only plain text.
Useful for fields that should not contain any HTML.
Args:
content: The content to sanitize (can be None)
Returns:
Plain text string with all HTML removed
"""
if not content:
return ''
if not isinstance(content, str):
content = str(content)
# Strip all HTML tags
return bleach.clean(content, tags=[], strip=True)
def sanitize_filename(filename: str) -> str:
"""
Sanitize filename to prevent path traversal and other attacks.
Args:
filename: The original filename
Returns:
Sanitized filename safe for filesystem operations
"""
import os
import secrets
from pathlib import Path
if not filename:
# Generate a random filename if none provided
return f"{secrets.token_urlsafe(16)}.bin"
# Remove path components (prevent directory traversal)
filename = os.path.basename(filename)
# Remove dangerous characters
# Keep only alphanumeric, dots, dashes, and underscores
safe_chars = []
for char in filename:
if char.isalnum() or char in '._-':
safe_chars.append(char)
else:
safe_chars.append('_')
filename = ''.join(safe_chars)
# Limit length (filesystem limit is typically 255)
if len(filename) > 255:
name, ext = os.path.splitext(filename)
max_name_length = 255 - len(ext)
filename = name[:max_name_length] + ext
# Ensure filename is not empty
if not filename or filename == '.' or filename == '..':
filename = f"{secrets.token_urlsafe(16)}.bin"
return filename
def sanitize_url(url: Optional[str]) -> Optional[str]:
"""
Sanitize URL to ensure it uses allowed protocols.
Args:
url: The URL to sanitize
Returns:
Sanitized URL or None if invalid
"""
if not url:
return None
if not isinstance(url, str):
url = str(url)
# Check if URL uses allowed protocol
url_lower = url.lower().strip()
if any(url_lower.startswith(proto + ':') for proto in ALLOWED_PROTOCOLS):
return url
# If no protocol, assume https
if '://' not in url:
return f'https://{url}'
# Invalid protocol - return None
return None

7
Backend/venv/bin/bandit Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from bandit.cli.main import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from bandit.cli.baseline import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from bandit.cli.config_generator import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

7
Backend/venv/bin/doesitcache Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from cachecontrol._cmd import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

7
Backend/venv/bin/fastapi Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from fastapi.cli import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

7
Backend/venv/bin/markdown-it Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from markdown_it.cli.parse import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

7
Backend/venv/bin/nltk Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from nltk.cli import cli
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(cli())

7
Backend/venv/bin/pip-audit Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from pip_audit._cli import audit
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(audit())

7
Backend/venv/bin/safety Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from safety.cli import cli
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(cli())

7
Backend/venv/bin/tqdm Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from tqdm.cli import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

7
Backend/venv/bin/typer Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from typer.cli import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

View File

@@ -1,59 +0,0 @@
Jinja2-3.1.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
Jinja2-3.1.2.dist-info/LICENSE.rst,sha256=O0nc7kEF6ze6wQ-vG-JgQI_oXSUrjp3y4JefweCUQ3s,1475
Jinja2-3.1.2.dist-info/METADATA,sha256=PZ6v2SIidMNixR7MRUX9f7ZWsPwtXanknqiZUmRbh4U,3539
Jinja2-3.1.2.dist-info/RECORD,,
Jinja2-3.1.2.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
Jinja2-3.1.2.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
Jinja2-3.1.2.dist-info/entry_points.txt,sha256=zRd62fbqIyfUpsRtU7EVIFyiu1tPwfgO7EvPErnxgTE,59
Jinja2-3.1.2.dist-info/top_level.txt,sha256=PkeVWtLb3-CqjWi1fO29OCbj55EhX_chhKrCdrVe_zs,7
jinja2/__init__.py,sha256=8vGduD8ytwgD6GDSqpYc2m3aU-T7PKOAddvVXgGr_Fs,1927
jinja2/__pycache__/__init__.cpython-312.pyc,,
jinja2/__pycache__/_identifier.cpython-312.pyc,,
jinja2/__pycache__/async_utils.cpython-312.pyc,,
jinja2/__pycache__/bccache.cpython-312.pyc,,
jinja2/__pycache__/compiler.cpython-312.pyc,,
jinja2/__pycache__/constants.cpython-312.pyc,,
jinja2/__pycache__/debug.cpython-312.pyc,,
jinja2/__pycache__/defaults.cpython-312.pyc,,
jinja2/__pycache__/environment.cpython-312.pyc,,
jinja2/__pycache__/exceptions.cpython-312.pyc,,
jinja2/__pycache__/ext.cpython-312.pyc,,
jinja2/__pycache__/filters.cpython-312.pyc,,
jinja2/__pycache__/idtracking.cpython-312.pyc,,
jinja2/__pycache__/lexer.cpython-312.pyc,,
jinja2/__pycache__/loaders.cpython-312.pyc,,
jinja2/__pycache__/meta.cpython-312.pyc,,
jinja2/__pycache__/nativetypes.cpython-312.pyc,,
jinja2/__pycache__/nodes.cpython-312.pyc,,
jinja2/__pycache__/optimizer.cpython-312.pyc,,
jinja2/__pycache__/parser.cpython-312.pyc,,
jinja2/__pycache__/runtime.cpython-312.pyc,,
jinja2/__pycache__/sandbox.cpython-312.pyc,,
jinja2/__pycache__/tests.cpython-312.pyc,,
jinja2/__pycache__/utils.cpython-312.pyc,,
jinja2/__pycache__/visitor.cpython-312.pyc,,
jinja2/_identifier.py,sha256=_zYctNKzRqlk_murTNlzrju1FFJL7Va_Ijqqd7ii2lU,1958
jinja2/async_utils.py,sha256=dHlbTeaxFPtAOQEYOGYh_PHcDT0rsDaUJAFDl_0XtTg,2472
jinja2/bccache.py,sha256=mhz5xtLxCcHRAa56azOhphIAe19u1we0ojifNMClDio,14061
jinja2/compiler.py,sha256=Gs-N8ThJ7OWK4-reKoO8Wh1ZXz95MVphBKNVf75qBr8,72172
jinja2/constants.py,sha256=GMoFydBF_kdpaRKPoM5cl5MviquVRLVyZtfp5-16jg0,1433
jinja2/debug.py,sha256=iWJ432RadxJNnaMOPrjIDInz50UEgni3_HKuFXi2vuQ,6299
jinja2/defaults.py,sha256=boBcSw78h-lp20YbaXSJsqkAI2uN_mD_TtCydpeq5wU,1267
jinja2/environment.py,sha256=6uHIcc7ZblqOMdx_uYNKqRnnwAF0_nzbyeMP9FFtuh4,61349
jinja2/exceptions.py,sha256=ioHeHrWwCWNaXX1inHmHVblvc4haO7AXsjCp3GfWvx0,5071
jinja2/ext.py,sha256=ivr3P7LKbddiXDVez20EflcO3q2aHQwz9P_PgWGHVqE,31502
jinja2/filters.py,sha256=9js1V-h2RlyW90IhLiBGLM2U-k6SCy2F4BUUMgB3K9Q,53509
jinja2/idtracking.py,sha256=GfNmadir4oDALVxzn3DL9YInhJDr69ebXeA2ygfuCGA,10704
jinja2/lexer.py,sha256=DW2nX9zk-6MWp65YR2bqqj0xqCvLtD-u9NWT8AnFRxQ,29726
jinja2/loaders.py,sha256=BfptfvTVpClUd-leMkHczdyPNYFzp_n7PKOJ98iyHOg,23207
jinja2/meta.py,sha256=GNPEvifmSaU3CMxlbheBOZjeZ277HThOPUTf1RkppKQ,4396
jinja2/nativetypes.py,sha256=DXgORDPRmVWgy034H0xL8eF7qYoK3DrMxs-935d0Fzk,4226
jinja2/nodes.py,sha256=i34GPRAZexXMT6bwuf5SEyvdmS-bRCy9KMjwN5O6pjk,34550
jinja2/optimizer.py,sha256=tHkMwXxfZkbfA1KmLcqmBMSaz7RLIvvItrJcPoXTyD8,1650
jinja2/parser.py,sha256=nHd-DFHbiygvfaPtm9rcQXJChZG7DPsWfiEsqfwKerY,39595
jinja2/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
jinja2/runtime.py,sha256=5CmD5BjbEJxSiDNTFBeKCaq8qU4aYD2v6q2EluyExms,33476
jinja2/sandbox.py,sha256=Y0xZeXQnH6EX5VjaV2YixESxoepnRbW_3UeQosaBU3M,14584
jinja2/tests.py,sha256=Am5Z6Lmfr2XaH_npIfJJ8MdXtWsbLjMULZJulTAj30E,5905
jinja2/utils.py,sha256=u9jXESxGn8ATZNVolwmkjUVu4SA-tLgV0W7PcSfPfdQ,23965
jinja2/visitor.py,sha256=MH14C6yq24G_KVtWzjwaI7Wg14PCJIYlWW1kpkxYak0,3568

View File

@@ -1,2 +0,0 @@
[babel.extractors]
jinja2 = jinja2.ext:babel_extract[i18n]

View File

@@ -0,0 +1,291 @@
from __future__ import annotations
import os
from io import BytesIO
from typing import IO
from . import ExifTags, Image, ImageFile
try:
from . import _avif
SUPPORTED = True
except ImportError:
SUPPORTED = False
# Decoder options as module globals, until there is a way to pass parameters
# to Image.open (see https://github.com/python-pillow/Pillow/issues/569)
DECODE_CODEC_CHOICE = "auto"
DEFAULT_MAX_THREADS = 0
def get_codec_version(codec_name: str) -> str | None:
versions = _avif.codec_versions()
for version in versions.split(", "):
if version.split(" [")[0] == codec_name:
return version.split(":")[-1].split(" ")[0]
return None
def _accept(prefix: bytes) -> bool | str:
if prefix[4:8] != b"ftyp":
return False
major_brand = prefix[8:12]
if major_brand in (
# coding brands
b"avif",
b"avis",
# We accept files with AVIF container brands; we can't yet know if
# the ftyp box has the correct compatible brands, but if it doesn't
# then the plugin will raise a SyntaxError which Pillow will catch
# before moving on to the next plugin that accepts the file.
#
# Also, because this file might not actually be an AVIF file, we
# don't raise an error if AVIF support isn't properly compiled.
b"mif1",
b"msf1",
):
if not SUPPORTED:
return (
"image file could not be identified because AVIF support not installed"
)
return True
return False
def _get_default_max_threads() -> int:
if DEFAULT_MAX_THREADS:
return DEFAULT_MAX_THREADS
if hasattr(os, "sched_getaffinity"):
return len(os.sched_getaffinity(0))
else:
return os.cpu_count() or 1
class AvifImageFile(ImageFile.ImageFile):
format = "AVIF"
format_description = "AVIF image"
__frame = -1
def _open(self) -> None:
if not SUPPORTED:
msg = "image file could not be opened because AVIF support not installed"
raise SyntaxError(msg)
if DECODE_CODEC_CHOICE != "auto" and not _avif.decoder_codec_available(
DECODE_CODEC_CHOICE
):
msg = "Invalid opening codec"
raise ValueError(msg)
self._decoder = _avif.AvifDecoder(
self.fp.read(),
DECODE_CODEC_CHOICE,
_get_default_max_threads(),
)
# Get info from decoder
self._size, self.n_frames, self._mode, icc, exif, exif_orientation, xmp = (
self._decoder.get_info()
)
self.is_animated = self.n_frames > 1
if icc:
self.info["icc_profile"] = icc
if xmp:
self.info["xmp"] = xmp
if exif_orientation != 1 or exif:
exif_data = Image.Exif()
if exif:
exif_data.load(exif)
original_orientation = exif_data.get(ExifTags.Base.Orientation, 1)
else:
original_orientation = 1
if exif_orientation != original_orientation:
exif_data[ExifTags.Base.Orientation] = exif_orientation
exif = exif_data.tobytes()
if exif:
self.info["exif"] = exif
self.seek(0)
def seek(self, frame: int) -> None:
if not self._seek_check(frame):
return
# Set tile
self.__frame = frame
self.tile = [ImageFile._Tile("raw", (0, 0) + self.size, 0, self.mode)]
def load(self) -> Image.core.PixelAccess | None:
if self.tile:
# We need to load the image data for this frame
data, timescale, pts_in_timescales, duration_in_timescales = (
self._decoder.get_frame(self.__frame)
)
self.info["timestamp"] = round(1000 * (pts_in_timescales / timescale))
self.info["duration"] = round(1000 * (duration_in_timescales / timescale))
if self.fp and self._exclusive_fp:
self.fp.close()
self.fp = BytesIO(data)
return super().load()
def load_seek(self, pos: int) -> None:
pass
def tell(self) -> int:
return self.__frame
def _save_all(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
_save(im, fp, filename, save_all=True)
def _save(
im: Image.Image, fp: IO[bytes], filename: str | bytes, save_all: bool = False
) -> None:
info = im.encoderinfo.copy()
if save_all:
append_images = list(info.get("append_images", []))
else:
append_images = []
total = 0
for ims in [im] + append_images:
total += getattr(ims, "n_frames", 1)
quality = info.get("quality", 75)
if not isinstance(quality, int) or quality < 0 or quality > 100:
msg = "Invalid quality setting"
raise ValueError(msg)
duration = info.get("duration", 0)
subsampling = info.get("subsampling", "4:2:0")
speed = info.get("speed", 6)
max_threads = info.get("max_threads", _get_default_max_threads())
codec = info.get("codec", "auto")
if codec != "auto" and not _avif.encoder_codec_available(codec):
msg = "Invalid saving codec"
raise ValueError(msg)
range_ = info.get("range", "full")
tile_rows_log2 = info.get("tile_rows", 0)
tile_cols_log2 = info.get("tile_cols", 0)
alpha_premultiplied = bool(info.get("alpha_premultiplied", False))
autotiling = bool(info.get("autotiling", tile_rows_log2 == tile_cols_log2 == 0))
icc_profile = info.get("icc_profile", im.info.get("icc_profile"))
exif_orientation = 1
if exif := info.get("exif"):
if isinstance(exif, Image.Exif):
exif_data = exif
else:
exif_data = Image.Exif()
exif_data.load(exif)
if ExifTags.Base.Orientation in exif_data:
exif_orientation = exif_data.pop(ExifTags.Base.Orientation)
exif = exif_data.tobytes() if exif_data else b""
elif isinstance(exif, Image.Exif):
exif = exif_data.tobytes()
xmp = info.get("xmp")
if isinstance(xmp, str):
xmp = xmp.encode("utf-8")
advanced = info.get("advanced")
if advanced is not None:
if isinstance(advanced, dict):
advanced = advanced.items()
try:
advanced = tuple(advanced)
except TypeError:
invalid = True
else:
invalid = any(not isinstance(v, tuple) or len(v) != 2 for v in advanced)
if invalid:
msg = (
"advanced codec options must be a dict of key-value string "
"pairs or a series of key-value two-tuples"
)
raise ValueError(msg)
# Setup the AVIF encoder
enc = _avif.AvifEncoder(
im.size,
subsampling,
quality,
speed,
max_threads,
codec,
range_,
tile_rows_log2,
tile_cols_log2,
alpha_premultiplied,
autotiling,
icc_profile or b"",
exif or b"",
exif_orientation,
xmp or b"",
advanced,
)
# Add each frame
frame_idx = 0
frame_duration = 0
cur_idx = im.tell()
is_single_frame = total == 1
try:
for ims in [im] + append_images:
# Get number of frames in this image
nfr = getattr(ims, "n_frames", 1)
for idx in range(nfr):
ims.seek(idx)
# Make sure image mode is supported
frame = ims
rawmode = ims.mode
if ims.mode not in {"RGB", "RGBA"}:
rawmode = "RGBA" if ims.has_transparency_data else "RGB"
frame = ims.convert(rawmode)
# Update frame duration
if isinstance(duration, (list, tuple)):
frame_duration = duration[frame_idx]
else:
frame_duration = duration
# Append the frame to the animation encoder
enc.add(
frame.tobytes("raw", rawmode),
frame_duration,
frame.size,
rawmode,
is_single_frame,
)
# Update frame index
frame_idx += 1
if not save_all:
break
finally:
im.seek(cur_idx)
# Get the final output from the encoder
data = enc.finish()
if data is None:
msg = "cannot write file as AVIF (encoder returned None)"
raise OSError(msg)
fp.write(data)
Image.register_open(AvifImageFile.format, AvifImageFile, _accept)
if SUPPORTED:
Image.register_save(AvifImageFile.format, _save)
Image.register_save_all(AvifImageFile.format, _save_all)
Image.register_extensions(AvifImageFile.format, [".avif", ".avifs"])
Image.register_mime(AvifImageFile.format, "image/avif")

View File

@@ -20,29 +20,30 @@
"""
Parse X Bitmap Distribution Format (BDF)
"""
from __future__ import annotations
from typing import BinaryIO
from . import FontFile, Image
bdf_slant = {
"R": "Roman",
"I": "Italic",
"O": "Oblique",
"RI": "Reverse Italic",
"RO": "Reverse Oblique",
"OT": "Other",
}
bdf_spacing = {"P": "Proportional", "M": "Monospaced", "C": "Cell"}
def bdf_char(f):
def bdf_char(
f: BinaryIO,
) -> (
tuple[
str,
int,
tuple[tuple[int, int], tuple[int, int, int, int], tuple[int, int, int, int]],
Image.Image,
]
| None
):
# skip to STARTCHAR
while True:
s = f.readline()
if not s:
return None
if s[:9] == b"STARTCHAR":
if s.startswith(b"STARTCHAR"):
break
id = s[9:].strip().decode("ascii")
@@ -50,19 +51,18 @@ def bdf_char(f):
props = {}
while True:
s = f.readline()
if not s or s[:6] == b"BITMAP":
if not s or s.startswith(b"BITMAP"):
break
i = s.find(b" ")
props[s[:i].decode("ascii")] = s[i + 1 : -1].decode("ascii")
# load bitmap
bitmap = []
bitmap = bytearray()
while True:
s = f.readline()
if not s or s[:7] == b"ENDCHAR":
if not s or s.startswith(b"ENDCHAR"):
break
bitmap.append(s[:-1])
bitmap = b"".join(bitmap)
bitmap += s[:-1]
# The word BBX
# followed by the width in x (BBw), height in y (BBh),
@@ -92,11 +92,11 @@ def bdf_char(f):
class BdfFontFile(FontFile.FontFile):
"""Font file plugin for the X11 BDF format."""
def __init__(self, fp):
def __init__(self, fp: BinaryIO) -> None:
super().__init__()
s = fp.readline()
if s[:13] != b"STARTFONT 2.1":
if not s.startswith(b"STARTFONT 2.1"):
msg = "not a valid BDF file"
raise SyntaxError(msg)
@@ -105,7 +105,7 @@ class BdfFontFile(FontFile.FontFile):
while True:
s = fp.readline()
if not s or s[:13] == b"ENDPROPERTIES":
if not s or s.startswith(b"ENDPROPERTIES"):
break
i = s.find(b" ")
props[s[:i].decode("ascii")] = s[i + 1 : -1].decode("ascii")

View File

@@ -29,10 +29,14 @@ BLP files come in many different flavours:
- DXT5 compression is used if alpha_encoding == 7.
"""
from __future__ import annotations
import abc
import os
import struct
from enum import IntEnum
from io import BytesIO
from typing import IO
from . import Image, ImageFile
@@ -53,11 +57,13 @@ class AlphaEncoding(IntEnum):
DXT5 = 7
def unpack_565(i):
def unpack_565(i: int) -> tuple[int, int, int]:
return ((i >> 11) & 0x1F) << 3, ((i >> 5) & 0x3F) << 2, (i & 0x1F) << 3
def decode_dxt1(data, alpha=False):
def decode_dxt1(
data: bytes, alpha: bool = False
) -> tuple[bytearray, bytearray, bytearray, bytearray]:
"""
input: one "row" of data (i.e. will produce 4*width pixels)
"""
@@ -65,9 +71,9 @@ def decode_dxt1(data, alpha=False):
blocks = len(data) // 8 # number of blocks in row
ret = (bytearray(), bytearray(), bytearray(), bytearray())
for block in range(blocks):
for block_index in range(blocks):
# Decode next 8-byte block.
idx = block * 8
idx = block_index * 8
color0, color1, bits = struct.unpack_from("<HHI", data, idx)
r0, g0, b0 = unpack_565(color0)
@@ -112,7 +118,7 @@ def decode_dxt1(data, alpha=False):
return ret
def decode_dxt3(data):
def decode_dxt3(data: bytes) -> tuple[bytearray, bytearray, bytearray, bytearray]:
"""
input: one "row" of data (i.e. will produce 4*width pixels)
"""
@@ -120,8 +126,8 @@ def decode_dxt3(data):
blocks = len(data) // 16 # number of blocks in row
ret = (bytearray(), bytearray(), bytearray(), bytearray())
for block in range(blocks):
idx = block * 16
for block_index in range(blocks):
idx = block_index * 16
block = data[idx : idx + 16]
# Decode next 16-byte block.
bits = struct.unpack_from("<8B", block)
@@ -165,7 +171,7 @@ def decode_dxt3(data):
return ret
def decode_dxt5(data):
def decode_dxt5(data: bytes) -> tuple[bytearray, bytearray, bytearray, bytearray]:
"""
input: one "row" of data (i.e. will produce 4 * width pixels)
"""
@@ -173,8 +179,8 @@ def decode_dxt5(data):
blocks = len(data) // 16 # number of blocks in row
ret = (bytearray(), bytearray(), bytearray(), bytearray())
for block in range(blocks):
idx = block * 16
for block_index in range(blocks):
idx = block_index * 16
block = data[idx : idx + 16]
# Decode next 16-byte block.
a0, a1 = struct.unpack_from("<BB", block)
@@ -239,8 +245,8 @@ class BLPFormatError(NotImplementedError):
pass
def _accept(prefix):
return prefix[:4] in (b"BLP1", b"BLP2")
def _accept(prefix: bytes) -> bool:
return prefix.startswith((b"BLP1", b"BLP2"))
class BlpImageFile(ImageFile.ImageFile):
@@ -251,60 +257,65 @@ class BlpImageFile(ImageFile.ImageFile):
format = "BLP"
format_description = "Blizzard Mipmap Format"
def _open(self):
def _open(self) -> None:
self.magic = self.fp.read(4)
self.fp.seek(5, os.SEEK_CUR)
(self._blp_alpha_depth,) = struct.unpack("<b", self.fp.read(1))
self.fp.seek(2, os.SEEK_CUR)
self._size = struct.unpack("<II", self.fp.read(8))
if self.magic in (b"BLP1", b"BLP2"):
decoder = self.magic.decode()
else:
if not _accept(self.magic):
msg = f"Bad BLP magic {repr(self.magic)}"
raise BLPFormatError(msg)
self._mode = "RGBA" if self._blp_alpha_depth else "RGB"
self.tile = [(decoder, (0, 0) + self.size, 0, (self.mode, 0, 1))]
compression = struct.unpack("<i", self.fp.read(4))[0]
if self.magic == b"BLP1":
alpha = struct.unpack("<I", self.fp.read(4))[0] != 0
else:
encoding = struct.unpack("<b", self.fp.read(1))[0]
alpha = struct.unpack("<b", self.fp.read(1))[0] != 0
alpha_encoding = struct.unpack("<b", self.fp.read(1))[0]
self.fp.seek(1, os.SEEK_CUR) # mips
self._size = struct.unpack("<II", self.fp.read(8))
args: tuple[int, int, bool] | tuple[int, int, bool, int]
if self.magic == b"BLP1":
encoding = struct.unpack("<i", self.fp.read(4))[0]
self.fp.seek(4, os.SEEK_CUR) # subtype
args = (compression, encoding, alpha)
offset = 28
else:
args = (compression, encoding, alpha, alpha_encoding)
offset = 20
decoder = self.magic.decode()
self._mode = "RGBA" if alpha else "RGB"
self.tile = [ImageFile._Tile(decoder, (0, 0) + self.size, offset, args)]
class _BLPBaseDecoder(ImageFile.PyDecoder):
class _BLPBaseDecoder(abc.ABC, ImageFile.PyDecoder):
_pulls_fd = True
def decode(self, buffer):
def decode(self, buffer: bytes | Image.SupportsArrayInterface) -> tuple[int, int]:
try:
self._read_blp_header()
self._read_header()
self._load()
except struct.error as e:
msg = "Truncated BLP file"
raise OSError(msg) from e
return -1, 0
def _read_blp_header(self):
self.fd.seek(4)
(self._blp_compression,) = struct.unpack("<i", self._safe_read(4))
@abc.abstractmethod
def _load(self) -> None:
pass
(self._blp_encoding,) = struct.unpack("<b", self._safe_read(1))
(self._blp_alpha_depth,) = struct.unpack("<b", self._safe_read(1))
(self._blp_alpha_encoding,) = struct.unpack("<b", self._safe_read(1))
self.fd.seek(1, os.SEEK_CUR) # mips
def _read_header(self) -> None:
self._offsets = struct.unpack("<16I", self._safe_read(16 * 4))
self._lengths = struct.unpack("<16I", self._safe_read(16 * 4))
self.size = struct.unpack("<II", self._safe_read(8))
if isinstance(self, BLP1Decoder):
# Only present for BLP1
(self._blp_encoding,) = struct.unpack("<i", self._safe_read(4))
self.fd.seek(4, os.SEEK_CUR) # subtype
self._blp_offsets = struct.unpack("<16I", self._safe_read(16 * 4))
self._blp_lengths = struct.unpack("<16I", self._safe_read(16 * 4))
def _safe_read(self, length):
def _safe_read(self, length: int) -> bytes:
assert self.fd is not None
return ImageFile._safe_read(self.fd, length)
def _read_palette(self):
def _read_palette(self) -> list[tuple[int, int, int, int]]:
ret = []
for i in range(256):
try:
@@ -314,110 +325,115 @@ class _BLPBaseDecoder(ImageFile.PyDecoder):
ret.append((b, g, r, a))
return ret
def _read_bgra(self, palette):
def _read_bgra(
self, palette: list[tuple[int, int, int, int]], alpha: bool
) -> bytearray:
data = bytearray()
_data = BytesIO(self._safe_read(self._blp_lengths[0]))
_data = BytesIO(self._safe_read(self._lengths[0]))
while True:
try:
(offset,) = struct.unpack("<B", _data.read(1))
except struct.error:
break
b, g, r, a = palette[offset]
d = (r, g, b)
if self._blp_alpha_depth:
d: tuple[int, ...] = (r, g, b)
if alpha:
d += (a,)
data.extend(d)
return data
class BLP1Decoder(_BLPBaseDecoder):
def _load(self):
if self._blp_compression == Format.JPEG:
def _load(self) -> None:
self._compression, self._encoding, alpha = self.args
if self._compression == Format.JPEG:
self._decode_jpeg_stream()
elif self._blp_compression == 1:
if self._blp_encoding in (4, 5):
elif self._compression == 1:
if self._encoding in (4, 5):
palette = self._read_palette()
data = self._read_bgra(palette)
self.set_as_raw(bytes(data))
data = self._read_bgra(palette, alpha)
self.set_as_raw(data)
else:
msg = f"Unsupported BLP encoding {repr(self._blp_encoding)}"
msg = f"Unsupported BLP encoding {repr(self._encoding)}"
raise BLPFormatError(msg)
else:
msg = f"Unsupported BLP compression {repr(self._blp_encoding)}"
msg = f"Unsupported BLP compression {repr(self._encoding)}"
raise BLPFormatError(msg)
def _decode_jpeg_stream(self):
def _decode_jpeg_stream(self) -> None:
from .JpegImagePlugin import JpegImageFile
(jpeg_header_size,) = struct.unpack("<I", self._safe_read(4))
jpeg_header = self._safe_read(jpeg_header_size)
self._safe_read(self._blp_offsets[0] - self.fd.tell()) # What IS this?
data = self._safe_read(self._blp_lengths[0])
assert self.fd is not None
self._safe_read(self._offsets[0] - self.fd.tell()) # What IS this?
data = self._safe_read(self._lengths[0])
data = jpeg_header + data
data = BytesIO(data)
image = JpegImageFile(data)
image = JpegImageFile(BytesIO(data))
Image._decompression_bomb_check(image.size)
if image.mode == "CMYK":
decoder_name, extents, offset, args = image.tile[0]
image.tile = [(decoder_name, extents, offset, (args[0], "CMYK"))]
r, g, b = image.convert("RGB").split()
image = Image.merge("RGB", (b, g, r))
self.set_as_raw(image.tobytes())
args = image.tile[0].args
assert isinstance(args, tuple)
image.tile = [image.tile[0]._replace(args=(args[0], "CMYK"))]
self.set_as_raw(image.convert("RGB").tobytes(), "BGR")
class BLP2Decoder(_BLPBaseDecoder):
def _load(self):
def _load(self) -> None:
self._compression, self._encoding, alpha, self._alpha_encoding = self.args
palette = self._read_palette()
self.fd.seek(self._blp_offsets[0])
assert self.fd is not None
self.fd.seek(self._offsets[0])
if self._blp_compression == 1:
if self._compression == 1:
# Uncompressed or DirectX compression
if self._blp_encoding == Encoding.UNCOMPRESSED:
data = self._read_bgra(palette)
if self._encoding == Encoding.UNCOMPRESSED:
data = self._read_bgra(palette, alpha)
elif self._blp_encoding == Encoding.DXT:
elif self._encoding == Encoding.DXT:
data = bytearray()
if self._blp_alpha_encoding == AlphaEncoding.DXT1:
linesize = (self.size[0] + 3) // 4 * 8
for yb in range((self.size[1] + 3) // 4):
for d in decode_dxt1(
self._safe_read(linesize), alpha=bool(self._blp_alpha_depth)
):
if self._alpha_encoding == AlphaEncoding.DXT1:
linesize = (self.state.xsize + 3) // 4 * 8
for yb in range((self.state.ysize + 3) // 4):
for d in decode_dxt1(self._safe_read(linesize), alpha):
data += d
elif self._blp_alpha_encoding == AlphaEncoding.DXT3:
linesize = (self.size[0] + 3) // 4 * 16
for yb in range((self.size[1] + 3) // 4):
elif self._alpha_encoding == AlphaEncoding.DXT3:
linesize = (self.state.xsize + 3) // 4 * 16
for yb in range((self.state.ysize + 3) // 4):
for d in decode_dxt3(self._safe_read(linesize)):
data += d
elif self._blp_alpha_encoding == AlphaEncoding.DXT5:
linesize = (self.size[0] + 3) // 4 * 16
for yb in range((self.size[1] + 3) // 4):
elif self._alpha_encoding == AlphaEncoding.DXT5:
linesize = (self.state.xsize + 3) // 4 * 16
for yb in range((self.state.ysize + 3) // 4):
for d in decode_dxt5(self._safe_read(linesize)):
data += d
else:
msg = f"Unsupported alpha encoding {repr(self._blp_alpha_encoding)}"
msg = f"Unsupported alpha encoding {repr(self._alpha_encoding)}"
raise BLPFormatError(msg)
else:
msg = f"Unknown BLP encoding {repr(self._blp_encoding)}"
msg = f"Unknown BLP encoding {repr(self._encoding)}"
raise BLPFormatError(msg)
else:
msg = f"Unknown BLP compression {repr(self._blp_compression)}"
msg = f"Unknown BLP compression {repr(self._compression)}"
raise BLPFormatError(msg)
self.set_as_raw(bytes(data))
self.set_as_raw(data)
class BLPEncoder(ImageFile.PyEncoder):
_pushes_fd = True
def _write_palette(self):
def _write_palette(self) -> bytes:
data = b""
assert self.im is not None
palette = self.im.getpalette("RGBA", "RGBA")
for i in range(len(palette) // 4):
r, g, b, a = palette[i * 4 : (i + 1) * 4]
@@ -426,12 +442,13 @@ class BLPEncoder(ImageFile.PyEncoder):
data += b"\x00" * 4
return data
def encode(self, bufsize):
def encode(self, bufsize: int) -> tuple[int, int, bytes]:
palette_data = self._write_palette()
offset = 20 + 16 * 4 * 2 + len(palette_data)
data = struct.pack("<16I", offset, *((0,) * 15))
assert self.im is not None
w, h = self.im.size
data += struct.pack("<16I", w * h, *((0,) * 15))
@@ -444,7 +461,7 @@ class BLPEncoder(ImageFile.PyEncoder):
return len(data), 0, data
def _save(im, fp, filename):
def _save(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
if im.mode != "P":
msg = "Unsupported BLP image mode"
raise ValueError(msg)
@@ -452,9 +469,15 @@ def _save(im, fp, filename):
magic = b"BLP1" if im.encoderinfo.get("blp_version") == "BLP1" else b"BLP2"
fp.write(magic)
assert im.palette is not None
fp.write(struct.pack("<i", 1)) # Uncompressed or DirectX compression
alpha_depth = 1 if im.palette.mode == "RGBA" else 0
if magic == b"BLP1":
fp.write(struct.pack("<L", alpha_depth))
else:
fp.write(struct.pack("<b", Encoding.UNCOMPRESSED))
fp.write(struct.pack("<b", 1 if im.palette.mode == "RGBA" else 0))
fp.write(struct.pack("<b", alpha_depth))
fp.write(struct.pack("<b", 0)) # alpha encoding
fp.write(struct.pack("<b", 0)) # mips
fp.write(struct.pack("<II", *im.size))
@@ -462,7 +485,7 @@ def _save(im, fp, filename):
fp.write(struct.pack("<i", 5))
fp.write(struct.pack("<i", 0))
ImageFile._save(im, fp, [("BLP", (0, 0) + im.size, 0, im.mode)])
ImageFile._save(im, fp, [ImageFile._Tile("BLP", (0, 0) + im.size, 0, im.mode)])
Image.register_open(BlpImageFile.format, BlpImageFile, _accept)

View File

@@ -22,9 +22,10 @@
#
# See the README file for information on usage and redistribution.
#
from __future__ import annotations
import os
from typing import IO, Any
from . import Image, ImageFile, ImagePalette
from ._binary import i16le as i16
@@ -47,13 +48,15 @@ BIT2MODE = {
32: ("RGB", "BGRX"),
}
def _accept(prefix):
return prefix[:2] == b"BM"
USE_RAW_ALPHA = False
def _dib_accept(prefix):
return i32(prefix) in [12, 40, 64, 108, 124]
def _accept(prefix: bytes) -> bool:
return prefix.startswith(b"BM")
def _dib_accept(prefix: bytes) -> bool:
return i32(prefix) in [12, 40, 52, 56, 64, 108, 124]
# =============================================================================
@@ -71,31 +74,41 @@ class BmpImageFile(ImageFile.ImageFile):
for k, v in COMPRESSIONS.items():
vars()[k] = v
def _bitmap(self, header=0, offset=0):
def _bitmap(self, header: int = 0, offset: int = 0) -> None:
"""Read relevant info about the BMP"""
read, seek = self.fp.read, self.fp.seek
if header:
seek(header)
# read bmp header size @offset 14 (this is part of the header size)
file_info = {"header_size": i32(read(4)), "direction": -1}
file_info: dict[str, bool | int | tuple[int, ...]] = {
"header_size": i32(read(4)),
"direction": -1,
}
# -------------------- If requested, read header at a specific position
# read the rest of the bmp header, without its size
assert isinstance(file_info["header_size"], int)
header_data = ImageFile._safe_read(self.fp, file_info["header_size"] - 4)
# -------------------------------------------------- IBM OS/2 Bitmap v1
# ------------------------------- Windows Bitmap v2, IBM OS/2 Bitmap v1
# ----- This format has different offsets because of width/height types
# 12: BITMAPCOREHEADER/OS21XBITMAPHEADER
if file_info["header_size"] == 12:
file_info["width"] = i16(header_data, 0)
file_info["height"] = i16(header_data, 2)
file_info["planes"] = i16(header_data, 4)
file_info["bits"] = i16(header_data, 6)
file_info["compression"] = self.RAW
file_info["compression"] = self.COMPRESSIONS["RAW"]
file_info["palette_padding"] = 3
# --------------------------------------------- Windows Bitmap v2 to v5
# v3, OS/2 v2, v4, v5
elif file_info["header_size"] in (40, 64, 108, 124):
# --------------------------------------------- Windows Bitmap v3 to v5
# 40: BITMAPINFOHEADER
# 52: BITMAPV2HEADER
# 56: BITMAPV3HEADER
# 64: BITMAPCOREHEADER2/OS22XBITMAPHEADER
# 108: BITMAPV4HEADER
# 124: BITMAPV5HEADER
elif file_info["header_size"] in (40, 52, 56, 64, 108, 124):
file_info["y_flip"] = header_data[7] == 0xFF
file_info["direction"] = 1 if file_info["y_flip"] else -1
file_info["width"] = i32(header_data, 0)
@@ -115,12 +128,16 @@ class BmpImageFile(ImageFile.ImageFile):
)
file_info["colors"] = i32(header_data, 28)
file_info["palette_padding"] = 4
assert isinstance(file_info["pixels_per_meter"], tuple)
self.info["dpi"] = tuple(x / 39.3701 for x in file_info["pixels_per_meter"])
if file_info["compression"] == self.BITFIELDS:
if file_info["compression"] == self.COMPRESSIONS["BITFIELDS"]:
masks = ["r_mask", "g_mask", "b_mask"]
if len(header_data) >= 48:
if len(header_data) >= 52:
for idx, mask in enumerate(
["r_mask", "g_mask", "b_mask", "a_mask"]
):
masks.append("a_mask")
else:
file_info["a_mask"] = 0x0
for idx, mask in enumerate(masks):
file_info[mask] = i32(header_data, 36 + idx * 4)
else:
# 40 byte headers only have the three components in the
@@ -132,8 +149,12 @@ class BmpImageFile(ImageFile.ImageFile):
# location, but it is listed as a reserved component,
# and it is not generally an alpha channel
file_info["a_mask"] = 0x0
for mask in ["r_mask", "g_mask", "b_mask"]:
for mask in masks:
file_info[mask] = i32(read(4))
assert isinstance(file_info["r_mask"], int)
assert isinstance(file_info["g_mask"], int)
assert isinstance(file_info["b_mask"], int)
assert isinstance(file_info["a_mask"], int)
file_info["rgb_mask"] = (
file_info["r_mask"],
file_info["g_mask"],
@@ -151,33 +172,39 @@ class BmpImageFile(ImageFile.ImageFile):
# ------------------ Special case : header is reported 40, which
# ---------------------- is shorter than real size for bpp >= 16
assert isinstance(file_info["width"], int)
assert isinstance(file_info["height"], int)
self._size = file_info["width"], file_info["height"]
# ------- If color count was not found in the header, compute from bits
assert isinstance(file_info["bits"], int)
file_info["colors"] = (
file_info["colors"]
if file_info.get("colors", 0)
else (1 << file_info["bits"])
)
assert isinstance(file_info["colors"], int)
if offset == 14 + file_info["header_size"] and file_info["bits"] <= 8:
offset += 4 * file_info["colors"]
# ---------------------- Check bit depth for unusual unsupported values
self._mode, raw_mode = BIT2MODE.get(file_info["bits"], (None, None))
if self.mode is None:
self._mode, raw_mode = BIT2MODE.get(file_info["bits"], ("", ""))
if not self.mode:
msg = f"Unsupported BMP pixel depth ({file_info['bits']})"
raise OSError(msg)
# ---------------- Process BMP with Bitfields compression (not palette)
decoder_name = "raw"
if file_info["compression"] == self.BITFIELDS:
SUPPORTED = {
if file_info["compression"] == self.COMPRESSIONS["BITFIELDS"]:
SUPPORTED: dict[int, list[tuple[int, ...]]] = {
32: [
(0xFF0000, 0xFF00, 0xFF, 0x0),
(0xFF000000, 0xFF0000, 0xFF00, 0x0),
(0xFF000000, 0xFF00, 0xFF, 0x0),
(0xFF000000, 0xFF0000, 0xFF00, 0xFF),
(0xFF, 0xFF00, 0xFF0000, 0xFF000000),
(0xFF0000, 0xFF00, 0xFF, 0xFF000000),
(0xFF000000, 0xFF00, 0xFF, 0xFF0000),
(0x0, 0x0, 0x0, 0x0),
],
24: [(0xFF0000, 0xFF00, 0xFF)],
@@ -186,9 +213,11 @@ class BmpImageFile(ImageFile.ImageFile):
MASK_MODES = {
(32, (0xFF0000, 0xFF00, 0xFF, 0x0)): "BGRX",
(32, (0xFF000000, 0xFF0000, 0xFF00, 0x0)): "XBGR",
(32, (0xFF000000, 0xFF00, 0xFF, 0x0)): "BGXR",
(32, (0xFF000000, 0xFF0000, 0xFF00, 0xFF)): "ABGR",
(32, (0xFF, 0xFF00, 0xFF0000, 0xFF000000)): "RGBA",
(32, (0xFF0000, 0xFF00, 0xFF, 0xFF000000)): "BGRA",
(32, (0xFF000000, 0xFF00, 0xFF, 0xFF0000)): "BGAR",
(32, (0x0, 0x0, 0x0, 0x0)): "BGRA",
(24, (0xFF0000, 0xFF00, 0xFF)): "BGR",
(16, (0xF800, 0x7E0, 0x1F)): "BGR;16",
@@ -199,12 +228,14 @@ class BmpImageFile(ImageFile.ImageFile):
file_info["bits"] == 32
and file_info["rgba_mask"] in SUPPORTED[file_info["bits"]]
):
assert isinstance(file_info["rgba_mask"], tuple)
raw_mode = MASK_MODES[(file_info["bits"], file_info["rgba_mask"])]
self._mode = "RGBA" if "A" in raw_mode else self.mode
elif (
file_info["bits"] in (24, 16)
and file_info["rgb_mask"] in SUPPORTED[file_info["bits"]]
):
assert isinstance(file_info["rgb_mask"], tuple)
raw_mode = MASK_MODES[(file_info["bits"], file_info["rgb_mask"])]
else:
msg = "Unsupported BMP bitfields layout"
@@ -212,10 +243,15 @@ class BmpImageFile(ImageFile.ImageFile):
else:
msg = "Unsupported BMP bitfields layout"
raise OSError(msg)
elif file_info["compression"] == self.RAW:
if file_info["bits"] == 32 and header == 22: # 32-bit .cur offset
elif file_info["compression"] == self.COMPRESSIONS["RAW"]:
if file_info["bits"] == 32 and (
header == 22 or USE_RAW_ALPHA # 32-bit .cur offset
):
raw_mode, self._mode = "BGRA", "RGBA"
elif file_info["compression"] in (self.RLE8, self.RLE4):
elif file_info["compression"] in (
self.COMPRESSIONS["RLE8"],
self.COMPRESSIONS["RLE4"],
):
decoder_name = "bmp_rle"
else:
msg = f"Unsupported BMP compression ({file_info['compression']})"
@@ -228,23 +264,24 @@ class BmpImageFile(ImageFile.ImageFile):
msg = f"Unsupported BMP Palette size ({file_info['colors']})"
raise OSError(msg)
else:
assert isinstance(file_info["palette_padding"], int)
padding = file_info["palette_padding"]
palette = read(padding * file_info["colors"])
greyscale = True
grayscale = True
indices = (
(0, 255)
if file_info["colors"] == 2
else list(range(file_info["colors"]))
)
# ----------------- Check if greyscale and ignore palette if so
# ----------------- Check if grayscale and ignore palette if so
for ind, val in enumerate(indices):
rgb = palette[ind * padding : ind * padding + 3]
if rgb != o8(val) * 3:
greyscale = False
grayscale = False
# ------- If all colors are grey, white or black, ditch palette
if greyscale:
# ------- If all colors are gray, white or black, ditch palette
if grayscale:
self._mode = "1" if file_info["colors"] == 2 else "L"
raw_mode = self.mode
else:
@@ -255,14 +292,15 @@ class BmpImageFile(ImageFile.ImageFile):
# ---------------------------- Finally set the tile data for the plugin
self.info["compression"] = file_info["compression"]
args = [raw_mode]
args: list[Any] = [raw_mode]
if decoder_name == "bmp_rle":
args.append(file_info["compression"] == self.RLE4)
args.append(file_info["compression"] == self.COMPRESSIONS["RLE4"])
else:
assert isinstance(file_info["width"], int)
args.append(((file_info["width"] * file_info["bits"] + 31) >> 3) & (~3))
args.append(file_info["direction"])
self.tile = [
(
ImageFile._Tile(
decoder_name,
(0, 0, file_info["width"], file_info["height"]),
offset or self.fp.tell(),
@@ -270,7 +308,7 @@ class BmpImageFile(ImageFile.ImageFile):
)
]
def _open(self):
def _open(self) -> None:
"""Open file, check magic number and read header"""
# read 14 bytes: magic number, filesize, reserved, header final offset
head_data = self.fp.read(14)
@@ -287,11 +325,13 @@ class BmpImageFile(ImageFile.ImageFile):
class BmpRleDecoder(ImageFile.PyDecoder):
_pulls_fd = True
def decode(self, buffer):
def decode(self, buffer: bytes | Image.SupportsArrayInterface) -> tuple[int, int]:
assert self.fd is not None
rle4 = self.args[1]
data = bytearray()
x = 0
while len(data) < self.state.xsize * self.state.ysize:
dest_length = self.state.xsize * self.state.ysize
while len(data) < dest_length:
pixels = self.fd.read(1)
byte = self.fd.read(1)
if not pixels or not byte:
@@ -351,7 +391,7 @@ class BmpRleDecoder(ImageFile.PyDecoder):
if self.fd.tell() % 2 != 0:
self.fd.seek(1, os.SEEK_CUR)
rawmode = "L" if self.mode == "L" else "P"
self.set_as_raw(bytes(data), (rawmode, 0, self.args[-1]))
self.set_as_raw(bytes(data), rawmode, (0, self.args[-1]))
return -1, 0
@@ -362,7 +402,7 @@ class DibImageFile(BmpImageFile):
format = "DIB"
format_description = "Windows Bitmap"
def _open(self):
def _open(self) -> None:
self._bitmap()
@@ -380,11 +420,13 @@ SAVE = {
}
def _dib_save(im, fp, filename):
def _dib_save(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
_save(im, fp, filename, False)
def _save(im, fp, filename, bitmap_header=True):
def _save(
im: Image.Image, fp: IO[bytes], filename: str | bytes, bitmap_header: bool = True
) -> None:
try:
rawmode, bits, colors = SAVE[im.mode]
except KeyError as e:
@@ -396,16 +438,16 @@ def _save(im, fp, filename, bitmap_header=True):
dpi = info.get("dpi", (96, 96))
# 1 meter == 39.3701 inches
ppm = tuple(map(lambda x: int(x * 39.3701 + 0.5), dpi))
ppm = tuple(int(x * 39.3701 + 0.5) for x in dpi)
stride = ((im.size[0] * bits + 7) // 8 + 3) & (~3)
header = 40 # or 64 for OS/2 version 2
image = stride * im.size[1]
if im.mode == "1":
palette = b"".join(o8(i) * 4 for i in (0, 255))
palette = b"".join(o8(i) * 3 + b"\x00" for i in (0, 255))
elif im.mode == "L":
palette = b"".join(o8(i) * 4 for i in range(256))
palette = b"".join(o8(i) * 3 + b"\x00" for i in range(256))
elif im.mode == "P":
palette = im.im.getpalette("RGB", "BGRX")
colors = len(palette) // 4
@@ -446,7 +488,9 @@ def _save(im, fp, filename, bitmap_header=True):
if palette:
fp.write(palette)
ImageFile._save(im, fp, [("raw", (0, 0) + im.size, 0, (rawmode, stride, -1))])
ImageFile._save(
im, fp, [ImageFile._Tile("raw", (0, 0) + im.size, 0, (rawmode, stride, -1))]
)
#

View File

@@ -8,13 +8,17 @@
#
# See the README file for information on usage and redistribution.
#
from __future__ import annotations
import os
from typing import IO
from . import Image, ImageFile
_handler = None
def register_handler(handler):
def register_handler(handler: ImageFile.StubHandler | None) -> None:
"""
Install application-specific BUFR image handler.
@@ -28,22 +32,20 @@ def register_handler(handler):
# Image adapter
def _accept(prefix):
return prefix[:4] == b"BUFR" or prefix[:4] == b"ZCZC"
def _accept(prefix: bytes) -> bool:
return prefix.startswith((b"BUFR", b"ZCZC"))
class BufrStubImageFile(ImageFile.StubImageFile):
format = "BUFR"
format_description = "BUFR"
def _open(self):
offset = self.fp.tell()
def _open(self) -> None:
if not _accept(self.fp.read(4)):
msg = "Not a BUFR file"
raise SyntaxError(msg)
self.fp.seek(offset)
self.fp.seek(-4, os.SEEK_CUR)
# make something up
self._mode = "F"
@@ -53,11 +55,11 @@ class BufrStubImageFile(ImageFile.StubImageFile):
if loader:
loader.open(self)
def _load(self):
def _load(self) -> ImageFile.StubHandler | None:
return _handler
def _save(im, fp, filename):
def _save(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
if _handler is None or not hasattr(_handler, "save"):
msg = "BUFR save handler not installed"
raise OSError(msg)

View File

@@ -13,18 +13,20 @@
#
# See the README file for information on usage and redistribution.
#
from __future__ import annotations
import io
from collections.abc import Iterable
from typing import IO, AnyStr, NoReturn
class ContainerIO:
class ContainerIO(IO[AnyStr]):
"""
A file object that provides read access to a part of an existing
file (for example a TAR file).
"""
def __init__(self, file, offset, length):
def __init__(self, file: IO[AnyStr], offset: int, length: int) -> None:
"""
Create file object.
@@ -32,7 +34,7 @@ class ContainerIO:
:param offset: Start of region, in bytes.
:param length: Size of region, in bytes.
"""
self.fh = file
self.fh: IO[AnyStr] = file
self.pos = 0
self.offset = offset
self.length = length
@@ -41,10 +43,13 @@ class ContainerIO:
##
# Always false.
def isatty(self):
def isatty(self) -> bool:
return False
def seek(self, offset, mode=io.SEEK_SET):
def seekable(self) -> bool:
return True
def seek(self, offset: int, mode: int = io.SEEK_SET) -> int:
"""
Move file pointer.
@@ -52,6 +57,7 @@ class ContainerIO:
:param mode: Starting position. Use 0 for beginning of region, 1
for current offset, and 2 for end of region. You cannot move
the pointer outside the defined region.
:returns: Offset from start of region, in bytes.
"""
if mode == 1:
self.pos = self.pos + offset
@@ -62,8 +68,9 @@ class ContainerIO:
# clamp
self.pos = max(0, min(self.pos, self.length))
self.fh.seek(self.offset + self.pos)
return self.pos
def tell(self):
def tell(self) -> int:
"""
Get current file pointer.
@@ -71,44 +78,51 @@ class ContainerIO:
"""
return self.pos
def read(self, n=0):
def readable(self) -> bool:
return True
def read(self, n: int = -1) -> AnyStr:
"""
Read data.
:param n: Number of bytes to read. If omitted or zero,
:param n: Number of bytes to read. If omitted, zero or negative,
read until end of region.
:returns: An 8-bit string.
"""
if n:
if n > 0:
n = min(n, self.length - self.pos)
else:
n = self.length - self.pos
if not n: # EOF
return b"" if "b" in self.fh.mode else ""
if n <= 0: # EOF
return b"" if "b" in self.fh.mode else "" # type: ignore[return-value]
self.pos = self.pos + n
return self.fh.read(n)
def readline(self):
def readline(self, n: int = -1) -> AnyStr:
"""
Read a line of text.
:param n: Number of bytes to read. If omitted, zero or negative,
read until end of line.
:returns: An 8-bit string.
"""
s = b"" if "b" in self.fh.mode else ""
s: AnyStr = b"" if "b" in self.fh.mode else "" # type: ignore[assignment]
newline_character = b"\n" if "b" in self.fh.mode else "\n"
while True:
c = self.read(1)
if not c:
break
s = s + c
if c == newline_character:
if c == newline_character or len(s) == n:
break
return s
def readlines(self):
def readlines(self, n: int | None = -1) -> list[AnyStr]:
"""
Read multiple lines of text.
:param n: Number of lines to read. If omitted, zero, negative or None,
read until end of region.
:returns: A list of 8-bit strings.
"""
lines = []
@@ -117,4 +131,43 @@ class ContainerIO:
if not s:
break
lines.append(s)
if len(lines) == n:
break
return lines
def writable(self) -> bool:
return False
def write(self, b: AnyStr) -> NoReturn:
raise NotImplementedError()
def writelines(self, lines: Iterable[AnyStr]) -> NoReturn:
raise NotImplementedError()
def truncate(self, size: int | None = None) -> int:
raise NotImplementedError()
def __enter__(self) -> ContainerIO[AnyStr]:
return self
def __exit__(self, *args: object) -> None:
self.close()
def __iter__(self) -> ContainerIO[AnyStr]:
return self
def __next__(self) -> AnyStr:
line = self.readline()
if not line:
msg = "end of region"
raise StopIteration(msg)
return line
def fileno(self) -> int:
return self.fh.fileno()
def flush(self) -> None:
self.fh.flush()
def close(self) -> None:
self.fh.close()

View File

@@ -15,6 +15,8 @@
#
# See the README file for information on usage and redistribution.
#
from __future__ import annotations
from . import BmpImagePlugin, Image
from ._binary import i16le as i16
from ._binary import i32le as i32
@@ -23,8 +25,8 @@ from ._binary import i32le as i32
# --------------------------------------------------------------------
def _accept(prefix):
return prefix[:4] == b"\0\0\2\0"
def _accept(prefix: bytes) -> bool:
return prefix.startswith(b"\0\0\2\0")
##
@@ -35,7 +37,8 @@ class CurImageFile(BmpImagePlugin.BmpImageFile):
format = "CUR"
format_description = "Windows Cursor"
def _open(self):
def _open(self) -> None:
assert self.fp is not None
offset = self.fp.tell()
# check magic
@@ -61,10 +64,7 @@ class CurImageFile(BmpImagePlugin.BmpImageFile):
# patch up the bitmap height
self._size = self.size[0], self.size[1] // 2
d, e, o, a = self.tile[0]
self.tile[0] = d, (0, 0) + self.size, o, a
return
self.tile = [self.tile[0]._replace(extents=(0, 0) + self.size)]
#

View File

@@ -20,15 +20,17 @@
#
# See the README file for information on usage and redistribution.
#
from __future__ import annotations
from . import Image
from ._binary import i32le as i32
from ._util import DeferredError
from .PcxImagePlugin import PcxImageFile
MAGIC = 0x3ADE68B1 # QUIZ: what's this value, then?
def _accept(prefix):
def _accept(prefix: bytes) -> bool:
return len(prefix) >= 4 and i32(prefix) == MAGIC
@@ -41,7 +43,7 @@ class DcxImageFile(PcxImageFile):
format_description = "Intel DCX"
_close_exclusive_fp_after_loading = False
def _open(self):
def _open(self) -> None:
# Header
s = self.fp.read(4)
if not _accept(s):
@@ -57,20 +59,22 @@ class DcxImageFile(PcxImageFile):
self._offset.append(offset)
self._fp = self.fp
self.frame = None
self.frame = -1
self.n_frames = len(self._offset)
self.is_animated = self.n_frames > 1
self.seek(0)
def seek(self, frame):
def seek(self, frame: int) -> None:
if not self._seek_check(frame):
return
if isinstance(self._fp, DeferredError):
raise self._fp.ex
self.frame = frame
self.fp = self._fp
self.fp.seek(self._offset[frame])
PcxImageFile._open(self)
def tell(self):
def tell(self) -> int:
return self.frame

View File

@@ -1,118 +1,338 @@
"""
A Pillow loader for .dds files (S3TC-compressed aka DXTC)
A Pillow plugin for .dds files (S3TC-compressed aka DXTC)
Jerome Leclanche <jerome@leclan.ch>
Documentation:
https://web.archive.org/web/20170802060935/http://oss.sgi.com/projects/ogl-sample/registry/EXT/texture_compression_s3tc.txt
https://web.archive.org/web/20170802060935/http://oss.sgi.com/projects/ogl-sample/registry/EXT/texture_compression_s3tc.txt
The contents of this file are hereby released in the public domain (CC0)
Full text of the CC0 license:
https://creativecommons.org/publicdomain/zero/1.0/
https://creativecommons.org/publicdomain/zero/1.0/
"""
from __future__ import annotations
import io
import struct
from io import BytesIO
import sys
from enum import IntEnum, IntFlag
from typing import IO
from . import Image, ImageFile, ImagePalette
from ._binary import i32le as i32
from ._binary import o8
from ._binary import o32le as o32
# Magic ("DDS ")
DDS_MAGIC = 0x20534444
# DDS flags
DDSD_CAPS = 0x1
DDSD_HEIGHT = 0x2
DDSD_WIDTH = 0x4
DDSD_PITCH = 0x8
DDSD_PIXELFORMAT = 0x1000
DDSD_MIPMAPCOUNT = 0x20000
DDSD_LINEARSIZE = 0x80000
DDSD_DEPTH = 0x800000
class DDSD(IntFlag):
CAPS = 0x1
HEIGHT = 0x2
WIDTH = 0x4
PITCH = 0x8
PIXELFORMAT = 0x1000
MIPMAPCOUNT = 0x20000
LINEARSIZE = 0x80000
DEPTH = 0x800000
# DDS caps
DDSCAPS_COMPLEX = 0x8
DDSCAPS_TEXTURE = 0x1000
DDSCAPS_MIPMAP = 0x400000
class DDSCAPS(IntFlag):
COMPLEX = 0x8
TEXTURE = 0x1000
MIPMAP = 0x400000
class DDSCAPS2(IntFlag):
CUBEMAP = 0x200
CUBEMAP_POSITIVEX = 0x400
CUBEMAP_NEGATIVEX = 0x800
CUBEMAP_POSITIVEY = 0x1000
CUBEMAP_NEGATIVEY = 0x2000
CUBEMAP_POSITIVEZ = 0x4000
CUBEMAP_NEGATIVEZ = 0x8000
VOLUME = 0x200000
DDSCAPS2_CUBEMAP = 0x200
DDSCAPS2_CUBEMAP_POSITIVEX = 0x400
DDSCAPS2_CUBEMAP_NEGATIVEX = 0x800
DDSCAPS2_CUBEMAP_POSITIVEY = 0x1000
DDSCAPS2_CUBEMAP_NEGATIVEY = 0x2000
DDSCAPS2_CUBEMAP_POSITIVEZ = 0x4000
DDSCAPS2_CUBEMAP_NEGATIVEZ = 0x8000
DDSCAPS2_VOLUME = 0x200000
# Pixel Format
DDPF_ALPHAPIXELS = 0x1
DDPF_ALPHA = 0x2
DDPF_FOURCC = 0x4
DDPF_PALETTEINDEXED8 = 0x20
DDPF_RGB = 0x40
DDPF_LUMINANCE = 0x20000
# dds.h
DDS_FOURCC = DDPF_FOURCC
DDS_RGB = DDPF_RGB
DDS_RGBA = DDPF_RGB | DDPF_ALPHAPIXELS
DDS_LUMINANCE = DDPF_LUMINANCE
DDS_LUMINANCEA = DDPF_LUMINANCE | DDPF_ALPHAPIXELS
DDS_ALPHA = DDPF_ALPHA
DDS_PAL8 = DDPF_PALETTEINDEXED8
DDS_HEADER_FLAGS_TEXTURE = DDSD_CAPS | DDSD_HEIGHT | DDSD_WIDTH | DDSD_PIXELFORMAT
DDS_HEADER_FLAGS_MIPMAP = DDSD_MIPMAPCOUNT
DDS_HEADER_FLAGS_VOLUME = DDSD_DEPTH
DDS_HEADER_FLAGS_PITCH = DDSD_PITCH
DDS_HEADER_FLAGS_LINEARSIZE = DDSD_LINEARSIZE
DDS_HEIGHT = DDSD_HEIGHT
DDS_WIDTH = DDSD_WIDTH
DDS_SURFACE_FLAGS_TEXTURE = DDSCAPS_TEXTURE
DDS_SURFACE_FLAGS_MIPMAP = DDSCAPS_COMPLEX | DDSCAPS_MIPMAP
DDS_SURFACE_FLAGS_CUBEMAP = DDSCAPS_COMPLEX
DDS_CUBEMAP_POSITIVEX = DDSCAPS2_CUBEMAP | DDSCAPS2_CUBEMAP_POSITIVEX
DDS_CUBEMAP_NEGATIVEX = DDSCAPS2_CUBEMAP | DDSCAPS2_CUBEMAP_NEGATIVEX
DDS_CUBEMAP_POSITIVEY = DDSCAPS2_CUBEMAP | DDSCAPS2_CUBEMAP_POSITIVEY
DDS_CUBEMAP_NEGATIVEY = DDSCAPS2_CUBEMAP | DDSCAPS2_CUBEMAP_NEGATIVEY
DDS_CUBEMAP_POSITIVEZ = DDSCAPS2_CUBEMAP | DDSCAPS2_CUBEMAP_POSITIVEZ
DDS_CUBEMAP_NEGATIVEZ = DDSCAPS2_CUBEMAP | DDSCAPS2_CUBEMAP_NEGATIVEZ
# DXT1
DXT1_FOURCC = 0x31545844
# DXT3
DXT3_FOURCC = 0x33545844
# DXT5
DXT5_FOURCC = 0x35545844
class DDPF(IntFlag):
ALPHAPIXELS = 0x1
ALPHA = 0x2
FOURCC = 0x4
PALETTEINDEXED8 = 0x20
RGB = 0x40
LUMINANCE = 0x20000
# dxgiformat.h
class DXGI_FORMAT(IntEnum):
UNKNOWN = 0
R32G32B32A32_TYPELESS = 1
R32G32B32A32_FLOAT = 2
R32G32B32A32_UINT = 3
R32G32B32A32_SINT = 4
R32G32B32_TYPELESS = 5
R32G32B32_FLOAT = 6
R32G32B32_UINT = 7
R32G32B32_SINT = 8
R16G16B16A16_TYPELESS = 9
R16G16B16A16_FLOAT = 10
R16G16B16A16_UNORM = 11
R16G16B16A16_UINT = 12
R16G16B16A16_SNORM = 13
R16G16B16A16_SINT = 14
R32G32_TYPELESS = 15
R32G32_FLOAT = 16
R32G32_UINT = 17
R32G32_SINT = 18
R32G8X24_TYPELESS = 19
D32_FLOAT_S8X24_UINT = 20
R32_FLOAT_X8X24_TYPELESS = 21
X32_TYPELESS_G8X24_UINT = 22
R10G10B10A2_TYPELESS = 23
R10G10B10A2_UNORM = 24
R10G10B10A2_UINT = 25
R11G11B10_FLOAT = 26
R8G8B8A8_TYPELESS = 27
R8G8B8A8_UNORM = 28
R8G8B8A8_UNORM_SRGB = 29
R8G8B8A8_UINT = 30
R8G8B8A8_SNORM = 31
R8G8B8A8_SINT = 32
R16G16_TYPELESS = 33
R16G16_FLOAT = 34
R16G16_UNORM = 35
R16G16_UINT = 36
R16G16_SNORM = 37
R16G16_SINT = 38
R32_TYPELESS = 39
D32_FLOAT = 40
R32_FLOAT = 41
R32_UINT = 42
R32_SINT = 43
R24G8_TYPELESS = 44
D24_UNORM_S8_UINT = 45
R24_UNORM_X8_TYPELESS = 46
X24_TYPELESS_G8_UINT = 47
R8G8_TYPELESS = 48
R8G8_UNORM = 49
R8G8_UINT = 50
R8G8_SNORM = 51
R8G8_SINT = 52
R16_TYPELESS = 53
R16_FLOAT = 54
D16_UNORM = 55
R16_UNORM = 56
R16_UINT = 57
R16_SNORM = 58
R16_SINT = 59
R8_TYPELESS = 60
R8_UNORM = 61
R8_UINT = 62
R8_SNORM = 63
R8_SINT = 64
A8_UNORM = 65
R1_UNORM = 66
R9G9B9E5_SHAREDEXP = 67
R8G8_B8G8_UNORM = 68
G8R8_G8B8_UNORM = 69
BC1_TYPELESS = 70
BC1_UNORM = 71
BC1_UNORM_SRGB = 72
BC2_TYPELESS = 73
BC2_UNORM = 74
BC2_UNORM_SRGB = 75
BC3_TYPELESS = 76
BC3_UNORM = 77
BC3_UNORM_SRGB = 78
BC4_TYPELESS = 79
BC4_UNORM = 80
BC4_SNORM = 81
BC5_TYPELESS = 82
BC5_UNORM = 83
BC5_SNORM = 84
B5G6R5_UNORM = 85
B5G5R5A1_UNORM = 86
B8G8R8A8_UNORM = 87
B8G8R8X8_UNORM = 88
R10G10B10_XR_BIAS_A2_UNORM = 89
B8G8R8A8_TYPELESS = 90
B8G8R8A8_UNORM_SRGB = 91
B8G8R8X8_TYPELESS = 92
B8G8R8X8_UNORM_SRGB = 93
BC6H_TYPELESS = 94
BC6H_UF16 = 95
BC6H_SF16 = 96
BC7_TYPELESS = 97
BC7_UNORM = 98
BC7_UNORM_SRGB = 99
AYUV = 100
Y410 = 101
Y416 = 102
NV12 = 103
P010 = 104
P016 = 105
OPAQUE_420 = 106
YUY2 = 107
Y210 = 108
Y216 = 109
NV11 = 110
AI44 = 111
IA44 = 112
P8 = 113
A8P8 = 114
B4G4R4A4_UNORM = 115
P208 = 130
V208 = 131
V408 = 132
SAMPLER_FEEDBACK_MIN_MIP_OPAQUE = 189
SAMPLER_FEEDBACK_MIP_REGION_USED_OPAQUE = 190
DXGI_FORMAT_R8G8B8A8_TYPELESS = 27
DXGI_FORMAT_R8G8B8A8_UNORM = 28
DXGI_FORMAT_R8G8B8A8_UNORM_SRGB = 29
DXGI_FORMAT_BC5_TYPELESS = 82
DXGI_FORMAT_BC5_UNORM = 83
DXGI_FORMAT_BC5_SNORM = 84
DXGI_FORMAT_BC6H_UF16 = 95
DXGI_FORMAT_BC6H_SF16 = 96
DXGI_FORMAT_BC7_TYPELESS = 97
DXGI_FORMAT_BC7_UNORM = 98
DXGI_FORMAT_BC7_UNORM_SRGB = 99
class D3DFMT(IntEnum):
UNKNOWN = 0
R8G8B8 = 20
A8R8G8B8 = 21
X8R8G8B8 = 22
R5G6B5 = 23
X1R5G5B5 = 24
A1R5G5B5 = 25
A4R4G4B4 = 26
R3G3B2 = 27
A8 = 28
A8R3G3B2 = 29
X4R4G4B4 = 30
A2B10G10R10 = 31
A8B8G8R8 = 32
X8B8G8R8 = 33
G16R16 = 34
A2R10G10B10 = 35
A16B16G16R16 = 36
A8P8 = 40
P8 = 41
L8 = 50
A8L8 = 51
A4L4 = 52
V8U8 = 60
L6V5U5 = 61
X8L8V8U8 = 62
Q8W8V8U8 = 63
V16U16 = 64
A2W10V10U10 = 67
D16_LOCKABLE = 70
D32 = 71
D15S1 = 73
D24S8 = 75
D24X8 = 77
D24X4S4 = 79
D16 = 80
D32F_LOCKABLE = 82
D24FS8 = 83
D32_LOCKABLE = 84
S8_LOCKABLE = 85
L16 = 81
VERTEXDATA = 100
INDEX16 = 101
INDEX32 = 102
Q16W16V16U16 = 110
R16F = 111
G16R16F = 112
A16B16G16R16F = 113
R32F = 114
G32R32F = 115
A32B32G32R32F = 116
CxV8U8 = 117
A1 = 118
A2B10G10R10_XR_BIAS = 119
BINARYBUFFER = 199
UYVY = i32(b"UYVY")
R8G8_B8G8 = i32(b"RGBG")
YUY2 = i32(b"YUY2")
G8R8_G8B8 = i32(b"GRGB")
DXT1 = i32(b"DXT1")
DXT2 = i32(b"DXT2")
DXT3 = i32(b"DXT3")
DXT4 = i32(b"DXT4")
DXT5 = i32(b"DXT5")
DX10 = i32(b"DX10")
BC4S = i32(b"BC4S")
BC4U = i32(b"BC4U")
BC5S = i32(b"BC5S")
BC5U = i32(b"BC5U")
ATI1 = i32(b"ATI1")
ATI2 = i32(b"ATI2")
MULTI2_ARGB8 = i32(b"MET1")
# Backward compatibility layer
module = sys.modules[__name__]
for item in DDSD:
assert item.name is not None
setattr(module, f"DDSD_{item.name}", item.value)
for item1 in DDSCAPS:
assert item1.name is not None
setattr(module, f"DDSCAPS_{item1.name}", item1.value)
for item2 in DDSCAPS2:
assert item2.name is not None
setattr(module, f"DDSCAPS2_{item2.name}", item2.value)
for item3 in DDPF:
assert item3.name is not None
setattr(module, f"DDPF_{item3.name}", item3.value)
DDS_FOURCC = DDPF.FOURCC
DDS_RGB = DDPF.RGB
DDS_RGBA = DDPF.RGB | DDPF.ALPHAPIXELS
DDS_LUMINANCE = DDPF.LUMINANCE
DDS_LUMINANCEA = DDPF.LUMINANCE | DDPF.ALPHAPIXELS
DDS_ALPHA = DDPF.ALPHA
DDS_PAL8 = DDPF.PALETTEINDEXED8
DDS_HEADER_FLAGS_TEXTURE = DDSD.CAPS | DDSD.HEIGHT | DDSD.WIDTH | DDSD.PIXELFORMAT
DDS_HEADER_FLAGS_MIPMAP = DDSD.MIPMAPCOUNT
DDS_HEADER_FLAGS_VOLUME = DDSD.DEPTH
DDS_HEADER_FLAGS_PITCH = DDSD.PITCH
DDS_HEADER_FLAGS_LINEARSIZE = DDSD.LINEARSIZE
DDS_HEIGHT = DDSD.HEIGHT
DDS_WIDTH = DDSD.WIDTH
DDS_SURFACE_FLAGS_TEXTURE = DDSCAPS.TEXTURE
DDS_SURFACE_FLAGS_MIPMAP = DDSCAPS.COMPLEX | DDSCAPS.MIPMAP
DDS_SURFACE_FLAGS_CUBEMAP = DDSCAPS.COMPLEX
DDS_CUBEMAP_POSITIVEX = DDSCAPS2.CUBEMAP | DDSCAPS2.CUBEMAP_POSITIVEX
DDS_CUBEMAP_NEGATIVEX = DDSCAPS2.CUBEMAP | DDSCAPS2.CUBEMAP_NEGATIVEX
DDS_CUBEMAP_POSITIVEY = DDSCAPS2.CUBEMAP | DDSCAPS2.CUBEMAP_POSITIVEY
DDS_CUBEMAP_NEGATIVEY = DDSCAPS2.CUBEMAP | DDSCAPS2.CUBEMAP_NEGATIVEY
DDS_CUBEMAP_POSITIVEZ = DDSCAPS2.CUBEMAP | DDSCAPS2.CUBEMAP_POSITIVEZ
DDS_CUBEMAP_NEGATIVEZ = DDSCAPS2.CUBEMAP | DDSCAPS2.CUBEMAP_NEGATIVEZ
DXT1_FOURCC = D3DFMT.DXT1
DXT3_FOURCC = D3DFMT.DXT3
DXT5_FOURCC = D3DFMT.DXT5
DXGI_FORMAT_R8G8B8A8_TYPELESS = DXGI_FORMAT.R8G8B8A8_TYPELESS
DXGI_FORMAT_R8G8B8A8_UNORM = DXGI_FORMAT.R8G8B8A8_UNORM
DXGI_FORMAT_R8G8B8A8_UNORM_SRGB = DXGI_FORMAT.R8G8B8A8_UNORM_SRGB
DXGI_FORMAT_BC5_TYPELESS = DXGI_FORMAT.BC5_TYPELESS
DXGI_FORMAT_BC5_UNORM = DXGI_FORMAT.BC5_UNORM
DXGI_FORMAT_BC5_SNORM = DXGI_FORMAT.BC5_SNORM
DXGI_FORMAT_BC6H_UF16 = DXGI_FORMAT.BC6H_UF16
DXGI_FORMAT_BC6H_SF16 = DXGI_FORMAT.BC6H_SF16
DXGI_FORMAT_BC7_TYPELESS = DXGI_FORMAT.BC7_TYPELESS
DXGI_FORMAT_BC7_UNORM = DXGI_FORMAT.BC7_UNORM
DXGI_FORMAT_BC7_UNORM_SRGB = DXGI_FORMAT.BC7_UNORM_SRGB
class DdsImageFile(ImageFile.ImageFile):
format = "DDS"
format_description = "DirectDraw Surface"
def _open(self):
def _open(self) -> None:
if not _accept(self.fp.read(4)):
msg = "not a DDS file"
raise SyntaxError(msg)
@@ -124,172 +344,281 @@ class DdsImageFile(ImageFile.ImageFile):
if len(header_bytes) != 120:
msg = f"Incomplete header: {len(header_bytes)} bytes"
raise OSError(msg)
header = BytesIO(header_bytes)
header = io.BytesIO(header_bytes)
flags, height, width = struct.unpack("<3I", header.read(12))
self._size = (width, height)
self._mode = "RGBA"
extents = (0, 0) + self.size
pitch, depth, mipmaps = struct.unpack("<3I", header.read(12))
struct.unpack("<11I", header.read(44)) # reserved
# pixel format
pfsize, pfflags = struct.unpack("<2I", header.read(8))
fourcc = header.read(4)
(bitcount,) = struct.unpack("<I", header.read(4))
masks = struct.unpack("<4I", header.read(16))
if pfflags & DDPF_LUMINANCE:
# Texture contains uncompressed L or LA data
if pfflags & DDPF_ALPHAPIXELS:
pfsize, pfflags, fourcc, bitcount = struct.unpack("<4I", header.read(16))
n = 0
rawmode = None
if pfflags & DDPF.RGB:
# Texture contains uncompressed RGB data
if pfflags & DDPF.ALPHAPIXELS:
self._mode = "RGBA"
mask_count = 4
else:
self._mode = "RGB"
mask_count = 3
masks = struct.unpack(f"<{mask_count}I", header.read(mask_count * 4))
self.tile = [ImageFile._Tile("dds_rgb", extents, 0, (bitcount, masks))]
return
elif pfflags & DDPF.LUMINANCE:
if bitcount == 8:
self._mode = "L"
elif bitcount == 16 and pfflags & DDPF.ALPHAPIXELS:
self._mode = "LA"
else:
self._mode = "L"
self.tile = [("raw", (0, 0) + self.size, 0, (self.mode, 0, 1))]
elif pfflags & DDPF_RGB:
# Texture contains uncompressed RGB data
masks = {mask: ["R", "G", "B", "A"][i] for i, mask in enumerate(masks)}
rawmode = ""
if pfflags & DDPF_ALPHAPIXELS:
rawmode += masks[0xFF000000]
else:
self._mode = "RGB"
rawmode += masks[0xFF0000] + masks[0xFF00] + masks[0xFF]
self.tile = [("raw", (0, 0) + self.size, 0, (rawmode[::-1], 0, 1))]
elif pfflags & DDPF_PALETTEINDEXED8:
msg = f"Unsupported bitcount {bitcount} for {pfflags}"
raise OSError(msg)
elif pfflags & DDPF.PALETTEINDEXED8:
self._mode = "P"
self.palette = ImagePalette.raw("RGBA", self.fp.read(1024))
self.tile = [("raw", (0, 0) + self.size, 0, "L")]
else:
data_start = header_size + 4
n = 0
if fourcc == b"DXT1":
self.palette.mode = "RGBA"
elif pfflags & DDPF.FOURCC:
offset = header_size + 4
if fourcc == D3DFMT.DXT1:
self._mode = "RGBA"
self.pixel_format = "DXT1"
n = 1
elif fourcc == b"DXT3":
elif fourcc == D3DFMT.DXT3:
self._mode = "RGBA"
self.pixel_format = "DXT3"
n = 2
elif fourcc == b"DXT5":
elif fourcc == D3DFMT.DXT5:
self._mode = "RGBA"
self.pixel_format = "DXT5"
n = 3
elif fourcc == b"ATI1":
elif fourcc in (D3DFMT.BC4U, D3DFMT.ATI1):
self._mode = "L"
self.pixel_format = "BC4"
n = 4
self._mode = "L"
elif fourcc in (b"ATI2", b"BC5U"):
self.pixel_format = "BC5"
n = 5
elif fourcc == D3DFMT.BC5S:
self._mode = "RGB"
elif fourcc == b"BC5S":
self.pixel_format = "BC5S"
n = 5
elif fourcc in (D3DFMT.BC5U, D3DFMT.ATI2):
self._mode = "RGB"
elif fourcc == b"DX10":
data_start += 20
self.pixel_format = "BC5"
n = 5
elif fourcc == D3DFMT.DX10:
offset += 20
# ignoring flags which pertain to volume textures and cubemaps
(dxgi_format,) = struct.unpack("<I", self.fp.read(4))
self.fp.read(16)
if dxgi_format in (DXGI_FORMAT_BC5_TYPELESS, DXGI_FORMAT_BC5_UNORM):
if dxgi_format in (
DXGI_FORMAT.BC1_UNORM,
DXGI_FORMAT.BC1_TYPELESS,
):
self._mode = "RGBA"
self.pixel_format = "BC1"
n = 1
elif dxgi_format in (DXGI_FORMAT.BC2_TYPELESS, DXGI_FORMAT.BC2_UNORM):
self._mode = "RGBA"
self.pixel_format = "BC2"
n = 2
elif dxgi_format in (DXGI_FORMAT.BC3_TYPELESS, DXGI_FORMAT.BC3_UNORM):
self._mode = "RGBA"
self.pixel_format = "BC3"
n = 3
elif dxgi_format in (DXGI_FORMAT.BC4_TYPELESS, DXGI_FORMAT.BC4_UNORM):
self._mode = "L"
self.pixel_format = "BC4"
n = 4
elif dxgi_format in (DXGI_FORMAT.BC5_TYPELESS, DXGI_FORMAT.BC5_UNORM):
self._mode = "RGB"
self.pixel_format = "BC5"
n = 5
elif dxgi_format == DXGI_FORMAT.BC5_SNORM:
self._mode = "RGB"
elif dxgi_format == DXGI_FORMAT_BC5_SNORM:
self.pixel_format = "BC5S"
n = 5
elif dxgi_format == DXGI_FORMAT.BC6H_UF16:
self._mode = "RGB"
elif dxgi_format == DXGI_FORMAT_BC6H_UF16:
self.pixel_format = "BC6H"
n = 6
elif dxgi_format == DXGI_FORMAT.BC6H_SF16:
self._mode = "RGB"
elif dxgi_format == DXGI_FORMAT_BC6H_SF16:
self.pixel_format = "BC6HS"
n = 6
self._mode = "RGB"
elif dxgi_format in (DXGI_FORMAT_BC7_TYPELESS, DXGI_FORMAT_BC7_UNORM):
self.pixel_format = "BC7"
n = 7
elif dxgi_format == DXGI_FORMAT_BC7_UNORM_SRGB:
self.pixel_format = "BC7"
self.info["gamma"] = 1 / 2.2
n = 7
elif dxgi_format in (
DXGI_FORMAT_R8G8B8A8_TYPELESS,
DXGI_FORMAT_R8G8B8A8_UNORM,
DXGI_FORMAT_R8G8B8A8_UNORM_SRGB,
DXGI_FORMAT.BC7_TYPELESS,
DXGI_FORMAT.BC7_UNORM,
DXGI_FORMAT.BC7_UNORM_SRGB,
):
self.tile = [("raw", (0, 0) + self.size, 0, ("RGBA", 0, 1))]
if dxgi_format == DXGI_FORMAT_R8G8B8A8_UNORM_SRGB:
self._mode = "RGBA"
self.pixel_format = "BC7"
n = 7
if dxgi_format == DXGI_FORMAT.BC7_UNORM_SRGB:
self.info["gamma"] = 1 / 2.2
elif dxgi_format in (
DXGI_FORMAT.R8G8B8A8_TYPELESS,
DXGI_FORMAT.R8G8B8A8_UNORM,
DXGI_FORMAT.R8G8B8A8_UNORM_SRGB,
):
self._mode = "RGBA"
if dxgi_format == DXGI_FORMAT.R8G8B8A8_UNORM_SRGB:
self.info["gamma"] = 1 / 2.2
return
else:
msg = f"Unimplemented DXGI format {dxgi_format}"
raise NotImplementedError(msg)
else:
msg = f"Unimplemented pixel format {repr(fourcc)}"
raise NotImplementedError(msg)
else:
msg = f"Unknown pixel format flags {pfflags}"
raise NotImplementedError(msg)
if n:
self.tile = [
("bcn", (0, 0) + self.size, data_start, (n, self.pixel_format))
ImageFile._Tile("bcn", extents, offset, (n, self.pixel_format))
]
else:
self.tile = [ImageFile._Tile("raw", extents, 0, rawmode or self.mode)]
def load_seek(self, pos):
def load_seek(self, pos: int) -> None:
pass
def _save(im, fp, filename):
class DdsRgbDecoder(ImageFile.PyDecoder):
_pulls_fd = True
def decode(self, buffer: bytes | Image.SupportsArrayInterface) -> tuple[int, int]:
assert self.fd is not None
bitcount, masks = self.args
# Some masks will be padded with zeros, e.g. R 0b11 G 0b1100
# Calculate how many zeros each mask is padded with
mask_offsets = []
# And the maximum value of each channel without the padding
mask_totals = []
for mask in masks:
offset = 0
if mask != 0:
while mask >> (offset + 1) << (offset + 1) == mask:
offset += 1
mask_offsets.append(offset)
mask_totals.append(mask >> offset)
data = bytearray()
bytecount = bitcount // 8
dest_length = self.state.xsize * self.state.ysize * len(masks)
while len(data) < dest_length:
value = int.from_bytes(self.fd.read(bytecount), "little")
for i, mask in enumerate(masks):
masked_value = value & mask
# Remove the zero padding, and scale it to 8 bits
data += o8(
int(((masked_value >> mask_offsets[i]) / mask_totals[i]) * 255)
)
self.set_as_raw(data)
return -1, 0
def _save(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
if im.mode not in ("RGB", "RGBA", "L", "LA"):
msg = f"cannot write mode {im.mode} as DDS"
raise OSError(msg)
rawmode = im.mode
masks = [0xFF0000, 0xFF00, 0xFF]
if im.mode in ("L", "LA"):
pixel_flags = DDPF_LUMINANCE
flags = DDSD.CAPS | DDSD.HEIGHT | DDSD.WIDTH | DDSD.PIXELFORMAT
bitcount = len(im.getbands()) * 8
pixel_format = im.encoderinfo.get("pixel_format")
args: tuple[int] | str
if pixel_format:
codec_name = "bcn"
flags |= DDSD.LINEARSIZE
pitch = (im.width + 3) * 4
rgba_mask = [0, 0, 0, 0]
pixel_flags = DDPF.FOURCC
if pixel_format == "DXT1":
fourcc = D3DFMT.DXT1
args = (1,)
elif pixel_format == "DXT3":
fourcc = D3DFMT.DXT3
args = (2,)
elif pixel_format == "DXT5":
fourcc = D3DFMT.DXT5
args = (3,)
else:
pixel_flags = DDPF_RGB
rawmode = rawmode[::-1]
if im.mode in ("LA", "RGBA"):
pixel_flags |= DDPF_ALPHAPIXELS
masks.append(0xFF000000)
fourcc = D3DFMT.DX10
if pixel_format == "BC2":
args = (2,)
dxgi_format = DXGI_FORMAT.BC2_TYPELESS
elif pixel_format == "BC3":
args = (3,)
dxgi_format = DXGI_FORMAT.BC3_TYPELESS
elif pixel_format == "BC5":
args = (5,)
dxgi_format = DXGI_FORMAT.BC5_TYPELESS
if im.mode != "RGB":
msg = "only RGB mode can be written as BC5"
raise OSError(msg)
else:
msg = f"cannot write pixel format {pixel_format}"
raise OSError(msg)
else:
codec_name = "raw"
flags |= DDSD.PITCH
pitch = (im.width * bitcount + 7) // 8
bitcount = len(masks) * 8
while len(masks) < 4:
masks.append(0)
alpha = im.mode[-1] == "A"
if im.mode[0] == "L":
pixel_flags = DDPF.LUMINANCE
args = im.mode
if alpha:
rgba_mask = [0x000000FF, 0x000000FF, 0x000000FF]
else:
rgba_mask = [0xFF000000, 0xFF000000, 0xFF000000]
else:
pixel_flags = DDPF.RGB
args = im.mode[::-1]
rgba_mask = [0x00FF0000, 0x0000FF00, 0x000000FF]
fp.write(
o32(DDS_MAGIC)
+ o32(124) # header size
+ o32(
DDSD_CAPS | DDSD_HEIGHT | DDSD_WIDTH | DDSD_PITCH | DDSD_PIXELFORMAT
) # flags
+ o32(im.height)
+ o32(im.width)
+ o32((im.width * bitcount + 7) // 8) # pitch
+ o32(0) # depth
+ o32(0) # mipmaps
+ o32(0) * 11 # reserved
+ o32(32) # pfsize
+ o32(pixel_flags) # pfflags
+ o32(0) # fourcc
+ o32(bitcount) # bitcount
+ b"".join(o32(mask) for mask in masks) # rgbabitmask
+ o32(DDSCAPS_TEXTURE) # dwCaps
+ o32(0) # dwCaps2
+ o32(0) # dwCaps3
+ o32(0) # dwCaps4
+ o32(0) # dwReserved2
)
if im.mode == "RGBA":
if alpha:
r, g, b, a = im.split()
im = Image.merge("RGBA", (a, r, g, b))
ImageFile._save(im, fp, [("raw", (0, 0) + im.size, 0, (rawmode, 0, 1))])
if alpha:
pixel_flags |= DDPF.ALPHAPIXELS
rgba_mask.append(0xFF000000 if alpha else 0)
fourcc = D3DFMT.UNKNOWN
fp.write(
o32(DDS_MAGIC)
+ struct.pack(
"<7I",
124, # header size
flags, # flags
im.height,
im.width,
pitch,
0, # depth
0, # mipmaps
)
+ struct.pack("11I", *((0,) * 11)) # reserved
# pfsize, pfflags, fourcc, bitcount
+ struct.pack("<4I", 32, pixel_flags, fourcc, bitcount)
+ struct.pack("<4I", *rgba_mask) # dwRGBABitMask
+ struct.pack("<5I", DDSCAPS.TEXTURE, 0, 0, 0, 0)
)
if fourcc == D3DFMT.DX10:
fp.write(
# dxgi_format, 2D resource, misc, array size, straight alpha
struct.pack("<5I", dxgi_format, 3, 0, 0, 1)
)
ImageFile._save(im, fp, [ImageFile._Tile(codec_name, (0, 0) + im.size, 0, args)])
def _accept(prefix):
return prefix[:4] == b"DDS "
def _accept(prefix: bytes) -> bool:
return prefix.startswith(b"DDS ")
Image.register_open(DdsImageFile.format, DdsImageFile, _accept)
Image.register_decoder("dds_rgb", DdsRgbDecoder)
Image.register_save(DdsImageFile.format, _save)
Image.register_extension(DdsImageFile.format, ".dds")

View File

@@ -19,6 +19,7 @@
#
# See the README file for information on usage and redistribution.
#
from __future__ import annotations
import io
import os
@@ -26,10 +27,10 @@ import re
import subprocess
import sys
import tempfile
from typing import IO
from . import Image, ImageFile
from ._binary import i32le as i32
from ._deprecate import deprecate
# --------------------------------------------------------------------
@@ -37,11 +38,11 @@ from ._deprecate import deprecate
split = re.compile(r"^%%([^:]*):[ \t]*(.*)[ \t]*$")
field = re.compile(r"^%[%!\w]([^:]*)[ \t]*$")
gs_binary = None
gs_binary: str | bool | None = None
gs_windows_binary = None
def has_ghostscript():
def has_ghostscript() -> bool:
global gs_binary, gs_windows_binary
if gs_binary is None:
if sys.platform.startswith("win"):
@@ -64,27 +65,32 @@ def has_ghostscript():
return gs_binary is not False
def Ghostscript(tile, size, fp, scale=1, transparency=False):
def Ghostscript(
tile: list[ImageFile._Tile],
size: tuple[int, int],
fp: IO[bytes],
scale: int = 1,
transparency: bool = False,
) -> Image.core.ImagingCore:
"""Render an image using Ghostscript"""
global gs_binary
if not has_ghostscript():
msg = "Unable to locate Ghostscript on paths"
raise OSError(msg)
assert isinstance(gs_binary, str)
# Unpack decoder tile
decoder, tile, offset, data = tile[0]
length, bbox = data
args = tile[0].args
assert isinstance(args, tuple)
length, bbox = args
# Hack to support hi-res rendering
scale = int(scale) or 1
# orig_size = size
# orig_bbox = bbox
size = (size[0] * scale, size[1] * scale)
width = size[0] * scale
height = size[1] * scale
# resolution is dependent on bbox and size
res = (
72.0 * size[0] / (bbox[2] - bbox[0]),
72.0 * size[1] / (bbox[3] - bbox[1]),
)
res_x = 72.0 * width / (bbox[2] - bbox[0])
res_y = 72.0 * height / (bbox[3] - bbox[1])
out_fd, outfile = tempfile.mkstemp()
os.close(out_fd)
@@ -115,14 +121,20 @@ def Ghostscript(tile, size, fp, scale=1, transparency=False):
lengthfile -= len(s)
f.write(s)
device = "pngalpha" if transparency else "ppmraw"
if transparency:
# "RGBA"
device = "pngalpha"
else:
# "pnmraw" automatically chooses between
# PBM ("1"), PGM ("L"), and PPM ("RGB").
device = "pnmraw"
# Build Ghostscript command
command = [
gs_binary,
"-q", # quiet mode
"-g%dx%d" % size, # set output geometry (pixels)
"-r%fx%f" % res, # set input DPI (dots per inch)
f"-g{width:d}x{height:d}", # set output geometry (pixels)
f"-r{res_x:f}x{res_y:f}", # set input DPI (dots per inch)
"-dBATCH", # exit after processing
"-dNOPAUSE", # don't pause between pages
"-dSAFER", # safe mode
@@ -145,8 +157,9 @@ def Ghostscript(tile, size, fp, scale=1, transparency=False):
startupinfo = subprocess.STARTUPINFO()
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
subprocess.check_call(command, startupinfo=startupinfo)
out_im = Image.open(outfile)
with Image.open(outfile) as out_im:
out_im.load()
return out_im.im.copy()
finally:
try:
os.unlink(outfile)
@@ -155,50 +168,11 @@ def Ghostscript(tile, size, fp, scale=1, transparency=False):
except OSError:
pass
im = out_im.im.copy()
out_im.close()
return im
class PSFile:
"""
Wrapper for bytesio object that treats either CR or LF as end of line.
This class is no longer used internally, but kept for backwards compatibility.
"""
def __init__(self, fp):
deprecate(
"PSFile",
11,
action="If you need the functionality of this class "
"you will need to implement it yourself.",
def _accept(prefix: bytes) -> bool:
return prefix.startswith(b"%!PS") or (
len(prefix) >= 4 and i32(prefix) == 0xC6D3D0C5
)
self.fp = fp
self.char = None
def seek(self, offset, whence=io.SEEK_SET):
self.char = None
self.fp.seek(offset, whence)
def readline(self):
s = [self.char or b""]
self.char = None
c = self.fp.read(1)
while (c not in b"\r\n") and len(c):
s.append(c)
c = self.fp.read(1)
self.char = self.fp.read(1)
# line endings can be 1 or 2 of \r \n, in either order
if self.char in b"\r\n":
self.char = None
return b"".join(s).decode("latin-1")
def _accept(prefix):
return prefix[:4] == b"%!PS" or (len(prefix) >= 4 and i32(prefix) == 0xC6D3D0C5)
##
@@ -214,14 +188,18 @@ class EpsImageFile(ImageFile.ImageFile):
mode_map = {1: "L", 2: "LAB", 3: "RGB", 4: "CMYK"}
def _open(self):
def _open(self) -> None:
(length, offset) = self._find_offset(self.fp)
# go to offset - start of "%!PS"
self.fp.seek(offset)
self._mode = "RGB"
self._size = None
# When reading header comments, the first comment is used.
# When reading trailer comments, the last comment is used.
bounding_box: list[int] | None = None
imagedata_size: tuple[int, int] | None = None
byte_arr = bytearray(255)
bytes_mv = memoryview(byte_arr)
@@ -230,7 +208,12 @@ class EpsImageFile(ImageFile.ImageFile):
reading_trailer_comments = False
trailer_reached = False
def check_required_header_comments():
def check_required_header_comments() -> None:
"""
The EPS specification requires that some headers exist.
This should be checked when the header comments formally end,
when image data starts, or when the file ends, whichever comes first.
"""
if "PS-Adobe" not in self.info:
msg = 'EPS header missing "%!PS-Adobe" comment'
raise SyntaxError(msg)
@@ -238,32 +221,28 @@ class EpsImageFile(ImageFile.ImageFile):
msg = 'EPS header missing "%%BoundingBox" comment'
raise SyntaxError(msg)
def _read_comment(s):
nonlocal reading_trailer_comments
def read_comment(s: str) -> bool:
nonlocal bounding_box, reading_trailer_comments
try:
m = split.match(s)
except re.error as e:
msg = "not an EPS file"
raise SyntaxError(msg) from e
if m:
if not m:
return False
k, v = m.group(1, 2)
self.info[k] = v
if k == "BoundingBox":
if v == "(atend)":
reading_trailer_comments = True
elif not self._size or (
trailer_reached and reading_trailer_comments
):
elif not bounding_box or (trailer_reached and reading_trailer_comments):
try:
# Note: The DSC spec says that BoundingBox
# fields should be integers, but some drivers
# put floating point values there anyway.
box = [int(float(i)) for i in v.split()]
self._size = box[2] - box[0], box[3] - box[1]
self.tile = [
("eps", (0, 0) + self.size, offset, (length, box))
]
bounding_box = [int(float(i)) for i in v.split()]
except Exception:
pass
return True
@@ -273,6 +252,8 @@ class EpsImageFile(ImageFile.ImageFile):
if byte == b"":
# if we didn't read a byte we must be at the end of the file
if bytes_read == 0:
if reading_header_comments:
check_required_header_comments()
break
elif byte in b"\r\n":
# if we read a line ending character, ignore it and parse what
@@ -312,11 +293,11 @@ class EpsImageFile(ImageFile.ImageFile):
continue
s = str(bytes_mv[:bytes_read], "latin-1")
if not _read_comment(s):
if not read_comment(s):
m = field.match(s)
if m:
k = m.group(1)
if k[:8] == "PS-Adobe":
if k.startswith("PS-Adobe"):
self.info["PS-Adobe"] = k[9:]
else:
self.info[k] = ""
@@ -331,6 +312,12 @@ class EpsImageFile(ImageFile.ImageFile):
# Check for an "ImageData" descriptor
# https://www.adobe.com/devnet-apps/photoshop/fileformatashtml/#50577413_pgfId-1035096
# If we've already read an "ImageData" descriptor,
# don't read another one.
if imagedata_size:
bytes_read = 0
continue
# Values:
# columns
# rows
@@ -356,29 +343,39 @@ class EpsImageFile(ImageFile.ImageFile):
else:
break
self._size = columns, rows
return
# Parse the columns and rows after checking the bit depth and mode
# in case the bit depth and/or mode are invalid.
imagedata_size = columns, rows
elif bytes_mv[:5] == b"%%EOF":
break
elif trailer_reached and reading_trailer_comments:
# Load EPS trailer
# if this line starts with "%%EOF",
# then we've reached the end of the file
if bytes_mv[:5] == b"%%EOF":
break
s = str(bytes_mv[:bytes_read], "latin-1")
_read_comment(s)
read_comment(s)
elif bytes_mv[:9] == b"%%Trailer":
trailer_reached = True
elif bytes_mv[:14] == b"%%BeginBinary:":
bytecount = int(byte_arr[14:bytes_read])
self.fp.seek(bytecount, os.SEEK_CUR)
bytes_read = 0
check_required_header_comments()
if not self._size:
# A "BoundingBox" is always required,
# even if an "ImageData" descriptor size exists.
if not bounding_box:
msg = "cannot determine EPS bounding box"
raise OSError(msg)
def _find_offset(self, fp):
# An "ImageData" size takes precedence over the "BoundingBox".
self._size = imagedata_size or (
bounding_box[2] - bounding_box[0],
bounding_box[3] - bounding_box[1],
)
self.tile = [
ImageFile._Tile("eps", (0, 0) + self.size, offset, (length, bounding_box))
]
def _find_offset(self, fp: IO[bytes]) -> tuple[int, int]:
s = fp.read(4)
if s == b"%!PS":
@@ -401,7 +398,9 @@ class EpsImageFile(ImageFile.ImageFile):
return length, offset
def load(self, scale=1, transparency=False):
def load(
self, scale: int = 1, transparency: bool = False
) -> Image.core.PixelAccess | None:
# Load EPS via Ghostscript
if self.tile:
self.im = Ghostscript(self.tile, self.size, self.fp, scale, transparency)
@@ -410,7 +409,7 @@ class EpsImageFile(ImageFile.ImageFile):
self.tile = []
return Image.Image.load(self)
def load_seek(self, *args, **kwargs):
def load_seek(self, pos: int) -> None:
# we can't incrementally load, so force ImageFile.parser to
# use our custom load method by defining this method.
pass
@@ -419,7 +418,7 @@ class EpsImageFile(ImageFile.ImageFile):
# --------------------------------------------------------------------
def _save(im, fp, filename, eps=1):
def _save(im: Image.Image, fp: IO[bytes], filename: str | bytes, eps: int = 1) -> None:
"""EPS Writer for the Python Imaging Library."""
# make sure image data is available
@@ -460,7 +459,7 @@ def _save(im, fp, filename, eps=1):
if hasattr(fp, "flush"):
fp.flush()
ImageFile._save(im, fp, [("eps", (0, 0) + im.size, 0, None)])
ImageFile._save(im, fp, [ImageFile._Tile("eps", (0, 0) + im.size)])
fp.write(b"\n%%%%EndBinary\n")
fp.write(b"grestore end\n")

View File

@@ -13,6 +13,7 @@
This module provides constants and clear-text names for various
well-known EXIF tags.
"""
from __future__ import annotations
from enum import IntEnum
@@ -302,38 +303,38 @@ TAGS = {
class GPS(IntEnum):
GPSVersionID = 0
GPSLatitudeRef = 1
GPSLatitude = 2
GPSLongitudeRef = 3
GPSLongitude = 4
GPSAltitudeRef = 5
GPSAltitude = 6
GPSTimeStamp = 7
GPSSatellites = 8
GPSStatus = 9
GPSMeasureMode = 10
GPSDOP = 11
GPSSpeedRef = 12
GPSSpeed = 13
GPSTrackRef = 14
GPSTrack = 15
GPSImgDirectionRef = 16
GPSImgDirection = 17
GPSMapDatum = 18
GPSDestLatitudeRef = 19
GPSDestLatitude = 20
GPSDestLongitudeRef = 21
GPSDestLongitude = 22
GPSDestBearingRef = 23
GPSDestBearing = 24
GPSDestDistanceRef = 25
GPSDestDistance = 26
GPSProcessingMethod = 27
GPSAreaInformation = 28
GPSDateStamp = 29
GPSDifferential = 30
GPSHPositioningError = 31
GPSVersionID = 0x00
GPSLatitudeRef = 0x01
GPSLatitude = 0x02
GPSLongitudeRef = 0x03
GPSLongitude = 0x04
GPSAltitudeRef = 0x05
GPSAltitude = 0x06
GPSTimeStamp = 0x07
GPSSatellites = 0x08
GPSStatus = 0x09
GPSMeasureMode = 0x0A
GPSDOP = 0x0B
GPSSpeedRef = 0x0C
GPSSpeed = 0x0D
GPSTrackRef = 0x0E
GPSTrack = 0x0F
GPSImgDirectionRef = 0x10
GPSImgDirection = 0x11
GPSMapDatum = 0x12
GPSDestLatitudeRef = 0x13
GPSDestLatitude = 0x14
GPSDestLongitudeRef = 0x15
GPSDestLongitude = 0x16
GPSDestBearingRef = 0x17
GPSDestBearing = 0x18
GPSDestDistanceRef = 0x19
GPSDestDistance = 0x1A
GPSProcessingMethod = 0x1B
GPSAreaInformation = 0x1C
GPSDateStamp = 0x1D
GPSDifferential = 0x1E
GPSHPositioningError = 0x1F
"""Maps EXIF GPS tags to tag names."""
@@ -341,40 +342,41 @@ GPSTAGS = {i.value: i.name for i in GPS}
class Interop(IntEnum):
InteropIndex = 1
InteropVersion = 2
RelatedImageFileFormat = 4096
RelatedImageWidth = 4097
RleatedImageHeight = 4098
InteropIndex = 0x0001
InteropVersion = 0x0002
RelatedImageFileFormat = 0x1000
RelatedImageWidth = 0x1001
RelatedImageHeight = 0x1002
class IFD(IntEnum):
Exif = 34665
GPSInfo = 34853
Makernote = 37500
Interop = 40965
Exif = 0x8769
GPSInfo = 0x8825
MakerNote = 0x927C
Makernote = 0x927C # Deprecated
Interop = 0xA005
IFD1 = -1
class LightSource(IntEnum):
Unknown = 0
Daylight = 1
Fluorescent = 2
Tungsten = 3
Flash = 4
Fine = 9
Cloudy = 10
Shade = 11
DaylightFluorescent = 12
DayWhiteFluorescent = 13
CoolWhiteFluorescent = 14
WhiteFluorescent = 15
StandardLightA = 17
StandardLightB = 18
StandardLightC = 19
D55 = 20
D65 = 21
D75 = 22
D50 = 23
ISO = 24
Other = 255
Unknown = 0x00
Daylight = 0x01
Fluorescent = 0x02
Tungsten = 0x03
Flash = 0x04
Fine = 0x09
Cloudy = 0x0A
Shade = 0x0B
DaylightFluorescent = 0x0C
DayWhiteFluorescent = 0x0D
CoolWhiteFluorescent = 0x0E
WhiteFluorescent = 0x0F
StandardLightA = 0x11
StandardLightB = 0x12
StandardLightC = 0x13
D55 = 0x14
D65 = 0x15
D75 = 0x16
D50 = 0x17
ISO = 0x18
Other = 0xFF

View File

@@ -8,30 +8,52 @@
#
# See the README file for information on usage and redistribution.
#
from __future__ import annotations
import gzip
import math
from . import Image, ImageFile
def _accept(prefix):
return prefix[:6] == b"SIMPLE"
def _accept(prefix: bytes) -> bool:
return prefix.startswith(b"SIMPLE")
class FitsImageFile(ImageFile.ImageFile):
format = "FITS"
format_description = "FITS"
def _open(self):
headers = {}
def _open(self) -> None:
assert self.fp is not None
headers: dict[bytes, bytes] = {}
header_in_progress = False
decoder_name = ""
while True:
header = self.fp.read(80)
if not header:
msg = "Truncated FITS file"
raise OSError(msg)
keyword = header[:8].strip()
if keyword == b"END":
if keyword in (b"SIMPLE", b"XTENSION"):
header_in_progress = True
elif headers and not header_in_progress:
# This is now a data unit
break
elif keyword == b"END":
# Seek to the end of the header unit
self.fp.seek(math.ceil(self.fp.tell() / 2880) * 2880)
if not decoder_name:
decoder_name, offset, args = self._parse_headers(headers)
header_in_progress = False
continue
if decoder_name:
# Keep going to read past the headers
continue
value = header[8:].split(b"/")[0].strip()
if value.startswith(b"="):
value = value[1:].strip()
@@ -40,34 +62,91 @@ class FitsImageFile(ImageFile.ImageFile):
raise SyntaxError(msg)
headers[keyword] = value
naxis = int(headers[b"NAXIS"])
if naxis == 0:
if not decoder_name:
msg = "No image data"
raise ValueError(msg)
elif naxis == 1:
self._size = 1, int(headers[b"NAXIS1"])
else:
self._size = int(headers[b"NAXIS1"]), int(headers[b"NAXIS2"])
offset += self.fp.tell() - 80
self.tile = [ImageFile._Tile(decoder_name, (0, 0) + self.size, offset, args)]
def _get_size(
self, headers: dict[bytes, bytes], prefix: bytes
) -> tuple[int, int] | None:
naxis = int(headers[prefix + b"NAXIS"])
if naxis == 0:
return None
if naxis == 1:
return 1, int(headers[prefix + b"NAXIS1"])
else:
return int(headers[prefix + b"NAXIS1"]), int(headers[prefix + b"NAXIS2"])
def _parse_headers(
self, headers: dict[bytes, bytes]
) -> tuple[str, int, tuple[str | int, ...]]:
prefix = b""
decoder_name = "raw"
offset = 0
if (
headers.get(b"XTENSION") == b"'BINTABLE'"
and headers.get(b"ZIMAGE") == b"T"
and headers[b"ZCMPTYPE"] == b"'GZIP_1 '"
):
no_prefix_size = self._get_size(headers, prefix) or (0, 0)
number_of_bits = int(headers[b"BITPIX"])
offset = no_prefix_size[0] * no_prefix_size[1] * (number_of_bits // 8)
prefix = b"Z"
decoder_name = "fits_gzip"
size = self._get_size(headers, prefix)
if not size:
return "", 0, ()
self._size = size
number_of_bits = int(headers[prefix + b"BITPIX"])
if number_of_bits == 8:
self._mode = "L"
elif number_of_bits == 16:
self._mode = "I"
# rawmode = "I;16S"
self._mode = "I;16"
elif number_of_bits == 32:
self._mode = "I"
elif number_of_bits in (-32, -64):
self._mode = "F"
# rawmode = "F" if number_of_bits == -32 else "F;64F"
offset = math.ceil(self.fp.tell() / 2880) * 2880
self.tile = [("raw", (0, 0) + self.size, offset, (self.mode, 0, -1))]
args: tuple[str | int, ...]
if decoder_name == "raw":
args = (self.mode, 0, -1)
else:
args = (number_of_bits,)
return decoder_name, offset, args
class FitsGzipDecoder(ImageFile.PyDecoder):
_pulls_fd = True
def decode(self, buffer: bytes | Image.SupportsArrayInterface) -> tuple[int, int]:
assert self.fd is not None
value = gzip.decompress(self.fd.read())
rows = []
offset = 0
number_of_bits = min(self.args[0] // 8, 4)
for y in range(self.state.ysize):
row = bytearray()
for x in range(self.state.xsize):
row += value[offset + (4 - number_of_bits) : offset + 4]
offset += 4
rows.append(row)
self.set_as_raw(bytes([pixel for row in rows[::-1] for pixel in row]))
return -1, 0
# --------------------------------------------------------------------
# Registry
Image.register_open(FitsImageFile.format, FitsImageFile, _accept)
Image.register_decoder("fits_gzip", FitsGzipDecoder)
Image.register_extensions(FitsImageFile.format, [".fit", ".fits"])

View File

@@ -14,6 +14,7 @@
#
# See the README file for information on usage and redistribution.
#
from __future__ import annotations
import os
@@ -21,14 +22,15 @@ from . import Image, ImageFile, ImagePalette
from ._binary import i16le as i16
from ._binary import i32le as i32
from ._binary import o8
from ._util import DeferredError
#
# decoder
def _accept(prefix):
def _accept(prefix: bytes) -> bool:
return (
len(prefix) >= 6
len(prefix) >= 16
and i16(prefix, 4) in [0xAF11, 0xAF12]
and i16(prefix, 14) in [0, 3] # flags
)
@@ -44,10 +46,16 @@ class FliImageFile(ImageFile.ImageFile):
format_description = "Autodesk FLI/FLC Animation"
_close_exclusive_fp_after_loading = False
def _open(self):
def _open(self) -> None:
# HEAD
assert self.fp is not None
s = self.fp.read(128)
if not (_accept(s) and s[20:22] == b"\x00\x00"):
if not (
_accept(s)
and s[20:22] == b"\x00" * 2
and s[42:80] == b"\x00" * 38
and s[88:] == b"\x00" * 40
):
msg = "not an FLI/FLC file"
raise SyntaxError(msg)
@@ -75,13 +83,13 @@ class FliImageFile(ImageFile.ImageFile):
if i16(s, 4) == 0xF100:
# prefix chunk; ignore it
self.__offset = self.__offset + i32(s)
self.fp.seek(self.__offset + i32(s))
s = self.fp.read(16)
if i16(s, 4) == 0xF1FA:
# look for palette chunk
number_of_subchunks = i16(s, 6)
chunk_size = None
chunk_size: int | None = None
for _ in range(number_of_subchunks):
if chunk_size is not None:
self.fp.seek(chunk_size - 6, os.SEEK_CUR)
@@ -94,8 +102,9 @@ class FliImageFile(ImageFile.ImageFile):
if not chunk_size:
break
palette = [o8(r) + o8(g) + o8(b) for (r, g, b) in palette]
self.palette = ImagePalette.raw("RGB", b"".join(palette))
self.palette = ImagePalette.raw(
"RGB", b"".join(o8(r) + o8(g) + o8(b) for (r, g, b) in palette)
)
# set things up to decode first frame
self.__frame = -1
@@ -103,10 +112,11 @@ class FliImageFile(ImageFile.ImageFile):
self.__rewind = self.fp.tell()
self.seek(0)
def _palette(self, palette, shift):
def _palette(self, palette: list[tuple[int, int, int]], shift: int) -> None:
# load palette
i = 0
assert self.fp is not None
for e in range(i16(self.fp.read(2))):
s = self.fp.read(2)
i = i + s[0]
@@ -121,7 +131,7 @@ class FliImageFile(ImageFile.ImageFile):
palette[i] = (r, g, b)
i += 1
def seek(self, frame):
def seek(self, frame: int) -> None:
if not self._seek_check(frame):
return
if frame < self.__frame:
@@ -130,7 +140,9 @@ class FliImageFile(ImageFile.ImageFile):
for f in range(self.__frame + 1, frame + 1):
self._seek(f)
def _seek(self, frame):
def _seek(self, frame: int) -> None:
if isinstance(self._fp, DeferredError):
raise self._fp.ex
if frame == 0:
self.__frame = -1
self._fp.seek(self.__rewind)
@@ -150,16 +162,17 @@ class FliImageFile(ImageFile.ImageFile):
s = self.fp.read(4)
if not s:
raise EOFError
msg = "missing frame size"
raise EOFError(msg)
framesize = i32(s)
self.decodermaxblock = framesize
self.tile = [("fli", (0, 0) + self.size, self.__offset, None)]
self.tile = [ImageFile._Tile("fli", (0, 0) + self.size, self.__offset)]
self.__offset += framesize
def tell(self):
def tell(self) -> int:
return self.__frame

View File

@@ -13,16 +13,19 @@
#
# See the README file for information on usage and redistribution.
#
from __future__ import annotations
import os
from typing import BinaryIO
from . import Image, _binary
WIDTH = 800
def puti16(fp, values):
def puti16(
fp: BinaryIO, values: tuple[int, int, int, int, int, int, int, int, int, int]
) -> None:
"""Write network order (big-endian) 16-bit sequence"""
for v in values:
if v < 0:
@@ -33,16 +36,32 @@ def puti16(fp, values):
class FontFile:
"""Base class for raster font file handlers."""
bitmap = None
bitmap: Image.Image | None = None
def __init__(self):
self.info = {}
self.glyph = [None] * 256
def __init__(self) -> None:
self.info: dict[bytes, bytes | int] = {}
self.glyph: list[
tuple[
tuple[int, int],
tuple[int, int, int, int],
tuple[int, int, int, int],
Image.Image,
]
| None
] = [None] * 256
def __getitem__(self, ix):
def __getitem__(self, ix: int) -> (
tuple[
tuple[int, int],
tuple[int, int, int, int],
tuple[int, int, int, int],
Image.Image,
]
| None
):
return self.glyph[ix]
def compile(self):
def compile(self) -> None:
"""Create metrics and bitmap"""
if self.bitmap:
@@ -51,7 +70,7 @@ class FontFile:
# create bitmap large enough to hold all data
h = w = maxwidth = 0
lines = 1
for glyph in self:
for glyph in self.glyph:
if glyph:
d, dst, src, im = glyph
h = max(h, src[3] - src[1])
@@ -65,20 +84,22 @@ class FontFile:
ysize = lines * h
if xsize == 0 and ysize == 0:
return ""
return
self.ysize = h
# paste glyphs into bitmap
self.bitmap = Image.new("1", (xsize, ysize))
self.metrics = [None] * 256
self.metrics: list[
tuple[tuple[int, int], tuple[int, int, int, int], tuple[int, int, int, int]]
| None
] = [None] * 256
x = y = 0
for i in range(256):
glyph = self[i]
if glyph:
d, dst, src, im = glyph
xx = src[2] - src[0]
# yy = src[3] - src[1]
x0, y0 = x, y
x = x + xx
if x > WIDTH:
@@ -89,12 +110,15 @@ class FontFile:
self.bitmap.paste(im.crop(src), s)
self.metrics[i] = d, dst, s
def save(self, filename):
def save(self, filename: str) -> None:
"""Save font"""
self.compile()
# font data
if not self.bitmap:
msg = "No bitmap created"
raise ValueError(msg)
self.bitmap.save(os.path.splitext(filename)[0] + ".pbm", "PNG")
# font metrics
@@ -105,6 +129,6 @@ class FontFile:
for id in range(256):
m = self.metrics[id]
if not m:
puti16(fp, [0] * 10)
puti16(fp, (0,) * 10)
else:
puti16(fp, m[0] + m[1] + m[2])

View File

@@ -14,6 +14,8 @@
#
# See the README file for information on usage and redistribution.
#
from __future__ import annotations
import olefile
from . import Image, ImageFile
@@ -39,8 +41,8 @@ MODES = {
# --------------------------------------------------------------------
def _accept(prefix):
return prefix[:8] == olefile.MAGIC
def _accept(prefix: bytes) -> bool:
return prefix.startswith(olefile.MAGIC)
##
@@ -51,7 +53,7 @@ class FpxImageFile(ImageFile.ImageFile):
format = "FPX"
format_description = "FlashPix"
def _open(self):
def _open(self) -> None:
#
# read the OLE directory and see if this is a likely
# to be a FlashPix file
@@ -62,13 +64,14 @@ class FpxImageFile(ImageFile.ImageFile):
msg = "not an FPX file; invalid OLE file"
raise SyntaxError(msg) from e
if self.ole.root.clsid != "56616700-C154-11CE-8553-00AA00A1F95B":
root = self.ole.root
if not root or root.clsid != "56616700-C154-11CE-8553-00AA00A1F95B":
msg = "not an FPX file; bad root CLSID"
raise SyntaxError(msg)
self._open_index(1)
def _open_index(self, index=1):
def _open_index(self, index: int = 1) -> None:
#
# get the Image Contents Property Set
@@ -78,12 +81,14 @@ class FpxImageFile(ImageFile.ImageFile):
# size (highest resolution)
assert isinstance(prop[0x1000002], int)
assert isinstance(prop[0x1000003], int)
self._size = prop[0x1000002], prop[0x1000003]
size = max(self.size)
i = 1
while size > 64:
size = size / 2
size = size // 2
i += 1
self.maxid = i - 1
@@ -97,16 +102,14 @@ class FpxImageFile(ImageFile.ImageFile):
s = prop[0x2000002 | id]
colors = []
bands = i32(s, 4)
if bands > 4:
if not isinstance(s, bytes) or (bands := i32(s, 4)) > 4:
msg = "Invalid number of bands"
raise OSError(msg)
for i in range(bands):
# note: for now, we ignore the "uncalibrated" flag
colors.append(i32(s, 8 + i * 4) & 0x7FFFFFFF)
self._mode, self.rawmode = MODES[tuple(colors)]
# note: for now, we ignore the "uncalibrated" flag
colors = tuple(i32(s, 8 + i * 4) & 0x7FFFFFFF for i in range(bands))
self._mode, self.rawmode = MODES[colors]
# load JPEG tables, if any
self.jpeg = {}
@@ -117,7 +120,7 @@ class FpxImageFile(ImageFile.ImageFile):
self._open_subimage(1, self.maxid)
def _open_subimage(self, index=1, subimage=0):
def _open_subimage(self, index: int = 1, subimage: int = 0) -> None:
#
# setup tile descriptors for a given subimage
@@ -163,18 +166,18 @@ class FpxImageFile(ImageFile.ImageFile):
if compression == 0:
self.tile.append(
(
ImageFile._Tile(
"raw",
(x, y, x1, y1),
i32(s, i) + 28,
(self.rawmode,),
self.rawmode,
)
)
elif compression == 1:
# FIXME: the fill decoder is not implemented
self.tile.append(
(
ImageFile._Tile(
"fill",
(x, y, x1, y1),
i32(s, i) + 28,
@@ -202,7 +205,7 @@ class FpxImageFile(ImageFile.ImageFile):
jpegmode = rawmode
self.tile.append(
(
ImageFile._Tile(
"jpeg",
(x, y, x1, y1),
i32(s, i) + 28,
@@ -227,19 +230,20 @@ class FpxImageFile(ImageFile.ImageFile):
break # isn't really required
self.stream = stream
self._fp = self.fp
self.fp = None
def load(self):
def load(self) -> Image.core.PixelAccess | None:
if not self.fp:
self.fp = self.ole.openstream(self.stream[:2] + ["Subimage 0000 Data"])
return ImageFile.ImageFile.load(self)
def close(self):
def close(self) -> None:
self.ole.close()
super().close()
def __exit__(self, *args):
def __exit__(self, *args: object) -> None:
self.ole.close()
super().__exit__()

View File

@@ -51,6 +51,8 @@ bytes for that mipmap level.
Note: All data is stored in little-Endian (Intel) byte order.
"""
from __future__ import annotations
import struct
from enum import IntEnum
from io import BytesIO
@@ -69,7 +71,7 @@ class FtexImageFile(ImageFile.ImageFile):
format = "FTEX"
format_description = "Texture File Format (IW2:EOC)"
def _open(self):
def _open(self) -> None:
if not _accept(self.fp.read(4)):
msg = "not an FTEX file"
raise SyntaxError(msg)
@@ -77,8 +79,6 @@ class FtexImageFile(ImageFile.ImageFile):
self._size = struct.unpack("<2i", self.fp.read(8))
mipmap_count, format_count = struct.unpack("<2i", self.fp.read(8))
self._mode = "RGB"
# Only support single-format files.
# I don't know of any multi-format file.
assert format_count == 1
@@ -91,9 +91,10 @@ class FtexImageFile(ImageFile.ImageFile):
if format == Format.DXT1:
self._mode = "RGBA"
self.tile = [("bcn", (0, 0) + self.size, 0, 1)]
self.tile = [ImageFile._Tile("bcn", (0, 0) + self.size, 0, (1,))]
elif format == Format.UNCOMPRESSED:
self.tile = [("raw", (0, 0) + self.size, 0, ("RGB", 0, 1))]
self._mode = "RGB"
self.tile = [ImageFile._Tile("raw", (0, 0) + self.size, 0, "RGB")]
else:
msg = f"Invalid texture compression format: {repr(format)}"
raise ValueError(msg)
@@ -101,12 +102,12 @@ class FtexImageFile(ImageFile.ImageFile):
self.fp.close()
self.fp = BytesIO(data)
def load_seek(self, pos):
def load_seek(self, pos: int) -> None:
pass
def _accept(prefix):
return prefix[:4] == MAGIC
def _accept(prefix: bytes) -> bool:
return prefix.startswith(MAGIC)
Image.register_open(FtexImageFile.format, FtexImageFile, _accept)

View File

@@ -23,12 +23,13 @@
# Version 2 files are saved by GIMP v2.8 (at least)
# Version 3 files have a format specifier of 18 for 16bit floats in
# the color depth field. This is currently unsupported by Pillow.
from __future__ import annotations
from . import Image, ImageFile
from ._binary import i32be as i32
def _accept(prefix):
def _accept(prefix: bytes) -> bool:
return len(prefix) >= 8 and i32(prefix, 0) >= 20 and i32(prefix, 4) in (1, 2)
@@ -40,7 +41,7 @@ class GbrImageFile(ImageFile.ImageFile):
format = "GBR"
format_description = "GIMP brush file"
def _open(self):
def _open(self) -> None:
header_size = i32(self.fp.read(4))
if header_size < 20:
msg = "not a GIMP brush"
@@ -53,7 +54,7 @@ class GbrImageFile(ImageFile.ImageFile):
width = i32(self.fp.read(4))
height = i32(self.fp.read(4))
color_depth = i32(self.fp.read(4))
if width <= 0 or height <= 0:
if width == 0 or height == 0:
msg = "not a GIMP brush"
raise SyntaxError(msg)
if color_depth not in (1, 4):
@@ -70,7 +71,7 @@ class GbrImageFile(ImageFile.ImageFile):
raise SyntaxError(msg)
self.info["spacing"] = i32(self.fp.read(4))
comment = self.fp.read(comment_length)[:-1]
self.info["comment"] = self.fp.read(comment_length)[:-1]
if color_depth == 1:
self._mode = "L"
@@ -79,16 +80,14 @@ class GbrImageFile(ImageFile.ImageFile):
self._size = width, height
self.info["comment"] = comment
# Image might not be small
Image._decompression_bomb_check(self.size)
# Data is an uncompressed block of w * h * bytes/pixel
self._data_size = width * height * color_depth
def load(self):
if not self.im:
def load(self) -> Image.core.PixelAccess | None:
if self._im is None:
self.im = Image.core.new(self.mode, self.size)
self.frombytes(self.fp.read(self._data_size))
return Image.Image.load(self)

View File

@@ -25,11 +25,14 @@
implementation is provided for convenience and demonstrational
purposes only.
"""
from __future__ import annotations
from typing import IO
from . import ImageFile, ImagePalette, UnidentifiedImageError
from ._binary import i16be as i16
from ._binary import i32be as i32
from ._typing import StrOrBytesPath
class GdImageFile(ImageFile.ImageFile):
@@ -43,15 +46,17 @@ class GdImageFile(ImageFile.ImageFile):
format = "GD"
format_description = "GD uncompressed images"
def _open(self):
def _open(self) -> None:
# Header
assert self.fp is not None
s = self.fp.read(1037)
if i16(s) not in [65534, 65535]:
msg = "Not a valid GD 2.x .gd file"
raise SyntaxError(msg)
self._mode = "L" # FIXME: "P"
self._mode = "P"
self._size = i16(s, 2), i16(s, 4)
true_color = s[6]
@@ -63,20 +68,20 @@ class GdImageFile(ImageFile.ImageFile):
self.info["transparency"] = tindex
self.palette = ImagePalette.raw(
"XBGR", s[7 + true_color_offset + 4 : 7 + true_color_offset + 4 + 256 * 4]
"RGBX", s[7 + true_color_offset + 6 : 7 + true_color_offset + 6 + 256 * 4]
)
self.tile = [
(
ImageFile._Tile(
"raw",
(0, 0) + self.size,
7 + true_color_offset + 4 + 256 * 4,
("L", 0, 1),
7 + true_color_offset + 6 + 256 * 4,
"L",
)
]
def open(fp, mode="r"):
def open(fp: StrOrBytesPath | IO[bytes], mode: str = "r") -> GdImageFile:
"""
Load texture from a GD image file.

View File

@@ -23,17 +23,36 @@
#
# See the README file for information on usage and redistribution.
#
from __future__ import annotations
import itertools
import math
import os
import subprocess
from enum import IntEnum
from functools import cached_property
from typing import Any, NamedTuple, cast
from . import Image, ImageChops, ImageFile, ImagePalette, ImageSequence
from . import (
Image,
ImageChops,
ImageFile,
ImageMath,
ImageOps,
ImagePalette,
ImageSequence,
)
from ._binary import i16le as i16
from ._binary import o8
from ._binary import o16le as o16
from ._util import DeferredError
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import IO, Literal
from . import _imaging
from ._typing import Buffer
class LoadingStrategy(IntEnum):
@@ -51,8 +70,8 @@ LOADING_STRATEGY = LoadingStrategy.RGB_AFTER_FIRST
# Identify/read GIF files
def _accept(prefix):
return prefix[:6] in [b"GIF87a", b"GIF89a"]
def _accept(prefix: bytes) -> bool:
return prefix.startswith((b"GIF87a", b"GIF89a"))
##
@@ -67,19 +86,19 @@ class GifImageFile(ImageFile.ImageFile):
global_palette = None
def data(self):
def data(self) -> bytes | None:
s = self.fp.read(1)
if s and s[0]:
return self.fp.read(s[0])
return None
def _is_palette_needed(self, p):
def _is_palette_needed(self, p: bytes) -> bool:
for i in range(0, len(p), 3):
if not (i // 3 == p[i] == p[i + 1] == p[i + 2]):
return True
return False
def _open(self):
def _open(self) -> None:
# Screen
s = self.fp.read(13)
if not _accept(s):
@@ -88,7 +107,6 @@ class GifImageFile(ImageFile.ImageFile):
self.info["version"] = s[:6]
self._size = i16(s, 6), i16(s, 8)
self.tile = []
flags = s[10]
bits = (flags & 7) + 1
@@ -103,12 +121,11 @@ class GifImageFile(ImageFile.ImageFile):
self._fp = self.fp # FIXME: hack
self.__rewind = self.fp.tell()
self._n_frames = None
self._is_animated = None
self._n_frames: int | None = None
self._seek(0) # get ready to read first frame
@property
def n_frames(self):
def n_frames(self) -> int:
if self._n_frames is None:
current = self.tell()
try:
@@ -119,30 +136,29 @@ class GifImageFile(ImageFile.ImageFile):
self.seek(current)
return self._n_frames
@property
def is_animated(self):
if self._is_animated is None:
@cached_property
def is_animated(self) -> bool:
if self._n_frames is not None:
self._is_animated = self._n_frames != 1
else:
return self._n_frames != 1
current = self.tell()
if current:
self._is_animated = True
else:
return True
try:
self._seek(1, False)
self._is_animated = True
is_animated = True
except EOFError:
self._is_animated = False
is_animated = False
self.seek(current)
return self._is_animated
return is_animated
def seek(self, frame):
def seek(self, frame: int) -> None:
if not self._seek_check(frame):
return
if frame < self.__frame:
self.im = None
self._im = None
self._seek(0)
last_frame = self.__frame
@@ -154,11 +170,13 @@ class GifImageFile(ImageFile.ImageFile):
msg = "no more images in GIF file"
raise EOFError(msg) from e
def _seek(self, frame, update_image=True):
def _seek(self, frame: int, update_image: bool = True) -> None:
if isinstance(self._fp, DeferredError):
raise self._fp.ex
if frame == 0:
# rewind
self.__offset = 0
self.dispose = None
self.dispose: _imaging.ImagingCore | None = None
self.__frame = -1
self._fp.seek(self.__rewind)
self.disposal_method = 0
@@ -183,11 +201,12 @@ class GifImageFile(ImageFile.ImageFile):
s = self.fp.read(1)
if not s or s == b";":
raise EOFError
msg = "no more images in GIF file"
raise EOFError(msg)
palette = None
palette: ImagePalette.ImagePalette | Literal[False] | None = None
info = {}
info: dict[str, Any] = {}
frame_transparency = None
interlace = None
frame_dispose_extent = None
@@ -203,7 +222,7 @@ class GifImageFile(ImageFile.ImageFile):
#
s = self.fp.read(1)
block = self.data()
if s[0] == 249:
if s[0] == 249 and block is not None:
#
# graphic control extension
#
@@ -239,14 +258,14 @@ class GifImageFile(ImageFile.ImageFile):
info["comment"] = comment
s = None
continue
elif s[0] == 255 and frame == 0:
elif s[0] == 255 and frame == 0 and block is not None:
#
# application extension
#
info["extension"] = block, self.fp.tell()
if block[:11] == b"NETSCAPE2.0":
if block.startswith(b"NETSCAPE2.0"):
block = self.data()
if len(block) >= 3 and block[0] == 1:
if block and len(block) >= 3 and block[0] == 1:
self.info["loop"] = i16(block, 1)
while self.data():
pass
@@ -280,15 +299,11 @@ class GifImageFile(ImageFile.ImageFile):
bits = self.fp.read(1)[0]
self.__offset = self.fp.tell()
break
else:
pass
# raise OSError, "illegal GIF tag `%x`" % s[0]
s = None
if interlace is None:
# self._fp = None
raise EOFError
msg = "image not found in GIF frame"
raise EOFError(msg)
self.__frame = frame
if not update_image:
@@ -310,18 +325,20 @@ class GifImageFile(ImageFile.ImageFile):
else:
self._mode = "L"
if not palette and self.global_palette:
if palette:
self.palette = palette
elif self.global_palette:
from copy import copy
palette = copy(self.global_palette)
self.palette = palette
self.palette = copy(self.global_palette)
else:
self.palette = None
else:
if self.mode == "P":
if (
LOADING_STRATEGY != LoadingStrategy.RGB_AFTER_DIFFERENT_PALETTE_ONLY
or palette
):
self.pyaccess = None
if "transparency" in self.info:
self.im.putpalettealpha(self.info["transparency"], 0)
self.im = self.im.convert("RGBA", Image.Dither.FLOYDSTEINBERG)
@@ -331,19 +348,22 @@ class GifImageFile(ImageFile.ImageFile):
self._mode = "RGB"
self.im = self.im.convert("RGB", Image.Dither.FLOYDSTEINBERG)
def _rgb(color):
def _rgb(color: int) -> tuple[int, int, int]:
if self._frame_palette:
color = tuple(self._frame_palette.palette[color * 3 : color * 3 + 3])
if color * 3 + 3 > len(self._frame_palette.palette):
color = 0
return cast(
tuple[int, int, int],
tuple(self._frame_palette.palette[color * 3 : color * 3 + 3]),
)
else:
color = (color, color, color)
return color
return (color, color, color)
self.dispose_extent = frame_dispose_extent
try:
if self.disposal_method < 2:
# do not dispose or none specified
self.dispose = None
elif self.disposal_method == 2:
self.dispose_extent: tuple[int, int, int, int] | None = frame_dispose_extent
if self.dispose_extent and self.disposal_method >= 2:
try:
if self.disposal_method == 2:
# replace with background colour
# only dispose the extent in this frame
@@ -367,7 +387,7 @@ class GifImageFile(ImageFile.ImageFile):
self.dispose = Image.core.fill(dispose_mode, dispose_size, color)
else:
# replace with previous contents
if self.im is not None:
if self._im is not None:
# only dispose the extent in this frame
self.dispose = self._crop(self.im, self.dispose_extent)
elif frame_transparency is not None:
@@ -380,7 +400,9 @@ class GifImageFile(ImageFile.ImageFile):
if self.mode in ("RGB", "RGBA"):
dispose_mode = "RGBA"
color = _rgb(frame_transparency) + (0,)
self.dispose = Image.core.fill(dispose_mode, dispose_size, color)
self.dispose = Image.core.fill(
dispose_mode, dispose_size, color
)
except AttributeError:
pass
@@ -393,7 +415,7 @@ class GifImageFile(ImageFile.ImageFile):
elif self.mode not in ("RGB", "RGBA"):
transparency = frame_transparency
self.tile = [
(
ImageFile._Tile(
"gif",
(x0, y0, x1, y1),
self.__offset,
@@ -409,7 +431,7 @@ class GifImageFile(ImageFile.ImageFile):
elif k in self.info:
del self.info[k]
def load_prepare(self):
def load_prepare(self) -> None:
temp_mode = "P" if self._frame_palette else "L"
self._prev_im = None
if self.__frame == 0:
@@ -421,15 +443,22 @@ class GifImageFile(ImageFile.ImageFile):
self._prev_im = self.im
if self._frame_palette:
self.im = Image.core.fill("P", self.size, self._frame_transparency or 0)
self.im.putpalette(*self._frame_palette.getdata())
self.im.putpalette("RGB", *self._frame_palette.getdata())
else:
self.im = None
self._im = None
if not self._prev_im and self._im is not None and self.size != self.im.size:
expanded_im = Image.core.fill(self.im.mode, self.size)
if self._frame_palette:
expanded_im.putpalette("RGB", *self._frame_palette.getdata())
expanded_im.paste(self.im, (0, 0) + self.im.size)
self.im = expanded_im
self._mode = temp_mode
self._frame_palette = None
super().load_prepare()
def load_end(self):
def load_end(self) -> None:
if self.__frame == 0:
if self.mode == "P" and LOADING_STRATEGY == LoadingStrategy.RGB_ALWAYS:
if self._frame_transparency is not None:
@@ -441,21 +470,37 @@ class GifImageFile(ImageFile.ImageFile):
return
if not self._prev_im:
return
if self.size != self._prev_im.size:
if self._frame_transparency is not None:
expanded_im = Image.core.fill("RGBA", self.size)
else:
expanded_im = Image.core.fill("P", self.size)
expanded_im.putpalette("RGB", "RGB", self.im.getpalette())
expanded_im = expanded_im.convert("RGB")
expanded_im.paste(self._prev_im, (0, 0) + self._prev_im.size)
self._prev_im = expanded_im
assert self._prev_im is not None
if self._frame_transparency is not None:
if self.mode == "L":
frame_im = self.im.convert_transparent("LA", self._frame_transparency)
else:
self.im.putpalettealpha(self._frame_transparency, 0)
frame_im = self.im.convert("RGBA")
else:
frame_im = self.im.convert("RGB")
assert self.dispose_extent is not None
frame_im = self._crop(frame_im, self.dispose_extent)
self.im = self._prev_im
self._mode = self.im.mode
if frame_im.mode == "RGBA":
if frame_im.mode in ("LA", "RGBA"):
self.im.paste(frame_im, self.dispose_extent, frame_im)
else:
self.im.paste(frame_im, self.dispose_extent)
def tell(self):
def tell(self) -> int:
return self.__frame
@@ -466,7 +511,7 @@ class GifImageFile(ImageFile.ImageFile):
RAWMODE = {"1": "L", "L": "L", "P": "P"}
def _normalize_mode(im):
def _normalize_mode(im: Image.Image) -> Image.Image:
"""
Takes an image (or frame), returns an image in a mode that is appropriate
for saving in a Gif.
@@ -482,6 +527,7 @@ def _normalize_mode(im):
return im
if Image.getmodebase(im.mode) == "RGB":
im = im.convert("P", palette=Image.Palette.ADAPTIVE)
assert im.palette is not None
if im.palette.mode == "RGBA":
for rgba in im.palette.colors:
if rgba[3] == 0:
@@ -491,7 +537,12 @@ def _normalize_mode(im):
return im.convert("L")
def _normalize_palette(im, palette, info):
_Palette = bytes | bytearray | list[int] | ImagePalette.ImagePalette
def _normalize_palette(
im: Image.Image, palette: _Palette | None, info: dict[str, Any]
) -> Image.Image:
"""
Normalizes the palette for image.
- Sets the palette to the incoming palette, if provided.
@@ -513,14 +564,18 @@ def _normalize_palette(im, palette, info):
if im.mode == "P":
if not source_palette:
source_palette = im.im.getpalette("RGB")[:768]
im_palette = im.getpalette(None)
assert im_palette is not None
source_palette = bytearray(im_palette)
else: # L-mode
if not source_palette:
source_palette = bytearray(i // 3 for i in range(768))
im.palette = ImagePalette.ImagePalette("RGB", palette=source_palette)
assert source_palette is not None
if palette:
used_palette_colors = []
used_palette_colors: list[int | None] = []
assert im.palette is not None
for i in range(0, len(source_palette), 3):
source_color = tuple(source_palette[i : i + 3])
index = im.palette.colors.get(source_color)
@@ -533,19 +588,37 @@ def _normalize_palette(im, palette, info):
if j not in used_palette_colors:
used_palette_colors[i] = j
break
im = im.remap_palette(used_palette_colors)
dest_map: list[int] = []
for index in used_palette_colors:
assert index is not None
dest_map.append(index)
im = im.remap_palette(dest_map)
else:
used_palette_colors = _get_optimize(im, info)
if used_palette_colors is not None:
return im.remap_palette(used_palette_colors, source_palette)
optimized_palette_colors = _get_optimize(im, info)
if optimized_palette_colors is not None:
im = im.remap_palette(optimized_palette_colors, source_palette)
if "transparency" in info:
try:
info["transparency"] = optimized_palette_colors.index(
info["transparency"]
)
except ValueError:
del info["transparency"]
return im
assert im.palette is not None
im.palette.palette = source_palette
return im
def _write_single_frame(im, fp, palette):
def _write_single_frame(
im: Image.Image,
fp: IO[bytes],
palette: _Palette | None,
) -> None:
im_out = _normalize_mode(im)
for k, v in im_out.info.items():
if isinstance(k, str):
im.encoderinfo.setdefault(k, v)
im_out = _normalize_palette(im_out, palette, im.encoderinfo)
@@ -559,26 +632,40 @@ def _write_single_frame(im, fp, palette):
_write_local_header(fp, im, (0, 0), flags)
im_out.encoderconfig = (8, get_interlace(im))
ImageFile._save(im_out, fp, [("gif", (0, 0) + im.size, 0, RAWMODE[im_out.mode])])
ImageFile._save(
im_out, fp, [ImageFile._Tile("gif", (0, 0) + im.size, 0, RAWMODE[im_out.mode])]
)
fp.write(b"\0") # end of image data
def _getbbox(base_im, im_frame):
if _get_palette_bytes(im_frame) == _get_palette_bytes(base_im):
def _getbbox(
base_im: Image.Image, im_frame: Image.Image
) -> tuple[Image.Image, tuple[int, int, int, int] | None]:
palette_bytes = [
bytes(im.palette.palette) if im.palette else b"" for im in (base_im, im_frame)
]
if palette_bytes[0] != palette_bytes[1]:
im_frame = im_frame.convert("RGBA")
base_im = base_im.convert("RGBA")
delta = ImageChops.subtract_modulo(im_frame, base_im)
else:
delta = ImageChops.subtract_modulo(
im_frame.convert("RGBA"), base_im.convert("RGBA")
)
return delta.getbbox(alpha_only=False)
return delta, delta.getbbox(alpha_only=False)
def _write_multiple_frames(im, fp, palette):
class _Frame(NamedTuple):
im: Image.Image
bbox: tuple[int, int, int, int] | None
encoderinfo: dict[str, Any]
def _write_multiple_frames(
im: Image.Image, fp: IO[bytes], palette: _Palette | None
) -> bool:
duration = im.encoderinfo.get("duration")
disposal = im.encoderinfo.get("disposal", im.info.get("disposal"))
im_frames = []
im_frames: list[_Frame] = []
previous_im: Image.Image | None = None
frame_count = 0
background_im = None
for imSequence in itertools.chain([im], im.encoderinfo.get("append_images", [])):
@@ -589,12 +676,13 @@ def _write_multiple_frames(im, fp, palette):
for k, v in im_frame.info.items():
if k == "transparency":
continue
if isinstance(k, str):
im.encoderinfo.setdefault(k, v)
encoderinfo = im.encoderinfo.copy()
im_frame = _normalize_palette(im_frame, palette, encoderinfo)
if "transparency" in im_frame.info:
encoderinfo.setdefault("transparency", im_frame.info["transparency"])
im_frame = _normalize_palette(im_frame, palette, encoderinfo)
if isinstance(duration, (list, tuple)):
encoderinfo["duration"] = duration[frame_count]
elif duration is None and "duration" in im_frame.info:
@@ -603,63 +691,116 @@ def _write_multiple_frames(im, fp, palette):
encoderinfo["disposal"] = disposal[frame_count]
frame_count += 1
if im_frames:
diff_frame = None
if im_frames and previous_im:
# delta frame
previous = im_frames[-1]
bbox = _getbbox(previous["im"], im_frame)
delta, bbox = _getbbox(previous_im, im_frame)
if not bbox:
# This frame is identical to the previous frame
if encoderinfo.get("duration"):
previous["encoderinfo"]["duration"] += encoderinfo["duration"]
im_frames[-1].encoderinfo["duration"] += encoderinfo["duration"]
continue
if encoderinfo.get("disposal") == 2:
if background_im is None:
if im_frames[-1].encoderinfo.get("disposal") == 2:
# To appear correctly in viewers using a convention,
# only consider transparency, and not background color
color = im.encoderinfo.get(
"transparency", im.info.get("transparency", (0, 0, 0))
"transparency", im.info.get("transparency")
)
if color is not None:
if background_im is None:
background = _get_background(im_frame, color)
background_im = Image.new("P", im_frame.size, background)
background_im.putpalette(im_frames[0]["im"].palette)
bbox = _getbbox(background_im, im_frame)
first_palette = im_frames[0].im.palette
assert first_palette is not None
background_im.putpalette(first_palette, first_palette.mode)
bbox = _getbbox(background_im, im_frame)[1]
else:
bbox = (0, 0) + im_frame.size
elif encoderinfo.get("optimize") and im_frame.mode != "1":
if "transparency" not in encoderinfo:
assert im_frame.palette is not None
try:
encoderinfo["transparency"] = (
im_frame.palette._new_color_index(im_frame)
)
except ValueError:
pass
if "transparency" in encoderinfo:
# When the delta is zero, fill the image with transparency
diff_frame = im_frame.copy()
fill = Image.new("P", delta.size, encoderinfo["transparency"])
if delta.mode == "RGBA":
r, g, b, a = delta.split()
mask = ImageMath.lambda_eval(
lambda args: args["convert"](
args["max"](
args["max"](
args["max"](args["r"], args["g"]), args["b"]
),
args["a"],
)
* 255,
"1",
),
r=r,
g=g,
b=b,
a=a,
)
else:
if delta.mode == "P":
# Convert to L without considering palette
delta_l = Image.new("L", delta.size)
delta_l.putdata(delta.getdata())
delta = delta_l
mask = ImageMath.lambda_eval(
lambda args: args["convert"](args["im"] * 255, "1"),
im=delta,
)
diff_frame.paste(fill, mask=ImageOps.invert(mask))
else:
bbox = None
im_frames.append({"im": im_frame, "bbox": bbox, "encoderinfo": encoderinfo})
previous_im = im_frame
im_frames.append(_Frame(diff_frame or im_frame, bbox, encoderinfo))
if len(im_frames) == 1:
if "duration" in im.encoderinfo:
# Since multiple frames will not be written, use the combined duration
im.encoderinfo["duration"] = im_frames[0].encoderinfo["duration"]
return False
if len(im_frames) > 1:
for frame_data in im_frames:
im_frame = frame_data["im"]
if not frame_data["bbox"]:
im_frame = frame_data.im
if not frame_data.bbox:
# global header
for s in _get_global_header(im_frame, frame_data["encoderinfo"]):
for s in _get_global_header(im_frame, frame_data.encoderinfo):
fp.write(s)
offset = (0, 0)
else:
# compress difference
if not palette:
frame_data["encoderinfo"]["include_color_table"] = True
frame_data.encoderinfo["include_color_table"] = True
im_frame = im_frame.crop(frame_data["bbox"])
offset = frame_data["bbox"][:2]
_write_frame_data(fp, im_frame, offset, frame_data["encoderinfo"])
if frame_data.bbox != (0, 0) + im_frame.size:
im_frame = im_frame.crop(frame_data.bbox)
offset = frame_data.bbox[:2]
_write_frame_data(fp, im_frame, offset, frame_data.encoderinfo)
return True
elif "duration" in im.encoderinfo and isinstance(
im.encoderinfo["duration"], (list, tuple)
):
# Since multiple frames will not be written, add together the frame durations
im.encoderinfo["duration"] = sum(im.encoderinfo["duration"])
def _save_all(im, fp, filename):
def _save_all(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
_save(im, fp, filename, save_all=True)
def _save(im, fp, filename, save_all=False):
def _save(
im: Image.Image, fp: IO[bytes], filename: str | bytes, save_all: bool = False
) -> None:
# header
if "palette" in im.encoderinfo or "palette" in im.info:
palette = im.encoderinfo.get("palette", im.info.get("palette"))
else:
palette = None
im.encoderinfo["optimize"] = im.encoderinfo.get("optimize", True)
im.encoderinfo.setdefault("optimize", True)
if not save_all or not _write_multiple_frames(im, fp, palette):
_write_single_frame(im, fp, palette)
@@ -670,7 +811,7 @@ def _save(im, fp, filename, save_all=False):
fp.flush()
def get_interlace(im):
def get_interlace(im: Image.Image) -> int:
interlace = im.encoderinfo.get("interlace", 1)
# workaround for @PIL153
@@ -680,23 +821,13 @@ def get_interlace(im):
return interlace
def _write_local_header(fp, im, offset, flags):
transparent_color_exists = False
def _write_local_header(
fp: IO[bytes], im: Image.Image, offset: tuple[int, int], flags: int
) -> None:
try:
transparency = int(im.encoderinfo["transparency"])
except (KeyError, ValueError):
pass
else:
# optimize the block away if transparent color is not used
transparent_color_exists = True
used_palette_colors = _get_optimize(im, im.encoderinfo)
if used_palette_colors is not None:
# adjust the transparency index after optimize
try:
transparency = used_palette_colors.index(transparency)
except ValueError:
transparent_color_exists = False
transparency = im.encoderinfo["transparency"]
except KeyError:
transparency = None
if "duration" in im.encoderinfo:
duration = int(im.encoderinfo["duration"] / 10)
@@ -705,11 +836,9 @@ def _write_local_header(fp, im, offset, flags):
disposal = int(im.encoderinfo.get("disposal", 0))
if transparent_color_exists or duration != 0 or disposal:
packed_flag = 1 if transparent_color_exists else 0
if transparency is not None or duration != 0 or disposal:
packed_flag = 1 if transparency is not None else 0
packed_flag |= disposal << 2
if not transparent_color_exists:
transparency = 0
fp.write(
b"!"
@@ -717,7 +846,7 @@ def _write_local_header(fp, im, offset, flags):
+ o8(4) # length
+ o8(packed_flag) # packed fields
+ o16(duration) # duration
+ o8(transparency) # transparency index
+ o8(transparency or 0) # transparency index
+ o8(0)
)
@@ -742,7 +871,7 @@ def _write_local_header(fp, im, offset, flags):
fp.write(o8(8)) # bits
def _save_netpbm(im, fp, filename):
def _save_netpbm(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
# Unused by default.
# To use, uncomment the register_save call at the end of the file.
#
@@ -773,6 +902,7 @@ def _save_netpbm(im, fp, filename):
)
# Allow ppmquant to receive SIGPIPE if ppmtogif exits
assert quant_proc.stdout is not None
quant_proc.stdout.close()
retcode = quant_proc.wait()
@@ -794,7 +924,7 @@ def _save_netpbm(im, fp, filename):
_FORCE_OPTIMIZE = False
def _get_optimize(im, info):
def _get_optimize(im: Image.Image, info: dict[str, Any]) -> list[int] | None:
"""
Palette optimization is a potentially expensive operation.
@@ -805,7 +935,7 @@ def _get_optimize(im, info):
:param info: encoderinfo
:returns: list of indexes of palette entries in use, or None
"""
if im.mode in ("P", "L") and info and info.get("optimize", 0):
if im.mode in ("P", "L") and info and info.get("optimize"):
# Potentially expensive operation.
# The palette saves 3 bytes per color not used, but palette
@@ -827,6 +957,7 @@ def _get_optimize(im, info):
if optimise or max(used_palette_colors) >= len(used_palette_colors):
return used_palette_colors
assert im.palette is not None
num_palette_colors = len(im.palette.palette) // Image.getmodebands(
im.palette.mode
)
@@ -838,9 +969,10 @@ def _get_optimize(im, info):
and current_palette_size > 2
):
return used_palette_colors
return None
def _get_color_table_size(palette_bytes):
def _get_color_table_size(palette_bytes: bytes) -> int:
# calculate the palette size for the header
if not palette_bytes:
return 0
@@ -850,7 +982,7 @@ def _get_color_table_size(palette_bytes):
return math.ceil(math.log(len(palette_bytes) // 3, 2)) - 1
def _get_header_palette(palette_bytes):
def _get_header_palette(palette_bytes: bytes) -> bytes:
"""
Returns the palette, null padded to the next power of 2 (*3) bytes
suitable for direct inclusion in the GIF header
@@ -868,23 +1000,33 @@ def _get_header_palette(palette_bytes):
return palette_bytes
def _get_palette_bytes(im):
def _get_palette_bytes(im: Image.Image) -> bytes:
"""
Gets the palette for inclusion in the gif header
:param im: Image object
:returns: Bytes, len<=768 suitable for inclusion in gif header
"""
return im.palette.palette if im.palette else b""
if not im.palette:
return b""
palette = bytes(im.palette.palette)
if im.palette.mode == "RGBA":
palette = b"".join(palette[i * 4 : i * 4 + 3] for i in range(len(palette) // 3))
return palette
def _get_background(im, info_background):
def _get_background(
im: Image.Image,
info_background: int | tuple[int, int, int] | tuple[int, int, int, int] | None,
) -> int:
background = 0
if info_background:
if isinstance(info_background, tuple):
# WebPImagePlugin stores an RGBA value in info["background"]
# So it must be converted to the same format as GifImagePlugin's
# info["background"] - a global color table index
assert im.palette is not None
try:
background = im.palette.getcolor(info_background, im)
except ValueError as e:
@@ -901,7 +1043,7 @@ def _get_background(im, info_background):
return background
def _get_global_header(im, info):
def _get_global_header(im: Image.Image, info: dict[str, Any]) -> list[bytes]:
"""Return a list of strings representing a GIF header"""
# Header Block
@@ -963,7 +1105,12 @@ def _get_global_header(im, info):
return header
def _write_frame_data(fp, im_frame, offset, params):
def _write_frame_data(
fp: IO[bytes],
im_frame: Image.Image,
offset: tuple[int, int],
params: dict[str, Any],
) -> None:
try:
im_frame.encoderinfo = params
@@ -971,7 +1118,9 @@ def _write_frame_data(fp, im_frame, offset, params):
_write_local_header(fp, im_frame, offset, 0)
ImageFile._save(
im_frame, fp, [("gif", (0, 0) + im_frame.size, 0, RAWMODE[im_frame.mode])]
im_frame,
fp,
[ImageFile._Tile("gif", (0, 0) + im_frame.size, 0, RAWMODE[im_frame.mode])],
)
fp.write(b"\0") # end of image data
@@ -983,7 +1132,9 @@ def _write_frame_data(fp, im_frame, offset, params):
# Legacy GIF utilities
def getheader(im, palette=None, info=None):
def getheader(
im: Image.Image, palette: _Palette | None = None, info: dict[str, Any] | None = None
) -> tuple[list[bytes], list[int] | None]:
"""
Legacy Method to get Gif data from image.
@@ -995,11 +1146,11 @@ def getheader(im, palette=None, info=None):
:returns: tuple of(list of header items, optimized palette)
"""
used_palette_colors = _get_optimize(im, info)
if info is None:
info = {}
used_palette_colors = _get_optimize(im, info)
if "background" not in info and "background" in im.info:
info["background"] = im.info["background"]
@@ -1011,7 +1162,9 @@ def getheader(im, palette=None, info=None):
return header, used_palette_colors
def getdata(im, offset=(0, 0), **params):
def getdata(
im: Image.Image, offset: tuple[int, int] = (0, 0), **params: Any
) -> list[bytes]:
"""
Legacy Method
@@ -1028,12 +1181,14 @@ def getdata(im, offset=(0, 0), **params):
:returns: List of bytes containing GIF encoded frame data
"""
from io import BytesIO
class Collector:
class Collector(BytesIO):
data = []
def write(self, data):
def write(self, data: Buffer) -> int:
self.data.append(data)
return len(data)
im.load() # make sure raster data is available

View File

@@ -18,17 +18,22 @@ Stuff to translate curve segments to palette values (derived from
the corresponding code in GIMP, written by Federico Mena Quintero.
See the GIMP distribution for more information.)
"""
from __future__ import annotations
from math import log, pi, sin, sqrt
from ._binary import o8
TYPE_CHECKING = False
if TYPE_CHECKING:
from collections.abc import Callable
from typing import IO
EPSILON = 1e-10
"""""" # Enable auto-doc for data member
def linear(middle, pos):
def linear(middle: float, pos: float) -> float:
if pos <= middle:
if middle < EPSILON:
return 0.0
@@ -43,19 +48,19 @@ def linear(middle, pos):
return 0.5 + 0.5 * pos / middle
def curved(middle, pos):
def curved(middle: float, pos: float) -> float:
return pos ** (log(0.5) / log(max(middle, EPSILON)))
def sine(middle, pos):
def sine(middle: float, pos: float) -> float:
return (sin((-pi / 2.0) + pi * linear(middle, pos)) + 1.0) / 2.0
def sphere_increasing(middle, pos):
def sphere_increasing(middle: float, pos: float) -> float:
return sqrt(1.0 - (linear(middle, pos) - 1.0) ** 2)
def sphere_decreasing(middle, pos):
def sphere_decreasing(middle: float, pos: float) -> float:
return 1.0 - sqrt(1.0 - linear(middle, pos) ** 2)
@@ -64,9 +69,22 @@ SEGMENTS = [linear, curved, sine, sphere_increasing, sphere_decreasing]
class GradientFile:
gradient = None
gradient: (
list[
tuple[
float,
float,
float,
list[float],
list[float],
Callable[[float, float], float],
]
]
| None
) = None
def getpalette(self, entries=256):
def getpalette(self, entries: int = 256) -> tuple[bytes, str]:
assert self.gradient is not None
palette = []
ix = 0
@@ -101,8 +119,8 @@ class GradientFile:
class GimpGradientFile(GradientFile):
"""File handler for GIMP's gradient format."""
def __init__(self, fp):
if fp.readline()[:13] != b"GIMP Gradient":
def __init__(self, fp: IO[bytes]) -> None:
if not fp.readline().startswith(b"GIMP Gradient"):
msg = "not a GIMP gradient file"
raise SyntaxError(msg)
@@ -114,7 +132,7 @@ class GimpGradientFile(GradientFile):
count = int(line)
gradient = []
self.gradient = []
for i in range(count):
s = fp.readline().split()
@@ -132,6 +150,4 @@ class GimpGradientFile(GradientFile):
msg = "cannot handle HSV colour space"
raise OSError(msg)
gradient.append((x0, x1, xm, rgb0, rgb1, segment))
self.gradient = gradient
self.gradient.append((x0, x1, xm, rgb0, rgb1, segment))

View File

@@ -13,10 +13,14 @@
#
# See the README file for information on usage and redistribution.
#
from __future__ import annotations
import re
from io import BytesIO
from ._binary import o8
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import IO
class GimpPaletteFile:
@@ -24,14 +28,18 @@ class GimpPaletteFile:
rawmode = "RGB"
def __init__(self, fp):
self.palette = [o8(i) * 3 for i in range(256)]
if fp.readline()[:12] != b"GIMP Palette":
def _read(self, fp: IO[bytes], limit: bool = True) -> None:
if not fp.readline().startswith(b"GIMP Palette"):
msg = "not a GIMP palette file"
raise SyntaxError(msg)
for i in range(256):
palette: list[int] = []
i = 0
while True:
if limit and i == 256 + 3:
break
i += 1
s = fp.readline()
if not s:
break
@@ -39,18 +47,29 @@ class GimpPaletteFile:
# skip fields and comment lines
if re.match(rb"\w+:|#", s):
continue
if len(s) > 100:
if limit and len(s) > 100:
msg = "bad palette file"
raise SyntaxError(msg)
v = tuple(map(int, s.split()[:3]))
if len(v) != 3:
v = s.split(maxsplit=3)
if len(v) < 3:
msg = "bad palette entry"
raise ValueError(msg)
self.palette[i] = o8(v[0]) + o8(v[1]) + o8(v[2])
palette += (int(v[i]) for i in range(3))
if limit and len(palette) == 768:
break
self.palette = b"".join(self.palette)
self.palette = bytes(palette)
def getpalette(self):
def __init__(self, fp: IO[bytes]) -> None:
self._read(fp)
@classmethod
def frombytes(cls, data: bytes) -> GimpPaletteFile:
self = cls.__new__(cls)
self._read(BytesIO(data), False)
return self
def getpalette(self) -> tuple[bytes, str]:
return self.palette, self.rawmode

View File

@@ -8,13 +8,17 @@
#
# See the README file for information on usage and redistribution.
#
from __future__ import annotations
import os
from typing import IO
from . import Image, ImageFile
_handler = None
def register_handler(handler):
def register_handler(handler: ImageFile.StubHandler | None) -> None:
"""
Install application-specific GRIB image handler.
@@ -28,22 +32,20 @@ def register_handler(handler):
# Image adapter
def _accept(prefix):
return prefix[:4] == b"GRIB" and prefix[7] == 1
def _accept(prefix: bytes) -> bool:
return len(prefix) >= 8 and prefix.startswith(b"GRIB") and prefix[7] == 1
class GribStubImageFile(ImageFile.StubImageFile):
format = "GRIB"
format_description = "GRIB"
def _open(self):
offset = self.fp.tell()
def _open(self) -> None:
if not _accept(self.fp.read(8)):
msg = "Not a GRIB file"
raise SyntaxError(msg)
self.fp.seek(offset)
self.fp.seek(-8, os.SEEK_CUR)
# make something up
self._mode = "F"
@@ -53,11 +55,11 @@ class GribStubImageFile(ImageFile.StubImageFile):
if loader:
loader.open(self)
def _load(self):
def _load(self) -> ImageFile.StubHandler | None:
return _handler
def _save(im, fp, filename):
def _save(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:
if _handler is None or not hasattr(_handler, "save"):
msg = "GRIB save handler not installed"
raise OSError(msg)

Some files were not shown because too many files have changed in this diff Show More