290 lines
11 KiB
Python
290 lines
11 KiB
Python
"""
|
|
Zero Trust Middleware
|
|
Automatically applies Zero Trust principles to all requests
|
|
"""
|
|
import logging
|
|
from django.http import JsonResponse
|
|
from django.utils.deprecation import MiddlewareMixin
|
|
from django.contrib.auth import get_user_model
|
|
from django.utils import timezone
|
|
|
|
from ..services.zero_trust import zero_trust_service
|
|
|
|
logger = logging.getLogger(__name__)
|
|
User = get_user_model()
|
|
|
|
|
|
class ZeroTrustMiddleware(MiddlewareMixin):
|
|
"""
|
|
Middleware that applies Zero Trust principles to all requests
|
|
"""
|
|
|
|
def __init__(self, get_response):
|
|
self.get_response = get_response
|
|
super().__init__(get_response)
|
|
|
|
def process_request(self, request):
|
|
"""Process incoming request for Zero Trust assessment"""
|
|
# Skip Zero Trust for certain paths
|
|
skip_paths = [
|
|
'/admin/',
|
|
'/static/',
|
|
'/media/',
|
|
'/api/auth/login/',
|
|
'/api/auth/logout/',
|
|
'/health/',
|
|
'/metrics/',
|
|
]
|
|
|
|
if any(request.path.startswith(path) for path in skip_paths):
|
|
return None
|
|
|
|
# Only apply to authenticated users
|
|
if not request.user.is_authenticated:
|
|
return None
|
|
|
|
# Skip for API endpoints that don't require Zero Trust
|
|
if request.path.startswith('/api/') and self._is_public_endpoint(request.path):
|
|
return None
|
|
|
|
try:
|
|
# Collect request context
|
|
request_context = self._collect_request_context(request)
|
|
|
|
# Perform Zero Trust assessment
|
|
assessment_result = zero_trust_service.assess_access_request(
|
|
request.user,
|
|
request_context
|
|
)
|
|
|
|
# Handle assessment result
|
|
if not assessment_result.get('access_granted', False):
|
|
return self._handle_access_denied(request, assessment_result)
|
|
|
|
# Store assessment result in request for use in views
|
|
request.zero_trust_assessment = assessment_result
|
|
|
|
# Update behavior profile
|
|
zero_trust_service.update_behavior_profile(request.user, request_context)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Zero Trust middleware error: {e}")
|
|
# In case of error, allow access but log the issue
|
|
request.zero_trust_assessment = {
|
|
'access_granted': True,
|
|
'reason': 'Assessment error - defaulting to allow',
|
|
'risk_level': 'UNKNOWN',
|
|
'error': str(e)
|
|
}
|
|
|
|
return None
|
|
|
|
def _collect_request_context(self, request) -> dict:
|
|
"""Collect context data from request"""
|
|
# Get client IP
|
|
ip_address = self._get_client_ip(request)
|
|
|
|
# Get device ID from headers or session
|
|
device_id = (
|
|
request.headers.get('X-Device-ID') or
|
|
request.session.get('device_id') or
|
|
self._generate_device_fingerprint(request)
|
|
)
|
|
|
|
return {
|
|
'ip_address': ip_address,
|
|
'user_agent': request.META.get('HTTP_USER_AGENT', ''),
|
|
'device_id': device_id,
|
|
'timestamp': timezone.now(),
|
|
'request_method': request.method,
|
|
'request_path': request.path,
|
|
'referer': request.META.get('HTTP_REFERER', ''),
|
|
'accept_language': request.META.get('HTTP_ACCEPT_LANGUAGE', ''),
|
|
'session_id': request.session.session_key,
|
|
}
|
|
|
|
def _get_client_ip(self, request) -> str:
|
|
"""Get client IP address from request"""
|
|
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
|
|
if x_forwarded_for:
|
|
ip = x_forwarded_for.split(',')[0].strip()
|
|
else:
|
|
ip = request.META.get('REMOTE_ADDR')
|
|
return ip
|
|
|
|
def _generate_device_fingerprint(self, request) -> str:
|
|
"""Generate a device fingerprint for tracking"""
|
|
import hashlib
|
|
|
|
# Create fingerprint from user agent and IP
|
|
fingerprint_data = f"{request.META.get('HTTP_USER_AGENT', '')}{self._get_client_ip(request)}"
|
|
return hashlib.sha256(fingerprint_data.encode()).hexdigest()[:16]
|
|
|
|
def _is_public_endpoint(self, path: str) -> bool:
|
|
"""Check if endpoint is public and doesn't require Zero Trust"""
|
|
public_endpoints = [
|
|
'/api/auth/',
|
|
'/api/health/',
|
|
'/api/status/',
|
|
'/api/docs/',
|
|
]
|
|
return any(path.startswith(endpoint) for endpoint in public_endpoints)
|
|
|
|
def _handle_access_denied(self, request, assessment_result: dict) -> JsonResponse:
|
|
"""Handle access denied based on Zero Trust assessment"""
|
|
required_actions = assessment_result.get('required_actions', [])
|
|
risk_level = assessment_result.get('risk_level', 'UNKNOWN')
|
|
reason = assessment_result.get('reason', 'Access denied')
|
|
|
|
# Determine response based on required actions
|
|
if 'MANUAL_REVIEW' in required_actions:
|
|
return JsonResponse({
|
|
'error': 'Access requires manual review',
|
|
'reason': reason,
|
|
'risk_level': risk_level,
|
|
'required_actions': required_actions,
|
|
'support_contact': 'security@company.com'
|
|
}, status=423) # 423 Locked
|
|
|
|
elif 'STEP_UP_AUTH' in required_actions:
|
|
return JsonResponse({
|
|
'error': 'Additional authentication required',
|
|
'reason': reason,
|
|
'risk_level': risk_level,
|
|
'required_actions': required_actions,
|
|
'auth_url': '/api/auth/step-up/'
|
|
}, status=401) # 401 Unauthorized
|
|
|
|
elif 'DEVICE_REGISTRATION' in required_actions:
|
|
return JsonResponse({
|
|
'error': 'Device registration required',
|
|
'reason': reason,
|
|
'risk_level': risk_level,
|
|
'required_actions': required_actions,
|
|
'registration_url': '/api/security/register-device/'
|
|
}, status=403) # 403 Forbidden
|
|
|
|
elif 'ADDITIONAL_MFA' in required_actions:
|
|
return JsonResponse({
|
|
'error': 'Additional MFA required',
|
|
'reason': reason,
|
|
'risk_level': risk_level,
|
|
'required_actions': required_actions,
|
|
'mfa_url': '/api/auth/mfa/'
|
|
}, status=401) # 401 Unauthorized
|
|
|
|
else:
|
|
# Generic access denied
|
|
return JsonResponse({
|
|
'error': 'Access denied',
|
|
'reason': reason,
|
|
'risk_level': risk_level,
|
|
'required_actions': required_actions
|
|
}, status=403) # 403 Forbidden
|
|
|
|
|
|
class DeviceRegistrationMiddleware(MiddlewareMixin):
|
|
"""
|
|
Middleware to handle device registration for Zero Trust
|
|
"""
|
|
|
|
def process_request(self, request):
|
|
"""Process device registration requests"""
|
|
if request.path == '/api/security/register-device/' and request.method == 'POST':
|
|
return self._handle_device_registration(request)
|
|
return None
|
|
|
|
def _handle_device_registration(self, request):
|
|
"""Handle device registration"""
|
|
try:
|
|
if not request.user.is_authenticated:
|
|
return JsonResponse({'error': 'Authentication required'}, status=401)
|
|
|
|
device_data = request.json if hasattr(request, 'json') else {}
|
|
|
|
# Register device
|
|
device_posture = zero_trust_service.register_device(request.user, device_data)
|
|
|
|
return JsonResponse({
|
|
'success': True,
|
|
'device_id': device_posture.device_id,
|
|
'trust_level': device_posture.trust_level,
|
|
'risk_score': device_posture.risk_score,
|
|
'is_compliant': device_posture.is_compliant,
|
|
'message': 'Device registered successfully'
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Device registration failed: {e}")
|
|
return JsonResponse({
|
|
'error': 'Device registration failed',
|
|
'details': str(e)
|
|
}, status=400)
|
|
|
|
|
|
class RiskBasedRateLimitMiddleware(MiddlewareMixin):
|
|
"""
|
|
Middleware for risk-based rate limiting
|
|
"""
|
|
|
|
def __init__(self, get_response):
|
|
self.get_response = get_response
|
|
super().__init__(get_response)
|
|
self.request_counts = {} # In production, use Redis or database
|
|
|
|
def process_request(self, request):
|
|
"""Apply risk-based rate limiting"""
|
|
if not request.user.is_authenticated:
|
|
return None
|
|
|
|
# Get user's risk level from Zero Trust assessment
|
|
assessment = getattr(request, 'zero_trust_assessment', None)
|
|
if not assessment:
|
|
return None
|
|
|
|
risk_level = assessment.get('risk_level', 'LOW')
|
|
user_id = str(request.user.id)
|
|
current_time = timezone.now()
|
|
|
|
# Define rate limits based on risk level
|
|
rate_limits = {
|
|
'LOW': {'requests': 1000, 'window': 3600}, # 1000 requests per hour
|
|
'MEDIUM': {'requests': 500, 'window': 3600}, # 500 requests per hour
|
|
'HIGH': {'requests': 100, 'window': 3600}, # 100 requests per hour
|
|
'CRITICAL': {'requests': 10, 'window': 3600}, # 10 requests per hour
|
|
}
|
|
|
|
limit = rate_limits.get(risk_level, rate_limits['LOW'])
|
|
|
|
# Check rate limit
|
|
if self._is_rate_limited(user_id, limit, current_time):
|
|
return JsonResponse({
|
|
'error': 'Rate limit exceeded',
|
|
'reason': f'Too many requests for risk level {risk_level}',
|
|
'retry_after': limit['window']
|
|
}, status=429) # 429 Too Many Requests
|
|
|
|
return None
|
|
|
|
def _is_rate_limited(self, user_id: str, limit: dict, current_time) -> bool:
|
|
"""Check if user has exceeded rate limit"""
|
|
# Simplified rate limiting (in production, use proper rate limiting library)
|
|
window_start = current_time.timestamp() - limit['window']
|
|
|
|
# Clean old entries
|
|
if user_id in self.request_counts:
|
|
self.request_counts[user_id] = [
|
|
timestamp for timestamp in self.request_counts[user_id]
|
|
if timestamp > window_start
|
|
]
|
|
else:
|
|
self.request_counts[user_id] = []
|
|
|
|
# Check if limit exceeded
|
|
if len(self.request_counts[user_id]) >= limit['requests']:
|
|
return True
|
|
|
|
# Add current request
|
|
self.request_counts[user_id].append(current_time.timestamp())
|
|
return False
|