# app/main.py import logging import uvicorn from fastapi import FastAPI, HTTPException, Depends, status, Request from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.sessions import SessionMiddleware import sentry_sdk from sentry_sdk.integrations.fastapi import FastApiIntegration from fastapi_users.authentication import JWTStrategy from pydantic import BaseModel from jose import jwt, JWTError from sqlalchemy.ext.asyncio import AsyncEngine from alembic.config import Config from alembic import command import os import sys from app.api.api_router import api_router from app.config import settings from app.core.api_config import API_METADATA, API_TAGS from app.auth import fastapi_users, auth_backend, get_refresh_jwt_strategy, get_jwt_strategy from app.models import User from app.api.auth.oauth import router as oauth_router from app.schemas.user import UserPublic, UserCreate, UserUpdate from app.core.scheduler import init_scheduler, shutdown_scheduler from app.database import get_session from sqlalchemy import select # Response model for refresh endpoint class RefreshResponse(BaseModel): access_token: str refresh_token: str token_type: str = "bearer" # Initialize Sentry only if DSN is provided if settings.SENTRY_DSN: sentry_sdk.init( dsn=settings.SENTRY_DSN, integrations=[ FastApiIntegration(), ], # Adjust traces_sample_rate for production traces_sample_rate=0.1 if settings.is_production else 1.0, environment=settings.ENVIRONMENT, # Enable PII data only in development send_default_pii=not settings.is_production ) # --- Logging Setup --- logging.basicConfig( level=getattr(logging, settings.LOG_LEVEL), format=settings.LOG_FORMAT ) logger = logging.getLogger(__name__) # --- FastAPI App Instance --- # Create API metadata with environment-dependent settings api_metadata = { **API_METADATA, "docs_url": settings.docs_url, "redoc_url": settings.redoc_url, "openapi_url": settings.openapi_url, } app = FastAPI( **api_metadata, openapi_tags=API_TAGS ) # Add session middleware for OAuth app.add_middleware( SessionMiddleware, secret_key=settings.SESSION_SECRET_KEY ) # --- CORS Middleware --- app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins_list, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], expose_headers=["*"] ) # --- End CORS Middleware --- # Refresh token endpoint @app.post("/auth/jwt/refresh", response_model=RefreshResponse, tags=["auth"]) async def refresh_jwt_token( request: Request, refresh_strategy: JWTStrategy = Depends(get_refresh_jwt_strategy), access_strategy: JWTStrategy = Depends(get_jwt_strategy), ): """ Refresh access token using a valid refresh token. Send refresh token in Authorization header: Bearer """ try: # Get refresh token from Authorization header authorization = request.headers.get("Authorization") if not authorization or not authorization.startswith("Bearer "): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token missing or invalid format", headers={"WWW-Authenticate": "Bearer"}, ) refresh_token = authorization.split(" ")[1] # Validate refresh token and get user data try: # Decode the refresh token to get the user identifier payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=["HS256"]) user_id = payload.get("sub") if user_id is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token", ) except JWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token", ) # Get user from database async with get_session() as session: result = await session.execute(select(User).where(User.id == int(user_id))) user = result.scalar_one_or_none() if not user or not user.is_active: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive", ) # Generate new tokens new_access_token = await access_strategy.write_token(user) new_refresh_token = await refresh_strategy.write_token(user) return RefreshResponse( access_token=new_access_token, refresh_token=new_refresh_token, token_type="bearer" ) except HTTPException: raise except Exception as e: logger.error(f"Error refreshing token: {e}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token" ) # --- Include API Routers --- # Include FastAPI-Users routes app.include_router( fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"], ) app.include_router( fastapi_users.get_register_router(UserPublic, UserCreate), prefix="/auth", tags=["auth"], ) app.include_router( fastapi_users.get_reset_password_router(), prefix="/auth", tags=["auth"], ) app.include_router( fastapi_users.get_verify_router(UserPublic), prefix="/auth", tags=["auth"], ) app.include_router( fastapi_users.get_users_router(UserPublic, UserUpdate), prefix="/users", tags=["users"], ) # Include OAuth routes app.include_router(oauth_router, prefix="/auth", tags=["auth"]) # Include your API router app.include_router(api_router, prefix=settings.API_PREFIX) # --- End Include API Routers --- # Health check endpoint @app.get("/health", tags=["Health"]) async def health_check(): """ Health check endpoint for load balancers and monitoring. """ return { "status": settings.HEALTH_STATUS_OK, "environment": settings.ENVIRONMENT, "version": settings.API_VERSION } # --- Root Endpoint (Optional - outside the main API structure) --- @app.get("/", tags=["Root"]) async def read_root(): """ Provides a simple welcome message at the root path. Useful for basic reachability checks. """ logger.info("Root endpoint '/' accessed.") return { "message": settings.ROOT_MESSAGE, "environment": settings.ENVIRONMENT, "version": settings.API_VERSION } # --- End Root Endpoint --- async def run_migrations(): """Run database migrations.""" try: logger.info("Running database migrations...") # Get the absolute path to the alembic directory base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) alembic_path = os.path.join(base_path, 'alembic') # Add alembic directory to Python path if alembic_path not in sys.path: sys.path.insert(0, alembic_path) # Import and run migrations from env import run_migrations_online_async await run_migrations_online_async() logger.info("Database migrations completed successfully.") except Exception as e: logger.error(f"Error running migrations: {e}") raise @app.on_event("startup") async def startup_event(): """Initialize services on startup.""" logger.info(f"Application startup in {settings.ENVIRONMENT} environment...") # Run database migrations await run_migrations() # Initialize scheduler init_scheduler() logger.info("Application startup complete.") @app.on_event("shutdown") async def shutdown_event(): """Cleanup services on shutdown.""" logger.info("Application shutdown: Disconnecting from database...") # await database.engine.dispose() # Close connection pool shutdown_scheduler() logger.info("Application shutdown complete.") # --- End Events --- # --- Direct Run (for simple local testing if needed) --- # It's better to use `uvicorn app.main:app --reload` from the terminal # if __name__ == "__main__": # logger.info("Starting Uvicorn server directly from main.py") # uvicorn.run(app, host="0.0.0.0", port=8000) # ------------------------------------------------------