mitlist/be/app/main.py
mohamad 9d404d04d5 Update OAuth redirect URIs and API routing structure
- Changed the Google and Apple redirect URIs in the configuration to include the API version in the path.
- Reorganized the inclusion of OAuth routes in the main application to ensure they are properly prefixed and accessible.

These updates aim to enhance the API structure and ensure consistency in the authentication flow.
2025-06-02 18:07:41 +02:00

266 lines
8.4 KiB
Python

# app/main.py
import logging
import uvicorn
from fastapi import FastAPI, HTTPException, Depends, status, Request
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
import sentry_sdk
from sentry_sdk.integrations.fastapi import FastApiIntegration
from fastapi_users.authentication import JWTStrategy
from pydantic import BaseModel
from jose import jwt, JWTError
from sqlalchemy.ext.asyncio import AsyncEngine
from alembic.config import Config
from alembic import command
import os
import sys
from app.api.api_router import api_router
from app.config import settings
from app.core.api_config import API_METADATA, API_TAGS
from app.auth import fastapi_users, auth_backend, get_refresh_jwt_strategy, get_jwt_strategy
from app.models import User
from app.api.auth.oauth import router as oauth_router
from app.schemas.user import UserPublic, UserCreate, UserUpdate
from app.core.scheduler import init_scheduler, shutdown_scheduler
from app.database import get_session
from sqlalchemy import select
# Response model for refresh endpoint
class RefreshResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
# Initialize Sentry only if DSN is provided
if settings.SENTRY_DSN:
sentry_sdk.init(
dsn=settings.SENTRY_DSN,
integrations=[
FastApiIntegration(),
],
# Adjust traces_sample_rate for production
traces_sample_rate=0.1 if settings.is_production else 1.0,
environment=settings.ENVIRONMENT,
# Enable PII data only in development
send_default_pii=not settings.is_production
)
# --- Logging Setup ---
logging.basicConfig(
level=getattr(logging, settings.LOG_LEVEL),
format=settings.LOG_FORMAT
)
logger = logging.getLogger(__name__)
# --- FastAPI App Instance ---
# Create API metadata with environment-dependent settings
api_metadata = {
**API_METADATA,
"docs_url": settings.docs_url,
"redoc_url": settings.redoc_url,
"openapi_url": settings.openapi_url,
}
app = FastAPI(
**api_metadata,
openapi_tags=API_TAGS
)
# Add session middleware for OAuth
app.add_middleware(
SessionMiddleware,
secret_key=settings.SESSION_SECRET_KEY
)
# --- CORS Middleware ---
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"]
)
# --- End CORS Middleware ---
# Refresh token endpoint
@app.post("/auth/jwt/refresh", response_model=RefreshResponse, tags=["auth"])
async def refresh_jwt_token(
request: Request,
refresh_strategy: JWTStrategy = Depends(get_refresh_jwt_strategy),
access_strategy: JWTStrategy = Depends(get_jwt_strategy),
):
"""
Refresh access token using a valid refresh token.
Send refresh token in Authorization header: Bearer <refresh_token>
"""
try:
# Get refresh token from Authorization header
authorization = request.headers.get("Authorization")
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token missing or invalid format",
headers={"WWW-Authenticate": "Bearer"},
)
refresh_token = authorization.split(" ")[1]
# Validate refresh token and get user data
try:
# Decode the refresh token to get the user identifier
payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=["HS256"])
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
# Get user from database
async with get_session() as session:
result = await session.execute(select(User).where(User.id == int(user_id)))
user = result.scalar_one_or_none()
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or inactive",
)
# Generate new tokens
new_access_token = await access_strategy.write_token(user)
new_refresh_token = await refresh_strategy.write_token(user)
return RefreshResponse(
access_token=new_access_token,
refresh_token=new_refresh_token,
token_type="bearer"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error refreshing token: {e}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token"
)
# --- Include API Routers ---
# Include OAuth routes first (no auth required)
app.include_router(oauth_router, prefix="/auth", tags=["auth"])
# Include FastAPI-Users routes
app.include_router(
fastapi_users.get_auth_router(auth_backend),
prefix="/auth/jwt",
tags=["auth"],
)
app.include_router(
fastapi_users.get_register_router(UserPublic, UserCreate),
prefix="/auth",
tags=["auth"],
)
app.include_router(
fastapi_users.get_reset_password_router(),
prefix="/auth",
tags=["auth"],
)
app.include_router(
fastapi_users.get_verify_router(UserPublic),
prefix="/auth",
tags=["auth"],
)
app.include_router(
fastapi_users.get_users_router(UserPublic, UserUpdate),
prefix="/users",
tags=["users"],
)
# Include your API router
app.include_router(api_router, prefix=settings.API_PREFIX)
# --- End Include API Routers ---
# Health check endpoint
@app.get("/health", tags=["Health"])
async def health_check():
"""
Health check endpoint for load balancers and monitoring.
"""
return {
"status": settings.HEALTH_STATUS_OK,
"environment": settings.ENVIRONMENT,
"version": settings.API_VERSION
}
# --- Root Endpoint (Optional - outside the main API structure) ---
@app.get("/", tags=["Root"])
async def read_root():
"""
Provides a simple welcome message at the root path.
Useful for basic reachability checks.
"""
logger.info("Root endpoint '/' accessed.")
return {
"message": settings.ROOT_MESSAGE,
"environment": settings.ENVIRONMENT,
"version": settings.API_VERSION
}
# --- End Root Endpoint ---
async def run_migrations():
"""Run database migrations."""
try:
logger.info("Running database migrations...")
# Get the absolute path to the alembic directory
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
alembic_path = os.path.join(base_path, 'alembic')
# Add alembic directory to Python path
if alembic_path not in sys.path:
sys.path.insert(0, alembic_path)
# Import and run migrations
from migrations import run_migrations as run_db_migrations
await run_db_migrations()
logger.info("Database migrations completed successfully.")
except Exception as e:
logger.error(f"Error running migrations: {e}")
raise
@app.on_event("startup")
async def startup_event():
"""Initialize services on startup."""
logger.info(f"Application startup in {settings.ENVIRONMENT} environment...")
# Run database migrations
# await run_migrations()
# Initialize scheduler
init_scheduler()
logger.info("Application startup complete.")
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup services on shutdown."""
logger.info("Application shutdown: Disconnecting from database...")
# await database.engine.dispose() # Close connection pool
shutdown_scheduler()
logger.info("Application shutdown complete.")
# --- End Events ---
# --- Direct Run (for simple local testing if needed) ---
# It's better to use `uvicorn app.main:app --reload` from the terminal
# if __name__ == "__main__":
# logger.info("Starting Uvicorn server directly from main.py")
# uvicorn.run(app, host="0.0.0.0", port=8000)
# ------------------------------------------------------