""" Zero Trust Middleware Automatically applies Zero Trust principles to all requests """ import logging from django.http import JsonResponse from django.utils.deprecation import MiddlewareMixin from django.contrib.auth import get_user_model from django.utils import timezone from ..services.zero_trust import zero_trust_service logger = logging.getLogger(__name__) User = get_user_model() class ZeroTrustMiddleware(MiddlewareMixin): """ Middleware that applies Zero Trust principles to all requests """ def __init__(self, get_response): self.get_response = get_response super().__init__(get_response) def process_request(self, request): """Process incoming request for Zero Trust assessment""" # Skip Zero Trust for certain paths skip_paths = [ '/admin/', '/static/', '/media/', '/api/auth/login/', '/api/auth/logout/', '/health/', '/metrics/', ] if any(request.path.startswith(path) for path in skip_paths): return None # Only apply to authenticated users if not request.user.is_authenticated: return None # Skip for API endpoints that don't require Zero Trust if request.path.startswith('/api/') and self._is_public_endpoint(request.path): return None try: # Collect request context request_context = self._collect_request_context(request) # Perform Zero Trust assessment assessment_result = zero_trust_service.assess_access_request( request.user, request_context ) # Handle assessment result if not assessment_result.get('access_granted', False): return self._handle_access_denied(request, assessment_result) # Store assessment result in request for use in views request.zero_trust_assessment = assessment_result # Update behavior profile zero_trust_service.update_behavior_profile(request.user, request_context) except Exception as e: logger.error(f"Zero Trust middleware error: {e}") # In case of error, allow access but log the issue request.zero_trust_assessment = { 'access_granted': True, 'reason': 'Assessment error - defaulting to allow', 'risk_level': 'UNKNOWN', 'error': str(e) } return None def _collect_request_context(self, request) -> dict: """Collect context data from request""" # Get client IP ip_address = self._get_client_ip(request) # Get device ID from headers or session device_id = ( request.headers.get('X-Device-ID') or request.session.get('device_id') or self._generate_device_fingerprint(request) ) return { 'ip_address': ip_address, 'user_agent': request.META.get('HTTP_USER_AGENT', ''), 'device_id': device_id, 'timestamp': timezone.now(), 'request_method': request.method, 'request_path': request.path, 'referer': request.META.get('HTTP_REFERER', ''), 'accept_language': request.META.get('HTTP_ACCEPT_LANGUAGE', ''), 'session_id': request.session.session_key, } def _get_client_ip(self, request) -> str: """Get client IP address from request""" x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR') if x_forwarded_for: ip = x_forwarded_for.split(',')[0].strip() else: ip = request.META.get('REMOTE_ADDR') return ip def _generate_device_fingerprint(self, request) -> str: """Generate a device fingerprint for tracking""" import hashlib # Create fingerprint from user agent and IP fingerprint_data = f"{request.META.get('HTTP_USER_AGENT', '')}{self._get_client_ip(request)}" return hashlib.sha256(fingerprint_data.encode()).hexdigest()[:16] def _is_public_endpoint(self, path: str) -> bool: """Check if endpoint is public and doesn't require Zero Trust""" public_endpoints = [ '/api/auth/', '/api/health/', '/api/status/', '/api/docs/', ] return any(path.startswith(endpoint) for endpoint in public_endpoints) def _handle_access_denied(self, request, assessment_result: dict) -> JsonResponse: """Handle access denied based on Zero Trust assessment""" required_actions = assessment_result.get('required_actions', []) risk_level = assessment_result.get('risk_level', 'UNKNOWN') reason = assessment_result.get('reason', 'Access denied') # Determine response based on required actions if 'MANUAL_REVIEW' in required_actions: return JsonResponse({ 'error': 'Access requires manual review', 'reason': reason, 'risk_level': risk_level, 'required_actions': required_actions, 'support_contact': 'security@company.com' }, status=423) # 423 Locked elif 'STEP_UP_AUTH' in required_actions: return JsonResponse({ 'error': 'Additional authentication required', 'reason': reason, 'risk_level': risk_level, 'required_actions': required_actions, 'auth_url': '/api/auth/step-up/' }, status=401) # 401 Unauthorized elif 'DEVICE_REGISTRATION' in required_actions: return JsonResponse({ 'error': 'Device registration required', 'reason': reason, 'risk_level': risk_level, 'required_actions': required_actions, 'registration_url': '/api/security/register-device/' }, status=403) # 403 Forbidden elif 'ADDITIONAL_MFA' in required_actions: return JsonResponse({ 'error': 'Additional MFA required', 'reason': reason, 'risk_level': risk_level, 'required_actions': required_actions, 'mfa_url': '/api/auth/mfa/' }, status=401) # 401 Unauthorized else: # Generic access denied return JsonResponse({ 'error': 'Access denied', 'reason': reason, 'risk_level': risk_level, 'required_actions': required_actions }, status=403) # 403 Forbidden class DeviceRegistrationMiddleware(MiddlewareMixin): """ Middleware to handle device registration for Zero Trust """ def process_request(self, request): """Process device registration requests""" if request.path == '/api/security/register-device/' and request.method == 'POST': return self._handle_device_registration(request) return None def _handle_device_registration(self, request): """Handle device registration""" try: if not request.user.is_authenticated: return JsonResponse({'error': 'Authentication required'}, status=401) device_data = request.json if hasattr(request, 'json') else {} # Register device device_posture = zero_trust_service.register_device(request.user, device_data) return JsonResponse({ 'success': True, 'device_id': device_posture.device_id, 'trust_level': device_posture.trust_level, 'risk_score': device_posture.risk_score, 'is_compliant': device_posture.is_compliant, 'message': 'Device registered successfully' }) except Exception as e: logger.error(f"Device registration failed: {e}") return JsonResponse({ 'error': 'Device registration failed', 'details': str(e) }, status=400) class RiskBasedRateLimitMiddleware(MiddlewareMixin): """ Middleware for risk-based rate limiting """ def __init__(self, get_response): self.get_response = get_response super().__init__(get_response) self.request_counts = {} # In production, use Redis or database def process_request(self, request): """Apply risk-based rate limiting""" if not request.user.is_authenticated: return None # Get user's risk level from Zero Trust assessment assessment = getattr(request, 'zero_trust_assessment', None) if not assessment: return None risk_level = assessment.get('risk_level', 'LOW') user_id = str(request.user.id) current_time = timezone.now() # Define rate limits based on risk level rate_limits = { 'LOW': {'requests': 1000, 'window': 3600}, # 1000 requests per hour 'MEDIUM': {'requests': 500, 'window': 3600}, # 500 requests per hour 'HIGH': {'requests': 100, 'window': 3600}, # 100 requests per hour 'CRITICAL': {'requests': 10, 'window': 3600}, # 10 requests per hour } limit = rate_limits.get(risk_level, rate_limits['LOW']) # Check rate limit if self._is_rate_limited(user_id, limit, current_time): return JsonResponse({ 'error': 'Rate limit exceeded', 'reason': f'Too many requests for risk level {risk_level}', 'retry_after': limit['window'] }, status=429) # 429 Too Many Requests return None def _is_rate_limited(self, user_id: str, limit: dict, current_time) -> bool: """Check if user has exceeded rate limit""" # Simplified rate limiting (in production, use proper rate limiting library) window_start = current_time.timestamp() - limit['window'] # Clean old entries if user_id in self.request_counts: self.request_counts[user_id] = [ timestamp for timestamp in self.request_counts[user_id] if timestamp > window_start ] else: self.request_counts[user_id] = [] # Check if limit exceeded if len(self.request_counts[user_id]) >= limit['requests']: return True # Add current request self.request_counts[user_id].append(current_time.timestamp()) return False