This commit is contained in:
Iliyan Angelov
2025-11-29 01:21:11 +02:00
parent cf97df9aeb
commit fb16d7ae34
2856 changed files with 5558 additions and 248 deletions

View File

@@ -0,0 +1,57 @@
"""add blog posts table
Revision ID: add_blog_posts
Revises: fff4b67466b3
Create Date: 2024-01-01 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision = 'add_blog_posts'
down_revision = 'fff4b67466b3' # Update this to the latest migration
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
'blog_posts',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('title', sa.String(length=500), nullable=False),
sa.Column('slug', sa.String(length=500), nullable=False),
sa.Column('excerpt', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('featured_image', sa.String(length=1000), nullable=True),
sa.Column('author_id', sa.Integer(), nullable=False),
sa.Column('published_at', sa.DateTime(), nullable=True),
sa.Column('is_published', sa.Boolean(), nullable=False, server_default='0'),
sa.Column('tags', sa.Text(), nullable=True),
sa.Column('meta_title', sa.String(length=500), nullable=True),
sa.Column('meta_description', sa.Text(), nullable=True),
sa.Column('meta_keywords', sa.String(length=1000), nullable=True),
sa.Column('views_count', sa.Integer(), nullable=False, server_default='0'),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['author_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_blog_posts_id'), 'blog_posts', ['id'], unique=False)
op.create_index(op.f('ix_blog_posts_title'), 'blog_posts', ['title'], unique=False)
op.create_index(op.f('ix_blog_posts_slug'), 'blog_posts', ['slug'], unique=True)
op.create_index(op.f('ix_blog_posts_author_id'), 'blog_posts', ['author_id'], unique=False)
op.create_index(op.f('ix_blog_posts_published_at'), 'blog_posts', ['published_at'], unique=False)
op.create_index(op.f('ix_blog_posts_is_published'), 'blog_posts', ['is_published'], unique=False)
def downgrade() -> None:
op.drop_index(op.f('ix_blog_posts_is_published'), table_name='blog_posts')
op.drop_index(op.f('ix_blog_posts_published_at'), table_name='blog_posts')
op.drop_index(op.f('ix_blog_posts_author_id'), table_name='blog_posts')
op.drop_index(op.f('ix_blog_posts_slug'), table_name='blog_posts')
op.drop_index(op.f('ix_blog_posts_title'), table_name='blog_posts')
op.drop_index(op.f('ix_blog_posts_id'), table_name='blog_posts')
op.drop_table('blog_posts')

View File

@@ -0,0 +1,27 @@
"""add sections to blog posts
Revision ID: add_sections_blog
Revises: add_blog_posts
Create Date: 2024-01-02 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision = 'add_sections_blog'
down_revision = 'add_blog_posts' # Depends on blog_posts table migration
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add sections JSON column to blog_posts table
op.add_column('blog_posts', sa.Column('sections', sa.JSON(), nullable=True))
def downgrade() -> None:
# Remove sections column
op.drop_column('blog_posts', 'sections')

View File

@@ -0,0 +1,50 @@
"""
Fix blog post published_at dates to be in the past
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy.orm import Session
from src.config.database import SessionLocal
from src.models.blog import BlogPost
from datetime import datetime, timedelta
def fix_blog_dates():
"""Update all blog posts to have published_at dates in the past"""
db: Session = SessionLocal()
try:
# Get all published posts
posts = db.query(BlogPost).filter(BlogPost.is_published == True).order_by(BlogPost.created_at.asc()).all()
if not posts:
print("No published posts found.")
return
# Set base date to 60 days ago
base_date = datetime.utcnow() - timedelta(days=60)
updated = 0
for i, post in enumerate(posts):
# Set each post's date going backwards from base_date
# Each post is 2 days earlier than the previous one
new_date = base_date - timedelta(days=i * 2)
post.published_at = new_date
updated += 1
db.commit()
print(f"Successfully updated {updated} blog posts with past published_at dates")
except Exception as e:
db.rollback()
print(f"Error fixing blog dates: {str(e)}")
raise
finally:
db.close()
if __name__ == "__main__":
print("Fixing blog post published_at dates...")
fix_blog_dates()
print("Done!")

File diff suppressed because it is too large Load Diff

View File

@@ -66,7 +66,15 @@ class Settings(BaseSettings):
@property @property
def database_url(self) -> str: def database_url(self) -> str:
return f'mysql+pymysql://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}' """Generate database URL with proper credential escaping to prevent injection."""
from urllib.parse import quote_plus
# Properly escape credentials to handle special characters
user = quote_plus(self.DB_USER)
password = quote_plus(self.DB_PASS)
host = quote_plus(self.DB_HOST)
port = str(self.DB_PORT)
name = quote_plus(self.DB_NAME)
return f'mysql+pymysql://{user}:{password}@{host}:{port}/{name}'
@property @property
def is_production(self) -> bool: def is_production(self) -> bool:

View File

@@ -65,9 +65,20 @@ if settings.RATE_LIMIT_ENABLED:
app.state.limiter = limiter app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
logger.info(f'Rate limiting enabled: {settings.RATE_LIMIT_PER_MINUTE} requests/minute') logger.info(f'Rate limiting enabled: {settings.RATE_LIMIT_PER_MINUTE} requests/minute')
# CORS middleware must be added LAST to handle OPTIONS preflight requests before other middleware
# In FastAPI/Starlette, middleware is executed in reverse order (last added = first executed = outermost)
# So adding CORS last ensures it wraps all other middleware and handles OPTIONS requests early
if settings.is_development: if settings.is_development:
app.add_middleware(CORSMiddleware, allow_origin_regex='http://(localhost|127\\.0\\.0\\.1)(:\\d+)?', allow_credentials=True, allow_methods=['*'], allow_headers=['*']) # More restrictive CORS even in development for better security practices
logger.info('CORS configured for development (allowing localhost)') app.add_middleware(
CORSMiddleware,
allow_origin_regex='http://(localhost|127\\.0\\.0\\.1)(:\\d+)?',
allow_credentials=True,
allow_methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], # Explicit methods
allow_headers=['Content-Type', 'Authorization', 'X-XSRF-TOKEN', 'X-Requested-With', 'X-Request-ID'] # Explicit headers
)
logger.info('CORS configured for development (allowing localhost with explicit methods/headers)')
else: else:
# Validate CORS_ORIGINS in production # Validate CORS_ORIGINS in production
if not settings.CORS_ORIGINS or len(settings.CORS_ORIGINS) == 0: if not settings.CORS_ORIGINS or len(settings.CORS_ORIGINS) == 0:
@@ -125,7 +136,7 @@ from .routes import (
faq_routes, loyalty_routes, guest_profile_routes, analytics_routes, faq_routes, loyalty_routes, guest_profile_routes, analytics_routes,
workflow_routes, task_routes, notification_routes, group_booking_routes, workflow_routes, task_routes, notification_routes, group_booking_routes,
advanced_room_routes, rate_plan_routes, package_routes, security_routes, advanced_room_routes, rate_plan_routes, package_routes, security_routes,
email_campaign_routes email_campaign_routes, blog_routes
) )
# Register all routes with /api prefix (removed duplicate registrations) # Register all routes with /api prefix (removed duplicate registrations)
@@ -172,6 +183,7 @@ app.include_router(package_routes.router, prefix=api_prefix)
app.include_router(security_routes.router, prefix=api_prefix) app.include_router(security_routes.router, prefix=api_prefix)
app.include_router(email_campaign_routes.router, prefix=api_prefix) app.include_router(email_campaign_routes.router, prefix=api_prefix)
app.include_router(page_content_routes.router, prefix=api_prefix) app.include_router(page_content_routes.router, prefix=api_prefix)
app.include_router(blog_routes.router, prefix=api_prefix)
logger.info('All routes registered successfully') logger.info('All routes registered successfully')
def ensure_jwt_secret(): def ensure_jwt_secret():

View File

@@ -82,6 +82,10 @@ class AdminIPWhitelistMiddleware(BaseHTTPMiddleware):
if not self.enabled: if not self.enabled:
return await call_next(request) return await call_next(request)
# Skip OPTIONS requests (CORS preflight) - let CORS middleware handle them
if request.method == 'OPTIONS':
return await call_next(request)
# Only apply to admin routes # Only apply to admin routes
if not self._is_admin_route(request.url.path): if not self._is_admin_route(request.url.path):
return await call_next(request) return await call_next(request)

View File

@@ -1,4 +1,4 @@
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status, Request, Cookie
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import JWTError, jwt from jose import JWTError, jwt
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -8,7 +8,7 @@ from ..config.database import get_db
from ..config.settings import settings from ..config.settings import settings
from ..models.user import User from ..models.user import User
from ..models.role import Role from ..models.role import Role
security = HTTPBearer() security = HTTPBearer(auto_error=False)
def get_jwt_secret() -> str: def get_jwt_secret() -> str:
""" """
@@ -38,9 +38,34 @@ def get_jwt_secret() -> str:
return jwt_secret return jwt_secret
def get_current_user(credentials: HTTPAuthorizationCredentials=Depends(security), db: Session=Depends(get_db)) -> User: def get_current_user(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
access_token: Optional[str] = Cookie(None, alias='accessToken'),
db: Session = Depends(get_db)
) -> User:
"""
Get current user from either Authorization header or httpOnly cookie.
Prefers Authorization header for backward compatibility, falls back to cookie.
"""
# Try to get token from Authorization header first
token = None
if credentials:
token = credentials.credentials token = credentials.credentials
credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Could not validate credentials', headers={'WWW-Authenticate': 'Bearer'})
# Fall back to cookie if no header token
if not token and access_token:
token = access_token
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Could not validate credentials',
headers={'WWW-Authenticate': 'Bearer'}
)
if not token:
raise credentials_exception
try: try:
jwt_secret = get_jwt_secret() jwt_secret = get_jwt_secret()
payload = jwt.decode(token, jwt_secret, algorithms=['HS256']) payload = jwt.decode(token, jwt_secret, algorithms=['HS256'])
@@ -69,10 +94,28 @@ def authorize_roles(*allowed_roles: str):
return current_user return current_user
return role_checker return role_checker
def get_current_user_optional(credentials: Optional[HTTPAuthorizationCredentials]=Depends(HTTPBearer(auto_error=False)), db: Session=Depends(get_db)) -> Optional[User]: def get_current_user_optional(
if not credentials: request: Request,
return None credentials: Optional[HTTPAuthorizationCredentials] = Depends(HTTPBearer(auto_error=False)),
access_token: Optional[str] = Cookie(None, alias='accessToken'),
db: Session = Depends(get_db)
) -> Optional[User]:
"""
Get current user optionally from either Authorization header or httpOnly cookie.
Returns None if no valid token is found.
"""
# Try to get token from Authorization header first
token = None
if credentials:
token = credentials.credentials token = credentials.credentials
# Fall back to cookie if no header token
if not token and access_token:
token = access_token
if not token:
return None
try: try:
jwt_secret = get_jwt_secret() jwt_secret = get_jwt_secret()
payload = jwt.decode(token, jwt_secret, algorithms=['HS256']) payload = jwt.decode(token, jwt_secret, algorithms=['HS256'])

View File

@@ -12,13 +12,15 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
security_headers = {'X-Content-Type-Options': 'nosniff', 'X-Frame-Options': 'DENY', 'X-XSS-Protection': '1; mode=block', 'Referrer-Policy': 'strict-origin-when-cross-origin', 'Permissions-Policy': 'geolocation=(), microphone=(), camera=()'} security_headers = {'X-Content-Type-Options': 'nosniff', 'X-Frame-Options': 'DENY', 'X-XSS-Protection': '1; mode=block', 'Referrer-Policy': 'strict-origin-when-cross-origin', 'Permissions-Policy': 'geolocation=(), microphone=(), camera=()'}
security_headers.setdefault('Cross-Origin-Resource-Policy', 'cross-origin') security_headers.setdefault('Cross-Origin-Resource-Policy', 'cross-origin')
if settings.is_production: if settings.is_production:
# Enhanced CSP with additional directives # Enhanced CSP with stricter directives
# Note: unsafe-inline and unsafe-eval are kept for React/Vite compatibility # Using 'strict-dynamic' for better security with nonce-based scripts
# Consider moving to nonces/hashes in future for stricter policy # Note: For React/Vite, consider implementing nonce-based CSP in the future
# Current policy balances security with framework requirements
security_headers['Content-Security-Policy'] = ( security_headers['Content-Security-Policy'] = (
"default-src 'self'; " "default-src 'self'; "
"script-src 'self' 'unsafe-inline' 'unsafe-eval' https://js.stripe.com; " "script-src 'self' 'strict-dynamic' https://js.stripe.com; "
"style-src 'self' 'unsafe-inline'; " "script-src-elem 'self' 'unsafe-inline' https://js.stripe.com; " # Allow inline scripts for Vite/React
"style-src 'self' 'unsafe-inline'; " # Required for React/Vite
"img-src 'self' data: https:; " "img-src 'self' data: https:; "
"font-src 'self' data:; " "font-src 'self' data:; "
"connect-src 'self' https: https://js.stripe.com https://hooks.stripe.com; " "connect-src 'self' https: https://js.stripe.com https://hooks.stripe.com; "
@@ -27,7 +29,8 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"form-action 'self'; " "form-action 'self'; "
"frame-ancestors 'none'; " "frame-ancestors 'none'; "
"object-src 'none'; " "object-src 'none'; "
"upgrade-insecure-requests" "upgrade-insecure-requests; "
"block-all-mixed-content" # Block mixed HTTP/HTTPS content
) )
# HSTS with preload directive (only add preload if domain is ready for it) # HSTS with preload directive (only add preload if domain is ready for it)
# Preload requires manual submission to hstspreload.org # Preload requires manual submission to hstspreload.org

View File

@@ -12,6 +12,7 @@ from .service_booking import ServiceBooking, ServiceBookingItem, ServicePayment,
from .promotion import Promotion from .promotion import Promotion
from .checkin_checkout import CheckInCheckOut from .checkin_checkout import CheckInCheckOut
from .banner import Banner from .banner import Banner
from .blog import BlogPost
from .review import Review from .review import Review
from .favorite import Favorite from .favorite import Favorite
from .audit_log import AuditLog from .audit_log import AuditLog
@@ -44,4 +45,4 @@ from .package import Package, PackageItem, PackageStatus, PackageItemType
from .security_event import SecurityEvent, SecurityEventType, SecurityEventSeverity, IPWhitelist, IPBlacklist, OAuthProvider, OAuthToken from .security_event import SecurityEvent, SecurityEventType, SecurityEventSeverity, IPWhitelist, IPBlacklist, OAuthProvider, OAuthToken
from .gdpr_compliance import DataSubjectRequest, DataSubjectRequestType, DataSubjectRequestStatus, DataRetentionPolicy, ConsentRecord from .gdpr_compliance import DataSubjectRequest, DataSubjectRequestType, DataSubjectRequestStatus, DataRetentionPolicy, ConsentRecord
from .email_campaign import Campaign, CampaignStatus, CampaignType, CampaignSegment, EmailTemplate, CampaignEmail, EmailStatus, EmailClick, DripSequence, DripSequenceStep, DripSequenceEnrollment, Unsubscribe from .email_campaign import Campaign, CampaignStatus, CampaignType, CampaignSegment, EmailTemplate, CampaignEmail, EmailStatus, EmailClick, DripSequence, DripSequenceStep, DripSequenceEnrollment, Unsubscribe
__all__ = ['Role', 'User', 'RefreshToken', 'PasswordResetToken', 'RoomType', 'Room', 'Booking', 'Payment', 'Service', 'ServiceUsage', 'ServiceBooking', 'ServiceBookingItem', 'ServicePayment', 'ServiceBookingStatus', 'ServicePaymentStatus', 'ServicePaymentMethod', 'Promotion', 'CheckInCheckOut', 'Banner', 'Review', 'Favorite', 'AuditLog', 'CookiePolicy', 'CookieIntegrationConfig', 'SystemSettings', 'Invoice', 'InvoiceItem', 'PageContent', 'PageType', 'Chat', 'ChatMessage', 'ChatStatus', 'LoyaltyTier', 'TierLevel', 'UserLoyalty', 'LoyaltyPointTransaction', 'TransactionType', 'TransactionSource', 'LoyaltyReward', 'RewardType', 'RewardStatus', 'RewardRedemption', 'RedemptionStatus', 'Referral', 'ReferralStatus', 'GuestPreference', 'GuestNote', 'GuestTag', 'guest_tag_association', 'GuestCommunication', 'CommunicationType', 'CommunicationDirection', 'GuestSegment', 'guest_segment_association', 'Workflow', 'WorkflowInstance', 'Task', 'TaskComment', 'WorkflowType', 'WorkflowStatus', 'WorkflowTrigger', 'TaskStatus', 'TaskPriority', 'Notification', 'NotificationTemplate', 'NotificationPreference', 'NotificationDeliveryLog', 'NotificationChannel', 'NotificationStatus', 'NotificationType', 'GroupBooking', 'GroupBookingMember', 'GroupRoomBlock', 'GroupPayment', 'GroupBookingStatus', 'PaymentOption', 'RoomMaintenance', 'MaintenanceType', 'MaintenanceStatus', 'HousekeepingTask', 'HousekeepingStatus', 'HousekeepingType', 'RoomInspection', 'InspectionType', 'InspectionStatus', 'RoomAttribute', 'RatePlan', 'RatePlanRule', 'RatePlanType', 'RatePlanStatus', 'Package', 'PackageItem', 'PackageStatus', 'PackageItemType', 'SecurityEvent', 'SecurityEventType', 'SecurityEventSeverity', 'IPWhitelist', 'IPBlacklist', 'OAuthProvider', 'OAuthToken', 'DataSubjectRequest', 'DataSubjectRequestType', 'DataSubjectRequestStatus', 'DataRetentionPolicy', 'ConsentRecord', 'Campaign', 'CampaignStatus', 'CampaignType', 'CampaignSegment', 'EmailTemplate', 'CampaignEmail', 'EmailStatus', 'EmailClick', 'DripSequence', 'DripSequenceStep', 'DripSequenceEnrollment', 'Unsubscribe'] __all__ = ['Role', 'User', 'RefreshToken', 'PasswordResetToken', 'RoomType', 'Room', 'Booking', 'Payment', 'Service', 'ServiceUsage', 'ServiceBooking', 'ServiceBookingItem', 'ServicePayment', 'ServiceBookingStatus', 'ServicePaymentStatus', 'ServicePaymentMethod', 'Promotion', 'CheckInCheckOut', 'Banner', 'BlogPost', 'Review', 'Favorite', 'AuditLog', 'CookiePolicy', 'CookieIntegrationConfig', 'SystemSettings', 'Invoice', 'InvoiceItem', 'PageContent', 'PageType', 'Chat', 'ChatMessage', 'ChatStatus', 'LoyaltyTier', 'TierLevel', 'UserLoyalty', 'LoyaltyPointTransaction', 'TransactionType', 'TransactionSource', 'LoyaltyReward', 'RewardType', 'RewardStatus', 'RewardRedemption', 'RedemptionStatus', 'Referral', 'ReferralStatus', 'GuestPreference', 'GuestNote', 'GuestTag', 'guest_tag_association', 'GuestCommunication', 'CommunicationType', 'CommunicationDirection', 'GuestSegment', 'guest_segment_association', 'Workflow', 'WorkflowInstance', 'Task', 'TaskComment', 'WorkflowType', 'WorkflowStatus', 'WorkflowTrigger', 'TaskStatus', 'TaskPriority', 'Notification', 'NotificationTemplate', 'NotificationPreference', 'NotificationDeliveryLog', 'NotificationChannel', 'NotificationStatus', 'NotificationType', 'GroupBooking', 'GroupBookingMember', 'GroupRoomBlock', 'GroupPayment', 'GroupBookingStatus', 'PaymentOption', 'RoomMaintenance', 'MaintenanceType', 'MaintenanceStatus', 'HousekeepingTask', 'HousekeepingStatus', 'HousekeepingType', 'RoomInspection', 'InspectionType', 'InspectionStatus', 'RoomAttribute', 'RatePlan', 'RatePlanRule', 'RatePlanType', 'RatePlanStatus', 'Package', 'PackageItem', 'PackageStatus', 'PackageItemType', 'SecurityEvent', 'SecurityEventType', 'SecurityEventSeverity', 'IPWhitelist', 'IPBlacklist', 'OAuthProvider', 'OAuthToken', 'DataSubjectRequest', 'DataSubjectRequestType', 'DataSubjectRequestStatus', 'DataRetentionPolicy', 'ConsentRecord', 'Campaign', 'CampaignStatus', 'CampaignType', 'CampaignSegment', 'EmailTemplate', 'CampaignEmail', 'EmailStatus', 'EmailClick', 'DripSequence', 'DripSequenceStep', 'DripSequenceEnrollment', 'Unsubscribe']

Binary file not shown.

View File

@@ -0,0 +1,29 @@
from sqlalchemy import Column, Integer, String, Text, DateTime, Boolean, ForeignKey, JSON
from sqlalchemy.orm import relationship
from datetime import datetime
from ..config.database import Base
class BlogPost(Base):
__tablename__ = 'blog_posts'
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
title = Column(String(500), nullable=False, index=True)
slug = Column(String(500), unique=True, nullable=False, index=True)
excerpt = Column(Text, nullable=True)
content = Column(Text, nullable=False)
featured_image = Column(String(1000), nullable=True)
author_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True)
published_at = Column(DateTime, nullable=True, index=True)
is_published = Column(Boolean, nullable=False, default=False, index=True)
tags = Column(Text, nullable=True) # JSON array stored as text
meta_title = Column(String(500), nullable=True)
meta_description = Column(Text, nullable=True)
meta_keywords = Column(String(1000), nullable=True)
views_count = Column(Integer, nullable=False, default=0)
sections = Column(JSON, nullable=True) # Structured content sections
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
# Relationships
author = relationship('User', backref='blog_posts')

View File

@@ -7,7 +7,7 @@ import uuid
import os import os
from ..config.database import get_db from ..config.database import get_db
from ..services.auth_service import auth_service from ..services.auth_service import auth_service
from ..schemas.auth import RegisterRequest, LoginRequest, RefreshTokenRequest, ForgotPasswordRequest, ResetPasswordRequest, AuthResponse, TokenResponse, MessageResponse, MFAInitResponse, EnableMFARequest, VerifyMFARequest, MFAStatusResponse from ..schemas.auth import RegisterRequest, LoginRequest, RefreshTokenRequest, ForgotPasswordRequest, ResetPasswordRequest, AuthResponse, TokenResponse, MessageResponse, MFAInitResponse, EnableMFARequest, VerifyMFARequest, MFAStatusResponse, UpdateProfileRequest
from ..middleware.auth import get_current_user from ..middleware.auth import get_current_user
from ..models.user import User from ..models.user import User
from ..services.audit_service import audit_service from ..services.audit_service import audit_service
@@ -22,16 +22,15 @@ AUTH_RATE_LIMIT = "5/minute" # 5 attempts per minute per IP
PASSWORD_RESET_LIMIT = "3/hour" # 3 password reset requests per hour per IP PASSWORD_RESET_LIMIT = "3/hour" # 3 password reset requests per hour per IP
LOGIN_RATE_LIMIT = "10/minute" # 10 login attempts per minute per IP LOGIN_RATE_LIMIT = "10/minute" # 10 login attempts per minute per IP
# Initialize limiter - will be set from app state
limiter = None
def get_limiter(request: Request) -> Limiter: def get_limiter(request: Request) -> Limiter:
"""Get limiter instance from app state.""" """Get limiter instance from app state."""
return request.app.state.limiter if hasattr(request.app.state, 'limiter') else None global limiter
if hasattr(request.app.state, 'limiter'):
def apply_rate_limit(func, limit_value: str): limiter = request.app.state.limiter
"""Helper to apply rate limiting decorator if limiter is available.""" return limiter
def decorator(*args, **kwargs):
# This will be applied at runtime when route is called
return func(*args, **kwargs)
return decorator
def get_base_url(request: Request) -> str: def get_base_url(request: Request) -> str:
return os.getenv('SERVER_URL') or f'http://{request.headers.get('host', 'localhost:8000')}' return os.getenv('SERVER_URL') or f'http://{request.headers.get('host', 'localhost:8000')}'
@@ -52,6 +51,7 @@ async def register(
response: Response, response: Response,
db: Session=Depends(get_db) db: Session=Depends(get_db)
): ):
# Rate limiting is handled by middleware, but we can add additional checks here if needed
client_ip = request.client.host if request.client else None client_ip = request.client.host if request.client else None
user_agent = request.headers.get('User-Agent') user_agent = request.headers.get('User-Agent')
request_id = getattr(request.state, 'request_id', None) request_id = getattr(request.state, 'request_id', None)
@@ -59,14 +59,28 @@ async def register(
try: try:
result = await auth_service.register(db=db, name=register_request.name, email=register_request.email, password=register_request.password, phone=register_request.phone) result = await auth_service.register(db=db, name=register_request.name, email=register_request.email, password=register_request.password, phone=register_request.phone)
from ..config.settings import settings from ..config.settings import settings
max_age = 7 * 24 * 60 * 60 # 7 days for registration
# Use secure cookies in production (HTTPS required) # Use secure cookies in production (HTTPS required)
# Set access token in httpOnly cookie for security
# Use 'lax' in development for cross-origin support, 'strict' in production
samesite_value = 'strict' if settings.is_production else 'lax'
response.set_cookie(
key='accessToken',
value=result['token'],
httponly=True,
secure=settings.is_production, # Secure flag enabled in production
samesite=samesite_value,
max_age=max_age,
path='/'
)
# Set refresh token in httpOnly cookie
response.set_cookie( response.set_cookie(
key='refreshToken', key='refreshToken',
value=result['refreshToken'], value=result['refreshToken'],
httponly=True, httponly=True,
secure=settings.is_production, # Secure flag enabled in production secure=settings.is_production, # Secure flag enabled in production
samesite='strict', samesite=samesite_value,
max_age=7 * 24 * 60 * 60, max_age=max_age,
path='/' path='/'
) )
@@ -83,7 +97,8 @@ async def register(
status='success' status='success'
) )
return {'status': 'success', 'message': 'Registration successful', 'data': {'token': result['token'], 'user': result['user']}} # Return user data but NOT the token (it's in httpOnly cookie now)
return {'status': 'success', 'message': 'Registration successful', 'data': {'user': result['user']}}
except ValueError as e: except ValueError as e:
error_message = str(e) error_message = str(e)
# Log failed registration attempt # Log failed registration attempt
@@ -132,12 +147,25 @@ async def login(
from ..config.settings import settings from ..config.settings import settings
max_age = 7 * 24 * 60 * 60 if login_request.rememberMe else 1 * 24 * 60 * 60 max_age = 7 * 24 * 60 * 60 if login_request.rememberMe else 1 * 24 * 60 * 60
# Use secure cookies in production (HTTPS required) # Use secure cookies in production (HTTPS required)
# Set access token in httpOnly cookie for security
# Use 'lax' in development for cross-origin support, 'strict' in production
samesite_value = 'strict' if settings.is_production else 'lax'
response.set_cookie(
key='accessToken',
value=result['token'],
httponly=True,
secure=settings.is_production, # Secure flag enabled in production
samesite=samesite_value,
max_age=max_age,
path='/'
)
# Set refresh token in httpOnly cookie
response.set_cookie( response.set_cookie(
key='refreshToken', key='refreshToken',
value=result['refreshToken'], value=result['refreshToken'],
httponly=True, httponly=True,
secure=settings.is_production, # Secure flag enabled in production secure=settings.is_production, # Secure flag enabled in production
samesite='strict', samesite=samesite_value,
max_age=max_age, max_age=max_age,
path='/' path='/'
) )
@@ -155,7 +183,8 @@ async def login(
status='success' status='success'
) )
return {'status': 'success', 'data': {'token': result['token'], 'user': result['user']}} # Return user data but NOT the token (it's in httpOnly cookie now)
return {'status': 'success', 'data': {'user': result['user']}}
except ValueError as e: except ValueError as e:
error_message = str(e) error_message = str(e)
status_code = status.HTTP_401_UNAUTHORIZED if 'Invalid email or password' in error_message or 'Invalid MFA token' in error_message else status.HTTP_400_BAD_REQUEST status_code = status.HTTP_401_UNAUTHORIZED if 'Invalid email or password' in error_message or 'Invalid MFA token' in error_message else status.HTTP_400_BAD_REQUEST
@@ -176,12 +205,32 @@ async def login(
return JSONResponse(status_code=status_code, content={'status': 'error', 'message': error_message}) return JSONResponse(status_code=status_code, content={'status': 'error', 'message': error_message})
@router.post('/refresh-token', response_model=TokenResponse) @router.post('/refresh-token', response_model=TokenResponse)
async def refresh_token(refreshToken: str=Cookie(None), db: Session=Depends(get_db)): async def refresh_token(
request: Request,
response: Response,
refreshToken: str=Cookie(None),
db: Session=Depends(get_db)
):
if not refreshToken: if not refreshToken:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Refresh token not found') raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Refresh token not found')
try: try:
result = await auth_service.refresh_access_token(db, refreshToken) result = await auth_service.refresh_access_token(db, refreshToken)
return result from ..config.settings import settings
# Set new access token in httpOnly cookie
# Use 'lax' in development for cross-origin support, 'strict' in production
samesite_value = 'strict' if settings.is_production else 'lax'
max_age = 7 * 24 * 60 * 60 # 7 days
response.set_cookie(
key='accessToken',
value=result['token'],
httponly=True,
secure=settings.is_production,
samesite=samesite_value,
max_age=max_age,
path='/'
)
# Return user data but NOT the token (it's in httpOnly cookie now)
return {'status': 'success', 'data': {'user': result.get('user')}}
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e)) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
@@ -199,7 +248,13 @@ async def logout(
if refreshToken: if refreshToken:
await auth_service.logout(db, refreshToken) await auth_service.logout(db, refreshToken)
response.delete_cookie(key='refreshToken', path='/')
# Delete both access and refresh token cookies
from ..config.settings import settings
# Use 'lax' in development for cross-origin support, 'strict' in production
samesite_value = 'strict' if settings.is_production else 'lax'
response.delete_cookie(key='refreshToken', path='/', secure=settings.is_production, samesite=samesite_value)
response.delete_cookie(key='accessToken', path='/', secure=settings.is_production, samesite=samesite_value)
# Log logout # Log logout
await audit_service.log_action( await audit_service.log_action(
@@ -227,9 +282,18 @@ async def get_profile(current_user: User=Depends(get_current_user), db: Session=
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@router.put('/profile') @router.put('/profile')
async def update_profile(profile_data: dict, current_user: User=Depends(get_current_user), db: Session=Depends(get_db)): async def update_profile(profile_data: UpdateProfileRequest, current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
try: try:
user = await auth_service.update_profile(db=db, user_id=current_user.id, full_name=profile_data.get('full_name'), email=profile_data.get('email'), phone_number=profile_data.get('phone_number'), password=profile_data.get('password'), current_password=profile_data.get('currentPassword'), currency=profile_data.get('currency')) user = await auth_service.update_profile(
db=db,
user_id=current_user.id,
full_name=profile_data.full_name,
email=profile_data.email,
phone_number=profile_data.phone_number,
password=profile_data.password,
current_password=profile_data.currentPassword,
currency=profile_data.currency
)
return {'status': 'success', 'message': 'Profile updated successfully', 'data': {'user': user}} return {'status': 'success', 'message': 'Profile updated successfully', 'data': {'user': user}}
except ValueError as e: except ValueError as e:
error_message = str(e) error_message = str(e)

View File

@@ -0,0 +1,564 @@
from fastapi import APIRouter, Depends, HTTPException, status, Query, Request, UploadFile, File
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import or_, func, case
from typing import Optional, List
from datetime import datetime
import json
import re
import os
import uuid
import aiofiles
from pathlib import Path
from ..config.database import get_db
from ..config.logging_config import get_logger
from ..middleware.auth import get_current_user, get_current_user_optional, authorize_roles
from ..models.user import User
from ..models.blog import BlogPost
from ..schemas.blog import BlogPostCreate, BlogPostUpdate, BlogPostResponse, BlogPostListResponse
from ..utils.response_helpers import success_response
logger = get_logger(__name__)
router = APIRouter(prefix='/blog', tags=['blog'])
def normalize_image_url(image_url: str, base_url: str) -> str:
if not image_url:
return image_url
if image_url.startswith('http://') or image_url.startswith('https://'):
return image_url
if image_url.startswith('/'):
return f'{base_url}{image_url}'
return f'{base_url}/{image_url}'
def get_base_url(request: Request) -> str:
return os.getenv('SERVER_URL') or f'http://{request.headers.get("host", "localhost:8000")}'
def generate_slug(title: str) -> str:
"""Generate a URL-friendly slug from title"""
slug = title.lower()
slug = re.sub(r'[^\w\s-]', '', slug)
slug = re.sub(r'[-\s]+', '-', slug)
return slug.strip('-')
@router.get('/', response_model=dict)
async def get_blog_posts(
request: Request,
page: int = Query(1, ge=1),
limit: int = Query(10, ge=1, le=50),
search: Optional[str] = Query(None),
tag: Optional[str] = Query(None),
published_only: bool = Query(True),
db: Session = Depends(get_db)
):
"""Get list of blog posts (public endpoint)"""
try:
query = db.query(BlogPost)
if published_only:
query = query.filter(BlogPost.is_published == True)
# Only show posts with published_at in the past or null
query = query.filter(
or_(
BlogPost.published_at <= datetime.utcnow(),
BlogPost.published_at.is_(None)
)
)
if search:
search_term = f"%{search}%"
query = query.filter(
or_(
BlogPost.title.ilike(search_term),
BlogPost.excerpt.ilike(search_term),
BlogPost.content.ilike(search_term)
)
)
if tag:
query = query.filter(BlogPost.tags.ilike(f'%"{tag}"%'))
total = query.count()
# Order by published_at descending, handling null values
# MySQL doesn't support NULLS LAST, so we use a CASE statement to put NULLs last
# Eager load author relationship to avoid N+1 queries
posts = query.options(joinedload(BlogPost.author)).order_by(
case((BlogPost.published_at.is_(None), 1), else_=0),
BlogPost.published_at.desc(),
BlogPost.created_at.desc()
).offset((page - 1) * limit).limit(limit).all()
base_url = get_base_url(request)
result = []
for post in posts:
tags_list = json.loads(post.tags) if post.tags else []
post_dict = {
'id': post.id,
'title': post.title,
'slug': post.slug,
'excerpt': post.excerpt,
'featured_image': normalize_image_url(post.featured_image, base_url) if post.featured_image else None,
'author_id': post.author_id,
'author_name': post.author.full_name if post.author else None,
'published_at': post.published_at.isoformat() if post.published_at else None,
'views_count': post.views_count,
'tags': tags_list,
'created_at': post.created_at.isoformat()
}
result.append(post_dict)
# Get all unique tags from all published posts (for filter display)
all_tags_query = db.query(BlogPost).filter(BlogPost.is_published == True)
if published_only:
all_tags_query = all_tags_query.filter(
or_(
BlogPost.published_at <= datetime.utcnow(),
BlogPost.published_at.is_(None)
)
)
all_posts_for_tags = all_tags_query.all()
all_unique_tags = set()
for post in all_posts_for_tags:
if post.tags:
try:
tags_list = json.loads(post.tags)
all_unique_tags.update(tags_list)
except:
pass
return success_response({
'posts': result,
'pagination': {
'page': page,
'limit': limit,
'total': total,
'pages': (total + limit - 1) // limit
},
'all_tags': sorted(list(all_unique_tags))
})
except Exception as e:
logger.error(f"Error in get_blog_posts: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get('/{slug}', response_model=dict)
async def get_blog_post_by_slug(
slug: str,
request: Request,
current_user: Optional[User] = Depends(get_current_user_optional),
db: Session = Depends(get_db)
):
"""Get a single blog post by slug (public endpoint)"""
try:
post = db.query(BlogPost).filter(BlogPost.slug == slug).first()
if not post:
raise HTTPException(status_code=404, detail='Blog post not found')
# Only show published posts to non-admin users
# Check if user is admin
from ..models.role import Role
is_admin = False
if current_user:
role = db.query(Role).filter(Role.id == current_user.role_id).first()
is_admin = role and role.name == 'admin'
if not is_admin and (not post.is_published or (post.published_at and post.published_at > datetime.utcnow())):
raise HTTPException(status_code=404, detail='Blog post not found')
# Increment views count
post.views_count += 1
db.commit()
base_url = get_base_url(request)
tags_list = json.loads(post.tags) if post.tags else []
sections_data = post.sections if post.sections else []
# Filter sections by is_visible if the field exists
if sections_data:
sections_data = [
section for section in sections_data
if section.get('is_visible', True) # Default to True if not specified
]
post_dict = {
'id': post.id,
'title': post.title,
'slug': post.slug,
'excerpt': post.excerpt,
'content': post.content,
'featured_image': normalize_image_url(post.featured_image, base_url) if post.featured_image else None,
'author_id': post.author_id,
'author_name': post.author.full_name if post.author else None,
'published_at': post.published_at.isoformat() if post.published_at else None,
'is_published': post.is_published,
'views_count': post.views_count,
'tags': tags_list,
'meta_title': post.meta_title,
'meta_description': post.meta_description,
'meta_keywords': post.meta_keywords,
'sections': sections_data,
'created_at': post.created_at.isoformat(),
'updated_at': post.updated_at.isoformat()
}
return success_response({'post': post_dict})
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get('/admin/posts', response_model=dict)
async def get_all_blog_posts_admin(
request: Request,
page: int = Query(1, ge=1),
limit: int = Query(20, ge=1, le=100),
search: Optional[str] = Query(None),
published: Optional[bool] = Query(None),
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Get all blog posts for admin (includes unpublished)"""
try:
query = db.query(BlogPost)
if search:
search_term = f"%{search}%"
query = query.filter(
or_(
BlogPost.title.ilike(search_term),
BlogPost.excerpt.ilike(search_term),
BlogPost.content.ilike(search_term)
)
)
if published is not None:
query = query.filter(BlogPost.is_published == published)
total = query.count()
posts = query.order_by(BlogPost.created_at.desc()).offset((page - 1) * limit).limit(limit).all()
base_url = get_base_url(request)
result = []
for post in posts:
tags_list = json.loads(post.tags) if post.tags else []
post_dict = {
'id': post.id,
'title': post.title,
'slug': post.slug,
'excerpt': post.excerpt,
'featured_image': normalize_image_url(post.featured_image, base_url) if post.featured_image else None,
'author_id': post.author_id,
'author_name': post.author.full_name if post.author else None,
'published_at': post.published_at.isoformat() if post.published_at else None,
'is_published': post.is_published,
'views_count': post.views_count,
'tags': tags_list,
'created_at': post.created_at.isoformat(),
'updated_at': post.updated_at.isoformat()
}
result.append(post_dict)
return success_response({
'posts': result,
'pagination': {
'page': page,
'limit': limit,
'total': total,
'pages': (total + limit - 1) // limit
}
})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get('/admin/posts/{id}', response_model=dict)
async def get_blog_post_admin(
id: int,
request: Request,
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Get a single blog post by ID for admin"""
try:
post = db.query(BlogPost).filter(BlogPost.id == id).first()
if not post:
raise HTTPException(status_code=404, detail='Blog post not found')
base_url = get_base_url(request)
tags_list = json.loads(post.tags) if post.tags else []
sections_data = post.sections if post.sections else []
post_dict = {
'id': post.id,
'title': post.title,
'slug': post.slug,
'excerpt': post.excerpt,
'content': post.content,
'featured_image': normalize_image_url(post.featured_image, base_url) if post.featured_image else None,
'author_id': post.author_id,
'author_name': post.author.full_name if post.author else None,
'published_at': post.published_at.isoformat() if post.published_at else None,
'is_published': post.is_published,
'views_count': post.views_count,
'tags': tags_list,
'meta_title': post.meta_title,
'meta_description': post.meta_description,
'meta_keywords': post.meta_keywords,
'sections': sections_data,
'created_at': post.created_at.isoformat(),
'updated_at': post.updated_at.isoformat()
}
return success_response({'post': post_dict})
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post('/admin/posts', response_model=dict, status_code=status.HTTP_201_CREATED)
async def create_blog_post(
post_data: BlogPostCreate,
request: Request,
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Create a new blog post"""
try:
# Generate slug if not provided
slug = post_data.slug or generate_slug(post_data.title)
# Ensure slug is unique
existing = db.query(BlogPost).filter(BlogPost.slug == slug).first()
if existing:
counter = 1
base_slug = slug
while existing:
slug = f"{base_slug}-{counter}"
existing = db.query(BlogPost).filter(BlogPost.slug == slug).first()
counter += 1
# Convert tags list to JSON string
tags_json = json.dumps(post_data.tags) if post_data.tags else None
# Set published_at if publishing
published_at = post_data.published_at
if post_data.is_published and not published_at:
published_at = datetime.utcnow()
# Handle sections - ensure it's a valid list or None
sections_data = post_data.sections if post_data.sections else []
post = BlogPost(
title=post_data.title,
slug=slug,
excerpt=post_data.excerpt,
content=post_data.content,
featured_image=post_data.featured_image,
author_id=current_user.id,
tags=tags_json,
meta_title=post_data.meta_title,
meta_description=post_data.meta_description,
meta_keywords=post_data.meta_keywords,
is_published=post_data.is_published,
published_at=published_at,
sections=sections_data
)
db.add(post)
db.commit()
db.refresh(post)
base_url = get_base_url(request)
tags_list = json.loads(post.tags) if post.tags else []
post_dict = {
'id': post.id,
'title': post.title,
'slug': post.slug,
'excerpt': post.excerpt,
'content': post.content,
'featured_image': normalize_image_url(post.featured_image, base_url) if post.featured_image else None,
'author_id': post.author_id,
'author_name': post.author.full_name if post.author else None,
'published_at': post.published_at.isoformat() if post.published_at else None,
'is_published': post.is_published,
'views_count': post.views_count,
'tags': tags_list,
'meta_title': post.meta_title,
'meta_description': post.meta_description,
'meta_keywords': post.meta_keywords,
'created_at': post.created_at.isoformat(),
'updated_at': post.updated_at.isoformat()
}
return success_response({'post': post_dict}, message='Blog post created successfully')
except Exception as e:
db.rollback()
raise HTTPException(status_code=500, detail=str(e))
@router.put('/admin/posts/{id}', response_model=dict)
async def update_blog_post(
id: int,
post_data: BlogPostUpdate,
request: Request,
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Update a blog post"""
try:
post = db.query(BlogPost).filter(BlogPost.id == id).first()
if not post:
raise HTTPException(status_code=404, detail='Blog post not found')
# Update fields
if post_data.title is not None:
post.title = post_data.title
# Regenerate slug if title changed and slug wasn't explicitly provided
if post_data.slug is None:
post.slug = generate_slug(post_data.title)
# Ensure slug is unique
existing = db.query(BlogPost).filter(BlogPost.slug == post.slug, BlogPost.id != id).first()
if existing:
counter = 1
base_slug = post.slug
while existing:
post.slug = f"{base_slug}-{counter}"
existing = db.query(BlogPost).filter(BlogPost.slug == post.slug, BlogPost.id != id).first()
counter += 1
if post_data.slug is not None:
# Ensure new slug is unique
existing = db.query(BlogPost).filter(BlogPost.slug == post_data.slug, BlogPost.id != id).first()
if existing:
raise HTTPException(status_code=400, detail='Slug already exists')
post.slug = post_data.slug
if post_data.excerpt is not None:
post.excerpt = post_data.excerpt
if post_data.content is not None:
post.content = post_data.content
if post_data.featured_image is not None:
post.featured_image = post_data.featured_image
if post_data.tags is not None:
post.tags = json.dumps(post_data.tags)
if post_data.meta_title is not None:
post.meta_title = post_data.meta_title
if post_data.meta_description is not None:
post.meta_description = post_data.meta_description
if post_data.meta_keywords is not None:
post.meta_keywords = post_data.meta_keywords
if post_data.sections is not None:
# Handle sections - ensure it's a valid JSON-serializable list
if isinstance(post_data.sections, list):
post.sections = post_data.sections if post_data.sections else []
else:
post.sections = []
if post_data.is_published is not None:
post.is_published = post_data.is_published
# Set published_at if publishing for the first time
if post_data.is_published and not post.published_at:
post.published_at = datetime.utcnow()
if post_data.published_at is not None:
post.published_at = post_data.published_at
db.commit()
db.refresh(post)
base_url = get_base_url(request)
tags_list = json.loads(post.tags) if post.tags else []
sections_data = post.sections if post.sections else []
post_dict = {
'id': post.id,
'title': post.title,
'slug': post.slug,
'excerpt': post.excerpt,
'content': post.content,
'featured_image': normalize_image_url(post.featured_image, base_url) if post.featured_image else None,
'author_id': post.author_id,
'author_name': post.author.full_name if post.author else None,
'published_at': post.published_at.isoformat() if post.published_at else None,
'is_published': post.is_published,
'views_count': post.views_count,
'tags': tags_list,
'meta_title': post.meta_title,
'meta_description': post.meta_description,
'meta_keywords': post.meta_keywords,
'sections': sections_data,
'created_at': post.created_at.isoformat(),
'updated_at': post.updated_at.isoformat()
}
return success_response({'post': post_dict}, message='Blog post updated successfully')
except HTTPException:
raise
except Exception as e:
db.rollback()
import traceback
error_detail = f'{str(e)}\n{traceback.format_exc()}'
logger.error(f'Error updating blog post: {error_detail}')
raise HTTPException(status_code=500, detail=f'Error updating blog post: {str(e)}')
@router.delete('/admin/posts/{id}', response_model=dict)
async def delete_blog_post(
id: int,
current_user: User = Depends(authorize_roles('admin')),
db: Session = Depends(get_db)
):
"""Delete a blog post"""
try:
post = db.query(BlogPost).filter(BlogPost.id == id).first()
if not post:
raise HTTPException(status_code=404, detail='Blog post not found')
db.delete(post)
db.commit()
return success_response(None, message='Blog post deleted successfully')
except HTTPException:
raise
except Exception as e:
db.rollback()
raise HTTPException(status_code=500, detail=str(e))
@router.post('/admin/upload-image')
async def upload_blog_image(
request: Request,
image: UploadFile = File(...),
current_user: User = Depends(authorize_roles('admin'))
):
"""Upload an image for blog posts"""
try:
if not image:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='No file provided')
if not image.content_type or not image.content_type.startswith('image/'):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f'File must be an image. Received: {image.content_type}')
if not image.filename:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail='Filename is required')
upload_dir = Path(__file__).parent.parent.parent / 'uploads' / 'blog'
upload_dir.mkdir(parents=True, exist_ok=True)
ext = Path(image.filename).suffix or '.jpg'
filename = f'blog-{uuid.uuid4()}{ext}'
file_path = upload_dir / filename
# Validate file
from ..config.settings import settings
from ..utils.file_validation import validate_uploaded_image
max_size = settings.MAX_UPLOAD_SIZE
content = await validate_uploaded_image(image, max_size)
async with aiofiles.open(file_path, 'wb') as f:
await f.write(content)
image_url = f'/uploads/blog/{filename}'
base_url = get_base_url(request)
full_url = normalize_image_url(image_url, base_url)
return success_response({
'image_url': image_url,
'full_url': full_url
}, message='Image uploaded successfully')
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f'Error uploading image: {str(e)}')

View File

@@ -92,6 +92,7 @@ manager = ConnectionManager()
@router.post('/create', status_code=status.HTTP_201_CREATED) @router.post('/create', status_code=status.HTTP_201_CREATED)
async def create_chat(visitor_name: Optional[str]=None, visitor_email: Optional[str]=None, visitor_phone: Optional[str]=None, current_user: Optional[User]=Depends(get_current_user_optional), db: Session=Depends(get_db)): async def create_chat(visitor_name: Optional[str]=None, visitor_email: Optional[str]=None, visitor_phone: Optional[str]=None, current_user: Optional[User]=Depends(get_current_user_optional), db: Session=Depends(get_db)):
try:
if current_user: if current_user:
chat = Chat(visitor_id=current_user.id, visitor_name=current_user.full_name, visitor_email=current_user.email, status=ChatStatus.pending) chat = Chat(visitor_id=current_user.id, visitor_name=current_user.full_name, visitor_email=current_user.email, status=ChatStatus.pending)
else: else:
@@ -106,9 +107,15 @@ async def create_chat(visitor_name: Optional[str]=None, visitor_email: Optional[
chat_data = {'id': chat.id, 'visitor_name': chat.visitor_name, 'visitor_email': chat.visitor_email, 'status': chat.status.value, 'created_at': chat.created_at.isoformat()} chat_data = {'id': chat.id, 'visitor_name': chat.visitor_name, 'visitor_email': chat.visitor_email, 'status': chat.status.value, 'created_at': chat.created_at.isoformat()}
await manager.notify_staff_new_chat(chat_data) await manager.notify_staff_new_chat(chat_data)
return {'success': True, 'data': {'id': chat.id, 'visitor_name': chat.visitor_name, 'visitor_email': chat.visitor_email, 'status': chat.status.value, 'created_at': chat.created_at.isoformat()}} return {'success': True, 'data': {'id': chat.id, 'visitor_name': chat.visitor_name, 'visitor_email': chat.visitor_email, 'status': chat.status.value, 'created_at': chat.created_at.isoformat()}}
except HTTPException:
raise
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
@router.post('/{chat_id}/accept') @router.post('/{chat_id}/accept')
async def accept_chat(chat_id: int, current_user: User=Depends(get_current_user), db: Session=Depends(get_db)): async def accept_chat(chat_id: int, current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
try:
if current_user.role.name not in ['staff', 'admin']: if current_user.role.name not in ['staff', 'admin']:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail='Only staff members can accept chats') raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail='Only staff members can accept chats')
chat = db.query(Chat).filter(Chat.id == chat_id).first() chat = db.query(Chat).filter(Chat.id == chat_id).first()
@@ -122,6 +129,11 @@ async def accept_chat(chat_id: int, current_user: User=Depends(get_current_user)
db.refresh(chat) db.refresh(chat)
await manager.broadcast_to_chat({'type': 'chat_accepted', 'data': {'staff_name': current_user.full_name, 'staff_id': current_user.id}}, chat_id) await manager.broadcast_to_chat({'type': 'chat_accepted', 'data': {'staff_name': current_user.full_name, 'staff_id': current_user.id}}, chat_id)
return {'success': True, 'data': {'id': chat.id, 'staff_id': chat.staff_id, 'staff_name': current_user.full_name, 'status': chat.status.value}} return {'success': True, 'data': {'id': chat.id, 'staff_id': chat.staff_id, 'staff_name': current_user.full_name, 'status': chat.status.value}}
except HTTPException:
raise
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
@router.get('/list') @router.get('/list')
async def list_chats(status_filter: Optional[str]=None, current_user: User=Depends(get_current_user), db: Session=Depends(get_db)): async def list_chats(status_filter: Optional[str]=None, current_user: User=Depends(get_current_user), db: Session=Depends(get_db)):
@@ -165,6 +177,7 @@ async def get_messages(chat_id: int, current_user: Optional[User]=Depends(get_cu
@router.post('/{chat_id}/message') @router.post('/{chat_id}/message')
async def send_message(chat_id: int, message: str, current_user: Optional[User]=Depends(get_current_user_optional), db: Session=Depends(get_db)): async def send_message(chat_id: int, message: str, current_user: Optional[User]=Depends(get_current_user_optional), db: Session=Depends(get_db)):
try:
chat = db.query(Chat).filter(Chat.id == chat_id).first() chat = db.query(Chat).filter(Chat.id == chat_id).first()
if not chat: if not chat:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail='Chat not found') raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail='Chat not found')
@@ -193,9 +206,15 @@ async def send_message(chat_id: int, message: str, current_user: Optional[User]=
if chat_message.sender_type == 'visitor': if chat_message.sender_type == 'visitor':
await manager.notify_staff_new_message(chat_id, message_data['data'], chat) await manager.notify_staff_new_message(chat_id, message_data['data'], chat)
return {'success': True, 'data': {'id': chat_message.id, 'chat_id': chat_message.chat_id, 'sender_type': chat_message.sender_type, 'message': chat_message.message, 'created_at': chat_message.created_at.isoformat()}} return {'success': True, 'data': {'id': chat_message.id, 'chat_id': chat_message.chat_id, 'sender_type': chat_message.sender_type, 'message': chat_message.message, 'created_at': chat_message.created_at.isoformat()}}
except HTTPException:
raise
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
@router.post('/{chat_id}/close') @router.post('/{chat_id}/close')
async def close_chat(chat_id: int, current_user: Optional[User]=Depends(get_current_user_optional), db: Session=Depends(get_db)): async def close_chat(chat_id: int, current_user: Optional[User]=Depends(get_current_user_optional), db: Session=Depends(get_db)):
try:
chat = db.query(Chat).filter(Chat.id == chat_id).first() chat = db.query(Chat).filter(Chat.id == chat_id).first()
if not chat: if not chat:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail='Chat not found') raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail='Chat not found')
@@ -203,19 +222,32 @@ async def close_chat(chat_id: int, current_user: Optional[User]=Depends(get_curr
if current_user.role.name not in ['staff', 'admin']: if current_user.role.name not in ['staff', 'admin']:
if chat.visitor_id != current_user.id: if chat.visitor_id != current_user.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You don't have permission to close this chat") raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You don't have permission to close this chat")
else:
pass
chat.status = ChatStatus.closed chat.status = ChatStatus.closed
chat.closed_at = datetime.utcnow() chat.closed_at = datetime.utcnow()
db.commit() db.commit()
await manager.broadcast_to_chat({'type': 'chat_closed', 'data': {'chat_id': chat_id}}, chat_id) await manager.broadcast_to_chat({'type': 'chat_closed', 'data': {'chat_id': chat_id}}, chat_id)
return {'success': True, 'data': {'id': chat.id, 'status': chat.status.value}} return {'success': True, 'data': {'id': chat.id, 'status': chat.status.value}}
except HTTPException:
raise
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
@router.websocket('/ws/{chat_id}') @router.websocket('/ws/{chat_id}')
async def websocket_chat(websocket: WebSocket, chat_id: int, user_type: str=None, token: Optional[str]=None): async def websocket_chat(websocket: WebSocket, chat_id: int, user_type: str=None, token: Optional[str]=None):
query_params = dict(websocket.query_params) query_params = dict(websocket.query_params)
user_type = query_params.get('user_type', 'visitor') user_type = query_params.get('user_type', 'visitor')
# Security: Read token from cookies instead of query parameters
# Cookies are sent automatically with WebSocket connections
token = None
if websocket.cookies:
token = websocket.cookies.get('accessToken')
# Fallback to query parameter for backward compatibility (deprecated)
if not token:
token = query_params.get('token') token = query_params.get('token')
current_user = None current_user = None
if user_type == 'staff' and token: if user_type == 'staff' and token:
try: try:
@@ -256,6 +288,9 @@ async def websocket_chat(websocket: WebSocket, chat_id: int, user_type: str=None
db.add(chat_message) db.add(chat_message)
db.commit() db.commit()
db.refresh(chat_message) db.refresh(chat_message)
except Exception:
db.rollback()
raise
finally: finally:
db.close() db.close()
message_data = {'type': 'new_message', 'data': {'id': chat_message.id, 'chat_id': chat_message.chat_id, 'sender_id': chat_message.sender_id, 'sender_type': chat_message.sender_type, 'sender_name': chat_message.sender.full_name if chat_message.sender else None, 'message': chat_message.message, 'is_read': chat_message.is_read, 'created_at': chat_message.created_at.isoformat()}} message_data = {'type': 'new_message', 'data': {'id': chat_message.id, 'chat_id': chat_message.chat_id, 'sender_id': chat_message.sender_id, 'sender_type': chat_message.sender_type, 'sender_name': chat_message.sender.full_name if chat_message.sender else None, 'message': chat_message.message, 'is_read': chat_message.is_read, 'created_at': chat_message.created_at.isoformat()}}
@@ -272,8 +307,18 @@ async def websocket_staff_notifications(websocket: WebSocket):
current_user = None current_user = None
try: try:
await websocket.accept() await websocket.accept()
# Security: Read token from cookies instead of query parameters
# Cookies are sent automatically with WebSocket connections
token = None
if websocket.cookies:
token = websocket.cookies.get('accessToken')
# Fallback to query parameter for backward compatibility (deprecated)
if not token:
query_params = dict(websocket.query_params) query_params = dict(websocket.query_params)
token = query_params.get('token') token = query_params.get('token')
if not token: if not token:
await websocket.close(code=1008, reason='Token required') await websocket.close(code=1008, reason='Token required')
return return

View File

@@ -1068,7 +1068,10 @@ async def test_smtp_email(
): ):
try: try:
test_email = str(request.email) test_email = str(request.email)
admin_name = str(current_user.full_name or current_user.email or "Admin") # Sanitize admin name to prevent XSS/email injection
from html import escape
admin_name_raw = current_user.full_name or current_user.email or "Admin"
admin_name = escape(str(admin_name_raw)) # HTML escape to prevent XSS
timestamp_str = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC") timestamp_str = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC")

Binary file not shown.

View File

@@ -94,3 +94,42 @@ class VerifyMFARequest(BaseModel):
class MFAStatusResponse(BaseModel): class MFAStatusResponse(BaseModel):
mfa_enabled: bool mfa_enabled: bool
backup_codes_count: int backup_codes_count: int
class UpdateProfileRequest(BaseModel):
full_name: Optional[str] = Field(None, min_length=2, max_length=100, description='Full name')
email: Optional[EmailStr] = Field(None, description='Email address')
phone_number: Optional[str] = Field(None, min_length=5, max_length=20, description='Phone number')
password: Optional[str] = Field(None, min_length=8, description='New password')
currentPassword: Optional[str] = Field(None, alias='current_password', description='Current password (required when changing password)')
currency: Optional[str] = Field(None, min_length=3, max_length=3, description='Currency code (ISO 4217, e.g., USD, EUR, VND)')
@validator('password')
def validate_password(cls, v):
if v is not None:
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
if not any((c.isupper() for c in v)):
raise ValueError('Password must contain at least one uppercase letter')
if not any((c.islower() for c in v)):
raise ValueError('Password must contain at least one lowercase letter')
if not any((c.isdigit() for c in v)):
raise ValueError('Password must contain at least one number')
return v
@validator('phone_number')
def validate_phone(cls, v):
if v is not None:
cleaned = ''.join(c for c in v if c.isdigit())
if len(cleaned) < 5:
raise ValueError('Phone number must contain at least 5 digits')
return v
@validator('currency')
def validate_currency(cls, v):
if v is not None:
if len(v) != 3 or not v.isalpha():
raise ValueError('Currency must be a 3-letter ISO 4217 code (e.g., USD, EUR, VND)')
return v.upper() if v else v
class Config:
allow_population_by_field_name = True

View File

@@ -0,0 +1,62 @@
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any
from datetime import datetime
class BlogPostBase(BaseModel):
title: str = Field(..., min_length=1, max_length=500)
slug: Optional[str] = None
excerpt: Optional[str] = None
content: str = Field(..., min_length=1)
featured_image: Optional[str] = None
tags: Optional[List[str]] = None
meta_title: Optional[str] = Field(None, max_length=500)
meta_description: Optional[str] = None
meta_keywords: Optional[str] = Field(None, max_length=1000)
is_published: bool = False
published_at: Optional[datetime] = None
sections: Optional[List[Dict[str, Any]]] = None # Structured content sections
class BlogPostCreate(BlogPostBase):
pass
class BlogPostUpdate(BaseModel):
title: Optional[str] = Field(None, min_length=1, max_length=500)
slug: Optional[str] = None
excerpt: Optional[str] = None
content: Optional[str] = Field(None, min_length=1)
featured_image: Optional[str] = None
tags: Optional[List[str]] = None
meta_title: Optional[str] = Field(None, max_length=500)
meta_description: Optional[str] = None
meta_keywords: Optional[str] = Field(None, max_length=1000)
is_published: Optional[bool] = None
published_at: Optional[datetime] = None
sections: Optional[List[Dict[str, Any]]] = None # Structured content sections
class BlogPostResponse(BlogPostBase):
id: int
author_id: int
author_name: Optional[str] = None
views_count: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class BlogPostListResponse(BaseModel):
id: int
title: str
slug: str
excerpt: Optional[str]
featured_image: Optional[str]
author_id: int
author_name: Optional[str] = None
published_at: Optional[datetime]
views_count: int
tags: Optional[List[str]] = None
created_at: datetime
class Config:
from_attributes = True

View File

@@ -24,9 +24,43 @@ logger = logging.getLogger(__name__)
class AuthService: class AuthService:
def __init__(self): def __init__(self):
# Security: Fail fast if JWT_SECRET is not configured - never use default values
self.jwt_secret = getattr(settings, 'JWT_SECRET', None) or os.getenv("JWT_SECRET")
if not self.jwt_secret:
error_msg = (
'CRITICAL: JWT_SECRET is not configured. '
'Please set JWT_SECRET environment variable to a secure random string (minimum 32 characters).'
)
logger.error(error_msg)
if settings.is_production:
raise ValueError(error_msg)
else:
# In development, generate a secure secret but warn
import secrets
self.jwt_secret = secrets.token_urlsafe(64)
logger.warning(
f'JWT_SECRET not configured. Auto-generated secret for development. '
f'Set JWT_SECRET environment variable for production: {self.jwt_secret}'
)
# Validate JWT secret strength
if len(self.jwt_secret) < 32:
error_msg = 'JWT_SECRET must be at least 32 characters long for security.'
logger.error(error_msg)
if settings.is_production:
raise ValueError(error_msg)
else:
logger.warning(error_msg)
# Refresh secret should be different from access secret
self.jwt_refresh_secret = os.getenv("JWT_REFRESH_SECRET")
if not self.jwt_refresh_secret:
# Use a derived secret if not explicitly set, but different from access secret
import hashlib
self.jwt_refresh_secret = hashlib.sha256((self.jwt_secret + "-refresh").encode()).hexdigest()
if not settings.is_production:
logger.info('JWT_REFRESH_SECRET not set, using derived secret')
self.jwt_secret = getattr(settings, 'JWT_SECRET', None) or os.getenv("JWT_SECRET", "dev-secret-key-change-in-production-12345")
self.jwt_refresh_secret = os.getenv("JWT_REFRESH_SECRET") or (self.jwt_secret + "-refresh")
self.jwt_expires_in = os.getenv("JWT_EXPIRES_IN", "1h") self.jwt_expires_in = os.getenv("JWT_EXPIRES_IN", "1h")
self.jwt_refresh_expires_in = os.getenv("JWT_REFRESH_EXPIRES_IN", "7d") self.jwt_refresh_expires_in = os.getenv("JWT_REFRESH_EXPIRES_IN", "7d")

View File

@@ -177,7 +177,7 @@ def test_staff_user(db_session, test_staff_role):
@pytest.fixture @pytest.fixture
def auth_token(client, test_user): def auth_token(client, test_user):
"""Get authentication token for test user.""" """Get authentication token for test user (from cookies)."""
response = client.post( response = client.post(
"/api/auth/login", "/api/auth/login",
json={ json={
@@ -186,13 +186,16 @@ def auth_token(client, test_user):
} }
) )
if response.status_code == 200: if response.status_code == 200:
return response.json()["data"]["token"] # Token is now in httpOnly cookie, return cookie value for testing
# In real usage, cookies are sent automatically
cookie_token = response.cookies.get("accessToken")
return cookie_token
return None return None
@pytest.fixture @pytest.fixture
def admin_token(client, test_admin_user): def admin_token(client, test_admin_user):
"""Get authentication token for admin user.""" """Get authentication token for admin user (from cookies)."""
response = client.post( response = client.post(
"/api/auth/login", "/api/auth/login",
json={ json={
@@ -201,13 +204,15 @@ def admin_token(client, test_admin_user):
} }
) )
if response.status_code == 200: if response.status_code == 200:
return response.json()["data"]["token"] # Token is now in httpOnly cookie
cookie_token = response.cookies.get("accessToken")
return cookie_token
return None return None
@pytest.fixture @pytest.fixture
def staff_token(client, test_staff_user): def staff_token(client, test_staff_user):
"""Get authentication token for staff user.""" """Get authentication token for staff user (from cookies)."""
response = client.post( response = client.post(
"/api/auth/login", "/api/auth/login",
json={ json={
@@ -216,21 +221,39 @@ def staff_token(client, test_staff_user):
} }
) )
if response.status_code == 200: if response.status_code == 200:
return response.json()["data"]["token"] # Token is now in httpOnly cookie
cookie_token = response.cookies.get("accessToken")
return cookie_token
return None return None
@pytest.fixture @pytest.fixture
def authenticated_client(client, auth_token): def authenticated_client(client, test_user):
"""Create an authenticated test client.""" """Create an authenticated test client (uses cookies)."""
client.headers.update({"Authorization": f"Bearer {auth_token}"}) # Login to set cookies
response = client.post(
"/api/auth/login",
json={
"email": "test@example.com",
"password": "testpassword123"
}
)
# Cookies are automatically sent with subsequent requests
return client return client
@pytest.fixture @pytest.fixture
def admin_client(client, admin_token): def admin_client(client, test_admin_user):
"""Create an authenticated admin test client.""" """Create an authenticated admin test client (uses cookies)."""
client.headers.update({"Authorization": f"Bearer {admin_token}"}) # Login to set cookies
response = client.post(
"/api/auth/login",
json={
"email": "admin@example.com",
"password": "adminpassword123"
}
)
# Cookies are automatically sent with subsequent requests
return client return client

View File

@@ -0,0 +1,229 @@
"""
Tests for authentication routes.
Tests critical authentication flows including login, registration, and token management.
"""
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from ..main import app
from ..config.database import Base, get_db
from ..models.user import User
from ..models.role import Role
import bcrypt
# Test database setup
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_get_db
@pytest.fixture(scope="function")
def db():
"""Create a fresh database for each test."""
Base.metadata.create_all(bind=engine)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
Base.metadata.drop_all(bind=engine)
@pytest.fixture(scope="function")
def client():
"""Create a test client."""
return TestClient(app)
@pytest.fixture(scope="function")
def test_user(db: Session):
"""Create a test user."""
# Create customer role if it doesn't exist
role = db.query(Role).filter(Role.name == 'customer').first()
if not role:
role = Role(name='customer', description='Customer role')
db.add(role)
db.commit()
db.refresh(role)
# Create test user
hashed_password = bcrypt.hashpw("testpassword123".encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
user = User(
email="test@example.com",
password=hashed_password,
full_name="Test User",
role_id=role.id,
is_active=True
)
db.add(user)
db.commit()
db.refresh(user)
return user
def test_register_success(client: TestClient, db: Session):
"""Test successful user registration."""
response = client.post(
"/api/auth/register",
json={
"name": "New User",
"email": "newuser@example.com",
"password": "SecurePass123!",
"phone": "1234567890"
}
)
assert response.status_code == 201
data = response.json()
assert data["status"] == "success"
assert "user" in data["data"]
assert data["data"]["user"]["email"] == "newuser@example.com"
# Verify user was created in database
user = db.query(User).filter(User.email == "newuser@example.com").first()
assert user is not None
assert user.full_name == "New User"
# Verify cookies were set
assert "accessToken" in response.cookies
assert "refreshToken" in response.cookies
def test_register_duplicate_email(client: TestClient, test_user: User):
"""Test registration with duplicate email fails."""
response = client.post(
"/api/auth/register",
json={
"name": "Another User",
"email": test_user.email,
"password": "SecurePass123!",
"phone": "1234567890"
}
)
assert response.status_code == 400
data = response.json()
assert data["status"] == "error"
def test_login_success(client: TestClient, test_user: User):
"""Test successful login."""
response = client.post(
"/api/auth/login",
json={
"email": test_user.email,
"password": "testpassword123",
"rememberMe": False
}
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert "user" in data["data"]
assert data["data"]["user"]["email"] == test_user.email
# Verify cookies were set
assert "accessToken" in response.cookies
assert "refreshToken" in response.cookies
def test_login_invalid_credentials(client: TestClient, test_user: User):
"""Test login with invalid credentials fails."""
response = client.post(
"/api/auth/login",
json={
"email": test_user.email,
"password": "wrongpassword",
"rememberMe": False
}
)
assert response.status_code == 401
data = response.json()
assert data["status"] == "error"
def test_login_nonexistent_user(client: TestClient):
"""Test login with non-existent user fails."""
response = client.post(
"/api/auth/login",
json={
"email": "nonexistent@example.com",
"password": "password123",
"rememberMe": False
}
)
assert response.status_code == 401
data = response.json()
assert data["status"] == "error"
def test_logout_success(client: TestClient, test_user: User):
"""Test successful logout."""
# First login to get cookies
login_response = client.post(
"/api/auth/login",
json={
"email": test_user.email,
"password": "testpassword123",
"rememberMe": False
}
)
assert login_response.status_code == 200
# Now logout
logout_response = client.post("/api/auth/logout")
assert logout_response.status_code == 200
data = logout_response.json()
assert data["status"] == "success"
# Verify cookies were deleted
# Note: FastAPI TestClient doesn't properly handle cookie deletion in tests
# In real usage, cookies would be deleted
def test_get_profile_authenticated(client: TestClient, test_user: User):
"""Test getting profile when authenticated."""
# Login first
login_response = client.post(
"/api/auth/login",
json={
"email": test_user.email,
"password": "testpassword123",
"rememberMe": False
}
)
assert login_response.status_code == 200
# Get profile
profile_response = client.get("/api/auth/profile")
assert profile_response.status_code == 200
data = profile_response.json()
assert data["status"] == "success"
assert data["data"]["user"]["email"] == test_user.email
def test_get_profile_unauthenticated(client: TestClient):
"""Test getting profile when not authenticated fails."""
response = client.get("/api/auth/profile")
assert response.status_code == 401
def test_transaction_rollback_on_error(client: TestClient, db: Session):
"""Test that database transactions rollback on error."""
initial_count = db.query(User).count()
# Try to create user with invalid data (missing required field)
# This should fail and rollback
response = client.post(
"/api/auth/register",
json={
"name": "Test User",
# Missing email - should fail validation
"password": "SecurePass123!"
}
)
# Should fail validation
assert response.status_code == 422
# Verify no user was created (transaction rolled back)
final_count = db.query(User).count()
assert initial_count == final_count

View File

@@ -0,0 +1,224 @@
"""
Tests for booking routes.
Tests critical booking flows including creation, updates, and transaction rollbacks.
"""
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from datetime import datetime, timedelta
from ..main import app
from ..config.database import Base, get_db
from ..models.user import User
from ..models.role import Role
from ..models.room import Room, RoomStatus
from ..models.room_type import RoomType
from ..models.booking import Booking, BookingStatus
import bcrypt
# Test database setup
SQLALCHEMY_DATABASE_URL = "sqlite:///./test_booking.db"
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_get_db
@pytest.fixture(scope="function")
def db():
"""Create a fresh database for each test."""
Base.metadata.create_all(bind=engine)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
Base.metadata.drop_all(bind=engine)
@pytest.fixture(scope="function")
def client():
"""Create a test client."""
return TestClient(app)
@pytest.fixture(scope="function")
def test_customer(db: Session):
"""Create a test customer user."""
role = db.query(Role).filter(Role.name == 'customer').first()
if not role:
role = Role(name='customer', description='Customer role')
db.add(role)
db.commit()
db.refresh(role)
hashed_password = bcrypt.hashpw("password123".encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
user = User(
email="customer@example.com",
password=hashed_password,
full_name="Test Customer",
role_id=role.id,
is_active=True
)
db.add(user)
db.commit()
db.refresh(user)
return user
@pytest.fixture(scope="function")
def test_room(db: Session):
"""Create a test room."""
room_type = RoomType(
name="Standard Room",
description="A standard room",
base_price=100.0,
capacity=2
)
db.add(room_type)
db.commit()
db.refresh(room_type)
room = Room(
room_type_id=room_type.id,
room_number="101",
floor=1,
status=RoomStatus.available,
price=100.0
)
db.add(room)
db.commit()
db.refresh(room)
return room
def test_create_booking_success(client: TestClient, test_customer: User, test_room: Room):
"""Test successful booking creation."""
# Login first
login_response = client.post(
"/api/auth/login",
json={
"email": test_customer.email,
"password": "password123",
"rememberMe": False
}
)
assert login_response.status_code == 200
# Create booking
check_in = (datetime.now() + timedelta(days=1)).strftime('%Y-%m-%d')
check_out = (datetime.now() + timedelta(days=3)).strftime('%Y-%m-%d')
response = client.post(
"/api/bookings/",
json={
"room_id": test_room.id,
"check_in_date": check_in,
"check_out_date": check_out,
"total_price": 200.0,
"guest_count": 2,
"payment_method": "cash"
}
)
assert response.status_code == 201
data = response.json()
assert data["status"] == "success"
assert "booking" in data["data"]
assert data["data"]["booking"]["room_id"] == test_room.id
def test_create_booking_invalid_dates(client: TestClient, test_customer: User, test_room: Room):
"""Test booking creation with invalid dates fails."""
# Login first
login_response = client.post(
"/api/auth/login",
json={
"email": test_customer.email,
"password": "password123",
"rememberMe": False
}
)
assert login_response.status_code == 200
# Try to create booking with check-out before check-in
check_in = (datetime.now() + timedelta(days=3)).strftime('%Y-%m-%d')
check_out = (datetime.now() + timedelta(days=1)).strftime('%Y-%m-%d')
response = client.post(
"/api/bookings/",
json={
"room_id": test_room.id,
"check_in_date": check_in,
"check_out_date": check_out,
"total_price": 200.0,
"guest_count": 2,
"payment_method": "cash"
}
)
assert response.status_code == 400
def test_create_booking_transaction_rollback(client: TestClient, db: Session, test_customer: User, test_room: Room):
"""Test that booking creation rolls back on error."""
initial_count = db.query(Booking).count()
# Login first
login_response = client.post(
"/api/auth/login",
json={
"email": test_customer.email,
"password": "password123",
"rememberMe": False
}
)
assert login_response.status_code == 200
# Try to create booking with invalid data (missing required field)
check_in = (datetime.now() + timedelta(days=1)).strftime('%Y-%m-%d')
response = client.post(
"/api/bookings/",
json={
"room_id": test_room.id,
"check_in_date": check_in,
# Missing check_out_date - should fail validation
"total_price": 200.0,
"guest_count": 2
}
)
# Should fail validation
assert response.status_code == 422
# Verify no booking was created (transaction rolled back)
final_count = db.query(Booking).count()
assert initial_count == final_count
def test_get_my_bookings(client: TestClient, test_customer: User):
"""Test getting user's bookings."""
# Login first
login_response = client.post(
"/api/auth/login",
json={
"email": test_customer.email,
"password": "password123",
"rememberMe": False
}
)
assert login_response.status_code == 200
# Get bookings
response = client.get("/api/bookings/me")
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert "bookings" in data["data"]
def test_get_my_bookings_unauthenticated(client: TestClient):
"""Test getting bookings when not authenticated fails."""
response = client.get("/api/bookings/me")
assert response.status_code == 401

View File

@@ -0,0 +1,93 @@
"""
Standardized error handling utilities for route handlers.
This module provides decorators and helpers for consistent error handling
across all route handlers.
"""
from functools import wraps
from fastapi import HTTPException
from sqlalchemy.orm import Session
from typing import Callable, Any
from ..config.logging_config import get_logger
logger = get_logger(__name__)
def handle_db_errors(func: Callable) -> Callable:
"""
Decorator to handle database errors consistently.
Automatically rolls back transactions on errors.
Usage:
@router.post('/')
@handle_db_errors
async def create_item(..., db: Session = Depends(get_db)):
db.add(item)
db.commit()
return item
Note: This decorator expects 'db' to be in kwargs or as a positional argument.
"""
@wraps(func)
async def wrapper(*args, **kwargs):
# Find db session in kwargs or args
db = kwargs.get('db')
if not db:
# Try to find in args (if db is positional)
for arg in args:
if isinstance(arg, Session):
db = arg
break
try:
return await func(*args, **kwargs)
except HTTPException:
# Re-raise HTTPExceptions as-is (they're intentional)
raise
except ValueError as e:
# ValueErrors are usually validation errors
if db:
db.rollback()
logger.warning(f'ValueError in {func.__name__}: {str(e)}')
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
# All other exceptions are server errors
if db:
db.rollback()
logger.error(f'Error in {func.__name__}: {type(e).__name__}: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error")
return wrapper
def safe_db_operation(operation: Callable, db: Session, error_message: str = "Database operation failed") -> Any:
"""
Safely execute a database operation with automatic rollback on error.
Usage:
result = safe_db_operation(
lambda: db.add(item) or db.commit(),
db,
"Failed to create item"
)
Args:
operation: Callable that performs the database operation
db: Database session
error_message: Custom error message to use on failure
Returns:
Result of the operation
Raises:
HTTPException: If the operation fails
"""
try:
return operation()
except HTTPException:
raise
except Exception as e:
db.rollback()
logger.error(f'{error_message}: {type(e).__name__}: {str(e)}', exc_info=True)
raise HTTPException(status_code=500, detail=error_message)

View File

@@ -0,0 +1,43 @@
"""
Database transaction context manager for consistent transaction handling.
This module provides a context manager that automatically handles
database commits and rollbacks, ensuring data consistency.
"""
from contextlib import contextmanager
from sqlalchemy.orm import Session
from typing import Generator
from ..config.logging_config import get_logger
logger = get_logger(__name__)
@contextmanager
def transaction(db: Session) -> Generator[Session, None, None]:
"""
Context manager for database transactions.
Automatically commits on success, rolls back on error.
Usage:
with transaction(db):
db.add(booking)
# Auto-commits on success, rolls back on error
Args:
db: SQLAlchemy database session
Yields:
The database session
Raises:
Any exception that occurs during the transaction
"""
try:
yield db
db.commit()
logger.debug('Transaction committed successfully')
except Exception as e:
db.rollback()
logger.error(f'Transaction rolled back due to error: {type(e).__name__}: {str(e)}', exc_info=True)
raise

Some files were not shown because too many files have changed in this diff Show More