490 lines
24 KiB
Python
490 lines
24 KiB
Python
from fastapi import FastAPI, Request, HTTPException, Depends, status
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.exceptions import RequestValidationError
|
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
|
from jose.exceptions import JWTError
|
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
|
from slowapi.util import get_remote_address
|
|
from slowapi.errors import RateLimitExceeded
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
import sys
|
|
import secrets
|
|
import os
|
|
import re
|
|
import logging
|
|
from .shared.config.settings import settings
|
|
from .shared.config.logging_config import setup_logging, get_logger
|
|
from .shared.config.database import engine, Base, get_db
|
|
from . import models
|
|
from sqlalchemy.orm import Session
|
|
logger = setup_logging()
|
|
logger.info(f'Starting {settings.APP_NAME} v{settings.APP_VERSION} in {settings.ENVIRONMENT} mode')
|
|
from .shared.middleware.error_handler import validation_exception_handler, integrity_error_handler, jwt_error_handler, http_exception_handler, general_exception_handler
|
|
from .shared.middleware.request_id import RequestIDMiddleware
|
|
from .security.middleware.security import SecurityHeadersMiddleware
|
|
from .shared.middleware.timeout import TimeoutMiddleware
|
|
from .shared.middleware.cookie_consent import CookieConsentMiddleware
|
|
from .security.middleware.csrf import CSRFProtectionMiddleware
|
|
from .shared.middleware.request_size_limit import RequestSizeLimitMiddleware
|
|
from .security.middleware.admin_ip_whitelist import AdminIPWhitelistMiddleware
|
|
if settings.is_development:
|
|
logger.info('Creating database tables (development mode)')
|
|
Base.metadata.create_all(bind=engine)
|
|
else:
|
|
try:
|
|
from .content.models.cookie_policy import CookiePolicy
|
|
from .content.models.cookie_integration_config import CookieIntegrationConfig
|
|
from .content.models.page_content import PageContent
|
|
logger.info('Ensuring required tables exist')
|
|
CookiePolicy.__table__.create(bind=engine, checkfirst=True)
|
|
CookieIntegrationConfig.__table__.create(bind=engine, checkfirst=True)
|
|
PageContent.__table__.create(bind=engine, checkfirst=True)
|
|
except Exception as e:
|
|
logger.error(f'Failed to ensure required tables exist: {e}')
|
|
|
|
app = FastAPI(title=settings.APP_NAME, description='Enterprise-grade Hotel Booking API', version=settings.APP_VERSION, docs_url='/api/docs' if not settings.is_production else None, redoc_url='/api/redoc' if not settings.is_production else None, openapi_url='/api/openapi.json' if not settings.is_production else None)
|
|
app.add_middleware(RequestIDMiddleware)
|
|
app.add_middleware(CookieConsentMiddleware)
|
|
# Add API versioning middleware
|
|
from .shared.middleware.api_versioning import APIVersioningMiddleware
|
|
app.add_middleware(APIVersioningMiddleware, default_version='v1')
|
|
if settings.REQUEST_TIMEOUT > 0:
|
|
app.add_middleware(TimeoutMiddleware)
|
|
app.add_middleware(SecurityHeadersMiddleware)
|
|
app.add_middleware(RequestSizeLimitMiddleware, max_size=settings.MAX_REQUEST_BODY_SIZE)
|
|
logger.info(f'Request size limiting enabled: {settings.MAX_REQUEST_BODY_SIZE // 1024 // 1024}MB max body size')
|
|
if settings.CSRF_PROTECTION_ENABLED:
|
|
app.add_middleware(CSRFProtectionMiddleware)
|
|
logger.info('CSRF protection enabled')
|
|
if settings.IP_WHITELIST_ENABLED:
|
|
app.add_middleware(AdminIPWhitelistMiddleware)
|
|
logger.info(f'Admin IP whitelisting enabled with {len(settings.ADMIN_IP_WHITELIST)} IP(s)/CIDR range(s)')
|
|
if settings.RATE_LIMIT_ENABLED:
|
|
# Use role-based rate limiting
|
|
from .security.middleware.role_based_rate_limit import create_role_based_limiter
|
|
limiter = create_role_based_limiter()
|
|
app.state.limiter = limiter
|
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
|
logger.info(f'Role-based rate limiting enabled: Admin={settings.RATE_LIMIT_ADMIN_PER_MINUTE}/min, Staff={settings.RATE_LIMIT_STAFF_PER_MINUTE}/min, Accountant={settings.RATE_LIMIT_ACCOUNTANT_PER_MINUTE}/min, Customer={settings.RATE_LIMIT_CUSTOMER_PER_MINUTE}/min, Default={settings.RATE_LIMIT_PER_MINUTE}/min')
|
|
|
|
# CORS middleware must be added LAST to handle OPTIONS preflight requests before other middleware
|
|
# In FastAPI/Starlette, middleware is executed in reverse order (last added = first executed = outermost)
|
|
# So adding CORS last ensures it wraps all other middleware and handles OPTIONS requests early
|
|
if settings.is_development:
|
|
# More restrictive CORS even in development for better security practices
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origin_regex='http://(localhost|127\\.0\\.0\\.1)(:\\d+)?',
|
|
allow_credentials=True,
|
|
allow_methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], # Explicit methods
|
|
allow_headers=['Content-Type', 'Authorization', 'X-XSRF-TOKEN', 'X-Requested-With', 'X-Request-ID'] # Explicit headers
|
|
)
|
|
logger.info('CORS configured for development (allowing localhost with explicit methods/headers)')
|
|
else:
|
|
# Validate CORS_ORIGINS in production
|
|
if not settings.CORS_ORIGINS or len(settings.CORS_ORIGINS) == 0:
|
|
logger.warning('CORS_ORIGINS is empty in production. This may block legitimate requests.')
|
|
logger.warning('Please set CORS_ORIGINS environment variable with allowed origins.')
|
|
else:
|
|
# Log CORS configuration for security audit
|
|
logger.info(f'CORS configured for production with {len(settings.CORS_ORIGINS)} allowed origin(s)')
|
|
if logger.isEnabledFor(logging.DEBUG):
|
|
logger.debug(f'Allowed CORS origins: {", ".join(settings.CORS_ORIGINS)}')
|
|
|
|
# SECURITY: Use explicit headers instead of wildcard to prevent header injection
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=settings.CORS_ORIGINS or [],
|
|
allow_credentials=True,
|
|
allow_methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'],
|
|
allow_headers=['Content-Type', 'Authorization', 'X-XSRF-TOKEN', 'X-Requested-With', 'X-Request-ID', 'Accept', 'Accept-Language']
|
|
)
|
|
uploads_dir = Path(__file__).parent.parent / settings.UPLOAD_DIR
|
|
uploads_dir.mkdir(exist_ok=True)
|
|
app.add_exception_handler(HTTPException, http_exception_handler)
|
|
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
|
app.add_exception_handler(IntegrityError, integrity_error_handler)
|
|
app.add_exception_handler(JWTError, jwt_error_handler)
|
|
app.add_exception_handler(Exception, general_exception_handler)
|
|
|
|
@app.get('/health', tags=['health'])
|
|
@app.get('/api/health', tags=['health'])
|
|
async def health_check(db: Session=Depends(get_db)):
|
|
"""
|
|
Public health check endpoint.
|
|
Returns minimal information for security - no sensitive details exposed.
|
|
"""
|
|
health_status = {
|
|
'status': 'healthy',
|
|
'timestamp': datetime.utcnow().isoformat(),
|
|
# SECURITY: Don't expose service name, version, or environment in public endpoint
|
|
'checks': {
|
|
'api': 'ok',
|
|
'database': 'unknown'
|
|
# SECURITY: Don't expose disk_space or memory details publicly
|
|
}
|
|
}
|
|
|
|
# Check database
|
|
try:
|
|
from sqlalchemy import text
|
|
db.execute(text('SELECT 1'))
|
|
health_status['checks']['database'] = 'ok'
|
|
except OperationalError as e:
|
|
health_status['status'] = 'unhealthy'
|
|
health_status['checks']['database'] = 'error'
|
|
# SECURITY: Don't expose database error details publicly
|
|
logger.error(f'Database health check failed: {str(e)}')
|
|
# Remove error details from response
|
|
return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content=health_status)
|
|
except Exception as e:
|
|
health_status['status'] = 'unhealthy'
|
|
health_status['checks']['database'] = 'error'
|
|
# SECURITY: Don't expose error details publicly
|
|
logger.error(f'Health check failed: {str(e)}')
|
|
# Remove error details from response
|
|
return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content=health_status)
|
|
|
|
# SECURITY: Disk space and memory checks removed from public endpoint
|
|
# These details should only be available on internal/admin health endpoint
|
|
|
|
# Determine overall status
|
|
if health_status['status'] == 'healthy' and any(
|
|
check == 'error' for check in health_status['checks'].values()
|
|
):
|
|
health_status['status'] = 'unhealthy'
|
|
|
|
status_code = status.HTTP_200_OK
|
|
if health_status['status'] == 'unhealthy':
|
|
status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
|
elif health_status['status'] == 'degraded':
|
|
status_code = status.HTTP_200_OK # Still operational but degraded
|
|
|
|
return JSONResponse(status_code=status_code, content=health_status)
|
|
|
|
@app.get('/metrics', tags=['monitoring'])
|
|
async def metrics(
|
|
current_user = Depends(lambda: None)
|
|
):
|
|
"""
|
|
Protected metrics endpoint - requires admin or staff authentication.
|
|
SECURITY: Prevents information disclosure to unauthorized users.
|
|
"""
|
|
from ..security.middleware.auth import authorize_roles
|
|
|
|
# Only allow admin and staff to access metrics
|
|
# Use authorize_roles as dependency - it will check authorization automatically
|
|
admin_or_staff = authorize_roles('admin', 'staff')
|
|
# FastAPI will inject dependencies when this dependency is resolved
|
|
current_user = admin_or_staff()
|
|
|
|
return {
|
|
'status': 'success',
|
|
'service': settings.APP_NAME,
|
|
'version': settings.APP_VERSION,
|
|
'environment': settings.ENVIRONMENT,
|
|
'timestamp': datetime.utcnow().isoformat()
|
|
}
|
|
|
|
# Custom route for serving uploads with CORS headers
|
|
# This route takes precedence over the mount below
|
|
from fastapi.responses import FileResponse
|
|
|
|
@app.options('/uploads/{file_path:path}')
|
|
async def serve_upload_file_options(file_path: str, request: Request):
|
|
"""Handle CORS preflight for upload files."""
|
|
origin = request.headers.get('origin')
|
|
if origin:
|
|
if settings.is_development:
|
|
if re.match(r'http://(localhost|127\.0\.0\.1)(:\d+)?', origin):
|
|
return JSONResponse(
|
|
content={},
|
|
headers={
|
|
'Access-Control-Allow-Origin': origin,
|
|
'Access-Control-Allow-Credentials': 'true',
|
|
'Access-Control-Allow-Methods': 'GET, HEAD, OPTIONS',
|
|
'Access-Control-Allow-Headers': '*',
|
|
'Access-Control-Max-Age': '3600'
|
|
}
|
|
)
|
|
elif origin in (settings.CORS_ORIGINS or []):
|
|
return JSONResponse(
|
|
content={},
|
|
headers={
|
|
'Access-Control-Allow-Origin': origin,
|
|
'Access-Control-Allow-Credentials': 'true',
|
|
'Access-Control-Allow-Methods': 'GET, HEAD, OPTIONS',
|
|
'Access-Control-Allow-Headers': '*',
|
|
'Access-Control-Max-Age': '3600'
|
|
}
|
|
)
|
|
return JSONResponse(content={})
|
|
|
|
@app.get('/uploads/{file_path:path}')
|
|
@app.head('/uploads/{file_path:path}')
|
|
async def serve_upload_file(file_path: str, request: Request):
|
|
"""Serve uploaded files with proper CORS headers."""
|
|
file_location = uploads_dir / file_path
|
|
|
|
# Security: Prevent directory traversal
|
|
try:
|
|
resolved_path = file_location.resolve()
|
|
resolved_uploads = uploads_dir.resolve()
|
|
if not str(resolved_path).startswith(str(resolved_uploads)):
|
|
raise HTTPException(status_code=403, detail="Access denied")
|
|
except (ValueError, OSError):
|
|
raise HTTPException(status_code=404, detail="File not found")
|
|
|
|
if not file_location.exists() or not file_location.is_file():
|
|
raise HTTPException(status_code=404, detail="File not found")
|
|
|
|
# Get origin from request
|
|
origin = request.headers.get('origin')
|
|
|
|
# Determine media type based on file extension
|
|
media_type = None
|
|
file_ext = file_location.suffix.lower()
|
|
if file_ext == '.webp':
|
|
media_type = 'image/webp'
|
|
elif file_ext in ['.jpg', '.jpeg']:
|
|
media_type = 'image/jpeg'
|
|
elif file_ext == '.png':
|
|
media_type = 'image/png'
|
|
elif file_ext == '.gif':
|
|
media_type = 'image/gif'
|
|
elif file_ext == '.ico':
|
|
media_type = 'image/x-icon'
|
|
|
|
# Prepare response with appropriate media type
|
|
response = FileResponse(str(file_location), media_type=media_type)
|
|
|
|
# Add CORS headers if origin matches
|
|
if origin:
|
|
if settings.is_development:
|
|
if re.match(r'http://(localhost|127\.0\.0\.1)(:\d+)?', origin):
|
|
response.headers['Access-Control-Allow-Origin'] = origin
|
|
response.headers['Access-Control-Allow-Credentials'] = 'true'
|
|
response.headers['Access-Control-Allow-Methods'] = 'GET, HEAD, OPTIONS'
|
|
response.headers['Access-Control-Allow-Headers'] = '*'
|
|
response.headers['Access-Control-Expose-Headers'] = '*'
|
|
elif origin in (settings.CORS_ORIGINS or []):
|
|
response.headers['Access-Control-Allow-Origin'] = origin
|
|
response.headers['Access-Control-Allow-Credentials'] = 'true'
|
|
response.headers['Access-Control-Allow-Methods'] = 'GET, HEAD, OPTIONS'
|
|
response.headers['Access-Control-Allow-Headers'] = '*'
|
|
response.headers['Access-Control-Expose-Headers'] = '*'
|
|
|
|
return response
|
|
|
|
# Mount static files as fallback (routes take precedence)
|
|
from starlette.staticfiles import StaticFiles
|
|
app.mount('/uploads-static', StaticFiles(directory=str(uploads_dir)), name='uploads-static')
|
|
|
|
# Import all route modules from feature-based structure
|
|
from .auth.routes import auth_routes, user_routes
|
|
from .rooms.routes import room_routes, advanced_room_routes, rate_plan_routes
|
|
from .bookings.routes import booking_routes, group_booking_routes
|
|
from .bookings.routes.upsell_routes import router as upsell_routes
|
|
from .payments.routes import payment_routes, invoice_routes, financial_routes, audit_trail_routes
|
|
from .payments.routes.approval_routes import router as financial_approval_routes
|
|
from .payments.routes.gl_routes import router as gl_routes
|
|
from .payments.routes.reconciliation_routes import router as reconciliation_routes
|
|
from .payments.routes.accountant_security_routes import router as accountant_security_routes
|
|
from .hotel_services.routes import service_routes, service_booking_routes, inventory_routes, guest_request_routes, staff_shift_routes
|
|
from .content.routes import (
|
|
banner_routes, page_content_routes, home_routes, about_routes,
|
|
contact_routes, contact_content_routes, footer_routes, privacy_routes,
|
|
admin_privacy_routes, terms_routes, refunds_routes, cancellation_routes,
|
|
accessibility_routes, faq_routes, blog_routes
|
|
)
|
|
from .reviews.routes import review_routes, favorite_routes
|
|
from .loyalty.routes import promotion_routes, loyalty_routes, package_routes
|
|
from .guest_management.routes import guest_profile_routes
|
|
from .guest_management.routes.complaint_routes import router as complaint_routes
|
|
from .notifications.routes import chat_routes, notification_routes, email_campaign_routes
|
|
from .analytics.routes import analytics_routes, report_routes, audit_routes
|
|
from .security.routes import security_routes, compliance_routes
|
|
from .system.routes import system_settings_routes, workflow_routes, task_routes, approval_routes, backup_routes
|
|
from .ai.routes import ai_assistant_routes
|
|
from .compliance.routes import gdpr_routes
|
|
from .compliance.routes.gdpr_admin_routes import router as gdpr_admin_routes
|
|
from .integrations.routes import webhook_routes, api_key_routes
|
|
from .auth.routes import session_routes
|
|
|
|
# Register all routes with /api prefix (removed duplicate registrations)
|
|
# Using /api prefix as standard, API versioning can be handled via headers if needed
|
|
api_prefix = '/api'
|
|
app.include_router(auth_routes.router, prefix=api_prefix)
|
|
app.include_router(room_routes.router, prefix=api_prefix)
|
|
app.include_router(booking_routes.router, prefix=api_prefix)
|
|
app.include_router(group_booking_routes.router, prefix=api_prefix)
|
|
app.include_router(upsell_routes, prefix=api_prefix)
|
|
app.include_router(payment_routes.router, prefix=api_prefix)
|
|
app.include_router(invoice_routes.router, prefix=api_prefix)
|
|
app.include_router(financial_routes.router, prefix=api_prefix)
|
|
app.include_router(audit_trail_routes.router, prefix=api_prefix)
|
|
app.include_router(financial_approval_routes, prefix=api_prefix)
|
|
app.include_router(gl_routes, prefix=api_prefix)
|
|
app.include_router(reconciliation_routes, prefix=api_prefix)
|
|
app.include_router(accountant_security_routes, prefix=api_prefix)
|
|
app.include_router(banner_routes.router, prefix=api_prefix)
|
|
app.include_router(favorite_routes.router, prefix=api_prefix)
|
|
app.include_router(service_routes.router, prefix=api_prefix)
|
|
app.include_router(service_booking_routes.router, prefix=api_prefix)
|
|
app.include_router(promotion_routes.router, prefix=api_prefix)
|
|
app.include_router(report_routes.router, prefix=api_prefix)
|
|
app.include_router(review_routes.router, prefix=api_prefix)
|
|
app.include_router(user_routes.router, prefix=api_prefix)
|
|
app.include_router(audit_routes.router, prefix=api_prefix)
|
|
app.include_router(admin_privacy_routes.router, prefix=api_prefix)
|
|
app.include_router(system_settings_routes.router, prefix=api_prefix)
|
|
app.include_router(contact_routes.router, prefix=api_prefix)
|
|
app.include_router(home_routes.router, prefix=api_prefix)
|
|
app.include_router(about_routes.router, prefix=api_prefix)
|
|
app.include_router(contact_content_routes.router, prefix=api_prefix)
|
|
app.include_router(footer_routes.router, prefix=api_prefix)
|
|
app.include_router(privacy_routes.router, prefix=api_prefix)
|
|
app.include_router(terms_routes.router, prefix=api_prefix)
|
|
app.include_router(refunds_routes.router, prefix=api_prefix)
|
|
app.include_router(cancellation_routes.router, prefix=api_prefix)
|
|
app.include_router(accessibility_routes.router, prefix=api_prefix)
|
|
app.include_router(faq_routes.router, prefix=api_prefix)
|
|
app.include_router(chat_routes.router, prefix=api_prefix)
|
|
app.include_router(loyalty_routes.router, prefix=api_prefix)
|
|
app.include_router(guest_profile_routes.router, prefix=api_prefix)
|
|
app.include_router(complaint_routes, prefix=api_prefix)
|
|
app.include_router(analytics_routes.router, prefix=api_prefix)
|
|
app.include_router(workflow_routes.router, prefix=api_prefix)
|
|
app.include_router(task_routes.router, prefix=api_prefix)
|
|
app.include_router(notification_routes.router, prefix=api_prefix)
|
|
app.include_router(advanced_room_routes.router, prefix=api_prefix)
|
|
app.include_router(inventory_routes.router, prefix=api_prefix)
|
|
app.include_router(guest_request_routes.router, prefix=api_prefix)
|
|
app.include_router(staff_shift_routes.router, prefix=api_prefix)
|
|
app.include_router(rate_plan_routes.router, prefix=api_prefix)
|
|
app.include_router(package_routes.router, prefix=api_prefix)
|
|
app.include_router(security_routes.router, prefix=api_prefix)
|
|
app.include_router(compliance_routes.router, prefix=api_prefix)
|
|
app.include_router(email_campaign_routes.router, prefix=api_prefix)
|
|
app.include_router(page_content_routes.router, prefix=api_prefix)
|
|
app.include_router(blog_routes.router, prefix=api_prefix)
|
|
app.include_router(ai_assistant_routes.router, prefix=api_prefix)
|
|
app.include_router(approval_routes.router, prefix=api_prefix)
|
|
app.include_router(gdpr_routes.router, prefix=api_prefix)
|
|
app.include_router(gdpr_admin_routes, prefix=api_prefix)
|
|
app.include_router(webhook_routes.router, prefix=api_prefix)
|
|
app.include_router(api_key_routes.router, prefix=api_prefix)
|
|
app.include_router(session_routes.router, prefix=api_prefix)
|
|
app.include_router(backup_routes.router, prefix=api_prefix)
|
|
logger.info('All routes registered successfully')
|
|
|
|
def ensure_jwt_secret():
|
|
"""
|
|
Validate JWT secret is properly configured.
|
|
|
|
SECURITY: JWT_SECRET must be explicitly set via environment variable.
|
|
No default values are acceptable for security.
|
|
"""
|
|
current_secret = settings.JWT_SECRET
|
|
|
|
# SECURITY: JWT_SECRET validation is now handled in settings.py
|
|
# This function is kept for backward compatibility and logging
|
|
if not current_secret or current_secret.strip() == '':
|
|
if settings.is_production:
|
|
# This should not happen as settings validation should catch it
|
|
error_msg = (
|
|
'CRITICAL SECURITY ERROR: JWT_SECRET is not configured. '
|
|
'Please set JWT_SECRET environment variable before starting the application.'
|
|
)
|
|
logger.error(error_msg)
|
|
raise ValueError(error_msg)
|
|
else:
|
|
logger.warning(
|
|
'JWT_SECRET is not configured. Authentication will fail. '
|
|
'Set JWT_SECRET environment variable before starting the application.'
|
|
)
|
|
else:
|
|
# Validate secret strength
|
|
if len(current_secret) < 64:
|
|
if settings.is_production:
|
|
logger.warning(
|
|
f'JWT_SECRET is only {len(current_secret)} characters. '
|
|
'Recommend using at least 64 characters for production security.'
|
|
)
|
|
logger.info('✓ JWT secret is configured')
|
|
|
|
@app.on_event('startup')
|
|
async def startup_event():
|
|
ensure_jwt_secret()
|
|
|
|
# Validate encryption key configuration
|
|
try:
|
|
settings.validate_encryption_key()
|
|
except ValueError as e:
|
|
logger.error(str(e))
|
|
if settings.is_production:
|
|
raise # Fail fast in production
|
|
|
|
# Start AI Training Scheduler for automatic self-learning
|
|
try:
|
|
from .ai.services.ai_training_scheduler import get_training_scheduler
|
|
scheduler = get_training_scheduler()
|
|
scheduler.start()
|
|
logger.info('AI Training Scheduler started - automatic self-learning enabled')
|
|
except Exception as e:
|
|
logger.error(f'Failed to start AI Training Scheduler: {str(e)}', exc_info=True)
|
|
# Don't fail app startup if scheduler fails
|
|
|
|
logger.info(f'{settings.APP_NAME} started successfully')
|
|
logger.info(f'Environment: {settings.ENVIRONMENT}')
|
|
logger.info(f'Debug mode: {settings.DEBUG}')
|
|
logger.info(f'API version: {settings.API_V1_PREFIX}')
|
|
|
|
@app.on_event('shutdown')
|
|
async def shutdown_event():
|
|
# Stop AI Training Scheduler
|
|
try:
|
|
from .ai.services.ai_training_scheduler import get_training_scheduler
|
|
scheduler = get_training_scheduler()
|
|
scheduler.stop()
|
|
logger.info('AI Training Scheduler stopped')
|
|
except Exception as e:
|
|
logger.error(f'Error stopping AI Training Scheduler: {str(e)}', exc_info=True)
|
|
|
|
logger.info(f'{settings.APP_NAME} shutting down gracefully')
|
|
if __name__ == '__main__':
|
|
import uvicorn
|
|
import os
|
|
import signal
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
def signal_handler(sig, frame):
|
|
"""Handle Ctrl+C gracefully."""
|
|
logger.info('\nReceived interrupt signal (Ctrl+C). Shutting down gracefully...')
|
|
sys.exit(0)
|
|
|
|
# Register signal handler for graceful shutdown on Ctrl+C
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
base_dir = Path(__file__).parent.parent
|
|
src_dir = str(base_dir / 'src')
|
|
# Enable hot reload in development mode or if explicitly enabled via environment variable
|
|
use_reload = settings.is_development or os.getenv('ENABLE_RELOAD', 'false').lower() == 'true'
|
|
if use_reload:
|
|
logger.info('Hot reload enabled - server will restart on code changes')
|
|
logger.info('Press Ctrl+C to stop the server')
|
|
uvicorn.run(
|
|
'src.main:app',
|
|
host=settings.HOST,
|
|
port=settings.PORT,
|
|
reload=use_reload,
|
|
log_level=settings.LOG_LEVEL.lower(),
|
|
reload_dirs=[src_dir] if use_reload else None,
|
|
reload_excludes=['*.log', '*.pyc', '*.pyo', '*.pyd', '__pycache__', '**/__pycache__/**', '*.db', '*.sqlite', '*.sqlite3', 'venv/**', '.venv/**'],
|
|
reload_delay=0.5
|
|
) |