feat: Implement refresh token functionality in authentication flow
- Added support for refresh tokens in the authentication backend, allowing users to obtain new access tokens using valid refresh tokens. - Created a new `BearerResponseWithRefresh` model to structure responses containing both access and refresh tokens. - Updated the `AuthenticationBackend` to handle login and logout processes with refresh token support. - Introduced a new `/auth/jwt/refresh` endpoint to facilitate token refreshing, validating the refresh token and generating new tokens as needed. - Modified OAuth callback logic to generate and return both access and refresh tokens upon successful authentication. - Updated frontend API service to send the refresh token in the Authorization header for token refresh requests.
This commit is contained in:
parent
a0d67f6c66
commit
84b046508a
@ -4,7 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.database import get_transactional_session
|
||||
from app.models import User
|
||||
from app.auth import oauth, fastapi_users, auth_backend
|
||||
from app.auth import oauth, fastapi_users, auth_backend, get_jwt_strategy, get_refresh_jwt_strategy
|
||||
from app.config import settings
|
||||
|
||||
router = APIRouter()
|
||||
@ -34,16 +34,15 @@ async def google_callback(request: Request, db: AsyncSession = Depends(get_trans
|
||||
await db.flush() # Use flush instead of commit since we're in a transaction
|
||||
user_to_login = new_user
|
||||
|
||||
# Generate JWT token
|
||||
strategy = auth_backend.get_strategy()
|
||||
token_response = await strategy.write_token(user_to_login)
|
||||
access_token = token_response["access_token"]
|
||||
refresh_token = token_response.get("refresh_token") # Use .get for safety, though it should be there
|
||||
# Generate JWT tokens using the new backend
|
||||
access_strategy = get_jwt_strategy()
|
||||
refresh_strategy = get_refresh_jwt_strategy()
|
||||
|
||||
access_token = await access_strategy.write_token(user_to_login)
|
||||
refresh_token = await refresh_strategy.write_token(user_to_login)
|
||||
|
||||
# Redirect to frontend with tokens
|
||||
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}"
|
||||
if refresh_token:
|
||||
redirect_url += f"&refresh_token={refresh_token}"
|
||||
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}&refresh_token={refresh_token}"
|
||||
|
||||
return RedirectResponse(url=redirect_url)
|
||||
|
||||
@ -83,15 +82,14 @@ async def apple_callback(request: Request, db: AsyncSession = Depends(get_transa
|
||||
await db.flush() # Use flush instead of commit since we're in a transaction
|
||||
user_to_login = new_user
|
||||
|
||||
# Generate JWT token
|
||||
strategy = auth_backend.get_strategy()
|
||||
token_response = await strategy.write_token(user_to_login)
|
||||
access_token = token_response["access_token"]
|
||||
refresh_token = token_response.get("refresh_token") # Use .get for safety
|
||||
# Generate JWT tokens using the new backend
|
||||
access_strategy = get_jwt_strategy()
|
||||
refresh_strategy = get_refresh_jwt_strategy()
|
||||
|
||||
access_token = await access_strategy.write_token(user_to_login)
|
||||
refresh_token = await refresh_strategy.write_token(user_to_login)
|
||||
|
||||
# Redirect to frontend with tokens
|
||||
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}"
|
||||
if refresh_token:
|
||||
redirect_url += f"&refresh_token={refresh_token}"
|
||||
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}&refresh_token={refresh_token}"
|
||||
|
||||
return RedirectResponse(url=redirect_url)
|
@ -14,6 +14,8 @@ from authlib.integrations.starlette_client import OAuth
|
||||
from starlette.config import Config
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from starlette.responses import Response
|
||||
from pydantic import BaseModel
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from .database import get_session
|
||||
from .models import User
|
||||
@ -43,6 +45,57 @@ oauth.register(
|
||||
}
|
||||
)
|
||||
|
||||
# Custom Bearer Response with Refresh Token
|
||||
class BearerResponseWithRefresh(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
# Custom Bearer Transport that supports refresh tokens
|
||||
class BearerTransportWithRefresh(BearerTransport):
|
||||
async def get_login_response(self, token: str, refresh_token: str = None) -> Response:
|
||||
if refresh_token:
|
||||
bearer_response = BearerResponseWithRefresh(
|
||||
access_token=token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer"
|
||||
)
|
||||
else:
|
||||
# Fallback to standard response if no refresh token
|
||||
bearer_response = {
|
||||
"access_token": token,
|
||||
"token_type": "bearer"
|
||||
}
|
||||
return JSONResponse(bearer_response.dict() if hasattr(bearer_response, 'dict') else bearer_response)
|
||||
|
||||
# Custom Authentication Backend with Refresh Token Support
|
||||
class AuthenticationBackendWithRefresh(AuthenticationBackend):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
transport: BearerTransportWithRefresh,
|
||||
get_strategy,
|
||||
get_refresh_strategy,
|
||||
):
|
||||
self.name = name
|
||||
self.transport = transport
|
||||
self.get_strategy = get_strategy
|
||||
self.get_refresh_strategy = get_refresh_strategy
|
||||
|
||||
async def login(self, strategy, user) -> Response:
|
||||
# Generate both access and refresh tokens
|
||||
access_token = await strategy.write_token(user)
|
||||
refresh_strategy = self.get_refresh_strategy()
|
||||
refresh_token = await refresh_strategy.write_token(user)
|
||||
|
||||
return await self.transport.get_login_response(
|
||||
token=access_token,
|
||||
refresh_token=refresh_token
|
||||
)
|
||||
|
||||
async def logout(self, strategy, user, token) -> Response:
|
||||
return await self.transport.get_logout_response()
|
||||
|
||||
class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
|
||||
reset_password_token_secret = settings.SECRET_KEY
|
||||
verification_token_secret = settings.SECRET_KEY
|
||||
@ -71,15 +124,22 @@ async def get_user_db(session: AsyncSession = Depends(get_session)):
|
||||
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
|
||||
yield UserManager(user_db)
|
||||
|
||||
bearer_transport = BearerTransport(tokenUrl="auth/jwt/login")
|
||||
# Updated transport with refresh token support
|
||||
bearer_transport = BearerTransportWithRefresh(tokenUrl="auth/jwt/login")
|
||||
|
||||
def get_jwt_strategy() -> JWTStrategy:
|
||||
return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60)
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
def get_refresh_jwt_strategy() -> JWTStrategy:
|
||||
# Refresh tokens last longer - 7 days
|
||||
return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=7 * 24 * 60 * 60)
|
||||
|
||||
# Updated auth backend with refresh token support
|
||||
auth_backend = AuthenticationBackendWithRefresh(
|
||||
name="jwt",
|
||||
transport=bearer_transport,
|
||||
get_strategy=get_jwt_strategy,
|
||||
get_refresh_strategy=get_refresh_jwt_strategy,
|
||||
)
|
||||
|
||||
fastapi_users = FastAPIUsers[User, int](
|
||||
|
@ -1,20 +1,31 @@
|
||||
# app/main.py
|
||||
import logging
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, HTTPException, Depends, status, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
import sentry_sdk
|
||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||
from fastapi_users.authentication import JWTStrategy
|
||||
from pydantic import BaseModel
|
||||
from jose import jwt, JWTError
|
||||
|
||||
from app.api.api_router import api_router
|
||||
from app.config import settings
|
||||
from app.core.api_config import API_METADATA, API_TAGS
|
||||
from app.auth import fastapi_users, auth_backend
|
||||
from app.auth import fastapi_users, auth_backend, get_refresh_jwt_strategy, get_jwt_strategy
|
||||
from app.models import User
|
||||
from app.api.auth.oauth import router as oauth_router
|
||||
from app.schemas.user import UserPublic, UserCreate, UserUpdate
|
||||
from app.core.scheduler import init_scheduler, shutdown_scheduler
|
||||
from app.database import get_session
|
||||
from sqlalchemy import select
|
||||
|
||||
# Response model for refresh endpoint
|
||||
class RefreshResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
# Initialize Sentry
|
||||
sentry_sdk.init(
|
||||
@ -60,6 +71,74 @@ app.add_middleware(
|
||||
)
|
||||
# --- End CORS Middleware ---
|
||||
|
||||
# Refresh token endpoint
|
||||
@app.post("/auth/jwt/refresh", response_model=RefreshResponse, tags=["auth"])
|
||||
async def refresh_jwt_token(
|
||||
request: Request,
|
||||
refresh_strategy: JWTStrategy = Depends(get_refresh_jwt_strategy),
|
||||
access_strategy: JWTStrategy = Depends(get_jwt_strategy),
|
||||
):
|
||||
"""
|
||||
Refresh access token using a valid refresh token.
|
||||
Send refresh token in Authorization header: Bearer <refresh_token>
|
||||
"""
|
||||
try:
|
||||
# Get refresh token from Authorization header
|
||||
authorization = request.headers.get("Authorization")
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Refresh token missing or invalid format",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
refresh_token = authorization.split(" ")[1]
|
||||
|
||||
# Validate refresh token and get user data
|
||||
try:
|
||||
# Decode the refresh token to get the user identifier
|
||||
payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=["HS256"])
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
)
|
||||
except JWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
)
|
||||
|
||||
# Get user from database
|
||||
async with get_session() as session:
|
||||
result = await session.execute(select(User).where(User.id == int(user_id)))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or inactive",
|
||||
)
|
||||
|
||||
# Generate new tokens
|
||||
new_access_token = await access_strategy.write_token(user)
|
||||
new_refresh_token = await refresh_strategy.write_token(user)
|
||||
|
||||
return RefreshResponse(
|
||||
access_token=new_access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
token_type="bearer"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing token: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token"
|
||||
)
|
||||
|
||||
# --- Include API Routers ---
|
||||
# Include FastAPI-Users routes
|
||||
|
@ -54,8 +54,11 @@ api.interceptors.response.use(
|
||||
return Promise.reject(error)
|
||||
}
|
||||
|
||||
const response = await api.post(API_ENDPOINTS.AUTH.REFRESH, {
|
||||
refresh_token: refreshTokenValue,
|
||||
// Send refresh token in Authorization header as expected by backend
|
||||
const response = await api.post(API_ENDPOINTS.AUTH.REFRESH, {}, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${refreshTokenValue}`
|
||||
}
|
||||
})
|
||||
|
||||
const { access_token: newAccessToken, refresh_token: newRefreshToken } = response.data
|
||||
|
Loading…
Reference in New Issue
Block a user