280 lines
10 KiB
Python
280 lines
10 KiB
Python
from fastapi import FastAPI, Request, HTTPException, Depends, status
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.exceptions import RequestValidationError
|
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
|
from jose.exceptions import JWTError
|
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
|
from slowapi.util import get_remote_address
|
|
from slowapi.errors import RateLimitExceeded
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
import sys
|
|
|
|
# Import configuration and logging FIRST
|
|
from .config.settings import settings
|
|
from .config.logging_config import setup_logging, get_logger
|
|
from .config.database import engine, Base, get_db
|
|
from . import models # noqa: F401 - ensure models are imported so tables are created
|
|
from sqlalchemy.orm import Session
|
|
|
|
# Setup logging before anything else
|
|
logger = setup_logging()
|
|
|
|
logger.info(f"Starting {settings.APP_NAME} v{settings.APP_VERSION} in {settings.ENVIRONMENT} mode")
|
|
|
|
# Import middleware
|
|
from .middleware.error_handler import (
|
|
validation_exception_handler,
|
|
integrity_error_handler,
|
|
jwt_error_handler,
|
|
http_exception_handler,
|
|
general_exception_handler
|
|
)
|
|
from .middleware.request_id import RequestIDMiddleware
|
|
from .middleware.security import SecurityHeadersMiddleware
|
|
from .middleware.timeout import TimeoutMiddleware
|
|
from .middleware.cookie_consent import CookieConsentMiddleware
|
|
|
|
# Create database tables (for development, migrations should be used in production)
|
|
if settings.is_development:
|
|
logger.info("Creating database tables (development mode)")
|
|
Base.metadata.create_all(bind=engine)
|
|
else:
|
|
# Ensure new cookie-related tables exist even if full migrations haven't been run yet.
|
|
try:
|
|
from .models.cookie_policy import CookiePolicy
|
|
from .models.cookie_integration_config import CookieIntegrationConfig
|
|
logger.info("Ensuring cookie-related tables exist")
|
|
CookiePolicy.__table__.create(bind=engine, checkfirst=True)
|
|
CookieIntegrationConfig.__table__.create(bind=engine, checkfirst=True)
|
|
except Exception as e:
|
|
logger.error(f"Failed to ensure cookie tables exist: {e}")
|
|
|
|
from .routes import auth_routes
|
|
from .routes import privacy_routes
|
|
|
|
# Initialize FastAPI app
|
|
app = FastAPI(
|
|
title=settings.APP_NAME,
|
|
description="Enterprise-grade Hotel Booking API",
|
|
version=settings.APP_VERSION,
|
|
docs_url="/api/docs" if not settings.is_production else None,
|
|
redoc_url="/api/redoc" if not settings.is_production else None,
|
|
openapi_url="/api/openapi.json" if not settings.is_production else None
|
|
)
|
|
|
|
# Add middleware in order (order matters!)
|
|
# 1. Request ID middleware (first to add request ID)
|
|
app.add_middleware(RequestIDMiddleware)
|
|
|
|
# 2. Cookie consent middleware (makes consent available on request.state)
|
|
app.add_middleware(CookieConsentMiddleware)
|
|
|
|
# 3. Timeout middleware
|
|
if settings.REQUEST_TIMEOUT > 0:
|
|
app.add_middleware(TimeoutMiddleware)
|
|
|
|
# 4. Security headers middleware
|
|
app.add_middleware(SecurityHeadersMiddleware)
|
|
|
|
# Rate limiting
|
|
if settings.RATE_LIMIT_ENABLED:
|
|
limiter = Limiter(
|
|
key_func=get_remote_address,
|
|
default_limits=[f"{settings.RATE_LIMIT_PER_MINUTE}/minute"]
|
|
)
|
|
app.state.limiter = limiter
|
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
|
logger.info(f"Rate limiting enabled: {settings.RATE_LIMIT_PER_MINUTE} requests/minute")
|
|
|
|
# CORS configuration
|
|
if settings.is_development:
|
|
# For development, use regex to allow any localhost port
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origin_regex=r"http://(localhost|127\.0\.0\.1)(:\d+)?",
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
logger.info("CORS configured for development (allowing localhost)")
|
|
else:
|
|
# Production: use specific origins
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=settings.CORS_ORIGINS,
|
|
allow_credentials=True,
|
|
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
|
|
allow_headers=["*"],
|
|
)
|
|
logger.info(f"CORS configured for production with {len(settings.CORS_ORIGINS)} allowed origins")
|
|
|
|
# Serve static files (uploads)
|
|
uploads_dir = Path(__file__).parent.parent / settings.UPLOAD_DIR
|
|
uploads_dir.mkdir(exist_ok=True)
|
|
app.mount("/uploads", StaticFiles(directory=str(uploads_dir)), name="uploads")
|
|
|
|
# Exception handlers
|
|
app.add_exception_handler(HTTPException, http_exception_handler)
|
|
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
|
app.add_exception_handler(IntegrityError, integrity_error_handler)
|
|
app.add_exception_handler(JWTError, jwt_error_handler)
|
|
app.add_exception_handler(Exception, general_exception_handler)
|
|
|
|
# Enhanced Health check with database connectivity
|
|
@app.get("/health", tags=["health"])
|
|
async def health_check(db: Session = Depends(get_db)):
|
|
"""
|
|
Enhanced health check endpoint with database connectivity test
|
|
"""
|
|
health_status = {
|
|
"status": "healthy",
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"service": settings.APP_NAME,
|
|
"version": settings.APP_VERSION,
|
|
"environment": settings.ENVIRONMENT,
|
|
"checks": {
|
|
"api": "ok",
|
|
"database": "unknown"
|
|
}
|
|
}
|
|
|
|
# Check database connectivity
|
|
try:
|
|
from sqlalchemy import text
|
|
db.execute(text("SELECT 1"))
|
|
health_status["checks"]["database"] = "ok"
|
|
except OperationalError as e:
|
|
health_status["status"] = "unhealthy"
|
|
health_status["checks"]["database"] = "error"
|
|
health_status["error"] = str(e)
|
|
logger.error(f"Database health check failed: {str(e)}")
|
|
return JSONResponse(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
content=health_status
|
|
)
|
|
except Exception as e:
|
|
health_status["status"] = "unhealthy"
|
|
health_status["checks"]["database"] = "error"
|
|
health_status["error"] = str(e)
|
|
logger.error(f"Health check failed: {str(e)}")
|
|
return JSONResponse(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
content=health_status
|
|
)
|
|
|
|
return health_status
|
|
|
|
|
|
# Metrics endpoint (basic)
|
|
@app.get("/metrics", tags=["monitoring"])
|
|
async def metrics():
|
|
"""
|
|
Basic metrics endpoint (can be extended with Prometheus or similar)
|
|
"""
|
|
return {
|
|
"status": "success",
|
|
"service": settings.APP_NAME,
|
|
"version": settings.APP_VERSION,
|
|
"environment": settings.ENVIRONMENT,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
# API Routes with versioning
|
|
# Legacy routes (maintain backward compatibility)
|
|
app.include_router(auth_routes.router, prefix="/api")
|
|
app.include_router(privacy_routes.router, prefix="/api")
|
|
|
|
# Versioned API routes (v1)
|
|
app.include_router(auth_routes.router, prefix=settings.API_V1_PREFIX)
|
|
app.include_router(privacy_routes.router, prefix=settings.API_V1_PREFIX)
|
|
|
|
# Import and include other routes
|
|
from .routes import (
|
|
room_routes, booking_routes, payment_routes, invoice_routes, banner_routes,
|
|
favorite_routes, service_routes, promotion_routes, report_routes,
|
|
review_routes, user_routes, audit_routes, admin_privacy_routes,
|
|
system_settings_routes
|
|
)
|
|
|
|
# Legacy routes (maintain backward compatibility)
|
|
app.include_router(room_routes.router, prefix="/api")
|
|
app.include_router(booking_routes.router, prefix="/api")
|
|
app.include_router(payment_routes.router, prefix="/api")
|
|
app.include_router(invoice_routes.router, prefix="/api")
|
|
app.include_router(banner_routes.router, prefix="/api")
|
|
app.include_router(favorite_routes.router, prefix="/api")
|
|
app.include_router(service_routes.router, prefix="/api")
|
|
app.include_router(promotion_routes.router, prefix="/api")
|
|
app.include_router(report_routes.router, prefix="/api")
|
|
app.include_router(review_routes.router, prefix="/api")
|
|
app.include_router(user_routes.router, prefix="/api")
|
|
app.include_router(audit_routes.router, prefix="/api")
|
|
app.include_router(admin_privacy_routes.router, prefix="/api")
|
|
app.include_router(system_settings_routes.router, prefix="/api")
|
|
|
|
# Versioned routes (v1)
|
|
app.include_router(room_routes.router, prefix=settings.API_V1_PREFIX)
|
|
app.include_router(booking_routes.router, prefix=settings.API_V1_PREFIX)
|
|
app.include_router(payment_routes.router, prefix=settings.API_V1_PREFIX)
|
|
app.include_router(invoice_routes.router, prefix=settings.API_V1_PREFIX)
|
|
app.include_router(banner_routes.router, prefix=settings.API_V1_PREFIX)
|
|
app.include_router(favorite_routes.router, prefix=settings.API_V1_PREFIX)
|
|
app.include_router(service_routes.router, prefix=settings.API_V1_PREFIX)
|
|
app.include_router(promotion_routes.router, prefix=settings.API_V1_PREFIX)
|
|
app.include_router(report_routes.router, prefix=settings.API_V1_PREFIX)
|
|
app.include_router(review_routes.router, prefix=settings.API_V1_PREFIX)
|
|
app.include_router(user_routes.router, prefix=settings.API_V1_PREFIX)
|
|
app.include_router(audit_routes.router, prefix=settings.API_V1_PREFIX)
|
|
app.include_router(admin_privacy_routes.router, prefix=settings.API_V1_PREFIX)
|
|
app.include_router(system_settings_routes.router, prefix=settings.API_V1_PREFIX)
|
|
|
|
logger.info("All routes registered successfully")
|
|
|
|
# Startup event
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
"""Run on application startup"""
|
|
logger.info(f"{settings.APP_NAME} started successfully")
|
|
logger.info(f"Environment: {settings.ENVIRONMENT}")
|
|
logger.info(f"Debug mode: {settings.DEBUG}")
|
|
logger.info(f"API version: {settings.API_V1_PREFIX}")
|
|
|
|
# Shutdown event
|
|
@app.on_event("shutdown")
|
|
async def shutdown_event():
|
|
"""Run on application shutdown"""
|
|
logger.info(f"{settings.APP_NAME} shutting down gracefully")
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
from pathlib import Path
|
|
|
|
# Only watch the src directory to avoid watching logs, uploads, etc.
|
|
base_dir = Path(__file__).parent.parent
|
|
src_dir = str(base_dir / "src")
|
|
|
|
uvicorn.run(
|
|
"src.main:app",
|
|
host=settings.HOST,
|
|
port=settings.PORT,
|
|
reload=settings.is_development,
|
|
log_level=settings.LOG_LEVEL.lower(),
|
|
reload_dirs=[src_dir] if settings.is_development else None,
|
|
reload_excludes=[
|
|
"*.log",
|
|
"*.pyc",
|
|
"*.pyo",
|
|
"*.pyd",
|
|
"__pycache__",
|
|
"**/__pycache__/**",
|
|
"*.db",
|
|
"*.sqlite",
|
|
"*.sqlite3"
|
|
],
|
|
reload_delay=0.5 # Increase delay to reduce false positives
|
|
)
|
|
|