From 84b046508adf4121a01fc8e3a38059fd0cb14f1c Mon Sep 17 00:00:00 2001 From: mohamad Date: Sun, 25 May 2025 12:51:02 +0200 Subject: [PATCH] feat: Implement refresh token functionality in authentication flow - Added support for refresh tokens in the authentication backend, allowing users to obtain new access tokens using valid refresh tokens. - Created a new `BearerResponseWithRefresh` model to structure responses containing both access and refresh tokens. - Updated the `AuthenticationBackend` to handle login and logout processes with refresh token support. - Introduced a new `/auth/jwt/refresh` endpoint to facilitate token refreshing, validating the refresh token and generating new tokens as needed. - Modified OAuth callback logic to generate and return both access and refresh tokens upon successful authentication. - Updated frontend API service to send the refresh token in the Authorization header for token refresh requests. --- be/app/api/auth/oauth.py | 32 ++++++++-------- be/app/auth.py | 64 ++++++++++++++++++++++++++++++- be/app/main.py | 83 +++++++++++++++++++++++++++++++++++++++- fe/src/services/api.ts | 7 +++- 4 files changed, 163 insertions(+), 23 deletions(-) diff --git a/be/app/api/auth/oauth.py b/be/app/api/auth/oauth.py index f758a14..13223d6 100644 --- a/be/app/api/auth/oauth.py +++ b/be/app/api/auth/oauth.py @@ -4,7 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.database import get_transactional_session from app.models import User -from app.auth import oauth, fastapi_users, auth_backend +from app.auth import oauth, fastapi_users, auth_backend, get_jwt_strategy, get_refresh_jwt_strategy from app.config import settings router = APIRouter() @@ -34,16 +34,15 @@ async def google_callback(request: Request, db: AsyncSession = Depends(get_trans await db.flush() # Use flush instead of commit since we're in a transaction user_to_login = new_user - # Generate JWT token - strategy = auth_backend.get_strategy() - token_response = await strategy.write_token(user_to_login) - access_token = token_response["access_token"] - refresh_token = token_response.get("refresh_token") # Use .get for safety, though it should be there + # Generate JWT tokens using the new backend + access_strategy = get_jwt_strategy() + refresh_strategy = get_refresh_jwt_strategy() + + access_token = await access_strategy.write_token(user_to_login) + refresh_token = await refresh_strategy.write_token(user_to_login) # Redirect to frontend with tokens - redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}" - if refresh_token: - redirect_url += f"&refresh_token={refresh_token}" + redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}&refresh_token={refresh_token}" return RedirectResponse(url=redirect_url) @@ -83,15 +82,14 @@ async def apple_callback(request: Request, db: AsyncSession = Depends(get_transa await db.flush() # Use flush instead of commit since we're in a transaction user_to_login = new_user - # Generate JWT token - strategy = auth_backend.get_strategy() - token_response = await strategy.write_token(user_to_login) - access_token = token_response["access_token"] - refresh_token = token_response.get("refresh_token") # Use .get for safety + # Generate JWT tokens using the new backend + access_strategy = get_jwt_strategy() + refresh_strategy = get_refresh_jwt_strategy() + + access_token = await access_strategy.write_token(user_to_login) + refresh_token = await refresh_strategy.write_token(user_to_login) # Redirect to frontend with tokens - redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}" - if refresh_token: - redirect_url += f"&refresh_token={refresh_token}" + redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}&refresh_token={refresh_token}" return RedirectResponse(url=redirect_url) \ No newline at end of file diff --git a/be/app/auth.py b/be/app/auth.py index b7b95ad..f8a7518 100644 --- a/be/app/auth.py +++ b/be/app/auth.py @@ -14,6 +14,8 @@ from authlib.integrations.starlette_client import OAuth from starlette.config import Config from starlette.middleware.sessions import SessionMiddleware from starlette.responses import Response +from pydantic import BaseModel +from fastapi.responses import JSONResponse from .database import get_session from .models import User @@ -43,6 +45,57 @@ oauth.register( } ) +# Custom Bearer Response with Refresh Token +class BearerResponseWithRefresh(BaseModel): + access_token: str + refresh_token: str + token_type: str = "bearer" + +# Custom Bearer Transport that supports refresh tokens +class BearerTransportWithRefresh(BearerTransport): + async def get_login_response(self, token: str, refresh_token: str = None) -> Response: + if refresh_token: + bearer_response = BearerResponseWithRefresh( + access_token=token, + refresh_token=refresh_token, + token_type="bearer" + ) + else: + # Fallback to standard response if no refresh token + bearer_response = { + "access_token": token, + "token_type": "bearer" + } + return JSONResponse(bearer_response.dict() if hasattr(bearer_response, 'dict') else bearer_response) + +# Custom Authentication Backend with Refresh Token Support +class AuthenticationBackendWithRefresh(AuthenticationBackend): + def __init__( + self, + name: str, + transport: BearerTransportWithRefresh, + get_strategy, + get_refresh_strategy, + ): + self.name = name + self.transport = transport + self.get_strategy = get_strategy + self.get_refresh_strategy = get_refresh_strategy + + async def login(self, strategy, user) -> Response: + # Generate both access and refresh tokens + access_token = await strategy.write_token(user) + refresh_strategy = self.get_refresh_strategy() + refresh_token = await refresh_strategy.write_token(user) + + return await self.transport.get_login_response( + token=access_token, + refresh_token=refresh_token + ) + + async def logout(self, strategy, user, token) -> Response: + return await self.transport.get_logout_response() + class UserManager(IntegerIDMixin, BaseUserManager[User, int]): reset_password_token_secret = settings.SECRET_KEY verification_token_secret = settings.SECRET_KEY @@ -71,15 +124,22 @@ async def get_user_db(session: AsyncSession = Depends(get_session)): async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)): yield UserManager(user_db) -bearer_transport = BearerTransport(tokenUrl="auth/jwt/login") +# Updated transport with refresh token support +bearer_transport = BearerTransportWithRefresh(tokenUrl="auth/jwt/login") def get_jwt_strategy() -> JWTStrategy: return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60) -auth_backend = AuthenticationBackend( +def get_refresh_jwt_strategy() -> JWTStrategy: + # Refresh tokens last longer - 7 days + return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=7 * 24 * 60 * 60) + +# Updated auth backend with refresh token support +auth_backend = AuthenticationBackendWithRefresh( name="jwt", transport=bearer_transport, get_strategy=get_jwt_strategy, + get_refresh_strategy=get_refresh_jwt_strategy, ) fastapi_users = FastAPIUsers[User, int]( diff --git a/be/app/main.py b/be/app/main.py index 683c555..3365eb6 100644 --- a/be/app/main.py +++ b/be/app/main.py @@ -1,20 +1,31 @@ # app/main.py import logging import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException, Depends, status, Request from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.sessions import SessionMiddleware import sentry_sdk from sentry_sdk.integrations.fastapi import FastApiIntegration +from fastapi_users.authentication import JWTStrategy +from pydantic import BaseModel +from jose import jwt, JWTError from app.api.api_router import api_router from app.config import settings from app.core.api_config import API_METADATA, API_TAGS -from app.auth import fastapi_users, auth_backend +from app.auth import fastapi_users, auth_backend, get_refresh_jwt_strategy, get_jwt_strategy from app.models import User from app.api.auth.oauth import router as oauth_router from app.schemas.user import UserPublic, UserCreate, UserUpdate from app.core.scheduler import init_scheduler, shutdown_scheduler +from app.database import get_session +from sqlalchemy import select + +# Response model for refresh endpoint +class RefreshResponse(BaseModel): + access_token: str + refresh_token: str + token_type: str = "bearer" # Initialize Sentry sentry_sdk.init( @@ -60,6 +71,74 @@ app.add_middleware( ) # --- End CORS Middleware --- +# Refresh token endpoint +@app.post("/auth/jwt/refresh", response_model=RefreshResponse, tags=["auth"]) +async def refresh_jwt_token( + request: Request, + refresh_strategy: JWTStrategy = Depends(get_refresh_jwt_strategy), + access_strategy: JWTStrategy = Depends(get_jwt_strategy), +): + """ + Refresh access token using a valid refresh token. + Send refresh token in Authorization header: Bearer + """ + try: + # Get refresh token from Authorization header + authorization = request.headers.get("Authorization") + if not authorization or not authorization.startswith("Bearer "): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Refresh token missing or invalid format", + headers={"WWW-Authenticate": "Bearer"}, + ) + + refresh_token = authorization.split(" ")[1] + + # Validate refresh token and get user data + try: + # Decode the refresh token to get the user identifier + payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=["HS256"]) + user_id = payload.get("sub") + if user_id is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid refresh token", + ) + except JWTError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid refresh token", + ) + + # Get user from database + async with get_session() as session: + result = await session.execute(select(User).where(User.id == int(user_id))) + user = result.scalar_one_or_none() + + if not user or not user.is_active: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="User not found or inactive", + ) + + # Generate new tokens + new_access_token = await access_strategy.write_token(user) + new_refresh_token = await refresh_strategy.write_token(user) + + return RefreshResponse( + access_token=new_access_token, + refresh_token=new_refresh_token, + token_type="bearer" + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error refreshing token: {e}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid refresh token" + ) # --- Include API Routers --- # Include FastAPI-Users routes diff --git a/fe/src/services/api.ts b/fe/src/services/api.ts index ff69c67..77268d0 100644 --- a/fe/src/services/api.ts +++ b/fe/src/services/api.ts @@ -54,8 +54,11 @@ api.interceptors.response.use( return Promise.reject(error) } - const response = await api.post(API_ENDPOINTS.AUTH.REFRESH, { - refresh_token: refreshTokenValue, + // Send refresh token in Authorization header as expected by backend + const response = await api.post(API_ENDPOINTS.AUTH.REFRESH, {}, { + headers: { + Authorization: `Bearer ${refreshTokenValue}` + } }) const { access_token: newAccessToken, refresh_token: newRefreshToken } = response.data