add_version_to_lists_and_items

This commit is contained in:
mohamad 2025-05-07 23:30:23 +02:00
parent d2d484c327
commit 423d345fdf
23 changed files with 796 additions and 185 deletions

View File

@ -32,4 +32,4 @@ EXPOSE 8000
# The default command for production (can be overridden in docker-compose for development) # The default command for production (can be overridden in docker-compose for development)
# Note: Make sure 'app.main:app' correctly points to your FastAPI app instance # Note: Make sure 'app.main:app' correctly points to your FastAPI app instance
# relative to the WORKDIR (/app). If your main.py is directly in /app, this is correct. # relative to the WORKDIR (/app). If your main.py is directly in /app, this is correct.
CMD ["uvicorn", "app.main:app", "--host", "localhost", "--port", "8000"] CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@ -0,0 +1,28 @@
"""add_version_to_lists_and_items
Revision ID: d53eedd151b7
Revises: d25788f63e2c
Create Date: 2025-05-07 21:05:50.396430
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'd53eedd151b7'
down_revision: Union[str, None] = 'd25788f63e2c'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
pass
def downgrade() -> None:
"""Downgrade schema."""
pass

View File

@ -11,13 +11,14 @@ from app.database import get_db
from app.core.security import verify_access_token from app.core.security import verify_access_token
from app.crud import user as crud_user from app.crud import user as crud_user
from app.models import User as UserModel # Import the SQLAlchemy model from app.models import User as UserModel # Import the SQLAlchemy model
from app.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define the OAuth2 scheme # Define the OAuth2 scheme
# tokenUrl should point to your login endpoint relative to the base path # tokenUrl should point to your login endpoint relative to the base path
# It's used by Swagger UI for the "Authorize" button flow. # It's used by Swagger UI for the "Authorize" button flow.
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") # Corrected path oauth2_scheme = OAuth2PasswordBearer(tokenUrl=settings.OAUTH2_TOKEN_URL)
async def get_current_user( async def get_current_user(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
@ -36,8 +37,8 @@ async def get_current_user(
""" """
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials", detail=settings.AUTH_CREDENTIALS_ERROR,
headers={"WWW-Authenticate": "Bearer"}, headers={settings.AUTH_HEADER_NAME: settings.AUTH_HEADER_PREFIX},
) )
payload = verify_access_token(token) payload = verify_access_token(token)

View File

@ -9,6 +9,7 @@ from app.api.v1.endpoints import invites
from app.api.v1.endpoints import lists from app.api.v1.endpoints import lists
from app.api.v1.endpoints import items from app.api.v1.endpoints import items
from app.api.v1.endpoints import ocr from app.api.v1.endpoints import ocr
from app.api.v1.endpoints import costs
api_router_v1 = APIRouter() api_router_v1 = APIRouter()
@ -20,5 +21,6 @@ api_router_v1.include_router(invites.router, prefix="/invites", tags=["Invites"]
api_router_v1.include_router(lists.router, prefix="/lists", tags=["Lists"]) api_router_v1.include_router(lists.router, prefix="/lists", tags=["Lists"])
api_router_v1.include_router(items.router, tags=["Items"]) api_router_v1.include_router(items.router, tags=["Items"])
api_router_v1.include_router(ocr.router, prefix="/ocr", tags=["OCR"]) api_router_v1.include_router(ocr.router, prefix="/ocr", tags=["OCR"])
api_router_v1.include_router(costs.router, tags=["Costs"])
# Add other v1 endpoint routers here later # Add other v1 endpoint routers here later
# e.g., api_router_v1.include_router(users.router, prefix="/users", tags=["Users"]) # e.g., api_router_v1.include_router(users.router, prefix="/users", tags=["Users"])

View File

@ -1,6 +1,7 @@
# app/api/v1/endpoints/auth.py # app/api/v1/endpoints/auth.py
import logging import logging
from fastapi import APIRouter, Depends from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -8,12 +9,18 @@ from app.database import get_db
from app.schemas.user import UserCreate, UserPublic from app.schemas.user import UserCreate, UserPublic
from app.schemas.auth import Token from app.schemas.auth import Token
from app.crud import user as crud_user from app.crud import user as crud_user
from app.core.security import verify_password, create_access_token from app.core.security import (
verify_password,
create_access_token,
create_refresh_token,
verify_refresh_token
)
from app.core.exceptions import ( from app.core.exceptions import (
EmailAlreadyRegisteredError, EmailAlreadyRegisteredError,
InvalidCredentialsError, InvalidCredentialsError,
UserCreationError UserCreationError
) )
from app.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@ -55,28 +62,74 @@ async def signup(
"/login", "/login",
response_model=Token, response_model=Token,
summary="User Login", summary="User Login",
description="Authenticates a user and returns an access token.", description="Authenticates a user and returns an access and refresh token.",
tags=["Authentication"] tags=["Authentication"]
) )
async def login( async def login(
form_data: OAuth2PasswordRequestForm = Depends(), form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
Handles user login. Handles user login.
- Finds user by email (provided in 'username' field of form). - Finds user by email (provided in 'username' field of form).
- Verifies the provided password against the stored hash. - Verifies the provided password against the stored hash.
- Generates and returns a JWT access token upon successful authentication. - Generates and returns JWT access and refresh tokens upon successful authentication.
""" """
logger.info(f"Login attempt for user: {form_data.username}") logger.info(f"Login attempt for user: {form_data.username}")
user = await crud_user.get_user_by_email(db, email=form_data.username) user = await crud_user.get_user_by_email(db, email=form_data.username)
# Check if user exists and password is correct
if not user or not verify_password(form_data.password, user.password_hash): if not user or not verify_password(form_data.password, user.password_hash):
logger.warning(f"Login failed: Invalid credentials for user {form_data.username}") logger.warning(f"Login failed: Invalid credentials for user {form_data.username}")
raise InvalidCredentialsError() raise InvalidCredentialsError()
# Generate JWT
access_token = create_access_token(subject=user.email) access_token = create_access_token(subject=user.email)
logger.info(f"Login successful, token generated for user: {user.email}") refresh_token = create_refresh_token(subject=user.email)
return Token(access_token=access_token, token_type="bearer") logger.info(f"Login successful, tokens generated for user: {user.email}")
return Token(
access_token=access_token,
refresh_token=refresh_token,
token_type=settings.TOKEN_TYPE
)
@router.post(
"/refresh",
response_model=Token,
summary="Refresh Access Token",
description="Refreshes an access token using a refresh token.",
tags=["Authentication"]
)
async def refresh_token(
refresh_token_str: str,
db: AsyncSession = Depends(get_db)
):
"""
Handles access token refresh.
- Verifies the provided refresh token.
- If valid, generates and returns a new JWT access token and the same refresh token.
"""
logger.info("Access token refresh attempt")
payload = verify_refresh_token(refresh_token_str)
if not payload:
logger.warning("Refresh token invalid or expired")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
headers={"WWW-Authenticate": "Bearer"},
)
user_email = payload.get("sub")
if not user_email:
logger.error("User email not found in refresh token payload")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token payload",
headers={"WWW-Authenticate": "Bearer"},
)
new_access_token = create_access_token(subject=user_email)
logger.info(f"Access token refreshed for user: {user_email}")
return Token(
access_token=new_access_token,
refresh_token=refresh_token_str,
token_type=settings.TOKEN_TYPE
)

View File

@ -0,0 +1,69 @@
# app/api/v1/endpoints/costs.py
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.api.dependencies import get_current_user
from app.models import User as UserModel # For get_current_user dependency
from app.schemas.cost import ListCostSummary
from app.crud import cost as crud_cost
from app.crud import list as crud_list # For permission checking
from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get(
"/lists/{list_id}/cost-summary",
response_model=ListCostSummary,
summary="Get Cost Summary for a List",
tags=["Costs"],
responses={
status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this list"},
status.HTTP_404_NOT_FOUND: {"description": "List or associated user not found"}
}
)
async def get_list_cost_summary(
list_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Retrieves a calculated cost summary for a specific list, detailing total costs,
equal shares per user, and individual user balances based on their contributions.
The user must have access to the list to view its cost summary.
Costs are split among group members if the list belongs to a group, or just for
the creator if it's a personal list. All users who added items with prices are
included in the calculation.
"""
logger.info(f"User {current_user.email} requesting cost summary for list {list_id}")
# 1. Verify user has access to the target list
try:
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
except ListPermissionError as e:
logger.warning(f"Permission denied for user {current_user.email} on list {list_id}: {str(e)}")
raise # Re-raise the original exception to be handled by FastAPI
except ListNotFoundError as e:
logger.warning(f"List {list_id} not found when checking permissions for cost summary: {str(e)}")
raise # Re-raise
# 2. Calculate the cost summary
try:
cost_summary = await crud_cost.calculate_list_cost_summary(db=db, list_id=list_id)
logger.info(f"Successfully generated cost summary for list {list_id} for user {current_user.email}")
return cost_summary
except ListNotFoundError as e:
logger.warning(f"List {list_id} not found during cost summary calculation: {str(e)}")
# This might be redundant if check_list_permission already confirmed list existence,
# but calculate_list_cost_summary also fetches the list.
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except UserNotFoundError as e:
logger.error(f"User not found during cost summary calculation for list {list_id}: {str(e)}")
# This indicates a data integrity issue (e.g., list creator or item adder missing)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error generating cost summary for list {list_id} for user {current_user.email}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while generating the cost summary.")

View File

@ -1,8 +1,8 @@
# app/api/v1/endpoints/items.py # app/api/v1/endpoints/items.py
import logging import logging
from typing import List as PyList from typing import List as PyList, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Response from fastapi import APIRouter, Depends, HTTPException, status, Response, Query
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db from app.database import get_db
@ -14,7 +14,7 @@ from app.models import Item as ItemModel # <-- IMPORT Item and alias it
from app.schemas.item import ItemCreate, ItemUpdate, ItemPublic from app.schemas.item import ItemCreate, ItemUpdate, ItemPublic
from app.crud import item as crud_item from app.crud import item as crud_item
from app.crud import list as crud_list from app.crud import list as crud_list
from app.core.exceptions import ItemNotFoundError, ListPermissionError from app.core.exceptions import ItemNotFoundError, ListPermissionError, ConflictError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@ -100,7 +100,10 @@ async def read_list_items(
"/items/{item_id}", # Operate directly on item ID "/items/{item_id}", # Operate directly on item ID
response_model=ItemPublic, response_model=ItemPublic,
summary="Update Item", summary="Update Item",
tags=["Items"] tags=["Items"],
responses={
status.HTTP_409_CONFLICT: {"description": "Conflict: Item has been modified by someone else"}
}
) )
async def update_item( async def update_item(
item_id: int, # Item ID from path item_id: int, # Item ID from path
@ -112,37 +115,61 @@ async def update_item(
""" """
Updates an item's details (name, quantity, is_complete, price). Updates an item's details (name, quantity, is_complete, price).
User must have access to the list the item belongs to. User must have access to the list the item belongs to.
The client MUST provide the current `version` of the item in the `item_in` payload.
If the version does not match, a 409 Conflict is returned.
Sets/unsets `completed_by_id` based on `is_complete` flag. Sets/unsets `completed_by_id` based on `is_complete` flag.
""" """
logger.info(f"User {current_user.email} attempting to update item ID: {item_id}") logger.info(f"User {current_user.email} attempting to update item ID: {item_id} with version {item_in.version}")
# Permission check is handled by get_item_and_verify_access dependency # Permission check is handled by get_item_and_verify_access dependency
updated_item = await crud_item.update_item( try:
db=db, item_db=item_db, item_in=item_in, user_id=current_user.id updated_item = await crud_item.update_item(
) db=db, item_db=item_db, item_in=item_in, user_id=current_user.id
logger.info(f"Item {item_id} updated successfully by user {current_user.email}.") )
return updated_item logger.info(f"Item {item_id} updated successfully by user {current_user.email} to version {updated_item.version}.")
return updated_item
except ConflictError as e:
logger.warning(f"Conflict updating item {item_id} for user {current_user.email}: {str(e)}")
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
except Exception as e:
logger.error(f"Error updating item {item_id} for user {current_user.email}: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the item.")
@router.delete( @router.delete(
"/items/{item_id}", # Operate directly on item ID "/items/{item_id}", # Operate directly on item ID
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
summary="Delete Item", summary="Delete Item",
tags=["Items"] tags=["Items"],
responses={
status.HTTP_409_CONFLICT: {"description": "Conflict: Item has been modified, cannot delete specified version"}
}
) )
async def delete_item( async def delete_item(
item_id: int, # Item ID from path item_id: int, # Item ID from path
expected_version: Optional[int] = Query(None, description="The expected version of the item to delete for optimistic locking."),
item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user), # Log who deleted it current_user: UserModel = Depends(get_current_user), # Log who deleted it
): ):
""" """
Deletes an item. User must have access to the list the item belongs to. Deletes an item. User must have access to the list the item belongs to.
(MVP: Any member with list access can delete items). If `expected_version` is provided and does not match the item's current version,
a 409 Conflict is returned.
""" """
logger.info(f"User {current_user.email} attempting to delete item ID: {item_id}") logger.info(f"User {current_user.email} attempting to delete item ID: {item_id}, expected version: {expected_version}")
# Permission check is handled by get_item_and_verify_access dependency # Permission check is handled by get_item_and_verify_access dependency
if expected_version is not None and item_db.version != expected_version:
logger.warning(
f"Conflict deleting item {item_id} for user {current_user.email}. "
f"Expected version {expected_version}, actual version {item_db.version}."
)
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Item has been modified. Expected version {expected_version}, but current version is {item_db.version}. Please refresh."
)
await crud_item.delete_item(db=db, item_db=item_db) await crud_item.delete_item(db=db, item_db=item_db)
logger.info(f"Item {item_id} deleted successfully by user {current_user.email}.") logger.info(f"Item {item_id} (version {item_db.version}) deleted successfully by user {current_user.email}.")
return Response(status_code=status.HTTP_204_NO_CONTENT) return Response(status_code=status.HTTP_204_NO_CONTENT)

View File

@ -1,8 +1,8 @@
# app/api/v1/endpoints/lists.py # app/api/v1/endpoints/lists.py
import logging import logging
from typing import List as PyList # Alias for Python List type hint from typing import List as PyList, Optional # Alias for Python List type hint
from fastapi import APIRouter, Depends, HTTPException, status, Response from fastapi import APIRouter, Depends, HTTPException, status, Response, Query # Added Query
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db from app.database import get_db
@ -17,7 +17,8 @@ from app.core.exceptions import (
GroupMembershipError, GroupMembershipError,
ListNotFoundError, ListNotFoundError,
ListPermissionError, ListPermissionError,
ListStatusNotFoundError ListStatusNotFoundError,
ConflictError # Added ConflictError
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -101,7 +102,10 @@ async def read_list(
"/{list_id}", "/{list_id}",
response_model=ListPublic, # Return updated basic info response_model=ListPublic, # Return updated basic info
summary="Update List", summary="Update List",
tags=["Lists"] tags=["Lists"],
responses={ # Add 409 to responses
status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified by someone else"}
}
) )
async def update_list( async def update_list(
list_id: int, list_id: int,
@ -112,40 +116,62 @@ async def update_list(
""" """
Updates a list's details (name, description, is_complete). Updates a list's details (name, description, is_complete).
Requires user to be the creator or a member of the list's group. Requires user to be the creator or a member of the list's group.
(MVP: Allows any member to update these fields). The client MUST provide the current `version` of the list in the `list_in` payload.
If the version does not match, a 409 Conflict is returned.
""" """
logger.info(f"User {current_user.email} attempting to update list ID: {list_id}") logger.info(f"User {current_user.email} attempting to update list ID: {list_id} with version {list_in.version}")
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
# Prevent changing group_id or creator via this endpoint for simplicity try:
# if list_in.group_id is not None or list_in.created_by_id is not None: updated_list = await crud_list.update_list(db=db, list_db=list_db, list_in=list_in)
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot change group or creator via this endpoint") logger.info(f"List {list_id} updated successfully by user {current_user.email} to version {updated_list.version}.")
return updated_list
updated_list = await crud_list.update_list(db=db, list_db=list_db, list_in=list_in) except ConflictError as e: # Catch and re-raise as HTTPException for proper FastAPI response
logger.info(f"List {list_id} updated successfully by user {current_user.email}.") logger.warning(f"Conflict updating list {list_id} for user {current_user.email}: {str(e)}")
return updated_list raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
except Exception as e: # Catch other potential errors from crud operation
logger.error(f"Error updating list {list_id} for user {current_user.email}: {str(e)}")
# Consider a more generic error, but for now, let's keep it specific if possible
# Re-raising might be better if crud layer already raises appropriate HTTPExceptions
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the list.")
@router.delete( @router.delete(
"/{list_id}", "/{list_id}",
status_code=status.HTTP_204_NO_CONTENT, # Standard for successful DELETE with no body status_code=status.HTTP_204_NO_CONTENT, # Standard for successful DELETE with no body
summary="Delete List", summary="Delete List",
tags=["Lists"] tags=["Lists"],
responses={ # Add 409 to responses
status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified, cannot delete specified version"}
}
) )
async def delete_list( async def delete_list(
list_id: int, list_id: int,
expected_version: Optional[int] = Query(None, description="The expected version of the list to delete for optimistic locking."),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user), current_user: UserModel = Depends(get_current_user),
): ):
""" """
Deletes a list. Requires user to be the creator of the list. Deletes a list. Requires user to be the creator of the list.
(Alternatively, could allow group owner). If `expected_version` is provided and does not match the list's current version,
a 409 Conflict is returned.
""" """
logger.info(f"User {current_user.email} attempting to delete list ID: {list_id}") logger.info(f"User {current_user.email} attempting to delete list ID: {list_id}, expected version: {expected_version}")
# Use the helper, requiring creator permission # Use the helper, requiring creator permission
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id, require_creator=True) list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id, require_creator=True)
if expected_version is not None and list_db.version != expected_version:
logger.warning(
f"Conflict deleting list {list_id} for user {current_user.email}. "
f"Expected version {expected_version}, actual version {list_db.version}."
)
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"List has been modified. Expected version {expected_version}, but current version is {list_db.version}. Please refresh."
)
await crud_list.delete_list(db=db, list_db=list_db) await crud_list.delete_list(db=db, list_db=list_db)
logger.info(f"List {list_id} deleted successfully by user {current_user.email}.") logger.info(f"List {list_id} (version: {list_db.version}) deleted successfully by user {current_user.email}.")
return Response(status_code=status.HTTP_204_NO_CONTENT) return Response(status_code=status.HTTP_204_NO_CONTENT)

View File

@ -1,28 +1,27 @@
# app/api/v1/endpoints/ocr.py
import logging import logging
from typing import List from typing import List
from fastapi import APIRouter, Depends, UploadFile, File from fastapi import APIRouter, Depends, UploadFile, File, HTTPException, status
from google.api_core import exceptions as google_exceptions from google.api_core import exceptions as google_exceptions
from app.api.dependencies import get_current_user from app.api.dependencies import get_current_user
from app.models import User as UserModel from app.models import User as UserModel
from app.schemas.ocr import OcrExtractResponse from app.schemas.ocr import OcrExtractResponse
from app.core.gemini import extract_items_from_image_gemini, gemini_initialization_error from app.core.gemini import extract_items_from_image_gemini, gemini_initialization_error, GeminiOCRService
from app.core.exceptions import ( from app.core.exceptions import (
OcrServiceUnavailableError, OCRServiceUnavailableError,
OCRServiceConfigError,
OCRUnexpectedError,
OCRQuotaExceededError,
InvalidFileTypeError, InvalidFileTypeError,
FileTooLargeError, FileTooLargeError,
OcrProcessingError, OCRProcessingError
OcrQuotaExceededError
) )
from app.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
ocr_service = GeminiOCRService()
# Allowed image MIME types
ALLOWED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/webp"]
MAX_FILE_SIZE_MB = 10
@router.post( @router.post(
"/extract-items", "/extract-items",
@ -41,20 +40,20 @@ async def ocr_extract_items(
# Check if Gemini client initialized correctly # Check if Gemini client initialized correctly
if gemini_initialization_error: if gemini_initialization_error:
logger.error("OCR endpoint called but Gemini client failed to initialize.") logger.error("OCR endpoint called but Gemini client failed to initialize.")
raise OcrServiceUnavailableError(gemini_initialization_error) raise OCRServiceUnavailableError(gemini_initialization_error)
logger.info(f"User {current_user.email} uploading image '{image_file.filename}' for OCR extraction.") logger.info(f"User {current_user.email} uploading image '{image_file.filename}' for OCR extraction.")
# --- File Validation --- # --- File Validation ---
if image_file.content_type not in ALLOWED_IMAGE_TYPES: if image_file.content_type not in settings.ALLOWED_IMAGE_TYPES:
logger.warning(f"Invalid file type uploaded by {current_user.email}: {image_file.content_type}") logger.warning(f"Invalid file type uploaded by {current_user.email}: {image_file.content_type}")
raise InvalidFileTypeError(ALLOWED_IMAGE_TYPES) raise InvalidFileTypeError()
# Simple size check # Simple size check
contents = await image_file.read() contents = await image_file.read()
if len(contents) > MAX_FILE_SIZE_MB * 1024 * 1024: if len(contents) > settings.MAX_FILE_SIZE_MB * 1024 * 1024:
logger.warning(f"File too large uploaded by {current_user.email}: {len(contents)} bytes") logger.warning(f"File too large uploaded by {current_user.email}: {len(contents)} bytes")
raise FileTooLargeError(MAX_FILE_SIZE_MB) raise FileTooLargeError()
try: try:
# Call the Gemini helper function # Call the Gemini helper function
@ -66,30 +65,14 @@ async def ocr_extract_items(
logger.info(f"Successfully extracted {len(extracted_items)} items for user {current_user.email}.") logger.info(f"Successfully extracted {len(extracted_items)} items for user {current_user.email}.")
return OcrExtractResponse(extracted_items=extracted_items) return OcrExtractResponse(extracted_items=extracted_items)
except ValueError as e: except OCRServiceUnavailableError:
# Handle errors from Gemini processing (blocked, empty response, etc.) raise OCRServiceUnavailableError()
logger.warning(f"Gemini processing error for user {current_user.email}: {e}") except OCRServiceConfigError:
raise OcrProcessingError(str(e)) raise OCRServiceConfigError()
except OCRQuotaExceededError:
except google_exceptions.ResourceExhausted as e: raise OCRQuotaExceededError()
# Specific handling for quota errors
logger.error(f"Gemini Quota Exceeded for user {current_user.email}: {e}", exc_info=True)
raise OcrQuotaExceededError()
except google_exceptions.GoogleAPIError as e:
# Handle other Google API errors (e.g., invalid key, permissions)
logger.error(f"Gemini API Error for user {current_user.email}: {e}", exc_info=True)
raise OcrServiceUnavailableError(str(e))
except RuntimeError as e:
# Catch initialization errors from get_gemini_client()
logger.error(f"Gemini client runtime error during OCR request: {e}")
raise OcrServiceUnavailableError(f"OCR service configuration error: {e}")
except Exception as e: except Exception as e:
# Catch any other unexpected errors raise OCRProcessingError(str(e))
logger.exception(f"Unexpected error during OCR extraction for user {current_user.email}: {e}")
raise OcrServiceUnavailableError("An unexpected error occurred during item extraction.")
finally: finally:
# Ensure file handle is closed # Ensure file handle is closed

View File

@ -16,6 +16,86 @@ class Settings(BaseSettings):
SECRET_KEY: str # Must be set via environment variable SECRET_KEY: str # Must be set via environment variable
ALGORITHM: str = "HS256" ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # Default token lifetime: 30 minutes ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # Default token lifetime: 30 minutes
REFRESH_TOKEN_EXPIRE_MINUTES: int = 10080 # Default refresh token lifetime: 7 days
# --- OCR Settings ---
MAX_FILE_SIZE_MB: int = 10 # Maximum allowed file size for OCR processing
ALLOWED_IMAGE_TYPES: list[str] = ["image/jpeg", "image/png", "image/webp"] # Supported image formats
OCR_ITEM_EXTRACTION_PROMPT: str = """
Extract the shopping list items from this image.
List each distinct item on a new line.
Ignore prices, quantities, store names, discounts, taxes, totals, and other non-item text.
Focus only on the names of the products or items to be purchased.
Add 2 underscores before and after the item name, if it is struck through.
If the image does not appear to be a shopping list or receipt, state that clearly.
Example output for a grocery list:
Milk
Eggs
Bread
__Apples__
Organic Bananas
"""
# --- Gemini AI Settings ---
GEMINI_MODEL_NAME: str = "gemini-2.0-flash" # The model to use for OCR
GEMINI_SAFETY_SETTINGS: dict = {
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE",
}
GEMINI_GENERATION_CONFIG: dict = {
"candidate_count": 1,
"max_output_tokens": 2048,
"temperature": 0.9,
"top_p": 1,
"top_k": 1
}
# --- API Settings ---
API_PREFIX: str = "/api" # Base path for all API endpoints
API_OPENAPI_URL: str = "/api/openapi.json"
API_DOCS_URL: str = "/api/docs"
API_REDOC_URL: str = "/api/redoc"
CORS_ORIGINS: list[str] = [
"http://localhost:5174",
"http://localhost:8000",
# Add your deployed frontend URL here later
# "https://your-frontend-domain.com",
]
# --- API Metadata ---
API_TITLE: str = "Shared Lists API"
API_DESCRIPTION: str = "API for managing shared shopping lists, OCR, and cost splitting."
API_VERSION: str = "0.1.0"
ROOT_MESSAGE: str = "Welcome to the Shared Lists API! Docs available at /api/docs"
# --- Logging Settings ---
LOG_LEVEL: str = "INFO"
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
# --- Health Check Settings ---
HEALTH_STATUS_OK: str = "ok"
HEALTH_STATUS_ERROR: str = "error"
# --- Auth Settings ---
OAUTH2_TOKEN_URL: str = "/api/v1/auth/login" # Path to login endpoint
TOKEN_TYPE: str = "bearer" # Default token type for OAuth2
AUTH_HEADER_PREFIX: str = "Bearer" # Prefix for Authorization header
AUTH_HEADER_NAME: str = "WWW-Authenticate" # Name of auth header
AUTH_CREDENTIALS_ERROR: str = "Could not validate credentials"
AUTH_INVALID_CREDENTIALS: str = "Incorrect email or password"
AUTH_NOT_AUTHENTICATED: str = "Not authenticated"
# --- HTTP Status Messages ---
HTTP_400_DETAIL: str = "Bad Request"
HTTP_401_DETAIL: str = "Unauthorized"
HTTP_403_DETAIL: str = "Forbidden"
HTTP_404_DETAIL: str = "Not Found"
HTTP_422_DETAIL: str = "Unprocessable Entity"
HTTP_429_DETAIL: str = "Too Many Requests"
HTTP_500_DETAIL: str = "Internal Server Error"
HTTP_503_DETAIL: str = "Service Unavailable"
class Config: class Config:
env_file = ".env" env_file = ".env"

View File

@ -1,4 +1,7 @@
from fastapi import HTTPException, status from fastapi import HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from app.config import settings
from typing import Optional
class ListNotFoundError(HTTPException): class ListNotFoundError(HTTPException):
"""Raised when a list is not found.""" """Raised when a list is not found."""
@ -72,76 +75,105 @@ class ItemNotFoundError(HTTPException):
detail=f"Item {item_id} not found" detail=f"Item {item_id} not found"
) )
class UserNotFoundError(HTTPException):
"""Raised when a user is not found."""
def __init__(self, user_id: Optional[int] = None, identifier: Optional[str] = None):
detail_msg = "User not found."
if user_id:
detail_msg = f"User with ID {user_id} not found."
elif identifier:
detail_msg = f"User with identifier '{identifier}' not found."
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail_msg
)
class DatabaseConnectionError(HTTPException): class DatabaseConnectionError(HTTPException):
"""Raised when there is an error connecting to the database.""" """Raised when there is an error connecting to the database."""
def __init__(self, detail: str = "Database connection error"): def __init__(self):
super().__init__( super().__init__(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=detail detail=settings.DB_CONNECTION_ERROR
) )
class DatabaseIntegrityError(HTTPException): class DatabaseIntegrityError(HTTPException):
"""Raised when a database integrity constraint is violated.""" """Raised when a database integrity constraint is violated."""
def __init__(self, detail: str = "Database integrity error"): def __init__(self):
super().__init__( super().__init__(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=detail detail=settings.DB_INTEGRITY_ERROR
) )
class DatabaseTransactionError(HTTPException): class DatabaseTransactionError(HTTPException):
"""Raised when a database transaction fails.""" """Raised when a database transaction fails."""
def __init__(self, detail: str = "Database transaction error"): def __init__(self):
super().__init__( super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail detail=settings.DB_TRANSACTION_ERROR
) )
class DatabaseQueryError(HTTPException): class DatabaseQueryError(HTTPException):
"""Raised when a database query fails.""" """Raised when a database query fails."""
def __init__(self, detail: str = "Database query error"): def __init__(self):
super().__init__( super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail detail=settings.DB_QUERY_ERROR
) )
class OcrServiceUnavailableError(HTTPException): class OCRServiceUnavailableError(HTTPException):
"""Raised when the OCR service is unavailable.""" """Raised when the OCR service is unavailable."""
def __init__(self, detail: str): def __init__(self):
super().__init__( super().__init__(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"OCR service unavailable: {detail}" detail=settings.OCR_SERVICE_UNAVAILABLE
) )
class InvalidFileTypeError(HTTPException): class OCRServiceConfigError(HTTPException):
"""Raised when an invalid file type is uploaded for OCR.""" """Raised when there is an error in the OCR service configuration."""
def __init__(self, allowed_types: list[str]): def __init__(self):
super().__init__( super().__init__(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Invalid file type. Allowed types: {', '.join(allowed_types)}" detail=settings.OCR_SERVICE_CONFIG_ERROR
) )
class FileTooLargeError(HTTPException): class OCRUnexpectedError(HTTPException):
"""Raised when an uploaded file exceeds the size limit.""" """Raised when there is an unexpected error in the OCR service."""
def __init__(self, max_size_mb: int): def __init__(self):
super().__init__( super().__init__(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"File size exceeds limit of {max_size_mb} MB." detail=settings.OCR_UNEXPECTED_ERROR
) )
class OcrProcessingError(HTTPException): class OCRQuotaExceededError(HTTPException):
"""Raised when there is an error processing the image with OCR."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Could not extract items from image: {detail}"
)
class OcrQuotaExceededError(HTTPException):
"""Raised when the OCR service quota is exceeded.""" """Raised when the OCR service quota is exceeded."""
def __init__(self): def __init__(self):
super().__init__( super().__init__(
status_code=status.HTTP_429_TOO_MANY_REQUESTS, status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="OCR service quota exceeded. Please try again later." detail=settings.OCR_QUOTA_EXCEEDED
)
class InvalidFileTypeError(HTTPException):
"""Raised when an invalid file type is uploaded for OCR."""
def __init__(self):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=settings.OCR_INVALID_FILE_TYPE.format(types=", ".join(settings.ALLOWED_IMAGE_TYPES))
)
class FileTooLargeError(HTTPException):
"""Raised when an uploaded file exceeds the size limit."""
def __init__(self):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=settings.OCR_FILE_TOO_LARGE.format(size=settings.MAX_FILE_SIZE_MB)
)
class OCRProcessingError(HTTPException):
"""Raised when there is an error processing the image with OCR."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=settings.OCR_PROCESSING_ERROR.format(detail=detail)
) )
class EmailAlreadyRegisteredError(HTTPException): class EmailAlreadyRegisteredError(HTTPException):
@ -152,15 +184,6 @@ class EmailAlreadyRegisteredError(HTTPException):
detail="Email already registered." detail="Email already registered."
) )
class InvalidCredentialsError(HTTPException):
"""Raised when login credentials are invalid."""
def __init__(self):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"}
)
class UserCreationError(HTTPException): class UserCreationError(HTTPException):
"""Raised when there is an error creating a new user.""" """Raised when there is an error creating a new user."""
def __init__(self): def __init__(self):
@ -208,3 +231,47 @@ class ListStatusNotFoundError(HTTPException):
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Status for list {list_id} not found" detail=f"Status for list {list_id} not found"
) )
class ConflictError(HTTPException):
"""Raised when an optimistic lock version conflict occurs."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_409_CONFLICT,
detail=detail
)
class InvalidCredentialsError(HTTPException):
"""Raised when login credentials are invalid."""
def __init__(self):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.AUTH_INVALID_CREDENTIALS,
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_credentials\""}
)
class NotAuthenticatedError(HTTPException):
"""Raised when the user is not authenticated."""
def __init__(self):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.AUTH_NOT_AUTHENTICATED,
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"not_authenticated\""}
)
class JWTError(HTTPException):
"""Raised when there is an error with the JWT token."""
def __init__(self, error: str):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.JWT_ERROR.format(error=error),
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
)
class JWTUnexpectedError(HTTPException):
"""Raised when there is an unexpected error with the JWT token."""
def __init__(self, error: str):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.JWT_UNEXPECTED_ERROR.format(error=error),
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
)

View File

@ -4,8 +4,13 @@ from typing import List
import google.generativeai as genai import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold # For safety settings from google.generativeai.types import HarmCategory, HarmBlockThreshold # For safety settings
from google.api_core import exceptions as google_exceptions from google.api_core import exceptions as google_exceptions
from app.config import settings from app.config import settings
from app.core.exceptions import (
OCRServiceUnavailableError,
OCRServiceConfigError,
OCRUnexpectedError,
OCRQuotaExceededError
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,26 +24,18 @@ try:
genai.configure(api_key=settings.GEMINI_API_KEY) genai.configure(api_key=settings.GEMINI_API_KEY)
# Initialize the specific model we want to use # Initialize the specific model we want to use
gemini_flash_client = genai.GenerativeModel( gemini_flash_client = genai.GenerativeModel(
model_name="gemini-2.0-flash", model_name=settings.GEMINI_MODEL_NAME,
# Optional: Add default safety settings # Safety settings from config
# Adjust these based on your expected content and risk tolerance
safety_settings={ safety_settings={
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, getattr(HarmCategory, category): getattr(HarmBlockThreshold, threshold)
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, for category, threshold in settings.GEMINI_SAFETY_SETTINGS.items()
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
}, },
# Optional: Add default generation config (can be overridden per request) # Generation config from settings
# generation_config=genai.types.GenerationConfig( generation_config=genai.types.GenerationConfig(
# # candidate_count=1, # Usually default is 1 **settings.GEMINI_GENERATION_CONFIG
# # stop_sequences=["\n"], )
# # max_output_tokens=2048,
# # temperature=0.9, # Controls randomness (0=deterministic, >1=more random)
# # top_p=1,
# # top_k=1
# )
) )
logger.info("Gemini AI client initialized successfully for model 'gemini-1.5-flash-latest'.") logger.info(f"Gemini AI client initialized successfully for model '{settings.GEMINI_MODEL_NAME}'.")
else: else:
# Store error if API key is missing # Store error if API key is missing
gemini_initialization_error = "GEMINI_API_KEY not configured. Gemini client not initialized." gemini_initialization_error = "GEMINI_API_KEY not configured. Gemini client not initialized."
@ -105,7 +102,7 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "
# Prepare the full prompt content # Prepare the full prompt content
prompt_parts = [ prompt_parts = [
OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first settings.OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first
image_part # Then the image image_part # Then the image
] ]
@ -153,3 +150,46 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True) logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
# Wrap in a generic ValueError or re-raise # Wrap in a generic ValueError or re-raise
raise ValueError(f"Failed to process image with Gemini: {e}") from e raise ValueError(f"Failed to process image with Gemini: {e}") from e
class GeminiOCRService:
def __init__(self):
try:
genai.configure(api_key=settings.GEMINI_API_KEY)
self.model = genai.GenerativeModel(settings.GEMINI_MODEL_NAME)
self.model.safety_settings = settings.GEMINI_SAFETY_SETTINGS
self.model.generation_config = settings.GEMINI_GENERATION_CONFIG
except Exception as e:
logger.error(f"Failed to initialize Gemini client: {e}")
raise OCRServiceConfigError()
async def extract_items(self, image_data: bytes) -> List[str]:
"""
Extract shopping list items from an image using Gemini Vision.
"""
try:
# Create image part
image_parts = [{"mime_type": "image/jpeg", "data": image_data}]
# Generate content
response = await self.model.generate_content_async(
contents=[settings.OCR_ITEM_EXTRACTION_PROMPT, *image_parts]
)
# Process response
if not response.text:
raise OCRUnexpectedError()
# Split response into lines and clean up
items = [
item.strip()
for item in response.text.split("\n")
if item.strip() and not item.strip().startswith("Example")
]
return items
except Exception as e:
logger.error(f"Error during OCR extraction: {e}")
if "quota" in str(e).lower():
raise OCRQuotaExceededError()
raise OCRServiceUnavailableError()

View File

@ -66,7 +66,34 @@ def create_access_token(subject: Union[str, Any], expires_delta: Optional[timede
) )
# Data to encode in the token payload # Data to encode in the token payload
to_encode = {"exp": expire, "sub": str(subject)} to_encode = {"exp": expire, "sub": str(subject), "type": "access"}
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def create_refresh_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
"""
Creates a JWT refresh token.
Args:
subject: The subject of the token (e.g., user ID or email).
expires_delta: Optional timedelta object for token expiry. If None,
uses REFRESH_TOKEN_EXPIRE_MINUTES from settings.
Returns:
The encoded JWT refresh token string.
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
)
# Data to encode in the token payload
to_encode = {"exp": expire, "sub": str(subject), "type": "refresh"}
encoded_jwt = jwt.encode( encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
@ -91,6 +118,8 @@ def verify_access_token(token: str) -> Optional[dict]:
payload = jwt.decode( payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
) )
if payload.get("type") != "access":
raise JWTError("Invalid token type")
return payload return payload
except JWTError as e: except JWTError as e:
# Handles InvalidSignatureError, ExpiredSignatureError, etc. # Handles InvalidSignatureError, ExpiredSignatureError, etc.
@ -101,6 +130,31 @@ def verify_access_token(token: str) -> Optional[dict]:
print(f"Unexpected error decoding JWT: {e}") print(f"Unexpected error decoding JWT: {e}")
return None return None
def verify_refresh_token(token: str) -> Optional[dict]:
"""
Verifies a JWT refresh token and returns its payload if valid.
Args:
token: The JWT token string to verify.
Returns:
The decoded token payload (dict) if the token is valid, not expired,
and is a refresh token, otherwise None.
"""
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
if payload.get("type") != "refresh":
raise JWTError("Invalid token type")
return payload
except JWTError as e:
print(f"JWT Error: {e}") # Log the error for debugging
return None
except Exception as e:
print(f"Unexpected error decoding JWT: {e}")
return None
# You might add a function here later to extract the 'sub' (subject/user id) # You might add a function here later to extract the 'sub' (subject/user id)
# specifically, often used in dependency injection for authentication. # specifically, often used in dependency injection for authentication.
# def get_subject_from_token(token: str) -> Optional[str]: # def get_subject_from_token(token: str) -> Optional[str]:

116
be/app/crud/cost.py Normal file
View File

@ -0,0 +1,116 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from decimal import Decimal, ROUND_HALF_UP
from typing import List as PyList, Dict, Set
from app.models import List as ListModel, Item as ItemModel, User as UserModel, UserGroup as UserGroupModel, Group as GroupModel
from app.schemas.cost import ListCostSummary, UserCostShare
from app.core.exceptions import ListNotFoundError, UserNotFoundError # Assuming UserNotFoundError might be useful
async def calculate_list_cost_summary(db: AsyncSession, list_id: int) -> ListCostSummary:
"""
Calculates the cost summary for a given list, splitting costs equally
among relevant users (group members if list is in a group, or creator if personal).
"""
# 1. Fetch the list, its items (with their 'added_by_user'), and its group (with members)
list_result = await db.execute(
select(ListModel)
.options(
selectinload(ListModel.items).options(joinedload(ItemModel.added_by_user)),
selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user)))
)
.where(ListModel.id == list_id)
)
db_list: Optional[ListModel] = list_result.scalars().first()
if not db_list:
raise ListNotFoundError(list_id)
# 2. Determine participating users
participating_users: Dict[int, UserModel] = {}
if db_list.group:
# If list is part of a group, all group members participate
for ug_assoc in db_list.group.user_associations:
if ug_assoc.user: # Ensure user object is loaded
participating_users[ug_assoc.user.id] = ug_assoc.user
else:
# If personal list, only the creator participates (or if items were added by others somehow, include them)
# For simplicity in MVP, if personal, only creator. If shared personal lists become a feature, this needs revisit.
# Let's fetch the creator if not already available through relationships (though it should be via ListModel.creator)
creator_user = await db.get(UserModel, db_list.created_by_id)
if not creator_user:
# This case should ideally not happen if data integrity is maintained
raise UserNotFoundError(user_id=db_list.created_by_id) # Or a more specific error
participating_users[creator_user.id] = creator_user
# Also ensure all users who added items are included, even if not in the group (edge case, but good for robustness)
for item in db_list.items:
if item.added_by_user and item.added_by_user.id not in participating_users:
participating_users[item.added_by_user.id] = item.added_by_user
num_participating_users = len(participating_users)
if num_participating_users == 0:
# Handle case with no users (e.g., empty group, or personal list creator deleted - though FKs should prevent)
# Or if list has no items and is personal, creator might not be in participating_users if logic changes.
# For now, if no users found (e.g. group with no members and list creator somehow not added), return empty/default summary.
return ListCostSummary(
list_id=db_list.id,
list_name=db_list.name,
total_list_cost=Decimal("0.00"),
num_participating_users=0,
equal_share_per_user=Decimal("0.00"),
user_balances=[]
)
# 3. Calculate total cost and items_added_value for each user
total_list_cost = Decimal("0.00")
user_items_added_value: Dict[int, Decimal] = {user_id: Decimal("0.00") for user_id in participating_users.keys()}
for item in db_list.items:
if item.price is not None and item.price > Decimal("0"):
total_list_cost += item.price
if item.added_by_id in user_items_added_value: # Item adder must be in participating users
user_items_added_value[item.added_by_id] += item.price
# If item.added_by_id is not in participating_users (e.g. user left group),
# their contribution still counts to total cost, but they aren't part of the split.
# The current logic adds item adders to participating_users, so this else is less likely.
# 4. Calculate equal share per user
# Using ROUND_HALF_UP to handle cents appropriately.
# Ensure division by zero is handled if num_participating_users could be 0 (already handled above)
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) if num_participating_users > 0 else Decimal("0.00")
# 5. For each user, calculate their balance
user_balances: PyList[UserCostShare] = []
for user_id, user_obj in participating_users.items():
items_added = user_items_added_value.get(user_id, Decimal("0.00"))
balance = items_added - equal_share_per_user
user_identifier = user_obj.name if user_obj.name else user_obj.email # Prefer name, fallback to email
user_balances.append(
UserCostShare(
user_id=user_id,
user_identifier=user_identifier,
items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
amount_due=equal_share_per_user, # Already quantized
balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
)
)
# Sort user_balances for consistent output, e.g., by user_id or identifier
user_balances.sort(key=lambda x: x.user_identifier)
# 6. Return the populated ListCostSummary schema
return ListCostSummary(
list_id=db_list.id,
list_name=db_list.name,
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
num_participating_users=num_participating_users,
equal_share_per_user=equal_share_per_user, # Already quantized
user_balances=user_balances
)

View File

@ -13,7 +13,8 @@ from app.core.exceptions import (
DatabaseConnectionError, DatabaseConnectionError,
DatabaseIntegrityError, DatabaseIntegrityError,
DatabaseQueryError, DatabaseQueryError,
DatabaseTransactionError DatabaseTransactionError,
ConflictError
) )
async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel: async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel:
@ -26,6 +27,7 @@ async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_
list_id=list_id, list_id=list_id,
added_by_id=user_id, added_by_id=user_id,
is_complete=False # Default on creation is_complete=False # Default on creation
# version is implicitly set to 1 by model default
) )
db.add(db_item) db.add(db_item)
await db.flush() await db.flush()
@ -65,44 +67,57 @@ async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]:
raise DatabaseQueryError(f"Failed to query item: {str(e)}") raise DatabaseQueryError(f"Failed to query item: {str(e)}")
async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel: async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel:
"""Updates an existing item record.""" """Updates an existing item record, checking for version conflicts."""
try: try:
async with db.begin(): async with db.begin():
update_data = item_in.model_dump(exclude_unset=True) # Get only provided fields if item_db.version != item_in.version:
raise ConflictError(
f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. "
f"Your version is {item_in.version}, current version is {item_db.version}. Please refresh."
)
update_data = item_in.model_dump(exclude_unset=True, exclude={'version'}) # Exclude version
# Special handling for is_complete # Special handling for is_complete
if 'is_complete' in update_data: if 'is_complete' in update_data:
if update_data['is_complete'] is True: if update_data['is_complete'] is True:
# Mark as complete: set completed_by_id if not already set if item_db.completed_by_id is None: # Only set if not already completed by someone
if item_db.completed_by_id is None:
update_data['completed_by_id'] = user_id update_data['completed_by_id'] = user_id
else: else:
# Mark as incomplete: clear completed_by_id update_data['completed_by_id'] = None # Clear if marked incomplete
update_data['completed_by_id'] = None
# Ensure updated_at is refreshed (handled by onupdate in model, but explicit is fine too)
# update_data['updated_at'] = datetime.now(timezone.utc)
for key, value in update_data.items(): for key, value in update_data.items():
setattr(item_db, key, value) setattr(item_db, key, value)
db.add(item_db) # Add to session to track changes item_db.version += 1 # Increment version
db.add(item_db)
await db.flush() await db.flush()
await db.refresh(item_db) await db.refresh(item_db)
return item_db return item_db
except IntegrityError as e: except IntegrityError as e:
raise DatabaseIntegrityError(f"Failed to update item: {str(e)}") await db.rollback()
raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}")
except OperationalError as e: except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error: {str(e)}") await db.rollback()
raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}")
except ConflictError: # Re-raise ConflictError
await db.rollback()
raise
except SQLAlchemyError as e: except SQLAlchemyError as e:
await db.rollback()
raise DatabaseTransactionError(f"Failed to update item: {str(e)}") raise DatabaseTransactionError(f"Failed to update item: {str(e)}")
async def delete_item(db: AsyncSession, item_db: ItemModel) -> None: async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
"""Deletes an item record.""" """Deletes an item record. Version check should be done by the caller (API endpoint)."""
try: try:
async with db.begin(): async with db.begin():
await db.delete(item_db) await db.delete(item_db)
return None # Or return True/False # await db.commit() # Not needed with async with db.begin()
return None
except OperationalError as e: except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error: {str(e)}") await db.rollback()
raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
await db.rollback()
raise DatabaseTransactionError(f"Failed to delete item: {str(e)}") raise DatabaseTransactionError(f"Failed to delete item: {str(e)}")

View File

@ -16,7 +16,8 @@ from app.core.exceptions import (
DatabaseConnectionError, DatabaseConnectionError,
DatabaseIntegrityError, DatabaseIntegrityError,
DatabaseQueryError, DatabaseQueryError,
DatabaseTransactionError DatabaseTransactionError,
ConflictError
) )
async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel: async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
@ -85,32 +86,50 @@ async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = Fals
raise DatabaseQueryError(f"Failed to query list: {str(e)}") raise DatabaseQueryError(f"Failed to query list: {str(e)}")
async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel: async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel:
"""Updates an existing list record.""" """Updates an existing list record, checking for version conflicts."""
try: try:
async with db.begin(): async with db.begin():
update_data = list_in.model_dump(exclude_unset=True) if list_db.version != list_in.version:
raise ConflictError(
f"List '{list_db.name}' (ID: {list_db.id}) has been modified. "
f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh."
)
update_data = list_in.model_dump(exclude_unset=True, exclude={'version'})
for key, value in update_data.items(): for key, value in update_data.items():
setattr(list_db, key, value) setattr(list_db, key, value)
list_db.version += 1
db.add(list_db) db.add(list_db)
await db.flush() await db.flush()
await db.refresh(list_db) await db.refresh(list_db)
return list_db return list_db
except IntegrityError as e: except IntegrityError as e:
raise DatabaseIntegrityError(f"Failed to update list: {str(e)}") await db.rollback()
raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}")
except OperationalError as e: except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error: {str(e)}") await db.rollback()
raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}")
except ConflictError:
await db.rollback()
raise
except SQLAlchemyError as e: except SQLAlchemyError as e:
await db.rollback()
raise DatabaseTransactionError(f"Failed to update list: {str(e)}") raise DatabaseTransactionError(f"Failed to update list: {str(e)}")
async def delete_list(db: AsyncSession, list_db: ListModel) -> None: async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
"""Deletes a list record.""" """Deletes a list record. Version check should be done by the caller (API endpoint)."""
try: try:
async with db.begin(): async with db.begin():
await db.delete(list_db) await db.delete(list_db)
return None return None
except OperationalError as e: except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error: {str(e)}") await db.rollback()
raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
await db.rollback()
raise DatabaseTransactionError(f"Failed to delete list: {str(e)}") raise DatabaseTransactionError(f"Failed to delete list: {str(e)}")
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel: async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel:

View File

@ -5,22 +5,25 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from app.api.api_router import api_router # Import the main combined router from app.api.api_router import api_router # Import the main combined router
from app.config import settings
# Import database and models if needed for startup/shutdown events later # Import database and models if needed for startup/shutdown events later
# from . import database, models # from . import database, models
# --- Logging Setup --- # --- Logging Setup ---
# Configure logging (can be more sophisticated later, e.g., using logging.yaml) logging.basicConfig(
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') level=getattr(logging, settings.LOG_LEVEL),
format=settings.LOG_FORMAT
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# --- FastAPI App Instance --- # --- FastAPI App Instance ---
app = FastAPI( app = FastAPI(
title="Shared Lists API", title=settings.API_TITLE,
description="API for managing shared shopping lists, OCR, and cost splitting.", description=settings.API_DESCRIPTION,
version="0.1.0", version=settings.API_VERSION,
openapi_url="/api/openapi.json", # Place OpenAPI spec under /api openapi_url=settings.API_OPENAPI_URL,
docs_url="/api/docs", # Place Swagger UI under /api docs_url=settings.API_DOCS_URL,
redoc_url="/api/redoc" # Place ReDoc under /api redoc_url=settings.API_REDOC_URL
) )
# --- CORS Middleware --- # --- CORS Middleware ---
@ -37,17 +40,17 @@ origins = [
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=origins, # List of origins that are allowed to make requests allow_origins=settings.CORS_ORIGINS,
allow_credentials=True, # Allow cookies to be included in requests allow_credentials=True,
allow_methods=["*"], # Allow all methods (GET, POST, PUT, DELETE, etc.) allow_methods=["*"],
allow_headers=["*"], # Allow all headers allow_headers=["*"],
) )
# --- End CORS Middleware --- # --- End CORS Middleware ---
# --- Include API Routers --- # --- Include API Routers ---
# All API endpoints will be prefixed with /api # All API endpoints will be prefixed with /api
app.include_router(api_router, prefix="/api") app.include_router(api_router, prefix=settings.API_PREFIX)
# --- End Include API Routers --- # --- End Include API Routers ---
@ -59,10 +62,7 @@ async def read_root():
Useful for basic reachability checks. Useful for basic reachability checks.
""" """
logger.info("Root endpoint '/' accessed.") logger.info("Root endpoint '/' accessed.")
# You could redirect to the docs or return a simple message return {"message": settings.ROOT_MESSAGE}
# from fastapi.responses import RedirectResponse
# return RedirectResponse(url="/api/docs")
return {"message": "Welcome to the Shared Lists API! Docs available at /api/docs"}
# --- End Root Endpoint --- # --- End Root Endpoint ---

View File

@ -117,6 +117,7 @@ class List(Base):
is_complete = Column(Boolean, default=False, nullable=False) is_complete = Column(Boolean, default=False, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
version = Column(Integer, nullable=False, default=1, server_default='1')
# --- Relationships --- # --- Relationships ---
creator = relationship("User", back_populates="created_lists") # Link to User.created_lists creator = relationship("User", back_populates="created_lists") # Link to User.created_lists
@ -138,6 +139,7 @@ class Item(Base):
completed_by_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Who marked it complete completed_by_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Who marked it complete
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
version = Column(Integer, nullable=False, default=1, server_default='1')
# --- Relationships --- # --- Relationships ---
list = relationship("List", back_populates="items") # Link to List.items list = relationship("List", back_populates="items") # Link to List.items

View File

@ -1,9 +1,11 @@
# app/schemas/auth.py # app/schemas/auth.py
from pydantic import BaseModel, EmailStr from pydantic import BaseModel, EmailStr
from app.config import settings
class Token(BaseModel): class Token(BaseModel):
access_token: str access_token: str
token_type: str = "bearer" # Default token type refresh_token: str # Added refresh token
token_type: str = settings.TOKEN_TYPE # Use configured token type
# Optional: If you preferred not to use OAuth2PasswordRequestForm # Optional: If you preferred not to use OAuth2PasswordRequestForm
# class UserLogin(BaseModel): # class UserLogin(BaseModel):

22
be/app/schemas/cost.py Normal file
View File

@ -0,0 +1,22 @@
from pydantic import BaseModel, ConfigDict
from typing import List, Optional
from decimal import Decimal
class UserCostShare(BaseModel):
user_id: int
user_identifier: str # Name or email
items_added_value: Decimal = Decimal("0.00") # Total value of items this user added
amount_due: Decimal # The user's share of the total cost (for equal split, this is total_cost / num_users)
balance: Decimal # items_added_value - amount_due
model_config = ConfigDict(from_attributes=True)
class ListCostSummary(BaseModel):
list_id: int
list_name: str
total_list_cost: Decimal
num_participating_users: int
equal_share_per_user: Decimal
user_balances: List[UserCostShare]
model_config = ConfigDict(from_attributes=True)

View File

@ -1,9 +1,10 @@
# app/schemas/health.py # app/schemas/health.py
from pydantic import BaseModel from pydantic import BaseModel
from app.config import settings
class HealthStatus(BaseModel): class HealthStatus(BaseModel):
""" """
Response model for the health check endpoint. Response model for the health check endpoint.
""" """
status: str = "ok" # Provide a default value status: str = settings.HEALTH_STATUS_OK # Use configured default value
database: str database: str

View File

@ -16,6 +16,7 @@ class ItemPublic(BaseModel):
completed_by_id: Optional[int] = None completed_by_id: Optional[int] = None
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
version: int
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
# Properties to receive via API on creation # Properties to receive via API on creation
@ -31,4 +32,5 @@ class ItemUpdate(BaseModel):
quantity: Optional[str] = None quantity: Optional[str] = None
is_complete: Optional[bool] = None is_complete: Optional[bool] = None
price: Optional[Decimal] = None # Price added here for update price: Optional[Decimal] = None # Price added here for update
version: int
# completed_by_id will be set internally if is_complete is true # completed_by_id will be set internally if is_complete is true

View File

@ -16,6 +16,7 @@ class ListUpdate(BaseModel):
name: Optional[str] = None name: Optional[str] = None
description: Optional[str] = None description: Optional[str] = None
is_complete: Optional[bool] = None is_complete: Optional[bool] = None
version: int # Client must provide the version for updates
# Potentially add group_id update later if needed # Potentially add group_id update later if needed
# Base properties returned by API (common fields) # Base properties returned by API (common fields)
@ -28,6 +29,7 @@ class ListBase(BaseModel):
is_complete: bool is_complete: bool
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
version: int # Include version in responses
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)