
- Updated the OAuth routes to be included under the main API prefix for better organization. - Changed the Google login URL in the SocialLoginButtons component to reflect the new API structure. These changes aim to improve the clarity and consistency of the API routing and enhance the login flow for users.
269 lines
8.5 KiB
Python
269 lines
8.5 KiB
Python
# app/main.py
|
|
import logging
|
|
import uvicorn
|
|
from fastapi import FastAPI, HTTPException, Depends, status, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from starlette.middleware.sessions import SessionMiddleware
|
|
import sentry_sdk
|
|
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
|
from fastapi_users.authentication import JWTStrategy
|
|
from pydantic import BaseModel
|
|
from jose import jwt, JWTError
|
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
from alembic.config import Config
|
|
from alembic import command
|
|
import os
|
|
import sys
|
|
|
|
from app.api.api_router import api_router
|
|
from app.config import settings
|
|
from app.core.api_config import API_METADATA, API_TAGS
|
|
from app.auth import fastapi_users, auth_backend, get_refresh_jwt_strategy, get_jwt_strategy
|
|
from app.models import User
|
|
from app.api.auth.oauth import router as oauth_router
|
|
from app.schemas.user import UserPublic, UserCreate, UserUpdate
|
|
from app.core.scheduler import init_scheduler, shutdown_scheduler
|
|
from app.database import get_session
|
|
from sqlalchemy import select
|
|
|
|
# Response model for refresh endpoint
|
|
class RefreshResponse(BaseModel):
|
|
access_token: str
|
|
refresh_token: str
|
|
token_type: str = "bearer"
|
|
|
|
# Initialize Sentry only if DSN is provided
|
|
if settings.SENTRY_DSN:
|
|
sentry_sdk.init(
|
|
dsn=settings.SENTRY_DSN,
|
|
integrations=[
|
|
FastApiIntegration(),
|
|
],
|
|
# Adjust traces_sample_rate for production
|
|
traces_sample_rate=0.1 if settings.is_production else 1.0,
|
|
environment=settings.ENVIRONMENT,
|
|
# Enable PII data only in development
|
|
send_default_pii=not settings.is_production
|
|
)
|
|
|
|
# --- Logging Setup ---
|
|
logging.basicConfig(
|
|
level=getattr(logging, settings.LOG_LEVEL),
|
|
format=settings.LOG_FORMAT
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- FastAPI App Instance ---
|
|
# Create API metadata with environment-dependent settings
|
|
api_metadata = {
|
|
**API_METADATA,
|
|
"docs_url": settings.docs_url,
|
|
"redoc_url": settings.redoc_url,
|
|
"openapi_url": settings.openapi_url,
|
|
}
|
|
|
|
app = FastAPI(
|
|
**api_metadata,
|
|
openapi_tags=API_TAGS
|
|
)
|
|
|
|
# Add session middleware for OAuth
|
|
app.add_middleware(
|
|
SessionMiddleware,
|
|
secret_key=settings.SESSION_SECRET_KEY
|
|
)
|
|
|
|
# --- CORS Middleware ---
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=settings.cors_origins_list,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
expose_headers=["*"]
|
|
)
|
|
# --- End CORS Middleware ---
|
|
|
|
# Refresh token endpoint
|
|
@app.post("/auth/jwt/refresh", response_model=RefreshResponse, tags=["auth"])
|
|
async def refresh_jwt_token(
|
|
request: Request,
|
|
refresh_strategy: JWTStrategy = Depends(get_refresh_jwt_strategy),
|
|
access_strategy: JWTStrategy = Depends(get_jwt_strategy),
|
|
):
|
|
"""
|
|
Refresh access token using a valid refresh token.
|
|
Send refresh token in Authorization header: Bearer <refresh_token>
|
|
"""
|
|
try:
|
|
# Get refresh token from Authorization header
|
|
authorization = request.headers.get("Authorization")
|
|
if not authorization or not authorization.startswith("Bearer "):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Refresh token missing or invalid format",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
refresh_token = authorization.split(" ")[1]
|
|
|
|
# Validate refresh token and get user data
|
|
try:
|
|
# Decode the refresh token to get the user identifier
|
|
payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=["HS256"])
|
|
user_id = payload.get("sub")
|
|
if user_id is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid refresh token",
|
|
)
|
|
except JWTError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid refresh token",
|
|
)
|
|
|
|
# Get user from database
|
|
async with get_session() as session:
|
|
result = await session.execute(select(User).where(User.id == int(user_id)))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user or not user.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User not found or inactive",
|
|
)
|
|
|
|
# Generate new tokens
|
|
new_access_token = await access_strategy.write_token(user)
|
|
new_refresh_token = await refresh_strategy.write_token(user)
|
|
|
|
return RefreshResponse(
|
|
access_token=new_access_token,
|
|
refresh_token=new_refresh_token,
|
|
token_type="bearer"
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error refreshing token: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid refresh token"
|
|
)
|
|
|
|
# --- Include API Routers ---
|
|
# Include FastAPI-Users routes
|
|
app.include_router(
|
|
fastapi_users.get_auth_router(auth_backend),
|
|
prefix="/auth/jwt",
|
|
tags=["auth"],
|
|
)
|
|
app.include_router(
|
|
fastapi_users.get_register_router(UserPublic, UserCreate),
|
|
prefix="/auth",
|
|
tags=["auth"],
|
|
)
|
|
app.include_router(
|
|
fastapi_users.get_reset_password_router(),
|
|
prefix="/auth",
|
|
tags=["auth"],
|
|
)
|
|
app.include_router(
|
|
fastapi_users.get_verify_router(UserPublic),
|
|
prefix="/auth",
|
|
tags=["auth"],
|
|
)
|
|
app.include_router(
|
|
fastapi_users.get_users_router(UserPublic, UserUpdate),
|
|
prefix="/users",
|
|
tags=["users"],
|
|
)
|
|
|
|
# Include OAuth routes
|
|
# app.include_router(oauth_router, prefix="/auth", tags=["auth"])
|
|
|
|
# Include your API router
|
|
app.include_router(api_router, prefix=settings.API_PREFIX)
|
|
|
|
# Include OAuth routes under the main API prefix
|
|
app.include_router(oauth_router, prefix=f"{settings.API_PREFIX}/auth", tags=["auth"])
|
|
# --- End Include API Routers ---
|
|
|
|
# Health check endpoint
|
|
@app.get("/health", tags=["Health"])
|
|
async def health_check():
|
|
"""
|
|
Health check endpoint for load balancers and monitoring.
|
|
"""
|
|
return {
|
|
"status": settings.HEALTH_STATUS_OK,
|
|
"environment": settings.ENVIRONMENT,
|
|
"version": settings.API_VERSION
|
|
}
|
|
|
|
# --- Root Endpoint (Optional - outside the main API structure) ---
|
|
@app.get("/", tags=["Root"])
|
|
async def read_root():
|
|
"""
|
|
Provides a simple welcome message at the root path.
|
|
Useful for basic reachability checks.
|
|
"""
|
|
logger.info("Root endpoint '/' accessed.")
|
|
return {
|
|
"message": settings.ROOT_MESSAGE,
|
|
"environment": settings.ENVIRONMENT,
|
|
"version": settings.API_VERSION
|
|
}
|
|
# --- End Root Endpoint ---
|
|
|
|
async def run_migrations():
|
|
"""Run database migrations."""
|
|
try:
|
|
logger.info("Running database migrations...")
|
|
# Get the absolute path to the alembic directory
|
|
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
alembic_path = os.path.join(base_path, 'alembic')
|
|
|
|
# Add alembic directory to Python path
|
|
if alembic_path not in sys.path:
|
|
sys.path.insert(0, alembic_path)
|
|
|
|
# Import and run migrations
|
|
from migrations import run_migrations as run_db_migrations
|
|
await run_db_migrations()
|
|
|
|
logger.info("Database migrations completed successfully.")
|
|
except Exception as e:
|
|
logger.error(f"Error running migrations: {e}")
|
|
raise
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
"""Initialize services on startup."""
|
|
logger.info(f"Application startup in {settings.ENVIRONMENT} environment...")
|
|
|
|
# Run database migrations
|
|
# await run_migrations()
|
|
|
|
# Initialize scheduler
|
|
init_scheduler()
|
|
logger.info("Application startup complete.")
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown_event():
|
|
"""Cleanup services on shutdown."""
|
|
logger.info("Application shutdown: Disconnecting from database...")
|
|
# await database.engine.dispose() # Close connection pool
|
|
shutdown_scheduler()
|
|
logger.info("Application shutdown complete.")
|
|
# --- End Events ---
|
|
|
|
|
|
# --- Direct Run (for simple local testing if needed) ---
|
|
# It's better to use `uvicorn app.main:app --reload` from the terminal
|
|
# if __name__ == "__main__":
|
|
# logger.info("Starting Uvicorn server directly from main.py")
|
|
# uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
# ------------------------------------------------------ |