This commit is contained in:
Iliyan Angelov
2025-11-30 22:43:09 +02:00
parent 24b40450dd
commit 39fcfff811
1610 changed files with 5442 additions and 1383 deletions

View File

@@ -0,0 +1,195 @@
from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from sqlalchemy.orm import Session
from typing import List
from ...shared.config.database import get_db
from ...shared.config.logging_config import get_logger
from ...shared.config.settings import settings
from ..models.security_event import IPWhitelist, IPBlacklist, SecurityEvent, SecurityEventType, SecurityEventSeverity
from datetime import datetime
import ipaddress
logger = get_logger(__name__)
class IPWhitelistMiddleware(BaseHTTPMiddleware):
"""Middleware to enforce IP whitelisting and blacklisting"""
def __init__(self, app, enabled: bool = True, whitelist_only: bool = False):
super().__init__(app)
self.enabled = enabled
self.whitelist_only = whitelist_only # If True, only whitelisted IPs allowed
async def dispatch(self, request: Request, call_next):
if not self.enabled:
return await call_next(request)
# Skip IP check for health checks and public endpoints
if request.url.path in ['/health', '/api/health', '/metrics']:
return await call_next(request)
client_ip = self._get_client_ip(request)
if not client_ip:
logger.warning("Could not determine client IP address")
return await call_next(request)
# Check blacklist first
if await self._is_blacklisted(client_ip):
await self._log_security_event(
request,
SecurityEventType.ip_blocked,
SecurityEventSeverity.high,
f"Blocked request from blacklisted IP: {client_ip}"
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"status": "error", "message": "Access denied"}
)
# Check whitelist if whitelist_only mode is enabled
if self.whitelist_only:
if not await self._is_whitelisted(client_ip):
await self._log_security_event(
request,
SecurityEventType.permission_denied,
SecurityEventSeverity.medium,
f"Blocked request from non-whitelisted IP: {client_ip}"
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"status": "error", "message": "Access denied. IP not whitelisted."}
)
return await call_next(request)
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP address from request"""
# Check for forwarded IP (when behind proxy/load balancer)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For can contain multiple IPs, take the first one
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip.strip()
# Fallback to direct client IP
if request.client:
return request.client.host
return None
async def _is_blacklisted(self, ip_address: str) -> bool:
"""Check if IP address is blacklisted"""
try:
db_gen = get_db()
db = next(db_gen)
try:
# Check exact match
blacklist_entry = db.query(IPBlacklist).filter(
IPBlacklist.ip_address == ip_address,
IPBlacklist.is_active == True
).first()
if blacklist_entry:
# Check if temporary block has expired
if blacklist_entry.blocked_until and blacklist_entry.blocked_until < datetime.utcnow():
# Block expired, deactivate it
blacklist_entry.is_active = False
db.commit()
return False
return True
# Check CIDR ranges (if needed)
# This is a simplified version - you might want to cache this
all_blacklist = db.query(IPBlacklist).filter(
IPBlacklist.is_active == True
).all()
for entry in all_blacklist:
try:
if '/' in entry.ip_address: # CIDR notation
network = ipaddress.ip_network(entry.ip_address, strict=False)
if ipaddress.ip_address(ip_address) in network:
return True
except (ValueError, ipaddress.AddressValueError):
continue
return False
finally:
db.close()
except Exception as e:
logger.error(f"Error checking IP blacklist: {str(e)}")
return False
async def _is_whitelisted(self, ip_address: str) -> bool:
"""Check if IP address is whitelisted"""
try:
db_gen = get_db()
db = next(db_gen)
try:
# Check exact match
whitelist_entry = db.query(IPWhitelist).filter(
IPWhitelist.ip_address == ip_address,
IPWhitelist.is_active == True
).first()
if whitelist_entry:
return True
# Check CIDR ranges
all_whitelist = db.query(IPWhitelist).filter(
IPWhitelist.is_active == True
).all()
for entry in all_whitelist:
try:
if '/' in entry.ip_address: # CIDR notation
network = ipaddress.ip_network(entry.ip_address, strict=False)
if ipaddress.ip_address(ip_address) in network:
return True
except (ValueError, ipaddress.AddressValueError):
continue
return False
finally:
db.close()
except Exception as e:
logger.error(f"Error checking IP whitelist: {str(e)}")
return False
async def _log_security_event(
self,
request: Request,
event_type: SecurityEventType,
severity: SecurityEventSeverity,
description: str
):
"""Log security event"""
try:
db_gen = get_db()
db = next(db_gen)
try:
client_ip = self._get_client_ip(request)
event = SecurityEvent(
event_type=event_type,
severity=severity,
ip_address=client_ip,
user_agent=request.headers.get("User-Agent"),
request_path=str(request.url.path),
request_method=request.method,
description=description,
details={
"url": str(request.url),
"headers": dict(request.headers)
}
)
db.add(event)
db.commit()
finally:
db.close()
except Exception as e:
logger.error(f"Error logging security event: {str(e)}")