feat: Implement refresh token functionality in authentication flow

- Added support for refresh tokens in the authentication backend, allowing users to obtain new access tokens using valid refresh tokens.
- Created a new `BearerResponseWithRefresh` model to structure responses containing both access and refresh tokens.
- Updated the `AuthenticationBackend` to handle login and logout processes with refresh token support.
- Introduced a new `/auth/jwt/refresh` endpoint to facilitate token refreshing, validating the refresh token and generating new tokens as needed.
- Modified OAuth callback logic to generate and return both access and refresh tokens upon successful authentication.
- Updated frontend API service to send the refresh token in the Authorization header for token refresh requests.
This commit is contained in:
mohamad 2025-05-25 12:51:02 +02:00
parent a0d67f6c66
commit 84b046508a
4 changed files with 163 additions and 23 deletions

View File

@ -4,7 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select from sqlalchemy import select
from app.database import get_transactional_session from app.database import get_transactional_session
from app.models import User from app.models import User
from app.auth import oauth, fastapi_users, auth_backend from app.auth import oauth, fastapi_users, auth_backend, get_jwt_strategy, get_refresh_jwt_strategy
from app.config import settings from app.config import settings
router = APIRouter() router = APIRouter()
@ -34,16 +34,15 @@ async def google_callback(request: Request, db: AsyncSession = Depends(get_trans
await db.flush() # Use flush instead of commit since we're in a transaction await db.flush() # Use flush instead of commit since we're in a transaction
user_to_login = new_user user_to_login = new_user
# Generate JWT token # Generate JWT tokens using the new backend
strategy = auth_backend.get_strategy() access_strategy = get_jwt_strategy()
token_response = await strategy.write_token(user_to_login) refresh_strategy = get_refresh_jwt_strategy()
access_token = token_response["access_token"]
refresh_token = token_response.get("refresh_token") # Use .get for safety, though it should be there access_token = await access_strategy.write_token(user_to_login)
refresh_token = await refresh_strategy.write_token(user_to_login)
# Redirect to frontend with tokens # Redirect to frontend with tokens
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}" redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}&refresh_token={refresh_token}"
if refresh_token:
redirect_url += f"&refresh_token={refresh_token}"
return RedirectResponse(url=redirect_url) return RedirectResponse(url=redirect_url)
@ -83,15 +82,14 @@ async def apple_callback(request: Request, db: AsyncSession = Depends(get_transa
await db.flush() # Use flush instead of commit since we're in a transaction await db.flush() # Use flush instead of commit since we're in a transaction
user_to_login = new_user user_to_login = new_user
# Generate JWT token # Generate JWT tokens using the new backend
strategy = auth_backend.get_strategy() access_strategy = get_jwt_strategy()
token_response = await strategy.write_token(user_to_login) refresh_strategy = get_refresh_jwt_strategy()
access_token = token_response["access_token"]
refresh_token = token_response.get("refresh_token") # Use .get for safety access_token = await access_strategy.write_token(user_to_login)
refresh_token = await refresh_strategy.write_token(user_to_login)
# Redirect to frontend with tokens # Redirect to frontend with tokens
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}" redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}&refresh_token={refresh_token}"
if refresh_token:
redirect_url += f"&refresh_token={refresh_token}"
return RedirectResponse(url=redirect_url) return RedirectResponse(url=redirect_url)

View File

@ -14,6 +14,8 @@ from authlib.integrations.starlette_client import OAuth
from starlette.config import Config from starlette.config import Config
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response from starlette.responses import Response
from pydantic import BaseModel
from fastapi.responses import JSONResponse
from .database import get_session from .database import get_session
from .models import User from .models import User
@ -43,6 +45,57 @@ oauth.register(
} }
) )
# Custom Bearer Response with Refresh Token
class BearerResponseWithRefresh(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
# Custom Bearer Transport that supports refresh tokens
class BearerTransportWithRefresh(BearerTransport):
async def get_login_response(self, token: str, refresh_token: str = None) -> Response:
if refresh_token:
bearer_response = BearerResponseWithRefresh(
access_token=token,
refresh_token=refresh_token,
token_type="bearer"
)
else:
# Fallback to standard response if no refresh token
bearer_response = {
"access_token": token,
"token_type": "bearer"
}
return JSONResponse(bearer_response.dict() if hasattr(bearer_response, 'dict') else bearer_response)
# Custom Authentication Backend with Refresh Token Support
class AuthenticationBackendWithRefresh(AuthenticationBackend):
def __init__(
self,
name: str,
transport: BearerTransportWithRefresh,
get_strategy,
get_refresh_strategy,
):
self.name = name
self.transport = transport
self.get_strategy = get_strategy
self.get_refresh_strategy = get_refresh_strategy
async def login(self, strategy, user) -> Response:
# Generate both access and refresh tokens
access_token = await strategy.write_token(user)
refresh_strategy = self.get_refresh_strategy()
refresh_token = await refresh_strategy.write_token(user)
return await self.transport.get_login_response(
token=access_token,
refresh_token=refresh_token
)
async def logout(self, strategy, user, token) -> Response:
return await self.transport.get_logout_response()
class UserManager(IntegerIDMixin, BaseUserManager[User, int]): class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
reset_password_token_secret = settings.SECRET_KEY reset_password_token_secret = settings.SECRET_KEY
verification_token_secret = settings.SECRET_KEY verification_token_secret = settings.SECRET_KEY
@ -71,15 +124,22 @@ async def get_user_db(session: AsyncSession = Depends(get_session)):
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)): async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
yield UserManager(user_db) yield UserManager(user_db)
bearer_transport = BearerTransport(tokenUrl="auth/jwt/login") # Updated transport with refresh token support
bearer_transport = BearerTransportWithRefresh(tokenUrl="auth/jwt/login")
def get_jwt_strategy() -> JWTStrategy: def get_jwt_strategy() -> JWTStrategy:
return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60) return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60)
auth_backend = AuthenticationBackend( def get_refresh_jwt_strategy() -> JWTStrategy:
# Refresh tokens last longer - 7 days
return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=7 * 24 * 60 * 60)
# Updated auth backend with refresh token support
auth_backend = AuthenticationBackendWithRefresh(
name="jwt", name="jwt",
transport=bearer_transport, transport=bearer_transport,
get_strategy=get_jwt_strategy, get_strategy=get_jwt_strategy,
get_refresh_strategy=get_refresh_jwt_strategy,
) )
fastapi_users = FastAPIUsers[User, int]( fastapi_users = FastAPIUsers[User, int](

View File

@ -1,20 +1,31 @@
# app/main.py # app/main.py
import logging import logging
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI, HTTPException, Depends, status, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
import sentry_sdk import sentry_sdk
from sentry_sdk.integrations.fastapi import FastApiIntegration from sentry_sdk.integrations.fastapi import FastApiIntegration
from fastapi_users.authentication import JWTStrategy
from pydantic import BaseModel
from jose import jwt, JWTError
from app.api.api_router import api_router from app.api.api_router import api_router
from app.config import settings from app.config import settings
from app.core.api_config import API_METADATA, API_TAGS from app.core.api_config import API_METADATA, API_TAGS
from app.auth import fastapi_users, auth_backend from app.auth import fastapi_users, auth_backend, get_refresh_jwt_strategy, get_jwt_strategy
from app.models import User from app.models import User
from app.api.auth.oauth import router as oauth_router from app.api.auth.oauth import router as oauth_router
from app.schemas.user import UserPublic, UserCreate, UserUpdate from app.schemas.user import UserPublic, UserCreate, UserUpdate
from app.core.scheduler import init_scheduler, shutdown_scheduler from app.core.scheduler import init_scheduler, shutdown_scheduler
from app.database import get_session
from sqlalchemy import select
# Response model for refresh endpoint
class RefreshResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
# Initialize Sentry # Initialize Sentry
sentry_sdk.init( sentry_sdk.init(
@ -60,6 +71,74 @@ app.add_middleware(
) )
# --- End CORS Middleware --- # --- End CORS Middleware ---
# Refresh token endpoint
@app.post("/auth/jwt/refresh", response_model=RefreshResponse, tags=["auth"])
async def refresh_jwt_token(
request: Request,
refresh_strategy: JWTStrategy = Depends(get_refresh_jwt_strategy),
access_strategy: JWTStrategy = Depends(get_jwt_strategy),
):
"""
Refresh access token using a valid refresh token.
Send refresh token in Authorization header: Bearer <refresh_token>
"""
try:
# Get refresh token from Authorization header
authorization = request.headers.get("Authorization")
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token missing or invalid format",
headers={"WWW-Authenticate": "Bearer"},
)
refresh_token = authorization.split(" ")[1]
# Validate refresh token and get user data
try:
# Decode the refresh token to get the user identifier
payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=["HS256"])
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
# Get user from database
async with get_session() as session:
result = await session.execute(select(User).where(User.id == int(user_id)))
user = result.scalar_one_or_none()
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or inactive",
)
# Generate new tokens
new_access_token = await access_strategy.write_token(user)
new_refresh_token = await refresh_strategy.write_token(user)
return RefreshResponse(
access_token=new_access_token,
refresh_token=new_refresh_token,
token_type="bearer"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error refreshing token: {e}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token"
)
# --- Include API Routers --- # --- Include API Routers ---
# Include FastAPI-Users routes # Include FastAPI-Users routes

View File

@ -54,8 +54,11 @@ api.interceptors.response.use(
return Promise.reject(error) return Promise.reject(error)
} }
const response = await api.post(API_ENDPOINTS.AUTH.REFRESH, { // Send refresh token in Authorization header as expected by backend
refresh_token: refreshTokenValue, const response = await api.post(API_ENDPOINTS.AUTH.REFRESH, {}, {
headers: {
Authorization: `Bearer ${refreshTokenValue}`
}
}) })
const { access_token: newAccessToken, refresh_token: newRefreshToken } = response.data const { access_token: newAccessToken, refresh_token: newRefreshToken } = response.data