add_version_to_lists_and_items

This commit is contained in:
mohamad 2025-05-07 23:30:23 +02:00
parent d2d484c327
commit 423d345fdf
23 changed files with 796 additions and 185 deletions

View File

@ -32,4 +32,4 @@ EXPOSE 8000
# The default command for production (can be overridden in docker-compose for development)
# Note: Make sure 'app.main:app' correctly points to your FastAPI app instance
# relative to the WORKDIR (/app). If your main.py is directly in /app, this is correct.
CMD ["uvicorn", "app.main:app", "--host", "localhost", "--port", "8000"]
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@ -0,0 +1,28 @@
"""add_version_to_lists_and_items
Revision ID: d53eedd151b7
Revises: d25788f63e2c
Create Date: 2025-05-07 21:05:50.396430
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'd53eedd151b7'
down_revision: Union[str, None] = 'd25788f63e2c'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
pass
def downgrade() -> None:
"""Downgrade schema."""
pass

View File

@ -11,13 +11,14 @@ from app.database import get_db
from app.core.security import verify_access_token
from app.crud import user as crud_user
from app.models import User as UserModel # Import the SQLAlchemy model
from app.config import settings
logger = logging.getLogger(__name__)
# Define the OAuth2 scheme
# tokenUrl should point to your login endpoint relative to the base path
# It's used by Swagger UI for the "Authorize" button flow.
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") # Corrected path
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=settings.OAUTH2_TOKEN_URL)
async def get_current_user(
token: str = Depends(oauth2_scheme),
@ -36,8 +37,8 @@ async def get_current_user(
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
detail=settings.AUTH_CREDENTIALS_ERROR,
headers={settings.AUTH_HEADER_NAME: settings.AUTH_HEADER_PREFIX},
)
payload = verify_access_token(token)

View File

@ -9,6 +9,7 @@ from app.api.v1.endpoints import invites
from app.api.v1.endpoints import lists
from app.api.v1.endpoints import items
from app.api.v1.endpoints import ocr
from app.api.v1.endpoints import costs
api_router_v1 = APIRouter()
@ -20,5 +21,6 @@ api_router_v1.include_router(invites.router, prefix="/invites", tags=["Invites"]
api_router_v1.include_router(lists.router, prefix="/lists", tags=["Lists"])
api_router_v1.include_router(items.router, tags=["Items"])
api_router_v1.include_router(ocr.router, prefix="/ocr", tags=["OCR"])
api_router_v1.include_router(costs.router, tags=["Costs"])
# Add other v1 endpoint routers here later
# e.g., api_router_v1.include_router(users.router, prefix="/users", tags=["Users"])

View File

@ -1,6 +1,7 @@
# app/api/v1/endpoints/auth.py
import logging
from fastapi import APIRouter, Depends
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
@ -8,12 +9,18 @@ from app.database import get_db
from app.schemas.user import UserCreate, UserPublic
from app.schemas.auth import Token
from app.crud import user as crud_user
from app.core.security import verify_password, create_access_token
from app.core.security import (
verify_password,
create_access_token,
create_refresh_token,
verify_refresh_token
)
from app.core.exceptions import (
EmailAlreadyRegisteredError,
InvalidCredentialsError,
UserCreationError
)
from app.config import settings
logger = logging.getLogger(__name__)
router = APIRouter()
@ -55,28 +62,74 @@ async def signup(
"/login",
response_model=Token,
summary="User Login",
description="Authenticates a user and returns an access token.",
description="Authenticates a user and returns an access and refresh token.",
tags=["Authentication"]
)
async def login(
form_data: OAuth2PasswordRequestForm = Depends(),
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: AsyncSession = Depends(get_db)
):
"""
Handles user login.
- Finds user by email (provided in 'username' field of form).
- Verifies the provided password against the stored hash.
- Generates and returns a JWT access token upon successful authentication.
- Generates and returns JWT access and refresh tokens upon successful authentication.
"""
logger.info(f"Login attempt for user: {form_data.username}")
user = await crud_user.get_user_by_email(db, email=form_data.username)
# Check if user exists and password is correct
if not user or not verify_password(form_data.password, user.password_hash):
logger.warning(f"Login failed: Invalid credentials for user {form_data.username}")
raise InvalidCredentialsError()
# Generate JWT
access_token = create_access_token(subject=user.email)
logger.info(f"Login successful, token generated for user: {user.email}")
return Token(access_token=access_token, token_type="bearer")
refresh_token = create_refresh_token(subject=user.email)
logger.info(f"Login successful, tokens generated for user: {user.email}")
return Token(
access_token=access_token,
refresh_token=refresh_token,
token_type=settings.TOKEN_TYPE
)
@router.post(
"/refresh",
response_model=Token,
summary="Refresh Access Token",
description="Refreshes an access token using a refresh token.",
tags=["Authentication"]
)
async def refresh_token(
refresh_token_str: str,
db: AsyncSession = Depends(get_db)
):
"""
Handles access token refresh.
- Verifies the provided refresh token.
- If valid, generates and returns a new JWT access token and the same refresh token.
"""
logger.info("Access token refresh attempt")
payload = verify_refresh_token(refresh_token_str)
if not payload:
logger.warning("Refresh token invalid or expired")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
headers={"WWW-Authenticate": "Bearer"},
)
user_email = payload.get("sub")
if not user_email:
logger.error("User email not found in refresh token payload")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token payload",
headers={"WWW-Authenticate": "Bearer"},
)
new_access_token = create_access_token(subject=user_email)
logger.info(f"Access token refreshed for user: {user_email}")
return Token(
access_token=new_access_token,
refresh_token=refresh_token_str,
token_type=settings.TOKEN_TYPE
)

View File

@ -0,0 +1,69 @@
# app/api/v1/endpoints/costs.py
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.api.dependencies import get_current_user
from app.models import User as UserModel # For get_current_user dependency
from app.schemas.cost import ListCostSummary
from app.crud import cost as crud_cost
from app.crud import list as crud_list # For permission checking
from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get(
"/lists/{list_id}/cost-summary",
response_model=ListCostSummary,
summary="Get Cost Summary for a List",
tags=["Costs"],
responses={
status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this list"},
status.HTTP_404_NOT_FOUND: {"description": "List or associated user not found"}
}
)
async def get_list_cost_summary(
list_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Retrieves a calculated cost summary for a specific list, detailing total costs,
equal shares per user, and individual user balances based on their contributions.
The user must have access to the list to view its cost summary.
Costs are split among group members if the list belongs to a group, or just for
the creator if it's a personal list. All users who added items with prices are
included in the calculation.
"""
logger.info(f"User {current_user.email} requesting cost summary for list {list_id}")
# 1. Verify user has access to the target list
try:
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
except ListPermissionError as e:
logger.warning(f"Permission denied for user {current_user.email} on list {list_id}: {str(e)}")
raise # Re-raise the original exception to be handled by FastAPI
except ListNotFoundError as e:
logger.warning(f"List {list_id} not found when checking permissions for cost summary: {str(e)}")
raise # Re-raise
# 2. Calculate the cost summary
try:
cost_summary = await crud_cost.calculate_list_cost_summary(db=db, list_id=list_id)
logger.info(f"Successfully generated cost summary for list {list_id} for user {current_user.email}")
return cost_summary
except ListNotFoundError as e:
logger.warning(f"List {list_id} not found during cost summary calculation: {str(e)}")
# This might be redundant if check_list_permission already confirmed list existence,
# but calculate_list_cost_summary also fetches the list.
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except UserNotFoundError as e:
logger.error(f"User not found during cost summary calculation for list {list_id}: {str(e)}")
# This indicates a data integrity issue (e.g., list creator or item adder missing)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error generating cost summary for list {list_id} for user {current_user.email}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while generating the cost summary.")

View File

@ -1,8 +1,8 @@
# app/api/v1/endpoints/items.py
import logging
from typing import List as PyList
from typing import List as PyList, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Response
from fastapi import APIRouter, Depends, HTTPException, status, Response, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
@ -14,7 +14,7 @@ from app.models import Item as ItemModel # <-- IMPORT Item and alias it
from app.schemas.item import ItemCreate, ItemUpdate, ItemPublic
from app.crud import item as crud_item
from app.crud import list as crud_list
from app.core.exceptions import ItemNotFoundError, ListPermissionError
from app.core.exceptions import ItemNotFoundError, ListPermissionError, ConflictError
logger = logging.getLogger(__name__)
router = APIRouter()
@ -100,7 +100,10 @@ async def read_list_items(
"/items/{item_id}", # Operate directly on item ID
response_model=ItemPublic,
summary="Update Item",
tags=["Items"]
tags=["Items"],
responses={
status.HTTP_409_CONFLICT: {"description": "Conflict: Item has been modified by someone else"}
}
)
async def update_item(
item_id: int, # Item ID from path
@ -112,37 +115,61 @@ async def update_item(
"""
Updates an item's details (name, quantity, is_complete, price).
User must have access to the list the item belongs to.
The client MUST provide the current `version` of the item in the `item_in` payload.
If the version does not match, a 409 Conflict is returned.
Sets/unsets `completed_by_id` based on `is_complete` flag.
"""
logger.info(f"User {current_user.email} attempting to update item ID: {item_id}")
logger.info(f"User {current_user.email} attempting to update item ID: {item_id} with version {item_in.version}")
# Permission check is handled by get_item_and_verify_access dependency
updated_item = await crud_item.update_item(
db=db, item_db=item_db, item_in=item_in, user_id=current_user.id
)
logger.info(f"Item {item_id} updated successfully by user {current_user.email}.")
return updated_item
try:
updated_item = await crud_item.update_item(
db=db, item_db=item_db, item_in=item_in, user_id=current_user.id
)
logger.info(f"Item {item_id} updated successfully by user {current_user.email} to version {updated_item.version}.")
return updated_item
except ConflictError as e:
logger.warning(f"Conflict updating item {item_id} for user {current_user.email}: {str(e)}")
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
except Exception as e:
logger.error(f"Error updating item {item_id} for user {current_user.email}: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the item.")
@router.delete(
"/items/{item_id}", # Operate directly on item ID
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete Item",
tags=["Items"]
tags=["Items"],
responses={
status.HTTP_409_CONFLICT: {"description": "Conflict: Item has been modified, cannot delete specified version"}
}
)
async def delete_item(
item_id: int, # Item ID from path
expected_version: Optional[int] = Query(None, description="The expected version of the item to delete for optimistic locking."),
item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user), # Log who deleted it
):
"""
Deletes an item. User must have access to the list the item belongs to.
(MVP: Any member with list access can delete items).
If `expected_version` is provided and does not match the item's current version,
a 409 Conflict is returned.
"""
logger.info(f"User {current_user.email} attempting to delete item ID: {item_id}")
logger.info(f"User {current_user.email} attempting to delete item ID: {item_id}, expected version: {expected_version}")
# Permission check is handled by get_item_and_verify_access dependency
if expected_version is not None and item_db.version != expected_version:
logger.warning(
f"Conflict deleting item {item_id} for user {current_user.email}. "
f"Expected version {expected_version}, actual version {item_db.version}."
)
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Item has been modified. Expected version {expected_version}, but current version is {item_db.version}. Please refresh."
)
await crud_item.delete_item(db=db, item_db=item_db)
logger.info(f"Item {item_id} deleted successfully by user {current_user.email}.")
logger.info(f"Item {item_id} (version {item_db.version}) deleted successfully by user {current_user.email}.")
return Response(status_code=status.HTTP_204_NO_CONTENT)

View File

@ -1,8 +1,8 @@
# app/api/v1/endpoints/lists.py
import logging
from typing import List as PyList # Alias for Python List type hint
from typing import List as PyList, Optional # Alias for Python List type hint
from fastapi import APIRouter, Depends, HTTPException, status, Response
from fastapi import APIRouter, Depends, HTTPException, status, Response, Query # Added Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
@ -17,7 +17,8 @@ from app.core.exceptions import (
GroupMembershipError,
ListNotFoundError,
ListPermissionError,
ListStatusNotFoundError
ListStatusNotFoundError,
ConflictError # Added ConflictError
)
logger = logging.getLogger(__name__)
@ -101,7 +102,10 @@ async def read_list(
"/{list_id}",
response_model=ListPublic, # Return updated basic info
summary="Update List",
tags=["Lists"]
tags=["Lists"],
responses={ # Add 409 to responses
status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified by someone else"}
}
)
async def update_list(
list_id: int,
@ -112,40 +116,62 @@ async def update_list(
"""
Updates a list's details (name, description, is_complete).
Requires user to be the creator or a member of the list's group.
(MVP: Allows any member to update these fields).
The client MUST provide the current `version` of the list in the `list_in` payload.
If the version does not match, a 409 Conflict is returned.
"""
logger.info(f"User {current_user.email} attempting to update list ID: {list_id}")
logger.info(f"User {current_user.email} attempting to update list ID: {list_id} with version {list_in.version}")
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
# Prevent changing group_id or creator via this endpoint for simplicity
# if list_in.group_id is not None or list_in.created_by_id is not None:
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot change group or creator via this endpoint")
updated_list = await crud_list.update_list(db=db, list_db=list_db, list_in=list_in)
logger.info(f"List {list_id} updated successfully by user {current_user.email}.")
return updated_list
try:
updated_list = await crud_list.update_list(db=db, list_db=list_db, list_in=list_in)
logger.info(f"List {list_id} updated successfully by user {current_user.email} to version {updated_list.version}.")
return updated_list
except ConflictError as e: # Catch and re-raise as HTTPException for proper FastAPI response
logger.warning(f"Conflict updating list {list_id} for user {current_user.email}: {str(e)}")
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
except Exception as e: # Catch other potential errors from crud operation
logger.error(f"Error updating list {list_id} for user {current_user.email}: {str(e)}")
# Consider a more generic error, but for now, let's keep it specific if possible
# Re-raising might be better if crud layer already raises appropriate HTTPExceptions
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the list.")
@router.delete(
"/{list_id}",
status_code=status.HTTP_204_NO_CONTENT, # Standard for successful DELETE with no body
summary="Delete List",
tags=["Lists"]
tags=["Lists"],
responses={ # Add 409 to responses
status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified, cannot delete specified version"}
}
)
async def delete_list(
list_id: int,
expected_version: Optional[int] = Query(None, description="The expected version of the list to delete for optimistic locking."),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Deletes a list. Requires user to be the creator of the list.
(Alternatively, could allow group owner).
If `expected_version` is provided and does not match the list's current version,
a 409 Conflict is returned.
"""
logger.info(f"User {current_user.email} attempting to delete list ID: {list_id}")
logger.info(f"User {current_user.email} attempting to delete list ID: {list_id}, expected version: {expected_version}")
# Use the helper, requiring creator permission
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id, require_creator=True)
if expected_version is not None and list_db.version != expected_version:
logger.warning(
f"Conflict deleting list {list_id} for user {current_user.email}. "
f"Expected version {expected_version}, actual version {list_db.version}."
)
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"List has been modified. Expected version {expected_version}, but current version is {list_db.version}. Please refresh."
)
await crud_list.delete_list(db=db, list_db=list_db)
logger.info(f"List {list_id} deleted successfully by user {current_user.email}.")
logger.info(f"List {list_id} (version: {list_db.version}) deleted successfully by user {current_user.email}.")
return Response(status_code=status.HTTP_204_NO_CONTENT)

View File

@ -1,28 +1,27 @@
# app/api/v1/endpoints/ocr.py
import logging
from typing import List
from fastapi import APIRouter, Depends, UploadFile, File
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException, status
from google.api_core import exceptions as google_exceptions
from app.api.dependencies import get_current_user
from app.models import User as UserModel
from app.schemas.ocr import OcrExtractResponse
from app.core.gemini import extract_items_from_image_gemini, gemini_initialization_error
from app.core.gemini import extract_items_from_image_gemini, gemini_initialization_error, GeminiOCRService
from app.core.exceptions import (
OcrServiceUnavailableError,
OCRServiceUnavailableError,
OCRServiceConfigError,
OCRUnexpectedError,
OCRQuotaExceededError,
InvalidFileTypeError,
FileTooLargeError,
OcrProcessingError,
OcrQuotaExceededError
OCRProcessingError
)
from app.config import settings
logger = logging.getLogger(__name__)
router = APIRouter()
# Allowed image MIME types
ALLOWED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/webp"]
MAX_FILE_SIZE_MB = 10
ocr_service = GeminiOCRService()
@router.post(
"/extract-items",
@ -41,20 +40,20 @@ async def ocr_extract_items(
# Check if Gemini client initialized correctly
if gemini_initialization_error:
logger.error("OCR endpoint called but Gemini client failed to initialize.")
raise OcrServiceUnavailableError(gemini_initialization_error)
raise OCRServiceUnavailableError(gemini_initialization_error)
logger.info(f"User {current_user.email} uploading image '{image_file.filename}' for OCR extraction.")
# --- File Validation ---
if image_file.content_type not in ALLOWED_IMAGE_TYPES:
if image_file.content_type not in settings.ALLOWED_IMAGE_TYPES:
logger.warning(f"Invalid file type uploaded by {current_user.email}: {image_file.content_type}")
raise InvalidFileTypeError(ALLOWED_IMAGE_TYPES)
raise InvalidFileTypeError()
# Simple size check
contents = await image_file.read()
if len(contents) > MAX_FILE_SIZE_MB * 1024 * 1024:
if len(contents) > settings.MAX_FILE_SIZE_MB * 1024 * 1024:
logger.warning(f"File too large uploaded by {current_user.email}: {len(contents)} bytes")
raise FileTooLargeError(MAX_FILE_SIZE_MB)
raise FileTooLargeError()
try:
# Call the Gemini helper function
@ -66,30 +65,14 @@ async def ocr_extract_items(
logger.info(f"Successfully extracted {len(extracted_items)} items for user {current_user.email}.")
return OcrExtractResponse(extracted_items=extracted_items)
except ValueError as e:
# Handle errors from Gemini processing (blocked, empty response, etc.)
logger.warning(f"Gemini processing error for user {current_user.email}: {e}")
raise OcrProcessingError(str(e))
except google_exceptions.ResourceExhausted as e:
# Specific handling for quota errors
logger.error(f"Gemini Quota Exceeded for user {current_user.email}: {e}", exc_info=True)
raise OcrQuotaExceededError()
except google_exceptions.GoogleAPIError as e:
# Handle other Google API errors (e.g., invalid key, permissions)
logger.error(f"Gemini API Error for user {current_user.email}: {e}", exc_info=True)
raise OcrServiceUnavailableError(str(e))
except RuntimeError as e:
# Catch initialization errors from get_gemini_client()
logger.error(f"Gemini client runtime error during OCR request: {e}")
raise OcrServiceUnavailableError(f"OCR service configuration error: {e}")
except OCRServiceUnavailableError:
raise OCRServiceUnavailableError()
except OCRServiceConfigError:
raise OCRServiceConfigError()
except OCRQuotaExceededError:
raise OCRQuotaExceededError()
except Exception as e:
# Catch any other unexpected errors
logger.exception(f"Unexpected error during OCR extraction for user {current_user.email}: {e}")
raise OcrServiceUnavailableError("An unexpected error occurred during item extraction.")
raise OCRProcessingError(str(e))
finally:
# Ensure file handle is closed

View File

@ -16,6 +16,86 @@ class Settings(BaseSettings):
SECRET_KEY: str # Must be set via environment variable
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # Default token lifetime: 30 minutes
REFRESH_TOKEN_EXPIRE_MINUTES: int = 10080 # Default refresh token lifetime: 7 days
# --- OCR Settings ---
MAX_FILE_SIZE_MB: int = 10 # Maximum allowed file size for OCR processing
ALLOWED_IMAGE_TYPES: list[str] = ["image/jpeg", "image/png", "image/webp"] # Supported image formats
OCR_ITEM_EXTRACTION_PROMPT: str = """
Extract the shopping list items from this image.
List each distinct item on a new line.
Ignore prices, quantities, store names, discounts, taxes, totals, and other non-item text.
Focus only on the names of the products or items to be purchased.
Add 2 underscores before and after the item name, if it is struck through.
If the image does not appear to be a shopping list or receipt, state that clearly.
Example output for a grocery list:
Milk
Eggs
Bread
__Apples__
Organic Bananas
"""
# --- Gemini AI Settings ---
GEMINI_MODEL_NAME: str = "gemini-2.0-flash" # The model to use for OCR
GEMINI_SAFETY_SETTINGS: dict = {
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE",
}
GEMINI_GENERATION_CONFIG: dict = {
"candidate_count": 1,
"max_output_tokens": 2048,
"temperature": 0.9,
"top_p": 1,
"top_k": 1
}
# --- API Settings ---
API_PREFIX: str = "/api" # Base path for all API endpoints
API_OPENAPI_URL: str = "/api/openapi.json"
API_DOCS_URL: str = "/api/docs"
API_REDOC_URL: str = "/api/redoc"
CORS_ORIGINS: list[str] = [
"http://localhost:5174",
"http://localhost:8000",
# Add your deployed frontend URL here later
# "https://your-frontend-domain.com",
]
# --- API Metadata ---
API_TITLE: str = "Shared Lists API"
API_DESCRIPTION: str = "API for managing shared shopping lists, OCR, and cost splitting."
API_VERSION: str = "0.1.0"
ROOT_MESSAGE: str = "Welcome to the Shared Lists API! Docs available at /api/docs"
# --- Logging Settings ---
LOG_LEVEL: str = "INFO"
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
# --- Health Check Settings ---
HEALTH_STATUS_OK: str = "ok"
HEALTH_STATUS_ERROR: str = "error"
# --- Auth Settings ---
OAUTH2_TOKEN_URL: str = "/api/v1/auth/login" # Path to login endpoint
TOKEN_TYPE: str = "bearer" # Default token type for OAuth2
AUTH_HEADER_PREFIX: str = "Bearer" # Prefix for Authorization header
AUTH_HEADER_NAME: str = "WWW-Authenticate" # Name of auth header
AUTH_CREDENTIALS_ERROR: str = "Could not validate credentials"
AUTH_INVALID_CREDENTIALS: str = "Incorrect email or password"
AUTH_NOT_AUTHENTICATED: str = "Not authenticated"
# --- HTTP Status Messages ---
HTTP_400_DETAIL: str = "Bad Request"
HTTP_401_DETAIL: str = "Unauthorized"
HTTP_403_DETAIL: str = "Forbidden"
HTTP_404_DETAIL: str = "Not Found"
HTTP_422_DETAIL: str = "Unprocessable Entity"
HTTP_429_DETAIL: str = "Too Many Requests"
HTTP_500_DETAIL: str = "Internal Server Error"
HTTP_503_DETAIL: str = "Service Unavailable"
class Config:
env_file = ".env"

View File

@ -1,4 +1,7 @@
from fastapi import HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from app.config import settings
from typing import Optional
class ListNotFoundError(HTTPException):
"""Raised when a list is not found."""
@ -72,76 +75,105 @@ class ItemNotFoundError(HTTPException):
detail=f"Item {item_id} not found"
)
class UserNotFoundError(HTTPException):
"""Raised when a user is not found."""
def __init__(self, user_id: Optional[int] = None, identifier: Optional[str] = None):
detail_msg = "User not found."
if user_id:
detail_msg = f"User with ID {user_id} not found."
elif identifier:
detail_msg = f"User with identifier '{identifier}' not found."
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail_msg
)
class DatabaseConnectionError(HTTPException):
"""Raised when there is an error connecting to the database."""
def __init__(self, detail: str = "Database connection error"):
def __init__(self):
super().__init__(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=detail
detail=settings.DB_CONNECTION_ERROR
)
class DatabaseIntegrityError(HTTPException):
"""Raised when a database integrity constraint is violated."""
def __init__(self, detail: str = "Database integrity error"):
def __init__(self):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=detail
detail=settings.DB_INTEGRITY_ERROR
)
class DatabaseTransactionError(HTTPException):
"""Raised when a database transaction fails."""
def __init__(self, detail: str = "Database transaction error"):
def __init__(self):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
detail=settings.DB_TRANSACTION_ERROR
)
class DatabaseQueryError(HTTPException):
"""Raised when a database query fails."""
def __init__(self, detail: str = "Database query error"):
def __init__(self):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
detail=settings.DB_QUERY_ERROR
)
class OcrServiceUnavailableError(HTTPException):
class OCRServiceUnavailableError(HTTPException):
"""Raised when the OCR service is unavailable."""
def __init__(self, detail: str):
def __init__(self):
super().__init__(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"OCR service unavailable: {detail}"
detail=settings.OCR_SERVICE_UNAVAILABLE
)
class InvalidFileTypeError(HTTPException):
"""Raised when an invalid file type is uploaded for OCR."""
def __init__(self, allowed_types: list[str]):
class OCRServiceConfigError(HTTPException):
"""Raised when there is an error in the OCR service configuration."""
def __init__(self):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid file type. Allowed types: {', '.join(allowed_types)}"
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=settings.OCR_SERVICE_CONFIG_ERROR
)
class FileTooLargeError(HTTPException):
"""Raised when an uploaded file exceeds the size limit."""
def __init__(self, max_size_mb: int):
class OCRUnexpectedError(HTTPException):
"""Raised when there is an unexpected error in the OCR service."""
def __init__(self):
super().__init__(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"File size exceeds limit of {max_size_mb} MB."
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=settings.OCR_UNEXPECTED_ERROR
)
class OcrProcessingError(HTTPException):
"""Raised when there is an error processing the image with OCR."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Could not extract items from image: {detail}"
)
class OcrQuotaExceededError(HTTPException):
class OCRQuotaExceededError(HTTPException):
"""Raised when the OCR service quota is exceeded."""
def __init__(self):
super().__init__(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="OCR service quota exceeded. Please try again later."
detail=settings.OCR_QUOTA_EXCEEDED
)
class InvalidFileTypeError(HTTPException):
"""Raised when an invalid file type is uploaded for OCR."""
def __init__(self):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=settings.OCR_INVALID_FILE_TYPE.format(types=", ".join(settings.ALLOWED_IMAGE_TYPES))
)
class FileTooLargeError(HTTPException):
"""Raised when an uploaded file exceeds the size limit."""
def __init__(self):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=settings.OCR_FILE_TOO_LARGE.format(size=settings.MAX_FILE_SIZE_MB)
)
class OCRProcessingError(HTTPException):
"""Raised when there is an error processing the image with OCR."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=settings.OCR_PROCESSING_ERROR.format(detail=detail)
)
class EmailAlreadyRegisteredError(HTTPException):
@ -152,15 +184,6 @@ class EmailAlreadyRegisteredError(HTTPException):
detail="Email already registered."
)
class InvalidCredentialsError(HTTPException):
"""Raised when login credentials are invalid."""
def __init__(self):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"}
)
class UserCreationError(HTTPException):
"""Raised when there is an error creating a new user."""
def __init__(self):
@ -208,3 +231,47 @@ class ListStatusNotFoundError(HTTPException):
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Status for list {list_id} not found"
)
class ConflictError(HTTPException):
"""Raised when an optimistic lock version conflict occurs."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_409_CONFLICT,
detail=detail
)
class InvalidCredentialsError(HTTPException):
"""Raised when login credentials are invalid."""
def __init__(self):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.AUTH_INVALID_CREDENTIALS,
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_credentials\""}
)
class NotAuthenticatedError(HTTPException):
"""Raised when the user is not authenticated."""
def __init__(self):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.AUTH_NOT_AUTHENTICATED,
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"not_authenticated\""}
)
class JWTError(HTTPException):
"""Raised when there is an error with the JWT token."""
def __init__(self, error: str):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.JWT_ERROR.format(error=error),
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
)
class JWTUnexpectedError(HTTPException):
"""Raised when there is an unexpected error with the JWT token."""
def __init__(self, error: str):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.JWT_UNEXPECTED_ERROR.format(error=error),
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
)

View File

@ -4,8 +4,13 @@ from typing import List
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold # For safety settings
from google.api_core import exceptions as google_exceptions
from app.config import settings
from app.core.exceptions import (
OCRServiceUnavailableError,
OCRServiceConfigError,
OCRUnexpectedError,
OCRQuotaExceededError
)
logger = logging.getLogger(__name__)
@ -19,26 +24,18 @@ try:
genai.configure(api_key=settings.GEMINI_API_KEY)
# Initialize the specific model we want to use
gemini_flash_client = genai.GenerativeModel(
model_name="gemini-2.0-flash",
# Optional: Add default safety settings
# Adjust these based on your expected content and risk tolerance
model_name=settings.GEMINI_MODEL_NAME,
# Safety settings from config
safety_settings={
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
getattr(HarmCategory, category): getattr(HarmBlockThreshold, threshold)
for category, threshold in settings.GEMINI_SAFETY_SETTINGS.items()
},
# Optional: Add default generation config (can be overridden per request)
# generation_config=genai.types.GenerationConfig(
# # candidate_count=1, # Usually default is 1
# # stop_sequences=["\n"],
# # max_output_tokens=2048,
# # temperature=0.9, # Controls randomness (0=deterministic, >1=more random)
# # top_p=1,
# # top_k=1
# )
# Generation config from settings
generation_config=genai.types.GenerationConfig(
**settings.GEMINI_GENERATION_CONFIG
)
)
logger.info("Gemini AI client initialized successfully for model 'gemini-1.5-flash-latest'.")
logger.info(f"Gemini AI client initialized successfully for model '{settings.GEMINI_MODEL_NAME}'.")
else:
# Store error if API key is missing
gemini_initialization_error = "GEMINI_API_KEY not configured. Gemini client not initialized."
@ -105,7 +102,7 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "
# Prepare the full prompt content
prompt_parts = [
OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first
settings.OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first
image_part # Then the image
]
@ -153,3 +150,46 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
# Wrap in a generic ValueError or re-raise
raise ValueError(f"Failed to process image with Gemini: {e}") from e
class GeminiOCRService:
def __init__(self):
try:
genai.configure(api_key=settings.GEMINI_API_KEY)
self.model = genai.GenerativeModel(settings.GEMINI_MODEL_NAME)
self.model.safety_settings = settings.GEMINI_SAFETY_SETTINGS
self.model.generation_config = settings.GEMINI_GENERATION_CONFIG
except Exception as e:
logger.error(f"Failed to initialize Gemini client: {e}")
raise OCRServiceConfigError()
async def extract_items(self, image_data: bytes) -> List[str]:
"""
Extract shopping list items from an image using Gemini Vision.
"""
try:
# Create image part
image_parts = [{"mime_type": "image/jpeg", "data": image_data}]
# Generate content
response = await self.model.generate_content_async(
contents=[settings.OCR_ITEM_EXTRACTION_PROMPT, *image_parts]
)
# Process response
if not response.text:
raise OCRUnexpectedError()
# Split response into lines and clean up
items = [
item.strip()
for item in response.text.split("\n")
if item.strip() and not item.strip().startswith("Example")
]
return items
except Exception as e:
logger.error(f"Error during OCR extraction: {e}")
if "quota" in str(e).lower():
raise OCRQuotaExceededError()
raise OCRServiceUnavailableError()

View File

@ -66,7 +66,34 @@ def create_access_token(subject: Union[str, Any], expires_delta: Optional[timede
)
# Data to encode in the token payload
to_encode = {"exp": expire, "sub": str(subject)}
to_encode = {"exp": expire, "sub": str(subject), "type": "access"}
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def create_refresh_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
"""
Creates a JWT refresh token.
Args:
subject: The subject of the token (e.g., user ID or email).
expires_delta: Optional timedelta object for token expiry. If None,
uses REFRESH_TOKEN_EXPIRE_MINUTES from settings.
Returns:
The encoded JWT refresh token string.
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
)
# Data to encode in the token payload
to_encode = {"exp": expire, "sub": str(subject), "type": "refresh"}
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
@ -91,6 +118,8 @@ def verify_access_token(token: str) -> Optional[dict]:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
if payload.get("type") != "access":
raise JWTError("Invalid token type")
return payload
except JWTError as e:
# Handles InvalidSignatureError, ExpiredSignatureError, etc.
@ -101,6 +130,31 @@ def verify_access_token(token: str) -> Optional[dict]:
print(f"Unexpected error decoding JWT: {e}")
return None
def verify_refresh_token(token: str) -> Optional[dict]:
"""
Verifies a JWT refresh token and returns its payload if valid.
Args:
token: The JWT token string to verify.
Returns:
The decoded token payload (dict) if the token is valid, not expired,
and is a refresh token, otherwise None.
"""
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
if payload.get("type") != "refresh":
raise JWTError("Invalid token type")
return payload
except JWTError as e:
print(f"JWT Error: {e}") # Log the error for debugging
return None
except Exception as e:
print(f"Unexpected error decoding JWT: {e}")
return None
# You might add a function here later to extract the 'sub' (subject/user id)
# specifically, often used in dependency injection for authentication.
# def get_subject_from_token(token: str) -> Optional[str]:

116
be/app/crud/cost.py Normal file
View File

@ -0,0 +1,116 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from decimal import Decimal, ROUND_HALF_UP
from typing import List as PyList, Dict, Set
from app.models import List as ListModel, Item as ItemModel, User as UserModel, UserGroup as UserGroupModel, Group as GroupModel
from app.schemas.cost import ListCostSummary, UserCostShare
from app.core.exceptions import ListNotFoundError, UserNotFoundError # Assuming UserNotFoundError might be useful
async def calculate_list_cost_summary(db: AsyncSession, list_id: int) -> ListCostSummary:
"""
Calculates the cost summary for a given list, splitting costs equally
among relevant users (group members if list is in a group, or creator if personal).
"""
# 1. Fetch the list, its items (with their 'added_by_user'), and its group (with members)
list_result = await db.execute(
select(ListModel)
.options(
selectinload(ListModel.items).options(joinedload(ItemModel.added_by_user)),
selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user)))
)
.where(ListModel.id == list_id)
)
db_list: Optional[ListModel] = list_result.scalars().first()
if not db_list:
raise ListNotFoundError(list_id)
# 2. Determine participating users
participating_users: Dict[int, UserModel] = {}
if db_list.group:
# If list is part of a group, all group members participate
for ug_assoc in db_list.group.user_associations:
if ug_assoc.user: # Ensure user object is loaded
participating_users[ug_assoc.user.id] = ug_assoc.user
else:
# If personal list, only the creator participates (or if items were added by others somehow, include them)
# For simplicity in MVP, if personal, only creator. If shared personal lists become a feature, this needs revisit.
# Let's fetch the creator if not already available through relationships (though it should be via ListModel.creator)
creator_user = await db.get(UserModel, db_list.created_by_id)
if not creator_user:
# This case should ideally not happen if data integrity is maintained
raise UserNotFoundError(user_id=db_list.created_by_id) # Or a more specific error
participating_users[creator_user.id] = creator_user
# Also ensure all users who added items are included, even if not in the group (edge case, but good for robustness)
for item in db_list.items:
if item.added_by_user and item.added_by_user.id not in participating_users:
participating_users[item.added_by_user.id] = item.added_by_user
num_participating_users = len(participating_users)
if num_participating_users == 0:
# Handle case with no users (e.g., empty group, or personal list creator deleted - though FKs should prevent)
# Or if list has no items and is personal, creator might not be in participating_users if logic changes.
# For now, if no users found (e.g. group with no members and list creator somehow not added), return empty/default summary.
return ListCostSummary(
list_id=db_list.id,
list_name=db_list.name,
total_list_cost=Decimal("0.00"),
num_participating_users=0,
equal_share_per_user=Decimal("0.00"),
user_balances=[]
)
# 3. Calculate total cost and items_added_value for each user
total_list_cost = Decimal("0.00")
user_items_added_value: Dict[int, Decimal] = {user_id: Decimal("0.00") for user_id in participating_users.keys()}
for item in db_list.items:
if item.price is not None and item.price > Decimal("0"):
total_list_cost += item.price
if item.added_by_id in user_items_added_value: # Item adder must be in participating users
user_items_added_value[item.added_by_id] += item.price
# If item.added_by_id is not in participating_users (e.g. user left group),
# their contribution still counts to total cost, but they aren't part of the split.
# The current logic adds item adders to participating_users, so this else is less likely.
# 4. Calculate equal share per user
# Using ROUND_HALF_UP to handle cents appropriately.
# Ensure division by zero is handled if num_participating_users could be 0 (already handled above)
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) if num_participating_users > 0 else Decimal("0.00")
# 5. For each user, calculate their balance
user_balances: PyList[UserCostShare] = []
for user_id, user_obj in participating_users.items():
items_added = user_items_added_value.get(user_id, Decimal("0.00"))
balance = items_added - equal_share_per_user
user_identifier = user_obj.name if user_obj.name else user_obj.email # Prefer name, fallback to email
user_balances.append(
UserCostShare(
user_id=user_id,
user_identifier=user_identifier,
items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
amount_due=equal_share_per_user, # Already quantized
balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
)
)
# Sort user_balances for consistent output, e.g., by user_id or identifier
user_balances.sort(key=lambda x: x.user_identifier)
# 6. Return the populated ListCostSummary schema
return ListCostSummary(
list_id=db_list.id,
list_name=db_list.name,
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
num_participating_users=num_participating_users,
equal_share_per_user=equal_share_per_user, # Already quantized
user_balances=user_balances
)

View File

@ -13,7 +13,8 @@ from app.core.exceptions import (
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError
DatabaseTransactionError,
ConflictError
)
async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel:
@ -26,6 +27,7 @@ async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_
list_id=list_id,
added_by_id=user_id,
is_complete=False # Default on creation
# version is implicitly set to 1 by model default
)
db.add(db_item)
await db.flush()
@ -65,44 +67,57 @@ async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]:
raise DatabaseQueryError(f"Failed to query item: {str(e)}")
async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel:
"""Updates an existing item record."""
"""Updates an existing item record, checking for version conflicts."""
try:
async with db.begin():
update_data = item_in.model_dump(exclude_unset=True) # Get only provided fields
if item_db.version != item_in.version:
raise ConflictError(
f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. "
f"Your version is {item_in.version}, current version is {item_db.version}. Please refresh."
)
update_data = item_in.model_dump(exclude_unset=True, exclude={'version'}) # Exclude version
# Special handling for is_complete
if 'is_complete' in update_data:
if update_data['is_complete'] is True:
# Mark as complete: set completed_by_id if not already set
if item_db.completed_by_id is None:
if item_db.completed_by_id is None: # Only set if not already completed by someone
update_data['completed_by_id'] = user_id
else:
# Mark as incomplete: clear completed_by_id
update_data['completed_by_id'] = None
# Ensure updated_at is refreshed (handled by onupdate in model, but explicit is fine too)
# update_data['updated_at'] = datetime.now(timezone.utc)
update_data['completed_by_id'] = None # Clear if marked incomplete
for key, value in update_data.items():
setattr(item_db, key, value)
db.add(item_db) # Add to session to track changes
item_db.version += 1 # Increment version
db.add(item_db)
await db.flush()
await db.refresh(item_db)
return item_db
except IntegrityError as e:
raise DatabaseIntegrityError(f"Failed to update item: {str(e)}")
await db.rollback()
raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}")
except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
await db.rollback()
raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}")
except ConflictError: # Re-raise ConflictError
await db.rollback()
raise
except SQLAlchemyError as e:
await db.rollback()
raise DatabaseTransactionError(f"Failed to update item: {str(e)}")
async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
"""Deletes an item record."""
"""Deletes an item record. Version check should be done by the caller (API endpoint)."""
try:
async with db.begin():
await db.delete(item_db)
return None # Or return True/False
# await db.commit() # Not needed with async with db.begin()
return None
except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
await db.rollback()
raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}")
except SQLAlchemyError as e:
await db.rollback()
raise DatabaseTransactionError(f"Failed to delete item: {str(e)}")

View File

@ -16,7 +16,8 @@ from app.core.exceptions import (
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError
DatabaseTransactionError,
ConflictError
)
async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
@ -85,32 +86,50 @@ async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = Fals
raise DatabaseQueryError(f"Failed to query list: {str(e)}")
async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel:
"""Updates an existing list record."""
"""Updates an existing list record, checking for version conflicts."""
try:
async with db.begin():
update_data = list_in.model_dump(exclude_unset=True)
if list_db.version != list_in.version:
raise ConflictError(
f"List '{list_db.name}' (ID: {list_db.id}) has been modified. "
f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh."
)
update_data = list_in.model_dump(exclude_unset=True, exclude={'version'})
for key, value in update_data.items():
setattr(list_db, key, value)
list_db.version += 1
db.add(list_db)
await db.flush()
await db.refresh(list_db)
return list_db
except IntegrityError as e:
raise DatabaseIntegrityError(f"Failed to update list: {str(e)}")
await db.rollback()
raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}")
except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
await db.rollback()
raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}")
except ConflictError:
await db.rollback()
raise
except SQLAlchemyError as e:
await db.rollback()
raise DatabaseTransactionError(f"Failed to update list: {str(e)}")
async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
"""Deletes a list record."""
"""Deletes a list record. Version check should be done by the caller (API endpoint)."""
try:
async with db.begin():
await db.delete(list_db)
return None
except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
await db.rollback()
raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}")
except SQLAlchemyError as e:
await db.rollback()
raise DatabaseTransactionError(f"Failed to delete list: {str(e)}")
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel:

View File

@ -5,22 +5,25 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.api_router import api_router # Import the main combined router
from app.config import settings
# Import database and models if needed for startup/shutdown events later
# from . import database, models
# --- Logging Setup ---
# Configure logging (can be more sophisticated later, e.g., using logging.yaml)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logging.basicConfig(
level=getattr(logging, settings.LOG_LEVEL),
format=settings.LOG_FORMAT
)
logger = logging.getLogger(__name__)
# --- FastAPI App Instance ---
app = FastAPI(
title="Shared Lists API",
description="API for managing shared shopping lists, OCR, and cost splitting.",
version="0.1.0",
openapi_url="/api/openapi.json", # Place OpenAPI spec under /api
docs_url="/api/docs", # Place Swagger UI under /api
redoc_url="/api/redoc" # Place ReDoc under /api
title=settings.API_TITLE,
description=settings.API_DESCRIPTION,
version=settings.API_VERSION,
openapi_url=settings.API_OPENAPI_URL,
docs_url=settings.API_DOCS_URL,
redoc_url=settings.API_REDOC_URL
)
# --- CORS Middleware ---
@ -37,17 +40,17 @@ origins = [
app.add_middleware(
CORSMiddleware,
allow_origins=origins, # List of origins that are allowed to make requests
allow_credentials=True, # Allow cookies to be included in requests
allow_methods=["*"], # Allow all methods (GET, POST, PUT, DELETE, etc.)
allow_headers=["*"], # Allow all headers
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- End CORS Middleware ---
# --- Include API Routers ---
# All API endpoints will be prefixed with /api
app.include_router(api_router, prefix="/api")
app.include_router(api_router, prefix=settings.API_PREFIX)
# --- End Include API Routers ---
@ -59,10 +62,7 @@ async def read_root():
Useful for basic reachability checks.
"""
logger.info("Root endpoint '/' accessed.")
# You could redirect to the docs or return a simple message
# from fastapi.responses import RedirectResponse
# return RedirectResponse(url="/api/docs")
return {"message": "Welcome to the Shared Lists API! Docs available at /api/docs"}
return {"message": settings.ROOT_MESSAGE}
# --- End Root Endpoint ---

View File

@ -117,6 +117,7 @@ class List(Base):
is_complete = Column(Boolean, default=False, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
version = Column(Integer, nullable=False, default=1, server_default='1')
# --- Relationships ---
creator = relationship("User", back_populates="created_lists") # Link to User.created_lists
@ -138,6 +139,7 @@ class Item(Base):
completed_by_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Who marked it complete
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
version = Column(Integer, nullable=False, default=1, server_default='1')
# --- Relationships ---
list = relationship("List", back_populates="items") # Link to List.items

View File

@ -1,9 +1,11 @@
# app/schemas/auth.py
from pydantic import BaseModel, EmailStr
from app.config import settings
class Token(BaseModel):
access_token: str
token_type: str = "bearer" # Default token type
refresh_token: str # Added refresh token
token_type: str = settings.TOKEN_TYPE # Use configured token type
# Optional: If you preferred not to use OAuth2PasswordRequestForm
# class UserLogin(BaseModel):

22
be/app/schemas/cost.py Normal file
View File

@ -0,0 +1,22 @@
from pydantic import BaseModel, ConfigDict
from typing import List, Optional
from decimal import Decimal
class UserCostShare(BaseModel):
user_id: int
user_identifier: str # Name or email
items_added_value: Decimal = Decimal("0.00") # Total value of items this user added
amount_due: Decimal # The user's share of the total cost (for equal split, this is total_cost / num_users)
balance: Decimal # items_added_value - amount_due
model_config = ConfigDict(from_attributes=True)
class ListCostSummary(BaseModel):
list_id: int
list_name: str
total_list_cost: Decimal
num_participating_users: int
equal_share_per_user: Decimal
user_balances: List[UserCostShare]
model_config = ConfigDict(from_attributes=True)

View File

@ -1,9 +1,10 @@
# app/schemas/health.py
from pydantic import BaseModel
from app.config import settings
class HealthStatus(BaseModel):
"""
Response model for the health check endpoint.
"""
status: str = "ok" # Provide a default value
status: str = settings.HEALTH_STATUS_OK # Use configured default value
database: str

View File

@ -16,6 +16,7 @@ class ItemPublic(BaseModel):
completed_by_id: Optional[int] = None
created_at: datetime
updated_at: datetime
version: int
model_config = ConfigDict(from_attributes=True)
# Properties to receive via API on creation
@ -31,4 +32,5 @@ class ItemUpdate(BaseModel):
quantity: Optional[str] = None
is_complete: Optional[bool] = None
price: Optional[Decimal] = None # Price added here for update
version: int
# completed_by_id will be set internally if is_complete is true

View File

@ -16,6 +16,7 @@ class ListUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
is_complete: Optional[bool] = None
version: int # Client must provide the version for updates
# Potentially add group_id update later if needed
# Base properties returned by API (common fields)
@ -28,6 +29,7 @@ class ListBase(BaseModel):
is_complete: bool
created_at: datetime
updated_at: datetime
version: int # Include version in responses
model_config = ConfigDict(from_attributes=True)