update
This commit is contained in:
0
Backend/src/auth/__init__.py
Normal file
0
Backend/src/auth/__init__.py
Normal file
0
Backend/src/auth/models/__init__.py
Normal file
0
Backend/src/auth/models/__init__.py
Normal file
15
Backend/src/auth/models/password_reset_token.py
Normal file
15
Backend/src/auth/models/password_reset_token.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from ...shared.config.database import Base
|
||||
|
||||
class PasswordResetToken(Base):
|
||||
__tablename__ = 'password_reset_tokens'
|
||||
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
|
||||
token = Column(String(255), unique=True, nullable=False, index=True)
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
used = Column(Boolean, default=False, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
||||
user = relationship('User')
|
||||
13
Backend/src/auth/models/refresh_token.py
Normal file
13
Backend/src/auth/models/refresh_token.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from ...shared.config.database import Base
|
||||
|
||||
class RefreshToken(Base):
|
||||
__tablename__ = 'refresh_tokens'
|
||||
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
|
||||
token = Column(String(500), unique=True, nullable=False, index=True)
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
user = relationship('User', back_populates='refresh_tokens')
|
||||
13
Backend/src/auth/models/role.py
Normal file
13
Backend/src/auth/models/role.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from ...shared.config.database import Base
|
||||
|
||||
class Role(Base):
|
||||
__tablename__ = 'roles'
|
||||
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
||||
name = Column(String(50), unique=True, nullable=False, index=True)
|
||||
description = Column(String(255), nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
||||
users = relationship('User', back_populates='role')
|
||||
53
Backend/src/auth/models/user.py
Normal file
53
Backend/src/auth/models/user.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from sqlalchemy import Column, Integer, String, Boolean, Text, ForeignKey, DateTime, Numeric
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from ...shared.config.database import Base
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'users'
|
||||
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
||||
role_id = Column(Integer, ForeignKey('roles.id'), nullable=False)
|
||||
email = Column(String(100), unique=True, nullable=False, index=True)
|
||||
password = Column(String(255), nullable=False)
|
||||
full_name = Column(String(100), nullable=False)
|
||||
phone = Column(String(20), nullable=True)
|
||||
address = Column(Text, nullable=True)
|
||||
avatar = Column(String(255), nullable=True)
|
||||
currency = Column(String(3), nullable=False, default='VND')
|
||||
is_active = Column(Boolean, nullable=False, default=True)
|
||||
mfa_enabled = Column(Boolean, nullable=False, default=False)
|
||||
mfa_secret = Column(String(255), nullable=True)
|
||||
mfa_backup_codes = Column(Text, nullable=True)
|
||||
|
||||
# Account lockout fields
|
||||
failed_login_attempts = Column(Integer, nullable=False, default=0)
|
||||
locked_until = Column(DateTime, nullable=True)
|
||||
|
||||
# Guest Profile & CRM fields
|
||||
is_vip = Column(Boolean, nullable=False, default=False)
|
||||
lifetime_value = Column(Numeric(10, 2), nullable=True, default=0) # Total revenue from guest
|
||||
satisfaction_score = Column(Numeric(3, 2), nullable=True) # Average satisfaction score (0-5)
|
||||
last_visit_date = Column(DateTime, nullable=True) # Last booking check-in date
|
||||
total_visits = Column(Integer, nullable=False, default=0) # Total number of bookings
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
||||
role = relationship('Role', back_populates='users')
|
||||
bookings = relationship('Booking', back_populates='user')
|
||||
refresh_tokens = relationship('RefreshToken', back_populates='user', cascade='all, delete-orphan')
|
||||
checkins_processed = relationship('CheckInCheckOut', foreign_keys='CheckInCheckOut.checkin_by', back_populates='checked_in_by')
|
||||
checkouts_processed = relationship('CheckInCheckOut', foreign_keys='CheckInCheckOut.checkout_by', back_populates='checked_out_by')
|
||||
reviews = relationship('Review', back_populates='user')
|
||||
favorites = relationship('Favorite', back_populates='user', cascade='all, delete-orphan')
|
||||
service_bookings = relationship('ServiceBooking', back_populates='user')
|
||||
visitor_chats = relationship('Chat', foreign_keys='Chat.visitor_id', back_populates='visitor')
|
||||
staff_chats = relationship('Chat', foreign_keys='Chat.staff_id', back_populates='staff')
|
||||
loyalty = relationship('UserLoyalty', back_populates='user', uselist=False, cascade='all, delete-orphan')
|
||||
referrals = relationship('Referral', foreign_keys='Referral.referred_user_id', back_populates='referred_user')
|
||||
|
||||
# Guest Profile & CRM relationships
|
||||
guest_preferences = relationship('GuestPreference', back_populates='user', uselist=False, cascade='all, delete-orphan')
|
||||
guest_notes = relationship('GuestNote', foreign_keys='GuestNote.user_id', back_populates='user', cascade='all, delete-orphan')
|
||||
guest_tags = relationship('GuestTag', secondary='guest_tag_associations', back_populates='users')
|
||||
guest_communications = relationship('GuestCommunication', foreign_keys='GuestCommunication.user_id', back_populates='user', cascade='all, delete-orphan')
|
||||
guest_segments = relationship('GuestSegment', secondary='guest_segment_associations', back_populates='users')
|
||||
0
Backend/src/auth/routes/__init__.py
Normal file
0
Backend/src/auth/routes/__init__.py
Normal file
413
Backend/src/auth/routes/auth_routes.py
Normal file
413
Backend/src/auth/routes/auth_routes.py
Normal file
@@ -0,0 +1,413 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Cookie, Response, Request, UploadFile, File
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from pathlib import Path
|
||||
import aiofiles
|
||||
import uuid
|
||||
import os
|
||||
from ...shared.config.database import get_db
|
||||
from ..services.auth_service import auth_service
|
||||
from ..schemas.auth import RegisterRequest, LoginRequest, RefreshTokenRequest, ForgotPasswordRequest, ResetPasswordRequest, AuthResponse, TokenResponse, MessageResponse, MFAInitResponse, EnableMFARequest, VerifyMFARequest, MFAStatusResponse, UpdateProfileRequest
|
||||
from ...security.middleware.auth import get_current_user
|
||||
from ..models.user import User
|
||||
from ...analytics.services.audit_service import audit_service
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
|
||||
router = APIRouter(prefix='/auth', tags=['auth'])
|
||||
|
||||
# Stricter rate limits for authentication endpoints
|
||||
AUTH_RATE_LIMIT = "5/minute" # 5 attempts per minute per IP
|
||||
PASSWORD_RESET_LIMIT = "3/hour" # 3 password reset requests per hour per IP
|
||||
LOGIN_RATE_LIMIT = "10/minute" # 10 login attempts per minute per IP
|
||||
|
||||
# Initialize limiter - will be set from app state
|
||||
limiter = None
|
||||
|
||||
def get_limiter(request: Request) -> Limiter:
|
||||
"""Get limiter instance from app state."""
|
||||
global limiter
|
||||
if hasattr(request.app.state, 'limiter'):
|
||||
limiter = request.app.state.limiter
|
||||
return limiter
|
||||
|
||||
def get_base_url(request: Request) -> str:
|
||||
return os.getenv('SERVER_URL') or f'http://{request.headers.get('host', 'localhost:8000')}'
|
||||
|
||||
def normalize_image_url(image_url: str, base_url: str) -> str:
|
||||
if not image_url:
|
||||
return image_url
|
||||
if image_url.startswith('http://') or image_url.startswith('https://'):
|
||||
return image_url
|
||||
if image_url.startswith('/'):
|
||||
return f'{base_url}{image_url}'
|
||||
return f'{base_url}/{image_url}'
|
||||
|
||||
@router.post('/register', status_code=status.HTTP_201_CREATED)
|
||||
async def register(
|
||||
request: Request,
|
||||
register_request: RegisterRequest,
|
||||
response: Response,
|
||||
db: Session=Depends(get_db)
|
||||
):
|
||||
# Rate limiting is handled by middleware, but we can add additional checks here if needed
|
||||
client_ip = request.client.host if request.client else None
|
||||
user_agent = request.headers.get('User-Agent')
|
||||
request_id = getattr(request.state, 'request_id', None)
|
||||
|
||||
try:
|
||||
result = await auth_service.register(db=db, name=register_request.name, email=register_request.email, password=register_request.password, phone=register_request.phone)
|
||||
from ...shared.config.settings import settings
|
||||
max_age = 7 * 24 * 60 * 60 # 7 days for registration
|
||||
# Use secure cookies in production (HTTPS required)
|
||||
# Set access token in httpOnly cookie for security
|
||||
# Use 'lax' in development for cross-origin support, 'strict' in production
|
||||
samesite_value = 'strict' if settings.is_production else 'lax'
|
||||
response.set_cookie(
|
||||
key='accessToken',
|
||||
value=result['token'],
|
||||
httponly=True,
|
||||
secure=settings.is_production, # Secure flag enabled in production
|
||||
samesite=samesite_value,
|
||||
max_age=max_age,
|
||||
path='/'
|
||||
)
|
||||
# Set refresh token in httpOnly cookie
|
||||
response.set_cookie(
|
||||
key='refreshToken',
|
||||
value=result['refreshToken'],
|
||||
httponly=True,
|
||||
secure=settings.is_production, # Secure flag enabled in production
|
||||
samesite=samesite_value,
|
||||
max_age=max_age,
|
||||
path='/'
|
||||
)
|
||||
|
||||
# Log successful registration
|
||||
await audit_service.log_action(
|
||||
db=db,
|
||||
action='user_registered',
|
||||
resource_type='user',
|
||||
user_id=result['user']['id'],
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
details={'email': register_request.email, 'name': register_request.name},
|
||||
status='success'
|
||||
)
|
||||
|
||||
# Return user data but NOT the token (it's in httpOnly cookie now)
|
||||
return {'status': 'success', 'message': 'Registration successful', 'data': {'user': result['user']}}
|
||||
except ValueError as e:
|
||||
error_message = str(e)
|
||||
# Log failed registration attempt
|
||||
await audit_service.log_action(
|
||||
db=db,
|
||||
action='user_registration_failed',
|
||||
resource_type='user',
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
details={'email': register_request.email, 'name': register_request.name},
|
||||
status='failed',
|
||||
error_message=error_message
|
||||
)
|
||||
return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content={'status': 'error', 'message': error_message})
|
||||
|
||||
@router.post('/login')
|
||||
async def login(
|
||||
request: Request,
|
||||
login_request: LoginRequest,
|
||||
response: Response,
|
||||
db: Session=Depends(get_db)
|
||||
):
|
||||
client_ip = request.client.host if request.client else None
|
||||
user_agent = request.headers.get('User-Agent')
|
||||
request_id = getattr(request.state, 'request_id', None)
|
||||
|
||||
try:
|
||||
result = await auth_service.login(db=db, email=login_request.email, password=login_request.password, remember_me=login_request.rememberMe or False, mfa_token=login_request.mfaToken)
|
||||
if result.get('requires_mfa'):
|
||||
# Log MFA required
|
||||
user = db.query(User).filter(User.email == login_request.email.lower().strip()).first()
|
||||
if user:
|
||||
await audit_service.log_action(
|
||||
db=db,
|
||||
action='login_mfa_required',
|
||||
resource_type='authentication',
|
||||
user_id=user.id,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
details={'email': login_request.email},
|
||||
status='success'
|
||||
)
|
||||
return {'status': 'success', 'requires_mfa': True, 'user_id': result['user_id']}
|
||||
from ...shared.config.settings import settings
|
||||
max_age = 7 * 24 * 60 * 60 if login_request.rememberMe else 1 * 24 * 60 * 60
|
||||
# Use secure cookies in production (HTTPS required)
|
||||
# Set access token in httpOnly cookie for security
|
||||
# Use 'lax' in development for cross-origin support, 'strict' in production
|
||||
samesite_value = 'strict' if settings.is_production else 'lax'
|
||||
response.set_cookie(
|
||||
key='accessToken',
|
||||
value=result['token'],
|
||||
httponly=True,
|
||||
secure=settings.is_production, # Secure flag enabled in production
|
||||
samesite=samesite_value,
|
||||
max_age=max_age,
|
||||
path='/'
|
||||
)
|
||||
# Set refresh token in httpOnly cookie
|
||||
response.set_cookie(
|
||||
key='refreshToken',
|
||||
value=result['refreshToken'],
|
||||
httponly=True,
|
||||
secure=settings.is_production, # Secure flag enabled in production
|
||||
samesite=samesite_value,
|
||||
max_age=max_age,
|
||||
path='/'
|
||||
)
|
||||
|
||||
# Log successful login
|
||||
await audit_service.log_action(
|
||||
db=db,
|
||||
action='login_success',
|
||||
resource_type='authentication',
|
||||
user_id=result['user']['id'],
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
details={'email': login_request.email, 'remember_me': login_request.rememberMe},
|
||||
status='success'
|
||||
)
|
||||
|
||||
# Return user data but NOT the token (it's in httpOnly cookie now)
|
||||
return {'status': 'success', 'data': {'user': result['user']}}
|
||||
except ValueError as e:
|
||||
error_message = str(e)
|
||||
status_code = status.HTTP_401_UNAUTHORIZED if 'Invalid email or password' in error_message or 'Invalid MFA token' in error_message else status.HTTP_400_BAD_REQUEST
|
||||
|
||||
# Log failed login attempt
|
||||
await audit_service.log_action(
|
||||
db=db,
|
||||
action='login_failed',
|
||||
resource_type='authentication',
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
details={'email': login_request.email},
|
||||
status='failed',
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
return JSONResponse(status_code=status_code, content={'status': 'error', 'message': error_message})
|
||||
|
||||
@router.post('/refresh-token', response_model=TokenResponse)
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
response: Response,
|
||||
refreshToken: str=Cookie(None),
|
||||
db: Session=Depends(get_db)
|
||||
):
|
||||
if not refreshToken:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Refresh token not found')
|
||||
try:
|
||||
result = await auth_service.refresh_access_token(db, refreshToken)
|
||||
from ...shared.config.settings import settings
|
||||
# Set new access token in httpOnly cookie
|
||||
# Use 'lax' in development for cross-origin support, 'strict' in production
|
||||
samesite_value = 'strict' if settings.is_production else 'lax'
|
||||
max_age = 7 * 24 * 60 * 60 # 7 days
|
||||
response.set_cookie(
|
||||
key='accessToken',
|
||||
value=result['token'],
|
||||
httponly=True,
|
||||
secure=settings.is_production,
|
||||
samesite=samesite_value,
|
||||
max_age=max_age,
|
||||
path='/'
|
||||
)
|
||||
# Return user data but NOT the token (it's in httpOnly cookie now)
|
||||
return {'status': 'success', 'data': {'user': result.get('user')}}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
|
||||
|
||||
@router.post('/logout', response_model=MessageResponse)
|
||||
async def logout(
|
||||
request: Request,
|
||||
response: Response,
|
||||
refreshToken: str=Cookie(None),
|
||||
current_user: User=Depends(get_current_user),
|
||||
db: Session=Depends(get_db)
|
||||
):
|
||||
client_ip = request.client.host if request.client else None
|
||||
user_agent = request.headers.get('User-Agent')
|
||||
request_id = getattr(request.state, 'request_id', None)
|
||||
|
||||
if refreshToken:
|
||||
await auth_service.logout(db, refreshToken)
|
||||
|
||||
# Delete both access and refresh token cookies
|
||||
from ...shared.config.settings import settings
|
||||
# Use 'lax' in development for cross-origin support, 'strict' in production
|
||||
samesite_value = 'strict' if settings.is_production else 'lax'
|
||||
response.delete_cookie(key='refreshToken', path='/', secure=settings.is_production, samesite=samesite_value)
|
||||
response.delete_cookie(key='accessToken', path='/', secure=settings.is_production, samesite=samesite_value)
|
||||
|
||||
# Log logout
|
||||
await audit_service.log_action(
|
||||
db=db,
|
||||
action='logout',
|
||||
resource_type='authentication',
|
||||
user_id=current_user.id,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
details={'email': current_user.email},
|
||||
status='success'
|
||||
)
|
||||
|
||||
return {'status': 'success', 'message': 'Logout successful'}
|
||||
|
||||
@router.get('/profile')
|
||||
async def get_profile(current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
|
||||
try:
|
||||
user = await auth_service.get_profile(db, current_user.id)
|
||||
return {'status': 'success', 'data': {'user': user}}
|
||||
except ValueError as e:
|
||||
if 'User not found' in str(e):
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
@router.put('/profile')
|
||||
async def update_profile(profile_data: UpdateProfileRequest, current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
|
||||
try:
|
||||
user = await auth_service.update_profile(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
full_name=profile_data.full_name,
|
||||
email=profile_data.email,
|
||||
phone_number=profile_data.phone_number,
|
||||
password=profile_data.password,
|
||||
current_password=profile_data.currentPassword,
|
||||
currency=profile_data.currency
|
||||
)
|
||||
return {'status': 'success', 'message': 'Profile updated successfully', 'data': {'user': user}}
|
||||
except ValueError as e:
|
||||
error_message = str(e)
|
||||
status_code = status.HTTP_400_BAD_REQUEST
|
||||
if 'not found' in error_message.lower():
|
||||
status_code = status.HTTP_404_NOT_FOUND
|
||||
raise HTTPException(status_code=status_code, detail=error_message)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'An error occurred: {str(e)}')
|
||||
|
||||
@router.post('/forgot-password', response_model=MessageResponse)
|
||||
async def forgot_password(request: ForgotPasswordRequest, db: Session=Depends(get_db)):
|
||||
result = await auth_service.forgot_password(db, request.email)
|
||||
return {'status': 'success', 'message': result['message']}
|
||||
|
||||
@router.post('/reset-password', response_model=MessageResponse)
|
||||
async def reset_password(request: ResetPasswordRequest, db: Session=Depends(get_db)):
|
||||
try:
|
||||
result = await auth_service.reset_password(db=db, token=request.token, password=request.password)
|
||||
return {'status': 'success', 'message': result['message']}
|
||||
except ValueError as e:
|
||||
status_code = status.HTTP_400_BAD_REQUEST
|
||||
if 'User not found' in str(e):
|
||||
status_code = status.HTTP_404_NOT_FOUND
|
||||
raise HTTPException(status_code=status_code, detail=str(e))
|
||||
from ..services.mfa_service import mfa_service
|
||||
from ...shared.config.settings import settings
|
||||
|
||||
@router.get('/mfa/init')
|
||||
async def init_mfa(current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
|
||||
try:
|
||||
if current_user.mfa_enabled:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='MFA is already enabled')
|
||||
secret = mfa_service.generate_secret()
|
||||
app_name = getattr(settings, 'APP_NAME', 'Hotel Booking')
|
||||
qr_code = mfa_service.generate_qr_code(secret, current_user.email, app_name)
|
||||
return {'status': 'success', 'data': {'secret': secret, 'qr_code': qr_code}}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error initializing MFA: {str(e)}')
|
||||
|
||||
@router.post('/mfa/enable')
|
||||
async def enable_mfa(request: EnableMFARequest, current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
|
||||
try:
|
||||
success, backup_codes = mfa_service.enable_mfa(db=db, user_id=current_user.id, secret=request.secret, verification_token=request.verification_token)
|
||||
return {'status': 'success', 'message': 'MFA enabled successfully', 'data': {'backup_codes': backup_codes}}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error enabling MFA: {str(e)}')
|
||||
|
||||
@router.post('/mfa/disable')
|
||||
async def disable_mfa(current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
|
||||
try:
|
||||
mfa_service.disable_mfa(db=db, user_id=current_user.id)
|
||||
return {'status': 'success', 'message': 'MFA disabled successfully'}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error disabling MFA: {str(e)}')
|
||||
|
||||
@router.get('/mfa/status', response_model=MFAStatusResponse)
|
||||
async def get_mfa_status(current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
|
||||
try:
|
||||
status_data = mfa_service.get_mfa_status(db=db, user_id=current_user.id)
|
||||
return status_data
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error getting MFA status: {str(e)}')
|
||||
|
||||
@router.post('/mfa/regenerate-backup-codes')
|
||||
async def regenerate_backup_codes(current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
|
||||
try:
|
||||
backup_codes = mfa_service.regenerate_backup_codes(db=db, user_id=current_user.id)
|
||||
return {'status': 'success', 'message': 'Backup codes regenerated successfully', 'data': {'backup_codes': backup_codes}}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error regenerating backup codes: {str(e)}')
|
||||
|
||||
@router.post('/avatar/upload')
|
||||
async def upload_avatar(request: Request, image: UploadFile=File(...), current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
|
||||
try:
|
||||
# Use comprehensive file validation (magic bytes + size)
|
||||
from ...shared.utils.file_validation import validate_uploaded_image
|
||||
max_avatar_size = 2 * 1024 * 1024 # 2MB for avatars
|
||||
|
||||
# Validate file completely (MIME type, size, magic bytes, integrity)
|
||||
content = await validate_uploaded_image(image, max_avatar_size)
|
||||
upload_dir = Path(__file__).parent.parent.parent / 'uploads' / 'avatars'
|
||||
upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
if current_user.avatar:
|
||||
old_avatar_path = Path(__file__).parent.parent.parent / current_user.avatar.lstrip('/')
|
||||
if old_avatar_path.exists() and old_avatar_path.is_file():
|
||||
try:
|
||||
old_avatar_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
ext = Path(image.filename).suffix or '.png'
|
||||
filename = f'avatar-{current_user.id}-{uuid.uuid4()}{ext}'
|
||||
file_path = upload_dir / filename
|
||||
async with aiofiles.open(file_path, 'wb') as f:
|
||||
await f.write(content)
|
||||
image_url = f'/uploads/avatars/{filename}'
|
||||
current_user.avatar = image_url
|
||||
db.commit()
|
||||
db.refresh(current_user)
|
||||
base_url = get_base_url(request)
|
||||
full_url = normalize_image_url(image_url, base_url)
|
||||
return {'success': True, 'status': 'success', 'message': 'Avatar uploaded successfully', 'data': {'avatar_url': image_url, 'full_url': full_url, 'user': {'id': current_user.id, 'name': current_user.full_name, 'email': current_user.email, 'phone': current_user.phone, 'avatar': image_url, 'role': current_user.role.name if current_user.role else 'customer'}}}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error uploading avatar: {str(e)}')
|
||||
164
Backend/src/auth/routes/user_routes.py
Normal file
164
Backend/src/auth/routes/user_routes.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import or_
|
||||
from typing import Optional
|
||||
import bcrypt
|
||||
from ...shared.config.database import get_db
|
||||
from ...security.middleware.auth import get_current_user, authorize_roles
|
||||
from ..models.user import User
|
||||
from ..models.role import Role
|
||||
from ...bookings.models.booking import Booking, BookingStatus
|
||||
from ...shared.utils.role_helpers import can_manage_users
|
||||
from ...shared.utils.response_helpers import success_response
|
||||
from ...analytics.services.audit_service import audit_service
|
||||
from ..schemas.user import CreateUserRequest, UpdateUserRequest
|
||||
router = APIRouter(prefix='/users', tags=['users'])
|
||||
|
||||
@router.get('/', dependencies=[Depends(authorize_roles('admin'))])
|
||||
async def get_users(search: Optional[str]=Query(None), role: Optional[str]=Query(None), status_filter: Optional[str]=Query(None, alias='status'), page: int=Query(1, ge=1), limit: int=Query(10, ge=1, le=100), current_user: User=Depends(authorize_roles('admin')), db: Session=Depends(get_db)):
|
||||
try:
|
||||
query = db.query(User)
|
||||
if search:
|
||||
query = query.filter(or_(User.full_name.like(f'%{search}%'), User.email.like(f'%{search}%'), User.phone.like(f'%{search}%')))
|
||||
if role:
|
||||
role_map = {'admin': 1, 'staff': 2, 'customer': 3, 'accountant': 4}
|
||||
if role in role_map:
|
||||
query = query.filter(User.role_id == role_map[role])
|
||||
if status_filter:
|
||||
is_active = status_filter == 'active'
|
||||
query = query.filter(User.is_active == is_active)
|
||||
total = query.count()
|
||||
offset = (page - 1) * limit
|
||||
users = query.order_by(User.created_at.desc()).offset(offset).limit(limit).all()
|
||||
result = []
|
||||
for user in users:
|
||||
user_dict = {'id': user.id, 'email': user.email, 'full_name': user.full_name, 'phone': user.phone, 'phone_number': user.phone, 'address': user.address, 'avatar': user.avatar, 'currency': getattr(user, 'currency', 'VND'), 'is_active': user.is_active, 'status': 'active' if user.is_active else 'inactive', 'role_id': user.role_id, 'role': user.role.name if user.role else 'customer', 'created_at': user.created_at.isoformat() if user.created_at else None, 'updated_at': user.updated_at.isoformat() if user.updated_at else None}
|
||||
result.append(user_dict)
|
||||
return success_response(data={'users': result, 'pagination': {'total': total, 'page': page, 'limit': limit, 'totalPages': (total + limit - 1) // limit}})
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get('/{id}', dependencies=[Depends(authorize_roles('admin'))])
|
||||
async def get_user_by_id(id: int, current_user: User=Depends(authorize_roles('admin')), db: Session=Depends(get_db)):
|
||||
try:
|
||||
user = db.query(User).filter(User.id == id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail='User not found')
|
||||
bookings = db.query(Booking).filter(Booking.user_id == id).order_by(Booking.created_at.desc()).limit(5).all()
|
||||
user_dict = {'id': user.id, 'email': user.email, 'full_name': user.full_name, 'phone': user.phone, 'phone_number': user.phone, 'address': user.address, 'avatar': user.avatar, 'currency': getattr(user, 'currency', 'VND'), 'is_active': user.is_active, 'status': 'active' if user.is_active else 'inactive', 'role_id': user.role_id, 'role': user.role.name if user.role else 'customer', 'created_at': user.created_at.isoformat() if user.created_at else None, 'updated_at': user.updated_at.isoformat() if user.updated_at else None, 'bookings': [{'id': b.id, 'booking_number': b.booking_number, 'status': b.status.value if isinstance(b.status, BookingStatus) else b.status, 'created_at': b.created_at.isoformat() if b.created_at else None} for b in bookings]}
|
||||
return success_response(data={'user': user_dict})
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post('/', dependencies=[Depends(authorize_roles('admin'))])
|
||||
async def create_user(
|
||||
request: Request,
|
||||
user_data: CreateUserRequest,
|
||||
current_user: User=Depends(authorize_roles('admin')),
|
||||
db: Session=Depends(get_db)
|
||||
):
|
||||
"""Create a user with validated input using Pydantic schema."""
|
||||
client_ip = request.client.host if request.client else None
|
||||
user_agent = request.headers.get('User-Agent')
|
||||
request_id = getattr(request.state, 'request_id', None)
|
||||
|
||||
try:
|
||||
email = user_data.email
|
||||
password = user_data.password
|
||||
full_name = user_data.full_name
|
||||
phone_number = user_data.phone_number
|
||||
role_id = user_data.role_id or 3 # Default to customer role
|
||||
existing = db.query(User).filter(User.email == email).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail='Email already exists')
|
||||
password_bytes = password.encode('utf-8')
|
||||
salt = bcrypt.gensalt()
|
||||
hashed_password = bcrypt.hashpw(password_bytes, salt).decode('utf-8')
|
||||
user = User(email=email, password=hashed_password, full_name=full_name, phone=phone_number, role_id=role_id, is_active=True)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
# Log admin action - user creation
|
||||
await audit_service.log_action(
|
||||
db=db,
|
||||
action='admin_user_created',
|
||||
resource_type='user',
|
||||
user_id=current_user.id,
|
||||
resource_id=user.id,
|
||||
ip_address=client_ip,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
details={
|
||||
'created_user_email': user.email,
|
||||
'created_user_name': user.full_name,
|
||||
'role_id': user.role_id,
|
||||
'is_active': user.is_active
|
||||
},
|
||||
status='success'
|
||||
)
|
||||
|
||||
user_dict = {'id': user.id, 'email': user.email, 'full_name': user.full_name, 'phone': user.phone, 'phone_number': user.phone, 'currency': getattr(user, 'currency', 'VND'), 'role_id': user.role_id, 'is_active': user.is_active}
|
||||
return success_response(data={'user': user_dict}, message='User created successfully')
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.put('/{id}')
|
||||
async def update_user(id: int, user_data: UpdateUserRequest, current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
|
||||
"""Update a user with validated input using Pydantic schema."""
|
||||
try:
|
||||
if not can_manage_users(current_user, db) and current_user.id != id:
|
||||
raise HTTPException(status_code=403, detail='Forbidden')
|
||||
user = db.query(User).filter(User.id == id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail='User not found')
|
||||
|
||||
# Check email uniqueness if being updated
|
||||
if user_data.email and user_data.email != user.email:
|
||||
existing = db.query(User).filter(User.email == user_data.email).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail='Email already exists')
|
||||
|
||||
# Update fields if provided
|
||||
if user_data.full_name is not None:
|
||||
user.full_name = user_data.full_name
|
||||
if user_data.email is not None and can_manage_users(current_user, db):
|
||||
user.email = user_data.email
|
||||
if user_data.phone_number is not None:
|
||||
user.phone = user_data.phone_number
|
||||
if user_data.role_id is not None and can_manage_users(current_user, db):
|
||||
user.role_id = user_data.role_id
|
||||
if user_data.is_active is not None and can_manage_users(current_user, db):
|
||||
user.is_active = user_data.is_active
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
user_dict = {'id': user.id, 'email': user.email, 'full_name': user.full_name, 'phone': user.phone, 'phone_number': user.phone, 'currency': getattr(user, 'currency', 'VND'), 'role_id': user.role_id, 'is_active': user.is_active}
|
||||
return success_response(data={'user': user_dict}, message='User updated successfully')
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.delete('/{id}', dependencies=[Depends(authorize_roles('admin'))])
|
||||
async def delete_user(id: int, current_user: User=Depends(authorize_roles('admin')), db: Session=Depends(get_db)):
|
||||
try:
|
||||
user = db.query(User).filter(User.id == id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail='User not found')
|
||||
active_bookings = db.query(Booking).filter(Booking.user_id == id, Booking.status.in_([BookingStatus.pending, BookingStatus.confirmed, BookingStatus.checked_in])).count()
|
||||
if active_bookings > 0:
|
||||
raise HTTPException(status_code=400, detail='Cannot delete user with active bookings')
|
||||
db.delete(user)
|
||||
db.commit()
|
||||
return success_response(message='User deleted successfully')
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
0
Backend/src/auth/schemas/__init__.py
Normal file
0
Backend/src/auth/schemas/__init__.py
Normal file
135
Backend/src/auth/schemas/auth.py
Normal file
135
Backend/src/auth/schemas/auth.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from pydantic import BaseModel, EmailStr, Field, validator
|
||||
from typing import Optional
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
name: str = Field(..., min_length=2, max_length=50)
|
||||
email: EmailStr
|
||||
password: str = Field(..., min_length=8)
|
||||
phone: Optional[str] = None
|
||||
|
||||
@validator('password')
|
||||
def validate_password(cls, v):
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters')
|
||||
if not any((c.isupper() for c in v)):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
if not any((c.islower() for c in v)):
|
||||
raise ValueError('Password must contain at least one lowercase letter')
|
||||
if not any((c.isdigit() for c in v)):
|
||||
raise ValueError('Password must contain at least one number')
|
||||
return v
|
||||
|
||||
@validator('phone')
|
||||
def validate_phone(cls, v):
|
||||
if v:
|
||||
cleaned = ''.join(c for c in v if c.isdigit())
|
||||
if len(cleaned) < 5:
|
||||
raise ValueError('Phone number must contain at least 5 digits')
|
||||
return v
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
rememberMe: Optional[bool] = False
|
||||
mfaToken: Optional[str] = None
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
refreshToken: Optional[str] = None
|
||||
|
||||
class ForgotPasswordRequest(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
class ResetPasswordRequest(BaseModel):
|
||||
token: str
|
||||
password: str = Field(..., min_length=8)
|
||||
|
||||
@validator('password')
|
||||
def validate_password(cls, v):
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters')
|
||||
if not any((c.isupper() for c in v)):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
if not any((c.islower() for c in v)):
|
||||
raise ValueError('Password must contain at least one lowercase letter')
|
||||
if not any((c.isdigit() for c in v)):
|
||||
raise ValueError('Password must contain at least one number')
|
||||
return v
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
email: str
|
||||
phone: Optional[str]
|
||||
role: str
|
||||
createdAt: Optional[str]
|
||||
updatedAt: Optional[str]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
user: UserResponse
|
||||
token: str
|
||||
refreshToken: Optional[str] = None
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
token: str
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
status: str
|
||||
message: str
|
||||
|
||||
class MFAInitResponse(BaseModel):
|
||||
secret: str
|
||||
qr_code: str
|
||||
|
||||
class EnableMFARequest(BaseModel):
|
||||
secret: str
|
||||
verification_token: str
|
||||
|
||||
class VerifyMFARequest(BaseModel):
|
||||
token: str
|
||||
is_backup_code: Optional[bool] = False
|
||||
|
||||
class MFAStatusResponse(BaseModel):
|
||||
mfa_enabled: bool
|
||||
backup_codes_count: int
|
||||
|
||||
class UpdateProfileRequest(BaseModel):
|
||||
full_name: Optional[str] = Field(None, min_length=2, max_length=100, description='Full name')
|
||||
email: Optional[EmailStr] = Field(None, description='Email address')
|
||||
phone_number: Optional[str] = Field(None, min_length=5, max_length=20, description='Phone number')
|
||||
password: Optional[str] = Field(None, min_length=8, description='New password')
|
||||
currentPassword: Optional[str] = Field(None, alias='current_password', description='Current password (required when changing password)')
|
||||
currency: Optional[str] = Field(None, min_length=3, max_length=3, description='Currency code (ISO 4217, e.g., USD, EUR, VND)')
|
||||
|
||||
@validator('password')
|
||||
def validate_password(cls, v):
|
||||
if v is not None:
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters')
|
||||
if not any((c.isupper() for c in v)):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
if not any((c.islower() for c in v)):
|
||||
raise ValueError('Password must contain at least one lowercase letter')
|
||||
if not any((c.isdigit() for c in v)):
|
||||
raise ValueError('Password must contain at least one number')
|
||||
return v
|
||||
|
||||
@validator('phone_number')
|
||||
def validate_phone(cls, v):
|
||||
if v is not None:
|
||||
cleaned = ''.join(c for c in v if c.isdigit())
|
||||
if len(cleaned) < 5:
|
||||
raise ValueError('Phone number must contain at least 5 digits')
|
||||
return v
|
||||
|
||||
@validator('currency')
|
||||
def validate_currency(cls, v):
|
||||
if v is not None:
|
||||
if len(v) != 3 or not v.isalpha():
|
||||
raise ValueError('Currency must be a 3-letter ISO 4217 code (e.g., USD, EUR, VND)')
|
||||
return v.upper() if v else v
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
38
Backend/src/auth/schemas/user.py
Normal file
38
Backend/src/auth/schemas/user.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Pydantic schemas for user-related requests and responses.
|
||||
"""
|
||||
from pydantic import BaseModel, Field, EmailStr, field_validator
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class CreateUserRequest(BaseModel):
|
||||
"""Schema for creating a user."""
|
||||
full_name: str = Field(..., min_length=2, max_length=100, description="Full name")
|
||||
email: EmailStr = Field(..., description="Email address")
|
||||
password: str = Field(..., min_length=8, description="Password")
|
||||
phone_number: Optional[str] = Field(None, max_length=20, description="Phone number")
|
||||
role_id: Optional[int] = Field(None, gt=0, description="Role ID")
|
||||
|
||||
@field_validator('password')
|
||||
@classmethod
|
||||
def validate_password(cls, v: str) -> str:
|
||||
"""Validate password strength."""
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters')
|
||||
if not any(c.isupper() for c in v):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
if not any(c.islower() for c in v):
|
||||
raise ValueError('Password must contain at least one lowercase letter')
|
||||
if not any(c.isdigit() for c in v):
|
||||
raise ValueError('Password must contain at least one number')
|
||||
return v
|
||||
|
||||
|
||||
class UpdateUserRequest(BaseModel):
|
||||
"""Schema for updating a user."""
|
||||
full_name: Optional[str] = Field(None, min_length=2, max_length=100)
|
||||
email: Optional[EmailStr] = None
|
||||
phone_number: Optional[str] = Field(None, max_length=20)
|
||||
role_id: Optional[int] = Field(None, gt=0)
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
0
Backend/src/auth/services/__init__.py
Normal file
0
Backend/src/auth/services/__init__.py
Normal file
500
Backend/src/auth/services/auth_service.py
Normal file
500
Backend/src/auth/services/auth_service.py
Normal file
@@ -0,0 +1,500 @@
|
||||
from jose import jwt
|
||||
import bcrypt
|
||||
from datetime import datetime, timedelta
|
||||
import secrets
|
||||
import hashlib
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
from ..models.user import User
|
||||
from ..models.refresh_token import RefreshToken
|
||||
from ..models.password_reset_token import PasswordResetToken
|
||||
from ..models.role import Role
|
||||
from ...shared.utils.mailer import send_email
|
||||
from ...shared.utils.email_templates import (
|
||||
welcome_email_template,
|
||||
password_reset_email_template,
|
||||
password_changed_email_template
|
||||
)
|
||||
from ...shared.config.settings import settings
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AuthService:
|
||||
def __init__(self):
|
||||
# Security: Fail fast if JWT_SECRET is not configured - never use default values
|
||||
self.jwt_secret = getattr(settings, 'JWT_SECRET', None) or os.getenv("JWT_SECRET")
|
||||
if not self.jwt_secret:
|
||||
error_msg = (
|
||||
'CRITICAL: JWT_SECRET is not configured. '
|
||||
'Please set JWT_SECRET environment variable to a secure random string (minimum 32 characters).'
|
||||
)
|
||||
logger.error(error_msg)
|
||||
if settings.is_production:
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
# In development, generate a secure secret but warn
|
||||
import secrets
|
||||
self.jwt_secret = secrets.token_urlsafe(64)
|
||||
logger.warning(
|
||||
f'JWT_SECRET not configured. Auto-generated secret for development. '
|
||||
f'Set JWT_SECRET environment variable for production: {self.jwt_secret}'
|
||||
)
|
||||
|
||||
# Validate JWT secret strength
|
||||
if len(self.jwt_secret) < 32:
|
||||
error_msg = 'JWT_SECRET must be at least 32 characters long for security.'
|
||||
logger.error(error_msg)
|
||||
if settings.is_production:
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
logger.warning(error_msg)
|
||||
|
||||
# Refresh secret should be different from access secret
|
||||
self.jwt_refresh_secret = os.getenv("JWT_REFRESH_SECRET")
|
||||
if not self.jwt_refresh_secret:
|
||||
# Use a derived secret if not explicitly set, but different from access secret
|
||||
import hashlib
|
||||
self.jwt_refresh_secret = hashlib.sha256((self.jwt_secret + "-refresh").encode()).hexdigest()
|
||||
if not settings.is_production:
|
||||
logger.info('JWT_REFRESH_SECRET not set, using derived secret')
|
||||
|
||||
self.jwt_expires_in = os.getenv("JWT_EXPIRES_IN", "1h")
|
||||
self.jwt_refresh_expires_in = os.getenv("JWT_REFRESH_EXPIRES_IN", "7d")
|
||||
|
||||
def generate_tokens(self, user_id: int) -> dict:
|
||||
access_token = jwt.encode(
|
||||
{"userId": user_id},
|
||||
self.jwt_secret,
|
||||
algorithm="HS256"
|
||||
)
|
||||
|
||||
refresh_token = jwt.encode(
|
||||
{"userId": user_id},
|
||||
self.jwt_refresh_secret,
|
||||
algorithm="HS256"
|
||||
)
|
||||
|
||||
return {"accessToken": access_token, "refreshToken": refresh_token}
|
||||
|
||||
def verify_access_token(self, token: str) -> dict:
|
||||
return jwt.decode(token, self.jwt_secret, algorithms=["HS256"])
|
||||
|
||||
def verify_refresh_token(self, token: str) -> dict:
|
||||
return jwt.decode(token, self.jwt_refresh_secret, algorithms=["HS256"])
|
||||
|
||||
def hash_password(self, password: str) -> str:
|
||||
|
||||
password_bytes = password.encode('utf-8')
|
||||
|
||||
salt = bcrypt.gensalt()
|
||||
hashed = bcrypt.hashpw(password_bytes, salt)
|
||||
return hashed.decode('utf-8')
|
||||
|
||||
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
|
||||
try:
|
||||
password_bytes = plain_password.encode('utf-8')
|
||||
hashed_bytes = hashed_password.encode('utf-8')
|
||||
return bcrypt.checkpw(password_bytes, hashed_bytes)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def format_user_response(self, user: User) -> dict:
|
||||
return {
|
||||
"id": user.id,
|
||||
"name": user.full_name,
|
||||
"email": user.email,
|
||||
"phone": user.phone,
|
||||
"avatar": user.avatar,
|
||||
"currency": getattr(user, 'currency', 'VND'),
|
||||
"role": user.role.name if user.role else "customer",
|
||||
"createdAt": user.created_at.isoformat() if user.created_at else None,
|
||||
"updatedAt": user.updated_at.isoformat() if user.updated_at else None,
|
||||
}
|
||||
|
||||
async def register(self, db: Session, name: str, email: str, password: str, phone: Optional[str] = None) -> dict:
|
||||
# Validate password strength
|
||||
from ...shared.utils.password_validation import validate_password_strength
|
||||
is_valid, errors = validate_password_strength(password)
|
||||
if not is_valid:
|
||||
error_message = 'Password does not meet requirements: ' + '; '.join(errors)
|
||||
raise ValueError(error_message)
|
||||
|
||||
existing_user = db.query(User).filter(User.email == email).first()
|
||||
if existing_user:
|
||||
raise ValueError("Email already registered")
|
||||
|
||||
hashed_password = self.hash_password(password)
|
||||
|
||||
user = User(
|
||||
full_name=name,
|
||||
email=email,
|
||||
password=hashed_password,
|
||||
phone=phone,
|
||||
role_id=3
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
user.role = db.query(Role).filter(Role.id == user.role_id).first()
|
||||
|
||||
tokens = self.generate_tokens(user.id)
|
||||
|
||||
expires_at = datetime.utcnow() + timedelta(days=7)
|
||||
refresh_token = RefreshToken(
|
||||
user_id=user.id,
|
||||
token=tokens["refreshToken"],
|
||||
expires_at=expires_at
|
||||
)
|
||||
db.add(refresh_token)
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
client_url = settings.CLIENT_URL or os.getenv("CLIENT_URL", "http://localhost:5173")
|
||||
email_html = welcome_email_template(user.full_name, user.email, client_url)
|
||||
await send_email(
|
||||
to=user.email,
|
||||
subject="Welcome to Hotel Booking",
|
||||
html=email_html
|
||||
)
|
||||
logger.info(f"Welcome email sent successfully to {user.email}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send welcome email to {user.email}: {type(e).__name__}: {str(e)}", exc_info=True)
|
||||
|
||||
return {
|
||||
"user": self.format_user_response(user),
|
||||
"token": tokens["accessToken"],
|
||||
"refreshToken": tokens["refreshToken"]
|
||||
}
|
||||
|
||||
async def login(self, db: Session, email: str, password: str, remember_me: bool = False, mfa_token: str = None) -> dict:
|
||||
|
||||
email = email.lower().strip() if email else ""
|
||||
if not email:
|
||||
raise ValueError("Invalid email or password")
|
||||
|
||||
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
if not user:
|
||||
logger.warning(f"Login attempt with non-existent email: {email}")
|
||||
raise ValueError("Invalid email or password")
|
||||
|
||||
if not user.is_active:
|
||||
logger.warning(f"Login attempt for inactive user: {email}")
|
||||
raise ValueError("Account is disabled. Please contact support.")
|
||||
|
||||
# Check if account is locked (reset if lockout expired)
|
||||
if user.locked_until:
|
||||
if user.locked_until > datetime.utcnow():
|
||||
remaining_minutes = int((user.locked_until - datetime.utcnow()).total_seconds() / 60)
|
||||
logger.warning(f"Login attempt for locked account: {email} (locked until {user.locked_until})")
|
||||
raise ValueError(f"Account is temporarily locked due to multiple failed login attempts. Please try again in {remaining_minutes} minute(s).")
|
||||
else:
|
||||
# Lockout expired, reset it
|
||||
user.locked_until = None
|
||||
user.failed_login_attempts = 0
|
||||
db.commit()
|
||||
|
||||
user.role = db.query(Role).filter(Role.id == user.role_id).first()
|
||||
|
||||
password_valid = self.verify_password(password, user.password)
|
||||
|
||||
# Handle failed login attempt
|
||||
if not password_valid:
|
||||
user.failed_login_attempts = (user.failed_login_attempts or 0) + 1
|
||||
max_attempts = settings.MAX_LOGIN_ATTEMPTS
|
||||
lockout_duration = settings.ACCOUNT_LOCKOUT_DURATION_MINUTES
|
||||
|
||||
# Lock account if max attempts reached
|
||||
if user.failed_login_attempts >= max_attempts:
|
||||
user.locked_until = datetime.utcnow() + timedelta(minutes=lockout_duration)
|
||||
logger.warning(f"Account locked due to {user.failed_login_attempts} failed login attempts: {email}")
|
||||
db.commit()
|
||||
raise ValueError(f"Account has been temporarily locked due to {max_attempts} failed login attempts. Please try again in {lockout_duration} minute(s).")
|
||||
else:
|
||||
remaining_attempts = max_attempts - user.failed_login_attempts
|
||||
logger.warning(f"Login attempt with invalid password for user: {email} ({user.failed_login_attempts}/{max_attempts} failed attempts)")
|
||||
db.commit()
|
||||
raise ValueError(f"Invalid email or password. {remaining_attempts} attempt(s) remaining before account lockout.")
|
||||
|
||||
if user.mfa_enabled:
|
||||
if not mfa_token:
|
||||
|
||||
return {
|
||||
"requires_mfa": True,
|
||||
"user_id": user.id
|
||||
}
|
||||
|
||||
|
||||
from ..services.mfa_service import mfa_service
|
||||
is_backup_code = len(mfa_token) == 8
|
||||
if not mfa_service.verify_mfa(db, user.id, mfa_token, is_backup_code):
|
||||
# Increment failed attempts on MFA failure
|
||||
user.failed_login_attempts = (user.failed_login_attempts or 0) + 1
|
||||
max_attempts = settings.MAX_LOGIN_ATTEMPTS
|
||||
lockout_duration = settings.ACCOUNT_LOCKOUT_DURATION_MINUTES
|
||||
|
||||
if user.failed_login_attempts >= max_attempts:
|
||||
user.locked_until = datetime.utcnow() + timedelta(minutes=lockout_duration)
|
||||
logger.warning(f"Account locked due to {user.failed_login_attempts} failed attempts (MFA failure): {email}")
|
||||
db.commit()
|
||||
raise ValueError(f"Account has been temporarily locked due to {max_attempts} failed login attempts. Please try again in {lockout_duration} minute(s).")
|
||||
else:
|
||||
remaining_attempts = max_attempts - user.failed_login_attempts
|
||||
db.commit()
|
||||
raise ValueError(f"Invalid MFA token. {remaining_attempts} attempt(s) remaining before account lockout.")
|
||||
|
||||
# Reset failed login attempts and unlock account on successful login
|
||||
if user.failed_login_attempts > 0 or user.locked_until:
|
||||
user.failed_login_attempts = 0
|
||||
user.locked_until = None
|
||||
db.commit()
|
||||
|
||||
tokens = self.generate_tokens(user.id)
|
||||
|
||||
expiry_days = 7 if remember_me else 1
|
||||
expires_at = datetime.utcnow() + timedelta(days=expiry_days)
|
||||
|
||||
try:
|
||||
db.query(RefreshToken).filter(
|
||||
RefreshToken.user_id == user.id
|
||||
).delete()
|
||||
db.flush()
|
||||
|
||||
refresh_token = RefreshToken(
|
||||
user_id=user.id,
|
||||
token=tokens["refreshToken"],
|
||||
expires_at=expires_at
|
||||
)
|
||||
db.add(refresh_token)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error saving refresh token for user {user.id}: {str(e)}", exc_info=True)
|
||||
try:
|
||||
db.query(RefreshToken).filter(
|
||||
RefreshToken.token == tokens["refreshToken"]
|
||||
).delete()
|
||||
db.flush()
|
||||
refresh_token = RefreshToken(
|
||||
user_id=user.id,
|
||||
token=tokens["refreshToken"],
|
||||
expires_at=expires_at
|
||||
)
|
||||
db.add(refresh_token)
|
||||
db.commit()
|
||||
except Exception as retry_error:
|
||||
db.rollback()
|
||||
logger.error(f"Retry failed for refresh token: {str(retry_error)}", exc_info=True)
|
||||
raise ValueError("Failed to create session. Please try again.")
|
||||
|
||||
return {
|
||||
"user": self.format_user_response(user),
|
||||
"token": tokens["accessToken"],
|
||||
"refreshToken": tokens["refreshToken"]
|
||||
}
|
||||
|
||||
async def refresh_access_token(self, db: Session, refresh_token_str: str) -> dict:
|
||||
if not refresh_token_str:
|
||||
raise ValueError("Refresh token is required")
|
||||
|
||||
decoded = self.verify_refresh_token(refresh_token_str)
|
||||
|
||||
stored_token = db.query(RefreshToken).filter(
|
||||
RefreshToken.token == refresh_token_str,
|
||||
RefreshToken.user_id == decoded["userId"]
|
||||
).first()
|
||||
|
||||
if not stored_token:
|
||||
raise ValueError("Invalid refresh token")
|
||||
|
||||
if datetime.utcnow() > stored_token.expires_at:
|
||||
db.delete(stored_token)
|
||||
db.commit()
|
||||
raise ValueError("Refresh token expired")
|
||||
|
||||
access_token = jwt.encode(
|
||||
{"userId": decoded["userId"]},
|
||||
self.jwt_secret,
|
||||
algorithm="HS256"
|
||||
)
|
||||
|
||||
return {"token": access_token}
|
||||
|
||||
async def logout(self, db: Session, refresh_token_str: str) -> bool:
|
||||
if refresh_token_str:
|
||||
db.query(RefreshToken).filter(RefreshToken.token == refresh_token_str).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
async def get_profile(self, db: Session, user_id: int) -> dict:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise ValueError("User not found")
|
||||
|
||||
user.role = db.query(Role).filter(Role.id == user.role_id).first()
|
||||
|
||||
return self.format_user_response(user)
|
||||
|
||||
async def update_profile(
|
||||
self,
|
||||
db: Session,
|
||||
user_id: int,
|
||||
full_name: Optional[str] = None,
|
||||
email: Optional[str] = None,
|
||||
phone_number: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
current_password: Optional[str] = None,
|
||||
currency: Optional[str] = None
|
||||
) -> dict:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise ValueError("User not found")
|
||||
|
||||
if password:
|
||||
if not current_password:
|
||||
raise ValueError("Current password is required to change password")
|
||||
if not self.verify_password(current_password, user.password):
|
||||
raise ValueError("Current password is incorrect")
|
||||
|
||||
# Validate new password strength
|
||||
from ...shared.utils.password_validation import validate_password_strength
|
||||
is_valid, errors = validate_password_strength(password)
|
||||
if not is_valid:
|
||||
error_message = 'New password does not meet requirements: ' + '; '.join(errors)
|
||||
raise ValueError(error_message)
|
||||
|
||||
user.password = self.hash_password(password)
|
||||
|
||||
if full_name is not None:
|
||||
user.full_name = full_name
|
||||
if email is not None:
|
||||
|
||||
existing_user = db.query(User).filter(
|
||||
User.email == email,
|
||||
User.id != user_id
|
||||
).first()
|
||||
if existing_user:
|
||||
raise ValueError("Email already registered")
|
||||
user.email = email
|
||||
if phone_number is not None:
|
||||
user.phone = phone_number
|
||||
if currency is not None:
|
||||
|
||||
if len(currency) == 3 and currency.isalpha():
|
||||
user.currency = currency.upper()
|
||||
else:
|
||||
raise ValueError("Invalid currency code. Must be a 3-letter ISO 4217 code (e.g., USD, EUR, VND)")
|
||||
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
user.role = db.query(Role).filter(Role.id == user.role_id).first()
|
||||
|
||||
return self.format_user_response(user)
|
||||
|
||||
def generate_reset_token(self) -> tuple:
|
||||
reset_token = secrets.token_hex(32)
|
||||
hashed_token = hashlib.sha256(reset_token.encode()).hexdigest()
|
||||
return reset_token, hashed_token
|
||||
|
||||
async def forgot_password(self, db: Session, email: str) -> dict:
|
||||
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
|
||||
if not user:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "If email exists, reset link has been sent"
|
||||
}
|
||||
|
||||
reset_token, hashed_token = self.generate_reset_token()
|
||||
|
||||
db.query(PasswordResetToken).filter(PasswordResetToken.user_id == user.id).delete()
|
||||
|
||||
expires_at = datetime.utcnow() + timedelta(hours=1)
|
||||
reset_token_obj = PasswordResetToken(
|
||||
user_id=user.id,
|
||||
token=hashed_token,
|
||||
expires_at=expires_at
|
||||
)
|
||||
db.add(reset_token_obj)
|
||||
db.commit()
|
||||
|
||||
client_url = settings.CLIENT_URL or os.getenv("CLIENT_URL", "http://localhost:5173")
|
||||
reset_url = f"{client_url}/reset-password/{reset_token}"
|
||||
|
||||
try:
|
||||
logger.info(f"Attempting to send password reset email to {user.email}")
|
||||
logger.info(f"Reset URL: {reset_url}")
|
||||
email_html = password_reset_email_template(reset_url)
|
||||
|
||||
plain_text = f"Please click the following link to reset your password: {reset_url}\n\nIf you did not request this, please ignore this email."
|
||||
|
||||
await send_email(
|
||||
to=user.email,
|
||||
subject="Reset password - Hotel Booking",
|
||||
html=email_html,
|
||||
text=plain_text
|
||||
)
|
||||
logger.info(f"Password reset email sent successfully to {user.email} with reset URL: {reset_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send password reset email to {user.email}: {type(e).__name__}: {str(e)}", exc_info=True)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Password reset link has been sent to your email"
|
||||
}
|
||||
|
||||
async def reset_password(self, db: Session, token: str, password: str) -> dict:
|
||||
if not token or not password:
|
||||
raise ValueError("Token and password are required")
|
||||
|
||||
hashed_token = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
reset_token = db.query(PasswordResetToken).filter(
|
||||
PasswordResetToken.token == hashed_token,
|
||||
PasswordResetToken.expires_at > datetime.utcnow(),
|
||||
PasswordResetToken.used == False
|
||||
).first()
|
||||
|
||||
if not reset_token:
|
||||
raise ValueError("Invalid or expired reset token")
|
||||
|
||||
user = db.query(User).filter(User.id == reset_token.user_id).first()
|
||||
if not user:
|
||||
raise ValueError("User not found")
|
||||
|
||||
if self.verify_password(password, user.password):
|
||||
raise ValueError("New password must be different from the old password")
|
||||
|
||||
hashed_password = self.hash_password(password)
|
||||
|
||||
user.password = hashed_password
|
||||
db.commit()
|
||||
|
||||
reset_token.used = True
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
logger.info(f"Attempting to send password changed confirmation email to {user.email}")
|
||||
email_html = password_changed_email_template(user.email)
|
||||
await send_email(
|
||||
to=user.email,
|
||||
subject="Password Changed",
|
||||
html=email_html
|
||||
)
|
||||
logger.info(f"Password changed confirmation email sent successfully to {user.email}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send password changed confirmation email to {user.email}: {type(e).__name__}: {str(e)}", exc_info=True)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Password has been reset successfully"
|
||||
}
|
||||
|
||||
auth_service = AuthService()
|
||||
|
||||
127
Backend/src/auth/services/mfa_service.py
Normal file
127
Backend/src/auth/services/mfa_service.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import pyotp
|
||||
import qrcode
|
||||
import secrets
|
||||
import hashlib
|
||||
import json
|
||||
import base64
|
||||
import io
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from ..models.user import User
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MFAService:
|
||||
|
||||
@staticmethod
|
||||
def generate_secret() -> str:
|
||||
return pyotp.random_base32()
|
||||
|
||||
@staticmethod
|
||||
def generate_qr_code(secret: str, email: str, app_name: str='Hotel Booking') -> str:
|
||||
totp_uri = pyotp.totp.TOTP(secret).provisioning_uri(name=email, issuer_name=app_name)
|
||||
qr = qrcode.QRCode(version=1, error_correction=qrcode.constants.ERROR_CORRECT_L, box_size=10, border=4)
|
||||
qr.add_data(totp_uri)
|
||||
qr.make(fit=True)
|
||||
img = qr.make_image(fill_color='black', back_color='white')
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='PNG')
|
||||
img_data = base64.b64encode(buffer.getvalue()).decode()
|
||||
return f'data:image/png;base64,{img_data}'
|
||||
|
||||
@staticmethod
|
||||
def generate_backup_codes(count: int=10) -> List[str]:
|
||||
codes = []
|
||||
for _ in range(count):
|
||||
code = secrets.token_urlsafe(6).upper()[:8]
|
||||
codes.append(code)
|
||||
return codes
|
||||
|
||||
@staticmethod
|
||||
def hash_backup_code(code: str) -> str:
|
||||
return hashlib.sha256(code.encode()).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def verify_backup_code(code: str, hashed_codes: List[str]) -> bool:
|
||||
code_hash = MFAService.hash_backup_code(code)
|
||||
return code_hash in hashed_codes
|
||||
|
||||
@staticmethod
|
||||
def verify_totp(token: str, secret: str) -> bool:
|
||||
try:
|
||||
totp = pyotp.TOTP(secret)
|
||||
return totp.verify(token, valid_window=1)
|
||||
except Exception as e:
|
||||
logger.error(f'Error verifying TOTP: {str(e)}')
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def enable_mfa(db: Session, user_id: int, secret: str, verification_token: str) -> Tuple[bool, List[str]]:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise ValueError('User not found')
|
||||
if not MFAService.verify_totp(verification_token, secret):
|
||||
raise ValueError('Invalid verification token')
|
||||
backup_codes = MFAService.generate_backup_codes()
|
||||
hashed_codes = [MFAService.hash_backup_code(code) for code in backup_codes]
|
||||
user.mfa_enabled = True
|
||||
user.mfa_secret = secret
|
||||
user.mfa_backup_codes = json.dumps(hashed_codes)
|
||||
db.commit()
|
||||
return (True, backup_codes)
|
||||
|
||||
@staticmethod
|
||||
def disable_mfa(db: Session, user_id: int) -> bool:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise ValueError('User not found')
|
||||
user.mfa_enabled = False
|
||||
user.mfa_secret = None
|
||||
user.mfa_backup_codes = None
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def verify_mfa(db: Session, user_id: int, token: str, is_backup_code: bool=False) -> bool:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise ValueError('User not found')
|
||||
if not user.mfa_enabled or not user.mfa_secret:
|
||||
raise ValueError('MFA is not enabled for this user')
|
||||
if is_backup_code:
|
||||
if not user.mfa_backup_codes:
|
||||
return False
|
||||
hashed_codes = json.loads(user.mfa_backup_codes)
|
||||
if not MFAService.verify_backup_code(token, hashed_codes):
|
||||
return False
|
||||
code_hash = MFAService.hash_backup_code(token)
|
||||
hashed_codes.remove(code_hash)
|
||||
user.mfa_backup_codes = json.dumps(hashed_codes) if hashed_codes else None
|
||||
db.commit()
|
||||
return True
|
||||
else:
|
||||
return MFAService.verify_totp(token, user.mfa_secret)
|
||||
|
||||
@staticmethod
|
||||
def regenerate_backup_codes(db: Session, user_id: int) -> List[str]:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise ValueError('User not found')
|
||||
if not user.mfa_enabled:
|
||||
raise ValueError('MFA is not enabled for this user')
|
||||
backup_codes = MFAService.generate_backup_codes()
|
||||
hashed_codes = [MFAService.hash_backup_code(code) for code in backup_codes]
|
||||
user.mfa_backup_codes = json.dumps(hashed_codes)
|
||||
db.commit()
|
||||
return backup_codes
|
||||
|
||||
@staticmethod
|
||||
def get_mfa_status(db: Session, user_id: int) -> Dict:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise ValueError('User not found')
|
||||
backup_codes_count = 0
|
||||
if user.mfa_backup_codes:
|
||||
backup_codes_count = len(json.loads(user.mfa_backup_codes))
|
||||
return {'mfa_enabled': user.mfa_enabled, 'backup_codes_count': backup_codes_count}
|
||||
mfa_service = MFAService()
|
||||
209
Backend/src/auth/services/oauth_service.py
Normal file
209
Backend/src/auth/services/oauth_service.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime, timedelta
|
||||
import httpx
|
||||
import secrets
|
||||
from urllib.parse import urlencode
|
||||
import logging
|
||||
|
||||
from ...security.models.security_event import OAuthProvider, OAuthToken
|
||||
from ..models.user import User
|
||||
from ...shared.config.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class OAuthService:
|
||||
"""Service for handling OAuth 2.0 / OpenID Connect authentication"""
|
||||
|
||||
@staticmethod
|
||||
def get_authorization_url(db: Session, provider_name: str, redirect_uri: str, state: Optional[str] = None) -> str:
|
||||
"""Generate OAuth authorization URL"""
|
||||
provider = db.query(OAuthProvider).filter(
|
||||
OAuthProvider.name == provider_name,
|
||||
OAuthProvider.is_active == True
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f"OAuth provider '{provider_name}' not found or inactive")
|
||||
|
||||
if not state:
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
params = {
|
||||
'client_id': provider.client_id,
|
||||
'redirect_uri': redirect_uri,
|
||||
'response_type': 'code',
|
||||
'scope': provider.scopes or 'openid profile email',
|
||||
'state': state,
|
||||
}
|
||||
|
||||
return f"{provider.authorization_url}?{urlencode(params)}"
|
||||
|
||||
@staticmethod
|
||||
async def exchange_code_for_token(
|
||||
db: Session,
|
||||
provider_name: str,
|
||||
code: str,
|
||||
redirect_uri: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Exchange authorization code for access token"""
|
||||
provider = db.query(OAuthProvider).filter(
|
||||
OAuthProvider.name == provider_name,
|
||||
OAuthProvider.is_active == True
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f"OAuth provider '{provider_name}' not found or inactive")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
provider.token_url,
|
||||
data={
|
||||
'grant_type': 'authorization_code',
|
||||
'code': code,
|
||||
'redirect_uri': redirect_uri,
|
||||
'client_id': provider.client_id,
|
||||
'client_secret': provider.client_secret,
|
||||
},
|
||||
headers={'Accept': 'application/json'}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"OAuth token exchange failed: {response.text}")
|
||||
raise ValueError("Failed to exchange authorization code for token")
|
||||
|
||||
token_data = response.json()
|
||||
return token_data
|
||||
|
||||
@staticmethod
|
||||
async def get_user_info(
|
||||
db: Session,
|
||||
provider_name: str,
|
||||
access_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Get user information from OAuth provider"""
|
||||
provider = db.query(OAuthProvider).filter(
|
||||
OAuthProvider.name == provider_name,
|
||||
OAuthProvider.is_active == True
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f"OAuth provider '{provider_name}' not found or inactive")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
provider.userinfo_url,
|
||||
headers={
|
||||
'Authorization': f'Bearer {access_token}',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to get user info: {response.text}")
|
||||
raise ValueError("Failed to get user information from OAuth provider")
|
||||
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def save_oauth_token(
|
||||
db: Session,
|
||||
user_id: int,
|
||||
provider_id: int,
|
||||
provider_user_id: str,
|
||||
access_token: str,
|
||||
refresh_token: Optional[str] = None,
|
||||
expires_in: Optional[int] = None,
|
||||
scopes: Optional[str] = None
|
||||
) -> OAuthToken:
|
||||
"""Save or update OAuth token for user"""
|
||||
expires_at = None
|
||||
if expires_in:
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||
|
||||
# Check if token already exists
|
||||
existing_token = db.query(OAuthToken).filter(
|
||||
OAuthToken.user_id == user_id,
|
||||
OAuthToken.provider_id == provider_id
|
||||
).first()
|
||||
|
||||
if existing_token:
|
||||
existing_token.access_token = access_token
|
||||
existing_token.refresh_token = refresh_token
|
||||
existing_token.expires_at = expires_at
|
||||
existing_token.scopes = scopes
|
||||
existing_token.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(existing_token)
|
||||
return existing_token
|
||||
else:
|
||||
new_token = OAuthToken(
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
provider_user_id=provider_user_id,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_at=expires_at,
|
||||
scopes=scopes
|
||||
)
|
||||
db.add(new_token)
|
||||
db.commit()
|
||||
db.refresh(new_token)
|
||||
return new_token
|
||||
|
||||
@staticmethod
|
||||
def find_or_create_user_from_oauth(
|
||||
db: Session,
|
||||
provider_name: str,
|
||||
user_info: Dict[str, Any]
|
||||
) -> User:
|
||||
"""Find existing user or create new user from OAuth user info"""
|
||||
provider = db.query(OAuthProvider).filter(
|
||||
OAuthProvider.name == provider_name
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f"OAuth provider '{provider_name}' not found")
|
||||
|
||||
# Try to find user by OAuth token
|
||||
provider_user_id = user_info.get('sub') or user_info.get('id')
|
||||
oauth_token = db.query(OAuthToken).filter(
|
||||
OAuthToken.provider_id == provider.id,
|
||||
OAuthToken.provider_user_id == str(provider_user_id)
|
||||
).first()
|
||||
|
||||
if oauth_token:
|
||||
return oauth_token.user
|
||||
|
||||
# Try to find user by email
|
||||
email = user_info.get('email')
|
||||
if email:
|
||||
user = db.query(User).filter(User.email == email.lower()).first()
|
||||
if user:
|
||||
return user
|
||||
|
||||
# Create new user
|
||||
from ..models.role import Role
|
||||
customer_role = db.query(Role).filter(Role.name == 'customer').first()
|
||||
if not customer_role:
|
||||
raise ValueError("Customer role not found")
|
||||
|
||||
name = user_info.get('name') or user_info.get('given_name', '') + ' ' + user_info.get('family_name', '')
|
||||
if not name.strip():
|
||||
name = email.split('@')[0] if email else 'User'
|
||||
|
||||
new_user = User(
|
||||
email=email.lower() if email else f"{provider_user_id}@{provider_name}.oauth",
|
||||
full_name=name.strip(),
|
||||
role_id=customer_role.id,
|
||||
is_active=True,
|
||||
email_verified=True # OAuth providers verify emails
|
||||
)
|
||||
db.add(new_user)
|
||||
db.commit()
|
||||
db.refresh(new_user)
|
||||
|
||||
return new_user
|
||||
|
||||
oauth_service = OAuthService()
|
||||
|
||||
Reference in New Issue
Block a user