feat: Implement refresh token functionality in authentication flow

- Added support for refresh tokens in the authentication backend, allowing users to obtain new access tokens using valid refresh tokens.
- Created a new `BearerResponseWithRefresh` model to structure responses containing both access and refresh tokens.
- Updated the `AuthenticationBackend` to handle login and logout processes with refresh token support.
- Introduced a new `/auth/jwt/refresh` endpoint to facilitate token refreshing, validating the refresh token and generating new tokens as needed.
- Modified OAuth callback logic to generate and return both access and refresh tokens upon successful authentication.
- Updated frontend API service to send the refresh token in the Authorization header for token refresh requests.
This commit is contained in:
mohamad 2025-05-25 12:51:02 +02:00
parent a0d67f6c66
commit 84b046508a
4 changed files with 163 additions and 23 deletions

View File

@ -4,7 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.database import get_transactional_session
from app.models import User
from app.auth import oauth, fastapi_users, auth_backend
from app.auth import oauth, fastapi_users, auth_backend, get_jwt_strategy, get_refresh_jwt_strategy
from app.config import settings
router = APIRouter()
@ -34,16 +34,15 @@ async def google_callback(request: Request, db: AsyncSession = Depends(get_trans
await db.flush() # Use flush instead of commit since we're in a transaction
user_to_login = new_user
# Generate JWT token
strategy = auth_backend.get_strategy()
token_response = await strategy.write_token(user_to_login)
access_token = token_response["access_token"]
refresh_token = token_response.get("refresh_token") # Use .get for safety, though it should be there
# Generate JWT tokens using the new backend
access_strategy = get_jwt_strategy()
refresh_strategy = get_refresh_jwt_strategy()
access_token = await access_strategy.write_token(user_to_login)
refresh_token = await refresh_strategy.write_token(user_to_login)
# Redirect to frontend with tokens
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}"
if refresh_token:
redirect_url += f"&refresh_token={refresh_token}"
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}&refresh_token={refresh_token}"
return RedirectResponse(url=redirect_url)
@ -83,15 +82,14 @@ async def apple_callback(request: Request, db: AsyncSession = Depends(get_transa
await db.flush() # Use flush instead of commit since we're in a transaction
user_to_login = new_user
# Generate JWT token
strategy = auth_backend.get_strategy()
token_response = await strategy.write_token(user_to_login)
access_token = token_response["access_token"]
refresh_token = token_response.get("refresh_token") # Use .get for safety
# Generate JWT tokens using the new backend
access_strategy = get_jwt_strategy()
refresh_strategy = get_refresh_jwt_strategy()
access_token = await access_strategy.write_token(user_to_login)
refresh_token = await refresh_strategy.write_token(user_to_login)
# Redirect to frontend with tokens
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}"
if refresh_token:
redirect_url += f"&refresh_token={refresh_token}"
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}&refresh_token={refresh_token}"
return RedirectResponse(url=redirect_url)

View File

@ -14,6 +14,8 @@ from authlib.integrations.starlette_client import OAuth
from starlette.config import Config
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response
from pydantic import BaseModel
from fastapi.responses import JSONResponse
from .database import get_session
from .models import User
@ -43,6 +45,57 @@ oauth.register(
}
)
# Custom Bearer Response with Refresh Token
class BearerResponseWithRefresh(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
# Custom Bearer Transport that supports refresh tokens
class BearerTransportWithRefresh(BearerTransport):
async def get_login_response(self, token: str, refresh_token: str = None) -> Response:
if refresh_token:
bearer_response = BearerResponseWithRefresh(
access_token=token,
refresh_token=refresh_token,
token_type="bearer"
)
else:
# Fallback to standard response if no refresh token
bearer_response = {
"access_token": token,
"token_type": "bearer"
}
return JSONResponse(bearer_response.dict() if hasattr(bearer_response, 'dict') else bearer_response)
# Custom Authentication Backend with Refresh Token Support
class AuthenticationBackendWithRefresh(AuthenticationBackend):
def __init__(
self,
name: str,
transport: BearerTransportWithRefresh,
get_strategy,
get_refresh_strategy,
):
self.name = name
self.transport = transport
self.get_strategy = get_strategy
self.get_refresh_strategy = get_refresh_strategy
async def login(self, strategy, user) -> Response:
# Generate both access and refresh tokens
access_token = await strategy.write_token(user)
refresh_strategy = self.get_refresh_strategy()
refresh_token = await refresh_strategy.write_token(user)
return await self.transport.get_login_response(
token=access_token,
refresh_token=refresh_token
)
async def logout(self, strategy, user, token) -> Response:
return await self.transport.get_logout_response()
class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
reset_password_token_secret = settings.SECRET_KEY
verification_token_secret = settings.SECRET_KEY
@ -71,15 +124,22 @@ async def get_user_db(session: AsyncSession = Depends(get_session)):
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
yield UserManager(user_db)
bearer_transport = BearerTransport(tokenUrl="auth/jwt/login")
# Updated transport with refresh token support
bearer_transport = BearerTransportWithRefresh(tokenUrl="auth/jwt/login")
def get_jwt_strategy() -> JWTStrategy:
return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60)
auth_backend = AuthenticationBackend(
def get_refresh_jwt_strategy() -> JWTStrategy:
# Refresh tokens last longer - 7 days
return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=7 * 24 * 60 * 60)
# Updated auth backend with refresh token support
auth_backend = AuthenticationBackendWithRefresh(
name="jwt",
transport=bearer_transport,
get_strategy=get_jwt_strategy,
get_refresh_strategy=get_refresh_jwt_strategy,
)
fastapi_users = FastAPIUsers[User, int](

View File

@ -1,20 +1,31 @@
# app/main.py
import logging
import uvicorn
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException, Depends, status, Request
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
import sentry_sdk
from sentry_sdk.integrations.fastapi import FastApiIntegration
from fastapi_users.authentication import JWTStrategy
from pydantic import BaseModel
from jose import jwt, JWTError
from app.api.api_router import api_router
from app.config import settings
from app.core.api_config import API_METADATA, API_TAGS
from app.auth import fastapi_users, auth_backend
from app.auth import fastapi_users, auth_backend, get_refresh_jwt_strategy, get_jwt_strategy
from app.models import User
from app.api.auth.oauth import router as oauth_router
from app.schemas.user import UserPublic, UserCreate, UserUpdate
from app.core.scheduler import init_scheduler, shutdown_scheduler
from app.database import get_session
from sqlalchemy import select
# Response model for refresh endpoint
class RefreshResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
# Initialize Sentry
sentry_sdk.init(
@ -60,6 +71,74 @@ app.add_middleware(
)
# --- End CORS Middleware ---
# Refresh token endpoint
@app.post("/auth/jwt/refresh", response_model=RefreshResponse, tags=["auth"])
async def refresh_jwt_token(
request: Request,
refresh_strategy: JWTStrategy = Depends(get_refresh_jwt_strategy),
access_strategy: JWTStrategy = Depends(get_jwt_strategy),
):
"""
Refresh access token using a valid refresh token.
Send refresh token in Authorization header: Bearer <refresh_token>
"""
try:
# Get refresh token from Authorization header
authorization = request.headers.get("Authorization")
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token missing or invalid format",
headers={"WWW-Authenticate": "Bearer"},
)
refresh_token = authorization.split(" ")[1]
# Validate refresh token and get user data
try:
# Decode the refresh token to get the user identifier
payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=["HS256"])
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
# Get user from database
async with get_session() as session:
result = await session.execute(select(User).where(User.id == int(user_id)))
user = result.scalar_one_or_none()
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or inactive",
)
# Generate new tokens
new_access_token = await access_strategy.write_token(user)
new_refresh_token = await refresh_strategy.write_token(user)
return RefreshResponse(
access_token=new_access_token,
refresh_token=new_refresh_token,
token_type="bearer"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error refreshing token: {e}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token"
)
# --- Include API Routers ---
# Include FastAPI-Users routes

View File

@ -54,8 +54,11 @@ api.interceptors.response.use(
return Promise.reject(error)
}
const response = await api.post(API_ENDPOINTS.AUTH.REFRESH, {
refresh_token: refreshTokenValue,
// Send refresh token in Authorization header as expected by backend
const response = await api.post(API_ENDPOINTS.AUTH.REFRESH, {}, {
headers: {
Authorization: `Bearer ${refreshTokenValue}`
}
})
const { access_token: newAccessToken, refresh_token: newRefreshToken } = response.data