updates
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Cookie, Response
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Cookie, Response, Request, UploadFile, File
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from pathlib import Path
|
||||
import aiofiles
|
||||
import uuid
|
||||
import os
|
||||
|
||||
from ..config.database import get_db
|
||||
from ..services.auth_service import auth_service
|
||||
@@ -12,7 +16,11 @@ from ..schemas.auth import (
|
||||
ResetPasswordRequest,
|
||||
AuthResponse,
|
||||
TokenResponse,
|
||||
MessageResponse
|
||||
MessageResponse,
|
||||
MFAInitResponse,
|
||||
EnableMFARequest,
|
||||
VerifyMFARequest,
|
||||
MFAStatusResponse
|
||||
)
|
||||
from ..middleware.auth import get_current_user
|
||||
from ..models.user import User
|
||||
@@ -20,6 +28,22 @@ from ..models.user import User
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
def get_base_url(request: Request) -> str:
|
||||
"""Get base URL for image normalization"""
|
||||
return os.getenv("SERVER_URL") or f"http://{request.headers.get('host', 'localhost:8000')}"
|
||||
|
||||
|
||||
def normalize_image_url(image_url: str, base_url: str) -> str:
|
||||
"""Normalize image URL to absolute URL"""
|
||||
if not image_url:
|
||||
return image_url
|
||||
if image_url.startswith('http://') or image_url.startswith('https://'):
|
||||
return image_url
|
||||
if image_url.startswith('/'):
|
||||
return f"{base_url}{image_url}"
|
||||
return f"{base_url}/{image_url}"
|
||||
|
||||
|
||||
@router.post("/register", status_code=status.HTTP_201_CREATED)
|
||||
async def register(
|
||||
request: RegisterRequest,
|
||||
@@ -79,9 +103,18 @@ async def login(
|
||||
db=db,
|
||||
email=request.email,
|
||||
password=request.password,
|
||||
remember_me=request.rememberMe or False
|
||||
remember_me=request.rememberMe or False,
|
||||
mfa_token=request.mfaToken
|
||||
)
|
||||
|
||||
# Check if MFA is required
|
||||
if result.get("requires_mfa"):
|
||||
return {
|
||||
"status": "success",
|
||||
"requires_mfa": True,
|
||||
"user_id": result["user_id"]
|
||||
}
|
||||
|
||||
# Set refresh token as HttpOnly cookie
|
||||
max_age = 7 * 24 * 60 * 60 if request.rememberMe else 1 * 24 * 60 * 60
|
||||
response.set_cookie(
|
||||
@@ -104,7 +137,7 @@ async def login(
|
||||
}
|
||||
except ValueError as e:
|
||||
error_message = str(e)
|
||||
status_code = status.HTTP_401_UNAUTHORIZED if "Invalid email or password" in error_message else status.HTTP_400_BAD_REQUEST
|
||||
status_code = status.HTTP_401_UNAUTHORIZED if "Invalid email or password" in error_message or "Invalid MFA token" in error_message else status.HTTP_400_BAD_REQUEST
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
@@ -260,3 +293,229 @@ async def reset_password(
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
|
||||
# MFA Routes
|
||||
from ..services.mfa_service import mfa_service
|
||||
from ..config.settings import settings
|
||||
|
||||
|
||||
@router.get("/mfa/init")
|
||||
async def init_mfa(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Initialize MFA setup - generate secret and QR code"""
|
||||
try:
|
||||
if current_user.mfa_enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="MFA is already enabled"
|
||||
)
|
||||
|
||||
secret = mfa_service.generate_secret()
|
||||
app_name = getattr(settings, 'APP_NAME', 'Hotel Booking')
|
||||
qr_code = mfa_service.generate_qr_code(secret, current_user.email, app_name)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"data": {
|
||||
"secret": secret,
|
||||
"qr_code": qr_code
|
||||
}
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error initializing MFA: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/mfa/enable")
|
||||
async def enable_mfa(
|
||||
request: EnableMFARequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Enable MFA after verifying token"""
|
||||
try:
|
||||
success, backup_codes = mfa_service.enable_mfa(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
secret=request.secret,
|
||||
verification_token=request.verification_token
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "MFA enabled successfully",
|
||||
"data": {
|
||||
"backup_codes": backup_codes
|
||||
}
|
||||
}
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error enabling MFA: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/mfa/disable")
|
||||
async def disable_mfa(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Disable MFA"""
|
||||
try:
|
||||
mfa_service.disable_mfa(db=db, user_id=current_user.id)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "MFA disabled successfully"
|
||||
}
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error disabling MFA: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/mfa/status", response_model=MFAStatusResponse)
|
||||
async def get_mfa_status(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get MFA status for current user"""
|
||||
try:
|
||||
status_data = mfa_service.get_mfa_status(db=db, user_id=current_user.id)
|
||||
return status_data
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error getting MFA status: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/mfa/regenerate-backup-codes")
|
||||
async def regenerate_backup_codes(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Regenerate backup codes for MFA"""
|
||||
try:
|
||||
backup_codes = mfa_service.regenerate_backup_codes(db=db, user_id=current_user.id)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Backup codes regenerated successfully",
|
||||
"data": {
|
||||
"backup_codes": backup_codes
|
||||
}
|
||||
}
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error regenerating backup codes: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/avatar/upload")
|
||||
async def upload_avatar(
|
||||
request: Request,
|
||||
image: UploadFile = File(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Upload user avatar"""
|
||||
try:
|
||||
# Validate file type
|
||||
if not image.content_type or not image.content_type.startswith('image/'):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="File must be an image"
|
||||
)
|
||||
|
||||
# Validate file size (max 2MB)
|
||||
content = await image.read()
|
||||
if len(content) > 2 * 1024 * 1024: # 2MB
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Avatar file size must be less than 2MB"
|
||||
)
|
||||
|
||||
# Create uploads directory
|
||||
upload_dir = Path(__file__).parent.parent.parent / "uploads" / "avatars"
|
||||
upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Delete old avatar if exists
|
||||
if current_user.avatar:
|
||||
old_avatar_path = Path(__file__).parent.parent.parent / current_user.avatar.lstrip('/')
|
||||
if old_avatar_path.exists() and old_avatar_path.is_file():
|
||||
try:
|
||||
old_avatar_path.unlink()
|
||||
except Exception:
|
||||
pass # Ignore deletion errors
|
||||
|
||||
# Generate filename
|
||||
ext = Path(image.filename).suffix or '.png'
|
||||
filename = f"avatar-{current_user.id}-{uuid.uuid4()}{ext}"
|
||||
file_path = upload_dir / filename
|
||||
|
||||
# Save file
|
||||
async with aiofiles.open(file_path, 'wb') as f:
|
||||
await f.write(content)
|
||||
|
||||
# Update user avatar
|
||||
image_url = f"/uploads/avatars/{filename}"
|
||||
current_user.avatar = image_url
|
||||
db.commit()
|
||||
db.refresh(current_user)
|
||||
|
||||
# Return the image URL
|
||||
base_url = get_base_url(request)
|
||||
full_url = normalize_image_url(image_url, base_url)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Avatar uploaded successfully",
|
||||
"data": {
|
||||
"avatar_url": image_url,
|
||||
"full_url": full_url,
|
||||
"user": {
|
||||
"id": current_user.id,
|
||||
"name": current_user.full_name,
|
||||
"email": current_user.email,
|
||||
"phone": current_user.phone,
|
||||
"avatar": image_url,
|
||||
"role": current_user.role.name if current_user.role else "customer"
|
||||
}
|
||||
}
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error uploading avatar: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user