This commit is contained in:
Iliyan Angelov
2025-11-30 22:43:09 +02:00
parent 24b40450dd
commit 39fcfff811
1610 changed files with 5442 additions and 1383 deletions

View File

View File

View File

@@ -0,0 +1,33 @@
from sqlalchemy import create_engine, event
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool
from .settings import settings
from .logging_config import get_logger
logger = get_logger(__name__)
DATABASE_URL = settings.database_url
engine = create_engine(DATABASE_URL, poolclass=QueuePool, pool_pre_ping=True, pool_recycle=3600, pool_size=10, max_overflow=20, echo=settings.is_development, future=True, connect_args={'charset': 'utf8mb4', 'connect_timeout': 10})
@event.listens_for(engine, 'connect')
def set_sqlite_pragma(dbapi_conn, connection_record):
logger.debug('New database connection established')
@event.listens_for(engine, 'checkout')
def receive_checkout(dbapi_conn, connection_record, connection_proxy):
logger.debug('Connection checked out from pool')
@event.listens_for(engine, 'checkin')
def receive_checkin(dbapi_conn, connection_record):
logger.debug('Connection returned to pool')
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_db():
db = SessionLocal()
try:
yield db
except Exception:
db.rollback()
raise
finally:
db.close()

View File

@@ -0,0 +1,36 @@
import logging
import sys
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Optional
from .settings import settings
def setup_logging(log_level: Optional[str]=None, log_file: Optional[str]=None, enable_file_logging: bool=True) -> logging.Logger:
level = log_level or settings.LOG_LEVEL
log_file_path = log_file or settings.LOG_FILE
numeric_level = getattr(logging, level.upper(), logging.INFO)
if enable_file_logging and log_file_path:
log_path = Path(log_file_path)
log_path.parent.mkdir(parents=True, exist_ok=True)
detailed_formatter = logging.Formatter(fmt='%(asctime)s | %(levelname)-8s | %(name)s | %(funcName)s:%(lineno)d | %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
simple_formatter = logging.Formatter(fmt='%(asctime)s | %(levelname)-8s | %(message)s', datefmt='%H:%M:%S')
root_logger = logging.getLogger()
root_logger.setLevel(numeric_level)
root_logger.handlers.clear()
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(numeric_level)
console_handler.setFormatter(simple_formatter if settings.is_development else detailed_formatter)
root_logger.addHandler(console_handler)
if enable_file_logging and log_file_path and (not settings.is_development):
file_handler = RotatingFileHandler(log_file_path, maxBytes=settings.LOG_MAX_BYTES, backupCount=settings.LOG_BACKUP_COUNT, encoding='utf-8')
file_handler.setLevel(numeric_level)
file_handler.setFormatter(detailed_formatter)
root_logger.addHandler(file_handler)
logging.getLogger('uvicorn').setLevel(logging.INFO)
logging.getLogger('uvicorn.access').setLevel(logging.INFO if settings.is_development else logging.WARNING)
logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING)
logging.getLogger('slowapi').setLevel(logging.WARNING)
return root_logger
def get_logger(name: str) -> logging.Logger:
return logging.getLogger(name)

View File

@@ -0,0 +1,137 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import Field
from typing import List
import os
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', case_sensitive=False, extra='ignore')
APP_NAME: str = Field(default='Hotel Booking API', description='Application name')
APP_VERSION: str = Field(default='1.0.0', description='Application version')
ENVIRONMENT: str = Field(default='development', description='Environment: development, staging, production')
DEBUG: bool = Field(default=False, description='Debug mode')
API_V1_PREFIX: str = Field(default='/api/v1', description='API v1 prefix')
HOST: str = Field(default='0.0.0.0', description='Server host')
PORT: int = Field(default=8000, description='Server port')
DB_USER: str = Field(default='root', description='Database user')
DB_PASS: str = Field(default='', description='Database password')
DB_NAME: str = Field(default='hotel_db', description='Database name')
DB_HOST: str = Field(default='localhost', description='Database host')
DB_PORT: str = Field(default='3306', description='Database port')
JWT_SECRET: str = Field(default='dev-secret-key-change-in-production-12345', description='JWT secret key')
JWT_ALGORITHM: str = Field(default='HS256', description='JWT algorithm')
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = Field(default=30, description='JWT access token expiration in minutes')
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = Field(default=3, description='JWT refresh token expiration in days (reduced from 7 for better security)')
MAX_LOGIN_ATTEMPTS: int = Field(default=5, description='Maximum failed login attempts before account lockout')
ACCOUNT_LOCKOUT_DURATION_MINUTES: int = Field(default=30, description='Account lockout duration in minutes after max failed attempts')
ENCRYPTION_KEY: str = Field(default='', description='Base64-encoded encryption key for data encryption at rest')
CLIENT_URL: str = Field(default='http://localhost:5173', description='Frontend client URL')
CORS_ORIGINS: List[str] = Field(default_factory=lambda: ['http://localhost:5173', 'http://localhost:3000', 'http://127.0.0.1:5173'], description='Allowed CORS origins')
RATE_LIMIT_ENABLED: bool = Field(default=True, description='Enable rate limiting')
RATE_LIMIT_PER_MINUTE: int = Field(default=60, description='Requests per minute per IP')
CSRF_PROTECTION_ENABLED: bool = Field(default=True, description='Enable CSRF protection')
HSTS_PRELOAD_ENABLED: bool = Field(default=False, description='Enable HSTS preload directive (requires domain submission to hstspreload.org)')
LOG_LEVEL: str = Field(default='INFO', description='Log level: DEBUG, INFO, WARNING, ERROR, CRITICAL')
LOG_FILE: str = Field(default='logs/app.log', description='Log file path')
LOG_MAX_BYTES: int = Field(default=10485760, description='Max log file size (10MB)')
LOG_BACKUP_COUNT: int = Field(default=5, description='Number of backup log files')
SMTP_HOST: str = Field(default='smtp.gmail.com', description='SMTP host')
SMTP_PORT: int = Field(default=587, description='SMTP port')
SMTP_USER: str = Field(default='', description='SMTP username')
SMTP_PASSWORD: str = Field(default='', description='SMTP password')
SMTP_FROM_EMAIL: str = Field(default='', description='From email address')
SMTP_FROM_NAME: str = Field(default='Hotel Booking', description='From name')
UPLOAD_DIR: str = Field(default='uploads', description='Upload directory')
MAX_UPLOAD_SIZE: int = Field(default=5242880, description='Max upload size in bytes (5MB)')
MAX_REQUEST_BODY_SIZE: int = Field(default=10485760, description='Max request body size in bytes (10MB)')
ALLOWED_EXTENSIONS: List[str] = Field(default_factory=lambda: ['jpg', 'jpeg', 'png', 'gif', 'webp'], description='Allowed file extensions')
REDIS_ENABLED: bool = Field(default=False, description='Enable Redis caching')
REDIS_HOST: str = Field(default='localhost', description='Redis host')
REDIS_PORT: int = Field(default=6379, description='Redis port')
REDIS_DB: int = Field(default=0, description='Redis database number')
REDIS_PASSWORD: str = Field(default='', description='Redis password')
REQUEST_TIMEOUT: int = Field(default=30, description='Request timeout in seconds')
HEALTH_CHECK_INTERVAL: int = Field(default=30, description='Health check interval in seconds')
STRIPE_SECRET_KEY: str = Field(default='', description='Stripe secret key')
STRIPE_PUBLISHABLE_KEY: str = Field(default='', description='Stripe publishable key')
STRIPE_WEBHOOK_SECRET: str = Field(default='', description='Stripe webhook secret')
PAYPAL_CLIENT_ID: str = Field(default='', description='PayPal client ID')
PAYPAL_CLIENT_SECRET: str = Field(default='', description='PayPal client secret')
PAYPAL_MODE: str = Field(default='sandbox', description='PayPal mode: sandbox or live')
BORICA_TERMINAL_ID: str = Field(default='', description='Borica Terminal ID')
BORICA_MERCHANT_ID: str = Field(default='', description='Borica Merchant ID')
BORICA_PRIVATE_KEY_PATH: str = Field(default='', description='Borica private key file path')
BORICA_CERTIFICATE_PATH: str = Field(default='', description='Borica certificate file path')
BORICA_GATEWAY_URL: str = Field(default='https://3dsgate-dev.borica.bg/cgi-bin/cgi_link', description='Borica gateway URL (test or production)')
BORICA_MODE: str = Field(default='test', description='Borica mode: test or production')
@property
def database_url(self) -> str:
"""Generate database URL with proper credential escaping to prevent injection."""
from urllib.parse import quote_plus
# Properly escape credentials to handle special characters
user = quote_plus(self.DB_USER)
password = quote_plus(self.DB_PASS)
host = quote_plus(self.DB_HOST)
port = str(self.DB_PORT)
name = quote_plus(self.DB_NAME)
return f'mysql+pymysql://{user}:{password}@{host}:{port}/{name}'
@property
def is_production(self) -> bool:
return self.ENVIRONMENT.lower() == 'production'
@property
def is_development(self) -> bool:
return self.ENVIRONMENT.lower() == 'development'
@property
def redis_url(self) -> str:
if self.REDIS_PASSWORD:
return f'redis://:{self.REDIS_PASSWORD}@{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}'
return f'redis://{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}'
IP_WHITELIST_ENABLED: bool = Field(default=False, description='Enable IP whitelisting for admin endpoints')
ADMIN_IP_WHITELIST: List[str] = Field(default_factory=list, description='List of allowed IP addresses/CIDR ranges for admin endpoints')
def validate_encryption_key(self) -> None:
"""
Validate encryption key is properly configured.
Raises ValueError if key is missing or invalid in production.
"""
if not self.ENCRYPTION_KEY:
if self.is_production:
raise ValueError(
'CRITICAL: ENCRYPTION_KEY is not configured in production. '
'Please set ENCRYPTION_KEY environment variable to a base64-encoded 32-byte key.'
)
else:
# In development, warn but don't fail
import logging
logger = logging.getLogger(__name__)
logger.warning(
'ENCRYPTION_KEY is not configured. Encryption operations may fail. '
'Please set ENCRYPTION_KEY environment variable.'
)
return
# Validate base64 encoding and key length (32 bytes = 44 base64 chars)
try:
import base64
decoded = base64.b64decode(self.ENCRYPTION_KEY)
if len(decoded) != 32:
raise ValueError(
f'ENCRYPTION_KEY must be a base64-encoded 32-byte key. '
f'Received {len(decoded)} bytes after decoding.'
)
except Exception as e:
if self.is_production:
raise ValueError(
f'Invalid ENCRYPTION_KEY format: {str(e)}. '
'Must be a valid base64-encoded 32-byte key.'
)
else:
import logging
logger = logging.getLogger(__name__)
logger.warning(f'Invalid ENCRYPTION_KEY format: {str(e)}')
settings = Settings()

View File

@@ -0,0 +1,52 @@
import json
from typing import Callable, Awaitable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from ...content.schemas.privacy import CookieConsent, CookieCategoryPreferences
from ...shared.config.settings import settings
from ...shared.config.logging_config import get_logger
logger = get_logger(__name__)
COOKIE_CONSENT_COOKIE_NAME = 'cookieConsent'
def _parse_consent_cookie(raw_value: str | None) -> CookieConsent:
if not raw_value:
return CookieConsent()
try:
data = json.loads(raw_value)
return CookieConsent(**data)
except Exception as exc:
logger.warning(f'Failed to parse cookie consent cookie: {exc}')
return CookieConsent()
class CookieConsentMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
raw_cookie = request.cookies.get(COOKIE_CONSENT_COOKIE_NAME)
consent = _parse_consent_cookie(raw_cookie)
consent.categories.necessary = True
request.state.cookie_consent = consent
response = await call_next(request)
if COOKIE_CONSENT_COOKIE_NAME not in request.cookies:
try:
response.set_cookie(key=COOKIE_CONSENT_COOKIE_NAME, value=consent.model_dump_json(), httponly=True, secure=settings.is_production, samesite='lax', max_age=365 * 24 * 60 * 60, path='/')
except Exception as exc:
logger.warning(f'Failed to set default cookie consent cookie: {exc}')
return response
def is_analytics_allowed(request: Request) -> bool:
consent: CookieConsent | None = getattr(request.state, 'cookie_consent', None)
if not consent:
return False
return consent.categories.analytics
def is_marketing_allowed(request: Request) -> bool:
consent: CookieConsent | None = getattr(request.state, 'cookie_consent', None)
if not consent:
return False
return consent.categories.marketing
def is_preferences_allowed(request: Request) -> bool:
consent: CookieConsent | None = getattr(request.state, 'cookie_consent', None)
if not consent:
return False
return consent.categories.preferences

View File

@@ -0,0 +1,135 @@
from fastapi import Request, status, HTTPException
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from sqlalchemy.exc import IntegrityError
from jose.exceptions import JWTError
import traceback
import os
from ...shared.utils.response_helpers import error_response
from ...shared.config.settings import settings
def _add_cors_headers(response: JSONResponse, request: Request) -> JSONResponse:
"""Add CORS headers to response for cross-origin requests."""
origin = request.headers.get('Origin')
if origin:
# Check if origin is allowed (development or production)
if settings.is_development:
# Allow localhost origins in development
if origin.startswith('http://localhost') or origin.startswith('http://127.0.0.1'):
response.headers['Access-Control-Allow-Origin'] = origin
response.headers['Access-Control-Allow-Credentials'] = 'true'
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, PATCH, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = '*'
else:
# In production, check against CORS_ORIGINS
if origin in settings.CORS_ORIGINS:
response.headers['Access-Control-Allow-Origin'] = origin
response.headers['Access-Control-Allow-Credentials'] = 'true'
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, PATCH, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = '*'
return response
async def validation_exception_handler(request: Request, exc: RequestValidationError):
errors = []
for error in exc.errors():
field = '.'.join((str(loc) for loc in error['loc'] if loc != 'body'))
errors.append({'field': field, 'message': error['msg']})
first_error = errors[0]['message'] if errors else 'Validation error'
request_id = getattr(request.state, 'request_id', None) if hasattr(request, 'state') else None
response_content = error_response(
message=first_error,
errors=errors,
request_id=request_id
)
response = JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=response_content)
# Add CORS headers to error responses
return _add_cors_headers(response, request)
async def integrity_error_handler(request: Request, exc: IntegrityError):
error_msg = str(exc.orig) if hasattr(exc, 'orig') else str(exc)
request_id = getattr(request.state, 'request_id', None) if hasattr(request, 'state') else None
if 'Duplicate entry' in error_msg or 'UNIQUE constraint' in error_msg:
response_content = error_response(
message='Duplicate entry',
errors=[{'message': 'This record already exists'}],
request_id=request_id
)
response = JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=response_content)
else:
response_content = error_response(
message='Database integrity error',
request_id=request_id
)
response = JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=response_content)
# Add CORS headers to error responses
return _add_cors_headers(response, request)
async def jwt_error_handler(request: Request, exc: JWTError):
request_id = getattr(request.state, 'request_id', None) if hasattr(request, 'state') else None
response_content = error_response(
message='Invalid token',
request_id=request_id
)
response = JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=response_content)
# Add CORS headers to error responses
return _add_cors_headers(response, request)
async def http_exception_handler(request: Request, exc: HTTPException):
request_id = getattr(request.state, 'request_id', None) if hasattr(request, 'state') else None
if isinstance(exc.detail, dict):
response_content = exc.detail.copy()
if request_id and 'request_id' not in response_content:
response_content['request_id'] = request_id
# Ensure it has standard error response format
if 'status' not in response_content:
response_content['status'] = 'error'
if 'success' not in response_content:
response_content['success'] = False
response = JSONResponse(status_code=exc.status_code, content=response_content)
else:
response_content = error_response(
message=str(exc.detail) if exc.detail else 'An error occurred',
request_id=request_id
)
response = JSONResponse(status_code=exc.status_code, content=response_content)
# Add CORS headers to error responses
return _add_cors_headers(response, request)
async def general_exception_handler(request: Request, exc: Exception):
from ...shared.config.logging_config import get_logger
logger = get_logger(__name__)
request_id = getattr(request.state, 'request_id', None)
logger.error(f'Unhandled exception: {type(exc).__name__}: {str(exc)}', extra={'request_id': request_id, 'path': request.url.path, 'method': request.method, 'exception_type': type(exc).__name__}, exc_info=True)
if isinstance(exc, Exception) and hasattr(exc, 'status_code'):
status_code = exc.status_code
if hasattr(exc, 'detail'):
detail = exc.detail
if isinstance(detail, dict):
response = JSONResponse(status_code=status_code, content=detail)
return _add_cors_headers(response, request)
message = str(detail) if detail else 'An error occurred'
else:
message = str(exc) if str(exc) else 'Internal server error'
else:
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
message = str(exc) if str(exc) else 'Internal server error'
response_content = error_response(
message=message,
request_id=request_id
)
# NEVER include stack traces in production responses
# Always log stack traces server-side only for debugging
if settings.is_development:
# Only include stack traces in development mode
# Double-check environment to prevent accidental exposure
env_check = os.getenv('ENVIRONMENT', 'development').lower()
if env_check == 'development':
response_content['stack'] = traceback.format_exc()
else:
# Log warning if development flag is set but environment says otherwise
logger.warning(f'is_development=True but ENVIRONMENT={env_check}. Not including stack trace in response.')
# Stack traces are always logged server-side via exc_info=True above
response = JSONResponse(status_code=status_code, content=response_content)
# Add CORS headers to error responses
return _add_cors_headers(response, request)

View File

@@ -0,0 +1,21 @@
import uuid
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from ...shared.config.logging_config import get_logger
logger = get_logger(__name__)
class RequestIDMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
request_id = request.headers.get('X-Request-ID') or str(uuid.uuid4())
request.state.request_id = request_id
logger.info(f'Request started: {request.method} {request.url.path}', extra={'request_id': request_id, 'method': request.method, 'path': request.url.path, 'client_ip': request.client.host if request.client else None})
try:
response = await call_next(request)
response.headers['X-Request-ID'] = request_id
logger.info(f'Request completed: {request.method} {request.url.path} - {response.status_code}', extra={'request_id': request_id, 'method': request.method, 'path': request.url.path, 'status_code': response.status_code})
return response
except Exception as e:
logger.error(f'Request failed: {request.method} {request.url.path} - {str(e)}', extra={'request_id': request_id, 'method': request.method, 'path': request.url.path, 'error': str(e)}, exc_info=True)
raise

View File

@@ -0,0 +1,53 @@
from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from ...shared.config.logging_config import get_logger
from ...shared.config.settings import settings
logger = get_logger(__name__)
class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
"""
Middleware to enforce maximum request body size limits.
Prevents DoS attacks by rejecting requests that exceed the configured
maximum body size before they are processed.
"""
def __init__(self, app, max_size: int = None):
super().__init__(app)
self.max_size = max_size or settings.MAX_REQUEST_BODY_SIZE
async def dispatch(self, request: Request, call_next):
# Skip size check for methods that don't have bodies
if request.method in ['GET', 'HEAD', 'OPTIONS', 'DELETE']:
return await call_next(request)
# Check Content-Length header if available
content_length = request.headers.get('content-length')
if content_length:
try:
size = int(content_length)
if size > self.max_size:
logger.warning(
f"Request body size {size} bytes exceeds maximum {self.max_size} bytes "
f"from {request.client.host if request.client else 'unknown'}"
)
return JSONResponse(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
content={
'status': 'error',
'message': f'Request body too large. Maximum size: {self.max_size // 1024 // 1024}MB'
}
)
except (ValueError, TypeError):
# Invalid content-length header, let it pass and let FastAPI handle it
pass
# For streaming requests without Content-Length, we need to check the body
# This is handled by limiting the body read size
response = await call_next(request)
return response

View File

@@ -0,0 +1,16 @@
import asyncio
from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from ...shared.config.logging_config import get_logger
from ...shared.config.settings import settings
logger = get_logger(__name__)
class TimeoutMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
try:
response = await asyncio.wait_for(call_next(request), timeout=settings.REQUEST_TIMEOUT)
return response
except asyncio.TimeoutError:
logger.warning(f'Request timeout: {request.method} {request.url.path}', extra={'request_id': getattr(request.state, 'request_id', None), 'method': request.method, 'path': request.url.path, 'timeout': settings.REQUEST_TIMEOUT})
raise HTTPException(status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail={'status': 'error', 'message': 'Request timeout. Please try again.'})

View File

View File

@@ -0,0 +1,24 @@
"""
Utility functions for currency handling
"""
CURRENCY_SYMBOLS = {
'USD': '$',
'EUR': '',
'GBP': '£',
'JPY': '¥',
'CNY': '¥',
'KRW': '',
'SGD': 'S$',
'THB': '฿',
'AUD': 'A$',
'CAD': 'C$',
'VND': '',
'INR': '',
'CHF': 'CHF',
'NZD': 'NZ$'
}
def get_currency_symbol(currency: str) -> str:
"""Get currency symbol for a given currency code"""
return CURRENCY_SYMBOLS.get(currency.upper(), currency)

View File

@@ -0,0 +1,42 @@
from typing import Dict
from decimal import Decimal
EXCHANGE_RATES: Dict[str, Decimal] = {'VND': Decimal('1.0'), 'USD': Decimal('0.000041'), 'EUR': Decimal('0.000038'), 'GBP': Decimal('0.000033'), 'JPY': Decimal('0.0061'), 'CNY': Decimal('0.00029'), 'KRW': Decimal('0.055'), 'SGD': Decimal('0.000055'), 'THB': Decimal('0.0015'), 'AUD': Decimal('0.000062'), 'CAD': Decimal('0.000056')}
SUPPORTED_CURRENCIES = list(EXCHANGE_RATES.keys())
class CurrencyService:
@staticmethod
def get_supported_currencies() -> list:
return SUPPORTED_CURRENCIES
@staticmethod
def convert_amount(amount: float, from_currency: str, to_currency: str) -> float:
from_currency = from_currency.upper()
to_currency = to_currency.upper()
if from_currency == to_currency:
return amount
if from_currency not in EXCHANGE_RATES:
raise ValueError(f'Unsupported source currency: {from_currency}')
if to_currency not in EXCHANGE_RATES:
raise ValueError(f'Unsupported target currency: {to_currency}')
amount_vnd = Decimal(str(amount)) / EXCHANGE_RATES[from_currency]
converted_amount = amount_vnd * EXCHANGE_RATES[to_currency]
return float(converted_amount)
@staticmethod
def get_exchange_rate(from_currency: str, to_currency: str) -> float:
from_currency = from_currency.upper()
to_currency = to_currency.upper()
if from_currency == to_currency:
return 1.0
if from_currency not in EXCHANGE_RATES:
raise ValueError(f'Unsupported source currency: {from_currency}')
if to_currency not in EXCHANGE_RATES:
raise ValueError(f'Unsupported target currency: {to_currency}')
rate = EXCHANGE_RATES[to_currency] / EXCHANGE_RATES[from_currency]
return float(rate)
@staticmethod
def format_currency_code(currency: str) -> str:
return currency.upper() if currency else 'VND'
currency_service = CurrencyService()

View File

@@ -0,0 +1,459 @@
from datetime import datetime
from typing import Optional
from ...shared.config.database import SessionLocal
from ...system.models.system_settings import SystemSettings
def _get_company_settings():
try:
db = SessionLocal()
try:
settings = {}
setting_keys = [
"company_name",
"company_tagline",
"company_logo_url",
"company_phone",
"company_email",
"company_address",
]
for key in setting_keys:
setting = db.query(SystemSettings).filter(
SystemSettings.key == key
).first()
if setting and setting.value:
settings[key] = setting.value
else:
settings[key] = None
return settings
finally:
db.close()
except Exception:
return {
"company_name": None,
"company_tagline": None,
"company_logo_url": None,
"company_phone": None,
"company_email": None,
"company_address": None,
}
def get_base_template(content: str, title: str = "Hotel Booking", client_url: str = "http://localhost:5173") -> str:
company_settings = _get_company_settings()
company_name = company_settings.get("company_name") or "Hotel Booking"
company_tagline = company_settings.get("company_tagline") or "Excellence Redefined"
company_logo_url = company_settings.get("company_logo_url")
company_phone = company_settings.get("company_phone")
company_email = company_settings.get("company_email")
company_address = company_settings.get("company_address")
logo_html = ""
if company_logo_url:
if not company_logo_url.startswith('http'):
server_url = client_url.replace('://localhost:5173', '').replace('://localhost:3000', '')
if not server_url.startswith('http'):
server_url = f"http://{server_url}" if ':' not in server_url.split('//')[-1] else server_url
full_logo_url = f"{server_url}{company_logo_url}" if company_logo_url.startswith('/') else f"{server_url}/{company_logo_url}"
else:
full_logo_url = company_logo_url
logo_html = f'<img src="{full_logo_url}" alt="{company_name}" style="max-width: 200px; height: auto; margin-bottom: 20px;" />'
else:
logo_html = f'<div style="font-size: 32px; font-weight: bold; color: #d4af37; margin-bottom: 20px;">{company_name}</div>'
footer_contact = ""
if company_phone or company_email or company_address:
footer_contact = '<div style="margin-top: 30px; padding-top: 20px; border-top: 1px solid #e5e7eb; color: #6b7280; font-size: 14px;">'
if company_phone:
footer_contact += f'<p style="margin: 5px 0;">Phone: {company_phone}</p>'
if company_email:
footer_contact += f'<p style="margin: 5px 0;">Email: <a href="mailto:{company_email}" style="color: #d4af37; text-decoration: none;">{company_email}</a></p>'
if company_address:
formatted_address = company_address.replace('\n', '<br>')
footer_contact += f'<p style="margin: 5px 0;">Address: {formatted_address}</p>'
footer_contact += '</div>'
return f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{title}</title>
</head>
<body style="margin: 0; padding: 0; font-family: Arial, sans-serif; background-color: #f3f4f6;">
<table role="presentation" style="width: 100%; border-collapse: collapse;">
<tr>
<td style="padding: 40px 20px;">
<table role="presentation" style="max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 10px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);">
<tr>
<td style="padding: 40px; text-align: center; background: linear-gradient(135deg, #1a1a1a 0%, #0f0f0f 100%); border-radius: 10px 10px 0 0;">
{logo_html}
<p style="color: #d4af37; font-size: 12px; letter-spacing: 2px; margin: 10px 0 0 0; text-transform: uppercase;">{company_tagline}</p>
</td>
</tr>
<tr>
<td style="padding: 40px;">
{content}
</td>
</tr>
<tr>
<td style="padding: 40px; background-color: #f9fafb; border-radius: 0 0 10px 10px; text-align: center;">
<p style="color: #6b7280; font-size: 12px; margin: 0 0 10px 0;"{datetime.now().year} {company_name}. All rights reserved.</p>
{footer_contact}
</td>
</tr>
</table>
</td>
</tr>
</table>
</body>
</html>
"""
def welcome_email_template(name: str, email: str, client_url: str) -> str:
company_settings = _get_company_settings()
company_name = company_settings.get("company_name") or "Hotel Booking"
content = f"""
<div style="max-width: 600px; margin: 0 auto; padding: 20px; font-family: Arial, sans-serif;">
<h1 style="color: #1a1a1a; font-size: 28px; margin-bottom: 20px;">Welcome to {company_name}!</h1>
<p style="font-size: 16px; color: #374151; line-height: 1.6; margin-bottom: 20px;">
Dear {name},
</p>
<p style="font-size: 16px; color: #374151; line-height: 1.6; margin-bottom: 20px;">
Thank you for registering with {company_name}. We're excited to have you as part of our community!
</p>
<p style="font-size: 16px; color: #374151; line-height: 1.6; margin-bottom: 30px;">
You can now explore our luxurious rooms and make bookings. If you have any questions, feel free to contact us.
</p>
<div style="text-align: center; margin-top: 30px;">
<a href="{client_url}/rooms" style="display: inline-block; background: #d4af37; color: #ffffff; padding: 12px 30px; text-decoration: none; border-radius: 6px; font-weight: bold;">
Explore Our Rooms
</a>
</div>
</div>
"""
return get_base_template(content, f"Welcome to {company_name}", client_url)
def password_reset_email_template(reset_url: str) -> str:
content = f"""
<div style="max-width: 600px; margin: 0 auto; padding: 20px; font-family: Arial, sans-serif;">
<h1 style="color: #1a1a1a; font-size: 28px; margin-bottom: 20px;">Password Reset Request</h1>
<p style="font-size: 16px; color: #374151; line-height: 1.6; margin-bottom: 20px;">
You have requested to reset your password. Click the button below to create a new password.
</p>
<p style="font-size: 16px; color: #374151; line-height: 1.6; margin-bottom: 20px;">
If you did not request this, please ignore this email and your password will remain unchanged.
</p>
<div style="text-align: center; margin: 30px 0;">
<a href="{reset_url}" style="display: inline-block; background: #d4af37; color: #ffffff; padding: 12px 30px; text-decoration: none; border-radius: 6px; font-weight: bold;">
Reset Password
</a>
</div>
<p style="font-size: 14px; color: #6b7280; line-height: 1.6; margin-top: 30px;">
If the button doesn't work, copy and paste this link into your browser:<br>
<a href="{reset_url}" style="color: #d4af37; word-break: break-all;">{reset_url}</a>
</p>
<p style="font-size: 14px; color: #6b7280; line-height: 1.6; margin-top: 20px;">
This link will expire in 1 hour for security reasons.
</p>
</div>
"""
company_settings = _get_company_settings()
company_name = company_settings.get("company_name") or "Hotel Booking"
return get_base_template(content, f"Password Reset - {company_name}", reset_url.split('/reset-password')[0] if '/reset-password' in reset_url else "http://localhost:5173")
def password_changed_email_template(email: str) -> str:
content = f"""
<div style="max-width: 600px; margin: 0 auto; padding: 20px; font-family: Arial, sans-serif;">
<h1 style="color: #1a1a1a; font-size: 28px; margin-bottom: 20px;">Password Changed Successfully</h1>
<p style="font-size: 16px; color: #374151; line-height: 1.6; margin-bottom: 20px;">
Your password has been successfully changed.
</p>
<div style="background: #d1fae5; border-left: 4px solid #10b981; padding: 15px; margin: 20px 0; border-radius: 4px;">
<p style="margin: 0; color: #065f46; font-size: 14px;">
<strong>Security Notice:</strong> If you did not make this change, please contact us immediately to secure your account.
</p>
</div>
<p style="font-size: 16px; color: #374151; line-height: 1.6; margin-top: 30px;">
For your security, we recommend:
</p>
<ul style="font-size: 16px; color: #374151; line-height: 1.8; margin: 20px 0;">
<li>Using a strong, unique password</li>
<li>Not sharing your password with anyone</li>
<li>Logging out when using shared devices</li>
</ul>
</div>
"""
company_settings = _get_company_settings()
company_name = company_settings.get("company_name") or "Hotel Booking"
return get_base_template(content, f"Password Changed - {company_name}", "http://localhost:5173")
def booking_confirmation_email_template(
booking_number: str,
guest_name: str,
room_number: str,
room_type: str,
check_in: str,
check_out: str,
num_guests: int,
total_price: float,
requires_deposit: bool,
deposit_amount: Optional[float] = None,
amount_paid: Optional[float] = None,
payment_type: Optional[str] = None,
original_price: Optional[float] = None,
discount_amount: Optional[float] = None,
promotion_code: Optional[str] = None,
client_url: str = "http://localhost:5173",
currency_symbol: str = "$"
) -> str:
deposit_info = ""
if requires_deposit and deposit_amount and amount_paid is None:
deposit_info = f"""
<div style="background: #fef3c7; border-left: 4px solid #f59e0b; padding: 15px; margin: 20px 0; border-radius: 4px;">
<p style="margin: 0; color: #92400e; font-size: 14px;">
<strong>Deposit Required:</strong> A deposit of {currency_symbol}{deposit_amount:.2f} is required to confirm your booking.
</p>
</div>
"""
payment_breakdown = ""
if amount_paid is not None:
remaining_due = total_price - amount_paid
payment_type_label = "Deposit Payment" if payment_type == "deposit" else "Full Payment"
payment_breakdown = f"""
<div style="background: #f9fafb; padding: 20px; border-radius: 8px; margin: 20px 0;">
<h3 style="color: #1a1a1a; font-size: 18px; margin-bottom: 15px;">Payment Summary</h3>
<div style="display: flex; justify-content: space-between; margin-bottom: 10px;">
<span style="color: #374151;">{payment_type_label}:</span>
<span style="color: #1a1a1a; font-weight: bold;">{currency_symbol}{amount_paid:.2f}</span>
</div>
"""
if remaining_due > 0:
payment_breakdown += f"""
<div style="display: flex; justify-content: space-between; margin-top: 10px; padding-top: 10px; border-top: 1px solid #e5e7eb;">
<span style="color: #374151;">Remaining Due:</span>
<span style="color: #dc2626; font-weight: bold;">{currency_symbol}{remaining_due:.2f}</span>
</div>
"""
else:
payment_breakdown += f"""
<div style="display: flex; justify-content: space-between; margin-top: 10px; padding-top: 10px; border-top: 1px solid #e5e7eb;">
<span style="color: #059669; font-weight: bold;">Payment Complete</span>
</div>
"""
payment_breakdown += "</div>"
discount_info = ""
if original_price and discount_amount and discount_amount > 0:
discount_info = f"""
<div style="background: #d1fae5; border-left: 4px solid #10b981; padding: 15px; margin: 20px 0; border-radius: 4px;">
<p style="margin: 0; color: #065f46; font-size: 14px;">
<strong>Discount Applied:</strong> {currency_symbol}{discount_amount:.2f} off
{f'(Promo Code: {promotion_code})' if promotion_code else ''}
</p>
</div>
"""
content = f"""
<div style="max-width: 600px; margin: 0 auto; padding: 20px; font-family: Arial, sans-serif;">
<div style="background: linear-gradient(135deg, #d1fae5 0%, #a7f3d0 100%); padding: 30px; border-radius: 10px; text-align: center; margin-bottom: 30px;">
<div style="font-size: 48px; color: #10b981; margin-bottom: 10px;">✓</div>
<h1 style="color: #10b981; margin: 0; font-size: 28px;">Booking Confirmed!</h1>
</div>
<div style="background: #ffffff; padding: 30px; border-radius: 10px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
<p style="font-size: 16px; color: #374151; line-height: 1.6; margin-bottom: 20px;">
Dear {guest_name},
</p>
<p style="font-size: 16px; color: #374151; line-height: 1.6; margin-bottom: 20px;">
Your booking has been confirmed! We're excited to welcome you to our hotel.
</p>
<div style="background: #f9fafb; padding: 20px; border-radius: 8px; margin: 30px 0; border-left: 4px solid #d4af37;">
<h3 style="color: #1a1a1a; font-size: 18px; margin-bottom: 15px;">Booking Details</h3>
<p style="margin: 8px 0; color: #374151;"><strong>Booking Number:</strong> {booking_number}</p>
<p style="margin: 8px 0; color: #374151;"><strong>Room:</strong> {room_type} (Room {room_number})</p>
<p style="margin: 8px 0; color: #374151;"><strong>Check-in:</strong> {check_in}</p>
<p style="margin: 8px 0; color: #374151;"><strong>Check-out:</strong> {check_out}</p>
<p style="margin: 8px 0; color: #374151;"><strong>Guests:</strong> {num_guests}</p>
<p style="margin: 8px 0; color: #374151;"><strong>Total Price:</strong> {currency_symbol}{total_price:.2f}</p>
</div>
{discount_info}
{deposit_info}
{payment_breakdown}
<p style="font-size: 16px; color: #374151; line-height: 1.6; margin-top: 30px;">
If you have any questions or need to make changes to your booking, please don't hesitate to contact us.
</p>
<div style="text-align: center; margin-top: 30px;">
<a href="{client_url}/bookings/{booking_number}" style="display: inline-block; background: #d4af37; color: #ffffff; padding: 12px 30px; text-decoration: none; border-radius: 6px; font-weight: bold;">
View Booking Details
</a>
</div>
</div>
</div>
"""
company_settings = _get_company_settings()
company_name = company_settings.get("company_name") or "Hotel Booking"
return get_base_template(content, f"Booking Confirmation - {company_name}", client_url)
def payment_confirmation_email_template(
booking_number: str,
guest_name: str,
amount: float,
payment_method: str,
transaction_id: Optional[str] = None,
payment_type: Optional[str] = None,
total_price: Optional[float] = None,
client_url: str = "http://localhost:5173",
currency_symbol: str = "$"
) -> str:
transaction_info = ""
if transaction_id:
transaction_info = f"""
<div style="background: #f9fafb; padding: 15px; border-radius: 8px; margin: 15px 0;">
<p style="margin: 0; color: #374151; font-size: 14px;">
<strong>Transaction ID:</strong> {transaction_id}
</p>
</div>
"""
payment_type_info = ""
if payment_type:
payment_type_label = "Deposit Payment (20%)" if payment_type == "deposit" else "Full Payment"
payment_type_info = f"""
<div style="background: #eff6ff; border-left: 4px solid #3b82f6; padding: 15px; margin: 15px 0; border-radius: 4px;">
<p style="margin: 0; color: #1e40af; font-size: 14px;">
<strong>Payment Type:</strong> {payment_type_label}
</p>
</div>
"""
total_price_info = ""
remaining_due_info = ""
if total_price is not None:
total_price_info = f"""
<div style="display: flex; justify-content: space-between; margin: 15px 0; padding: 10px 0; border-top: 1px solid #e5e7eb;">
<span style="color: #374151; font-weight: bold;">Total Booking Amount:</span>
<span style="color: #1a1a1a; font-weight: bold; font-size: 18px;">{currency_symbol}{total_price:.2f}</span>
</div>
"""
if payment_type == "deposit" and total_price > amount:
remaining_due = total_price - amount
remaining_due_info = f"""
<div style="background: #fef3c7; border-left: 4px solid #f59e0b; padding: 15px; margin: 15px 0; border-radius: 4px;">
<p style="margin: 0; color: #92400e; font-size: 14px;">
<strong>Remaining Balance:</strong> {currency_symbol}{remaining_due:.2f} (due at check-in)
</p>
</div>
"""
content = f"""
<div style="max-width: 600px; margin: 0 auto; padding: 20px; font-family: Arial, sans-serif;">
<div style="background: linear-gradient(135deg, #dbeafe 0%, #bfdbfe 100%); padding: 30px; border-radius: 10px; text-align: center; margin-bottom: 30px;">
<div style="font-size: 48px; color: #3b82f6; margin-bottom: 10px;">💳</div>
<h1 style="color: #3b82f6; margin: 0; font-size: 28px;">Payment Confirmed!</h1>
</div>
<div style="background: #ffffff; padding: 30px; border-radius: 10px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
<p style="font-size: 16px; color: #374151; line-height: 1.6; margin-bottom: 20px;">
Dear {guest_name},
</p>
<p style="font-size: 16px; color: #374151; line-height: 1.6; margin-bottom: 20px;">
We have successfully received your payment. Thank you!
</p>
<div style="background: #f9fafb; padding: 20px; border-radius: 8px; margin: 30px 0; border-left: 4px solid #3b82f6;">
<h3 style="color: #1a1a1a; font-size: 18px; margin-bottom: 15px;">Payment Details</h3>
<p style="margin: 8px 0; color: #374151;"><strong>Booking Number:</strong> {booking_number}</p>
<p style="margin: 8px 0; color: #374151;"><strong>Amount Paid:</strong> <span style="color: #10b981; font-weight: bold; font-size: 18px;">{currency_symbol}{amount:.2f}</span></p>
<p style="margin: 8px 0; color: #374151;"><strong>Payment Method:</strong> {payment_method.title()}</p>
{transaction_info}
{payment_type_info}
{total_price_info}
{remaining_due_info}
</div>
<p style="font-size: 16px; color: #374151; line-height: 1.6; margin-top: 30px;">
Your payment has been processed successfully. You will receive a separate booking confirmation email with all the details of your reservation.
</p>
<div style="text-align: center; margin-top: 30px;">
<a href="{client_url}/bookings/{booking_number}" style="display: inline-block; background: #d4af37; color: #ffffff; padding: 12px 30px; text-decoration: none; border-radius: 6px; font-weight: bold;">
View Booking Details
</a>
</div>
</div>
</div>
"""
company_settings = _get_company_settings()
company_name = company_settings.get("company_name") or "Hotel Booking"
return get_base_template(content, f"Payment Confirmation - {company_name}", client_url)
def booking_status_changed_email_template(
booking_number: str,
guest_name: str,
status: str,
client_url: str = "http://localhost:5173"
) -> str:
status_colors = {
"confirmed": ("#10b981", "Confirmed", "", "#d1fae5", "#a7f3d0"),
"cancelled": ("#ef4444", "Cancelled", "", "#fee2e2", "#fecaca"),
"checked_in": ("#3b82f6", "Checked In", "", "#dbeafe", "#bfdbfe"),
"checked_out": ("#8b5cf6", "Checked Out", "", "#ede9fe", "#ddd6fe"),
}
color, status_text, icon, bg_start, bg_end = status_colors.get(status.lower(), ("#6b7280", "Updated", "", "#f3f4f6", "#e5e7eb"))
content = f"""
<div style="max-width: 600px; margin: 0 auto; padding: 20px; font-family: Arial, sans-serif;">
<div style="background: linear-gradient(135deg, {bg_start} 0%, {bg_end} 100%); padding: 30px; border-radius: 10px; text-align: center; margin-bottom: 30px;">
<div style="font-size: 48px; color: {color}; margin-bottom: 10px;">{icon}</div>
<h1 style="color: {color}; margin: 0; font-size: 28px;">Booking {status_text}</h1>
</div>
<div style="background: #ffffff; padding: 30px; border-radius: 10px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
<p style="font-size: 16px; color: #374151; line-height: 1.6; margin-bottom: 20px;">
Dear {guest_name},
</p>
<p style="font-size: 16px; color: #374151; line-height: 1.6; margin-bottom: 20px;">
Your booking <strong style="color: {color};">{booking_number}</strong> has been <strong>{status_text.lower()}</strong>.
</p>
<div style="background: #f9fafb; padding: 20px; border-radius: 8px; margin: 30px 0; border-left: 4px solid {color};">
<p style="margin: 0; font-size: 14px; color: #6b7280;">
<strong>Booking Number:</strong> {booking_number}<br>
<strong>Status:</strong> <span style="color: {color}; font-weight: bold;">{status_text}</span>
</p>
</div>
<p style="font-size: 16px; color: #374151; line-height: 1.6; margin-bottom: 30px;">
If you have any questions or need assistance, please don't hesitate to contact us.
</p>
<div style="text-align: center; margin-top: 30px;">
<a href="{client_url}/bookings" style="display: inline-block; background: {color}; color: #ffffff; padding: 12px 30px; text-decoration: none; border-radius: 6px; font-weight: bold;">
View Booking Details
</a>
</div>
</div>
</div>
"""
company_settings = _get_company_settings()
company_name = company_settings.get("company_name") or "Hotel Booking"
return get_base_template(content, f"Booking {status_text} - {company_name}", client_url)

View File

@@ -0,0 +1,93 @@
"""
Standardized error handling utilities for route handlers.
This module provides decorators and helpers for consistent error handling
across all route handlers.
"""
from functools import wraps
from fastapi import HTTPException
from sqlalchemy.orm import Session
from typing import Callable, Any
from ...shared.config.logging_config import get_logger
logger = get_logger(__name__)
def handle_db_errors(func: Callable) -> Callable:
"""
Decorator to handle database errors consistently.
Automatically rolls back transactions on errors.
Usage:
@router.post('/')
@handle_db_errors
async def create_item(..., db: Session = Depends(get_db)):
db.add(item)
db.commit()
return item
Note: This decorator expects 'db' to be in kwargs or as a positional argument.
"""
@wraps(func)
async def wrapper(*args, **kwargs):
# Find db session in kwargs or args
db = kwargs.get('db')
if not db:
# Try to find in args (if db is positional)
for arg in args:
if isinstance(arg, Session):
db = arg
break
try:
return await func(*args, **kwargs)
except HTTPException:
# Re-raise HTTPExceptions as-is (they're intentional)
raise
except ValueError as e:
# ValueErrors are usually validation errors
if db:
db.rollback()
logger.warning(f'ValueError in {func.__name__}: {str(e)}')
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
# All other exceptions are server errors
if db:
db.rollback()
logger.error(f'Error in {func.__name__}: {type(e).__name__}: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
return wrapper
def safe_db_operation(operation: Callable, db: Session, error_message: str = "Database operation failed") -> Any:
"""
Safely execute a database operation with automatic rollback on error.
Usage:
result = safe_db_operation(
lambda: db.add(item) or db.commit(),
db,
"Failed to create item"
)
Args:
operation: Callable that performs the database operation
db: Database session
error_message: Custom error message to use on failure
Returns:
Result of the operation
Raises:
HTTPException: If the operation fails
"""
try:
return operation()
except HTTPException:
raise
except Exception as e:
db.rollback()
logger.error(f'{error_message}: {type(e).__name__}: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=error_message)

View File

@@ -0,0 +1,148 @@
"""
File validation utilities for secure file uploads.
Validates file types using magic bytes (file signatures) to prevent spoofing.
"""
from PIL import Image
import io
from typing import Tuple, Optional
from fastapi import UploadFile, HTTPException, status
# Magic bytes for common image formats
IMAGE_MAGIC_BYTES = {
b'\xFF\xD8\xFF': 'image/jpeg',
b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A': 'image/png',
b'GIF87a': 'image/gif',
b'GIF89a': 'image/gif',
b'RIFF': 'image/webp', # WebP files start with RIFF, need deeper check
b'\x00\x00\x01\x00': 'image/x-icon',
b'\x00\x00\x02\x00': 'image/x-icon',
}
ALLOWED_IMAGE_TYPES = {'image/jpeg', 'image/png', 'image/gif', 'image/webp'}
def validate_image_file_signature(file_content: bytes, filename: str) -> Tuple[bool, str]:
"""
Validate file type using magic bytes (file signature).
This prevents MIME type spoofing attacks.
Args:
file_content: The file content as bytes
filename: The filename (for extension checking)
Returns:
Tuple of (is_valid, error_message)
"""
if not file_content:
return False, "File is empty"
# Check magic bytes for image types
file_start = file_content[:16] # Check first 16 bytes
detected_type = None
# Check for JPEG
if file_content.startswith(b'\xFF\xD8\xFF'):
detected_type = 'image/jpeg'
# Check for PNG
elif file_content.startswith(b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A'):
detected_type = 'image/png'
# Check for GIF
elif file_content.startswith(b'GIF87a') or file_content.startswith(b'GIF89a'):
detected_type = 'image/gif'
# Check for WebP (RIFF header with WEBP in bytes 8-11)
elif file_content.startswith(b'RIFF') and len(file_content) > 12:
if file_content[8:12] == b'WEBP':
detected_type = 'image/webp'
# Check for ICO
elif file_content.startswith(b'\x00\x00\x01\x00') or file_content.startswith(b'\x00\x00\x02\x00'):
detected_type = 'image/x-icon'
# If magic bytes don't match known image types, try PIL verification
if not detected_type:
try:
# Try to open with PIL to verify it's a valid image
img = Image.open(io.BytesIO(file_content))
img.verify()
# Get format from PIL
img_format = img.format.lower() if img.format else None
if img_format == 'jpeg':
detected_type = 'image/jpeg'
elif img_format == 'png':
detected_type = 'image/png'
elif img_format == 'gif':
detected_type = 'image/gif'
elif img_format == 'webp':
detected_type = 'image/webp'
else:
return False, f"Unsupported image format: {img_format}"
except Exception:
return False, "File is not a valid image or is corrupted"
# Verify detected type is in allowed list
if detected_type not in ALLOWED_IMAGE_TYPES and detected_type != 'image/x-icon':
return False, f"File type {detected_type} is not allowed. Allowed types: {', '.join(ALLOWED_IMAGE_TYPES)}"
return True, detected_type
async def validate_uploaded_image(file: UploadFile, max_size: int) -> bytes:
"""
Validate an uploaded image file completely.
Args:
file: FastAPI UploadFile object
max_size: Maximum file size in bytes
Returns:
File content as bytes
Raises:
HTTPException if validation fails
"""
# Check MIME type first (quick check)
if not file.content_type or not file.content_type.startswith('image/'):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f'File must be an image. Received MIME type: {file.content_type}'
)
# Read file content
content = await file.read()
# Validate file size
if len(content) > max_size:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f'File size ({len(content)} bytes) exceeds maximum allowed size ({max_size} bytes / {max_size // 1024 // 1024}MB)'
)
# Validate file signature (magic bytes)
is_valid, result = validate_image_file_signature(content, file.filename or '')
if not is_valid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f'Invalid file type: {result}. File signature validation failed. Please upload a valid image file.'
)
# Additional PIL validation to ensure image is not corrupted
try:
img = Image.open(io.BytesIO(content))
# Verify image integrity
img.verify()
# Re-open for further processing (verify() closes the image)
img = Image.open(io.BytesIO(content))
# Check image dimensions to prevent decompression bombs
if img.width > 10000 or img.height > 10000:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Image dimensions too large. Maximum dimensions: 10000x10000 pixels'
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f'Invalid or corrupted image file: {str(e)}'
)
return content

View File

@@ -0,0 +1,99 @@
"""
HTML sanitization utilities for backend content storage.
Prevents XSS attacks by sanitizing HTML before storing in database.
"""
import bleach
from typing import Optional
# Allowed HTML tags for rich content
ALLOWED_TAGS = [
'p', 'br', 'strong', 'em', 'u', 'b', 'i', 'span', 'div',
'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
'ul', 'ol', 'li',
'a', 'blockquote', 'pre', 'code',
'table', 'thead', 'tbody', 'tr', 'th', 'td',
'img', 'hr', 'section', 'article'
]
# Allowed HTML attributes
ALLOWED_ATTRIBUTES = {
'a': ['href', 'title', 'target', 'rel'],
'img': ['src', 'alt', 'title', 'width', 'height', 'class'],
'div': ['class', 'id', 'style'],
'span': ['class', 'id', 'style'],
'p': ['class', 'id', 'style'],
'h1': ['class', 'id'],
'h2': ['class', 'id'],
'h3': ['class', 'id'],
'h4': ['class', 'id'],
'h5': ['class', 'id'],
'h6': ['class', 'id'],
'table': ['class', 'id'],
'tr': ['class', 'id'],
'th': ['class', 'id', 'colspan', 'rowspan'],
'td': ['class', 'id', 'colspan', 'rowspan'],
}
# Allowed URL schemes
ALLOWED_SCHEMES = ['http', 'https', 'mailto', 'tel']
def sanitize_html(html_content: Optional[str]) -> str:
"""
Sanitize HTML content to prevent XSS attacks.
Args:
html_content: HTML string to sanitize (can be None)
Returns:
Sanitized HTML string safe for storage
"""
if not html_content:
return ''
# Clean HTML content
cleaned = bleach.clean(
html_content,
tags=ALLOWED_TAGS,
attributes=ALLOWED_ATTRIBUTES,
protocols=ALLOWED_SCHEMES,
strip=True, # Strip disallowed tags instead of escaping
strip_comments=True, # Remove HTML comments
)
# Additional link sanitization - ensure external links have rel="noopener"
if '<a' in cleaned:
import re
# Add rel="noopener noreferrer" to external links
def add_rel(match):
tag = match.group(0)
if 'href=' in tag and ('http://' in tag or 'https://' in tag):
if 'rel=' not in tag:
# Insert rel attribute before closing >
return tag[:-1] + ' rel="noopener noreferrer">'
elif 'noopener' not in tag and 'noreferrer' not in tag:
# Add to existing rel attribute
tag = tag.replace('rel="', 'rel="noopener noreferrer ')
tag = tag.replace("rel='", "rel='noopener noreferrer ")
return tag
return tag
cleaned = re.sub(r'<a[^>]*>', add_rel, cleaned)
return cleaned
def sanitize_text_for_html(text: Optional[str]) -> str:
"""
Escape text content to be safely included in HTML.
Use this for plain text that should be displayed as-is.
Args:
text: Plain text string to escape
Returns:
HTML-escaped string
"""
if not text:
return ''
return bleach.clean(text, tags=[], strip=True)

View File

@@ -0,0 +1,91 @@
import aiosmtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
import os
import logging
from ...shared.config.settings import settings
from ...shared.config.database import SessionLocal
from ...system.models.system_settings import SystemSettings
logger = logging.getLogger(__name__)
def _get_smtp_settings_from_db():
try:
db = SessionLocal()
try:
smtp_settings = {}
setting_keys = ['smtp_host', 'smtp_port', 'smtp_user', 'smtp_password', 'smtp_from_email', 'smtp_from_name', 'smtp_use_tls']
for key in setting_keys:
setting = db.query(SystemSettings).filter(SystemSettings.key == key).first()
if setting and setting.value:
smtp_settings[key] = setting.value
if smtp_settings.get('smtp_host') and smtp_settings.get('smtp_user') and smtp_settings.get('smtp_password'):
return smtp_settings
return None
finally:
db.close()
except Exception as e:
logger.debug(f'Could not fetch SMTP settings from database: {str(e)}')
return None
async def send_email(to: str, subject: str, html: str=None, text: str=None):
try:
db_smtp_settings = _get_smtp_settings_from_db()
if db_smtp_settings:
mail_host = db_smtp_settings.get('smtp_host')
mail_user = db_smtp_settings.get('smtp_user')
mail_pass = db_smtp_settings.get('smtp_password')
mail_port = int(db_smtp_settings.get('smtp_port', '587'))
mail_use_tls = db_smtp_settings.get('smtp_use_tls', 'true').lower() == 'true'
from_address = db_smtp_settings.get('smtp_from_email')
from_name = db_smtp_settings.get('smtp_from_name', 'Hotel Booking')
logger.info('Using SMTP settings from system_settings database')
else:
mail_host = settings.SMTP_HOST or os.getenv('MAIL_HOST')
mail_user = settings.SMTP_USER or os.getenv('MAIL_USER')
mail_pass = settings.SMTP_PASSWORD or os.getenv('MAIL_PASS')
mail_port = settings.SMTP_PORT or int(os.getenv('MAIL_PORT', '587'))
mail_secure = os.getenv('MAIL_SECURE', 'false').lower() == 'true'
mail_use_tls = mail_secure
client_url = settings.CLIENT_URL or os.getenv('CLIENT_URL', 'http://localhost:5173')
from_address = settings.SMTP_FROM_EMAIL or os.getenv('MAIL_FROM')
if not from_address:
domain = client_url.replace('https://', '').replace('http://', '').split('/')[0]
from_address = f'no-reply@{domain}'
from_name = settings.SMTP_FROM_NAME or 'Hotel Booking'
logger.info('Using SMTP settings from config/environment variables')
from_header = f'{from_name} <{from_address}>'
if not (mail_host and mail_user and mail_pass):
error_msg = 'SMTP mailer not configured. Set SMTP_HOST, SMTP_USER and SMTP_PASSWORD in .env file.'
logger.error(error_msg)
raise ValueError(error_msg)
message = MIMEMultipart('alternative')
message['From'] = from_header
message['To'] = to
message['Subject'] = subject
if text:
message.attach(MIMEText(text, 'plain'))
if html:
message.attach(MIMEText(html, 'html'))
if not text and (not html):
message.attach(MIMEText('', 'plain'))
if mail_port == 465 or mail_use_tls:
use_tls = True
start_tls = False
elif mail_port == 587:
use_tls = False
start_tls = True
else:
use_tls = False
start_tls = False
logger.info(f'Attempting to send email to {to} via {mail_host}:{mail_port} (use_tls: {use_tls}, start_tls: {start_tls})')
smtp_client = aiosmtplib.SMTP(hostname=mail_host, port=mail_port, use_tls=use_tls, start_tls=start_tls, username=mail_user, password=mail_pass)
try:
await smtp_client.connect()
await smtp_client.send_message(message)
logger.info(f'Email sent successfully to {to}')
finally:
await smtp_client.quit()
except Exception as e:
error_msg = f'Failed to send email to {to}: {type(e).__name__}: {str(e)}'
logger.error(error_msg, exc_info=True)
raise

View File

@@ -0,0 +1,59 @@
"""
Password validation utilities for enforcing password strength requirements.
"""
import re
from typing import Tuple, List
# Password strength requirements
MIN_PASSWORD_LENGTH = 8
REQUIRE_UPPERCASE = True
REQUIRE_LOWERCASE = True
REQUIRE_NUMBER = True
REQUIRE_SPECIAL = True
def validate_password_strength(password: str) -> Tuple[bool, List[str]]:
"""
Validate password meets strength requirements.
Args:
password: The password to validate
Returns:
Tuple of (is_valid, list_of_errors)
"""
errors = []
if not password:
return False, ['Password is required']
# Check minimum length
if len(password) < MIN_PASSWORD_LENGTH:
errors.append(f'Password must be at least {MIN_PASSWORD_LENGTH} characters long')
# Check for uppercase letter
if REQUIRE_UPPERCASE and not re.search(r'[A-Z]', password):
errors.append('Password must contain at least one uppercase letter')
# Check for lowercase letter
if REQUIRE_LOWERCASE and not re.search(r'[a-z]', password):
errors.append('Password must contain at least one lowercase letter')
# Check for number
if REQUIRE_NUMBER and not re.search(r'\d', password):
errors.append('Password must contain at least one number')
# Check for special character
if REQUIRE_SPECIAL and not re.search(r'[!@#$%^&*(),.?":{}|<>]', password):
errors.append('Password must contain at least one special character (!@#$%^&*(),.?":{}|<>)')
# Check for common weak passwords
common_passwords = [
'password', '12345678', 'qwerty', 'abc123', 'password123',
'admin', 'letmein', 'welcome', 'monkey', '1234567890'
]
if password.lower() in common_passwords:
errors.append('Password is too common. Please choose a stronger password')
is_valid = len(errors) == 0
return is_valid, errors

View File

@@ -0,0 +1,21 @@
"""
Utility functions for request handling
"""
from typing import Optional
from fastapi import Request
def get_request_id(request: Optional[Request] = None) -> Optional[str]:
"""
Extract request_id from request state.
Args:
request: FastAPI Request object
Returns:
Request ID string or None
"""
if not request:
return None
return getattr(request.state, 'request_id', None) if hasattr(request, 'state') else None

View File

@@ -0,0 +1,86 @@
"""
Utility functions for standardizing API responses
"""
from typing import Any, Dict, Optional
from fastapi import HTTPException, Request
def success_response(
data: Any = None,
message: Optional[str] = None,
**kwargs
) -> Dict[str, Any]:
"""
Create a standardized success response.
Returns both 'success' (boolean) and 'status' (string) for backward compatibility.
"""
response: Dict[str, Any] = {
'success': True,
'status': 'success'
}
if data is not None:
response['data'] = data
if message:
response['message'] = message
# Add any additional fields
response.update(kwargs)
return response
def error_response(
message: str,
errors: Optional[list] = None,
request_id: Optional[str] = None,
**kwargs
) -> Dict[str, Any]:
"""
Create a standardized error response.
"""
response: Dict[str, Any] = {
'success': False,
'status': 'error',
'message': message
}
if errors:
response['errors'] = errors
if request_id:
response['request_id'] = request_id
response.update(kwargs)
return response
def raise_http_exception(
status_code: int,
message: str,
errors: Optional[list] = None,
request: Optional[Request] = None,
**kwargs
) -> None:
"""
Raise an HTTPException with standardized error response format.
Args:
status_code: HTTP status code
message: Error message
errors: Optional list of error details
request: Optional Request object to extract request_id
**kwargs: Additional fields to include in response
"""
request_id = None
if request:
request_id = getattr(request.state, 'request_id', None) if hasattr(request, 'state') else None
detail = error_response(
message=message,
errors=errors,
request_id=request_id,
**kwargs
)
raise HTTPException(status_code=status_code, detail=detail)

View File

@@ -0,0 +1,47 @@
"""
Utility functions for role-based access control
"""
from sqlalchemy.orm import Session
from ...auth.models.user import User
from ...auth.models.role import Role
def get_user_role_name(user: User, db: Session) -> str:
"""Get the role name for a user"""
role = db.query(Role).filter(Role.id == user.role_id).first()
return role.name if role else 'customer'
def is_admin(user: User, db: Session) -> bool:
"""Check if user is admin"""
return get_user_role_name(user, db) == 'admin'
def is_staff(user: User, db: Session) -> bool:
"""Check if user is staff"""
return get_user_role_name(user, db) == 'staff'
def is_accountant(user: User, db: Session) -> bool:
"""Check if user is accountant"""
return get_user_role_name(user, db) == 'accountant'
def is_customer(user: User, db: Session) -> bool:
"""Check if user is customer"""
return get_user_role_name(user, db) == 'customer'
def can_access_all_payments(user: User, db: Session) -> bool:
"""Check if user can see all payments (admin or accountant)"""
role_name = get_user_role_name(user, db)
return role_name in ['admin', 'accountant']
def can_access_all_invoices(user: User, db: Session) -> bool:
"""Check if user can see all invoices (admin or accountant)"""
role_name = get_user_role_name(user, db)
return role_name in ['admin', 'accountant']
def can_create_invoices(user: User, db: Session) -> bool:
"""Check if user can create invoices (admin, staff, or accountant)"""
role_name = get_user_role_name(user, db)
return role_name in ['admin', 'staff', 'accountant']
def can_manage_users(user: User, db: Session) -> bool:
"""Check if user can manage users (admin only)"""
return is_admin(user, db)

View File

@@ -0,0 +1,43 @@
"""
Database transaction context manager for consistent transaction handling.
This module provides a context manager that automatically handles
database commits and rollbacks, ensuring data consistency.
"""
from contextlib import contextmanager
from sqlalchemy.orm import Session
from typing import Generator
from ...shared.config.logging_config import get_logger
logger = get_logger(__name__)
@contextmanager
def transaction(db: Session) -> Generator[Session, None, None]:
"""
Context manager for database transactions.
Automatically commits on success, rolls back on error.
Usage:
with transaction(db):
db.add(booking)
# Auto-commits on success, rolls back on error
Args:
db: SQLAlchemy database session
Yields:
The database session
Raises:
Any exception that occurs during the transaction
"""
try:
yield db
db.commit()
logger.debug('Transaction committed successfully')
except Exception as e:
db.rollback()
logger.error(f'Transaction rolled back due to error: {type(e).__name__}: {str(e)}', exc_info=True)
raise