413 lines
19 KiB
Python
413 lines
19 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, status, Cookie, Response, Request, UploadFile, File
|
|
from fastapi.responses import JSONResponse
|
|
from sqlalchemy.orm import Session
|
|
from pathlib import Path
|
|
import aiofiles
|
|
import uuid
|
|
import os
|
|
from ..config.database import get_db
|
|
from ..services.auth_service import auth_service
|
|
from ..schemas.auth import RegisterRequest, LoginRequest, RefreshTokenRequest, ForgotPasswordRequest, ResetPasswordRequest, AuthResponse, TokenResponse, MessageResponse, MFAInitResponse, EnableMFARequest, VerifyMFARequest, MFAStatusResponse, UpdateProfileRequest
|
|
from ..middleware.auth import get_current_user
|
|
from ..models.user import User
|
|
from ..services.audit_service import audit_service
|
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
|
from slowapi.util import get_remote_address
|
|
from slowapi.errors import RateLimitExceeded
|
|
|
|
router = APIRouter(prefix='/auth', tags=['auth'])
|
|
|
|
# Stricter rate limits for authentication endpoints
|
|
AUTH_RATE_LIMIT = "5/minute" # 5 attempts per minute per IP
|
|
PASSWORD_RESET_LIMIT = "3/hour" # 3 password reset requests per hour per IP
|
|
LOGIN_RATE_LIMIT = "10/minute" # 10 login attempts per minute per IP
|
|
|
|
# Initialize limiter - will be set from app state
|
|
limiter = None
|
|
|
|
def get_limiter(request: Request) -> Limiter:
|
|
"""Get limiter instance from app state."""
|
|
global limiter
|
|
if hasattr(request.app.state, 'limiter'):
|
|
limiter = request.app.state.limiter
|
|
return limiter
|
|
|
|
def get_base_url(request: Request) -> str:
|
|
return os.getenv('SERVER_URL') or f'http://{request.headers.get('host', 'localhost:8000')}'
|
|
|
|
def normalize_image_url(image_url: str, base_url: str) -> str:
|
|
if not image_url:
|
|
return image_url
|
|
if image_url.startswith('http://') or image_url.startswith('https://'):
|
|
return image_url
|
|
if image_url.startswith('/'):
|
|
return f'{base_url}{image_url}'
|
|
return f'{base_url}/{image_url}'
|
|
|
|
@router.post('/register', status_code=status.HTTP_201_CREATED)
|
|
async def register(
|
|
request: Request,
|
|
register_request: RegisterRequest,
|
|
response: Response,
|
|
db: Session=Depends(get_db)
|
|
):
|
|
# Rate limiting is handled by middleware, but we can add additional checks here if needed
|
|
client_ip = request.client.host if request.client else None
|
|
user_agent = request.headers.get('User-Agent')
|
|
request_id = getattr(request.state, 'request_id', None)
|
|
|
|
try:
|
|
result = await auth_service.register(db=db, name=register_request.name, email=register_request.email, password=register_request.password, phone=register_request.phone)
|
|
from ..config.settings import settings
|
|
max_age = 7 * 24 * 60 * 60 # 7 days for registration
|
|
# Use secure cookies in production (HTTPS required)
|
|
# Set access token in httpOnly cookie for security
|
|
# Use 'lax' in development for cross-origin support, 'strict' in production
|
|
samesite_value = 'strict' if settings.is_production else 'lax'
|
|
response.set_cookie(
|
|
key='accessToken',
|
|
value=result['token'],
|
|
httponly=True,
|
|
secure=settings.is_production, # Secure flag enabled in production
|
|
samesite=samesite_value,
|
|
max_age=max_age,
|
|
path='/'
|
|
)
|
|
# Set refresh token in httpOnly cookie
|
|
response.set_cookie(
|
|
key='refreshToken',
|
|
value=result['refreshToken'],
|
|
httponly=True,
|
|
secure=settings.is_production, # Secure flag enabled in production
|
|
samesite=samesite_value,
|
|
max_age=max_age,
|
|
path='/'
|
|
)
|
|
|
|
# Log successful registration
|
|
await audit_service.log_action(
|
|
db=db,
|
|
action='user_registered',
|
|
resource_type='user',
|
|
user_id=result['user']['id'],
|
|
ip_address=client_ip,
|
|
user_agent=user_agent,
|
|
request_id=request_id,
|
|
details={'email': register_request.email, 'name': register_request.name},
|
|
status='success'
|
|
)
|
|
|
|
# Return user data but NOT the token (it's in httpOnly cookie now)
|
|
return {'status': 'success', 'message': 'Registration successful', 'data': {'user': result['user']}}
|
|
except ValueError as e:
|
|
error_message = str(e)
|
|
# Log failed registration attempt
|
|
await audit_service.log_action(
|
|
db=db,
|
|
action='user_registration_failed',
|
|
resource_type='user',
|
|
ip_address=client_ip,
|
|
user_agent=user_agent,
|
|
request_id=request_id,
|
|
details={'email': register_request.email, 'name': register_request.name},
|
|
status='failed',
|
|
error_message=error_message
|
|
)
|
|
return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content={'status': 'error', 'message': error_message})
|
|
|
|
@router.post('/login')
|
|
async def login(
|
|
request: Request,
|
|
login_request: LoginRequest,
|
|
response: Response,
|
|
db: Session=Depends(get_db)
|
|
):
|
|
client_ip = request.client.host if request.client else None
|
|
user_agent = request.headers.get('User-Agent')
|
|
request_id = getattr(request.state, 'request_id', None)
|
|
|
|
try:
|
|
result = await auth_service.login(db=db, email=login_request.email, password=login_request.password, remember_me=login_request.rememberMe or False, mfa_token=login_request.mfaToken)
|
|
if result.get('requires_mfa'):
|
|
# Log MFA required
|
|
user = db.query(User).filter(User.email == login_request.email.lower().strip()).first()
|
|
if user:
|
|
await audit_service.log_action(
|
|
db=db,
|
|
action='login_mfa_required',
|
|
resource_type='authentication',
|
|
user_id=user.id,
|
|
ip_address=client_ip,
|
|
user_agent=user_agent,
|
|
request_id=request_id,
|
|
details={'email': login_request.email},
|
|
status='success'
|
|
)
|
|
return {'status': 'success', 'requires_mfa': True, 'user_id': result['user_id']}
|
|
from ..config.settings import settings
|
|
max_age = 7 * 24 * 60 * 60 if login_request.rememberMe else 1 * 24 * 60 * 60
|
|
# Use secure cookies in production (HTTPS required)
|
|
# Set access token in httpOnly cookie for security
|
|
# Use 'lax' in development for cross-origin support, 'strict' in production
|
|
samesite_value = 'strict' if settings.is_production else 'lax'
|
|
response.set_cookie(
|
|
key='accessToken',
|
|
value=result['token'],
|
|
httponly=True,
|
|
secure=settings.is_production, # Secure flag enabled in production
|
|
samesite=samesite_value,
|
|
max_age=max_age,
|
|
path='/'
|
|
)
|
|
# Set refresh token in httpOnly cookie
|
|
response.set_cookie(
|
|
key='refreshToken',
|
|
value=result['refreshToken'],
|
|
httponly=True,
|
|
secure=settings.is_production, # Secure flag enabled in production
|
|
samesite=samesite_value,
|
|
max_age=max_age,
|
|
path='/'
|
|
)
|
|
|
|
# Log successful login
|
|
await audit_service.log_action(
|
|
db=db,
|
|
action='login_success',
|
|
resource_type='authentication',
|
|
user_id=result['user']['id'],
|
|
ip_address=client_ip,
|
|
user_agent=user_agent,
|
|
request_id=request_id,
|
|
details={'email': login_request.email, 'remember_me': login_request.rememberMe},
|
|
status='success'
|
|
)
|
|
|
|
# Return user data but NOT the token (it's in httpOnly cookie now)
|
|
return {'status': 'success', 'data': {'user': result['user']}}
|
|
except ValueError as e:
|
|
error_message = str(e)
|
|
status_code = status.HTTP_401_UNAUTHORIZED if 'Invalid email or password' in error_message or 'Invalid MFA token' in error_message else status.HTTP_400_BAD_REQUEST
|
|
|
|
# Log failed login attempt
|
|
await audit_service.log_action(
|
|
db=db,
|
|
action='login_failed',
|
|
resource_type='authentication',
|
|
ip_address=client_ip,
|
|
user_agent=user_agent,
|
|
request_id=request_id,
|
|
details={'email': login_request.email},
|
|
status='failed',
|
|
error_message=error_message
|
|
)
|
|
|
|
return JSONResponse(status_code=status_code, content={'status': 'error', 'message': error_message})
|
|
|
|
@router.post('/refresh-token', response_model=TokenResponse)
|
|
async def refresh_token(
|
|
request: Request,
|
|
response: Response,
|
|
refreshToken: str=Cookie(None),
|
|
db: Session=Depends(get_db)
|
|
):
|
|
if not refreshToken:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Refresh token not found')
|
|
try:
|
|
result = await auth_service.refresh_access_token(db, refreshToken)
|
|
from ..config.settings import settings
|
|
# Set new access token in httpOnly cookie
|
|
# Use 'lax' in development for cross-origin support, 'strict' in production
|
|
samesite_value = 'strict' if settings.is_production else 'lax'
|
|
max_age = 7 * 24 * 60 * 60 # 7 days
|
|
response.set_cookie(
|
|
key='accessToken',
|
|
value=result['token'],
|
|
httponly=True,
|
|
secure=settings.is_production,
|
|
samesite=samesite_value,
|
|
max_age=max_age,
|
|
path='/'
|
|
)
|
|
# Return user data but NOT the token (it's in httpOnly cookie now)
|
|
return {'status': 'success', 'data': {'user': result.get('user')}}
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
|
|
|
|
@router.post('/logout', response_model=MessageResponse)
|
|
async def logout(
|
|
request: Request,
|
|
response: Response,
|
|
refreshToken: str=Cookie(None),
|
|
current_user: User=Depends(get_current_user),
|
|
db: Session=Depends(get_db)
|
|
):
|
|
client_ip = request.client.host if request.client else None
|
|
user_agent = request.headers.get('User-Agent')
|
|
request_id = getattr(request.state, 'request_id', None)
|
|
|
|
if refreshToken:
|
|
await auth_service.logout(db, refreshToken)
|
|
|
|
# Delete both access and refresh token cookies
|
|
from ..config.settings import settings
|
|
# Use 'lax' in development for cross-origin support, 'strict' in production
|
|
samesite_value = 'strict' if settings.is_production else 'lax'
|
|
response.delete_cookie(key='refreshToken', path='/', secure=settings.is_production, samesite=samesite_value)
|
|
response.delete_cookie(key='accessToken', path='/', secure=settings.is_production, samesite=samesite_value)
|
|
|
|
# Log logout
|
|
await audit_service.log_action(
|
|
db=db,
|
|
action='logout',
|
|
resource_type='authentication',
|
|
user_id=current_user.id,
|
|
ip_address=client_ip,
|
|
user_agent=user_agent,
|
|
request_id=request_id,
|
|
details={'email': current_user.email},
|
|
status='success'
|
|
)
|
|
|
|
return {'status': 'success', 'message': 'Logout successful'}
|
|
|
|
@router.get('/profile')
|
|
async def get_profile(current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
|
|
try:
|
|
user = await auth_service.get_profile(db, current_user.id)
|
|
return {'status': 'success', 'data': {'user': user}}
|
|
except ValueError as e:
|
|
if 'User not found' in str(e):
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
|
|
@router.put('/profile')
|
|
async def update_profile(profile_data: UpdateProfileRequest, current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
|
|
try:
|
|
user = await auth_service.update_profile(
|
|
db=db,
|
|
user_id=current_user.id,
|
|
full_name=profile_data.full_name,
|
|
email=profile_data.email,
|
|
phone_number=profile_data.phone_number,
|
|
password=profile_data.password,
|
|
current_password=profile_data.currentPassword,
|
|
currency=profile_data.currency
|
|
)
|
|
return {'status': 'success', 'message': 'Profile updated successfully', 'data': {'user': user}}
|
|
except ValueError as e:
|
|
error_message = str(e)
|
|
status_code = status.HTTP_400_BAD_REQUEST
|
|
if 'not found' in error_message.lower():
|
|
status_code = status.HTTP_404_NOT_FOUND
|
|
raise HTTPException(status_code=status_code, detail=error_message)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'An error occurred: {str(e)}')
|
|
|
|
@router.post('/forgot-password', response_model=MessageResponse)
|
|
async def forgot_password(request: ForgotPasswordRequest, db: Session=Depends(get_db)):
|
|
result = await auth_service.forgot_password(db, request.email)
|
|
return {'status': 'success', 'message': result['message']}
|
|
|
|
@router.post('/reset-password', response_model=MessageResponse)
|
|
async def reset_password(request: ResetPasswordRequest, db: Session=Depends(get_db)):
|
|
try:
|
|
result = await auth_service.reset_password(db=db, token=request.token, password=request.password)
|
|
return {'status': 'success', 'message': result['message']}
|
|
except ValueError as e:
|
|
status_code = status.HTTP_400_BAD_REQUEST
|
|
if 'User not found' in str(e):
|
|
status_code = status.HTTP_404_NOT_FOUND
|
|
raise HTTPException(status_code=status_code, detail=str(e))
|
|
from ..services.mfa_service import mfa_service
|
|
from ..config.settings import settings
|
|
|
|
@router.get('/mfa/init')
|
|
async def init_mfa(current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
|
|
try:
|
|
if current_user.mfa_enabled:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='MFA is already enabled')
|
|
secret = mfa_service.generate_secret()
|
|
app_name = getattr(settings, 'APP_NAME', 'Hotel Booking')
|
|
qr_code = mfa_service.generate_qr_code(secret, current_user.email, app_name)
|
|
return {'status': 'success', 'data': {'secret': secret, 'qr_code': qr_code}}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error initializing MFA: {str(e)}')
|
|
|
|
@router.post('/mfa/enable')
|
|
async def enable_mfa(request: EnableMFARequest, current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
|
|
try:
|
|
success, backup_codes = mfa_service.enable_mfa(db=db, user_id=current_user.id, secret=request.secret, verification_token=request.verification_token)
|
|
return {'status': 'success', 'message': 'MFA enabled successfully', 'data': {'backup_codes': backup_codes}}
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error enabling MFA: {str(e)}')
|
|
|
|
@router.post('/mfa/disable')
|
|
async def disable_mfa(current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
|
|
try:
|
|
mfa_service.disable_mfa(db=db, user_id=current_user.id)
|
|
return {'status': 'success', 'message': 'MFA disabled successfully'}
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error disabling MFA: {str(e)}')
|
|
|
|
@router.get('/mfa/status', response_model=MFAStatusResponse)
|
|
async def get_mfa_status(current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
|
|
try:
|
|
status_data = mfa_service.get_mfa_status(db=db, user_id=current_user.id)
|
|
return status_data
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error getting MFA status: {str(e)}')
|
|
|
|
@router.post('/mfa/regenerate-backup-codes')
|
|
async def regenerate_backup_codes(current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
|
|
try:
|
|
backup_codes = mfa_service.regenerate_backup_codes(db=db, user_id=current_user.id)
|
|
return {'status': 'success', 'message': 'Backup codes regenerated successfully', 'data': {'backup_codes': backup_codes}}
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error regenerating backup codes: {str(e)}')
|
|
|
|
@router.post('/avatar/upload')
|
|
async def upload_avatar(request: Request, image: UploadFile=File(...), current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
|
|
try:
|
|
# Use comprehensive file validation (magic bytes + size)
|
|
from ..utils.file_validation import validate_uploaded_image
|
|
max_avatar_size = 2 * 1024 * 1024 # 2MB for avatars
|
|
|
|
# Validate file completely (MIME type, size, magic bytes, integrity)
|
|
content = await validate_uploaded_image(image, max_avatar_size)
|
|
upload_dir = Path(__file__).parent.parent.parent / 'uploads' / 'avatars'
|
|
upload_dir.mkdir(parents=True, exist_ok=True)
|
|
if current_user.avatar:
|
|
old_avatar_path = Path(__file__).parent.parent.parent / current_user.avatar.lstrip('/')
|
|
if old_avatar_path.exists() and old_avatar_path.is_file():
|
|
try:
|
|
old_avatar_path.unlink()
|
|
except Exception:
|
|
pass
|
|
ext = Path(image.filename).suffix or '.png'
|
|
filename = f'avatar-{current_user.id}-{uuid.uuid4()}{ext}'
|
|
file_path = upload_dir / filename
|
|
async with aiofiles.open(file_path, 'wb') as f:
|
|
await f.write(content)
|
|
image_url = f'/uploads/avatars/{filename}'
|
|
current_user.avatar = image_url
|
|
db.commit()
|
|
db.refresh(current_user)
|
|
base_url = get_base_url(request)
|
|
full_url = normalize_image_url(image_url, base_url)
|
|
return {'success': True, 'status': 'success', 'message': 'Avatar uploaded successfully', 'data': {'avatar_url': image_url, 'full_url': full_url, 'user': {'id': current_user.id, 'name': current_user.full_name, 'email': current_user.email, 'phone': current_user.phone, 'avatar': image_url, 'role': current_user.role.name if current_user.role else 'customer'}}}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
db.rollback()
|
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error uploading avatar: {str(e)}') |