update
This commit is contained in:
195
Backend/src/security/middleware/ip_whitelist.py
Normal file
195
Backend/src/security/middleware/ip_whitelist.py
Normal file
@@ -0,0 +1,195 @@
|
||||
from fastapi import Request, HTTPException, status
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from ...shared.config.database import get_db
|
||||
from ...shared.config.logging_config import get_logger
|
||||
from ...shared.config.settings import settings
|
||||
from ..models.security_event import IPWhitelist, IPBlacklist, SecurityEvent, SecurityEventType, SecurityEventSeverity
|
||||
from datetime import datetime
|
||||
import ipaddress
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class IPWhitelistMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to enforce IP whitelisting and blacklisting"""
|
||||
|
||||
def __init__(self, app, enabled: bool = True, whitelist_only: bool = False):
|
||||
super().__init__(app)
|
||||
self.enabled = enabled
|
||||
self.whitelist_only = whitelist_only # If True, only whitelisted IPs allowed
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if not self.enabled:
|
||||
return await call_next(request)
|
||||
|
||||
# Skip IP check for health checks and public endpoints
|
||||
if request.url.path in ['/health', '/api/health', '/metrics']:
|
||||
return await call_next(request)
|
||||
|
||||
client_ip = self._get_client_ip(request)
|
||||
|
||||
if not client_ip:
|
||||
logger.warning("Could not determine client IP address")
|
||||
return await call_next(request)
|
||||
|
||||
# Check blacklist first
|
||||
if await self._is_blacklisted(client_ip):
|
||||
await self._log_security_event(
|
||||
request,
|
||||
SecurityEventType.ip_blocked,
|
||||
SecurityEventSeverity.high,
|
||||
f"Blocked request from blacklisted IP: {client_ip}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
content={"status": "error", "message": "Access denied"}
|
||||
)
|
||||
|
||||
# Check whitelist if whitelist_only mode is enabled
|
||||
if self.whitelist_only:
|
||||
if not await self._is_whitelisted(client_ip):
|
||||
await self._log_security_event(
|
||||
request,
|
||||
SecurityEventType.permission_denied,
|
||||
SecurityEventSeverity.medium,
|
||||
f"Blocked request from non-whitelisted IP: {client_ip}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
content={"status": "error", "message": "Access denied. IP not whitelisted."}
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract client IP address from request"""
|
||||
# Check for forwarded IP (when behind proxy/load balancer)
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For can contain multiple IPs, take the first one
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip.strip()
|
||||
|
||||
# Fallback to direct client IP
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return None
|
||||
|
||||
async def _is_blacklisted(self, ip_address: str) -> bool:
|
||||
"""Check if IP address is blacklisted"""
|
||||
try:
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
try:
|
||||
# Check exact match
|
||||
blacklist_entry = db.query(IPBlacklist).filter(
|
||||
IPBlacklist.ip_address == ip_address,
|
||||
IPBlacklist.is_active == True
|
||||
).first()
|
||||
|
||||
if blacklist_entry:
|
||||
# Check if temporary block has expired
|
||||
if blacklist_entry.blocked_until and blacklist_entry.blocked_until < datetime.utcnow():
|
||||
# Block expired, deactivate it
|
||||
blacklist_entry.is_active = False
|
||||
db.commit()
|
||||
return False
|
||||
return True
|
||||
|
||||
# Check CIDR ranges (if needed)
|
||||
# This is a simplified version - you might want to cache this
|
||||
all_blacklist = db.query(IPBlacklist).filter(
|
||||
IPBlacklist.is_active == True
|
||||
).all()
|
||||
|
||||
for entry in all_blacklist:
|
||||
try:
|
||||
if '/' in entry.ip_address: # CIDR notation
|
||||
network = ipaddress.ip_network(entry.ip_address, strict=False)
|
||||
if ipaddress.ip_address(ip_address) in network:
|
||||
return True
|
||||
except (ValueError, ipaddress.AddressValueError):
|
||||
continue
|
||||
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking IP blacklist: {str(e)}")
|
||||
return False
|
||||
|
||||
async def _is_whitelisted(self, ip_address: str) -> bool:
|
||||
"""Check if IP address is whitelisted"""
|
||||
try:
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
try:
|
||||
# Check exact match
|
||||
whitelist_entry = db.query(IPWhitelist).filter(
|
||||
IPWhitelist.ip_address == ip_address,
|
||||
IPWhitelist.is_active == True
|
||||
).first()
|
||||
|
||||
if whitelist_entry:
|
||||
return True
|
||||
|
||||
# Check CIDR ranges
|
||||
all_whitelist = db.query(IPWhitelist).filter(
|
||||
IPWhitelist.is_active == True
|
||||
).all()
|
||||
|
||||
for entry in all_whitelist:
|
||||
try:
|
||||
if '/' in entry.ip_address: # CIDR notation
|
||||
network = ipaddress.ip_network(entry.ip_address, strict=False)
|
||||
if ipaddress.ip_address(ip_address) in network:
|
||||
return True
|
||||
except (ValueError, ipaddress.AddressValueError):
|
||||
continue
|
||||
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking IP whitelist: {str(e)}")
|
||||
return False
|
||||
|
||||
async def _log_security_event(
|
||||
self,
|
||||
request: Request,
|
||||
event_type: SecurityEventType,
|
||||
severity: SecurityEventSeverity,
|
||||
description: str
|
||||
):
|
||||
"""Log security event"""
|
||||
try:
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
try:
|
||||
client_ip = self._get_client_ip(request)
|
||||
event = SecurityEvent(
|
||||
event_type=event_type,
|
||||
severity=severity,
|
||||
ip_address=client_ip,
|
||||
user_agent=request.headers.get("User-Agent"),
|
||||
request_path=str(request.url.path),
|
||||
request_method=request.method,
|
||||
description=description,
|
||||
details={
|
||||
"url": str(request.url),
|
||||
"headers": dict(request.headers)
|
||||
}
|
||||
)
|
||||
db.add(event)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging security event: {str(e)}")
|
||||
|
||||
Reference in New Issue
Block a user