409 lines
14 KiB
Python
409 lines
14 KiB
Python
from jose import jwt
|
|
import bcrypt
|
|
from datetime import datetime, timedelta
|
|
import secrets
|
|
import hashlib
|
|
from sqlalchemy.orm import Session
|
|
from typing import Optional
|
|
import logging
|
|
|
|
from ..models.user import User
|
|
from ..models.refresh_token import RefreshToken
|
|
from ..models.password_reset_token import PasswordResetToken
|
|
from ..models.role import Role
|
|
from ..utils.mailer import send_email
|
|
from ..utils.email_templates import (
|
|
welcome_email_template,
|
|
password_reset_email_template,
|
|
password_changed_email_template
|
|
)
|
|
from ..config.settings import settings
|
|
import os
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class AuthService:
|
|
def __init__(self):
|
|
|
|
self.jwt_secret = getattr(settings, 'JWT_SECRET', None) or os.getenv("JWT_SECRET", "dev-secret-key-change-in-production-12345")
|
|
self.jwt_refresh_secret = os.getenv("JWT_REFRESH_SECRET") or (self.jwt_secret + "-refresh")
|
|
self.jwt_expires_in = os.getenv("JWT_EXPIRES_IN", "1h")
|
|
self.jwt_refresh_expires_in = os.getenv("JWT_REFRESH_EXPIRES_IN", "7d")
|
|
|
|
def generate_tokens(self, user_id: int) -> dict:
|
|
access_token = jwt.encode(
|
|
{"userId": user_id},
|
|
self.jwt_secret,
|
|
algorithm="HS256"
|
|
)
|
|
|
|
refresh_token = jwt.encode(
|
|
{"userId": user_id},
|
|
self.jwt_refresh_secret,
|
|
algorithm="HS256"
|
|
)
|
|
|
|
return {"accessToken": access_token, "refreshToken": refresh_token}
|
|
|
|
def verify_access_token(self, token: str) -> dict:
|
|
return jwt.decode(token, self.jwt_secret, algorithms=["HS256"])
|
|
|
|
def verify_refresh_token(self, token: str) -> dict:
|
|
return jwt.decode(token, self.jwt_refresh_secret, algorithms=["HS256"])
|
|
|
|
def hash_password(self, password: str) -> str:
|
|
|
|
password_bytes = password.encode('utf-8')
|
|
|
|
salt = bcrypt.gensalt()
|
|
hashed = bcrypt.hashpw(password_bytes, salt)
|
|
return hashed.decode('utf-8')
|
|
|
|
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
|
|
try:
|
|
password_bytes = plain_password.encode('utf-8')
|
|
hashed_bytes = hashed_password.encode('utf-8')
|
|
return bcrypt.checkpw(password_bytes, hashed_bytes)
|
|
except Exception:
|
|
return False
|
|
|
|
def format_user_response(self, user: User) -> dict:
|
|
return {
|
|
"id": user.id,
|
|
"name": user.full_name,
|
|
"email": user.email,
|
|
"phone": user.phone,
|
|
"avatar": user.avatar,
|
|
"currency": getattr(user, 'currency', 'VND'),
|
|
"role": user.role.name if user.role else "customer",
|
|
"createdAt": user.created_at.isoformat() if user.created_at else None,
|
|
"updatedAt": user.updated_at.isoformat() if user.updated_at else None,
|
|
}
|
|
|
|
async def register(self, db: Session, name: str, email: str, password: str, phone: Optional[str] = None) -> dict:
|
|
|
|
existing_user = db.query(User).filter(User.email == email).first()
|
|
if existing_user:
|
|
raise ValueError("Email already registered")
|
|
|
|
hashed_password = self.hash_password(password)
|
|
|
|
user = User(
|
|
full_name=name,
|
|
email=email,
|
|
password=hashed_password,
|
|
phone=phone,
|
|
role_id=3
|
|
)
|
|
db.add(user)
|
|
db.commit()
|
|
db.refresh(user)
|
|
|
|
user.role = db.query(Role).filter(Role.id == user.role_id).first()
|
|
|
|
tokens = self.generate_tokens(user.id)
|
|
|
|
expires_at = datetime.utcnow() + timedelta(days=7)
|
|
refresh_token = RefreshToken(
|
|
user_id=user.id,
|
|
token=tokens["refreshToken"],
|
|
expires_at=expires_at
|
|
)
|
|
db.add(refresh_token)
|
|
db.commit()
|
|
|
|
try:
|
|
client_url = settings.CLIENT_URL or os.getenv("CLIENT_URL", "http://localhost:5173")
|
|
email_html = welcome_email_template(user.full_name, user.email, client_url)
|
|
await send_email(
|
|
to=user.email,
|
|
subject="Welcome to Hotel Booking",
|
|
html=email_html
|
|
)
|
|
logger.info(f"Welcome email sent successfully to {user.email}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to send welcome email to {user.email}: {type(e).__name__}: {str(e)}", exc_info=True)
|
|
|
|
return {
|
|
"user": self.format_user_response(user),
|
|
"token": tokens["accessToken"],
|
|
"refreshToken": tokens["refreshToken"]
|
|
}
|
|
|
|
async def login(self, db: Session, email: str, password: str, remember_me: bool = False, mfa_token: str = None) -> dict:
|
|
|
|
email = email.lower().strip() if email else ""
|
|
if not email:
|
|
raise ValueError("Invalid email or password")
|
|
|
|
|
|
user = db.query(User).filter(User.email == email).first()
|
|
if not user:
|
|
logger.warning(f"Login attempt with non-existent email: {email}")
|
|
raise ValueError("Invalid email or password")
|
|
|
|
if not user.is_active:
|
|
logger.warning(f"Login attempt for inactive user: {email}")
|
|
raise ValueError("Account is disabled. Please contact support.")
|
|
|
|
user.role = db.query(Role).filter(Role.id == user.role_id).first()
|
|
|
|
if not self.verify_password(password, user.password):
|
|
logger.warning(f"Login attempt with invalid password for user: {email}")
|
|
raise ValueError("Invalid email or password")
|
|
|
|
if user.mfa_enabled:
|
|
if not mfa_token:
|
|
|
|
return {
|
|
"requires_mfa": True,
|
|
"user_id": user.id
|
|
}
|
|
|
|
|
|
from ..services.mfa_service import mfa_service
|
|
is_backup_code = len(mfa_token) == 8
|
|
if not mfa_service.verify_mfa(db, user.id, mfa_token, is_backup_code):
|
|
raise ValueError("Invalid MFA token")
|
|
|
|
tokens = self.generate_tokens(user.id)
|
|
|
|
expiry_days = 7 if remember_me else 1
|
|
expires_at = datetime.utcnow() + timedelta(days=expiry_days)
|
|
|
|
try:
|
|
db.query(RefreshToken).filter(
|
|
RefreshToken.user_id == user.id
|
|
).delete()
|
|
db.flush()
|
|
|
|
refresh_token = RefreshToken(
|
|
user_id=user.id,
|
|
token=tokens["refreshToken"],
|
|
expires_at=expires_at
|
|
)
|
|
db.add(refresh_token)
|
|
db.commit()
|
|
except Exception as e:
|
|
db.rollback()
|
|
logger.error(f"Error saving refresh token for user {user.id}: {str(e)}", exc_info=True)
|
|
try:
|
|
db.query(RefreshToken).filter(
|
|
RefreshToken.token == tokens["refreshToken"]
|
|
).delete()
|
|
db.flush()
|
|
refresh_token = RefreshToken(
|
|
user_id=user.id,
|
|
token=tokens["refreshToken"],
|
|
expires_at=expires_at
|
|
)
|
|
db.add(refresh_token)
|
|
db.commit()
|
|
except Exception as retry_error:
|
|
db.rollback()
|
|
logger.error(f"Retry failed for refresh token: {str(retry_error)}", exc_info=True)
|
|
raise ValueError("Failed to create session. Please try again.")
|
|
|
|
return {
|
|
"user": self.format_user_response(user),
|
|
"token": tokens["accessToken"],
|
|
"refreshToken": tokens["refreshToken"]
|
|
}
|
|
|
|
async def refresh_access_token(self, db: Session, refresh_token_str: str) -> dict:
|
|
if not refresh_token_str:
|
|
raise ValueError("Refresh token is required")
|
|
|
|
decoded = self.verify_refresh_token(refresh_token_str)
|
|
|
|
stored_token = db.query(RefreshToken).filter(
|
|
RefreshToken.token == refresh_token_str,
|
|
RefreshToken.user_id == decoded["userId"]
|
|
).first()
|
|
|
|
if not stored_token:
|
|
raise ValueError("Invalid refresh token")
|
|
|
|
if datetime.utcnow() > stored_token.expires_at:
|
|
db.delete(stored_token)
|
|
db.commit()
|
|
raise ValueError("Refresh token expired")
|
|
|
|
access_token = jwt.encode(
|
|
{"userId": decoded["userId"]},
|
|
self.jwt_secret,
|
|
algorithm="HS256"
|
|
)
|
|
|
|
return {"token": access_token}
|
|
|
|
async def logout(self, db: Session, refresh_token_str: str) -> bool:
|
|
if refresh_token_str:
|
|
db.query(RefreshToken).filter(RefreshToken.token == refresh_token_str).delete()
|
|
db.commit()
|
|
return True
|
|
|
|
async def get_profile(self, db: Session, user_id: int) -> dict:
|
|
user = db.query(User).filter(User.id == user_id).first()
|
|
if not user:
|
|
raise ValueError("User not found")
|
|
|
|
user.role = db.query(Role).filter(Role.id == user.role_id).first()
|
|
|
|
return self.format_user_response(user)
|
|
|
|
async def update_profile(
|
|
self,
|
|
db: Session,
|
|
user_id: int,
|
|
full_name: Optional[str] = None,
|
|
email: Optional[str] = None,
|
|
phone_number: Optional[str] = None,
|
|
password: Optional[str] = None,
|
|
current_password: Optional[str] = None,
|
|
currency: Optional[str] = None
|
|
) -> dict:
|
|
user = db.query(User).filter(User.id == user_id).first()
|
|
if not user:
|
|
raise ValueError("User not found")
|
|
|
|
if password:
|
|
if not current_password:
|
|
raise ValueError("Current password is required to change password")
|
|
if not self.verify_password(current_password, user.password):
|
|
raise ValueError("Current password is incorrect")
|
|
|
|
user.password = self.hash_password(password)
|
|
|
|
if full_name is not None:
|
|
user.full_name = full_name
|
|
if email is not None:
|
|
|
|
existing_user = db.query(User).filter(
|
|
User.email == email,
|
|
User.id != user_id
|
|
).first()
|
|
if existing_user:
|
|
raise ValueError("Email already registered")
|
|
user.email = email
|
|
if phone_number is not None:
|
|
user.phone = phone_number
|
|
if currency is not None:
|
|
|
|
if len(currency) == 3 and currency.isalpha():
|
|
user.currency = currency.upper()
|
|
else:
|
|
raise ValueError("Invalid currency code. Must be a 3-letter ISO 4217 code (e.g., USD, EUR, VND)")
|
|
|
|
db.commit()
|
|
db.refresh(user)
|
|
|
|
user.role = db.query(Role).filter(Role.id == user.role_id).first()
|
|
|
|
return self.format_user_response(user)
|
|
|
|
def generate_reset_token(self) -> tuple:
|
|
reset_token = secrets.token_hex(32)
|
|
hashed_token = hashlib.sha256(reset_token.encode()).hexdigest()
|
|
return reset_token, hashed_token
|
|
|
|
async def forgot_password(self, db: Session, email: str) -> dict:
|
|
|
|
user = db.query(User).filter(User.email == email).first()
|
|
|
|
if not user:
|
|
return {
|
|
"success": True,
|
|
"message": "If email exists, reset link has been sent"
|
|
}
|
|
|
|
reset_token, hashed_token = self.generate_reset_token()
|
|
|
|
db.query(PasswordResetToken).filter(PasswordResetToken.user_id == user.id).delete()
|
|
|
|
expires_at = datetime.utcnow() + timedelta(hours=1)
|
|
reset_token_obj = PasswordResetToken(
|
|
user_id=user.id,
|
|
token=hashed_token,
|
|
expires_at=expires_at
|
|
)
|
|
db.add(reset_token_obj)
|
|
db.commit()
|
|
|
|
client_url = settings.CLIENT_URL or os.getenv("CLIENT_URL", "http://localhost:5173")
|
|
reset_url = f"{client_url}/reset-password/{reset_token}"
|
|
|
|
try:
|
|
logger.info(f"Attempting to send password reset email to {user.email}")
|
|
logger.info(f"Reset URL: {reset_url}")
|
|
email_html = password_reset_email_template(reset_url)
|
|
|
|
|
|
plain_text = f
|
|
.strip()
|
|
|
|
await send_email(
|
|
to=user.email,
|
|
subject="Reset password - Hotel Booking",
|
|
html=email_html,
|
|
text=plain_text
|
|
)
|
|
logger.info(f"Password reset email sent successfully to {user.email} with reset URL: {reset_url}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to send password reset email to {user.email}: {type(e).__name__}: {str(e)}", exc_info=True)
|
|
|
|
return {
|
|
"success": True,
|
|
"message": "Password reset link has been sent to your email"
|
|
}
|
|
|
|
async def reset_password(self, db: Session, token: str, password: str) -> dict:
|
|
if not token or not password:
|
|
raise ValueError("Token and password are required")
|
|
|
|
hashed_token = hashlib.sha256(token.encode()).hexdigest()
|
|
|
|
reset_token = db.query(PasswordResetToken).filter(
|
|
PasswordResetToken.token == hashed_token,
|
|
PasswordResetToken.expires_at > datetime.utcnow(),
|
|
PasswordResetToken.used == False
|
|
).first()
|
|
|
|
if not reset_token:
|
|
raise ValueError("Invalid or expired reset token")
|
|
|
|
user = db.query(User).filter(User.id == reset_token.user_id).first()
|
|
if not user:
|
|
raise ValueError("User not found")
|
|
|
|
if self.verify_password(password, user.password):
|
|
raise ValueError("New password must be different from the old password")
|
|
|
|
hashed_password = self.hash_password(password)
|
|
|
|
user.password = hashed_password
|
|
db.commit()
|
|
|
|
reset_token.used = True
|
|
db.commit()
|
|
|
|
try:
|
|
logger.info(f"Attempting to send password changed confirmation email to {user.email}")
|
|
email_html = password_changed_email_template(user.email)
|
|
await send_email(
|
|
to=user.email,
|
|
subject="Password Changed",
|
|
html=email_html
|
|
)
|
|
logger.info(f"Password changed confirmation email sent successfully to {user.email}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to send password changed confirmation email to {user.email}: {type(e).__name__}: {str(e)}", exc_info=True)
|
|
|
|
return {
|
|
"success": True,
|
|
"message": "Password has been reset successfully"
|
|
}
|
|
|
|
auth_service = AuthService()
|
|
|