feat: Implement refresh token functionality in authentication flow
- Added support for refresh tokens in the authentication backend, allowing users to obtain new access tokens using valid refresh tokens. - Created a new `BearerResponseWithRefresh` model to structure responses containing both access and refresh tokens. - Updated the `AuthenticationBackend` to handle login and logout processes with refresh token support. - Introduced a new `/auth/jwt/refresh` endpoint to facilitate token refreshing, validating the refresh token and generating new tokens as needed. - Modified OAuth callback logic to generate and return both access and refresh tokens upon successful authentication. - Updated frontend API service to send the refresh token in the Authorization header for token refresh requests.
This commit is contained in:
parent
a0d67f6c66
commit
84b046508a
@ -4,7 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from app.database import get_transactional_session
|
from app.database import get_transactional_session
|
||||||
from app.models import User
|
from app.models import User
|
||||||
from app.auth import oauth, fastapi_users, auth_backend
|
from app.auth import oauth, fastapi_users, auth_backend, get_jwt_strategy, get_refresh_jwt_strategy
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -34,16 +34,15 @@ async def google_callback(request: Request, db: AsyncSession = Depends(get_trans
|
|||||||
await db.flush() # Use flush instead of commit since we're in a transaction
|
await db.flush() # Use flush instead of commit since we're in a transaction
|
||||||
user_to_login = new_user
|
user_to_login = new_user
|
||||||
|
|
||||||
# Generate JWT token
|
# Generate JWT tokens using the new backend
|
||||||
strategy = auth_backend.get_strategy()
|
access_strategy = get_jwt_strategy()
|
||||||
token_response = await strategy.write_token(user_to_login)
|
refresh_strategy = get_refresh_jwt_strategy()
|
||||||
access_token = token_response["access_token"]
|
|
||||||
refresh_token = token_response.get("refresh_token") # Use .get for safety, though it should be there
|
access_token = await access_strategy.write_token(user_to_login)
|
||||||
|
refresh_token = await refresh_strategy.write_token(user_to_login)
|
||||||
|
|
||||||
# Redirect to frontend with tokens
|
# Redirect to frontend with tokens
|
||||||
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}"
|
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}&refresh_token={refresh_token}"
|
||||||
if refresh_token:
|
|
||||||
redirect_url += f"&refresh_token={refresh_token}"
|
|
||||||
|
|
||||||
return RedirectResponse(url=redirect_url)
|
return RedirectResponse(url=redirect_url)
|
||||||
|
|
||||||
@ -83,15 +82,14 @@ async def apple_callback(request: Request, db: AsyncSession = Depends(get_transa
|
|||||||
await db.flush() # Use flush instead of commit since we're in a transaction
|
await db.flush() # Use flush instead of commit since we're in a transaction
|
||||||
user_to_login = new_user
|
user_to_login = new_user
|
||||||
|
|
||||||
# Generate JWT token
|
# Generate JWT tokens using the new backend
|
||||||
strategy = auth_backend.get_strategy()
|
access_strategy = get_jwt_strategy()
|
||||||
token_response = await strategy.write_token(user_to_login)
|
refresh_strategy = get_refresh_jwt_strategy()
|
||||||
access_token = token_response["access_token"]
|
|
||||||
refresh_token = token_response.get("refresh_token") # Use .get for safety
|
access_token = await access_strategy.write_token(user_to_login)
|
||||||
|
refresh_token = await refresh_strategy.write_token(user_to_login)
|
||||||
|
|
||||||
# Redirect to frontend with tokens
|
# Redirect to frontend with tokens
|
||||||
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}"
|
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}&refresh_token={refresh_token}"
|
||||||
if refresh_token:
|
|
||||||
redirect_url += f"&refresh_token={refresh_token}"
|
|
||||||
|
|
||||||
return RedirectResponse(url=redirect_url)
|
return RedirectResponse(url=redirect_url)
|
@ -14,6 +14,8 @@ from authlib.integrations.starlette_client import OAuth
|
|||||||
from starlette.config import Config
|
from starlette.config import Config
|
||||||
from starlette.middleware.sessions import SessionMiddleware
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from .database import get_session
|
from .database import get_session
|
||||||
from .models import User
|
from .models import User
|
||||||
@ -43,6 +45,57 @@ oauth.register(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Custom Bearer Response with Refresh Token
|
||||||
|
class BearerResponseWithRefresh(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
|
||||||
|
# Custom Bearer Transport that supports refresh tokens
|
||||||
|
class BearerTransportWithRefresh(BearerTransport):
|
||||||
|
async def get_login_response(self, token: str, refresh_token: str = None) -> Response:
|
||||||
|
if refresh_token:
|
||||||
|
bearer_response = BearerResponseWithRefresh(
|
||||||
|
access_token=token,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
token_type="bearer"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback to standard response if no refresh token
|
||||||
|
bearer_response = {
|
||||||
|
"access_token": token,
|
||||||
|
"token_type": "bearer"
|
||||||
|
}
|
||||||
|
return JSONResponse(bearer_response.dict() if hasattr(bearer_response, 'dict') else bearer_response)
|
||||||
|
|
||||||
|
# Custom Authentication Backend with Refresh Token Support
|
||||||
|
class AuthenticationBackendWithRefresh(AuthenticationBackend):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
transport: BearerTransportWithRefresh,
|
||||||
|
get_strategy,
|
||||||
|
get_refresh_strategy,
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.transport = transport
|
||||||
|
self.get_strategy = get_strategy
|
||||||
|
self.get_refresh_strategy = get_refresh_strategy
|
||||||
|
|
||||||
|
async def login(self, strategy, user) -> Response:
|
||||||
|
# Generate both access and refresh tokens
|
||||||
|
access_token = await strategy.write_token(user)
|
||||||
|
refresh_strategy = self.get_refresh_strategy()
|
||||||
|
refresh_token = await refresh_strategy.write_token(user)
|
||||||
|
|
||||||
|
return await self.transport.get_login_response(
|
||||||
|
token=access_token,
|
||||||
|
refresh_token=refresh_token
|
||||||
|
)
|
||||||
|
|
||||||
|
async def logout(self, strategy, user, token) -> Response:
|
||||||
|
return await self.transport.get_logout_response()
|
||||||
|
|
||||||
class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
|
class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
|
||||||
reset_password_token_secret = settings.SECRET_KEY
|
reset_password_token_secret = settings.SECRET_KEY
|
||||||
verification_token_secret = settings.SECRET_KEY
|
verification_token_secret = settings.SECRET_KEY
|
||||||
@ -71,15 +124,22 @@ async def get_user_db(session: AsyncSession = Depends(get_session)):
|
|||||||
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
|
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
|
||||||
yield UserManager(user_db)
|
yield UserManager(user_db)
|
||||||
|
|
||||||
bearer_transport = BearerTransport(tokenUrl="auth/jwt/login")
|
# Updated transport with refresh token support
|
||||||
|
bearer_transport = BearerTransportWithRefresh(tokenUrl="auth/jwt/login")
|
||||||
|
|
||||||
def get_jwt_strategy() -> JWTStrategy:
|
def get_jwt_strategy() -> JWTStrategy:
|
||||||
return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60)
|
return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60)
|
||||||
|
|
||||||
auth_backend = AuthenticationBackend(
|
def get_refresh_jwt_strategy() -> JWTStrategy:
|
||||||
|
# Refresh tokens last longer - 7 days
|
||||||
|
return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=7 * 24 * 60 * 60)
|
||||||
|
|
||||||
|
# Updated auth backend with refresh token support
|
||||||
|
auth_backend = AuthenticationBackendWithRefresh(
|
||||||
name="jwt",
|
name="jwt",
|
||||||
transport=bearer_transport,
|
transport=bearer_transport,
|
||||||
get_strategy=get_jwt_strategy,
|
get_strategy=get_jwt_strategy,
|
||||||
|
get_refresh_strategy=get_refresh_jwt_strategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
fastapi_users = FastAPIUsers[User, int](
|
fastapi_users = FastAPIUsers[User, int](
|
||||||
|
@ -1,20 +1,31 @@
|
|||||||
# app/main.py
|
# app/main.py
|
||||||
import logging
|
import logging
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, HTTPException, Depends, status, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from starlette.middleware.sessions import SessionMiddleware
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||||
|
from fastapi_users.authentication import JWTStrategy
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from jose import jwt, JWTError
|
||||||
|
|
||||||
from app.api.api_router import api_router
|
from app.api.api_router import api_router
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.core.api_config import API_METADATA, API_TAGS
|
from app.core.api_config import API_METADATA, API_TAGS
|
||||||
from app.auth import fastapi_users, auth_backend
|
from app.auth import fastapi_users, auth_backend, get_refresh_jwt_strategy, get_jwt_strategy
|
||||||
from app.models import User
|
from app.models import User
|
||||||
from app.api.auth.oauth import router as oauth_router
|
from app.api.auth.oauth import router as oauth_router
|
||||||
from app.schemas.user import UserPublic, UserCreate, UserUpdate
|
from app.schemas.user import UserPublic, UserCreate, UserUpdate
|
||||||
from app.core.scheduler import init_scheduler, shutdown_scheduler
|
from app.core.scheduler import init_scheduler, shutdown_scheduler
|
||||||
|
from app.database import get_session
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
# Response model for refresh endpoint
|
||||||
|
class RefreshResponse(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
|
||||||
# Initialize Sentry
|
# Initialize Sentry
|
||||||
sentry_sdk.init(
|
sentry_sdk.init(
|
||||||
@ -60,6 +71,74 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
# --- End CORS Middleware ---
|
# --- End CORS Middleware ---
|
||||||
|
|
||||||
|
# Refresh token endpoint
|
||||||
|
@app.post("/auth/jwt/refresh", response_model=RefreshResponse, tags=["auth"])
|
||||||
|
async def refresh_jwt_token(
|
||||||
|
request: Request,
|
||||||
|
refresh_strategy: JWTStrategy = Depends(get_refresh_jwt_strategy),
|
||||||
|
access_strategy: JWTStrategy = Depends(get_jwt_strategy),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Refresh access token using a valid refresh token.
|
||||||
|
Send refresh token in Authorization header: Bearer <refresh_token>
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get refresh token from Authorization header
|
||||||
|
authorization = request.headers.get("Authorization")
|
||||||
|
if not authorization or not authorization.startswith("Bearer "):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Refresh token missing or invalid format",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_token = authorization.split(" ")[1]
|
||||||
|
|
||||||
|
# Validate refresh token and get user data
|
||||||
|
try:
|
||||||
|
# Decode the refresh token to get the user identifier
|
||||||
|
payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=["HS256"])
|
||||||
|
user_id = payload.get("sub")
|
||||||
|
if user_id is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid refresh token",
|
||||||
|
)
|
||||||
|
except JWTError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid refresh token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get user from database
|
||||||
|
async with get_session() as session:
|
||||||
|
result = await session.execute(select(User).where(User.id == int(user_id)))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not user or not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="User not found or inactive",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate new tokens
|
||||||
|
new_access_token = await access_strategy.write_token(user)
|
||||||
|
new_refresh_token = await refresh_strategy.write_token(user)
|
||||||
|
|
||||||
|
return RefreshResponse(
|
||||||
|
access_token=new_access_token,
|
||||||
|
refresh_token=new_refresh_token,
|
||||||
|
token_type="bearer"
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error refreshing token: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid refresh token"
|
||||||
|
)
|
||||||
|
|
||||||
# --- Include API Routers ---
|
# --- Include API Routers ---
|
||||||
# Include FastAPI-Users routes
|
# Include FastAPI-Users routes
|
||||||
|
@ -54,8 +54,11 @@ api.interceptors.response.use(
|
|||||||
return Promise.reject(error)
|
return Promise.reject(error)
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await api.post(API_ENDPOINTS.AUTH.REFRESH, {
|
// Send refresh token in Authorization header as expected by backend
|
||||||
refresh_token: refreshTokenValue,
|
const response = await api.post(API_ENDPOINTS.AUTH.REFRESH, {}, {
|
||||||
|
headers: {
|
||||||
|
Authorization: `Bearer ${refreshTokenValue}`
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
const { access_token: newAccessToken, refresh_token: newRefreshToken } = response.data
|
const { access_token: newAccessToken, refresh_token: newRefreshToken } = response.data
|
||||||
|
Loading…
Reference in New Issue
Block a user