from typing import Optional from fastapi import Depends, Request from fastapi.security import OAuth2PasswordRequestForm from fastapi_users import BaseUserManager, FastAPIUsers, IntegerIDMixin from fastapi_users.authentication import ( AuthenticationBackend, BearerTransport, JWTStrategy, ) from fastapi_users.db import SQLAlchemyUserDatabase from sqlalchemy.ext.asyncio import AsyncSession from authlib.integrations.starlette_client import OAuth from starlette.config import Config from starlette.middleware.sessions import SessionMiddleware from starlette.responses import Response from pydantic import BaseModel from fastapi.responses import JSONResponse from .database import get_session from .models import User from .config import settings # OAuth2 configuration config = Config('.env') oauth = OAuth(config) # Google OAuth2 setup oauth.register( name='google', server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', client_kwargs={ 'scope': 'openid email profile', 'redirect_uri': settings.GOOGLE_REDIRECT_URI } ) # Apple OAuth2 setup oauth.register( name='apple', server_metadata_url='https://appleid.apple.com/.well-known/openid-configuration', client_kwargs={ 'scope': 'openid email name', 'redirect_uri': settings.APPLE_REDIRECT_URI } ) # Custom Bearer Response with Refresh Token class BearerResponseWithRefresh(BaseModel): access_token: str refresh_token: str token_type: str = "bearer" # Custom Bearer Transport that supports refresh tokens class BearerTransportWithRefresh(BearerTransport): async def get_login_response(self, token: str, refresh_token: str = None) -> Response: if refresh_token: bearer_response = BearerResponseWithRefresh( access_token=token, refresh_token=refresh_token, token_type="bearer" ) else: # Fallback to standard response if no refresh token bearer_response = { "access_token": token, "token_type": "bearer" } return JSONResponse(bearer_response.dict() if hasattr(bearer_response, 'dict') else bearer_response) # Custom Authentication Backend with Refresh Token Support class AuthenticationBackendWithRefresh(AuthenticationBackend): def __init__( self, name: str, transport: BearerTransportWithRefresh, get_strategy, get_refresh_strategy, ): self.name = name self.transport = transport self.get_strategy = get_strategy self.get_refresh_strategy = get_refresh_strategy async def login(self, strategy, user) -> Response: # Generate both access and refresh tokens access_token = await strategy.write_token(user) refresh_strategy = self.get_refresh_strategy() refresh_token = await refresh_strategy.write_token(user) return await self.transport.get_login_response( token=access_token, refresh_token=refresh_token ) async def logout(self, strategy, user, token) -> Response: return await self.transport.get_logout_response() class UserManager(IntegerIDMixin, BaseUserManager[User, int]): reset_password_token_secret = settings.SECRET_KEY verification_token_secret = settings.SECRET_KEY async def on_after_register(self, user: User, request: Optional[Request] = None): print(f"User {user.id} has registered.") async def on_after_forgot_password( self, user: User, token: str, request: Optional[Request] = None ): print(f"User {user.id} has forgot their password. Reset token: {token}") async def on_after_request_verify( self, user: User, token: str, request: Optional[Request] = None ): print(f"Verification requested for user {user.id}. Verification token: {token}") async def on_after_login( self, user: User, request: Optional[Request] = None, response: Optional[Response] = None ): print(f"User {user.id} has logged in.") async def get_user_db(session: AsyncSession = Depends(get_session)): yield SQLAlchemyUserDatabase(session, User) async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)): yield UserManager(user_db) # Updated transport with refresh token support bearer_transport = BearerTransportWithRefresh(tokenUrl="auth/jwt/login") def get_jwt_strategy() -> JWTStrategy: return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60) def get_refresh_jwt_strategy() -> JWTStrategy: # Refresh tokens last longer - 7 days return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=7 * 24 * 60 * 60) # Updated auth backend with refresh token support auth_backend = AuthenticationBackendWithRefresh( name="jwt", transport=bearer_transport, get_strategy=get_jwt_strategy, get_refresh_strategy=get_refresh_jwt_strategy, ) fastapi_users = FastAPIUsers[User, int]( get_user_manager, [auth_backend], ) current_active_user = fastapi_users.current_user(active=True) current_superuser = fastapi_users.current_user(active=True, superuser=True)