This commit is contained in:
Iliyan Angelov
2025-11-30 22:43:09 +02:00
parent 24b40450dd
commit 39fcfff811
1610 changed files with 5442 additions and 1383 deletions

View File

View File

@@ -0,0 +1,118 @@
from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from typing import List
import ipaddress
from ...shared.config.settings import settings
from ...shared.config.logging_config import get_logger
logger = get_logger(__name__)
class AdminIPWhitelistMiddleware(BaseHTTPMiddleware):
"""
Middleware to enforce IP whitelisting for admin endpoints.
Only applies to routes starting with /api/admin/ or containing 'admin' in path.
"""
def __init__(self, app, enabled: bool = None, whitelist: List[str] = None):
super().__init__(app)
self.enabled = enabled if enabled is not None else settings.IP_WHITELIST_ENABLED
self.whitelist = whitelist if whitelist is not None else settings.ADMIN_IP_WHITELIST
# Pre-compile IP networks for faster lookup
self._compiled_networks = []
if self.enabled and self.whitelist:
for ip_or_cidr in self.whitelist:
try:
if '/' in ip_or_cidr:
# CIDR notation
network = ipaddress.ip_network(ip_or_cidr, strict=False)
self._compiled_networks.append(network)
else:
# Single IP address
ip = ipaddress.ip_address(ip_or_cidr)
# Convert to /32 network for consistent handling
self._compiled_networks.append(ipaddress.ip_network(f'{ip}/32', strict=False))
except (ValueError, ipaddress.AddressValueError) as e:
logger.warning(f'Invalid IP/CIDR in admin whitelist: {ip_or_cidr} - {str(e)}')
if self.enabled:
logger.info(f'Admin IP whitelisting enabled with {len(self._compiled_networks)} allowed IP(s)/CIDR range(s)')
def _is_admin_route(self, path: str) -> bool:
"""Check if the path is an admin route"""
return '/admin/' in path.lower() or path.lower().startswith('/api/admin')
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP address from request"""
# Check for forwarded IP (when behind proxy/load balancer)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For can contain multiple IPs, take the first one (original client)
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip.strip()
# Fallback to direct client IP
if request.client:
return request.client.host
return None
def _is_ip_allowed(self, ip_address: str) -> bool:
"""Check if IP address is in whitelist"""
if not self._compiled_networks:
# Empty whitelist means deny all (security-first approach)
return False
try:
client_ip = ipaddress.ip_address(ip_address)
for network in self._compiled_networks:
if client_ip in network:
return True
return False
except (ValueError, ipaddress.AddressValueError):
logger.warning(f'Invalid IP address format: {ip_address}')
return False
async def dispatch(self, request: Request, call_next):
# Skip if not enabled
if not self.enabled:
return await call_next(request)
# Skip OPTIONS requests (CORS preflight) - let CORS middleware handle them
if request.method == 'OPTIONS':
return await call_next(request)
# Only apply to admin routes
if not self._is_admin_route(request.url.path):
return await call_next(request)
# Skip IP check for health checks and public endpoints
if request.url.path in ['/health', '/api/health', '/metrics']:
return await call_next(request)
client_ip = self._get_client_ip(request)
if not client_ip:
logger.warning("Could not determine client IP address for admin route")
# Deny by default if IP cannot be determined (security-first)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"status": "error", "message": "Access denied: Unable to verify IP address"}
)
# Check whitelist
if not self._is_ip_allowed(client_ip):
logger.warning(
f"Admin route access denied for IP: {client_ip} from path: {request.url.path}"
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"status": "error", "message": "Access denied. IP address not whitelisted."}
)
# IP is whitelisted, continue
return await call_next(request)

View File

@@ -0,0 +1,133 @@
from fastapi import Depends, HTTPException, status, Request, Cookie
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import JWTError, jwt
from sqlalchemy.orm import Session
from typing import Optional
import os
from ...shared.config.database import get_db
from ...shared.config.settings import settings
from ...auth.models.user import User
from ...auth.models.role import Role
security = HTTPBearer(auto_error=False)
def get_jwt_secret() -> str:
"""
Get JWT secret securely, fail if not configured.
Never use hardcoded fallback secrets.
"""
default_secret = 'dev-secret-key-change-in-production-12345'
jwt_secret = getattr(settings, 'JWT_SECRET', None) or os.getenv('JWT_SECRET', None)
# Fail fast if secret is not configured or using default value
if not jwt_secret or jwt_secret == default_secret:
if settings.is_production:
raise ValueError(
'CRITICAL: JWT_SECRET is not properly configured in production. '
'Please set JWT_SECRET environment variable to a secure random string.'
)
# In development, warn but allow (startup validation should catch this)
import warnings
warnings.warn(
f'JWT_SECRET not configured. Using settings value but this is insecure. '
f'Set JWT_SECRET environment variable.',
UserWarning
)
jwt_secret = getattr(settings, 'JWT_SECRET', None)
if not jwt_secret:
raise ValueError('JWT_SECRET must be configured')
return jwt_secret
def get_current_user(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
access_token: Optional[str] = Cookie(None, alias='accessToken'),
db: Session = Depends(get_db)
) -> User:
"""
Get current user from either Authorization header or httpOnly cookie.
Prefers Authorization header for backward compatibility, falls back to cookie.
"""
# Try to get token from Authorization header first
token = None
if credentials:
token = credentials.credentials
# Fall back to cookie if no header token
if not token and access_token:
token = access_token
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Could not validate credentials',
headers={'WWW-Authenticate': 'Bearer'}
)
if not token:
raise credentials_exception
try:
jwt_secret = get_jwt_secret()
payload = jwt.decode(token, jwt_secret, algorithms=['HS256'])
user_id: int = payload.get('userId')
if user_id is None:
raise credentials_exception
except JWTError:
raise credentials_exception
except ValueError as e:
# JWT secret configuration error - should not happen in production
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Server configuration error')
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise credentials_exception
return user
def authorize_roles(*allowed_roles: str):
def role_checker(current_user: User=Depends(get_current_user), db: Session=Depends(get_db)) -> User:
role = db.query(Role).filter(Role.id == current_user.role_id).first()
if not role:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail='User role not found')
user_role_name = role.name
if user_role_name not in allowed_roles:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail='You do not have permission to access this resource')
return current_user
return role_checker
def get_current_user_optional(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(HTTPBearer(auto_error=False)),
access_token: Optional[str] = Cookie(None, alias='accessToken'),
db: Session = Depends(get_db)
) -> Optional[User]:
"""
Get current user optionally from either Authorization header or httpOnly cookie.
Returns None if no valid token is found.
"""
# Try to get token from Authorization header first
token = None
if credentials:
token = credentials.credentials
# Fall back to cookie if no header token
if not token and access_token:
token = access_token
if not token:
return None
try:
jwt_secret = get_jwt_secret()
payload = jwt.decode(token, jwt_secret, algorithms=['HS256'])
user_id: int = payload.get('userId')
if user_id is None:
return None
except (JWTError, ValueError):
return None
user = db.query(User).filter(User.id == user_id).first()
return user
def verify_token(token: str) -> dict:
jwt_secret = get_jwt_secret()
payload = jwt.decode(token, jwt_secret, algorithms=['HS256'])
return payload

View File

@@ -0,0 +1,158 @@
from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from typing import Optional
import secrets
import hmac
import hashlib
from ...shared.config.logging_config import get_logger
from ...shared.config.settings import settings
logger = get_logger(__name__)
# Safe HTTP methods that don't require CSRF protection
SAFE_METHODS = {'GET', 'HEAD', 'OPTIONS'}
class CSRFProtectionMiddleware(BaseHTTPMiddleware):
"""
CSRF Protection Middleware
Validates CSRF tokens for state-changing requests (POST, PUT, DELETE, PATCH).
Uses Double Submit Cookie pattern for stateless CSRF protection.
"""
CSRF_TOKEN_COOKIE_NAME = 'XSRF-TOKEN'
CSRF_TOKEN_HEADER_NAME = 'X-XSRF-TOKEN'
CSRF_SECRET_LENGTH = 32
async def dispatch(self, request: Request, call_next):
path = request.url.path
# Skip CSRF protection for certain endpoints that don't need it
# (e.g., public APIs, webhooks with their own validation)
is_exempt = self._is_exempt_path(path)
# Get or generate CSRF token (always generate for all requests to ensure cookie is set)
csrf_token = request.cookies.get(self.CSRF_TOKEN_COOKIE_NAME)
if not csrf_token:
csrf_token = self._generate_token()
# Skip CSRF validation for safe methods (GET, HEAD, OPTIONS) and exempt paths
if request.method in SAFE_METHODS or is_exempt:
response = await call_next(request)
else:
# For state-changing requests, validate the token
if request.method in {'POST', 'PUT', 'DELETE', 'PATCH'}:
header_token = request.headers.get(self.CSRF_TOKEN_HEADER_NAME)
if not header_token:
logger.warning(f"CSRF token missing in header for {request.method} {path}")
# Create error response with CSRF cookie set so frontend can retry
from fastapi.responses import JSONResponse
error_response = JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"status": "error", "message": "CSRF token missing. Please include X-XSRF-TOKEN header."}
)
# Set cookie even on error so client can get the token and retry
if not request.cookies.get(self.CSRF_TOKEN_COOKIE_NAME):
error_response.set_cookie(
key=self.CSRF_TOKEN_COOKIE_NAME,
value=csrf_token,
httponly=False,
secure=settings.is_production,
samesite='lax', # Changed to 'lax' for better cross-origin support
max_age=86400 * 7,
path='/'
)
return error_response
# Validate token using constant-time comparison
if not self._verify_token(csrf_token, header_token):
logger.warning(f"CSRF token validation failed for {request.method} {path}")
# Create error response with CSRF cookie set so frontend can retry
from fastapi.responses import JSONResponse
error_response = JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"status": "error", "message": "Invalid CSRF token. Please refresh the page and try again."}
)
# Set cookie even on error so client can get the token and retry
if not request.cookies.get(self.CSRF_TOKEN_COOKIE_NAME):
error_response.set_cookie(
key=self.CSRF_TOKEN_COOKIE_NAME,
value=csrf_token,
httponly=False,
secure=settings.is_production,
samesite='lax', # Changed to 'lax' for better cross-origin support
max_age=86400 * 7,
path='/'
)
return error_response
# Process request
response = await call_next(request)
# Always set CSRF token cookie if not present (ensures client always has it)
# This allows frontend to read it from cookies for subsequent requests
if not request.cookies.get(self.CSRF_TOKEN_COOKIE_NAME):
# Set secure cookie with SameSite protection
response.set_cookie(
key=self.CSRF_TOKEN_COOKIE_NAME,
value=csrf_token,
httponly=False, # Must be accessible to JavaScript for header submission
secure=settings.is_production, # HTTPS only in production
samesite='lax', # Changed to 'lax' for better cross-origin support
max_age=86400 * 7, # 7 days
path='/'
)
return response
def _is_exempt_path(self, path: str) -> bool:
"""
Check if path is exempt from CSRF protection.
Exempt paths:
- Authentication endpoints (login, register, logout, refresh token)
- Webhook endpoints (they have their own signature validation)
- Health check endpoints
- Static file endpoints
"""
exempt_patterns = [
'/api/auth/', # All authentication endpoints
'/api/webhooks/',
'/api/stripe/webhook',
'/api/payments/stripe/webhook',
'/api/paypal/webhook',
'/health',
'/api/health',
'/static/',
'/docs',
'/redoc',
'/openapi.json'
]
return any(path.startswith(pattern) for pattern in exempt_patterns)
def _generate_token(self) -> str:
"""Generate a secure random CSRF token."""
return secrets.token_urlsafe(self.CSRF_SECRET_LENGTH)
def _verify_token(self, cookie_token: str, header_token: str) -> bool:
"""
Verify CSRF token using constant-time comparison.
Uses the Double Submit Cookie pattern - the token in the cookie
must match the token in the header.
"""
if not cookie_token or not header_token:
return False
# Constant-time comparison to prevent timing attacks
return hmac.compare_digest(cookie_token, header_token)
def get_csrf_token(request: Request) -> Optional[str]:
"""Helper function to get CSRF token from request cookies."""
return request.cookies.get(CSRFProtectionMiddleware.CSRF_TOKEN_COOKIE_NAME)

View File

@@ -0,0 +1,195 @@
from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from sqlalchemy.orm import Session
from typing import List
from ...shared.config.database import get_db
from ...shared.config.logging_config import get_logger
from ...shared.config.settings import settings
from ..models.security_event import IPWhitelist, IPBlacklist, SecurityEvent, SecurityEventType, SecurityEventSeverity
from datetime import datetime
import ipaddress
logger = get_logger(__name__)
class IPWhitelistMiddleware(BaseHTTPMiddleware):
"""Middleware to enforce IP whitelisting and blacklisting"""
def __init__(self, app, enabled: bool = True, whitelist_only: bool = False):
super().__init__(app)
self.enabled = enabled
self.whitelist_only = whitelist_only # If True, only whitelisted IPs allowed
async def dispatch(self, request: Request, call_next):
if not self.enabled:
return await call_next(request)
# Skip IP check for health checks and public endpoints
if request.url.path in ['/health', '/api/health', '/metrics']:
return await call_next(request)
client_ip = self._get_client_ip(request)
if not client_ip:
logger.warning("Could not determine client IP address")
return await call_next(request)
# Check blacklist first
if await self._is_blacklisted(client_ip):
await self._log_security_event(
request,
SecurityEventType.ip_blocked,
SecurityEventSeverity.high,
f"Blocked request from blacklisted IP: {client_ip}"
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"status": "error", "message": "Access denied"}
)
# Check whitelist if whitelist_only mode is enabled
if self.whitelist_only:
if not await self._is_whitelisted(client_ip):
await self._log_security_event(
request,
SecurityEventType.permission_denied,
SecurityEventSeverity.medium,
f"Blocked request from non-whitelisted IP: {client_ip}"
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"status": "error", "message": "Access denied. IP not whitelisted."}
)
return await call_next(request)
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP address from request"""
# Check for forwarded IP (when behind proxy/load balancer)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For can contain multiple IPs, take the first one
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip.strip()
# Fallback to direct client IP
if request.client:
return request.client.host
return None
async def _is_blacklisted(self, ip_address: str) -> bool:
"""Check if IP address is blacklisted"""
try:
db_gen = get_db()
db = next(db_gen)
try:
# Check exact match
blacklist_entry = db.query(IPBlacklist).filter(
IPBlacklist.ip_address == ip_address,
IPBlacklist.is_active == True
).first()
if blacklist_entry:
# Check if temporary block has expired
if blacklist_entry.blocked_until and blacklist_entry.blocked_until < datetime.utcnow():
# Block expired, deactivate it
blacklist_entry.is_active = False
db.commit()
return False
return True
# Check CIDR ranges (if needed)
# This is a simplified version - you might want to cache this
all_blacklist = db.query(IPBlacklist).filter(
IPBlacklist.is_active == True
).all()
for entry in all_blacklist:
try:
if '/' in entry.ip_address: # CIDR notation
network = ipaddress.ip_network(entry.ip_address, strict=False)
if ipaddress.ip_address(ip_address) in network:
return True
except (ValueError, ipaddress.AddressValueError):
continue
return False
finally:
db.close()
except Exception as e:
logger.error(f"Error checking IP blacklist: {str(e)}")
return False
async def _is_whitelisted(self, ip_address: str) -> bool:
"""Check if IP address is whitelisted"""
try:
db_gen = get_db()
db = next(db_gen)
try:
# Check exact match
whitelist_entry = db.query(IPWhitelist).filter(
IPWhitelist.ip_address == ip_address,
IPWhitelist.is_active == True
).first()
if whitelist_entry:
return True
# Check CIDR ranges
all_whitelist = db.query(IPWhitelist).filter(
IPWhitelist.is_active == True
).all()
for entry in all_whitelist:
try:
if '/' in entry.ip_address: # CIDR notation
network = ipaddress.ip_network(entry.ip_address, strict=False)
if ipaddress.ip_address(ip_address) in network:
return True
except (ValueError, ipaddress.AddressValueError):
continue
return False
finally:
db.close()
except Exception as e:
logger.error(f"Error checking IP whitelist: {str(e)}")
return False
async def _log_security_event(
self,
request: Request,
event_type: SecurityEventType,
severity: SecurityEventSeverity,
description: str
):
"""Log security event"""
try:
db_gen = get_db()
db = next(db_gen)
try:
client_ip = self._get_client_ip(request)
event = SecurityEvent(
event_type=event_type,
severity=severity,
ip_address=client_ip,
user_agent=request.headers.get("User-Agent"),
request_path=str(request.url.path),
request_method=request.method,
description=description,
details={
"url": str(request.url),
"headers": dict(request.headers)
}
)
db.add(event)
db.commit()
finally:
db.close()
except Exception as e:
logger.error(f"Error logging security event: {str(e)}")

View File

@@ -0,0 +1,44 @@
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from ...shared.config.logging_config import get_logger
from ...shared.config.settings import settings
logger = get_logger(__name__)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
security_headers = {'X-Content-Type-Options': 'nosniff', 'X-Frame-Options': 'DENY', 'X-XSS-Protection': '1; mode=block', 'Referrer-Policy': 'strict-origin-when-cross-origin', 'Permissions-Policy': 'geolocation=(), microphone=(), camera=()'}
security_headers.setdefault('Cross-Origin-Resource-Policy', 'cross-origin')
if settings.is_production:
# Enhanced CSP with stricter directives
# Using 'strict-dynamic' for better security with nonce-based scripts
# Note: For React/Vite, consider implementing nonce-based CSP in the future
# Current policy balances security with framework requirements
security_headers['Content-Security-Policy'] = (
"default-src 'self'; "
"script-src 'self' 'strict-dynamic' https://js.stripe.com; "
"script-src-elem 'self' 'unsafe-inline' https://js.stripe.com; " # Allow inline scripts for Vite/React
"style-src 'self' 'unsafe-inline'; " # Required for React/Vite
"img-src 'self' data: https:; "
"font-src 'self' data:; "
"connect-src 'self' https: https://js.stripe.com https://hooks.stripe.com; "
"frame-src 'self' https://js.stripe.com https://hooks.stripe.com; "
"base-uri 'self'; "
"form-action 'self'; "
"frame-ancestors 'none'; "
"object-src 'none'; "
"upgrade-insecure-requests; "
"block-all-mixed-content" # Block mixed HTTP/HTTPS content
)
# HSTS with preload directive (only add preload if domain is ready for it)
# Preload requires manual submission to hstspreload.org
# Include preload directive only if explicitly enabled
hsts_directive = 'max-age=31536000; includeSubDomains'
if getattr(settings, 'HSTS_PRELOAD_ENABLED', False):
hsts_directive += '; preload'
security_headers['Strict-Transport-Security'] = hsts_directive
for header, value in security_headers.items():
response.headers[header] = value
return response

View File

View File

@@ -0,0 +1,87 @@
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, JSON, Enum, Boolean
from sqlalchemy.orm import relationship
from datetime import datetime
import enum
from ...shared.config.database import Base
class DataSubjectRequestType(str, enum.Enum):
access = 'access' # Right to access
rectification = 'rectification' # Right to rectification
erasure = 'erasure' # Right to erasure (right to be forgotten)
portability = 'portability' # Right to data portability
restriction = 'restriction' # Right to restriction of processing
objection = 'objection' # Right to object
class DataSubjectRequestStatus(str, enum.Enum):
pending = 'pending'
in_progress = 'in_progress'
completed = 'completed'
rejected = 'rejected'
cancelled = 'cancelled'
class DataSubjectRequest(Base):
__tablename__ = 'data_subject_requests'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('users.id'), nullable=True, index=True)
email = Column(String(255), nullable=False, index=True)
request_type = Column(Enum(DataSubjectRequestType), nullable=False, index=True)
status = Column(Enum(DataSubjectRequestStatus), nullable=False, default=DataSubjectRequestStatus.pending, index=True)
# Request details
description = Column(Text, nullable=True)
verification_token = Column(String(100), nullable=True, unique=True, index=True)
verified = Column(Boolean, nullable=False, default=False)
verified_at = Column(DateTime, nullable=True)
# Processing
assigned_to = Column(Integer, ForeignKey('users.id'), nullable=True)
notes = Column(Text, nullable=True)
response_data = Column(JSON, nullable=True) # For access requests, store the data
# Completion
completed_at = Column(DateTime, nullable=True)
completed_by = Column(Integer, ForeignKey('users.id'), nullable=True)
# Metadata
ip_address = Column(String(45), nullable=True)
user_agent = Column(String(500), nullable=True)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
# Relationships
user = Column(Integer, ForeignKey('users.id'), nullable=True)
assignee = relationship('User', foreign_keys=[assigned_to])
completer = relationship('User', foreign_keys=[completed_by])
class DataRetentionPolicy(Base):
__tablename__ = 'data_retention_policies'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
data_type = Column(String(100), nullable=False, unique=True) # e.g., 'user_data', 'booking_data', 'payment_data'
retention_days = Column(Integer, nullable=False) # Days to retain data
auto_delete = Column(Boolean, nullable=False, default=False)
description = Column(Text, nullable=True)
is_active = Column(Boolean, nullable=False, default=True)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
class ConsentRecord(Base):
__tablename__ = 'consent_records'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True)
consent_type = Column(String(100), nullable=False, index=True) # 'marketing', 'analytics', 'cookies', etc.
granted = Column(Boolean, nullable=False, default=False)
granted_at = Column(DateTime, nullable=True)
revoked_at = Column(DateTime, nullable=True)
ip_address = Column(String(45), nullable=True)
user_agent = Column(String(500), nullable=True)
version = Column(String(50), nullable=True) # Policy version when consent was given
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
# Relationships
user = relationship('User')

View File

@@ -0,0 +1,135 @@
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, JSON, Enum, Boolean
from sqlalchemy.orm import relationship
from datetime import datetime
import enum
from ...shared.config.database import Base
class SecurityEventType(str, enum.Enum):
login_attempt = 'login_attempt'
login_success = 'login_success'
login_failure = 'login_failure'
logout = 'logout'
password_change = 'password_change'
password_reset = 'password_reset'
account_locked = 'account_locked'
account_unlocked = 'account_unlocked'
permission_denied = 'permission_denied'
suspicious_activity = 'suspicious_activity'
data_access = 'data_access'
data_modification = 'data_modification'
data_deletion = 'data_deletion'
api_access = 'api_access'
ip_blocked = 'ip_blocked'
rate_limit_exceeded = 'rate_limit_exceeded'
oauth_login = 'oauth_login'
sso_login = 'sso_login'
class SecurityEventSeverity(str, enum.Enum):
low = 'low'
medium = 'medium'
high = 'high'
critical = 'critical'
class SecurityEvent(Base):
__tablename__ = 'security_events'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('users.id'), nullable=True, index=True)
event_type = Column(Enum(SecurityEventType), nullable=False, index=True)
severity = Column(Enum(SecurityEventSeverity), nullable=False, default=SecurityEventSeverity.medium, index=True)
# Request details
ip_address = Column(String(45), nullable=True, index=True)
user_agent = Column(String(500), nullable=True)
request_path = Column(String(500), nullable=True)
request_method = Column(String(10), nullable=True)
request_id = Column(String(36), nullable=True, index=True)
# Event details
description = Column(Text, nullable=True)
details = Column(JSON, nullable=True)
extra_data = Column(JSON, nullable=True) # Additional metadata (metadata is reserved by SQLAlchemy)
# Status
resolved = Column(Boolean, nullable=False, default=False)
resolved_at = Column(DateTime, nullable=True)
resolved_by = Column(Integer, ForeignKey('users.id'), nullable=True)
resolution_notes = Column(Text, nullable=True)
# Location (if available)
country = Column(String(100), nullable=True)
city = Column(String(100), nullable=True)
latitude = Column(String(20), nullable=True)
longitude = Column(String(20), nullable=True)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
# Relationships
user = relationship('User', foreign_keys=[user_id])
resolver = relationship('User', foreign_keys=[resolved_by])
class IPWhitelist(Base):
__tablename__ = 'ip_whitelist'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
ip_address = Column(String(45), nullable=False, unique=True, index=True)
description = Column(String(255), nullable=True)
is_active = Column(Boolean, nullable=False, default=True)
created_by = Column(Integer, ForeignKey('users.id'), nullable=True)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
# Relationships
creator = relationship('User', foreign_keys=[created_by])
class IPBlacklist(Base):
__tablename__ = 'ip_blacklist'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
ip_address = Column(String(45), nullable=False, unique=True, index=True)
reason = Column(Text, nullable=True)
is_active = Column(Boolean, nullable=False, default=True)
blocked_until = Column(DateTime, nullable=True) # Temporary block
created_by = Column(Integer, ForeignKey('users.id'), nullable=True)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
# Relationships
creator = relationship('User', foreign_keys=[created_by])
class OAuthProvider(Base):
__tablename__ = 'oauth_providers'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
name = Column(String(50), nullable=False, unique=True) # google, microsoft, github, etc.
display_name = Column(String(100), nullable=False)
client_id = Column(String(500), nullable=False)
client_secret = Column(String(500), nullable=False) # Should be encrypted
authorization_url = Column(String(500), nullable=False)
token_url = Column(String(500), nullable=False)
userinfo_url = Column(String(500), nullable=False)
scopes = Column(String(500), nullable=True) # space-separated scopes
is_active = Column(Boolean, nullable=False, default=True)
is_sso_enabled = Column(Boolean, nullable=False, default=False)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
class OAuthToken(Base):
__tablename__ = 'oauth_tokens'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True)
provider_id = Column(Integer, ForeignKey('oauth_providers.id'), nullable=False)
provider_user_id = Column(String(255), nullable=False) # User ID from OAuth provider
access_token = Column(Text, nullable=False) # Should be encrypted
refresh_token = Column(Text, nullable=True) # Should be encrypted
token_type = Column(String(50), nullable=True, default='Bearer')
expires_at = Column(DateTime, nullable=True)
scopes = Column(String(500), nullable=True)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
# Relationships
user = relationship('User')
provider = relationship('OAuthProvider')

View File

View File

@@ -0,0 +1,744 @@
from fastapi import APIRouter, Depends, HTTPException, status, Request, Query
from sqlalchemy.orm import Session
from typing import Optional, List
from datetime import datetime, timedelta
from pydantic import BaseModel, EmailStr
import logging
from ...shared.config.logging_config import get_logger
logger = get_logger(__name__)
from ...shared.config.database import get_db
from ...security.middleware.auth import get_current_user, authorize_roles
from ...auth.models.user import User
from ..models.security_event import (
SecurityEvent,
SecurityEventType,
SecurityEventSeverity,
IPWhitelist,
IPBlacklist
)
from ..services.security_monitoring_service import security_monitoring_service
from ..services.gdpr_service import gdpr_service
from ..services.encryption_service import encryption_service
from ..services.security_scan_service import security_scan_service
# OAuth service is optional - only import if httpx is available
try:
from ...auth.services.oauth_service import oauth_service
OAUTH_AVAILABLE = True
except ImportError:
OAUTH_AVAILABLE = False
oauth_service = None
router = APIRouter(prefix="/security", tags=["Security"])
# Security Events
class SecurityEventResponse(BaseModel):
id: int
user_id: Optional[int]
event_type: str
severity: str
ip_address: Optional[str]
description: Optional[str]
created_at: datetime
class Config:
from_attributes = True
@router.get("/events", response_model=List[SecurityEventResponse])
async def get_security_events(
user_id: Optional[int] = Query(None),
event_type: Optional[str] = Query(None),
severity: Optional[str] = Query(None),
ip_address: Optional[str] = Query(None),
resolved: Optional[bool] = Query(None),
days: int = Query(7, ge=1, le=90),
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Get security events"""
event_type_enum = None
if event_type:
try:
event_type_enum = SecurityEventType(event_type)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid event type")
severity_enum = None
if severity:
try:
severity_enum = SecurityEventSeverity(severity)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid severity")
start_date = datetime.utcnow() - timedelta(days=days)
events = security_monitoring_service.get_security_events(
db=db,
user_id=user_id,
event_type=event_type_enum,
severity=severity_enum,
ip_address=ip_address,
resolved=resolved,
start_date=start_date,
limit=limit,
offset=offset
)
return events
@router.get("/events/stats")
async def get_security_stats(
days: int = Query(7, ge=1, le=90),
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Get security statistics"""
stats = security_monitoring_service.get_security_stats(db=db, days=days)
return stats
@router.post("/events/{event_id}/resolve")
async def resolve_security_event(
event_id: int,
resolution_notes: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Mark a security event as resolved"""
try:
event = security_monitoring_service.resolve_event(
db=db,
event_id=event_id,
resolved_by=current_user.id,
resolution_notes=resolution_notes
)
return {"status": "success", "message": "Event resolved", "event_id": event.id}
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
# IP Whitelist/Blacklist
class IPWhitelistCreate(BaseModel):
ip_address: str
description: Optional[str] = None
class IPBlacklistCreate(BaseModel):
ip_address: str
reason: Optional[str] = None
blocked_until: Optional[datetime] = None
@router.post("/ip/whitelist")
async def add_ip_to_whitelist(
data: IPWhitelistCreate,
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Add IP address to whitelist"""
# Check if already exists
existing = db.query(IPWhitelist).filter(
IPWhitelist.ip_address == data.ip_address
).first()
if existing:
existing.is_active = True
existing.description = data.description
db.commit()
return {"status": "success", "message": "IP whitelist updated"}
whitelist = IPWhitelist(
ip_address=data.ip_address,
description=data.description,
created_by=current_user.id
)
db.add(whitelist)
db.commit()
return {"status": "success", "message": "IP added to whitelist"}
@router.delete("/ip/whitelist/{ip_address}")
async def remove_ip_from_whitelist(
ip_address: str,
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Remove IP address from whitelist"""
whitelist = db.query(IPWhitelist).filter(
IPWhitelist.ip_address == ip_address
).first()
if not whitelist:
raise HTTPException(status_code=404, detail="IP not found in whitelist")
whitelist.is_active = False
db.commit()
return {"status": "success", "message": "IP removed from whitelist"}
@router.get("/ip/whitelist")
async def get_whitelisted_ips(
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Get all whitelisted IPs"""
whitelist = db.query(IPWhitelist).filter(
IPWhitelist.is_active == True
).all()
return [{"id": w.id, "ip_address": w.ip_address, "description": w.description} for w in whitelist]
@router.post("/ip/blacklist")
async def add_ip_to_blacklist(
data: IPBlacklistCreate,
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Add IP address to blacklist"""
existing = db.query(IPBlacklist).filter(
IPBlacklist.ip_address == data.ip_address
).first()
if existing:
existing.is_active = True
existing.reason = data.reason
existing.blocked_until = data.blocked_until
db.commit()
return {"status": "success", "message": "IP blacklist updated"}
blacklist = IPBlacklist(
ip_address=data.ip_address,
reason=data.reason,
blocked_until=data.blocked_until,
created_by=current_user.id
)
db.add(blacklist)
db.commit()
return {"status": "success", "message": "IP added to blacklist"}
@router.delete("/ip/blacklist/{ip_address}")
async def remove_ip_from_blacklist(
ip_address: str,
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Remove IP address from blacklist"""
blacklist = db.query(IPBlacklist).filter(
IPBlacklist.ip_address == ip_address
).first()
if not blacklist:
raise HTTPException(status_code=404, detail="IP not found in blacklist")
blacklist.is_active = False
db.commit()
return {"status": "success", "message": "IP removed from blacklist"}
@router.get("/ip/blacklist")
async def get_blacklisted_ips(
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Get all blacklisted IPs"""
blacklist = db.query(IPBlacklist).filter(
IPBlacklist.is_active == True
).all()
return [{"id": b.id, "ip_address": b.ip_address, "reason": b.reason, "blocked_until": b.blocked_until} for b in blacklist]
# OAuth Provider Management
class OAuthProviderCreate(BaseModel):
name: str
display_name: str
client_id: str
client_secret: str
authorization_url: str
token_url: str
userinfo_url: str
scopes: Optional[str] = None
is_active: bool = True
is_sso_enabled: bool = False
class OAuthProviderUpdate(BaseModel):
display_name: Optional[str] = None
client_id: Optional[str] = None
client_secret: Optional[str] = None
authorization_url: Optional[str] = None
token_url: Optional[str] = None
userinfo_url: Optional[str] = None
scopes: Optional[str] = None
is_active: Optional[bool] = None
is_sso_enabled: Optional[bool] = None
@router.get("/oauth/providers")
async def get_oauth_providers(
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Get all OAuth providers"""
from ..models.security_event import OAuthProvider
providers = db.query(OAuthProvider).all()
return [{
"id": p.id,
"name": p.name,
"display_name": p.display_name,
"is_active": p.is_active,
"is_sso_enabled": p.is_sso_enabled,
"created_at": p.created_at.isoformat() if p.created_at else None
} for p in providers]
@router.post("/oauth/providers")
async def create_oauth_provider(
data: OAuthProviderCreate,
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Create a new OAuth provider"""
from ..models.security_event import OAuthProvider
from ..services.encryption_service import encryption_service
# Encrypt client secret
encrypted_secret = encryption_service.encrypt(data.client_secret)
provider = OAuthProvider(
name=data.name,
display_name=data.display_name,
client_id=data.client_id,
client_secret=encrypted_secret,
authorization_url=data.authorization_url,
token_url=data.token_url,
userinfo_url=data.userinfo_url,
scopes=data.scopes,
is_active=data.is_active,
is_sso_enabled=data.is_sso_enabled
)
db.add(provider)
db.commit()
db.refresh(provider)
return {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
"is_active": provider.is_active,
"is_sso_enabled": provider.is_sso_enabled
}
@router.put("/oauth/providers/{provider_id}")
async def update_oauth_provider(
provider_id: int,
data: OAuthProviderUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Update an OAuth provider"""
from ..models.security_event import OAuthProvider
from ..services.encryption_service import encryption_service
provider = db.query(OAuthProvider).filter(OAuthProvider.id == provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail="OAuth provider not found")
if data.display_name is not None:
provider.display_name = data.display_name
if data.client_id is not None:
provider.client_id = data.client_id
if data.client_secret is not None:
provider.client_secret = encryption_service.encrypt(data.client_secret)
if data.authorization_url is not None:
provider.authorization_url = data.authorization_url
if data.token_url is not None:
provider.token_url = data.token_url
if data.userinfo_url is not None:
provider.userinfo_url = data.userinfo_url
if data.scopes is not None:
provider.scopes = data.scopes
if data.is_active is not None:
provider.is_active = data.is_active
if data.is_sso_enabled is not None:
provider.is_sso_enabled = data.is_sso_enabled
db.commit()
db.refresh(provider)
return {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
"is_active": provider.is_active,
"is_sso_enabled": provider.is_sso_enabled
}
@router.delete("/oauth/providers/{provider_id}")
async def delete_oauth_provider(
provider_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Delete an OAuth provider"""
from ..models.security_event import OAuthProvider
provider = db.query(OAuthProvider).filter(OAuthProvider.id == provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail="OAuth provider not found")
db.delete(provider)
db.commit()
return {"status": "success", "message": "OAuth provider deleted"}
# GDPR Request Management
@router.get("/gdpr/requests")
async def get_gdpr_requests(
status: Optional[str] = Query(None),
request_type: Optional[str] = Query(None),
limit: int = Query(50, ge=1, le=100),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Get all GDPR requests"""
from ..models.gdpr_compliance import DataSubjectRequest, DataSubjectRequestStatus, DataSubjectRequestType
query = db.query(DataSubjectRequest)
if status:
try:
status_enum = DataSubjectRequestStatus(status)
query = query.filter(DataSubjectRequest.status == status_enum)
except ValueError:
pass
if request_type:
try:
type_enum = DataSubjectRequestType(request_type)
query = query.filter(DataSubjectRequest.request_type == type_enum)
except ValueError:
pass
requests = query.order_by(DataSubjectRequest.created_at.desc()).offset(offset).limit(limit).all()
return [{
"id": r.id,
"user_id": r.user_id,
"email": r.email,
"request_type": r.request_type.value,
"status": r.status.value,
"description": r.description,
"verified": r.verified,
"verified_at": r.verified_at.isoformat() if r.verified_at else None,
"assigned_to": r.assigned_to,
"completed_at": r.completed_at.isoformat() if r.completed_at else None,
"created_at": r.created_at.isoformat() if r.created_at else None
} for r in requests]
@router.get("/gdpr/requests/{request_id}")
async def get_gdpr_request(
request_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Get a specific GDPR request"""
from ..models.gdpr_compliance import DataSubjectRequest
request = db.query(DataSubjectRequest).filter(DataSubjectRequest.id == request_id).first()
if not request:
raise HTTPException(status_code=404, detail="GDPR request not found")
return {
"id": request.id,
"user_id": request.user_id,
"email": request.email,
"request_type": request.request_type.value,
"status": request.status.value,
"description": request.description,
"verified": request.verified,
"verified_at": request.verified_at.isoformat() if request.verified_at else None,
"assigned_to": request.assigned_to,
"notes": request.notes,
"response_data": request.response_data,
"completed_at": request.completed_at.isoformat() if request.completed_at else None,
"created_at": request.created_at.isoformat() if request.created_at else None
}
@router.post("/gdpr/requests/{request_id}/assign")
async def assign_gdpr_request(
request_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Assign a GDPR request to the current admin"""
from ..models.gdpr_compliance import DataSubjectRequest
request = db.query(DataSubjectRequest).filter(DataSubjectRequest.id == request_id).first()
if not request:
raise HTTPException(status_code=404, detail="GDPR request not found")
request.assigned_to = current_user.id
db.commit()
return {"status": "success", "message": "Request assigned"}
@router.post("/gdpr/requests/{request_id}/complete")
async def complete_gdpr_request(
request_id: int,
notes: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Mark a GDPR request as completed"""
from ..models.gdpr_compliance import DataSubjectRequest, DataSubjectRequestStatus
request = db.query(DataSubjectRequest).filter(DataSubjectRequest.id == request_id).first()
if not request:
raise HTTPException(status_code=404, detail="GDPR request not found")
request.status = DataSubjectRequestStatus.completed
request.completed_at = datetime.utcnow()
request.completed_by = current_user.id
if notes:
request.notes = notes
db.commit()
return {"status": "success", "message": "Request completed"}
# OAuth Routes
@router.get("/oauth/{provider_name}/authorize")
async def oauth_authorize(
provider_name: str,
redirect_uri: str = Query(...),
state: Optional[str] = None,
db: Session = Depends(get_db)
):
"""Get OAuth authorization URL"""
if not OAUTH_AVAILABLE:
raise HTTPException(status_code=503, detail="OAuth service is not available. Please install httpx: pip install httpx")
try:
auth_url = oauth_service.get_authorization_url(
db=db,
provider_name=provider_name,
redirect_uri=redirect_uri,
state=state
)
return {"authorization_url": auth_url}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/oauth/{provider_name}/callback")
async def oauth_callback(
provider_name: str,
code: str = Query(...),
redirect_uri: str = Query(...),
state: Optional[str] = None,
db: Session = Depends(get_db)
):
"""Handle OAuth callback"""
if not OAUTH_AVAILABLE:
raise HTTPException(status_code=503, detail="OAuth service is not available. Please install httpx: pip install httpx")
try:
# Exchange code for token
token_data = await oauth_service.exchange_code_for_token(
db=db,
provider_name=provider_name,
code=code,
redirect_uri=redirect_uri
)
# Get user info
user_info = await oauth_service.get_user_info(
db=db,
provider_name=provider_name,
access_token=token_data['access_token']
)
# Find or create user
user = oauth_service.find_or_create_user_from_oauth(
db=db,
provider_name=provider_name,
user_info=user_info
)
# Save OAuth token
from ..models.security_event import OAuthProvider
provider = db.query(OAuthProvider).filter(
OAuthProvider.name == provider_name
).first()
oauth_service.save_oauth_token(
db=db,
user_id=user.id,
provider_id=provider.id,
provider_user_id=user_info.get('sub') or user_info.get('id'),
access_token=token_data['access_token'],
refresh_token=token_data.get('refresh_token'),
expires_in=token_data.get('expires_in'),
scopes=token_data.get('scope')
)
# Generate JWT tokens for the user
from ...auth.services.auth_service import auth_service
tokens = auth_service.generate_tokens(user.id)
return {
"status": "success",
"token": tokens["accessToken"],
"refreshToken": tokens["refreshToken"],
"user": {
"id": user.id,
"email": user.email,
"full_name": user.full_name
}
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# GDPR Routes
class DataSubjectRequestCreate(BaseModel):
email: EmailStr
request_type: str
description: Optional[str] = None
@router.post("/gdpr/request")
async def create_data_subject_request(
data: DataSubjectRequestCreate,
request: Request,
db: Session = Depends(get_db)
):
"""Create a GDPR data subject request"""
try:
request_type = DataSubjectRequestType(data.request_type)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid request type")
try:
gdpr_request = gdpr_service.create_data_subject_request(
db=db,
email=data.email,
request_type=request_type,
description=data.description,
ip_address=request.client.host if request.client else None,
user_agent=request.headers.get("User-Agent")
)
return {
"status": "success",
"message": "Request created. Please check your email for verification.",
"verification_token": gdpr_request.verification_token
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/gdpr/verify/{verification_token}")
async def verify_data_subject_request(
verification_token: str,
db: Session = Depends(get_db)
):
"""Verify a data subject request"""
verified = gdpr_service.verify_request(db=db, verification_token=verification_token)
if not verified:
raise HTTPException(status_code=404, detail="Invalid verification token")
return {"status": "success", "message": "Request verified"}
@router.get("/gdpr/data/{user_id}")
async def get_user_data(
user_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get user data (GDPR access request)"""
# Users can only access their own data, unless admin
if current_user.id != user_id:
# Check if user is admin
from ...auth.models.role import Role
role = db.query(Role).filter(Role.id == current_user.role_id).first()
if not role or role.name != "admin":
raise HTTPException(status_code=403, detail="Access denied")
try:
data = gdpr_service.get_user_data(db=db, user_id=user_id)
return {"status": "success", "data": data}
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.delete("/gdpr/data/{user_id}")
async def delete_user_data(
user_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Delete user data (GDPR erasure request)"""
success = gdpr_service.delete_user_data(db=db, user_id=user_id)
if not success:
raise HTTPException(status_code=404, detail="User not found")
return {"status": "success", "message": "User data deleted"}
@router.get("/gdpr/export/{user_id}")
async def export_user_data(
user_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Export user data (GDPR portability request)"""
if current_user.id != user_id:
# Check if user is admin
from ...auth.models.role import Role
role = db.query(Role).filter(Role.id == current_user.role_id).first()
if not role or role.name != "admin":
raise HTTPException(status_code=403, detail="Access denied")
try:
data = gdpr_service.export_user_data(db=db, user_id=user_id)
return {"status": "success", "data": data}
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
# Security Scanning
@router.post("/scan/run")
async def run_security_scan(
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Run a manual security scan"""
try:
results = security_scan_service.run_full_scan(db=db)
return {"status": "success", "results": results}
except Exception as e:
import traceback
error_details = traceback.format_exc()
logger.error(f"Security scan failed: {str(e)}\n{error_details}")
raise HTTPException(status_code=500, detail=f"Scan failed: {str(e)}")
@router.post("/scan/schedule")
async def schedule_security_scan(
interval_hours: int = Query(24, ge=1, le=168), # 1 hour to 1 week
db: Session = Depends(get_db),
current_user: User = Depends(authorize_roles("admin"))
):
"""Schedule automatic security scans"""
try:
schedule = security_scan_service.schedule_scan(db=db, interval_hours=interval_hours)
return {"status": "success", "schedule": schedule}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to schedule scan: {str(e)}")

View File

View File

@@ -0,0 +1,89 @@
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.backends import default_backend
import base64
import os
from typing import Optional
import logging
from ...shared.config.settings import settings
logger = logging.getLogger(__name__)
class EncryptionService:
"""Service for data encryption at rest"""
def __init__(self, encryption_key: Optional[str] = None):
"""
Initialize encryption service
Args:
encryption_key: Base64-encoded encryption key. If not provided, will use ENCRYPTION_KEY from settings or env var
"""
if encryption_key:
self.key = encryption_key.encode()
else:
# Try to get key from settings first (loads from .env), then fall back to os.getenv
key_str = getattr(settings, 'ENCRYPTION_KEY', None) or os.getenv('ENCRYPTION_KEY')
if not key_str or key_str.strip() == '':
# Generate a key if not provided (for development only)
logger.warning("ENCRYPTION_KEY not set. Generating temporary key. This should be set in production!")
key = Fernet.generate_key()
self.key = key
else:
self.key = key_str.encode()
try:
self.cipher = Fernet(self.key)
except Exception as e:
logger.error(f"Failed to initialize encryption: {str(e)}")
raise
def encrypt(self, data: str) -> str:
"""Encrypt a string"""
try:
if not data:
return data
encrypted = self.cipher.encrypt(data.encode())
return base64.urlsafe_b64encode(encrypted).decode()
except Exception as e:
logger.error(f"Encryption failed: {str(e)}")
raise
def decrypt(self, encrypted_data: str) -> str:
"""Decrypt a string"""
try:
if not encrypted_data:
return encrypted_data
decoded = base64.urlsafe_b64decode(encrypted_data.encode())
decrypted = self.cipher.decrypt(decoded)
return decrypted.decode()
except Exception as e:
logger.error(f"Decryption failed: {str(e)}")
raise
def encrypt_dict(self, data: dict) -> dict:
"""Encrypt sensitive fields in a dictionary"""
encrypted = {}
sensitive_fields = ['password', 'token', 'secret', 'key', 'api_key', 'access_token', 'refresh_token']
for key, value in data.items():
if any(sensitive in key.lower() for sensitive in sensitive_fields):
if isinstance(value, str):
encrypted[key] = self.encrypt(value)
else:
encrypted[key] = value
else:
encrypted[key] = value
return encrypted
@staticmethod
def generate_key() -> str:
"""Generate a new encryption key"""
key = Fernet.generate_key()
return key.decode()
# Global instance
encryption_service = EncryptionService()

View File

@@ -0,0 +1,215 @@
from sqlalchemy.orm import Session
from typing import Optional, Dict, Any, List
from datetime import datetime
import secrets
import logging
from ..models.gdpr_compliance import (
DataSubjectRequest,
DataSubjectRequestType,
DataSubjectRequestStatus,
DataRetentionPolicy,
ConsentRecord
)
from ...auth.models.user import User
from ...bookings.models.booking import Booking
from ...payments.models.payment import Payment
from ...reviews.models.review import Review
from ...shared.config.logging_config import get_logger
logger = get_logger(__name__)
class GDPRService:
"""Service for GDPR compliance operations"""
@staticmethod
def create_data_subject_request(
db: Session,
email: str,
request_type: DataSubjectRequestType,
description: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None
) -> DataSubjectRequest:
"""Create a new data subject request"""
# Find user by email
user = db.query(User).filter(User.email == email.lower()).first()
# Generate verification token
verification_token = secrets.token_urlsafe(32)
request = DataSubjectRequest(
user_id=user.id if user else None,
email=email.lower(),
request_type=request_type,
status=DataSubjectRequestStatus.pending,
description=description,
verification_token=verification_token,
ip_address=ip_address,
user_agent=user_agent
)
db.add(request)
db.commit()
db.refresh(request)
logger.info(f"Data subject request created: {request_type.value} for {email}")
return request
@staticmethod
def verify_request(db: Session, verification_token: str) -> bool:
"""Verify a data subject request"""
request = db.query(DataSubjectRequest).filter(
DataSubjectRequest.verification_token == verification_token
).first()
if not request:
return False
request.verified = True
request.verified_at = datetime.utcnow()
db.commit()
return True
@staticmethod
def get_user_data(db: Session, user_id: int) -> Dict[str, Any]:
"""Get all data for a user (for access request)"""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError("User not found")
# Collect all user data
data = {
"user": {
"id": user.id,
"email": user.email,
"full_name": user.full_name,
"phone": user.phone,
"created_at": user.created_at.isoformat() if user.created_at else None,
},
"bookings": [],
"payments": [],
"reviews": [],
}
# Get bookings
bookings = db.query(Booking).filter(Booking.user_id == user_id).all()
for booking in bookings:
data["bookings"].append({
"id": booking.id,
"booking_number": booking.booking_number,
"check_in_date": booking.check_in_date.isoformat() if booking.check_in_date else None,
"check_out_date": booking.check_out_date.isoformat() if booking.check_out_date else None,
"total_price": float(booking.total_price) if booking.total_price else None,
"status": booking.status.value if hasattr(booking.status, 'value') else booking.status,
})
# Get payments
payments = db.query(Payment).filter(Payment.booking.has(user_id=user_id)).all()
for payment in payments:
data["payments"].append({
"id": payment.id,
"amount": float(payment.amount) if payment.amount else None,
"payment_method": payment.payment_method,
"payment_status": payment.payment_status,
"payment_date": payment.payment_date.isoformat() if payment.payment_date else None,
})
# Get reviews
reviews = db.query(Review).filter(Review.user_id == user_id).all()
for review in reviews:
data["reviews"].append({
"id": review.id,
"rating": review.rating,
"comment": review.comment,
"created_at": review.created_at.isoformat() if review.created_at else None,
})
return data
@staticmethod
def delete_user_data(db: Session, user_id: int) -> bool:
"""Delete all user data (for erasure request)"""
try:
user = db.query(User).filter(User.id == user_id).first()
if not user:
return False
# Anonymize user data instead of deleting (for audit trail)
user.email = f"deleted_{user.id}@deleted.local"
user.full_name = "Deleted User"
user.phone = None
user.password = "deleted" # Invalidate password
# Delete related data
# Note: In production, you might want to soft-delete or anonymize instead
db.query(Booking).filter(Booking.user_id == user_id).delete()
db.query(Review).filter(Review.user_id == user_id).delete()
db.commit()
logger.info(f"User data deleted/anonymized for user {user_id}")
return True
except Exception as e:
logger.error(f"Error deleting user data: {str(e)}")
db.rollback()
return False
@staticmethod
def export_user_data(db: Session, user_id: int) -> Dict[str, Any]:
"""Export user data in portable format (for portability request)"""
return GDPRService.get_user_data(db, user_id)
@staticmethod
def record_consent(
db: Session,
user_id: int,
consent_type: str,
granted: bool,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
version: Optional[str] = None
) -> ConsentRecord:
"""Record user consent"""
# Revoke previous consent if granting new one
if granted:
previous = db.query(ConsentRecord).filter(
ConsentRecord.user_id == user_id,
ConsentRecord.consent_type == consent_type,
ConsentRecord.revoked_at.is_(None)
).first()
if previous:
previous.revoked_at = datetime.utcnow()
consent = ConsentRecord(
user_id=user_id,
consent_type=consent_type,
granted=granted,
granted_at=datetime.utcnow() if granted else None,
revoked_at=datetime.utcnow() if not granted else None,
ip_address=ip_address,
user_agent=user_agent,
version=version
)
db.add(consent)
db.commit()
db.refresh(consent)
return consent
@staticmethod
def check_consent(db: Session, user_id: int, consent_type: str) -> bool:
"""Check if user has granted consent"""
consent = db.query(ConsentRecord).filter(
ConsentRecord.user_id == user_id,
ConsentRecord.consent_type == consent_type,
ConsentRecord.granted == True,
ConsentRecord.revoked_at.is_(None)
).order_by(ConsentRecord.granted_at.desc()).first()
return consent is not None
gdpr_service = GDPRService()

View File

@@ -0,0 +1,71 @@
from sqlalchemy.orm import Session
from ...content.models.cookie_policy import CookiePolicy
from ...content.models.cookie_integration_config import CookieIntegrationConfig
from ...auth.models.user import User
from ...content.schemas.admin_privacy import CookieIntegrationSettings, CookiePolicySettings, PublicPrivacyConfig
class PrivacyAdminService:
@staticmethod
def get_or_create_policy(db: Session) -> CookiePolicy:
policy = db.query(CookiePolicy).first()
if policy:
return policy
policy = CookiePolicy()
db.add(policy)
db.commit()
db.refresh(policy)
return policy
@staticmethod
def get_policy_settings(db: Session) -> CookiePolicySettings:
policy = PrivacyAdminService.get_or_create_policy(db)
return CookiePolicySettings(analytics_enabled=policy.analytics_enabled, marketing_enabled=policy.marketing_enabled, preferences_enabled=policy.preferences_enabled)
@staticmethod
def update_policy(db: Session, settings: CookiePolicySettings, updated_by: User | None) -> CookiePolicy:
policy = PrivacyAdminService.get_or_create_policy(db)
policy.analytics_enabled = settings.analytics_enabled
policy.marketing_enabled = settings.marketing_enabled
policy.preferences_enabled = settings.preferences_enabled
if updated_by:
policy.updated_by_id = updated_by.id
db.add(policy)
db.commit()
db.refresh(policy)
return policy
@staticmethod
def get_or_create_integrations(db: Session) -> CookieIntegrationConfig:
config = db.query(CookieIntegrationConfig).first()
if config:
return config
config = CookieIntegrationConfig()
db.add(config)
db.commit()
db.refresh(config)
return config
@staticmethod
def get_integration_settings(db: Session) -> CookieIntegrationSettings:
cfg = PrivacyAdminService.get_or_create_integrations(db)
return CookieIntegrationSettings(ga_measurement_id=cfg.ga_measurement_id, fb_pixel_id=cfg.fb_pixel_id)
@staticmethod
def update_integrations(db: Session, settings: CookieIntegrationSettings, updated_by: User | None) -> CookieIntegrationConfig:
cfg = PrivacyAdminService.get_or_create_integrations(db)
cfg.ga_measurement_id = settings.ga_measurement_id
cfg.fb_pixel_id = settings.fb_pixel_id
if updated_by:
cfg.updated_by_id = updated_by.id
db.add(cfg)
db.commit()
db.refresh(cfg)
return cfg
@staticmethod
def get_public_privacy_config(db: Session) -> PublicPrivacyConfig:
policy = PrivacyAdminService.get_policy_settings(db)
integrations = PrivacyAdminService.get_integration_settings(db)
return PublicPrivacyConfig(policy=policy, integrations=integrations)
privacy_admin_service = PrivacyAdminService()

View File

@@ -0,0 +1,189 @@
from sqlalchemy.orm import Session
from typing import Optional, List, Dict, Any
from datetime import datetime, timedelta
from ..models.security_event import SecurityEvent, SecurityEventType, SecurityEventSeverity
from ...auth.models.user import User
from ...shared.config.logging_config import get_logger
logger = get_logger(__name__)
class SecurityMonitoringService:
"""Service for monitoring and analyzing security events"""
@staticmethod
def log_security_event(
db: Session,
event_type: SecurityEventType,
severity: SecurityEventSeverity,
user_id: Optional[int] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
request_path: Optional[str] = None,
request_method: Optional[str] = None,
request_id: Optional[str] = None,
description: Optional[str] = None,
details: Optional[Dict[str, Any]] = None
) -> SecurityEvent:
"""Log a security event"""
try:
event = SecurityEvent(
user_id=user_id,
event_type=event_type,
severity=severity,
ip_address=ip_address,
user_agent=user_agent,
request_path=request_path,
request_method=request_method,
request_id=request_id,
description=description,
details=details
)
db.add(event)
db.commit()
db.refresh(event)
# Check for suspicious patterns
SecurityMonitoringService._check_suspicious_patterns(db, event)
return event
except Exception as e:
logger.error(f"Error logging security event: {str(e)}")
db.rollback()
raise
@staticmethod
def _check_suspicious_patterns(db: Session, event: SecurityEvent):
"""Check for suspicious activity patterns"""
# Multiple failed login attempts from same IP
if event.event_type == SecurityEventType.login_failure:
recent_failures = db.query(SecurityEvent).filter(
SecurityEvent.event_type == SecurityEventType.login_failure,
SecurityEvent.ip_address == event.ip_address,
SecurityEvent.created_at >= datetime.utcnow() - timedelta(minutes=15)
).count()
if recent_failures >= 5:
# Log suspicious activity
SecurityMonitoringService.log_security_event(
db,
SecurityEventType.suspicious_activity,
SecurityEventSeverity.high,
ip_address=event.ip_address,
description=f"Multiple failed login attempts ({recent_failures}) from IP {event.ip_address}",
details={"failure_count": recent_failures}
)
# Multiple permission denied from same user
if event.event_type == SecurityEventType.permission_denied and event.user_id:
recent_denials = db.query(SecurityEvent).filter(
SecurityEvent.event_type == SecurityEventType.permission_denied,
SecurityEvent.user_id == event.user_id,
SecurityEvent.created_at >= datetime.utcnow() - timedelta(hours=1)
).count()
if recent_denials >= 10:
SecurityMonitoringService.log_security_event(
db,
SecurityEventType.suspicious_activity,
SecurityEventSeverity.medium,
user_id=event.user_id,
description=f"User {event.user_id} has {recent_denials} permission denials in the last hour",
details={"denial_count": recent_denials}
)
@staticmethod
def get_security_events(
db: Session,
user_id: Optional[int] = None,
event_type: Optional[SecurityEventType] = None,
severity: Optional[SecurityEventSeverity] = None,
ip_address: Optional[str] = None,
resolved: Optional[bool] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = 100,
offset: int = 0
) -> List[SecurityEvent]:
"""Get security events with filters"""
query = db.query(SecurityEvent)
if user_id:
query = query.filter(SecurityEvent.user_id == user_id)
if event_type:
query = query.filter(SecurityEvent.event_type == event_type)
if severity:
query = query.filter(SecurityEvent.severity == severity)
if ip_address:
query = query.filter(SecurityEvent.ip_address == ip_address)
if resolved is not None:
query = query.filter(SecurityEvent.resolved == resolved)
if start_date:
query = query.filter(SecurityEvent.created_at >= start_date)
if end_date:
query = query.filter(SecurityEvent.created_at <= end_date)
return query.order_by(SecurityEvent.created_at.desc()).offset(offset).limit(limit).all()
@staticmethod
def get_security_stats(
db: Session,
days: int = 7
) -> Dict[str, Any]:
"""Get security statistics for the last N days"""
start_date = datetime.utcnow() - timedelta(days=days)
total_events = db.query(SecurityEvent).filter(
SecurityEvent.created_at >= start_date
).count()
by_type = {}
by_severity = {}
events = db.query(SecurityEvent).filter(
SecurityEvent.created_at >= start_date
).all()
for event in events:
event_type = event.event_type.value
severity = event.severity.value
by_type[event_type] = by_type.get(event_type, 0) + 1
by_severity[severity] = by_severity.get(severity, 0) + 1
unresolved_critical = db.query(SecurityEvent).filter(
SecurityEvent.severity == SecurityEventSeverity.critical,
SecurityEvent.resolved == False,
SecurityEvent.created_at >= start_date
).count()
return {
"total_events": total_events,
"by_type": by_type,
"by_severity": by_severity,
"unresolved_critical": unresolved_critical,
"period_days": days
}
@staticmethod
def resolve_event(
db: Session,
event_id: int,
resolved_by: int,
resolution_notes: Optional[str] = None
) -> SecurityEvent:
"""Mark a security event as resolved"""
event = db.query(SecurityEvent).filter(SecurityEvent.id == event_id).first()
if not event:
raise ValueError("Security event not found")
event.resolved = True
event.resolved_at = datetime.utcnow()
event.resolved_by = resolved_by
event.resolution_notes = resolution_notes
db.commit()
db.refresh(event)
return event
security_monitoring_service = SecurityMonitoringService()

View File

@@ -0,0 +1,314 @@
from sqlalchemy.orm import Session
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
import logging
from ..models.security_event import SecurityEvent, SecurityEventType, SecurityEventSeverity
from ...auth.models.user import User
from ...bookings.models.booking import Booking
from ...payments.models.payment import Payment
from ...shared.config.logging_config import get_logger
logger = get_logger(__name__)
class SecurityScanService:
"""Service for automated security scanning"""
@staticmethod
def run_full_scan(db: Session) -> Dict[str, Any]:
"""Run a full security scan"""
results = {
"scan_id": f"scan_{datetime.utcnow().isoformat()}",
"started_at": datetime.utcnow().isoformat(),
"checks": [],
"total_issues": 0,
"critical_issues": 0,
"high_issues": 0,
"medium_issues": 0,
"low_issues": 0
}
# Run all security checks
checks = [
SecurityScanService._check_weak_passwords(db),
SecurityScanService._check_inactive_users(db),
SecurityScanService._check_failed_login_attempts(db),
SecurityScanService._check_suspicious_activity(db),
SecurityScanService._check_unresolved_security_events(db),
SecurityScanService._check_expired_tokens(db),
SecurityScanService._check_unusual_payment_patterns(db),
SecurityScanService._check_data_retention_compliance(db),
]
for check in checks:
if check:
results["checks"].append(check)
results["total_issues"] += check.get("issue_count", 0)
severity = check.get("severity", "low")
if severity == "critical":
results["critical_issues"] += check.get("issue_count", 0)
elif severity == "high":
results["high_issues"] += check.get("issue_count", 0)
elif severity == "medium":
results["medium_issues"] += check.get("issue_count", 0)
else:
results["low_issues"] += check.get("issue_count", 0)
completed_at = datetime.utcnow()
results["completed_at"] = completed_at.isoformat()
# Parse start time - handle ISO format with/without microseconds
start_str = results["started_at"]
try:
# Remove timezone info and microseconds for simpler parsing
if '.' in start_str:
start_str = start_str.split('.')[0]
if 'Z' in start_str:
start_str = start_str.replace('Z', '')
if '+' in start_str:
start_str = start_str.split('+')[0]
started_at = datetime.fromisoformat(start_str)
except Exception:
# Fallback: use current time if parsing fails
started_at = completed_at
results["duration_seconds"] = (completed_at - started_at).total_seconds()
# Log critical and high issues as security events
for check in results["checks"]:
if check.get("severity") in ["critical", "high"] and check.get("issue_count", 0) > 0:
SecurityScanService._log_scan_finding(db, check)
return results
@staticmethod
def _check_weak_passwords(db: Session) -> Optional[Dict[str, Any]]:
"""Check for users with weak passwords"""
# This is a placeholder - in production, you'd check password strength
# For now, we'll check for users without password changes in a long time
cutoff_date = datetime.utcnow() - timedelta(days=365)
users = db.query(User).filter(
User.created_at < cutoff_date,
User.is_active == True
).all()
if len(users) > 10: # Threshold
return {
"check_name": "Weak Passwords",
"check_type": "password_security",
"severity": "medium",
"status": "failed",
"issue_count": len(users),
"description": f"{len(users)} users have not changed passwords in over a year",
"recommendation": "Enforce password rotation policy",
"affected_items": [{"user_id": u.id, "email": u.email} for u in users[:10]]
}
return None
@staticmethod
def _check_inactive_users(db: Session) -> Optional[Dict[str, Any]]:
"""Check for inactive users that should be deactivated"""
cutoff_date = datetime.utcnow() - timedelta(days=180)
# Users who haven't logged in for 6 months
inactive_users = db.query(User).filter(
User.is_active == True
).all()
# This is simplified - in production, track last login
if len(inactive_users) > 50:
return {
"check_name": "Inactive Users",
"check_type": "user_management",
"severity": "low",
"status": "warning",
"issue_count": len(inactive_users),
"description": f"Found {len(inactive_users)} potentially inactive users",
"recommendation": "Review and deactivate inactive accounts",
"affected_items": []
}
return None
@staticmethod
def _check_failed_login_attempts(db: Session) -> Optional[Dict[str, Any]]:
"""Check for excessive failed login attempts"""
from ..models.security_event import SecurityEvent, SecurityEventType
recent_failures = db.query(SecurityEvent).filter(
SecurityEvent.event_type == SecurityEventType.login_failure,
SecurityEvent.created_at >= datetime.utcnow() - timedelta(hours=24)
).count()
if recent_failures > 50:
return {
"check_name": "Excessive Failed Logins",
"check_type": "authentication",
"severity": "high",
"status": "failed",
"issue_count": recent_failures,
"description": f"{recent_failures} failed login attempts in the last 24 hours",
"recommendation": "Review failed login attempts and consider IP blocking",
"affected_items": []
}
return None
@staticmethod
def _check_suspicious_activity(db: Session) -> Optional[Dict[str, Any]]:
"""Check for suspicious activity patterns"""
from ..models.security_event import SecurityEvent, SecurityEventType, SecurityEventSeverity
suspicious_events = db.query(SecurityEvent).filter(
SecurityEvent.event_type == SecurityEventType.suspicious_activity,
SecurityEvent.resolved == False,
SecurityEvent.created_at >= datetime.utcnow() - timedelta(days=7)
).count()
if suspicious_events > 0:
return {
"check_name": "Unresolved Suspicious Activity",
"check_type": "threat_detection",
"severity": "critical" if suspicious_events > 5 else "high",
"status": "failed",
"issue_count": suspicious_events,
"description": f"{suspicious_events} unresolved suspicious activity events in the last 7 days",
"recommendation": "Review and resolve suspicious activity events immediately",
"affected_items": []
}
return None
@staticmethod
def _check_unresolved_security_events(db: Session) -> Optional[Dict[str, Any]]:
"""Check for unresolved critical security events"""
from ..models.security_event import SecurityEvent, SecurityEventSeverity
unresolved_critical = db.query(SecurityEvent).filter(
SecurityEvent.severity == SecurityEventSeverity.critical,
SecurityEvent.resolved == False,
SecurityEvent.created_at >= datetime.utcnow() - timedelta(days=7)
).count()
if unresolved_critical > 0:
return {
"check_name": "Unresolved Critical Events",
"check_type": "incident_management",
"severity": "critical",
"status": "failed",
"issue_count": unresolved_critical,
"description": f"{unresolved_critical} unresolved critical security events",
"recommendation": "Resolve critical security events immediately",
"affected_items": []
}
return None
@staticmethod
def _check_expired_tokens(db: Session) -> Optional[Dict[str, Any]]:
"""Check for expired tokens that should be cleaned up"""
from ...auth.models.refresh_token import RefreshToken
expired_tokens = db.query(RefreshToken).filter(
RefreshToken.expires_at < datetime.utcnow()
).count()
if expired_tokens > 1000:
return {
"check_name": "Expired Tokens",
"check_type": "token_management",
"severity": "low",
"status": "warning",
"issue_count": expired_tokens,
"description": f"{expired_tokens} expired tokens found in database",
"recommendation": "Clean up expired tokens to improve database performance",
"affected_items": []
}
return None
@staticmethod
def _check_unusual_payment_patterns(db: Session) -> Optional[Dict[str, Any]]:
"""Check for unusual payment patterns that might indicate fraud"""
from ...payments.models.payment import PaymentStatus
# Check for multiple failed payments from same IP
recent_payments = db.query(Payment).filter(
Payment.payment_date >= datetime.utcnow() - timedelta(hours=24)
).all()
# Simplified check - in production, use more sophisticated fraud detection
failed_payments = [p for p in recent_payments if p.payment_status == PaymentStatus.failed]
if len(failed_payments) > 20:
return {
"check_name": "Unusual Payment Patterns",
"check_type": "fraud_detection",
"severity": "medium",
"status": "warning",
"issue_count": len(failed_payments),
"description": f"{len(failed_payments)} failed payments in the last 24 hours",
"recommendation": "Review failed payment patterns for potential fraud",
"affected_items": []
}
return None
@staticmethod
def _check_data_retention_compliance(db: Session) -> Optional[Dict[str, Any]]:
"""Check data retention policy compliance"""
from ..models.gdpr_compliance import DataRetentionPolicy
policies = db.query(DataRetentionPolicy).filter(
DataRetentionPolicy.is_active == True,
DataRetentionPolicy.auto_delete == True
).all()
# Check if there's data that should have been deleted
issues = []
for policy in policies:
# This is simplified - in production, check actual data age
if policy.retention_days < 30: # Very short retention
issues.append({
"policy": policy.data_type,
"retention_days": policy.retention_days
})
if issues:
return {
"check_name": "Data Retention Compliance",
"check_type": "gdpr_compliance",
"severity": "high",
"status": "warning",
"issue_count": len(issues),
"description": f"Found {len(issues)} data retention policies that may need review",
"recommendation": "Review data retention policies for GDPR compliance",
"affected_items": issues
}
return None
@staticmethod
def _log_scan_finding(db: Session, check: Dict[str, Any]):
"""Log scan findings as security events"""
try:
event = SecurityEvent(
event_type=SecurityEventType.suspicious_activity,
severity=SecurityEventSeverity(check["severity"]),
description=f"Security Scan: {check['check_name']} - {check['description']}",
details={
"check_type": check.get("check_type"),
"issue_count": check.get("issue_count"),
"recommendation": check.get("recommendation"),
"affected_items": check.get("affected_items", [])
}
)
db.add(event)
db.commit()
except Exception as e:
logger.error(f"Error logging scan finding: {str(e)}")
db.rollback()
@staticmethod
def schedule_scan(db: Session, interval_hours: int = 24) -> Dict[str, Any]:
"""Schedule automatic security scans"""
# In production, use a task scheduler like Celery or APScheduler
# For now, this is a placeholder that returns scan configuration
return {
"scheduled": True,
"interval_hours": interval_hours,
"next_scan": (datetime.utcnow() + timedelta(hours=interval_hours)).isoformat(),
"message": "Scan scheduled. In production, use a task scheduler to run scans automatically."
}
security_scan_service = SecurityScanService()