This commit is contained in:
Iliyan Angelov
2025-11-30 22:43:09 +02:00
parent 24b40450dd
commit 39fcfff811
1610 changed files with 5442 additions and 1383 deletions

View File

View File

View File

@@ -0,0 +1,15 @@
from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey
from sqlalchemy.orm import relationship
from datetime import datetime
from ...shared.config.database import Base
class PasswordResetToken(Base):
__tablename__ = 'password_reset_tokens'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
token = Column(String(255), unique=True, nullable=False, index=True)
expires_at = Column(DateTime, nullable=False)
used = Column(Boolean, default=False, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
user = relationship('User')

View File

@@ -0,0 +1,13 @@
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import relationship
from datetime import datetime
from ...shared.config.database import Base
class RefreshToken(Base):
__tablename__ = 'refresh_tokens'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
token = Column(String(500), unique=True, nullable=False, index=True)
expires_at = Column(DateTime, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
user = relationship('User', back_populates='refresh_tokens')

View File

@@ -0,0 +1,13 @@
from sqlalchemy import Column, Integer, String, DateTime
from sqlalchemy.orm import relationship
from datetime import datetime
from ...shared.config.database import Base
class Role(Base):
__tablename__ = 'roles'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
name = Column(String(50), unique=True, nullable=False, index=True)
description = Column(String(255), nullable=True)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
users = relationship('User', back_populates='role')

View File

@@ -0,0 +1,53 @@
from sqlalchemy import Column, Integer, String, Boolean, Text, ForeignKey, DateTime, Numeric
from sqlalchemy.orm import relationship
from datetime import datetime
from ...shared.config.database import Base
class User(Base):
__tablename__ = 'users'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
role_id = Column(Integer, ForeignKey('roles.id'), nullable=False)
email = Column(String(100), unique=True, nullable=False, index=True)
password = Column(String(255), nullable=False)
full_name = Column(String(100), nullable=False)
phone = Column(String(20), nullable=True)
address = Column(Text, nullable=True)
avatar = Column(String(255), nullable=True)
currency = Column(String(3), nullable=False, default='VND')
is_active = Column(Boolean, nullable=False, default=True)
mfa_enabled = Column(Boolean, nullable=False, default=False)
mfa_secret = Column(String(255), nullable=True)
mfa_backup_codes = Column(Text, nullable=True)
# Account lockout fields
failed_login_attempts = Column(Integer, nullable=False, default=0)
locked_until = Column(DateTime, nullable=True)
# Guest Profile & CRM fields
is_vip = Column(Boolean, nullable=False, default=False)
lifetime_value = Column(Numeric(10, 2), nullable=True, default=0) # Total revenue from guest
satisfaction_score = Column(Numeric(3, 2), nullable=True) # Average satisfaction score (0-5)
last_visit_date = Column(DateTime, nullable=True) # Last booking check-in date
total_visits = Column(Integer, nullable=False, default=0) # Total number of bookings
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
role = relationship('Role', back_populates='users')
bookings = relationship('Booking', back_populates='user')
refresh_tokens = relationship('RefreshToken', back_populates='user', cascade='all, delete-orphan')
checkins_processed = relationship('CheckInCheckOut', foreign_keys='CheckInCheckOut.checkin_by', back_populates='checked_in_by')
checkouts_processed = relationship('CheckInCheckOut', foreign_keys='CheckInCheckOut.checkout_by', back_populates='checked_out_by')
reviews = relationship('Review', back_populates='user')
favorites = relationship('Favorite', back_populates='user', cascade='all, delete-orphan')
service_bookings = relationship('ServiceBooking', back_populates='user')
visitor_chats = relationship('Chat', foreign_keys='Chat.visitor_id', back_populates='visitor')
staff_chats = relationship('Chat', foreign_keys='Chat.staff_id', back_populates='staff')
loyalty = relationship('UserLoyalty', back_populates='user', uselist=False, cascade='all, delete-orphan')
referrals = relationship('Referral', foreign_keys='Referral.referred_user_id', back_populates='referred_user')
# Guest Profile & CRM relationships
guest_preferences = relationship('GuestPreference', back_populates='user', uselist=False, cascade='all, delete-orphan')
guest_notes = relationship('GuestNote', foreign_keys='GuestNote.user_id', back_populates='user', cascade='all, delete-orphan')
guest_tags = relationship('GuestTag', secondary='guest_tag_associations', back_populates='users')
guest_communications = relationship('GuestCommunication', foreign_keys='GuestCommunication.user_id', back_populates='user', cascade='all, delete-orphan')
guest_segments = relationship('GuestSegment', secondary='guest_segment_associations', back_populates='users')

View File

View File

@@ -0,0 +1,413 @@
from fastapi import APIRouter, Depends, HTTPException, status, Cookie, Response, Request, UploadFile, File
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from pathlib import Path
import aiofiles
import uuid
import os
from ...shared.config.database import get_db
from ..services.auth_service import auth_service
from ..schemas.auth import RegisterRequest, LoginRequest, RefreshTokenRequest, ForgotPasswordRequest, ResetPasswordRequest, AuthResponse, TokenResponse, MessageResponse, MFAInitResponse, EnableMFARequest, VerifyMFARequest, MFAStatusResponse, UpdateProfileRequest
from ...security.middleware.auth import get_current_user
from ..models.user import User
from ...analytics.services.audit_service import audit_service
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
router = APIRouter(prefix='/auth', tags=['auth'])
# Stricter rate limits for authentication endpoints
AUTH_RATE_LIMIT = "5/minute" # 5 attempts per minute per IP
PASSWORD_RESET_LIMIT = "3/hour" # 3 password reset requests per hour per IP
LOGIN_RATE_LIMIT = "10/minute" # 10 login attempts per minute per IP
# Initialize limiter - will be set from app state
limiter = None
def get_limiter(request: Request) -> Limiter:
"""Get limiter instance from app state."""
global limiter
if hasattr(request.app.state, 'limiter'):
limiter = request.app.state.limiter
return limiter
def get_base_url(request: Request) -> str:
return os.getenv('SERVER_URL') or f'http://{request.headers.get('host', 'localhost:8000')}'
def normalize_image_url(image_url: str, base_url: str) -> str:
if not image_url:
return image_url
if image_url.startswith('http://') or image_url.startswith('https://'):
return image_url
if image_url.startswith('/'):
return f'{base_url}{image_url}'
return f'{base_url}/{image_url}'
@router.post('/register', status_code=status.HTTP_201_CREATED)
async def register(
request: Request,
register_request: RegisterRequest,
response: Response,
db: Session=Depends(get_db)
):
# Rate limiting is handled by middleware, but we can add additional checks here if needed
client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent')
request_id = getattr(request.state, 'request_id', None)
try:
result = await auth_service.register(db=db, name=register_request.name, email=register_request.email, password=register_request.password, phone=register_request.phone)
from ...shared.config.settings import settings
max_age = 7 * 24 * 60 * 60 # 7 days for registration
# Use secure cookies in production (HTTPS required)
# Set access token in httpOnly cookie for security
# Use 'lax' in development for cross-origin support, 'strict' in production
samesite_value = 'strict' if settings.is_production else 'lax'
response.set_cookie(
key='accessToken',
value=result['token'],
httponly=True,
secure=settings.is_production, # Secure flag enabled in production
samesite=samesite_value,
max_age=max_age,
path='/'
)
# Set refresh token in httpOnly cookie
response.set_cookie(
key='refreshToken',
value=result['refreshToken'],
httponly=True,
secure=settings.is_production, # Secure flag enabled in production
samesite=samesite_value,
max_age=max_age,
path='/'
)
# Log successful registration
await audit_service.log_action(
db=db,
action='user_registered',
resource_type='user',
user_id=result['user']['id'],
ip_address=client_ip,
user_agent=user_agent,
request_id=request_id,
details={'email': register_request.email, 'name': register_request.name},
status='success'
)
# Return user data but NOT the token (it's in httpOnly cookie now)
return {'status': 'success', 'message': 'Registration successful', 'data': {'user': result['user']}}
except ValueError as e:
error_message = str(e)
# Log failed registration attempt
await audit_service.log_action(
db=db,
action='user_registration_failed',
resource_type='user',
ip_address=client_ip,
user_agent=user_agent,
request_id=request_id,
details={'email': register_request.email, 'name': register_request.name},
status='failed',
error_message=error_message
)
return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content={'status': 'error', 'message': error_message})
@router.post('/login')
async def login(
request: Request,
login_request: LoginRequest,
response: Response,
db: Session=Depends(get_db)
):
client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent')
request_id = getattr(request.state, 'request_id', None)
try:
result = await auth_service.login(db=db, email=login_request.email, password=login_request.password, remember_me=login_request.rememberMe or False, mfa_token=login_request.mfaToken)
if result.get('requires_mfa'):
# Log MFA required
user = db.query(User).filter(User.email == login_request.email.lower().strip()).first()
if user:
await audit_service.log_action(
db=db,
action='login_mfa_required',
resource_type='authentication',
user_id=user.id,
ip_address=client_ip,
user_agent=user_agent,
request_id=request_id,
details={'email': login_request.email},
status='success'
)
return {'status': 'success', 'requires_mfa': True, 'user_id': result['user_id']}
from ...shared.config.settings import settings
max_age = 7 * 24 * 60 * 60 if login_request.rememberMe else 1 * 24 * 60 * 60
# Use secure cookies in production (HTTPS required)
# Set access token in httpOnly cookie for security
# Use 'lax' in development for cross-origin support, 'strict' in production
samesite_value = 'strict' if settings.is_production else 'lax'
response.set_cookie(
key='accessToken',
value=result['token'],
httponly=True,
secure=settings.is_production, # Secure flag enabled in production
samesite=samesite_value,
max_age=max_age,
path='/'
)
# Set refresh token in httpOnly cookie
response.set_cookie(
key='refreshToken',
value=result['refreshToken'],
httponly=True,
secure=settings.is_production, # Secure flag enabled in production
samesite=samesite_value,
max_age=max_age,
path='/'
)
# Log successful login
await audit_service.log_action(
db=db,
action='login_success',
resource_type='authentication',
user_id=result['user']['id'],
ip_address=client_ip,
user_agent=user_agent,
request_id=request_id,
details={'email': login_request.email, 'remember_me': login_request.rememberMe},
status='success'
)
# Return user data but NOT the token (it's in httpOnly cookie now)
return {'status': 'success', 'data': {'user': result['user']}}
except ValueError as e:
error_message = str(e)
status_code = status.HTTP_401_UNAUTHORIZED if 'Invalid email or password' in error_message or 'Invalid MFA token' in error_message else status.HTTP_400_BAD_REQUEST
# Log failed login attempt
await audit_service.log_action(
db=db,
action='login_failed',
resource_type='authentication',
ip_address=client_ip,
user_agent=user_agent,
request_id=request_id,
details={'email': login_request.email},
status='failed',
error_message=error_message
)
return JSONResponse(status_code=status_code, content={'status': 'error', 'message': error_message})
@router.post('/refresh-token', response_model=TokenResponse)
async def refresh_token(
request: Request,
response: Response,
refreshToken: str=Cookie(None),
db: Session=Depends(get_db)
):
if not refreshToken:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Refresh token not found')
try:
result = await auth_service.refresh_access_token(db, refreshToken)
from ...shared.config.settings import settings
# Set new access token in httpOnly cookie
# Use 'lax' in development for cross-origin support, 'strict' in production
samesite_value = 'strict' if settings.is_production else 'lax'
max_age = 7 * 24 * 60 * 60 # 7 days
response.set_cookie(
key='accessToken',
value=result['token'],
httponly=True,
secure=settings.is_production,
samesite=samesite_value,
max_age=max_age,
path='/'
)
# Return user data but NOT the token (it's in httpOnly cookie now)
return {'status': 'success', 'data': {'user': result.get('user')}}
except ValueError as e:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
@router.post('/logout', response_model=MessageResponse)
async def logout(
request: Request,
response: Response,
refreshToken: str=Cookie(None),
current_user: User=Depends(get_current_user),
db: Session=Depends(get_db)
):
client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent')
request_id = getattr(request.state, 'request_id', None)
if refreshToken:
await auth_service.logout(db, refreshToken)
# Delete both access and refresh token cookies
from ...shared.config.settings import settings
# Use 'lax' in development for cross-origin support, 'strict' in production
samesite_value = 'strict' if settings.is_production else 'lax'
response.delete_cookie(key='refreshToken', path='/', secure=settings.is_production, samesite=samesite_value)
response.delete_cookie(key='accessToken', path='/', secure=settings.is_production, samesite=samesite_value)
# Log logout
await audit_service.log_action(
db=db,
action='logout',
resource_type='authentication',
user_id=current_user.id,
ip_address=client_ip,
user_agent=user_agent,
request_id=request_id,
details={'email': current_user.email},
status='success'
)
return {'status': 'success', 'message': 'Logout successful'}
@router.get('/profile')
async def get_profile(current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
try:
user = await auth_service.get_profile(db, current_user.id)
return {'status': 'success', 'data': {'user': user}}
except ValueError as e:
if 'User not found' in str(e):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@router.put('/profile')
async def update_profile(profile_data: UpdateProfileRequest, current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
try:
user = await auth_service.update_profile(
db=db,
user_id=current_user.id,
full_name=profile_data.full_name,
email=profile_data.email,
phone_number=profile_data.phone_number,
password=profile_data.password,
current_password=profile_data.currentPassword,
currency=profile_data.currency
)
return {'status': 'success', 'message': 'Profile updated successfully', 'data': {'user': user}}
except ValueError as e:
error_message = str(e)
status_code = status.HTTP_400_BAD_REQUEST
if 'not found' in error_message.lower():
status_code = status.HTTP_404_NOT_FOUND
raise HTTPException(status_code=status_code, detail=error_message)
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'An error occurred: {str(e)}')
@router.post('/forgot-password', response_model=MessageResponse)
async def forgot_password(request: ForgotPasswordRequest, db: Session=Depends(get_db)):
result = await auth_service.forgot_password(db, request.email)
return {'status': 'success', 'message': result['message']}
@router.post('/reset-password', response_model=MessageResponse)
async def reset_password(request: ResetPasswordRequest, db: Session=Depends(get_db)):
try:
result = await auth_service.reset_password(db=db, token=request.token, password=request.password)
return {'status': 'success', 'message': result['message']}
except ValueError as e:
status_code = status.HTTP_400_BAD_REQUEST
if 'User not found' in str(e):
status_code = status.HTTP_404_NOT_FOUND
raise HTTPException(status_code=status_code, detail=str(e))
from ..services.mfa_service import mfa_service
from ...shared.config.settings import settings
@router.get('/mfa/init')
async def init_mfa(current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
try:
if current_user.mfa_enabled:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='MFA is already enabled')
secret = mfa_service.generate_secret()
app_name = getattr(settings, 'APP_NAME', 'Hotel Booking')
qr_code = mfa_service.generate_qr_code(secret, current_user.email, app_name)
return {'status': 'success', 'data': {'secret': secret, 'qr_code': qr_code}}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error initializing MFA: {str(e)}')
@router.post('/mfa/enable')
async def enable_mfa(request: EnableMFARequest, current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
try:
success, backup_codes = mfa_service.enable_mfa(db=db, user_id=current_user.id, secret=request.secret, verification_token=request.verification_token)
return {'status': 'success', 'message': 'MFA enabled successfully', 'data': {'backup_codes': backup_codes}}
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error enabling MFA: {str(e)}')
@router.post('/mfa/disable')
async def disable_mfa(current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
try:
mfa_service.disable_mfa(db=db, user_id=current_user.id)
return {'status': 'success', 'message': 'MFA disabled successfully'}
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error disabling MFA: {str(e)}')
@router.get('/mfa/status', response_model=MFAStatusResponse)
async def get_mfa_status(current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
try:
status_data = mfa_service.get_mfa_status(db=db, user_id=current_user.id)
return status_data
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error getting MFA status: {str(e)}')
@router.post('/mfa/regenerate-backup-codes')
async def regenerate_backup_codes(current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
try:
backup_codes = mfa_service.regenerate_backup_codes(db=db, user_id=current_user.id)
return {'status': 'success', 'message': 'Backup codes regenerated successfully', 'data': {'backup_codes': backup_codes}}
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error regenerating backup codes: {str(e)}')
@router.post('/avatar/upload')
async def upload_avatar(request: Request, image: UploadFile=File(...), current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
try:
# Use comprehensive file validation (magic bytes + size)
from ...shared.utils.file_validation import validate_uploaded_image
max_avatar_size = 2 * 1024 * 1024 # 2MB for avatars
# Validate file completely (MIME type, size, magic bytes, integrity)
content = await validate_uploaded_image(image, max_avatar_size)
upload_dir = Path(__file__).parent.parent.parent / 'uploads' / 'avatars'
upload_dir.mkdir(parents=True, exist_ok=True)
if current_user.avatar:
old_avatar_path = Path(__file__).parent.parent.parent / current_user.avatar.lstrip('/')
if old_avatar_path.exists() and old_avatar_path.is_file():
try:
old_avatar_path.unlink()
except Exception:
pass
ext = Path(image.filename).suffix or '.png'
filename = f'avatar-{current_user.id}-{uuid.uuid4()}{ext}'
file_path = upload_dir / filename
async with aiofiles.open(file_path, 'wb') as f:
await f.write(content)
image_url = f'/uploads/avatars/{filename}'
current_user.avatar = image_url
db.commit()
db.refresh(current_user)
base_url = get_base_url(request)
full_url = normalize_image_url(image_url, base_url)
return {'success': True, 'status': 'success', 'message': 'Avatar uploaded successfully', 'data': {'avatar_url': image_url, 'full_url': full_url, 'user': {'id': current_user.id, 'name': current_user.full_name, 'email': current_user.email, 'phone': current_user.phone, 'avatar': image_url, 'role': current_user.role.name if current_user.role else 'customer'}}}
except HTTPException:
raise
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error uploading avatar: {str(e)}')

View File

@@ -0,0 +1,164 @@
from fastapi import APIRouter, Depends, HTTPException, status, Query, Request
from sqlalchemy.orm import Session
from sqlalchemy import or_
from typing import Optional
import bcrypt
from ...shared.config.database import get_db
from ...security.middleware.auth import get_current_user, authorize_roles
from ..models.user import User
from ..models.role import Role
from ...bookings.models.booking import Booking, BookingStatus
from ...shared.utils.role_helpers import can_manage_users
from ...shared.utils.response_helpers import success_response
from ...analytics.services.audit_service import audit_service
from ..schemas.user import CreateUserRequest, UpdateUserRequest
router = APIRouter(prefix='/users', tags=['users'])
@router.get('/', dependencies=[Depends(authorize_roles('admin'))])
async def get_users(search: Optional[str]=Query(None), role: Optional[str]=Query(None), status_filter: Optional[str]=Query(None, alias='status'), page: int=Query(1, ge=1), limit: int=Query(10, ge=1, le=100), current_user: User=Depends(authorize_roles('admin')), db: Session=Depends(get_db)):
try:
query = db.query(User)
if search:
query = query.filter(or_(User.full_name.like(f'%{search}%'), User.email.like(f'%{search}%'), User.phone.like(f'%{search}%')))
if role:
role_map = {'admin': 1, 'staff': 2, 'customer': 3, 'accountant': 4}
if role in role_map:
query = query.filter(User.role_id == role_map[role])
if status_filter:
is_active = status_filter == 'active'
query = query.filter(User.is_active == is_active)
total = query.count()
offset = (page - 1) * limit
users = query.order_by(User.created_at.desc()).offset(offset).limit(limit).all()
result = []
for user in users:
user_dict = {'id': user.id, 'email': user.email, 'full_name': user.full_name, 'phone': user.phone, 'phone_number': user.phone, 'address': user.address, 'avatar': user.avatar, 'currency': getattr(user, 'currency', 'VND'), 'is_active': user.is_active, 'status': 'active' if user.is_active else 'inactive', 'role_id': user.role_id, 'role': user.role.name if user.role else 'customer', 'created_at': user.created_at.isoformat() if user.created_at else None, 'updated_at': user.updated_at.isoformat() if user.updated_at else None}
result.append(user_dict)
return success_response(data={'users': result, 'pagination': {'total': total, 'page': page, 'limit': limit, 'totalPages': (total + limit - 1) // limit}})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get('/{id}', dependencies=[Depends(authorize_roles('admin'))])
async def get_user_by_id(id: int, current_user: User=Depends(authorize_roles('admin')), db: Session=Depends(get_db)):
try:
user = db.query(User).filter(User.id == id).first()
if not user:
raise HTTPException(status_code=404, detail='User not found')
bookings = db.query(Booking).filter(Booking.user_id == id).order_by(Booking.created_at.desc()).limit(5).all()
user_dict = {'id': user.id, 'email': user.email, 'full_name': user.full_name, 'phone': user.phone, 'phone_number': user.phone, 'address': user.address, 'avatar': user.avatar, 'currency': getattr(user, 'currency', 'VND'), 'is_active': user.is_active, 'status': 'active' if user.is_active else 'inactive', 'role_id': user.role_id, 'role': user.role.name if user.role else 'customer', 'created_at': user.created_at.isoformat() if user.created_at else None, 'updated_at': user.updated_at.isoformat() if user.updated_at else None, 'bookings': [{'id': b.id, 'booking_number': b.booking_number, 'status': b.status.value if isinstance(b.status, BookingStatus) else b.status, 'created_at': b.created_at.isoformat() if b.created_at else None} for b in bookings]}
return success_response(data={'user': user_dict})
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post('/', dependencies=[Depends(authorize_roles('admin'))])
async def create_user(
request: Request,
user_data: CreateUserRequest,
current_user: User=Depends(authorize_roles('admin')),
db: Session=Depends(get_db)
):
"""Create a user with validated input using Pydantic schema."""
client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent')
request_id = getattr(request.state, 'request_id', None)
try:
email = user_data.email
password = user_data.password
full_name = user_data.full_name
phone_number = user_data.phone_number
role_id = user_data.role_id or 3 # Default to customer role
existing = db.query(User).filter(User.email == email).first()
if existing:
raise HTTPException(status_code=400, detail='Email already exists')
password_bytes = password.encode('utf-8')
salt = bcrypt.gensalt()
hashed_password = bcrypt.hashpw(password_bytes, salt).decode('utf-8')
user = User(email=email, password=hashed_password, full_name=full_name, phone=phone_number, role_id=role_id, is_active=True)
db.add(user)
db.commit()
db.refresh(user)
# Log admin action - user creation
await audit_service.log_action(
db=db,
action='admin_user_created',
resource_type='user',
user_id=current_user.id,
resource_id=user.id,
ip_address=client_ip,
user_agent=user_agent,
request_id=request_id,
details={
'created_user_email': user.email,
'created_user_name': user.full_name,
'role_id': user.role_id,
'is_active': user.is_active
},
status='success'
)
user_dict = {'id': user.id, 'email': user.email, 'full_name': user.full_name, 'phone': user.phone, 'phone_number': user.phone, 'currency': getattr(user, 'currency', 'VND'), 'role_id': user.role_id, 'is_active': user.is_active}
return success_response(data={'user': user_dict}, message='User created successfully')
except HTTPException:
raise
except Exception as e:
db.rollback()
raise HTTPException(status_code=500, detail=str(e))
@router.put('/{id}')
async def update_user(id: int, user_data: UpdateUserRequest, current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
"""Update a user with validated input using Pydantic schema."""
try:
if not can_manage_users(current_user, db) and current_user.id != id:
raise HTTPException(status_code=403, detail='Forbidden')
user = db.query(User).filter(User.id == id).first()
if not user:
raise HTTPException(status_code=404, detail='User not found')
# Check email uniqueness if being updated
if user_data.email and user_data.email != user.email:
existing = db.query(User).filter(User.email == user_data.email).first()
if existing:
raise HTTPException(status_code=400, detail='Email already exists')
# Update fields if provided
if user_data.full_name is not None:
user.full_name = user_data.full_name
if user_data.email is not None and can_manage_users(current_user, db):
user.email = user_data.email
if user_data.phone_number is not None:
user.phone = user_data.phone_number
if user_data.role_id is not None and can_manage_users(current_user, db):
user.role_id = user_data.role_id
if user_data.is_active is not None and can_manage_users(current_user, db):
user.is_active = user_data.is_active
db.commit()
db.refresh(user)
user_dict = {'id': user.id, 'email': user.email, 'full_name': user.full_name, 'phone': user.phone, 'phone_number': user.phone, 'currency': getattr(user, 'currency', 'VND'), 'role_id': user.role_id, 'is_active': user.is_active}
return success_response(data={'user': user_dict}, message='User updated successfully')
except HTTPException:
raise
except Exception as e:
db.rollback()
raise HTTPException(status_code=500, detail=str(e))
@router.delete('/{id}', dependencies=[Depends(authorize_roles('admin'))])
async def delete_user(id: int, current_user: User=Depends(authorize_roles('admin')), db: Session=Depends(get_db)):
try:
user = db.query(User).filter(User.id == id).first()
if not user:
raise HTTPException(status_code=404, detail='User not found')
active_bookings = db.query(Booking).filter(Booking.user_id == id, Booking.status.in_([BookingStatus.pending, BookingStatus.confirmed, BookingStatus.checked_in])).count()
if active_bookings > 0:
raise HTTPException(status_code=400, detail='Cannot delete user with active bookings')
db.delete(user)
db.commit()
return success_response(message='User deleted successfully')
except HTTPException:
raise
except Exception as e:
db.rollback()
raise HTTPException(status_code=500, detail=str(e))

View File

View File

@@ -0,0 +1,135 @@
from pydantic import BaseModel, EmailStr, Field, validator
from typing import Optional
class RegisterRequest(BaseModel):
name: str = Field(..., min_length=2, max_length=50)
email: EmailStr
password: str = Field(..., min_length=8)
phone: Optional[str] = None
@validator('password')
def validate_password(cls, v):
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
if not any((c.isupper() for c in v)):
raise ValueError('Password must contain at least one uppercase letter')
if not any((c.islower() for c in v)):
raise ValueError('Password must contain at least one lowercase letter')
if not any((c.isdigit() for c in v)):
raise ValueError('Password must contain at least one number')
return v
@validator('phone')
def validate_phone(cls, v):
if v:
cleaned = ''.join(c for c in v if c.isdigit())
if len(cleaned) < 5:
raise ValueError('Phone number must contain at least 5 digits')
return v
class LoginRequest(BaseModel):
email: EmailStr
password: str
rememberMe: Optional[bool] = False
mfaToken: Optional[str] = None
class RefreshTokenRequest(BaseModel):
refreshToken: Optional[str] = None
class ForgotPasswordRequest(BaseModel):
email: EmailStr
class ResetPasswordRequest(BaseModel):
token: str
password: str = Field(..., min_length=8)
@validator('password')
def validate_password(cls, v):
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
if not any((c.isupper() for c in v)):
raise ValueError('Password must contain at least one uppercase letter')
if not any((c.islower() for c in v)):
raise ValueError('Password must contain at least one lowercase letter')
if not any((c.isdigit() for c in v)):
raise ValueError('Password must contain at least one number')
return v
class UserResponse(BaseModel):
id: int
name: str
email: str
phone: Optional[str]
role: str
createdAt: Optional[str]
updatedAt: Optional[str]
class Config:
from_attributes = True
class AuthResponse(BaseModel):
user: UserResponse
token: str
refreshToken: Optional[str] = None
class TokenResponse(BaseModel):
token: str
class MessageResponse(BaseModel):
status: str
message: str
class MFAInitResponse(BaseModel):
secret: str
qr_code: str
class EnableMFARequest(BaseModel):
secret: str
verification_token: str
class VerifyMFARequest(BaseModel):
token: str
is_backup_code: Optional[bool] = False
class MFAStatusResponse(BaseModel):
mfa_enabled: bool
backup_codes_count: int
class UpdateProfileRequest(BaseModel):
full_name: Optional[str] = Field(None, min_length=2, max_length=100, description='Full name')
email: Optional[EmailStr] = Field(None, description='Email address')
phone_number: Optional[str] = Field(None, min_length=5, max_length=20, description='Phone number')
password: Optional[str] = Field(None, min_length=8, description='New password')
currentPassword: Optional[str] = Field(None, alias='current_password', description='Current password (required when changing password)')
currency: Optional[str] = Field(None, min_length=3, max_length=3, description='Currency code (ISO 4217, e.g., USD, EUR, VND)')
@validator('password')
def validate_password(cls, v):
if v is not None:
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
if not any((c.isupper() for c in v)):
raise ValueError('Password must contain at least one uppercase letter')
if not any((c.islower() for c in v)):
raise ValueError('Password must contain at least one lowercase letter')
if not any((c.isdigit() for c in v)):
raise ValueError('Password must contain at least one number')
return v
@validator('phone_number')
def validate_phone(cls, v):
if v is not None:
cleaned = ''.join(c for c in v if c.isdigit())
if len(cleaned) < 5:
raise ValueError('Phone number must contain at least 5 digits')
return v
@validator('currency')
def validate_currency(cls, v):
if v is not None:
if len(v) != 3 or not v.isalpha():
raise ValueError('Currency must be a 3-letter ISO 4217 code (e.g., USD, EUR, VND)')
return v.upper() if v else v
class Config:
allow_population_by_field_name = True

View File

@@ -0,0 +1,38 @@
"""
Pydantic schemas for user-related requests and responses.
"""
from pydantic import BaseModel, Field, EmailStr, field_validator
from typing import Optional
class CreateUserRequest(BaseModel):
"""Schema for creating a user."""
full_name: str = Field(..., min_length=2, max_length=100, description="Full name")
email: EmailStr = Field(..., description="Email address")
password: str = Field(..., min_length=8, description="Password")
phone_number: Optional[str] = Field(None, max_length=20, description="Phone number")
role_id: Optional[int] = Field(None, gt=0, description="Role ID")
@field_validator('password')
@classmethod
def validate_password(cls, v: str) -> str:
"""Validate password strength."""
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
if not any(c.isupper() for c in v):
raise ValueError('Password must contain at least one uppercase letter')
if not any(c.islower() for c in v):
raise ValueError('Password must contain at least one lowercase letter')
if not any(c.isdigit() for c in v):
raise ValueError('Password must contain at least one number')
return v
class UpdateUserRequest(BaseModel):
"""Schema for updating a user."""
full_name: Optional[str] = Field(None, min_length=2, max_length=100)
email: Optional[EmailStr] = None
phone_number: Optional[str] = Field(None, max_length=20)
role_id: Optional[int] = Field(None, gt=0)
is_active: Optional[bool] = None

View File

View File

@@ -0,0 +1,500 @@
from jose import jwt
import bcrypt
from datetime import datetime, timedelta
import secrets
import hashlib
from sqlalchemy.orm import Session
from typing import Optional
import logging
from ..models.user import User
from ..models.refresh_token import RefreshToken
from ..models.password_reset_token import PasswordResetToken
from ..models.role import Role
from ...shared.utils.mailer import send_email
from ...shared.utils.email_templates import (
welcome_email_template,
password_reset_email_template,
password_changed_email_template
)
from ...shared.config.settings import settings
import os
logger = logging.getLogger(__name__)
class AuthService:
def __init__(self):
# Security: Fail fast if JWT_SECRET is not configured - never use default values
self.jwt_secret = getattr(settings, 'JWT_SECRET', None) or os.getenv("JWT_SECRET")
if not self.jwt_secret:
error_msg = (
'CRITICAL: JWT_SECRET is not configured. '
'Please set JWT_SECRET environment variable to a secure random string (minimum 32 characters).'
)
logger.error(error_msg)
if settings.is_production:
raise ValueError(error_msg)
else:
# In development, generate a secure secret but warn
import secrets
self.jwt_secret = secrets.token_urlsafe(64)
logger.warning(
f'JWT_SECRET not configured. Auto-generated secret for development. '
f'Set JWT_SECRET environment variable for production: {self.jwt_secret}'
)
# Validate JWT secret strength
if len(self.jwt_secret) < 32:
error_msg = 'JWT_SECRET must be at least 32 characters long for security.'
logger.error(error_msg)
if settings.is_production:
raise ValueError(error_msg)
else:
logger.warning(error_msg)
# Refresh secret should be different from access secret
self.jwt_refresh_secret = os.getenv("JWT_REFRESH_SECRET")
if not self.jwt_refresh_secret:
# Use a derived secret if not explicitly set, but different from access secret
import hashlib
self.jwt_refresh_secret = hashlib.sha256((self.jwt_secret + "-refresh").encode()).hexdigest()
if not settings.is_production:
logger.info('JWT_REFRESH_SECRET not set, using derived secret')
self.jwt_expires_in = os.getenv("JWT_EXPIRES_IN", "1h")
self.jwt_refresh_expires_in = os.getenv("JWT_REFRESH_EXPIRES_IN", "7d")
def generate_tokens(self, user_id: int) -> dict:
access_token = jwt.encode(
{"userId": user_id},
self.jwt_secret,
algorithm="HS256"
)
refresh_token = jwt.encode(
{"userId": user_id},
self.jwt_refresh_secret,
algorithm="HS256"
)
return {"accessToken": access_token, "refreshToken": refresh_token}
def verify_access_token(self, token: str) -> dict:
return jwt.decode(token, self.jwt_secret, algorithms=["HS256"])
def verify_refresh_token(self, token: str) -> dict:
return jwt.decode(token, self.jwt_refresh_secret, algorithms=["HS256"])
def hash_password(self, password: str) -> str:
password_bytes = password.encode('utf-8')
salt = bcrypt.gensalt()
hashed = bcrypt.hashpw(password_bytes, salt)
return hashed.decode('utf-8')
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
try:
password_bytes = plain_password.encode('utf-8')
hashed_bytes = hashed_password.encode('utf-8')
return bcrypt.checkpw(password_bytes, hashed_bytes)
except Exception:
return False
def format_user_response(self, user: User) -> dict:
return {
"id": user.id,
"name": user.full_name,
"email": user.email,
"phone": user.phone,
"avatar": user.avatar,
"currency": getattr(user, 'currency', 'VND'),
"role": user.role.name if user.role else "customer",
"createdAt": user.created_at.isoformat() if user.created_at else None,
"updatedAt": user.updated_at.isoformat() if user.updated_at else None,
}
async def register(self, db: Session, name: str, email: str, password: str, phone: Optional[str] = None) -> dict:
# Validate password strength
from ...shared.utils.password_validation import validate_password_strength
is_valid, errors = validate_password_strength(password)
if not is_valid:
error_message = 'Password does not meet requirements: ' + '; '.join(errors)
raise ValueError(error_message)
existing_user = db.query(User).filter(User.email == email).first()
if existing_user:
raise ValueError("Email already registered")
hashed_password = self.hash_password(password)
user = User(
full_name=name,
email=email,
password=hashed_password,
phone=phone,
role_id=3
)
db.add(user)
db.commit()
db.refresh(user)
user.role = db.query(Role).filter(Role.id == user.role_id).first()
tokens = self.generate_tokens(user.id)
expires_at = datetime.utcnow() + timedelta(days=7)
refresh_token = RefreshToken(
user_id=user.id,
token=tokens["refreshToken"],
expires_at=expires_at
)
db.add(refresh_token)
db.commit()
try:
client_url = settings.CLIENT_URL or os.getenv("CLIENT_URL", "http://localhost:5173")
email_html = welcome_email_template(user.full_name, user.email, client_url)
await send_email(
to=user.email,
subject="Welcome to Hotel Booking",
html=email_html
)
logger.info(f"Welcome email sent successfully to {user.email}")
except Exception as e:
logger.error(f"Failed to send welcome email to {user.email}: {type(e).__name__}: {str(e)}", exc_info=True)
return {
"user": self.format_user_response(user),
"token": tokens["accessToken"],
"refreshToken": tokens["refreshToken"]
}
async def login(self, db: Session, email: str, password: str, remember_me: bool = False, mfa_token: str = None) -> dict:
email = email.lower().strip() if email else ""
if not email:
raise ValueError("Invalid email or password")
user = db.query(User).filter(User.email == email).first()
if not user:
logger.warning(f"Login attempt with non-existent email: {email}")
raise ValueError("Invalid email or password")
if not user.is_active:
logger.warning(f"Login attempt for inactive user: {email}")
raise ValueError("Account is disabled. Please contact support.")
# Check if account is locked (reset if lockout expired)
if user.locked_until:
if user.locked_until > datetime.utcnow():
remaining_minutes = int((user.locked_until - datetime.utcnow()).total_seconds() / 60)
logger.warning(f"Login attempt for locked account: {email} (locked until {user.locked_until})")
raise ValueError(f"Account is temporarily locked due to multiple failed login attempts. Please try again in {remaining_minutes} minute(s).")
else:
# Lockout expired, reset it
user.locked_until = None
user.failed_login_attempts = 0
db.commit()
user.role = db.query(Role).filter(Role.id == user.role_id).first()
password_valid = self.verify_password(password, user.password)
# Handle failed login attempt
if not password_valid:
user.failed_login_attempts = (user.failed_login_attempts or 0) + 1
max_attempts = settings.MAX_LOGIN_ATTEMPTS
lockout_duration = settings.ACCOUNT_LOCKOUT_DURATION_MINUTES
# Lock account if max attempts reached
if user.failed_login_attempts >= max_attempts:
user.locked_until = datetime.utcnow() + timedelta(minutes=lockout_duration)
logger.warning(f"Account locked due to {user.failed_login_attempts} failed login attempts: {email}")
db.commit()
raise ValueError(f"Account has been temporarily locked due to {max_attempts} failed login attempts. Please try again in {lockout_duration} minute(s).")
else:
remaining_attempts = max_attempts - user.failed_login_attempts
logger.warning(f"Login attempt with invalid password for user: {email} ({user.failed_login_attempts}/{max_attempts} failed attempts)")
db.commit()
raise ValueError(f"Invalid email or password. {remaining_attempts} attempt(s) remaining before account lockout.")
if user.mfa_enabled:
if not mfa_token:
return {
"requires_mfa": True,
"user_id": user.id
}
from ..services.mfa_service import mfa_service
is_backup_code = len(mfa_token) == 8
if not mfa_service.verify_mfa(db, user.id, mfa_token, is_backup_code):
# Increment failed attempts on MFA failure
user.failed_login_attempts = (user.failed_login_attempts or 0) + 1
max_attempts = settings.MAX_LOGIN_ATTEMPTS
lockout_duration = settings.ACCOUNT_LOCKOUT_DURATION_MINUTES
if user.failed_login_attempts >= max_attempts:
user.locked_until = datetime.utcnow() + timedelta(minutes=lockout_duration)
logger.warning(f"Account locked due to {user.failed_login_attempts} failed attempts (MFA failure): {email}")
db.commit()
raise ValueError(f"Account has been temporarily locked due to {max_attempts} failed login attempts. Please try again in {lockout_duration} minute(s).")
else:
remaining_attempts = max_attempts - user.failed_login_attempts
db.commit()
raise ValueError(f"Invalid MFA token. {remaining_attempts} attempt(s) remaining before account lockout.")
# Reset failed login attempts and unlock account on successful login
if user.failed_login_attempts > 0 or user.locked_until:
user.failed_login_attempts = 0
user.locked_until = None
db.commit()
tokens = self.generate_tokens(user.id)
expiry_days = 7 if remember_me else 1
expires_at = datetime.utcnow() + timedelta(days=expiry_days)
try:
db.query(RefreshToken).filter(
RefreshToken.user_id == user.id
).delete()
db.flush()
refresh_token = RefreshToken(
user_id=user.id,
token=tokens["refreshToken"],
expires_at=expires_at
)
db.add(refresh_token)
db.commit()
except Exception as e:
db.rollback()
logger.error(f"Error saving refresh token for user {user.id}: {str(e)}", exc_info=True)
try:
db.query(RefreshToken).filter(
RefreshToken.token == tokens["refreshToken"]
).delete()
db.flush()
refresh_token = RefreshToken(
user_id=user.id,
token=tokens["refreshToken"],
expires_at=expires_at
)
db.add(refresh_token)
db.commit()
except Exception as retry_error:
db.rollback()
logger.error(f"Retry failed for refresh token: {str(retry_error)}", exc_info=True)
raise ValueError("Failed to create session. Please try again.")
return {
"user": self.format_user_response(user),
"token": tokens["accessToken"],
"refreshToken": tokens["refreshToken"]
}
async def refresh_access_token(self, db: Session, refresh_token_str: str) -> dict:
if not refresh_token_str:
raise ValueError("Refresh token is required")
decoded = self.verify_refresh_token(refresh_token_str)
stored_token = db.query(RefreshToken).filter(
RefreshToken.token == refresh_token_str,
RefreshToken.user_id == decoded["userId"]
).first()
if not stored_token:
raise ValueError("Invalid refresh token")
if datetime.utcnow() > stored_token.expires_at:
db.delete(stored_token)
db.commit()
raise ValueError("Refresh token expired")
access_token = jwt.encode(
{"userId": decoded["userId"]},
self.jwt_secret,
algorithm="HS256"
)
return {"token": access_token}
async def logout(self, db: Session, refresh_token_str: str) -> bool:
if refresh_token_str:
db.query(RefreshToken).filter(RefreshToken.token == refresh_token_str).delete()
db.commit()
return True
async def get_profile(self, db: Session, user_id: int) -> dict:
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError("User not found")
user.role = db.query(Role).filter(Role.id == user.role_id).first()
return self.format_user_response(user)
async def update_profile(
self,
db: Session,
user_id: int,
full_name: Optional[str] = None,
email: Optional[str] = None,
phone_number: Optional[str] = None,
password: Optional[str] = None,
current_password: Optional[str] = None,
currency: Optional[str] = None
) -> dict:
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError("User not found")
if password:
if not current_password:
raise ValueError("Current password is required to change password")
if not self.verify_password(current_password, user.password):
raise ValueError("Current password is incorrect")
# Validate new password strength
from ...shared.utils.password_validation import validate_password_strength
is_valid, errors = validate_password_strength(password)
if not is_valid:
error_message = 'New password does not meet requirements: ' + '; '.join(errors)
raise ValueError(error_message)
user.password = self.hash_password(password)
if full_name is not None:
user.full_name = full_name
if email is not None:
existing_user = db.query(User).filter(
User.email == email,
User.id != user_id
).first()
if existing_user:
raise ValueError("Email already registered")
user.email = email
if phone_number is not None:
user.phone = phone_number
if currency is not None:
if len(currency) == 3 and currency.isalpha():
user.currency = currency.upper()
else:
raise ValueError("Invalid currency code. Must be a 3-letter ISO 4217 code (e.g., USD, EUR, VND)")
db.commit()
db.refresh(user)
user.role = db.query(Role).filter(Role.id == user.role_id).first()
return self.format_user_response(user)
def generate_reset_token(self) -> tuple:
reset_token = secrets.token_hex(32)
hashed_token = hashlib.sha256(reset_token.encode()).hexdigest()
return reset_token, hashed_token
async def forgot_password(self, db: Session, email: str) -> dict:
user = db.query(User).filter(User.email == email).first()
if not user:
return {
"success": True,
"message": "If email exists, reset link has been sent"
}
reset_token, hashed_token = self.generate_reset_token()
db.query(PasswordResetToken).filter(PasswordResetToken.user_id == user.id).delete()
expires_at = datetime.utcnow() + timedelta(hours=1)
reset_token_obj = PasswordResetToken(
user_id=user.id,
token=hashed_token,
expires_at=expires_at
)
db.add(reset_token_obj)
db.commit()
client_url = settings.CLIENT_URL or os.getenv("CLIENT_URL", "http://localhost:5173")
reset_url = f"{client_url}/reset-password/{reset_token}"
try:
logger.info(f"Attempting to send password reset email to {user.email}")
logger.info(f"Reset URL: {reset_url}")
email_html = password_reset_email_template(reset_url)
plain_text = f"Please click the following link to reset your password: {reset_url}\n\nIf you did not request this, please ignore this email."
await send_email(
to=user.email,
subject="Reset password - Hotel Booking",
html=email_html,
text=plain_text
)
logger.info(f"Password reset email sent successfully to {user.email} with reset URL: {reset_url}")
except Exception as e:
logger.error(f"Failed to send password reset email to {user.email}: {type(e).__name__}: {str(e)}", exc_info=True)
return {
"success": True,
"message": "Password reset link has been sent to your email"
}
async def reset_password(self, db: Session, token: str, password: str) -> dict:
if not token or not password:
raise ValueError("Token and password are required")
hashed_token = hashlib.sha256(token.encode()).hexdigest()
reset_token = db.query(PasswordResetToken).filter(
PasswordResetToken.token == hashed_token,
PasswordResetToken.expires_at > datetime.utcnow(),
PasswordResetToken.used == False
).first()
if not reset_token:
raise ValueError("Invalid or expired reset token")
user = db.query(User).filter(User.id == reset_token.user_id).first()
if not user:
raise ValueError("User not found")
if self.verify_password(password, user.password):
raise ValueError("New password must be different from the old password")
hashed_password = self.hash_password(password)
user.password = hashed_password
db.commit()
reset_token.used = True
db.commit()
try:
logger.info(f"Attempting to send password changed confirmation email to {user.email}")
email_html = password_changed_email_template(user.email)
await send_email(
to=user.email,
subject="Password Changed",
html=email_html
)
logger.info(f"Password changed confirmation email sent successfully to {user.email}")
except Exception as e:
logger.error(f"Failed to send password changed confirmation email to {user.email}: {type(e).__name__}: {str(e)}", exc_info=True)
return {
"success": True,
"message": "Password has been reset successfully"
}
auth_service = AuthService()

View File

@@ -0,0 +1,127 @@
import pyotp
import qrcode
import secrets
import hashlib
import json
import base64
import io
from typing import List, Optional, Dict, Tuple
from sqlalchemy.orm import Session
from ..models.user import User
import logging
logger = logging.getLogger(__name__)
class MFAService:
@staticmethod
def generate_secret() -> str:
return pyotp.random_base32()
@staticmethod
def generate_qr_code(secret: str, email: str, app_name: str='Hotel Booking') -> str:
totp_uri = pyotp.totp.TOTP(secret).provisioning_uri(name=email, issuer_name=app_name)
qr = qrcode.QRCode(version=1, error_correction=qrcode.constants.ERROR_CORRECT_L, box_size=10, border=4)
qr.add_data(totp_uri)
qr.make(fit=True)
img = qr.make_image(fill_color='black', back_color='white')
buffer = io.BytesIO()
img.save(buffer, format='PNG')
img_data = base64.b64encode(buffer.getvalue()).decode()
return f'data:image/png;base64,{img_data}'
@staticmethod
def generate_backup_codes(count: int=10) -> List[str]:
codes = []
for _ in range(count):
code = secrets.token_urlsafe(6).upper()[:8]
codes.append(code)
return codes
@staticmethod
def hash_backup_code(code: str) -> str:
return hashlib.sha256(code.encode()).hexdigest()
@staticmethod
def verify_backup_code(code: str, hashed_codes: List[str]) -> bool:
code_hash = MFAService.hash_backup_code(code)
return code_hash in hashed_codes
@staticmethod
def verify_totp(token: str, secret: str) -> bool:
try:
totp = pyotp.TOTP(secret)
return totp.verify(token, valid_window=1)
except Exception as e:
logger.error(f'Error verifying TOTP: {str(e)}')
return False
@staticmethod
def enable_mfa(db: Session, user_id: int, secret: str, verification_token: str) -> Tuple[bool, List[str]]:
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError('User not found')
if not MFAService.verify_totp(verification_token, secret):
raise ValueError('Invalid verification token')
backup_codes = MFAService.generate_backup_codes()
hashed_codes = [MFAService.hash_backup_code(code) for code in backup_codes]
user.mfa_enabled = True
user.mfa_secret = secret
user.mfa_backup_codes = json.dumps(hashed_codes)
db.commit()
return (True, backup_codes)
@staticmethod
def disable_mfa(db: Session, user_id: int) -> bool:
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError('User not found')
user.mfa_enabled = False
user.mfa_secret = None
user.mfa_backup_codes = None
db.commit()
return True
@staticmethod
def verify_mfa(db: Session, user_id: int, token: str, is_backup_code: bool=False) -> bool:
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError('User not found')
if not user.mfa_enabled or not user.mfa_secret:
raise ValueError('MFA is not enabled for this user')
if is_backup_code:
if not user.mfa_backup_codes:
return False
hashed_codes = json.loads(user.mfa_backup_codes)
if not MFAService.verify_backup_code(token, hashed_codes):
return False
code_hash = MFAService.hash_backup_code(token)
hashed_codes.remove(code_hash)
user.mfa_backup_codes = json.dumps(hashed_codes) if hashed_codes else None
db.commit()
return True
else:
return MFAService.verify_totp(token, user.mfa_secret)
@staticmethod
def regenerate_backup_codes(db: Session, user_id: int) -> List[str]:
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError('User not found')
if not user.mfa_enabled:
raise ValueError('MFA is not enabled for this user')
backup_codes = MFAService.generate_backup_codes()
hashed_codes = [MFAService.hash_backup_code(code) for code in backup_codes]
user.mfa_backup_codes = json.dumps(hashed_codes)
db.commit()
return backup_codes
@staticmethod
def get_mfa_status(db: Session, user_id: int) -> Dict:
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError('User not found')
backup_codes_count = 0
if user.mfa_backup_codes:
backup_codes_count = len(json.loads(user.mfa_backup_codes))
return {'mfa_enabled': user.mfa_enabled, 'backup_codes_count': backup_codes_count}
mfa_service = MFAService()

View File

@@ -0,0 +1,209 @@
from typing import Optional, Dict, Any
from sqlalchemy.orm import Session
from datetime import datetime, timedelta
import httpx
import secrets
from urllib.parse import urlencode
import logging
from ...security.models.security_event import OAuthProvider, OAuthToken
from ..models.user import User
from ...shared.config.logging_config import get_logger
logger = get_logger(__name__)
class OAuthService:
"""Service for handling OAuth 2.0 / OpenID Connect authentication"""
@staticmethod
def get_authorization_url(db: Session, provider_name: str, redirect_uri: str, state: Optional[str] = None) -> str:
"""Generate OAuth authorization URL"""
provider = db.query(OAuthProvider).filter(
OAuthProvider.name == provider_name,
OAuthProvider.is_active == True
).first()
if not provider:
raise ValueError(f"OAuth provider '{provider_name}' not found or inactive")
if not state:
state = secrets.token_urlsafe(32)
params = {
'client_id': provider.client_id,
'redirect_uri': redirect_uri,
'response_type': 'code',
'scope': provider.scopes or 'openid profile email',
'state': state,
}
return f"{provider.authorization_url}?{urlencode(params)}"
@staticmethod
async def exchange_code_for_token(
db: Session,
provider_name: str,
code: str,
redirect_uri: str
) -> Dict[str, Any]:
"""Exchange authorization code for access token"""
provider = db.query(OAuthProvider).filter(
OAuthProvider.name == provider_name,
OAuthProvider.is_active == True
).first()
if not provider:
raise ValueError(f"OAuth provider '{provider_name}' not found or inactive")
async with httpx.AsyncClient() as client:
response = await client.post(
provider.token_url,
data={
'grant_type': 'authorization_code',
'code': code,
'redirect_uri': redirect_uri,
'client_id': provider.client_id,
'client_secret': provider.client_secret,
},
headers={'Accept': 'application/json'}
)
if response.status_code != 200:
logger.error(f"OAuth token exchange failed: {response.text}")
raise ValueError("Failed to exchange authorization code for token")
token_data = response.json()
return token_data
@staticmethod
async def get_user_info(
db: Session,
provider_name: str,
access_token: str
) -> Dict[str, Any]:
"""Get user information from OAuth provider"""
provider = db.query(OAuthProvider).filter(
OAuthProvider.name == provider_name,
OAuthProvider.is_active == True
).first()
if not provider:
raise ValueError(f"OAuth provider '{provider_name}' not found or inactive")
async with httpx.AsyncClient() as client:
response = await client.get(
provider.userinfo_url,
headers={
'Authorization': f'Bearer {access_token}',
'Accept': 'application/json'
}
)
if response.status_code != 200:
logger.error(f"Failed to get user info: {response.text}")
raise ValueError("Failed to get user information from OAuth provider")
return response.json()
@staticmethod
def save_oauth_token(
db: Session,
user_id: int,
provider_id: int,
provider_user_id: str,
access_token: str,
refresh_token: Optional[str] = None,
expires_in: Optional[int] = None,
scopes: Optional[str] = None
) -> OAuthToken:
"""Save or update OAuth token for user"""
expires_at = None
if expires_in:
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
# Check if token already exists
existing_token = db.query(OAuthToken).filter(
OAuthToken.user_id == user_id,
OAuthToken.provider_id == provider_id
).first()
if existing_token:
existing_token.access_token = access_token
existing_token.refresh_token = refresh_token
existing_token.expires_at = expires_at
existing_token.scopes = scopes
existing_token.updated_at = datetime.utcnow()
db.commit()
db.refresh(existing_token)
return existing_token
else:
new_token = OAuthToken(
user_id=user_id,
provider_id=provider_id,
provider_user_id=provider_user_id,
access_token=access_token,
refresh_token=refresh_token,
expires_at=expires_at,
scopes=scopes
)
db.add(new_token)
db.commit()
db.refresh(new_token)
return new_token
@staticmethod
def find_or_create_user_from_oauth(
db: Session,
provider_name: str,
user_info: Dict[str, Any]
) -> User:
"""Find existing user or create new user from OAuth user info"""
provider = db.query(OAuthProvider).filter(
OAuthProvider.name == provider_name
).first()
if not provider:
raise ValueError(f"OAuth provider '{provider_name}' not found")
# Try to find user by OAuth token
provider_user_id = user_info.get('sub') or user_info.get('id')
oauth_token = db.query(OAuthToken).filter(
OAuthToken.provider_id == provider.id,
OAuthToken.provider_user_id == str(provider_user_id)
).first()
if oauth_token:
return oauth_token.user
# Try to find user by email
email = user_info.get('email')
if email:
user = db.query(User).filter(User.email == email.lower()).first()
if user:
return user
# Create new user
from ..models.role import Role
customer_role = db.query(Role).filter(Role.name == 'customer').first()
if not customer_role:
raise ValueError("Customer role not found")
name = user_info.get('name') or user_info.get('given_name', '') + ' ' + user_info.get('family_name', '')
if not name.strip():
name = email.split('@')[0] if email else 'User'
new_user = User(
email=email.lower() if email else f"{provider_user_id}@{provider_name}.oauth",
full_name=name.strip(),
role_id=customer_role.id,
is_active=True,
email_verified=True # OAuth providers verify emails
)
db.add(new_user)
db.commit()
db.refresh(new_user)
return new_user
oauth_service = OAuthService()