196 lines
7.5 KiB
Python
196 lines
7.5 KiB
Python
from fastapi import Request, HTTPException, status
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.responses import JSONResponse
|
|
from sqlalchemy.orm import Session
|
|
from typing import List
|
|
from ...shared.config.database import get_db
|
|
from ...shared.config.logging_config import get_logger
|
|
from ...shared.config.settings import settings
|
|
from ..models.security_event import IPWhitelist, IPBlacklist, SecurityEvent, SecurityEventType, SecurityEventSeverity
|
|
from datetime import datetime
|
|
import ipaddress
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
class IPWhitelistMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware to enforce IP whitelisting and blacklisting"""
|
|
|
|
def __init__(self, app, enabled: bool = True, whitelist_only: bool = False):
|
|
super().__init__(app)
|
|
self.enabled = enabled
|
|
self.whitelist_only = whitelist_only # If True, only whitelisted IPs allowed
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
if not self.enabled:
|
|
return await call_next(request)
|
|
|
|
# Skip IP check for health checks and public endpoints
|
|
if request.url.path in ['/health', '/api/health', '/metrics']:
|
|
return await call_next(request)
|
|
|
|
client_ip = self._get_client_ip(request)
|
|
|
|
if not client_ip:
|
|
logger.warning("Could not determine client IP address")
|
|
return await call_next(request)
|
|
|
|
# Check blacklist first
|
|
if await self._is_blacklisted(client_ip):
|
|
await self._log_security_event(
|
|
request,
|
|
SecurityEventType.ip_blocked,
|
|
SecurityEventSeverity.high,
|
|
f"Blocked request from blacklisted IP: {client_ip}"
|
|
)
|
|
return JSONResponse(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
content={"status": "error", "message": "Access denied"}
|
|
)
|
|
|
|
# Check whitelist if whitelist_only mode is enabled
|
|
if self.whitelist_only:
|
|
if not await self._is_whitelisted(client_ip):
|
|
await self._log_security_event(
|
|
request,
|
|
SecurityEventType.permission_denied,
|
|
SecurityEventSeverity.medium,
|
|
f"Blocked request from non-whitelisted IP: {client_ip}"
|
|
)
|
|
return JSONResponse(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
content={"status": "error", "message": "Access denied. IP not whitelisted."}
|
|
)
|
|
|
|
return await call_next(request)
|
|
|
|
def _get_client_ip(self, request: Request) -> str:
|
|
"""Extract client IP address from request"""
|
|
# Check for forwarded IP (when behind proxy/load balancer)
|
|
forwarded_for = request.headers.get("X-Forwarded-For")
|
|
if forwarded_for:
|
|
# X-Forwarded-For can contain multiple IPs, take the first one
|
|
return forwarded_for.split(",")[0].strip()
|
|
|
|
real_ip = request.headers.get("X-Real-IP")
|
|
if real_ip:
|
|
return real_ip.strip()
|
|
|
|
# Fallback to direct client IP
|
|
if request.client:
|
|
return request.client.host
|
|
|
|
return None
|
|
|
|
async def _is_blacklisted(self, ip_address: str) -> bool:
|
|
"""Check if IP address is blacklisted"""
|
|
try:
|
|
db_gen = get_db()
|
|
db = next(db_gen)
|
|
try:
|
|
# Check exact match
|
|
blacklist_entry = db.query(IPBlacklist).filter(
|
|
IPBlacklist.ip_address == ip_address,
|
|
IPBlacklist.is_active == True
|
|
).first()
|
|
|
|
if blacklist_entry:
|
|
# Check if temporary block has expired
|
|
if blacklist_entry.blocked_until and blacklist_entry.blocked_until < datetime.utcnow():
|
|
# Block expired, deactivate it
|
|
blacklist_entry.is_active = False
|
|
db.commit()
|
|
return False
|
|
return True
|
|
|
|
# Check CIDR ranges (if needed)
|
|
# This is a simplified version - you might want to cache this
|
|
all_blacklist = db.query(IPBlacklist).filter(
|
|
IPBlacklist.is_active == True
|
|
).all()
|
|
|
|
for entry in all_blacklist:
|
|
try:
|
|
if '/' in entry.ip_address: # CIDR notation
|
|
network = ipaddress.ip_network(entry.ip_address, strict=False)
|
|
if ipaddress.ip_address(ip_address) in network:
|
|
return True
|
|
except (ValueError, ipaddress.AddressValueError):
|
|
continue
|
|
|
|
return False
|
|
finally:
|
|
db.close()
|
|
except Exception as e:
|
|
logger.error(f"Error checking IP blacklist: {str(e)}")
|
|
return False
|
|
|
|
async def _is_whitelisted(self, ip_address: str) -> bool:
|
|
"""Check if IP address is whitelisted"""
|
|
try:
|
|
db_gen = get_db()
|
|
db = next(db_gen)
|
|
try:
|
|
# Check exact match
|
|
whitelist_entry = db.query(IPWhitelist).filter(
|
|
IPWhitelist.ip_address == ip_address,
|
|
IPWhitelist.is_active == True
|
|
).first()
|
|
|
|
if whitelist_entry:
|
|
return True
|
|
|
|
# Check CIDR ranges
|
|
all_whitelist = db.query(IPWhitelist).filter(
|
|
IPWhitelist.is_active == True
|
|
).all()
|
|
|
|
for entry in all_whitelist:
|
|
try:
|
|
if '/' in entry.ip_address: # CIDR notation
|
|
network = ipaddress.ip_network(entry.ip_address, strict=False)
|
|
if ipaddress.ip_address(ip_address) in network:
|
|
return True
|
|
except (ValueError, ipaddress.AddressValueError):
|
|
continue
|
|
|
|
return False
|
|
finally:
|
|
db.close()
|
|
except Exception as e:
|
|
logger.error(f"Error checking IP whitelist: {str(e)}")
|
|
return False
|
|
|
|
async def _log_security_event(
|
|
self,
|
|
request: Request,
|
|
event_type: SecurityEventType,
|
|
severity: SecurityEventSeverity,
|
|
description: str
|
|
):
|
|
"""Log security event"""
|
|
try:
|
|
db_gen = get_db()
|
|
db = next(db_gen)
|
|
try:
|
|
client_ip = self._get_client_ip(request)
|
|
event = SecurityEvent(
|
|
event_type=event_type,
|
|
severity=severity,
|
|
ip_address=client_ip,
|
|
user_agent=request.headers.get("User-Agent"),
|
|
request_path=str(request.url.path),
|
|
request_method=request.method,
|
|
description=description,
|
|
details={
|
|
"url": str(request.url),
|
|
"headers": dict(request.headers)
|
|
}
|
|
)
|
|
db.add(event)
|
|
db.commit()
|
|
finally:
|
|
db.close()
|
|
except Exception as e:
|
|
logger.error(f"Error logging security event: {str(e)}")
|
|
|