from fastapi import Request, HTTPException, status from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse from sqlalchemy.orm import Session from typing import List from ...shared.config.database import get_db from ...shared.config.logging_config import get_logger from ...shared.config.settings import settings from ..models.security_event import IPWhitelist, IPBlacklist, SecurityEvent, SecurityEventType, SecurityEventSeverity from datetime import datetime import ipaddress logger = get_logger(__name__) class IPWhitelistMiddleware(BaseHTTPMiddleware): """Middleware to enforce IP whitelisting and blacklisting""" def __init__(self, app, enabled: bool = True, whitelist_only: bool = False): super().__init__(app) self.enabled = enabled self.whitelist_only = whitelist_only # If True, only whitelisted IPs allowed async def dispatch(self, request: Request, call_next): if not self.enabled: return await call_next(request) # Skip IP check for health checks and public endpoints if request.url.path in ['/health', '/api/health', '/metrics']: return await call_next(request) client_ip = self._get_client_ip(request) if not client_ip: logger.warning("Could not determine client IP address") return await call_next(request) # Check blacklist first if await self._is_blacklisted(client_ip): await self._log_security_event( request, SecurityEventType.ip_blocked, SecurityEventSeverity.high, f"Blocked request from blacklisted IP: {client_ip}" ) return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content={"status": "error", "message": "Access denied"} ) # Check whitelist if whitelist_only mode is enabled if self.whitelist_only: if not await self._is_whitelisted(client_ip): await self._log_security_event( request, SecurityEventType.permission_denied, SecurityEventSeverity.medium, f"Blocked request from non-whitelisted IP: {client_ip}" ) return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content={"status": "error", "message": "Access denied. IP not whitelisted."} ) return await call_next(request) def _get_client_ip(self, request: Request) -> str: """Extract client IP address from request""" # Check for forwarded IP (when behind proxy/load balancer) forwarded_for = request.headers.get("X-Forwarded-For") if forwarded_for: # X-Forwarded-For can contain multiple IPs, take the first one return forwarded_for.split(",")[0].strip() real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip.strip() # Fallback to direct client IP if request.client: return request.client.host return None async def _is_blacklisted(self, ip_address: str) -> bool: """Check if IP address is blacklisted""" try: db_gen = get_db() db = next(db_gen) try: # Check exact match blacklist_entry = db.query(IPBlacklist).filter( IPBlacklist.ip_address == ip_address, IPBlacklist.is_active == True ).first() if blacklist_entry: # Check if temporary block has expired if blacklist_entry.blocked_until and blacklist_entry.blocked_until < datetime.utcnow(): # Block expired, deactivate it blacklist_entry.is_active = False db.commit() return False return True # Check CIDR ranges (if needed) # This is a simplified version - you might want to cache this all_blacklist = db.query(IPBlacklist).filter( IPBlacklist.is_active == True ).all() for entry in all_blacklist: try: if '/' in entry.ip_address: # CIDR notation network = ipaddress.ip_network(entry.ip_address, strict=False) if ipaddress.ip_address(ip_address) in network: return True except (ValueError, ipaddress.AddressValueError): continue return False finally: db.close() except Exception as e: logger.error(f"Error checking IP blacklist: {str(e)}") return False async def _is_whitelisted(self, ip_address: str) -> bool: """Check if IP address is whitelisted""" try: db_gen = get_db() db = next(db_gen) try: # Check exact match whitelist_entry = db.query(IPWhitelist).filter( IPWhitelist.ip_address == ip_address, IPWhitelist.is_active == True ).first() if whitelist_entry: return True # Check CIDR ranges all_whitelist = db.query(IPWhitelist).filter( IPWhitelist.is_active == True ).all() for entry in all_whitelist: try: if '/' in entry.ip_address: # CIDR notation network = ipaddress.ip_network(entry.ip_address, strict=False) if ipaddress.ip_address(ip_address) in network: return True except (ValueError, ipaddress.AddressValueError): continue return False finally: db.close() except Exception as e: logger.error(f"Error checking IP whitelist: {str(e)}") return False async def _log_security_event( self, request: Request, event_type: SecurityEventType, severity: SecurityEventSeverity, description: str ): """Log security event""" try: db_gen = get_db() db = next(db_gen) try: client_ip = self._get_client_ip(request) event = SecurityEvent( event_type=event_type, severity=severity, ip_address=client_ip, user_agent=request.headers.get("User-Agent"), request_path=str(request.url.path), request_method=request.method, description=description, details={ "url": str(request.url), "headers": dict(request.headers) } ) db.add(event) db.commit() finally: db.close() except Exception as e: logger.error(f"Error logging security event: {str(e)}")