diff --git a/be/Dockerfile b/be/Dockerfile index a6e5d5d..a2f5925 100644 --- a/be/Dockerfile +++ b/be/Dockerfile @@ -32,4 +32,4 @@ EXPOSE 8000 # The default command for production (can be overridden in docker-compose for development) # Note: Make sure 'app.main:app' correctly points to your FastAPI app instance # relative to the WORKDIR (/app). If your main.py is directly in /app, this is correct. -CMD ["uvicorn", "app.main:app", "--host", "localhost", "--port", "8000"] \ No newline at end of file +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/be/alembic/versions/d53eedd151b7_add_version_to_lists_and_items.py b/be/alembic/versions/d53eedd151b7_add_version_to_lists_and_items.py new file mode 100644 index 0000000..dca040f --- /dev/null +++ b/be/alembic/versions/d53eedd151b7_add_version_to_lists_and_items.py @@ -0,0 +1,28 @@ +"""add_version_to_lists_and_items + +Revision ID: d53eedd151b7 +Revises: d25788f63e2c +Create Date: 2025-05-07 21:05:50.396430 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'd53eedd151b7' +down_revision: Union[str, None] = 'd25788f63e2c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + pass + + +def downgrade() -> None: + """Downgrade schema.""" + pass diff --git a/be/app/api/dependencies.py b/be/app/api/dependencies.py index 98d24e2..b0ad596 100644 --- a/be/app/api/dependencies.py +++ b/be/app/api/dependencies.py @@ -11,13 +11,14 @@ from app.database import get_db from app.core.security import verify_access_token from app.crud import user as crud_user from app.models import User as UserModel # Import the SQLAlchemy model +from app.config import settings logger = logging.getLogger(__name__) # Define the OAuth2 scheme # tokenUrl should point to your login endpoint relative to the base path # It's used by Swagger UI for the "Authorize" button flow. -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") # Corrected path +oauth2_scheme = OAuth2PasswordBearer(tokenUrl=settings.OAUTH2_TOKEN_URL) async def get_current_user( token: str = Depends(oauth2_scheme), @@ -36,8 +37,8 @@ async def get_current_user( """ credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, + detail=settings.AUTH_CREDENTIALS_ERROR, + headers={settings.AUTH_HEADER_NAME: settings.AUTH_HEADER_PREFIX}, ) payload = verify_access_token(token) diff --git a/be/app/api/v1/api.py b/be/app/api/v1/api.py index 640c569..feb43c1 100644 --- a/be/app/api/v1/api.py +++ b/be/app/api/v1/api.py @@ -9,6 +9,7 @@ from app.api.v1.endpoints import invites from app.api.v1.endpoints import lists from app.api.v1.endpoints import items from app.api.v1.endpoints import ocr +from app.api.v1.endpoints import costs api_router_v1 = APIRouter() @@ -20,5 +21,6 @@ api_router_v1.include_router(invites.router, prefix="/invites", tags=["Invites"] api_router_v1.include_router(lists.router, prefix="/lists", tags=["Lists"]) api_router_v1.include_router(items.router, tags=["Items"]) api_router_v1.include_router(ocr.router, prefix="/ocr", tags=["OCR"]) +api_router_v1.include_router(costs.router, tags=["Costs"]) # Add other v1 endpoint routers here later # e.g., api_router_v1.include_router(users.router, prefix="/users", tags=["Users"]) \ No newline at end of file diff --git a/be/app/api/v1/endpoints/auth.py b/be/app/api/v1/endpoints/auth.py index fcc4532..6871d65 100644 --- a/be/app/api/v1/endpoints/auth.py +++ b/be/app/api/v1/endpoints/auth.py @@ -1,6 +1,7 @@ # app/api/v1/endpoints/auth.py import logging -from fastapi import APIRouter, Depends +from typing import Annotated +from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm from sqlalchemy.ext.asyncio import AsyncSession @@ -8,12 +9,18 @@ from app.database import get_db from app.schemas.user import UserCreate, UserPublic from app.schemas.auth import Token from app.crud import user as crud_user -from app.core.security import verify_password, create_access_token +from app.core.security import ( + verify_password, + create_access_token, + create_refresh_token, + verify_refresh_token +) from app.core.exceptions import ( EmailAlreadyRegisteredError, InvalidCredentialsError, UserCreationError ) +from app.config import settings logger = logging.getLogger(__name__) router = APIRouter() @@ -55,28 +62,74 @@ async def signup( "/login", response_model=Token, summary="User Login", - description="Authenticates a user and returns an access token.", + description="Authenticates a user and returns an access and refresh token.", tags=["Authentication"] ) async def login( - form_data: OAuth2PasswordRequestForm = Depends(), + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: AsyncSession = Depends(get_db) ): """ Handles user login. - Finds user by email (provided in 'username' field of form). - Verifies the provided password against the stored hash. - - Generates and returns a JWT access token upon successful authentication. + - Generates and returns JWT access and refresh tokens upon successful authentication. """ logger.info(f"Login attempt for user: {form_data.username}") user = await crud_user.get_user_by_email(db, email=form_data.username) - # Check if user exists and password is correct if not user or not verify_password(form_data.password, user.password_hash): logger.warning(f"Login failed: Invalid credentials for user {form_data.username}") raise InvalidCredentialsError() - # Generate JWT access_token = create_access_token(subject=user.email) - logger.info(f"Login successful, token generated for user: {user.email}") - return Token(access_token=access_token, token_type="bearer") \ No newline at end of file + refresh_token = create_refresh_token(subject=user.email) + logger.info(f"Login successful, tokens generated for user: {user.email}") + return Token( + access_token=access_token, + refresh_token=refresh_token, + token_type=settings.TOKEN_TYPE + ) + +@router.post( + "/refresh", + response_model=Token, + summary="Refresh Access Token", + description="Refreshes an access token using a refresh token.", + tags=["Authentication"] +) +async def refresh_token( + refresh_token_str: str, + db: AsyncSession = Depends(get_db) +): + """ + Handles access token refresh. + - Verifies the provided refresh token. + - If valid, generates and returns a new JWT access token and the same refresh token. + """ + logger.info("Access token refresh attempt") + payload = verify_refresh_token(refresh_token_str) + if not payload: + logger.warning("Refresh token invalid or expired") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired refresh token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + user_email = payload.get("sub") + if not user_email: + logger.error("User email not found in refresh token payload") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid refresh token payload", + headers={"WWW-Authenticate": "Bearer"}, + ) + + new_access_token = create_access_token(subject=user_email) + logger.info(f"Access token refreshed for user: {user_email}") + return Token( + access_token=new_access_token, + refresh_token=refresh_token_str, + token_type=settings.TOKEN_TYPE + ) \ No newline at end of file diff --git a/be/app/api/v1/endpoints/costs.py b/be/app/api/v1/endpoints/costs.py new file mode 100644 index 0000000..9f8035e --- /dev/null +++ b/be/app/api/v1/endpoints/costs.py @@ -0,0 +1,69 @@ +# app/api/v1/endpoints/costs.py +import logging +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.database import get_db +from app.api.dependencies import get_current_user +from app.models import User as UserModel # For get_current_user dependency +from app.schemas.cost import ListCostSummary +from app.crud import cost as crud_cost +from app.crud import list as crud_list # For permission checking +from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError + +logger = logging.getLogger(__name__) +router = APIRouter() + +@router.get( + "/lists/{list_id}/cost-summary", + response_model=ListCostSummary, + summary="Get Cost Summary for a List", + tags=["Costs"], + responses={ + status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this list"}, + status.HTTP_404_NOT_FOUND: {"description": "List or associated user not found"} + } +) +async def get_list_cost_summary( + list_id: int, + db: AsyncSession = Depends(get_db), + current_user: UserModel = Depends(get_current_user), +): + """ + Retrieves a calculated cost summary for a specific list, detailing total costs, + equal shares per user, and individual user balances based on their contributions. + + The user must have access to the list to view its cost summary. + Costs are split among group members if the list belongs to a group, or just for + the creator if it's a personal list. All users who added items with prices are + included in the calculation. + """ + logger.info(f"User {current_user.email} requesting cost summary for list {list_id}") + + # 1. Verify user has access to the target list + try: + await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) + except ListPermissionError as e: + logger.warning(f"Permission denied for user {current_user.email} on list {list_id}: {str(e)}") + raise # Re-raise the original exception to be handled by FastAPI + except ListNotFoundError as e: + logger.warning(f"List {list_id} not found when checking permissions for cost summary: {str(e)}") + raise # Re-raise + + # 2. Calculate the cost summary + try: + cost_summary = await crud_cost.calculate_list_cost_summary(db=db, list_id=list_id) + logger.info(f"Successfully generated cost summary for list {list_id} for user {current_user.email}") + return cost_summary + except ListNotFoundError as e: + logger.warning(f"List {list_id} not found during cost summary calculation: {str(e)}") + # This might be redundant if check_list_permission already confirmed list existence, + # but calculate_list_cost_summary also fetches the list. + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) + except UserNotFoundError as e: + logger.error(f"User not found during cost summary calculation for list {list_id}: {str(e)}") + # This indicates a data integrity issue (e.g., list creator or item adder missing) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) + except Exception as e: + logger.error(f"Unexpected error generating cost summary for list {list_id} for user {current_user.email}: {str(e)}", exc_info=True) + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while generating the cost summary.") \ No newline at end of file diff --git a/be/app/api/v1/endpoints/items.py b/be/app/api/v1/endpoints/items.py index 2188ddd..f44554d 100644 --- a/be/app/api/v1/endpoints/items.py +++ b/be/app/api/v1/endpoints/items.py @@ -1,8 +1,8 @@ # app/api/v1/endpoints/items.py import logging -from typing import List as PyList +from typing import List as PyList, Optional -from fastapi import APIRouter, Depends, HTTPException, status, Response +from fastapi import APIRouter, Depends, HTTPException, status, Response, Query from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db @@ -14,7 +14,7 @@ from app.models import Item as ItemModel # <-- IMPORT Item and alias it from app.schemas.item import ItemCreate, ItemUpdate, ItemPublic from app.crud import item as crud_item from app.crud import list as crud_list -from app.core.exceptions import ItemNotFoundError, ListPermissionError +from app.core.exceptions import ItemNotFoundError, ListPermissionError, ConflictError logger = logging.getLogger(__name__) router = APIRouter() @@ -100,7 +100,10 @@ async def read_list_items( "/items/{item_id}", # Operate directly on item ID response_model=ItemPublic, summary="Update Item", - tags=["Items"] + tags=["Items"], + responses={ + status.HTTP_409_CONFLICT: {"description": "Conflict: Item has been modified by someone else"} + } ) async def update_item( item_id: int, # Item ID from path @@ -112,37 +115,61 @@ async def update_item( """ Updates an item's details (name, quantity, is_complete, price). User must have access to the list the item belongs to. + The client MUST provide the current `version` of the item in the `item_in` payload. + If the version does not match, a 409 Conflict is returned. Sets/unsets `completed_by_id` based on `is_complete` flag. """ - logger.info(f"User {current_user.email} attempting to update item ID: {item_id}") + logger.info(f"User {current_user.email} attempting to update item ID: {item_id} with version {item_in.version}") # Permission check is handled by get_item_and_verify_access dependency - updated_item = await crud_item.update_item( - db=db, item_db=item_db, item_in=item_in, user_id=current_user.id - ) - logger.info(f"Item {item_id} updated successfully by user {current_user.email}.") - return updated_item + try: + updated_item = await crud_item.update_item( + db=db, item_db=item_db, item_in=item_in, user_id=current_user.id + ) + logger.info(f"Item {item_id} updated successfully by user {current_user.email} to version {updated_item.version}.") + return updated_item + except ConflictError as e: + logger.warning(f"Conflict updating item {item_id} for user {current_user.email}: {str(e)}") + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) + except Exception as e: + logger.error(f"Error updating item {item_id} for user {current_user.email}: {str(e)}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the item.") @router.delete( "/items/{item_id}", # Operate directly on item ID status_code=status.HTTP_204_NO_CONTENT, summary="Delete Item", - tags=["Items"] + tags=["Items"], + responses={ + status.HTTP_409_CONFLICT: {"description": "Conflict: Item has been modified, cannot delete specified version"} + } ) async def delete_item( item_id: int, # Item ID from path + expected_version: Optional[int] = Query(None, description="The expected version of the item to delete for optimistic locking."), item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), # Log who deleted it ): """ Deletes an item. User must have access to the list the item belongs to. - (MVP: Any member with list access can delete items). + If `expected_version` is provided and does not match the item's current version, + a 409 Conflict is returned. """ - logger.info(f"User {current_user.email} attempting to delete item ID: {item_id}") + logger.info(f"User {current_user.email} attempting to delete item ID: {item_id}, expected version: {expected_version}") # Permission check is handled by get_item_and_verify_access dependency + if expected_version is not None and item_db.version != expected_version: + logger.warning( + f"Conflict deleting item {item_id} for user {current_user.email}. " + f"Expected version {expected_version}, actual version {item_db.version}." + ) + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"Item has been modified. Expected version {expected_version}, but current version is {item_db.version}. Please refresh." + ) + await crud_item.delete_item(db=db, item_db=item_db) - logger.info(f"Item {item_id} deleted successfully by user {current_user.email}.") + logger.info(f"Item {item_id} (version {item_db.version}) deleted successfully by user {current_user.email}.") return Response(status_code=status.HTTP_204_NO_CONTENT) \ No newline at end of file diff --git a/be/app/api/v1/endpoints/lists.py b/be/app/api/v1/endpoints/lists.py index af75f24..5683b07 100644 --- a/be/app/api/v1/endpoints/lists.py +++ b/be/app/api/v1/endpoints/lists.py @@ -1,8 +1,8 @@ # app/api/v1/endpoints/lists.py import logging -from typing import List as PyList # Alias for Python List type hint +from typing import List as PyList, Optional # Alias for Python List type hint -from fastapi import APIRouter, Depends, HTTPException, status, Response +from fastapi import APIRouter, Depends, HTTPException, status, Response, Query # Added Query from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db @@ -17,7 +17,8 @@ from app.core.exceptions import ( GroupMembershipError, ListNotFoundError, ListPermissionError, - ListStatusNotFoundError + ListStatusNotFoundError, + ConflictError # Added ConflictError ) logger = logging.getLogger(__name__) @@ -101,7 +102,10 @@ async def read_list( "/{list_id}", response_model=ListPublic, # Return updated basic info summary="Update List", - tags=["Lists"] + tags=["Lists"], + responses={ # Add 409 to responses + status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified by someone else"} + } ) async def update_list( list_id: int, @@ -112,40 +116,62 @@ async def update_list( """ Updates a list's details (name, description, is_complete). Requires user to be the creator or a member of the list's group. - (MVP: Allows any member to update these fields). + The client MUST provide the current `version` of the list in the `list_in` payload. + If the version does not match, a 409 Conflict is returned. """ - logger.info(f"User {current_user.email} attempting to update list ID: {list_id}") + logger.info(f"User {current_user.email} attempting to update list ID: {list_id} with version {list_in.version}") list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) - # Prevent changing group_id or creator via this endpoint for simplicity - # if list_in.group_id is not None or list_in.created_by_id is not None: - # raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot change group or creator via this endpoint") - - updated_list = await crud_list.update_list(db=db, list_db=list_db, list_in=list_in) - logger.info(f"List {list_id} updated successfully by user {current_user.email}.") - return updated_list + try: + updated_list = await crud_list.update_list(db=db, list_db=list_db, list_in=list_in) + logger.info(f"List {list_id} updated successfully by user {current_user.email} to version {updated_list.version}.") + return updated_list + except ConflictError as e: # Catch and re-raise as HTTPException for proper FastAPI response + logger.warning(f"Conflict updating list {list_id} for user {current_user.email}: {str(e)}") + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) + except Exception as e: # Catch other potential errors from crud operation + logger.error(f"Error updating list {list_id} for user {current_user.email}: {str(e)}") + # Consider a more generic error, but for now, let's keep it specific if possible + # Re-raising might be better if crud layer already raises appropriate HTTPExceptions + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the list.") @router.delete( "/{list_id}", status_code=status.HTTP_204_NO_CONTENT, # Standard for successful DELETE with no body summary="Delete List", - tags=["Lists"] + tags=["Lists"], + responses={ # Add 409 to responses + status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified, cannot delete specified version"} + } ) async def delete_list( list_id: int, + expected_version: Optional[int] = Query(None, description="The expected version of the list to delete for optimistic locking."), db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """ Deletes a list. Requires user to be the creator of the list. - (Alternatively, could allow group owner). + If `expected_version` is provided and does not match the list's current version, + a 409 Conflict is returned. """ - logger.info(f"User {current_user.email} attempting to delete list ID: {list_id}") + logger.info(f"User {current_user.email} attempting to delete list ID: {list_id}, expected version: {expected_version}") # Use the helper, requiring creator permission list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id, require_creator=True) + + if expected_version is not None and list_db.version != expected_version: + logger.warning( + f"Conflict deleting list {list_id} for user {current_user.email}. " + f"Expected version {expected_version}, actual version {list_db.version}." + ) + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"List has been modified. Expected version {expected_version}, but current version is {list_db.version}. Please refresh." + ) + await crud_list.delete_list(db=db, list_db=list_db) - logger.info(f"List {list_id} deleted successfully by user {current_user.email}.") + logger.info(f"List {list_id} (version: {list_db.version}) deleted successfully by user {current_user.email}.") return Response(status_code=status.HTTP_204_NO_CONTENT) diff --git a/be/app/api/v1/endpoints/ocr.py b/be/app/api/v1/endpoints/ocr.py index 4d0fbaf..4600152 100644 --- a/be/app/api/v1/endpoints/ocr.py +++ b/be/app/api/v1/endpoints/ocr.py @@ -1,28 +1,27 @@ -# app/api/v1/endpoints/ocr.py import logging from typing import List -from fastapi import APIRouter, Depends, UploadFile, File +from fastapi import APIRouter, Depends, UploadFile, File, HTTPException, status from google.api_core import exceptions as google_exceptions from app.api.dependencies import get_current_user from app.models import User as UserModel from app.schemas.ocr import OcrExtractResponse -from app.core.gemini import extract_items_from_image_gemini, gemini_initialization_error +from app.core.gemini import extract_items_from_image_gemini, gemini_initialization_error, GeminiOCRService from app.core.exceptions import ( - OcrServiceUnavailableError, + OCRServiceUnavailableError, + OCRServiceConfigError, + OCRUnexpectedError, + OCRQuotaExceededError, InvalidFileTypeError, FileTooLargeError, - OcrProcessingError, - OcrQuotaExceededError + OCRProcessingError ) +from app.config import settings logger = logging.getLogger(__name__) router = APIRouter() - -# Allowed image MIME types -ALLOWED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/webp"] -MAX_FILE_SIZE_MB = 10 +ocr_service = GeminiOCRService() @router.post( "/extract-items", @@ -41,20 +40,20 @@ async def ocr_extract_items( # Check if Gemini client initialized correctly if gemini_initialization_error: logger.error("OCR endpoint called but Gemini client failed to initialize.") - raise OcrServiceUnavailableError(gemini_initialization_error) + raise OCRServiceUnavailableError(gemini_initialization_error) logger.info(f"User {current_user.email} uploading image '{image_file.filename}' for OCR extraction.") # --- File Validation --- - if image_file.content_type not in ALLOWED_IMAGE_TYPES: + if image_file.content_type not in settings.ALLOWED_IMAGE_TYPES: logger.warning(f"Invalid file type uploaded by {current_user.email}: {image_file.content_type}") - raise InvalidFileTypeError(ALLOWED_IMAGE_TYPES) + raise InvalidFileTypeError() # Simple size check contents = await image_file.read() - if len(contents) > MAX_FILE_SIZE_MB * 1024 * 1024: + if len(contents) > settings.MAX_FILE_SIZE_MB * 1024 * 1024: logger.warning(f"File too large uploaded by {current_user.email}: {len(contents)} bytes") - raise FileTooLargeError(MAX_FILE_SIZE_MB) + raise FileTooLargeError() try: # Call the Gemini helper function @@ -66,30 +65,14 @@ async def ocr_extract_items( logger.info(f"Successfully extracted {len(extracted_items)} items for user {current_user.email}.") return OcrExtractResponse(extracted_items=extracted_items) - except ValueError as e: - # Handle errors from Gemini processing (blocked, empty response, etc.) - logger.warning(f"Gemini processing error for user {current_user.email}: {e}") - raise OcrProcessingError(str(e)) - - except google_exceptions.ResourceExhausted as e: - # Specific handling for quota errors - logger.error(f"Gemini Quota Exceeded for user {current_user.email}: {e}", exc_info=True) - raise OcrQuotaExceededError() - - except google_exceptions.GoogleAPIError as e: - # Handle other Google API errors (e.g., invalid key, permissions) - logger.error(f"Gemini API Error for user {current_user.email}: {e}", exc_info=True) - raise OcrServiceUnavailableError(str(e)) - - except RuntimeError as e: - # Catch initialization errors from get_gemini_client() - logger.error(f"Gemini client runtime error during OCR request: {e}") - raise OcrServiceUnavailableError(f"OCR service configuration error: {e}") - + except OCRServiceUnavailableError: + raise OCRServiceUnavailableError() + except OCRServiceConfigError: + raise OCRServiceConfigError() + except OCRQuotaExceededError: + raise OCRQuotaExceededError() except Exception as e: - # Catch any other unexpected errors - logger.exception(f"Unexpected error during OCR extraction for user {current_user.email}: {e}") - raise OcrServiceUnavailableError("An unexpected error occurred during item extraction.") + raise OCRProcessingError(str(e)) finally: # Ensure file handle is closed diff --git a/be/app/config.py b/be/app/config.py index 017e8e1..8333247 100644 --- a/be/app/config.py +++ b/be/app/config.py @@ -16,6 +16,86 @@ class Settings(BaseSettings): SECRET_KEY: str # Must be set via environment variable ALGORITHM: str = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # Default token lifetime: 30 minutes + REFRESH_TOKEN_EXPIRE_MINUTES: int = 10080 # Default refresh token lifetime: 7 days + + # --- OCR Settings --- + MAX_FILE_SIZE_MB: int = 10 # Maximum allowed file size for OCR processing + ALLOWED_IMAGE_TYPES: list[str] = ["image/jpeg", "image/png", "image/webp"] # Supported image formats + OCR_ITEM_EXTRACTION_PROMPT: str = """ +Extract the shopping list items from this image. +List each distinct item on a new line. +Ignore prices, quantities, store names, discounts, taxes, totals, and other non-item text. +Focus only on the names of the products or items to be purchased. +Add 2 underscores before and after the item name, if it is struck through. +If the image does not appear to be a shopping list or receipt, state that clearly. +Example output for a grocery list: +Milk +Eggs +Bread +__Apples__ +Organic Bananas +""" + + # --- Gemini AI Settings --- + GEMINI_MODEL_NAME: str = "gemini-2.0-flash" # The model to use for OCR + GEMINI_SAFETY_SETTINGS: dict = { + "HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE", + "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE", + "HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE", + "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE", + } + GEMINI_GENERATION_CONFIG: dict = { + "candidate_count": 1, + "max_output_tokens": 2048, + "temperature": 0.9, + "top_p": 1, + "top_k": 1 + } + + # --- API Settings --- + API_PREFIX: str = "/api" # Base path for all API endpoints + API_OPENAPI_URL: str = "/api/openapi.json" + API_DOCS_URL: str = "/api/docs" + API_REDOC_URL: str = "/api/redoc" + CORS_ORIGINS: list[str] = [ + "http://localhost:5174", + "http://localhost:8000", + # Add your deployed frontend URL here later + # "https://your-frontend-domain.com", + ] + + # --- API Metadata --- + API_TITLE: str = "Shared Lists API" + API_DESCRIPTION: str = "API for managing shared shopping lists, OCR, and cost splitting." + API_VERSION: str = "0.1.0" + ROOT_MESSAGE: str = "Welcome to the Shared Lists API! Docs available at /api/docs" + + # --- Logging Settings --- + LOG_LEVEL: str = "INFO" + LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + + # --- Health Check Settings --- + HEALTH_STATUS_OK: str = "ok" + HEALTH_STATUS_ERROR: str = "error" + + # --- Auth Settings --- + OAUTH2_TOKEN_URL: str = "/api/v1/auth/login" # Path to login endpoint + TOKEN_TYPE: str = "bearer" # Default token type for OAuth2 + AUTH_HEADER_PREFIX: str = "Bearer" # Prefix for Authorization header + AUTH_HEADER_NAME: str = "WWW-Authenticate" # Name of auth header + AUTH_CREDENTIALS_ERROR: str = "Could not validate credentials" + AUTH_INVALID_CREDENTIALS: str = "Incorrect email or password" + AUTH_NOT_AUTHENTICATED: str = "Not authenticated" + + # --- HTTP Status Messages --- + HTTP_400_DETAIL: str = "Bad Request" + HTTP_401_DETAIL: str = "Unauthorized" + HTTP_403_DETAIL: str = "Forbidden" + HTTP_404_DETAIL: str = "Not Found" + HTTP_422_DETAIL: str = "Unprocessable Entity" + HTTP_429_DETAIL: str = "Too Many Requests" + HTTP_500_DETAIL: str = "Internal Server Error" + HTTP_503_DETAIL: str = "Service Unavailable" class Config: env_file = ".env" diff --git a/be/app/core/exceptions.py b/be/app/core/exceptions.py index 698df55..2522d20 100644 --- a/be/app/core/exceptions.py +++ b/be/app/core/exceptions.py @@ -1,4 +1,7 @@ from fastapi import HTTPException, status +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from app.config import settings +from typing import Optional class ListNotFoundError(HTTPException): """Raised when a list is not found.""" @@ -72,76 +75,105 @@ class ItemNotFoundError(HTTPException): detail=f"Item {item_id} not found" ) +class UserNotFoundError(HTTPException): + """Raised when a user is not found.""" + def __init__(self, user_id: Optional[int] = None, identifier: Optional[str] = None): + detail_msg = "User not found." + if user_id: + detail_msg = f"User with ID {user_id} not found." + elif identifier: + detail_msg = f"User with identifier '{identifier}' not found." + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail=detail_msg + ) + class DatabaseConnectionError(HTTPException): """Raised when there is an error connecting to the database.""" - def __init__(self, detail: str = "Database connection error"): + def __init__(self): super().__init__( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail=detail + detail=settings.DB_CONNECTION_ERROR ) class DatabaseIntegrityError(HTTPException): """Raised when a database integrity constraint is violated.""" - def __init__(self, detail: str = "Database integrity error"): + def __init__(self): super().__init__( status_code=status.HTTP_400_BAD_REQUEST, - detail=detail + detail=settings.DB_INTEGRITY_ERROR ) class DatabaseTransactionError(HTTPException): """Raised when a database transaction fails.""" - def __init__(self, detail: str = "Database transaction error"): + def __init__(self): super().__init__( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=detail + detail=settings.DB_TRANSACTION_ERROR ) class DatabaseQueryError(HTTPException): """Raised when a database query fails.""" - def __init__(self, detail: str = "Database query error"): + def __init__(self): super().__init__( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=detail + detail=settings.DB_QUERY_ERROR ) -class OcrServiceUnavailableError(HTTPException): +class OCRServiceUnavailableError(HTTPException): """Raised when the OCR service is unavailable.""" - def __init__(self, detail: str): + def __init__(self): super().__init__( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail=f"OCR service unavailable: {detail}" + detail=settings.OCR_SERVICE_UNAVAILABLE ) -class InvalidFileTypeError(HTTPException): - """Raised when an invalid file type is uploaded for OCR.""" - def __init__(self, allowed_types: list[str]): +class OCRServiceConfigError(HTTPException): + """Raised when there is an error in the OCR service configuration.""" + def __init__(self): super().__init__( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid file type. Allowed types: {', '.join(allowed_types)}" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=settings.OCR_SERVICE_CONFIG_ERROR ) -class FileTooLargeError(HTTPException): - """Raised when an uploaded file exceeds the size limit.""" - def __init__(self, max_size_mb: int): +class OCRUnexpectedError(HTTPException): + """Raised when there is an unexpected error in the OCR service.""" + def __init__(self): super().__init__( - status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, - detail=f"File size exceeds limit of {max_size_mb} MB." + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=settings.OCR_UNEXPECTED_ERROR ) -class OcrProcessingError(HTTPException): - """Raised when there is an error processing the image with OCR.""" - def __init__(self, detail: str): - super().__init__( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail=f"Could not extract items from image: {detail}" - ) - -class OcrQuotaExceededError(HTTPException): +class OCRQuotaExceededError(HTTPException): """Raised when the OCR service quota is exceeded.""" def __init__(self): super().__init__( status_code=status.HTTP_429_TOO_MANY_REQUESTS, - detail="OCR service quota exceeded. Please try again later." + detail=settings.OCR_QUOTA_EXCEEDED + ) + +class InvalidFileTypeError(HTTPException): + """Raised when an invalid file type is uploaded for OCR.""" + def __init__(self): + super().__init__( + status_code=status.HTTP_400_BAD_REQUEST, + detail=settings.OCR_INVALID_FILE_TYPE.format(types=", ".join(settings.ALLOWED_IMAGE_TYPES)) + ) + +class FileTooLargeError(HTTPException): + """Raised when an uploaded file exceeds the size limit.""" + def __init__(self): + super().__init__( + status_code=status.HTTP_400_BAD_REQUEST, + detail=settings.OCR_FILE_TOO_LARGE.format(size=settings.MAX_FILE_SIZE_MB) + ) + +class OCRProcessingError(HTTPException): + """Raised when there is an error processing the image with OCR.""" + def __init__(self, detail: str): + super().__init__( + status_code=status.HTTP_400_BAD_REQUEST, + detail=settings.OCR_PROCESSING_ERROR.format(detail=detail) ) class EmailAlreadyRegisteredError(HTTPException): @@ -152,15 +184,6 @@ class EmailAlreadyRegisteredError(HTTPException): detail="Email already registered." ) -class InvalidCredentialsError(HTTPException): - """Raised when login credentials are invalid.""" - def __init__(self): - super().__init__( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect email or password", - headers={"WWW-Authenticate": "Bearer"} - ) - class UserCreationError(HTTPException): """Raised when there is an error creating a new user.""" def __init__(self): @@ -207,4 +230,48 @@ class ListStatusNotFoundError(HTTPException): super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=f"Status for list {list_id} not found" + ) + +class ConflictError(HTTPException): + """Raised when an optimistic lock version conflict occurs.""" + def __init__(self, detail: str): + super().__init__( + status_code=status.HTTP_409_CONFLICT, + detail=detail + ) + +class InvalidCredentialsError(HTTPException): + """Raised when login credentials are invalid.""" + def __init__(self): + super().__init__( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=settings.AUTH_INVALID_CREDENTIALS, + headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_credentials\""} + ) + +class NotAuthenticatedError(HTTPException): + """Raised when the user is not authenticated.""" + def __init__(self): + super().__init__( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=settings.AUTH_NOT_AUTHENTICATED, + headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"not_authenticated\""} + ) + +class JWTError(HTTPException): + """Raised when there is an error with the JWT token.""" + def __init__(self, error: str): + super().__init__( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=settings.JWT_ERROR.format(error=error), + headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""} + ) + +class JWTUnexpectedError(HTTPException): + """Raised when there is an unexpected error with the JWT token.""" + def __init__(self, error: str): + super().__init__( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=settings.JWT_UNEXPECTED_ERROR.format(error=error), + headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""} ) \ No newline at end of file diff --git a/be/app/core/gemini.py b/be/app/core/gemini.py index e86e56a..c09e983 100644 --- a/be/app/core/gemini.py +++ b/be/app/core/gemini.py @@ -4,8 +4,13 @@ from typing import List import google.generativeai as genai from google.generativeai.types import HarmCategory, HarmBlockThreshold # For safety settings from google.api_core import exceptions as google_exceptions - from app.config import settings +from app.core.exceptions import ( + OCRServiceUnavailableError, + OCRServiceConfigError, + OCRUnexpectedError, + OCRQuotaExceededError +) logger = logging.getLogger(__name__) @@ -19,26 +24,18 @@ try: genai.configure(api_key=settings.GEMINI_API_KEY) # Initialize the specific model we want to use gemini_flash_client = genai.GenerativeModel( - model_name="gemini-2.0-flash", - # Optional: Add default safety settings - # Adjust these based on your expected content and risk tolerance + model_name=settings.GEMINI_MODEL_NAME, + # Safety settings from config safety_settings={ - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + getattr(HarmCategory, category): getattr(HarmBlockThreshold, threshold) + for category, threshold in settings.GEMINI_SAFETY_SETTINGS.items() }, - # Optional: Add default generation config (can be overridden per request) - # generation_config=genai.types.GenerationConfig( - # # candidate_count=1, # Usually default is 1 - # # stop_sequences=["\n"], - # # max_output_tokens=2048, - # # temperature=0.9, # Controls randomness (0=deterministic, >1=more random) - # # top_p=1, - # # top_k=1 - # ) + # Generation config from settings + generation_config=genai.types.GenerationConfig( + **settings.GEMINI_GENERATION_CONFIG + ) ) - logger.info("Gemini AI client initialized successfully for model 'gemini-1.5-flash-latest'.") + logger.info(f"Gemini AI client initialized successfully for model '{settings.GEMINI_MODEL_NAME}'.") else: # Store error if API key is missing gemini_initialization_error = "GEMINI_API_KEY not configured. Gemini client not initialized." @@ -105,7 +102,7 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = " # Prepare the full prompt content prompt_parts = [ - OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first + settings.OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first image_part # Then the image ] @@ -152,4 +149,47 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = " # Catch other unexpected errors during generation or processing logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True) # Wrap in a generic ValueError or re-raise - raise ValueError(f"Failed to process image with Gemini: {e}") from e \ No newline at end of file + raise ValueError(f"Failed to process image with Gemini: {e}") from e + +class GeminiOCRService: + def __init__(self): + try: + genai.configure(api_key=settings.GEMINI_API_KEY) + self.model = genai.GenerativeModel(settings.GEMINI_MODEL_NAME) + self.model.safety_settings = settings.GEMINI_SAFETY_SETTINGS + self.model.generation_config = settings.GEMINI_GENERATION_CONFIG + except Exception as e: + logger.error(f"Failed to initialize Gemini client: {e}") + raise OCRServiceConfigError() + + async def extract_items(self, image_data: bytes) -> List[str]: + """ + Extract shopping list items from an image using Gemini Vision. + """ + try: + # Create image part + image_parts = [{"mime_type": "image/jpeg", "data": image_data}] + + # Generate content + response = await self.model.generate_content_async( + contents=[settings.OCR_ITEM_EXTRACTION_PROMPT, *image_parts] + ) + + # Process response + if not response.text: + raise OCRUnexpectedError() + + # Split response into lines and clean up + items = [ + item.strip() + for item in response.text.split("\n") + if item.strip() and not item.strip().startswith("Example") + ] + + return items + + except Exception as e: + logger.error(f"Error during OCR extraction: {e}") + if "quota" in str(e).lower(): + raise OCRQuotaExceededError() + raise OCRServiceUnavailableError() \ No newline at end of file diff --git a/be/app/core/security.py b/be/app/core/security.py index 64ce3f7..f269a9a 100644 --- a/be/app/core/security.py +++ b/be/app/core/security.py @@ -66,7 +66,34 @@ def create_access_token(subject: Union[str, Any], expires_delta: Optional[timede ) # Data to encode in the token payload - to_encode = {"exp": expire, "sub": str(subject)} + to_encode = {"exp": expire, "sub": str(subject), "type": "access"} + + encoded_jwt = jwt.encode( + to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM + ) + return encoded_jwt + +def create_refresh_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str: + """ + Creates a JWT refresh token. + + Args: + subject: The subject of the token (e.g., user ID or email). + expires_delta: Optional timedelta object for token expiry. If None, + uses REFRESH_TOKEN_EXPIRE_MINUTES from settings. + + Returns: + The encoded JWT refresh token string. + """ + if expires_delta: + expire = datetime.now(timezone.utc) + expires_delta + else: + expire = datetime.now(timezone.utc) + timedelta( + minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES + ) + + # Data to encode in the token payload + to_encode = {"exp": expire, "sub": str(subject), "type": "refresh"} encoded_jwt = jwt.encode( to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM @@ -91,6 +118,8 @@ def verify_access_token(token: str) -> Optional[dict]: payload = jwt.decode( token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] ) + if payload.get("type") != "access": + raise JWTError("Invalid token type") return payload except JWTError as e: # Handles InvalidSignatureError, ExpiredSignatureError, etc. @@ -101,6 +130,31 @@ def verify_access_token(token: str) -> Optional[dict]: print(f"Unexpected error decoding JWT: {e}") return None +def verify_refresh_token(token: str) -> Optional[dict]: + """ + Verifies a JWT refresh token and returns its payload if valid. + + Args: + token: The JWT token string to verify. + + Returns: + The decoded token payload (dict) if the token is valid, not expired, + and is a refresh token, otherwise None. + """ + try: + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] + ) + if payload.get("type") != "refresh": + raise JWTError("Invalid token type") + return payload + except JWTError as e: + print(f"JWT Error: {e}") # Log the error for debugging + return None + except Exception as e: + print(f"Unexpected error decoding JWT: {e}") + return None + # You might add a function here later to extract the 'sub' (subject/user id) # specifically, often used in dependency injection for authentication. # def get_subject_from_token(token: str) -> Optional[str]: diff --git a/be/app/crud/cost.py b/be/app/crud/cost.py new file mode 100644 index 0000000..2b5d3c1 --- /dev/null +++ b/be/app/crud/cost.py @@ -0,0 +1,116 @@ +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy.orm import selectinload, joinedload +from decimal import Decimal, ROUND_HALF_UP +from typing import List as PyList, Dict, Set + +from app.models import List as ListModel, Item as ItemModel, User as UserModel, UserGroup as UserGroupModel, Group as GroupModel +from app.schemas.cost import ListCostSummary, UserCostShare +from app.core.exceptions import ListNotFoundError, UserNotFoundError # Assuming UserNotFoundError might be useful + +async def calculate_list_cost_summary(db: AsyncSession, list_id: int) -> ListCostSummary: + """ + Calculates the cost summary for a given list, splitting costs equally + among relevant users (group members if list is in a group, or creator if personal). + """ + # 1. Fetch the list, its items (with their 'added_by_user'), and its group (with members) + list_result = await db.execute( + select(ListModel) + .options( + selectinload(ListModel.items).options(joinedload(ItemModel.added_by_user)), + selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user))) + ) + .where(ListModel.id == list_id) + ) + db_list: Optional[ListModel] = list_result.scalars().first() + + if not db_list: + raise ListNotFoundError(list_id) + + # 2. Determine participating users + participating_users: Dict[int, UserModel] = {} + if db_list.group: + # If list is part of a group, all group members participate + for ug_assoc in db_list.group.user_associations: + if ug_assoc.user: # Ensure user object is loaded + participating_users[ug_assoc.user.id] = ug_assoc.user + else: + # If personal list, only the creator participates (or if items were added by others somehow, include them) + # For simplicity in MVP, if personal, only creator. If shared personal lists become a feature, this needs revisit. + # Let's fetch the creator if not already available through relationships (though it should be via ListModel.creator) + creator_user = await db.get(UserModel, db_list.created_by_id) + if not creator_user: + # This case should ideally not happen if data integrity is maintained + raise UserNotFoundError(user_id=db_list.created_by_id) # Or a more specific error + participating_users[creator_user.id] = creator_user + + # Also ensure all users who added items are included, even if not in the group (edge case, but good for robustness) + for item in db_list.items: + if item.added_by_user and item.added_by_user.id not in participating_users: + participating_users[item.added_by_user.id] = item.added_by_user + + + num_participating_users = len(participating_users) + if num_participating_users == 0: + # Handle case with no users (e.g., empty group, or personal list creator deleted - though FKs should prevent) + # Or if list has no items and is personal, creator might not be in participating_users if logic changes. + # For now, if no users found (e.g. group with no members and list creator somehow not added), return empty/default summary. + return ListCostSummary( + list_id=db_list.id, + list_name=db_list.name, + total_list_cost=Decimal("0.00"), + num_participating_users=0, + equal_share_per_user=Decimal("0.00"), + user_balances=[] + ) + + + # 3. Calculate total cost and items_added_value for each user + total_list_cost = Decimal("0.00") + user_items_added_value: Dict[int, Decimal] = {user_id: Decimal("0.00") for user_id in participating_users.keys()} + + for item in db_list.items: + if item.price is not None and item.price > Decimal("0"): + total_list_cost += item.price + if item.added_by_id in user_items_added_value: # Item adder must be in participating users + user_items_added_value[item.added_by_id] += item.price + # If item.added_by_id is not in participating_users (e.g. user left group), + # their contribution still counts to total cost, but they aren't part of the split. + # The current logic adds item adders to participating_users, so this else is less likely. + + # 4. Calculate equal share per user + # Using ROUND_HALF_UP to handle cents appropriately. + # Ensure division by zero is handled if num_participating_users could be 0 (already handled above) + equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) if num_participating_users > 0 else Decimal("0.00") + + # 5. For each user, calculate their balance + user_balances: PyList[UserCostShare] = [] + for user_id, user_obj in participating_users.items(): + items_added = user_items_added_value.get(user_id, Decimal("0.00")) + balance = items_added - equal_share_per_user + + user_identifier = user_obj.name if user_obj.name else user_obj.email # Prefer name, fallback to email + + user_balances.append( + UserCostShare( + user_id=user_id, + user_identifier=user_identifier, + items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), + amount_due=equal_share_per_user, # Already quantized + balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) + ) + ) + + # Sort user_balances for consistent output, e.g., by user_id or identifier + user_balances.sort(key=lambda x: x.user_identifier) + + + # 6. Return the populated ListCostSummary schema + return ListCostSummary( + list_id=db_list.id, + list_name=db_list.name, + total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), + num_participating_users=num_participating_users, + equal_share_per_user=equal_share_per_user, # Already quantized + user_balances=user_balances + ) \ No newline at end of file diff --git a/be/app/crud/item.py b/be/app/crud/item.py index da934a1..16ec082 100644 --- a/be/app/crud/item.py +++ b/be/app/crud/item.py @@ -13,7 +13,8 @@ from app.core.exceptions import ( DatabaseConnectionError, DatabaseIntegrityError, DatabaseQueryError, - DatabaseTransactionError + DatabaseTransactionError, + ConflictError ) async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel: @@ -26,6 +27,7 @@ async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_ list_id=list_id, added_by_id=user_id, is_complete=False # Default on creation + # version is implicitly set to 1 by model default ) db.add(db_item) await db.flush() @@ -65,44 +67,57 @@ async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]: raise DatabaseQueryError(f"Failed to query item: {str(e)}") async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel: - """Updates an existing item record.""" + """Updates an existing item record, checking for version conflicts.""" try: async with db.begin(): - update_data = item_in.model_dump(exclude_unset=True) # Get only provided fields + if item_db.version != item_in.version: + raise ConflictError( + f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. " + f"Your version is {item_in.version}, current version is {item_db.version}. Please refresh." + ) + + update_data = item_in.model_dump(exclude_unset=True, exclude={'version'}) # Exclude version # Special handling for is_complete if 'is_complete' in update_data: if update_data['is_complete'] is True: - # Mark as complete: set completed_by_id if not already set - if item_db.completed_by_id is None: + if item_db.completed_by_id is None: # Only set if not already completed by someone update_data['completed_by_id'] = user_id else: - # Mark as incomplete: clear completed_by_id - update_data['completed_by_id'] = None - # Ensure updated_at is refreshed (handled by onupdate in model, but explicit is fine too) - # update_data['updated_at'] = datetime.now(timezone.utc) - + update_data['completed_by_id'] = None # Clear if marked incomplete + for key, value in update_data.items(): setattr(item_db, key, value) - db.add(item_db) # Add to session to track changes + item_db.version += 1 # Increment version + + db.add(item_db) await db.flush() await db.refresh(item_db) return item_db except IntegrityError as e: - raise DatabaseIntegrityError(f"Failed to update item: {str(e)}") + await db.rollback() + raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}") except OperationalError as e: - raise DatabaseConnectionError(f"Database connection error: {str(e)}") + await db.rollback() + raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}") + except ConflictError: # Re-raise ConflictError + await db.rollback() + raise except SQLAlchemyError as e: + await db.rollback() raise DatabaseTransactionError(f"Failed to update item: {str(e)}") async def delete_item(db: AsyncSession, item_db: ItemModel) -> None: - """Deletes an item record.""" + """Deletes an item record. Version check should be done by the caller (API endpoint).""" try: async with db.begin(): await db.delete(item_db) - return None # Or return True/False + # await db.commit() # Not needed with async with db.begin() + return None except OperationalError as e: - raise DatabaseConnectionError(f"Database connection error: {str(e)}") + await db.rollback() + raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}") except SQLAlchemyError as e: + await db.rollback() raise DatabaseTransactionError(f"Failed to delete item: {str(e)}") \ No newline at end of file diff --git a/be/app/crud/list.py b/be/app/crud/list.py index d657a44..b674563 100644 --- a/be/app/crud/list.py +++ b/be/app/crud/list.py @@ -16,7 +16,8 @@ from app.core.exceptions import ( DatabaseConnectionError, DatabaseIntegrityError, DatabaseQueryError, - DatabaseTransactionError + DatabaseTransactionError, + ConflictError ) async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel: @@ -85,32 +86,50 @@ async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = Fals raise DatabaseQueryError(f"Failed to query list: {str(e)}") async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel: - """Updates an existing list record.""" + """Updates an existing list record, checking for version conflicts.""" try: async with db.begin(): - update_data = list_in.model_dump(exclude_unset=True) + if list_db.version != list_in.version: + raise ConflictError( + f"List '{list_db.name}' (ID: {list_db.id}) has been modified. " + f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh." + ) + + update_data = list_in.model_dump(exclude_unset=True, exclude={'version'}) + for key, value in update_data.items(): setattr(list_db, key, value) + + list_db.version += 1 + db.add(list_db) await db.flush() await db.refresh(list_db) return list_db except IntegrityError as e: - raise DatabaseIntegrityError(f"Failed to update list: {str(e)}") + await db.rollback() + raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}") except OperationalError as e: - raise DatabaseConnectionError(f"Database connection error: {str(e)}") + await db.rollback() + raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}") + except ConflictError: + await db.rollback() + raise except SQLAlchemyError as e: + await db.rollback() raise DatabaseTransactionError(f"Failed to update list: {str(e)}") async def delete_list(db: AsyncSession, list_db: ListModel) -> None: - """Deletes a list record.""" + """Deletes a list record. Version check should be done by the caller (API endpoint).""" try: async with db.begin(): await db.delete(list_db) return None except OperationalError as e: - raise DatabaseConnectionError(f"Database connection error: {str(e)}") + await db.rollback() + raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}") except SQLAlchemyError as e: + await db.rollback() raise DatabaseTransactionError(f"Failed to delete list: {str(e)}") async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel: diff --git a/be/app/main.py b/be/app/main.py index 0efd641..0230e98 100644 --- a/be/app/main.py +++ b/be/app/main.py @@ -5,22 +5,25 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from app.api.api_router import api_router # Import the main combined router +from app.config import settings # Import database and models if needed for startup/shutdown events later # from . import database, models # --- Logging Setup --- -# Configure logging (can be more sophisticated later, e.g., using logging.yaml) -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logging.basicConfig( + level=getattr(logging, settings.LOG_LEVEL), + format=settings.LOG_FORMAT +) logger = logging.getLogger(__name__) # --- FastAPI App Instance --- app = FastAPI( - title="Shared Lists API", - description="API for managing shared shopping lists, OCR, and cost splitting.", - version="0.1.0", - openapi_url="/api/openapi.json", # Place OpenAPI spec under /api - docs_url="/api/docs", # Place Swagger UI under /api - redoc_url="/api/redoc" # Place ReDoc under /api + title=settings.API_TITLE, + description=settings.API_DESCRIPTION, + version=settings.API_VERSION, + openapi_url=settings.API_OPENAPI_URL, + docs_url=settings.API_DOCS_URL, + redoc_url=settings.API_REDOC_URL ) # --- CORS Middleware --- @@ -37,17 +40,17 @@ origins = [ app.add_middleware( CORSMiddleware, - allow_origins=origins, # List of origins that are allowed to make requests - allow_credentials=True, # Allow cookies to be included in requests - allow_methods=["*"], # Allow all methods (GET, POST, PUT, DELETE, etc.) - allow_headers=["*"], # Allow all headers + allow_origins=settings.CORS_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], ) # --- End CORS Middleware --- # --- Include API Routers --- # All API endpoints will be prefixed with /api -app.include_router(api_router, prefix="/api") +app.include_router(api_router, prefix=settings.API_PREFIX) # --- End Include API Routers --- @@ -59,10 +62,7 @@ async def read_root(): Useful for basic reachability checks. """ logger.info("Root endpoint '/' accessed.") - # You could redirect to the docs or return a simple message - # from fastapi.responses import RedirectResponse - # return RedirectResponse(url="/api/docs") - return {"message": "Welcome to the Shared Lists API! Docs available at /api/docs"} + return {"message": settings.ROOT_MESSAGE} # --- End Root Endpoint --- diff --git a/be/app/models.py b/be/app/models.py index 851b8c0..50f56b5 100644 --- a/be/app/models.py +++ b/be/app/models.py @@ -117,6 +117,7 @@ class List(Base): is_complete = Column(Boolean, default=False, nullable=False) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) + version = Column(Integer, nullable=False, default=1, server_default='1') # --- Relationships --- creator = relationship("User", back_populates="created_lists") # Link to User.created_lists @@ -138,6 +139,7 @@ class Item(Base): completed_by_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Who marked it complete created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) + version = Column(Integer, nullable=False, default=1, server_default='1') # --- Relationships --- list = relationship("List", back_populates="items") # Link to List.items diff --git a/be/app/schemas/auth.py b/be/app/schemas/auth.py index d3a76fb..c0c4fcb 100644 --- a/be/app/schemas/auth.py +++ b/be/app/schemas/auth.py @@ -1,9 +1,11 @@ # app/schemas/auth.py from pydantic import BaseModel, EmailStr +from app.config import settings class Token(BaseModel): access_token: str - token_type: str = "bearer" # Default token type + refresh_token: str # Added refresh token + token_type: str = settings.TOKEN_TYPE # Use configured token type # Optional: If you preferred not to use OAuth2PasswordRequestForm # class UserLogin(BaseModel): diff --git a/be/app/schemas/cost.py b/be/app/schemas/cost.py new file mode 100644 index 0000000..b30a18f --- /dev/null +++ b/be/app/schemas/cost.py @@ -0,0 +1,22 @@ +from pydantic import BaseModel, ConfigDict +from typing import List, Optional +from decimal import Decimal + +class UserCostShare(BaseModel): + user_id: int + user_identifier: str # Name or email + items_added_value: Decimal = Decimal("0.00") # Total value of items this user added + amount_due: Decimal # The user's share of the total cost (for equal split, this is total_cost / num_users) + balance: Decimal # items_added_value - amount_due + + model_config = ConfigDict(from_attributes=True) + +class ListCostSummary(BaseModel): + list_id: int + list_name: str + total_list_cost: Decimal + num_participating_users: int + equal_share_per_user: Decimal + user_balances: List[UserCostShare] + + model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/be/app/schemas/health.py b/be/app/schemas/health.py index 1d4f6e5..bbb00b7 100644 --- a/be/app/schemas/health.py +++ b/be/app/schemas/health.py @@ -1,9 +1,10 @@ # app/schemas/health.py from pydantic import BaseModel +from app.config import settings class HealthStatus(BaseModel): """ Response model for the health check endpoint. """ - status: str = "ok" # Provide a default value + status: str = settings.HEALTH_STATUS_OK # Use configured default value database: str \ No newline at end of file diff --git a/be/app/schemas/item.py b/be/app/schemas/item.py index 1f281d0..fe2f2f6 100644 --- a/be/app/schemas/item.py +++ b/be/app/schemas/item.py @@ -16,6 +16,7 @@ class ItemPublic(BaseModel): completed_by_id: Optional[int] = None created_at: datetime updated_at: datetime + version: int model_config = ConfigDict(from_attributes=True) # Properties to receive via API on creation @@ -31,4 +32,5 @@ class ItemUpdate(BaseModel): quantity: Optional[str] = None is_complete: Optional[bool] = None price: Optional[Decimal] = None # Price added here for update + version: int # completed_by_id will be set internally if is_complete is true \ No newline at end of file diff --git a/be/app/schemas/list.py b/be/app/schemas/list.py index 66006ca..a2d4314 100644 --- a/be/app/schemas/list.py +++ b/be/app/schemas/list.py @@ -16,6 +16,7 @@ class ListUpdate(BaseModel): name: Optional[str] = None description: Optional[str] = None is_complete: Optional[bool] = None + version: int # Client must provide the version for updates # Potentially add group_id update later if needed # Base properties returned by API (common fields) @@ -28,6 +29,7 @@ class ListBase(BaseModel): is_complete: bool created_at: datetime updated_at: datetime + version: int # Include version in responses model_config = ConfigDict(from_attributes=True)