update
This commit is contained in:
0
Backend/src/auth/services/__init__.py
Normal file
0
Backend/src/auth/services/__init__.py
Normal file
500
Backend/src/auth/services/auth_service.py
Normal file
500
Backend/src/auth/services/auth_service.py
Normal file
@@ -0,0 +1,500 @@
|
||||
from jose import jwt
|
||||
import bcrypt
|
||||
from datetime import datetime, timedelta
|
||||
import secrets
|
||||
import hashlib
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
from ..models.user import User
|
||||
from ..models.refresh_token import RefreshToken
|
||||
from ..models.password_reset_token import PasswordResetToken
|
||||
from ..models.role import Role
|
||||
from ...shared.utils.mailer import send_email
|
||||
from ...shared.utils.email_templates import (
|
||||
welcome_email_template,
|
||||
password_reset_email_template,
|
||||
password_changed_email_template
|
||||
)
|
||||
from ...shared.config.settings import settings
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AuthService:
|
||||
def __init__(self):
|
||||
# Security: Fail fast if JWT_SECRET is not configured - never use default values
|
||||
self.jwt_secret = getattr(settings, 'JWT_SECRET', None) or os.getenv("JWT_SECRET")
|
||||
if not self.jwt_secret:
|
||||
error_msg = (
|
||||
'CRITICAL: JWT_SECRET is not configured. '
|
||||
'Please set JWT_SECRET environment variable to a secure random string (minimum 32 characters).'
|
||||
)
|
||||
logger.error(error_msg)
|
||||
if settings.is_production:
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
# In development, generate a secure secret but warn
|
||||
import secrets
|
||||
self.jwt_secret = secrets.token_urlsafe(64)
|
||||
logger.warning(
|
||||
f'JWT_SECRET not configured. Auto-generated secret for development. '
|
||||
f'Set JWT_SECRET environment variable for production: {self.jwt_secret}'
|
||||
)
|
||||
|
||||
# Validate JWT secret strength
|
||||
if len(self.jwt_secret) < 32:
|
||||
error_msg = 'JWT_SECRET must be at least 32 characters long for security.'
|
||||
logger.error(error_msg)
|
||||
if settings.is_production:
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
logger.warning(error_msg)
|
||||
|
||||
# Refresh secret should be different from access secret
|
||||
self.jwt_refresh_secret = os.getenv("JWT_REFRESH_SECRET")
|
||||
if not self.jwt_refresh_secret:
|
||||
# Use a derived secret if not explicitly set, but different from access secret
|
||||
import hashlib
|
||||
self.jwt_refresh_secret = hashlib.sha256((self.jwt_secret + "-refresh").encode()).hexdigest()
|
||||
if not settings.is_production:
|
||||
logger.info('JWT_REFRESH_SECRET not set, using derived secret')
|
||||
|
||||
self.jwt_expires_in = os.getenv("JWT_EXPIRES_IN", "1h")
|
||||
self.jwt_refresh_expires_in = os.getenv("JWT_REFRESH_EXPIRES_IN", "7d")
|
||||
|
||||
def generate_tokens(self, user_id: int) -> dict:
|
||||
access_token = jwt.encode(
|
||||
{"userId": user_id},
|
||||
self.jwt_secret,
|
||||
algorithm="HS256"
|
||||
)
|
||||
|
||||
refresh_token = jwt.encode(
|
||||
{"userId": user_id},
|
||||
self.jwt_refresh_secret,
|
||||
algorithm="HS256"
|
||||
)
|
||||
|
||||
return {"accessToken": access_token, "refreshToken": refresh_token}
|
||||
|
||||
def verify_access_token(self, token: str) -> dict:
|
||||
return jwt.decode(token, self.jwt_secret, algorithms=["HS256"])
|
||||
|
||||
def verify_refresh_token(self, token: str) -> dict:
|
||||
return jwt.decode(token, self.jwt_refresh_secret, algorithms=["HS256"])
|
||||
|
||||
def hash_password(self, password: str) -> str:
|
||||
|
||||
password_bytes = password.encode('utf-8')
|
||||
|
||||
salt = bcrypt.gensalt()
|
||||
hashed = bcrypt.hashpw(password_bytes, salt)
|
||||
return hashed.decode('utf-8')
|
||||
|
||||
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
|
||||
try:
|
||||
password_bytes = plain_password.encode('utf-8')
|
||||
hashed_bytes = hashed_password.encode('utf-8')
|
||||
return bcrypt.checkpw(password_bytes, hashed_bytes)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def format_user_response(self, user: User) -> dict:
|
||||
return {
|
||||
"id": user.id,
|
||||
"name": user.full_name,
|
||||
"email": user.email,
|
||||
"phone": user.phone,
|
||||
"avatar": user.avatar,
|
||||
"currency": getattr(user, 'currency', 'VND'),
|
||||
"role": user.role.name if user.role else "customer",
|
||||
"createdAt": user.created_at.isoformat() if user.created_at else None,
|
||||
"updatedAt": user.updated_at.isoformat() if user.updated_at else None,
|
||||
}
|
||||
|
||||
async def register(self, db: Session, name: str, email: str, password: str, phone: Optional[str] = None) -> dict:
|
||||
# Validate password strength
|
||||
from ...shared.utils.password_validation import validate_password_strength
|
||||
is_valid, errors = validate_password_strength(password)
|
||||
if not is_valid:
|
||||
error_message = 'Password does not meet requirements: ' + '; '.join(errors)
|
||||
raise ValueError(error_message)
|
||||
|
||||
existing_user = db.query(User).filter(User.email == email).first()
|
||||
if existing_user:
|
||||
raise ValueError("Email already registered")
|
||||
|
||||
hashed_password = self.hash_password(password)
|
||||
|
||||
user = User(
|
||||
full_name=name,
|
||||
email=email,
|
||||
password=hashed_password,
|
||||
phone=phone,
|
||||
role_id=3
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
user.role = db.query(Role).filter(Role.id == user.role_id).first()
|
||||
|
||||
tokens = self.generate_tokens(user.id)
|
||||
|
||||
expires_at = datetime.utcnow() + timedelta(days=7)
|
||||
refresh_token = RefreshToken(
|
||||
user_id=user.id,
|
||||
token=tokens["refreshToken"],
|
||||
expires_at=expires_at
|
||||
)
|
||||
db.add(refresh_token)
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
client_url = settings.CLIENT_URL or os.getenv("CLIENT_URL", "http://localhost:5173")
|
||||
email_html = welcome_email_template(user.full_name, user.email, client_url)
|
||||
await send_email(
|
||||
to=user.email,
|
||||
subject="Welcome to Hotel Booking",
|
||||
html=email_html
|
||||
)
|
||||
logger.info(f"Welcome email sent successfully to {user.email}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send welcome email to {user.email}: {type(e).__name__}: {str(e)}", exc_info=True)
|
||||
|
||||
return {
|
||||
"user": self.format_user_response(user),
|
||||
"token": tokens["accessToken"],
|
||||
"refreshToken": tokens["refreshToken"]
|
||||
}
|
||||
|
||||
async def login(self, db: Session, email: str, password: str, remember_me: bool = False, mfa_token: str = None) -> dict:
|
||||
|
||||
email = email.lower().strip() if email else ""
|
||||
if not email:
|
||||
raise ValueError("Invalid email or password")
|
||||
|
||||
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
if not user:
|
||||
logger.warning(f"Login attempt with non-existent email: {email}")
|
||||
raise ValueError("Invalid email or password")
|
||||
|
||||
if not user.is_active:
|
||||
logger.warning(f"Login attempt for inactive user: {email}")
|
||||
raise ValueError("Account is disabled. Please contact support.")
|
||||
|
||||
# Check if account is locked (reset if lockout expired)
|
||||
if user.locked_until:
|
||||
if user.locked_until > datetime.utcnow():
|
||||
remaining_minutes = int((user.locked_until - datetime.utcnow()).total_seconds() / 60)
|
||||
logger.warning(f"Login attempt for locked account: {email} (locked until {user.locked_until})")
|
||||
raise ValueError(f"Account is temporarily locked due to multiple failed login attempts. Please try again in {remaining_minutes} minute(s).")
|
||||
else:
|
||||
# Lockout expired, reset it
|
||||
user.locked_until = None
|
||||
user.failed_login_attempts = 0
|
||||
db.commit()
|
||||
|
||||
user.role = db.query(Role).filter(Role.id == user.role_id).first()
|
||||
|
||||
password_valid = self.verify_password(password, user.password)
|
||||
|
||||
# Handle failed login attempt
|
||||
if not password_valid:
|
||||
user.failed_login_attempts = (user.failed_login_attempts or 0) + 1
|
||||
max_attempts = settings.MAX_LOGIN_ATTEMPTS
|
||||
lockout_duration = settings.ACCOUNT_LOCKOUT_DURATION_MINUTES
|
||||
|
||||
# Lock account if max attempts reached
|
||||
if user.failed_login_attempts >= max_attempts:
|
||||
user.locked_until = datetime.utcnow() + timedelta(minutes=lockout_duration)
|
||||
logger.warning(f"Account locked due to {user.failed_login_attempts} failed login attempts: {email}")
|
||||
db.commit()
|
||||
raise ValueError(f"Account has been temporarily locked due to {max_attempts} failed login attempts. Please try again in {lockout_duration} minute(s).")
|
||||
else:
|
||||
remaining_attempts = max_attempts - user.failed_login_attempts
|
||||
logger.warning(f"Login attempt with invalid password for user: {email} ({user.failed_login_attempts}/{max_attempts} failed attempts)")
|
||||
db.commit()
|
||||
raise ValueError(f"Invalid email or password. {remaining_attempts} attempt(s) remaining before account lockout.")
|
||||
|
||||
if user.mfa_enabled:
|
||||
if not mfa_token:
|
||||
|
||||
return {
|
||||
"requires_mfa": True,
|
||||
"user_id": user.id
|
||||
}
|
||||
|
||||
|
||||
from ..services.mfa_service import mfa_service
|
||||
is_backup_code = len(mfa_token) == 8
|
||||
if not mfa_service.verify_mfa(db, user.id, mfa_token, is_backup_code):
|
||||
# Increment failed attempts on MFA failure
|
||||
user.failed_login_attempts = (user.failed_login_attempts or 0) + 1
|
||||
max_attempts = settings.MAX_LOGIN_ATTEMPTS
|
||||
lockout_duration = settings.ACCOUNT_LOCKOUT_DURATION_MINUTES
|
||||
|
||||
if user.failed_login_attempts >= max_attempts:
|
||||
user.locked_until = datetime.utcnow() + timedelta(minutes=lockout_duration)
|
||||
logger.warning(f"Account locked due to {user.failed_login_attempts} failed attempts (MFA failure): {email}")
|
||||
db.commit()
|
||||
raise ValueError(f"Account has been temporarily locked due to {max_attempts} failed login attempts. Please try again in {lockout_duration} minute(s).")
|
||||
else:
|
||||
remaining_attempts = max_attempts - user.failed_login_attempts
|
||||
db.commit()
|
||||
raise ValueError(f"Invalid MFA token. {remaining_attempts} attempt(s) remaining before account lockout.")
|
||||
|
||||
# Reset failed login attempts and unlock account on successful login
|
||||
if user.failed_login_attempts > 0 or user.locked_until:
|
||||
user.failed_login_attempts = 0
|
||||
user.locked_until = None
|
||||
db.commit()
|
||||
|
||||
tokens = self.generate_tokens(user.id)
|
||||
|
||||
expiry_days = 7 if remember_me else 1
|
||||
expires_at = datetime.utcnow() + timedelta(days=expiry_days)
|
||||
|
||||
try:
|
||||
db.query(RefreshToken).filter(
|
||||
RefreshToken.user_id == user.id
|
||||
).delete()
|
||||
db.flush()
|
||||
|
||||
refresh_token = RefreshToken(
|
||||
user_id=user.id,
|
||||
token=tokens["refreshToken"],
|
||||
expires_at=expires_at
|
||||
)
|
||||
db.add(refresh_token)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error saving refresh token for user {user.id}: {str(e)}", exc_info=True)
|
||||
try:
|
||||
db.query(RefreshToken).filter(
|
||||
RefreshToken.token == tokens["refreshToken"]
|
||||
).delete()
|
||||
db.flush()
|
||||
refresh_token = RefreshToken(
|
||||
user_id=user.id,
|
||||
token=tokens["refreshToken"],
|
||||
expires_at=expires_at
|
||||
)
|
||||
db.add(refresh_token)
|
||||
db.commit()
|
||||
except Exception as retry_error:
|
||||
db.rollback()
|
||||
logger.error(f"Retry failed for refresh token: {str(retry_error)}", exc_info=True)
|
||||
raise ValueError("Failed to create session. Please try again.")
|
||||
|
||||
return {
|
||||
"user": self.format_user_response(user),
|
||||
"token": tokens["accessToken"],
|
||||
"refreshToken": tokens["refreshToken"]
|
||||
}
|
||||
|
||||
async def refresh_access_token(self, db: Session, refresh_token_str: str) -> dict:
|
||||
if not refresh_token_str:
|
||||
raise ValueError("Refresh token is required")
|
||||
|
||||
decoded = self.verify_refresh_token(refresh_token_str)
|
||||
|
||||
stored_token = db.query(RefreshToken).filter(
|
||||
RefreshToken.token == refresh_token_str,
|
||||
RefreshToken.user_id == decoded["userId"]
|
||||
).first()
|
||||
|
||||
if not stored_token:
|
||||
raise ValueError("Invalid refresh token")
|
||||
|
||||
if datetime.utcnow() > stored_token.expires_at:
|
||||
db.delete(stored_token)
|
||||
db.commit()
|
||||
raise ValueError("Refresh token expired")
|
||||
|
||||
access_token = jwt.encode(
|
||||
{"userId": decoded["userId"]},
|
||||
self.jwt_secret,
|
||||
algorithm="HS256"
|
||||
)
|
||||
|
||||
return {"token": access_token}
|
||||
|
||||
async def logout(self, db: Session, refresh_token_str: str) -> bool:
|
||||
if refresh_token_str:
|
||||
db.query(RefreshToken).filter(RefreshToken.token == refresh_token_str).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
async def get_profile(self, db: Session, user_id: int) -> dict:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise ValueError("User not found")
|
||||
|
||||
user.role = db.query(Role).filter(Role.id == user.role_id).first()
|
||||
|
||||
return self.format_user_response(user)
|
||||
|
||||
async def update_profile(
|
||||
self,
|
||||
db: Session,
|
||||
user_id: int,
|
||||
full_name: Optional[str] = None,
|
||||
email: Optional[str] = None,
|
||||
phone_number: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
current_password: Optional[str] = None,
|
||||
currency: Optional[str] = None
|
||||
) -> dict:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise ValueError("User not found")
|
||||
|
||||
if password:
|
||||
if not current_password:
|
||||
raise ValueError("Current password is required to change password")
|
||||
if not self.verify_password(current_password, user.password):
|
||||
raise ValueError("Current password is incorrect")
|
||||
|
||||
# Validate new password strength
|
||||
from ...shared.utils.password_validation import validate_password_strength
|
||||
is_valid, errors = validate_password_strength(password)
|
||||
if not is_valid:
|
||||
error_message = 'New password does not meet requirements: ' + '; '.join(errors)
|
||||
raise ValueError(error_message)
|
||||
|
||||
user.password = self.hash_password(password)
|
||||
|
||||
if full_name is not None:
|
||||
user.full_name = full_name
|
||||
if email is not None:
|
||||
|
||||
existing_user = db.query(User).filter(
|
||||
User.email == email,
|
||||
User.id != user_id
|
||||
).first()
|
||||
if existing_user:
|
||||
raise ValueError("Email already registered")
|
||||
user.email = email
|
||||
if phone_number is not None:
|
||||
user.phone = phone_number
|
||||
if currency is not None:
|
||||
|
||||
if len(currency) == 3 and currency.isalpha():
|
||||
user.currency = currency.upper()
|
||||
else:
|
||||
raise ValueError("Invalid currency code. Must be a 3-letter ISO 4217 code (e.g., USD, EUR, VND)")
|
||||
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
user.role = db.query(Role).filter(Role.id == user.role_id).first()
|
||||
|
||||
return self.format_user_response(user)
|
||||
|
||||
def generate_reset_token(self) -> tuple:
|
||||
reset_token = secrets.token_hex(32)
|
||||
hashed_token = hashlib.sha256(reset_token.encode()).hexdigest()
|
||||
return reset_token, hashed_token
|
||||
|
||||
async def forgot_password(self, db: Session, email: str) -> dict:
|
||||
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
|
||||
if not user:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "If email exists, reset link has been sent"
|
||||
}
|
||||
|
||||
reset_token, hashed_token = self.generate_reset_token()
|
||||
|
||||
db.query(PasswordResetToken).filter(PasswordResetToken.user_id == user.id).delete()
|
||||
|
||||
expires_at = datetime.utcnow() + timedelta(hours=1)
|
||||
reset_token_obj = PasswordResetToken(
|
||||
user_id=user.id,
|
||||
token=hashed_token,
|
||||
expires_at=expires_at
|
||||
)
|
||||
db.add(reset_token_obj)
|
||||
db.commit()
|
||||
|
||||
client_url = settings.CLIENT_URL or os.getenv("CLIENT_URL", "http://localhost:5173")
|
||||
reset_url = f"{client_url}/reset-password/{reset_token}"
|
||||
|
||||
try:
|
||||
logger.info(f"Attempting to send password reset email to {user.email}")
|
||||
logger.info(f"Reset URL: {reset_url}")
|
||||
email_html = password_reset_email_template(reset_url)
|
||||
|
||||
plain_text = f"Please click the following link to reset your password: {reset_url}\n\nIf you did not request this, please ignore this email."
|
||||
|
||||
await send_email(
|
||||
to=user.email,
|
||||
subject="Reset password - Hotel Booking",
|
||||
html=email_html,
|
||||
text=plain_text
|
||||
)
|
||||
logger.info(f"Password reset email sent successfully to {user.email} with reset URL: {reset_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send password reset email to {user.email}: {type(e).__name__}: {str(e)}", exc_info=True)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Password reset link has been sent to your email"
|
||||
}
|
||||
|
||||
async def reset_password(self, db: Session, token: str, password: str) -> dict:
|
||||
if not token or not password:
|
||||
raise ValueError("Token and password are required")
|
||||
|
||||
hashed_token = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
reset_token = db.query(PasswordResetToken).filter(
|
||||
PasswordResetToken.token == hashed_token,
|
||||
PasswordResetToken.expires_at > datetime.utcnow(),
|
||||
PasswordResetToken.used == False
|
||||
).first()
|
||||
|
||||
if not reset_token:
|
||||
raise ValueError("Invalid or expired reset token")
|
||||
|
||||
user = db.query(User).filter(User.id == reset_token.user_id).first()
|
||||
if not user:
|
||||
raise ValueError("User not found")
|
||||
|
||||
if self.verify_password(password, user.password):
|
||||
raise ValueError("New password must be different from the old password")
|
||||
|
||||
hashed_password = self.hash_password(password)
|
||||
|
||||
user.password = hashed_password
|
||||
db.commit()
|
||||
|
||||
reset_token.used = True
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
logger.info(f"Attempting to send password changed confirmation email to {user.email}")
|
||||
email_html = password_changed_email_template(user.email)
|
||||
await send_email(
|
||||
to=user.email,
|
||||
subject="Password Changed",
|
||||
html=email_html
|
||||
)
|
||||
logger.info(f"Password changed confirmation email sent successfully to {user.email}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send password changed confirmation email to {user.email}: {type(e).__name__}: {str(e)}", exc_info=True)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Password has been reset successfully"
|
||||
}
|
||||
|
||||
auth_service = AuthService()
|
||||
|
||||
127
Backend/src/auth/services/mfa_service.py
Normal file
127
Backend/src/auth/services/mfa_service.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import pyotp
|
||||
import qrcode
|
||||
import secrets
|
||||
import hashlib
|
||||
import json
|
||||
import base64
|
||||
import io
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from ..models.user import User
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MFAService:
|
||||
|
||||
@staticmethod
|
||||
def generate_secret() -> str:
|
||||
return pyotp.random_base32()
|
||||
|
||||
@staticmethod
|
||||
def generate_qr_code(secret: str, email: str, app_name: str='Hotel Booking') -> str:
|
||||
totp_uri = pyotp.totp.TOTP(secret).provisioning_uri(name=email, issuer_name=app_name)
|
||||
qr = qrcode.QRCode(version=1, error_correction=qrcode.constants.ERROR_CORRECT_L, box_size=10, border=4)
|
||||
qr.add_data(totp_uri)
|
||||
qr.make(fit=True)
|
||||
img = qr.make_image(fill_color='black', back_color='white')
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='PNG')
|
||||
img_data = base64.b64encode(buffer.getvalue()).decode()
|
||||
return f'data:image/png;base64,{img_data}'
|
||||
|
||||
@staticmethod
|
||||
def generate_backup_codes(count: int=10) -> List[str]:
|
||||
codes = []
|
||||
for _ in range(count):
|
||||
code = secrets.token_urlsafe(6).upper()[:8]
|
||||
codes.append(code)
|
||||
return codes
|
||||
|
||||
@staticmethod
|
||||
def hash_backup_code(code: str) -> str:
|
||||
return hashlib.sha256(code.encode()).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def verify_backup_code(code: str, hashed_codes: List[str]) -> bool:
|
||||
code_hash = MFAService.hash_backup_code(code)
|
||||
return code_hash in hashed_codes
|
||||
|
||||
@staticmethod
|
||||
def verify_totp(token: str, secret: str) -> bool:
|
||||
try:
|
||||
totp = pyotp.TOTP(secret)
|
||||
return totp.verify(token, valid_window=1)
|
||||
except Exception as e:
|
||||
logger.error(f'Error verifying TOTP: {str(e)}')
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def enable_mfa(db: Session, user_id: int, secret: str, verification_token: str) -> Tuple[bool, List[str]]:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise ValueError('User not found')
|
||||
if not MFAService.verify_totp(verification_token, secret):
|
||||
raise ValueError('Invalid verification token')
|
||||
backup_codes = MFAService.generate_backup_codes()
|
||||
hashed_codes = [MFAService.hash_backup_code(code) for code in backup_codes]
|
||||
user.mfa_enabled = True
|
||||
user.mfa_secret = secret
|
||||
user.mfa_backup_codes = json.dumps(hashed_codes)
|
||||
db.commit()
|
||||
return (True, backup_codes)
|
||||
|
||||
@staticmethod
|
||||
def disable_mfa(db: Session, user_id: int) -> bool:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise ValueError('User not found')
|
||||
user.mfa_enabled = False
|
||||
user.mfa_secret = None
|
||||
user.mfa_backup_codes = None
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def verify_mfa(db: Session, user_id: int, token: str, is_backup_code: bool=False) -> bool:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise ValueError('User not found')
|
||||
if not user.mfa_enabled or not user.mfa_secret:
|
||||
raise ValueError('MFA is not enabled for this user')
|
||||
if is_backup_code:
|
||||
if not user.mfa_backup_codes:
|
||||
return False
|
||||
hashed_codes = json.loads(user.mfa_backup_codes)
|
||||
if not MFAService.verify_backup_code(token, hashed_codes):
|
||||
return False
|
||||
code_hash = MFAService.hash_backup_code(token)
|
||||
hashed_codes.remove(code_hash)
|
||||
user.mfa_backup_codes = json.dumps(hashed_codes) if hashed_codes else None
|
||||
db.commit()
|
||||
return True
|
||||
else:
|
||||
return MFAService.verify_totp(token, user.mfa_secret)
|
||||
|
||||
@staticmethod
|
||||
def regenerate_backup_codes(db: Session, user_id: int) -> List[str]:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise ValueError('User not found')
|
||||
if not user.mfa_enabled:
|
||||
raise ValueError('MFA is not enabled for this user')
|
||||
backup_codes = MFAService.generate_backup_codes()
|
||||
hashed_codes = [MFAService.hash_backup_code(code) for code in backup_codes]
|
||||
user.mfa_backup_codes = json.dumps(hashed_codes)
|
||||
db.commit()
|
||||
return backup_codes
|
||||
|
||||
@staticmethod
|
||||
def get_mfa_status(db: Session, user_id: int) -> Dict:
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise ValueError('User not found')
|
||||
backup_codes_count = 0
|
||||
if user.mfa_backup_codes:
|
||||
backup_codes_count = len(json.loads(user.mfa_backup_codes))
|
||||
return {'mfa_enabled': user.mfa_enabled, 'backup_codes_count': backup_codes_count}
|
||||
mfa_service = MFAService()
|
||||
209
Backend/src/auth/services/oauth_service.py
Normal file
209
Backend/src/auth/services/oauth_service.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import datetime, timedelta
|
||||
import httpx
|
||||
import secrets
|
||||
from urllib.parse import urlencode
|
||||
import logging
|
||||
|
||||
from ...security.models.security_event import OAuthProvider, OAuthToken
|
||||
from ..models.user import User
|
||||
from ...shared.config.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class OAuthService:
|
||||
"""Service for handling OAuth 2.0 / OpenID Connect authentication"""
|
||||
|
||||
@staticmethod
|
||||
def get_authorization_url(db: Session, provider_name: str, redirect_uri: str, state: Optional[str] = None) -> str:
|
||||
"""Generate OAuth authorization URL"""
|
||||
provider = db.query(OAuthProvider).filter(
|
||||
OAuthProvider.name == provider_name,
|
||||
OAuthProvider.is_active == True
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f"OAuth provider '{provider_name}' not found or inactive")
|
||||
|
||||
if not state:
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
params = {
|
||||
'client_id': provider.client_id,
|
||||
'redirect_uri': redirect_uri,
|
||||
'response_type': 'code',
|
||||
'scope': provider.scopes or 'openid profile email',
|
||||
'state': state,
|
||||
}
|
||||
|
||||
return f"{provider.authorization_url}?{urlencode(params)}"
|
||||
|
||||
@staticmethod
|
||||
async def exchange_code_for_token(
|
||||
db: Session,
|
||||
provider_name: str,
|
||||
code: str,
|
||||
redirect_uri: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Exchange authorization code for access token"""
|
||||
provider = db.query(OAuthProvider).filter(
|
||||
OAuthProvider.name == provider_name,
|
||||
OAuthProvider.is_active == True
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f"OAuth provider '{provider_name}' not found or inactive")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
provider.token_url,
|
||||
data={
|
||||
'grant_type': 'authorization_code',
|
||||
'code': code,
|
||||
'redirect_uri': redirect_uri,
|
||||
'client_id': provider.client_id,
|
||||
'client_secret': provider.client_secret,
|
||||
},
|
||||
headers={'Accept': 'application/json'}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"OAuth token exchange failed: {response.text}")
|
||||
raise ValueError("Failed to exchange authorization code for token")
|
||||
|
||||
token_data = response.json()
|
||||
return token_data
|
||||
|
||||
@staticmethod
|
||||
async def get_user_info(
|
||||
db: Session,
|
||||
provider_name: str,
|
||||
access_token: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Get user information from OAuth provider"""
|
||||
provider = db.query(OAuthProvider).filter(
|
||||
OAuthProvider.name == provider_name,
|
||||
OAuthProvider.is_active == True
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f"OAuth provider '{provider_name}' not found or inactive")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
provider.userinfo_url,
|
||||
headers={
|
||||
'Authorization': f'Bearer {access_token}',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to get user info: {response.text}")
|
||||
raise ValueError("Failed to get user information from OAuth provider")
|
||||
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def save_oauth_token(
|
||||
db: Session,
|
||||
user_id: int,
|
||||
provider_id: int,
|
||||
provider_user_id: str,
|
||||
access_token: str,
|
||||
refresh_token: Optional[str] = None,
|
||||
expires_in: Optional[int] = None,
|
||||
scopes: Optional[str] = None
|
||||
) -> OAuthToken:
|
||||
"""Save or update OAuth token for user"""
|
||||
expires_at = None
|
||||
if expires_in:
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||
|
||||
# Check if token already exists
|
||||
existing_token = db.query(OAuthToken).filter(
|
||||
OAuthToken.user_id == user_id,
|
||||
OAuthToken.provider_id == provider_id
|
||||
).first()
|
||||
|
||||
if existing_token:
|
||||
existing_token.access_token = access_token
|
||||
existing_token.refresh_token = refresh_token
|
||||
existing_token.expires_at = expires_at
|
||||
existing_token.scopes = scopes
|
||||
existing_token.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(existing_token)
|
||||
return existing_token
|
||||
else:
|
||||
new_token = OAuthToken(
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
provider_user_id=provider_user_id,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_at=expires_at,
|
||||
scopes=scopes
|
||||
)
|
||||
db.add(new_token)
|
||||
db.commit()
|
||||
db.refresh(new_token)
|
||||
return new_token
|
||||
|
||||
@staticmethod
|
||||
def find_or_create_user_from_oauth(
|
||||
db: Session,
|
||||
provider_name: str,
|
||||
user_info: Dict[str, Any]
|
||||
) -> User:
|
||||
"""Find existing user or create new user from OAuth user info"""
|
||||
provider = db.query(OAuthProvider).filter(
|
||||
OAuthProvider.name == provider_name
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f"OAuth provider '{provider_name}' not found")
|
||||
|
||||
# Try to find user by OAuth token
|
||||
provider_user_id = user_info.get('sub') or user_info.get('id')
|
||||
oauth_token = db.query(OAuthToken).filter(
|
||||
OAuthToken.provider_id == provider.id,
|
||||
OAuthToken.provider_user_id == str(provider_user_id)
|
||||
).first()
|
||||
|
||||
if oauth_token:
|
||||
return oauth_token.user
|
||||
|
||||
# Try to find user by email
|
||||
email = user_info.get('email')
|
||||
if email:
|
||||
user = db.query(User).filter(User.email == email.lower()).first()
|
||||
if user:
|
||||
return user
|
||||
|
||||
# Create new user
|
||||
from ..models.role import Role
|
||||
customer_role = db.query(Role).filter(Role.name == 'customer').first()
|
||||
if not customer_role:
|
||||
raise ValueError("Customer role not found")
|
||||
|
||||
name = user_info.get('name') or user_info.get('given_name', '') + ' ' + user_info.get('family_name', '')
|
||||
if not name.strip():
|
||||
name = email.split('@')[0] if email else 'User'
|
||||
|
||||
new_user = User(
|
||||
email=email.lower() if email else f"{provider_user_id}@{provider_name}.oauth",
|
||||
full_name=name.strip(),
|
||||
role_id=customer_role.id,
|
||||
is_active=True,
|
||||
email_verified=True # OAuth providers verify emails
|
||||
)
|
||||
db.add(new_user)
|
||||
db.commit()
|
||||
db.refresh(new_user)
|
||||
|
||||
return new_user
|
||||
|
||||
oauth_service = OAuthService()
|
||||
|
||||
Reference in New Issue
Block a user