Files
Hotel-Booking/Backend/src/security/middleware/ip_whitelist.py
Iliyan Angelov 39fcfff811 update
2025-11-30 22:43:09 +02:00

196 lines
7.5 KiB
Python

from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from sqlalchemy.orm import Session
from typing import List
from ...shared.config.database import get_db
from ...shared.config.logging_config import get_logger
from ...shared.config.settings import settings
from ..models.security_event import IPWhitelist, IPBlacklist, SecurityEvent, SecurityEventType, SecurityEventSeverity
from datetime import datetime
import ipaddress
logger = get_logger(__name__)
class IPWhitelistMiddleware(BaseHTTPMiddleware):
"""Middleware to enforce IP whitelisting and blacklisting"""
def __init__(self, app, enabled: bool = True, whitelist_only: bool = False):
super().__init__(app)
self.enabled = enabled
self.whitelist_only = whitelist_only # If True, only whitelisted IPs allowed
async def dispatch(self, request: Request, call_next):
if not self.enabled:
return await call_next(request)
# Skip IP check for health checks and public endpoints
if request.url.path in ['/health', '/api/health', '/metrics']:
return await call_next(request)
client_ip = self._get_client_ip(request)
if not client_ip:
logger.warning("Could not determine client IP address")
return await call_next(request)
# Check blacklist first
if await self._is_blacklisted(client_ip):
await self._log_security_event(
request,
SecurityEventType.ip_blocked,
SecurityEventSeverity.high,
f"Blocked request from blacklisted IP: {client_ip}"
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"status": "error", "message": "Access denied"}
)
# Check whitelist if whitelist_only mode is enabled
if self.whitelist_only:
if not await self._is_whitelisted(client_ip):
await self._log_security_event(
request,
SecurityEventType.permission_denied,
SecurityEventSeverity.medium,
f"Blocked request from non-whitelisted IP: {client_ip}"
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"status": "error", "message": "Access denied. IP not whitelisted."}
)
return await call_next(request)
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP address from request"""
# Check for forwarded IP (when behind proxy/load balancer)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For can contain multiple IPs, take the first one
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip.strip()
# Fallback to direct client IP
if request.client:
return request.client.host
return None
async def _is_blacklisted(self, ip_address: str) -> bool:
"""Check if IP address is blacklisted"""
try:
db_gen = get_db()
db = next(db_gen)
try:
# Check exact match
blacklist_entry = db.query(IPBlacklist).filter(
IPBlacklist.ip_address == ip_address,
IPBlacklist.is_active == True
).first()
if blacklist_entry:
# Check if temporary block has expired
if blacklist_entry.blocked_until and blacklist_entry.blocked_until < datetime.utcnow():
# Block expired, deactivate it
blacklist_entry.is_active = False
db.commit()
return False
return True
# Check CIDR ranges (if needed)
# This is a simplified version - you might want to cache this
all_blacklist = db.query(IPBlacklist).filter(
IPBlacklist.is_active == True
).all()
for entry in all_blacklist:
try:
if '/' in entry.ip_address: # CIDR notation
network = ipaddress.ip_network(entry.ip_address, strict=False)
if ipaddress.ip_address(ip_address) in network:
return True
except (ValueError, ipaddress.AddressValueError):
continue
return False
finally:
db.close()
except Exception as e:
logger.error(f"Error checking IP blacklist: {str(e)}")
return False
async def _is_whitelisted(self, ip_address: str) -> bool:
"""Check if IP address is whitelisted"""
try:
db_gen = get_db()
db = next(db_gen)
try:
# Check exact match
whitelist_entry = db.query(IPWhitelist).filter(
IPWhitelist.ip_address == ip_address,
IPWhitelist.is_active == True
).first()
if whitelist_entry:
return True
# Check CIDR ranges
all_whitelist = db.query(IPWhitelist).filter(
IPWhitelist.is_active == True
).all()
for entry in all_whitelist:
try:
if '/' in entry.ip_address: # CIDR notation
network = ipaddress.ip_network(entry.ip_address, strict=False)
if ipaddress.ip_address(ip_address) in network:
return True
except (ValueError, ipaddress.AddressValueError):
continue
return False
finally:
db.close()
except Exception as e:
logger.error(f"Error checking IP whitelist: {str(e)}")
return False
async def _log_security_event(
self,
request: Request,
event_type: SecurityEventType,
severity: SecurityEventSeverity,
description: str
):
"""Log security event"""
try:
db_gen = get_db()
db = next(db_gen)
try:
client_ip = self._get_client_ip(request)
event = SecurityEvent(
event_type=event_type,
severity=severity,
ip_address=client_ip,
user_agent=request.headers.get("User-Agent"),
request_path=str(request.url.path),
request_method=request.method,
description=description,
details={
"url": str(request.url),
"headers": dict(request.headers)
}
)
db.add(event)
db.commit()
finally:
db.close()
except Exception as e:
logger.error(f"Error logging security event: {str(e)}")