add_version_to_lists_and_items
This commit is contained in:
parent
d2d484c327
commit
423d345fdf
@ -32,4 +32,4 @@ EXPOSE 8000
|
|||||||
# The default command for production (can be overridden in docker-compose for development)
|
# The default command for production (can be overridden in docker-compose for development)
|
||||||
# Note: Make sure 'app.main:app' correctly points to your FastAPI app instance
|
# Note: Make sure 'app.main:app' correctly points to your FastAPI app instance
|
||||||
# relative to the WORKDIR (/app). If your main.py is directly in /app, this is correct.
|
# relative to the WORKDIR (/app). If your main.py is directly in /app, this is correct.
|
||||||
CMD ["uvicorn", "app.main:app", "--host", "localhost", "--port", "8000"]
|
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
@ -0,0 +1,28 @@
|
|||||||
|
"""add_version_to_lists_and_items
|
||||||
|
|
||||||
|
Revision ID: d53eedd151b7
|
||||||
|
Revises: d25788f63e2c
|
||||||
|
Create Date: 2025-05-07 21:05:50.396430
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'd53eedd151b7'
|
||||||
|
down_revision: Union[str, None] = 'd25788f63e2c'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
pass
|
@ -11,13 +11,14 @@ from app.database import get_db
|
|||||||
from app.core.security import verify_access_token
|
from app.core.security import verify_access_token
|
||||||
from app.crud import user as crud_user
|
from app.crud import user as crud_user
|
||||||
from app.models import User as UserModel # Import the SQLAlchemy model
|
from app.models import User as UserModel # Import the SQLAlchemy model
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Define the OAuth2 scheme
|
# Define the OAuth2 scheme
|
||||||
# tokenUrl should point to your login endpoint relative to the base path
|
# tokenUrl should point to your login endpoint relative to the base path
|
||||||
# It's used by Swagger UI for the "Authorize" button flow.
|
# It's used by Swagger UI for the "Authorize" button flow.
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") # Corrected path
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=settings.OAUTH2_TOKEN_URL)
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
token: str = Depends(oauth2_scheme),
|
token: str = Depends(oauth2_scheme),
|
||||||
@ -36,8 +37,8 @@ async def get_current_user(
|
|||||||
"""
|
"""
|
||||||
credentials_exception = HTTPException(
|
credentials_exception = HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Could not validate credentials",
|
detail=settings.AUTH_CREDENTIALS_ERROR,
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={settings.AUTH_HEADER_NAME: settings.AUTH_HEADER_PREFIX},
|
||||||
)
|
)
|
||||||
|
|
||||||
payload = verify_access_token(token)
|
payload = verify_access_token(token)
|
||||||
|
@ -9,6 +9,7 @@ from app.api.v1.endpoints import invites
|
|||||||
from app.api.v1.endpoints import lists
|
from app.api.v1.endpoints import lists
|
||||||
from app.api.v1.endpoints import items
|
from app.api.v1.endpoints import items
|
||||||
from app.api.v1.endpoints import ocr
|
from app.api.v1.endpoints import ocr
|
||||||
|
from app.api.v1.endpoints import costs
|
||||||
|
|
||||||
api_router_v1 = APIRouter()
|
api_router_v1 = APIRouter()
|
||||||
|
|
||||||
@ -20,5 +21,6 @@ api_router_v1.include_router(invites.router, prefix="/invites", tags=["Invites"]
|
|||||||
api_router_v1.include_router(lists.router, prefix="/lists", tags=["Lists"])
|
api_router_v1.include_router(lists.router, prefix="/lists", tags=["Lists"])
|
||||||
api_router_v1.include_router(items.router, tags=["Items"])
|
api_router_v1.include_router(items.router, tags=["Items"])
|
||||||
api_router_v1.include_router(ocr.router, prefix="/ocr", tags=["OCR"])
|
api_router_v1.include_router(ocr.router, prefix="/ocr", tags=["OCR"])
|
||||||
|
api_router_v1.include_router(costs.router, tags=["Costs"])
|
||||||
# Add other v1 endpoint routers here later
|
# Add other v1 endpoint routers here later
|
||||||
# e.g., api_router_v1.include_router(users.router, prefix="/users", tags=["Users"])
|
# e.g., api_router_v1.include_router(users.router, prefix="/users", tags=["Users"])
|
@ -1,6 +1,7 @@
|
|||||||
# app/api/v1/endpoints/auth.py
|
# app/api/v1/endpoints/auth.py
|
||||||
import logging
|
import logging
|
||||||
from fastapi import APIRouter, Depends
|
from typing import Annotated
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
@ -8,12 +9,18 @@ from app.database import get_db
|
|||||||
from app.schemas.user import UserCreate, UserPublic
|
from app.schemas.user import UserCreate, UserPublic
|
||||||
from app.schemas.auth import Token
|
from app.schemas.auth import Token
|
||||||
from app.crud import user as crud_user
|
from app.crud import user as crud_user
|
||||||
from app.core.security import verify_password, create_access_token
|
from app.core.security import (
|
||||||
|
verify_password,
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
verify_refresh_token
|
||||||
|
)
|
||||||
from app.core.exceptions import (
|
from app.core.exceptions import (
|
||||||
EmailAlreadyRegisteredError,
|
EmailAlreadyRegisteredError,
|
||||||
InvalidCredentialsError,
|
InvalidCredentialsError,
|
||||||
UserCreationError
|
UserCreationError
|
||||||
)
|
)
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -55,28 +62,74 @@ async def signup(
|
|||||||
"/login",
|
"/login",
|
||||||
response_model=Token,
|
response_model=Token,
|
||||||
summary="User Login",
|
summary="User Login",
|
||||||
description="Authenticates a user and returns an access token.",
|
description="Authenticates a user and returns an access and refresh token.",
|
||||||
tags=["Authentication"]
|
tags=["Authentication"]
|
||||||
)
|
)
|
||||||
async def login(
|
async def login(
|
||||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Handles user login.
|
Handles user login.
|
||||||
- Finds user by email (provided in 'username' field of form).
|
- Finds user by email (provided in 'username' field of form).
|
||||||
- Verifies the provided password against the stored hash.
|
- Verifies the provided password against the stored hash.
|
||||||
- Generates and returns a JWT access token upon successful authentication.
|
- Generates and returns JWT access and refresh tokens upon successful authentication.
|
||||||
"""
|
"""
|
||||||
logger.info(f"Login attempt for user: {form_data.username}")
|
logger.info(f"Login attempt for user: {form_data.username}")
|
||||||
user = await crud_user.get_user_by_email(db, email=form_data.username)
|
user = await crud_user.get_user_by_email(db, email=form_data.username)
|
||||||
|
|
||||||
# Check if user exists and password is correct
|
|
||||||
if not user or not verify_password(form_data.password, user.password_hash):
|
if not user or not verify_password(form_data.password, user.password_hash):
|
||||||
logger.warning(f"Login failed: Invalid credentials for user {form_data.username}")
|
logger.warning(f"Login failed: Invalid credentials for user {form_data.username}")
|
||||||
raise InvalidCredentialsError()
|
raise InvalidCredentialsError()
|
||||||
|
|
||||||
# Generate JWT
|
|
||||||
access_token = create_access_token(subject=user.email)
|
access_token = create_access_token(subject=user.email)
|
||||||
logger.info(f"Login successful, token generated for user: {user.email}")
|
refresh_token = create_refresh_token(subject=user.email)
|
||||||
return Token(access_token=access_token, token_type="bearer")
|
logger.info(f"Login successful, tokens generated for user: {user.email}")
|
||||||
|
return Token(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
token_type=settings.TOKEN_TYPE
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/refresh",
|
||||||
|
response_model=Token,
|
||||||
|
summary="Refresh Access Token",
|
||||||
|
description="Refreshes an access token using a refresh token.",
|
||||||
|
tags=["Authentication"]
|
||||||
|
)
|
||||||
|
async def refresh_token(
|
||||||
|
refresh_token_str: str,
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Handles access token refresh.
|
||||||
|
- Verifies the provided refresh token.
|
||||||
|
- If valid, generates and returns a new JWT access token and the same refresh token.
|
||||||
|
"""
|
||||||
|
logger.info("Access token refresh attempt")
|
||||||
|
payload = verify_refresh_token(refresh_token_str)
|
||||||
|
if not payload:
|
||||||
|
logger.warning("Refresh token invalid or expired")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid or expired refresh token",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
user_email = payload.get("sub")
|
||||||
|
if not user_email:
|
||||||
|
logger.error("User email not found in refresh token payload")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid refresh token payload",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
new_access_token = create_access_token(subject=user_email)
|
||||||
|
logger.info(f"Access token refreshed for user: {user_email}")
|
||||||
|
return Token(
|
||||||
|
access_token=new_access_token,
|
||||||
|
refresh_token=refresh_token_str,
|
||||||
|
token_type=settings.TOKEN_TYPE
|
||||||
|
)
|
69
be/app/api/v1/endpoints/costs.py
Normal file
69
be/app/api/v1/endpoints/costs.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
# app/api/v1/endpoints/costs.py
|
||||||
|
import logging
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.database import get_db
|
||||||
|
from app.api.dependencies import get_current_user
|
||||||
|
from app.models import User as UserModel # For get_current_user dependency
|
||||||
|
from app.schemas.cost import ListCostSummary
|
||||||
|
from app.crud import cost as crud_cost
|
||||||
|
from app.crud import list as crud_list # For permission checking
|
||||||
|
from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/lists/{list_id}/cost-summary",
|
||||||
|
response_model=ListCostSummary,
|
||||||
|
summary="Get Cost Summary for a List",
|
||||||
|
tags=["Costs"],
|
||||||
|
responses={
|
||||||
|
status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this list"},
|
||||||
|
status.HTTP_404_NOT_FOUND: {"description": "List or associated user not found"}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
async def get_list_cost_summary(
|
||||||
|
list_id: int,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: UserModel = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieves a calculated cost summary for a specific list, detailing total costs,
|
||||||
|
equal shares per user, and individual user balances based on their contributions.
|
||||||
|
|
||||||
|
The user must have access to the list to view its cost summary.
|
||||||
|
Costs are split among group members if the list belongs to a group, or just for
|
||||||
|
the creator if it's a personal list. All users who added items with prices are
|
||||||
|
included in the calculation.
|
||||||
|
"""
|
||||||
|
logger.info(f"User {current_user.email} requesting cost summary for list {list_id}")
|
||||||
|
|
||||||
|
# 1. Verify user has access to the target list
|
||||||
|
try:
|
||||||
|
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
|
||||||
|
except ListPermissionError as e:
|
||||||
|
logger.warning(f"Permission denied for user {current_user.email} on list {list_id}: {str(e)}")
|
||||||
|
raise # Re-raise the original exception to be handled by FastAPI
|
||||||
|
except ListNotFoundError as e:
|
||||||
|
logger.warning(f"List {list_id} not found when checking permissions for cost summary: {str(e)}")
|
||||||
|
raise # Re-raise
|
||||||
|
|
||||||
|
# 2. Calculate the cost summary
|
||||||
|
try:
|
||||||
|
cost_summary = await crud_cost.calculate_list_cost_summary(db=db, list_id=list_id)
|
||||||
|
logger.info(f"Successfully generated cost summary for list {list_id} for user {current_user.email}")
|
||||||
|
return cost_summary
|
||||||
|
except ListNotFoundError as e:
|
||||||
|
logger.warning(f"List {list_id} not found during cost summary calculation: {str(e)}")
|
||||||
|
# This might be redundant if check_list_permission already confirmed list existence,
|
||||||
|
# but calculate_list_cost_summary also fetches the list.
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||||
|
except UserNotFoundError as e:
|
||||||
|
logger.error(f"User not found during cost summary calculation for list {list_id}: {str(e)}")
|
||||||
|
# This indicates a data integrity issue (e.g., list creator or item adder missing)
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error generating cost summary for list {list_id} for user {current_user.email}: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while generating the cost summary.")
|
@ -1,8 +1,8 @@
|
|||||||
# app/api/v1/endpoints/items.py
|
# app/api/v1/endpoints/items.py
|
||||||
import logging
|
import logging
|
||||||
from typing import List as PyList
|
from typing import List as PyList, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Response
|
from fastapi import APIRouter, Depends, HTTPException, status, Response, Query
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
@ -14,7 +14,7 @@ from app.models import Item as ItemModel # <-- IMPORT Item and alias it
|
|||||||
from app.schemas.item import ItemCreate, ItemUpdate, ItemPublic
|
from app.schemas.item import ItemCreate, ItemUpdate, ItemPublic
|
||||||
from app.crud import item as crud_item
|
from app.crud import item as crud_item
|
||||||
from app.crud import list as crud_list
|
from app.crud import list as crud_list
|
||||||
from app.core.exceptions import ItemNotFoundError, ListPermissionError
|
from app.core.exceptions import ItemNotFoundError, ListPermissionError, ConflictError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -100,7 +100,10 @@ async def read_list_items(
|
|||||||
"/items/{item_id}", # Operate directly on item ID
|
"/items/{item_id}", # Operate directly on item ID
|
||||||
response_model=ItemPublic,
|
response_model=ItemPublic,
|
||||||
summary="Update Item",
|
summary="Update Item",
|
||||||
tags=["Items"]
|
tags=["Items"],
|
||||||
|
responses={
|
||||||
|
status.HTTP_409_CONFLICT: {"description": "Conflict: Item has been modified by someone else"}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
async def update_item(
|
async def update_item(
|
||||||
item_id: int, # Item ID from path
|
item_id: int, # Item ID from path
|
||||||
@ -112,37 +115,61 @@ async def update_item(
|
|||||||
"""
|
"""
|
||||||
Updates an item's details (name, quantity, is_complete, price).
|
Updates an item's details (name, quantity, is_complete, price).
|
||||||
User must have access to the list the item belongs to.
|
User must have access to the list the item belongs to.
|
||||||
|
The client MUST provide the current `version` of the item in the `item_in` payload.
|
||||||
|
If the version does not match, a 409 Conflict is returned.
|
||||||
Sets/unsets `completed_by_id` based on `is_complete` flag.
|
Sets/unsets `completed_by_id` based on `is_complete` flag.
|
||||||
"""
|
"""
|
||||||
logger.info(f"User {current_user.email} attempting to update item ID: {item_id}")
|
logger.info(f"User {current_user.email} attempting to update item ID: {item_id} with version {item_in.version}")
|
||||||
# Permission check is handled by get_item_and_verify_access dependency
|
# Permission check is handled by get_item_and_verify_access dependency
|
||||||
|
|
||||||
|
try:
|
||||||
updated_item = await crud_item.update_item(
|
updated_item = await crud_item.update_item(
|
||||||
db=db, item_db=item_db, item_in=item_in, user_id=current_user.id
|
db=db, item_db=item_db, item_in=item_in, user_id=current_user.id
|
||||||
)
|
)
|
||||||
logger.info(f"Item {item_id} updated successfully by user {current_user.email}.")
|
logger.info(f"Item {item_id} updated successfully by user {current_user.email} to version {updated_item.version}.")
|
||||||
return updated_item
|
return updated_item
|
||||||
|
except ConflictError as e:
|
||||||
|
logger.warning(f"Conflict updating item {item_id} for user {current_user.email}: {str(e)}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating item {item_id} for user {current_user.email}: {str(e)}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the item.")
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
"/items/{item_id}", # Operate directly on item ID
|
"/items/{item_id}", # Operate directly on item ID
|
||||||
status_code=status.HTTP_204_NO_CONTENT,
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
summary="Delete Item",
|
summary="Delete Item",
|
||||||
tags=["Items"]
|
tags=["Items"],
|
||||||
|
responses={
|
||||||
|
status.HTTP_409_CONFLICT: {"description": "Conflict: Item has been modified, cannot delete specified version"}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
async def delete_item(
|
async def delete_item(
|
||||||
item_id: int, # Item ID from path
|
item_id: int, # Item ID from path
|
||||||
|
expected_version: Optional[int] = Query(None, description="The expected version of the item to delete for optimistic locking."),
|
||||||
item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access
|
item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
current_user: UserModel = Depends(get_current_user), # Log who deleted it
|
current_user: UserModel = Depends(get_current_user), # Log who deleted it
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Deletes an item. User must have access to the list the item belongs to.
|
Deletes an item. User must have access to the list the item belongs to.
|
||||||
(MVP: Any member with list access can delete items).
|
If `expected_version` is provided and does not match the item's current version,
|
||||||
|
a 409 Conflict is returned.
|
||||||
"""
|
"""
|
||||||
logger.info(f"User {current_user.email} attempting to delete item ID: {item_id}")
|
logger.info(f"User {current_user.email} attempting to delete item ID: {item_id}, expected version: {expected_version}")
|
||||||
# Permission check is handled by get_item_and_verify_access dependency
|
# Permission check is handled by get_item_and_verify_access dependency
|
||||||
|
|
||||||
|
if expected_version is not None and item_db.version != expected_version:
|
||||||
|
logger.warning(
|
||||||
|
f"Conflict deleting item {item_id} for user {current_user.email}. "
|
||||||
|
f"Expected version {expected_version}, actual version {item_db.version}."
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=f"Item has been modified. Expected version {expected_version}, but current version is {item_db.version}. Please refresh."
|
||||||
|
)
|
||||||
|
|
||||||
await crud_item.delete_item(db=db, item_db=item_db)
|
await crud_item.delete_item(db=db, item_db=item_db)
|
||||||
logger.info(f"Item {item_id} deleted successfully by user {current_user.email}.")
|
logger.info(f"Item {item_id} (version {item_db.version}) deleted successfully by user {current_user.email}.")
|
||||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
@ -1,8 +1,8 @@
|
|||||||
# app/api/v1/endpoints/lists.py
|
# app/api/v1/endpoints/lists.py
|
||||||
import logging
|
import logging
|
||||||
from typing import List as PyList # Alias for Python List type hint
|
from typing import List as PyList, Optional # Alias for Python List type hint
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Response
|
from fastapi import APIRouter, Depends, HTTPException, status, Response, Query # Added Query
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
@ -17,7 +17,8 @@ from app.core.exceptions import (
|
|||||||
GroupMembershipError,
|
GroupMembershipError,
|
||||||
ListNotFoundError,
|
ListNotFoundError,
|
||||||
ListPermissionError,
|
ListPermissionError,
|
||||||
ListStatusNotFoundError
|
ListStatusNotFoundError,
|
||||||
|
ConflictError # Added ConflictError
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -101,7 +102,10 @@ async def read_list(
|
|||||||
"/{list_id}",
|
"/{list_id}",
|
||||||
response_model=ListPublic, # Return updated basic info
|
response_model=ListPublic, # Return updated basic info
|
||||||
summary="Update List",
|
summary="Update List",
|
||||||
tags=["Lists"]
|
tags=["Lists"],
|
||||||
|
responses={ # Add 409 to responses
|
||||||
|
status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified by someone else"}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
async def update_list(
|
async def update_list(
|
||||||
list_id: int,
|
list_id: int,
|
||||||
@ -112,40 +116,62 @@ async def update_list(
|
|||||||
"""
|
"""
|
||||||
Updates a list's details (name, description, is_complete).
|
Updates a list's details (name, description, is_complete).
|
||||||
Requires user to be the creator or a member of the list's group.
|
Requires user to be the creator or a member of the list's group.
|
||||||
(MVP: Allows any member to update these fields).
|
The client MUST provide the current `version` of the list in the `list_in` payload.
|
||||||
|
If the version does not match, a 409 Conflict is returned.
|
||||||
"""
|
"""
|
||||||
logger.info(f"User {current_user.email} attempting to update list ID: {list_id}")
|
logger.info(f"User {current_user.email} attempting to update list ID: {list_id} with version {list_in.version}")
|
||||||
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
|
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
|
||||||
|
|
||||||
# Prevent changing group_id or creator via this endpoint for simplicity
|
try:
|
||||||
# if list_in.group_id is not None or list_in.created_by_id is not None:
|
|
||||||
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot change group or creator via this endpoint")
|
|
||||||
|
|
||||||
updated_list = await crud_list.update_list(db=db, list_db=list_db, list_in=list_in)
|
updated_list = await crud_list.update_list(db=db, list_db=list_db, list_in=list_in)
|
||||||
logger.info(f"List {list_id} updated successfully by user {current_user.email}.")
|
logger.info(f"List {list_id} updated successfully by user {current_user.email} to version {updated_list.version}.")
|
||||||
return updated_list
|
return updated_list
|
||||||
|
except ConflictError as e: # Catch and re-raise as HTTPException for proper FastAPI response
|
||||||
|
logger.warning(f"Conflict updating list {list_id} for user {current_user.email}: {str(e)}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
|
||||||
|
except Exception as e: # Catch other potential errors from crud operation
|
||||||
|
logger.error(f"Error updating list {list_id} for user {current_user.email}: {str(e)}")
|
||||||
|
# Consider a more generic error, but for now, let's keep it specific if possible
|
||||||
|
# Re-raising might be better if crud layer already raises appropriate HTTPExceptions
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the list.")
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
"/{list_id}",
|
"/{list_id}",
|
||||||
status_code=status.HTTP_204_NO_CONTENT, # Standard for successful DELETE with no body
|
status_code=status.HTTP_204_NO_CONTENT, # Standard for successful DELETE with no body
|
||||||
summary="Delete List",
|
summary="Delete List",
|
||||||
tags=["Lists"]
|
tags=["Lists"],
|
||||||
|
responses={ # Add 409 to responses
|
||||||
|
status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified, cannot delete specified version"}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
async def delete_list(
|
async def delete_list(
|
||||||
list_id: int,
|
list_id: int,
|
||||||
|
expected_version: Optional[int] = Query(None, description="The expected version of the list to delete for optimistic locking."),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
current_user: UserModel = Depends(get_current_user),
|
current_user: UserModel = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Deletes a list. Requires user to be the creator of the list.
|
Deletes a list. Requires user to be the creator of the list.
|
||||||
(Alternatively, could allow group owner).
|
If `expected_version` is provided and does not match the list's current version,
|
||||||
|
a 409 Conflict is returned.
|
||||||
"""
|
"""
|
||||||
logger.info(f"User {current_user.email} attempting to delete list ID: {list_id}")
|
logger.info(f"User {current_user.email} attempting to delete list ID: {list_id}, expected version: {expected_version}")
|
||||||
# Use the helper, requiring creator permission
|
# Use the helper, requiring creator permission
|
||||||
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id, require_creator=True)
|
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id, require_creator=True)
|
||||||
|
|
||||||
|
if expected_version is not None and list_db.version != expected_version:
|
||||||
|
logger.warning(
|
||||||
|
f"Conflict deleting list {list_id} for user {current_user.email}. "
|
||||||
|
f"Expected version {expected_version}, actual version {list_db.version}."
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=f"List has been modified. Expected version {expected_version}, but current version is {list_db.version}. Please refresh."
|
||||||
|
)
|
||||||
|
|
||||||
await crud_list.delete_list(db=db, list_db=list_db)
|
await crud_list.delete_list(db=db, list_db=list_db)
|
||||||
logger.info(f"List {list_id} deleted successfully by user {current_user.email}.")
|
logger.info(f"List {list_id} (version: {list_db.version}) deleted successfully by user {current_user.email}.")
|
||||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,28 +1,27 @@
|
|||||||
# app/api/v1/endpoints/ocr.py
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, UploadFile, File
|
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException, status
|
||||||
from google.api_core import exceptions as google_exceptions
|
from google.api_core import exceptions as google_exceptions
|
||||||
|
|
||||||
from app.api.dependencies import get_current_user
|
from app.api.dependencies import get_current_user
|
||||||
from app.models import User as UserModel
|
from app.models import User as UserModel
|
||||||
from app.schemas.ocr import OcrExtractResponse
|
from app.schemas.ocr import OcrExtractResponse
|
||||||
from app.core.gemini import extract_items_from_image_gemini, gemini_initialization_error
|
from app.core.gemini import extract_items_from_image_gemini, gemini_initialization_error, GeminiOCRService
|
||||||
from app.core.exceptions import (
|
from app.core.exceptions import (
|
||||||
OcrServiceUnavailableError,
|
OCRServiceUnavailableError,
|
||||||
|
OCRServiceConfigError,
|
||||||
|
OCRUnexpectedError,
|
||||||
|
OCRQuotaExceededError,
|
||||||
InvalidFileTypeError,
|
InvalidFileTypeError,
|
||||||
FileTooLargeError,
|
FileTooLargeError,
|
||||||
OcrProcessingError,
|
OCRProcessingError
|
||||||
OcrQuotaExceededError
|
|
||||||
)
|
)
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
ocr_service = GeminiOCRService()
|
||||||
# Allowed image MIME types
|
|
||||||
ALLOWED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/webp"]
|
|
||||||
MAX_FILE_SIZE_MB = 10
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/extract-items",
|
"/extract-items",
|
||||||
@ -41,20 +40,20 @@ async def ocr_extract_items(
|
|||||||
# Check if Gemini client initialized correctly
|
# Check if Gemini client initialized correctly
|
||||||
if gemini_initialization_error:
|
if gemini_initialization_error:
|
||||||
logger.error("OCR endpoint called but Gemini client failed to initialize.")
|
logger.error("OCR endpoint called but Gemini client failed to initialize.")
|
||||||
raise OcrServiceUnavailableError(gemini_initialization_error)
|
raise OCRServiceUnavailableError(gemini_initialization_error)
|
||||||
|
|
||||||
logger.info(f"User {current_user.email} uploading image '{image_file.filename}' for OCR extraction.")
|
logger.info(f"User {current_user.email} uploading image '{image_file.filename}' for OCR extraction.")
|
||||||
|
|
||||||
# --- File Validation ---
|
# --- File Validation ---
|
||||||
if image_file.content_type not in ALLOWED_IMAGE_TYPES:
|
if image_file.content_type not in settings.ALLOWED_IMAGE_TYPES:
|
||||||
logger.warning(f"Invalid file type uploaded by {current_user.email}: {image_file.content_type}")
|
logger.warning(f"Invalid file type uploaded by {current_user.email}: {image_file.content_type}")
|
||||||
raise InvalidFileTypeError(ALLOWED_IMAGE_TYPES)
|
raise InvalidFileTypeError()
|
||||||
|
|
||||||
# Simple size check
|
# Simple size check
|
||||||
contents = await image_file.read()
|
contents = await image_file.read()
|
||||||
if len(contents) > MAX_FILE_SIZE_MB * 1024 * 1024:
|
if len(contents) > settings.MAX_FILE_SIZE_MB * 1024 * 1024:
|
||||||
logger.warning(f"File too large uploaded by {current_user.email}: {len(contents)} bytes")
|
logger.warning(f"File too large uploaded by {current_user.email}: {len(contents)} bytes")
|
||||||
raise FileTooLargeError(MAX_FILE_SIZE_MB)
|
raise FileTooLargeError()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Call the Gemini helper function
|
# Call the Gemini helper function
|
||||||
@ -66,30 +65,14 @@ async def ocr_extract_items(
|
|||||||
logger.info(f"Successfully extracted {len(extracted_items)} items for user {current_user.email}.")
|
logger.info(f"Successfully extracted {len(extracted_items)} items for user {current_user.email}.")
|
||||||
return OcrExtractResponse(extracted_items=extracted_items)
|
return OcrExtractResponse(extracted_items=extracted_items)
|
||||||
|
|
||||||
except ValueError as e:
|
except OCRServiceUnavailableError:
|
||||||
# Handle errors from Gemini processing (blocked, empty response, etc.)
|
raise OCRServiceUnavailableError()
|
||||||
logger.warning(f"Gemini processing error for user {current_user.email}: {e}")
|
except OCRServiceConfigError:
|
||||||
raise OcrProcessingError(str(e))
|
raise OCRServiceConfigError()
|
||||||
|
except OCRQuotaExceededError:
|
||||||
except google_exceptions.ResourceExhausted as e:
|
raise OCRQuotaExceededError()
|
||||||
# Specific handling for quota errors
|
|
||||||
logger.error(f"Gemini Quota Exceeded for user {current_user.email}: {e}", exc_info=True)
|
|
||||||
raise OcrQuotaExceededError()
|
|
||||||
|
|
||||||
except google_exceptions.GoogleAPIError as e:
|
|
||||||
# Handle other Google API errors (e.g., invalid key, permissions)
|
|
||||||
logger.error(f"Gemini API Error for user {current_user.email}: {e}", exc_info=True)
|
|
||||||
raise OcrServiceUnavailableError(str(e))
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
# Catch initialization errors from get_gemini_client()
|
|
||||||
logger.error(f"Gemini client runtime error during OCR request: {e}")
|
|
||||||
raise OcrServiceUnavailableError(f"OCR service configuration error: {e}")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Catch any other unexpected errors
|
raise OCRProcessingError(str(e))
|
||||||
logger.exception(f"Unexpected error during OCR extraction for user {current_user.email}: {e}")
|
|
||||||
raise OcrServiceUnavailableError("An unexpected error occurred during item extraction.")
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Ensure file handle is closed
|
# Ensure file handle is closed
|
||||||
|
@ -16,6 +16,86 @@ class Settings(BaseSettings):
|
|||||||
SECRET_KEY: str # Must be set via environment variable
|
SECRET_KEY: str # Must be set via environment variable
|
||||||
ALGORITHM: str = "HS256"
|
ALGORITHM: str = "HS256"
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # Default token lifetime: 30 minutes
|
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # Default token lifetime: 30 minutes
|
||||||
|
REFRESH_TOKEN_EXPIRE_MINUTES: int = 10080 # Default refresh token lifetime: 7 days
|
||||||
|
|
||||||
|
# --- OCR Settings ---
|
||||||
|
MAX_FILE_SIZE_MB: int = 10 # Maximum allowed file size for OCR processing
|
||||||
|
ALLOWED_IMAGE_TYPES: list[str] = ["image/jpeg", "image/png", "image/webp"] # Supported image formats
|
||||||
|
OCR_ITEM_EXTRACTION_PROMPT: str = """
|
||||||
|
Extract the shopping list items from this image.
|
||||||
|
List each distinct item on a new line.
|
||||||
|
Ignore prices, quantities, store names, discounts, taxes, totals, and other non-item text.
|
||||||
|
Focus only on the names of the products or items to be purchased.
|
||||||
|
Add 2 underscores before and after the item name, if it is struck through.
|
||||||
|
If the image does not appear to be a shopping list or receipt, state that clearly.
|
||||||
|
Example output for a grocery list:
|
||||||
|
Milk
|
||||||
|
Eggs
|
||||||
|
Bread
|
||||||
|
__Apples__
|
||||||
|
Organic Bananas
|
||||||
|
"""
|
||||||
|
|
||||||
|
# --- Gemini AI Settings ---
|
||||||
|
GEMINI_MODEL_NAME: str = "gemini-2.0-flash" # The model to use for OCR
|
||||||
|
GEMINI_SAFETY_SETTINGS: dict = {
|
||||||
|
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
}
|
||||||
|
GEMINI_GENERATION_CONFIG: dict = {
|
||||||
|
"candidate_count": 1,
|
||||||
|
"max_output_tokens": 2048,
|
||||||
|
"temperature": 0.9,
|
||||||
|
"top_p": 1,
|
||||||
|
"top_k": 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- API Settings ---
|
||||||
|
API_PREFIX: str = "/api" # Base path for all API endpoints
|
||||||
|
API_OPENAPI_URL: str = "/api/openapi.json"
|
||||||
|
API_DOCS_URL: str = "/api/docs"
|
||||||
|
API_REDOC_URL: str = "/api/redoc"
|
||||||
|
CORS_ORIGINS: list[str] = [
|
||||||
|
"http://localhost:5174",
|
||||||
|
"http://localhost:8000",
|
||||||
|
# Add your deployed frontend URL here later
|
||||||
|
# "https://your-frontend-domain.com",
|
||||||
|
]
|
||||||
|
|
||||||
|
# --- API Metadata ---
|
||||||
|
API_TITLE: str = "Shared Lists API"
|
||||||
|
API_DESCRIPTION: str = "API for managing shared shopping lists, OCR, and cost splitting."
|
||||||
|
API_VERSION: str = "0.1.0"
|
||||||
|
ROOT_MESSAGE: str = "Welcome to the Shared Lists API! Docs available at /api/docs"
|
||||||
|
|
||||||
|
# --- Logging Settings ---
|
||||||
|
LOG_LEVEL: str = "INFO"
|
||||||
|
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
|
||||||
|
# --- Health Check Settings ---
|
||||||
|
HEALTH_STATUS_OK: str = "ok"
|
||||||
|
HEALTH_STATUS_ERROR: str = "error"
|
||||||
|
|
||||||
|
# --- Auth Settings ---
|
||||||
|
OAUTH2_TOKEN_URL: str = "/api/v1/auth/login" # Path to login endpoint
|
||||||
|
TOKEN_TYPE: str = "bearer" # Default token type for OAuth2
|
||||||
|
AUTH_HEADER_PREFIX: str = "Bearer" # Prefix for Authorization header
|
||||||
|
AUTH_HEADER_NAME: str = "WWW-Authenticate" # Name of auth header
|
||||||
|
AUTH_CREDENTIALS_ERROR: str = "Could not validate credentials"
|
||||||
|
AUTH_INVALID_CREDENTIALS: str = "Incorrect email or password"
|
||||||
|
AUTH_NOT_AUTHENTICATED: str = "Not authenticated"
|
||||||
|
|
||||||
|
# --- HTTP Status Messages ---
|
||||||
|
HTTP_400_DETAIL: str = "Bad Request"
|
||||||
|
HTTP_401_DETAIL: str = "Unauthorized"
|
||||||
|
HTTP_403_DETAIL: str = "Forbidden"
|
||||||
|
HTTP_404_DETAIL: str = "Not Found"
|
||||||
|
HTTP_422_DETAIL: str = "Unprocessable Entity"
|
||||||
|
HTTP_429_DETAIL: str = "Too Many Requests"
|
||||||
|
HTTP_500_DETAIL: str = "Internal Server Error"
|
||||||
|
HTTP_503_DETAIL: str = "Service Unavailable"
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||||
|
from app.config import settings
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
class ListNotFoundError(HTTPException):
|
class ListNotFoundError(HTTPException):
|
||||||
"""Raised when a list is not found."""
|
"""Raised when a list is not found."""
|
||||||
@ -72,76 +75,105 @@ class ItemNotFoundError(HTTPException):
|
|||||||
detail=f"Item {item_id} not found"
|
detail=f"Item {item_id} not found"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class UserNotFoundError(HTTPException):
|
||||||
|
"""Raised when a user is not found."""
|
||||||
|
def __init__(self, user_id: Optional[int] = None, identifier: Optional[str] = None):
|
||||||
|
detail_msg = "User not found."
|
||||||
|
if user_id:
|
||||||
|
detail_msg = f"User with ID {user_id} not found."
|
||||||
|
elif identifier:
|
||||||
|
detail_msg = f"User with identifier '{identifier}' not found."
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=detail_msg
|
||||||
|
)
|
||||||
|
|
||||||
class DatabaseConnectionError(HTTPException):
|
class DatabaseConnectionError(HTTPException):
|
||||||
"""Raised when there is an error connecting to the database."""
|
"""Raised when there is an error connecting to the database."""
|
||||||
def __init__(self, detail: str = "Database connection error"):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
detail=detail
|
detail=settings.DB_CONNECTION_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
class DatabaseIntegrityError(HTTPException):
|
class DatabaseIntegrityError(HTTPException):
|
||||||
"""Raised when a database integrity constraint is violated."""
|
"""Raised when a database integrity constraint is violated."""
|
||||||
def __init__(self, detail: str = "Database integrity error"):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=detail
|
detail=settings.DB_INTEGRITY_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
class DatabaseTransactionError(HTTPException):
|
class DatabaseTransactionError(HTTPException):
|
||||||
"""Raised when a database transaction fails."""
|
"""Raised when a database transaction fails."""
|
||||||
def __init__(self, detail: str = "Database transaction error"):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=detail
|
detail=settings.DB_TRANSACTION_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
class DatabaseQueryError(HTTPException):
|
class DatabaseQueryError(HTTPException):
|
||||||
"""Raised when a database query fails."""
|
"""Raised when a database query fails."""
|
||||||
def __init__(self, detail: str = "Database query error"):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=detail
|
detail=settings.DB_QUERY_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
class OcrServiceUnavailableError(HTTPException):
|
class OCRServiceUnavailableError(HTTPException):
|
||||||
"""Raised when the OCR service is unavailable."""
|
"""Raised when the OCR service is unavailable."""
|
||||||
def __init__(self, detail: str):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
detail=f"OCR service unavailable: {detail}"
|
detail=settings.OCR_SERVICE_UNAVAILABLE
|
||||||
)
|
)
|
||||||
|
|
||||||
class InvalidFileTypeError(HTTPException):
|
class OCRServiceConfigError(HTTPException):
|
||||||
"""Raised when an invalid file type is uploaded for OCR."""
|
"""Raised when there is an error in the OCR service configuration."""
|
||||||
def __init__(self, allowed_types: list[str]):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Invalid file type. Allowed types: {', '.join(allowed_types)}"
|
detail=settings.OCR_SERVICE_CONFIG_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
class FileTooLargeError(HTTPException):
|
class OCRUnexpectedError(HTTPException):
|
||||||
"""Raised when an uploaded file exceeds the size limit."""
|
"""Raised when there is an unexpected error in the OCR service."""
|
||||||
def __init__(self, max_size_mb: int):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"File size exceeds limit of {max_size_mb} MB."
|
detail=settings.OCR_UNEXPECTED_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
class OcrProcessingError(HTTPException):
|
class OCRQuotaExceededError(HTTPException):
|
||||||
"""Raised when there is an error processing the image with OCR."""
|
|
||||||
def __init__(self, detail: str):
|
|
||||||
super().__init__(
|
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
||||||
detail=f"Could not extract items from image: {detail}"
|
|
||||||
)
|
|
||||||
|
|
||||||
class OcrQuotaExceededError(HTTPException):
|
|
||||||
"""Raised when the OCR service quota is exceeded."""
|
"""Raised when the OCR service quota is exceeded."""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
detail="OCR service quota exceeded. Please try again later."
|
detail=settings.OCR_QUOTA_EXCEEDED
|
||||||
|
)
|
||||||
|
|
||||||
|
class InvalidFileTypeError(HTTPException):
|
||||||
|
"""Raised when an invalid file type is uploaded for OCR."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=settings.OCR_INVALID_FILE_TYPE.format(types=", ".join(settings.ALLOWED_IMAGE_TYPES))
|
||||||
|
)
|
||||||
|
|
||||||
|
class FileTooLargeError(HTTPException):
|
||||||
|
"""Raised when an uploaded file exceeds the size limit."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=settings.OCR_FILE_TOO_LARGE.format(size=settings.MAX_FILE_SIZE_MB)
|
||||||
|
)
|
||||||
|
|
||||||
|
class OCRProcessingError(HTTPException):
|
||||||
|
"""Raised when there is an error processing the image with OCR."""
|
||||||
|
def __init__(self, detail: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=settings.OCR_PROCESSING_ERROR.format(detail=detail)
|
||||||
)
|
)
|
||||||
|
|
||||||
class EmailAlreadyRegisteredError(HTTPException):
|
class EmailAlreadyRegisteredError(HTTPException):
|
||||||
@ -152,15 +184,6 @@ class EmailAlreadyRegisteredError(HTTPException):
|
|||||||
detail="Email already registered."
|
detail="Email already registered."
|
||||||
)
|
)
|
||||||
|
|
||||||
class InvalidCredentialsError(HTTPException):
|
|
||||||
"""Raised when login credentials are invalid."""
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Incorrect email or password",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"}
|
|
||||||
)
|
|
||||||
|
|
||||||
class UserCreationError(HTTPException):
|
class UserCreationError(HTTPException):
|
||||||
"""Raised when there is an error creating a new user."""
|
"""Raised when there is an error creating a new user."""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -208,3 +231,47 @@ class ListStatusNotFoundError(HTTPException):
|
|||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail=f"Status for list {list_id} not found"
|
detail=f"Status for list {list_id} not found"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class ConflictError(HTTPException):
|
||||||
|
"""Raised when an optimistic lock version conflict occurs."""
|
||||||
|
def __init__(self, detail: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=detail
|
||||||
|
)
|
||||||
|
|
||||||
|
class InvalidCredentialsError(HTTPException):
|
||||||
|
"""Raised when login credentials are invalid."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=settings.AUTH_INVALID_CREDENTIALS,
|
||||||
|
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_credentials\""}
|
||||||
|
)
|
||||||
|
|
||||||
|
class NotAuthenticatedError(HTTPException):
|
||||||
|
"""Raised when the user is not authenticated."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=settings.AUTH_NOT_AUTHENTICATED,
|
||||||
|
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"not_authenticated\""}
|
||||||
|
)
|
||||||
|
|
||||||
|
class JWTError(HTTPException):
|
||||||
|
"""Raised when there is an error with the JWT token."""
|
||||||
|
def __init__(self, error: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=settings.JWT_ERROR.format(error=error),
|
||||||
|
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
|
||||||
|
)
|
||||||
|
|
||||||
|
class JWTUnexpectedError(HTTPException):
|
||||||
|
"""Raised when there is an unexpected error with the JWT token."""
|
||||||
|
def __init__(self, error: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=settings.JWT_UNEXPECTED_ERROR.format(error=error),
|
||||||
|
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
|
||||||
|
)
|
@ -4,8 +4,13 @@ from typing import List
|
|||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
from google.generativeai.types import HarmCategory, HarmBlockThreshold # For safety settings
|
from google.generativeai.types import HarmCategory, HarmBlockThreshold # For safety settings
|
||||||
from google.api_core import exceptions as google_exceptions
|
from google.api_core import exceptions as google_exceptions
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
from app.core.exceptions import (
|
||||||
|
OCRServiceUnavailableError,
|
||||||
|
OCRServiceConfigError,
|
||||||
|
OCRUnexpectedError,
|
||||||
|
OCRQuotaExceededError
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -19,26 +24,18 @@ try:
|
|||||||
genai.configure(api_key=settings.GEMINI_API_KEY)
|
genai.configure(api_key=settings.GEMINI_API_KEY)
|
||||||
# Initialize the specific model we want to use
|
# Initialize the specific model we want to use
|
||||||
gemini_flash_client = genai.GenerativeModel(
|
gemini_flash_client = genai.GenerativeModel(
|
||||||
model_name="gemini-2.0-flash",
|
model_name=settings.GEMINI_MODEL_NAME,
|
||||||
# Optional: Add default safety settings
|
# Safety settings from config
|
||||||
# Adjust these based on your expected content and risk tolerance
|
|
||||||
safety_settings={
|
safety_settings={
|
||||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
getattr(HarmCategory, category): getattr(HarmBlockThreshold, threshold)
|
||||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
for category, threshold in settings.GEMINI_SAFETY_SETTINGS.items()
|
||||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
|
||||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
|
||||||
},
|
},
|
||||||
# Optional: Add default generation config (can be overridden per request)
|
# Generation config from settings
|
||||||
# generation_config=genai.types.GenerationConfig(
|
generation_config=genai.types.GenerationConfig(
|
||||||
# # candidate_count=1, # Usually default is 1
|
**settings.GEMINI_GENERATION_CONFIG
|
||||||
# # stop_sequences=["\n"],
|
|
||||||
# # max_output_tokens=2048,
|
|
||||||
# # temperature=0.9, # Controls randomness (0=deterministic, >1=more random)
|
|
||||||
# # top_p=1,
|
|
||||||
# # top_k=1
|
|
||||||
# )
|
|
||||||
)
|
)
|
||||||
logger.info("Gemini AI client initialized successfully for model 'gemini-1.5-flash-latest'.")
|
)
|
||||||
|
logger.info(f"Gemini AI client initialized successfully for model '{settings.GEMINI_MODEL_NAME}'.")
|
||||||
else:
|
else:
|
||||||
# Store error if API key is missing
|
# Store error if API key is missing
|
||||||
gemini_initialization_error = "GEMINI_API_KEY not configured. Gemini client not initialized."
|
gemini_initialization_error = "GEMINI_API_KEY not configured. Gemini client not initialized."
|
||||||
@ -105,7 +102,7 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "
|
|||||||
|
|
||||||
# Prepare the full prompt content
|
# Prepare the full prompt content
|
||||||
prompt_parts = [
|
prompt_parts = [
|
||||||
OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first
|
settings.OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first
|
||||||
image_part # Then the image
|
image_part # Then the image
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -153,3 +150,46 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "
|
|||||||
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
|
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
|
||||||
# Wrap in a generic ValueError or re-raise
|
# Wrap in a generic ValueError or re-raise
|
||||||
raise ValueError(f"Failed to process image with Gemini: {e}") from e
|
raise ValueError(f"Failed to process image with Gemini: {e}") from e
|
||||||
|
|
||||||
|
class GeminiOCRService:
|
||||||
|
def __init__(self):
|
||||||
|
try:
|
||||||
|
genai.configure(api_key=settings.GEMINI_API_KEY)
|
||||||
|
self.model = genai.GenerativeModel(settings.GEMINI_MODEL_NAME)
|
||||||
|
self.model.safety_settings = settings.GEMINI_SAFETY_SETTINGS
|
||||||
|
self.model.generation_config = settings.GEMINI_GENERATION_CONFIG
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize Gemini client: {e}")
|
||||||
|
raise OCRServiceConfigError()
|
||||||
|
|
||||||
|
async def extract_items(self, image_data: bytes) -> List[str]:
|
||||||
|
"""
|
||||||
|
Extract shopping list items from an image using Gemini Vision.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Create image part
|
||||||
|
image_parts = [{"mime_type": "image/jpeg", "data": image_data}]
|
||||||
|
|
||||||
|
# Generate content
|
||||||
|
response = await self.model.generate_content_async(
|
||||||
|
contents=[settings.OCR_ITEM_EXTRACTION_PROMPT, *image_parts]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process response
|
||||||
|
if not response.text:
|
||||||
|
raise OCRUnexpectedError()
|
||||||
|
|
||||||
|
# Split response into lines and clean up
|
||||||
|
items = [
|
||||||
|
item.strip()
|
||||||
|
for item in response.text.split("\n")
|
||||||
|
if item.strip() and not item.strip().startswith("Example")
|
||||||
|
]
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during OCR extraction: {e}")
|
||||||
|
if "quota" in str(e).lower():
|
||||||
|
raise OCRQuotaExceededError()
|
||||||
|
raise OCRServiceUnavailableError()
|
@ -66,7 +66,34 @@ def create_access_token(subject: Union[str, Any], expires_delta: Optional[timede
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Data to encode in the token payload
|
# Data to encode in the token payload
|
||||||
to_encode = {"exp": expire, "sub": str(subject)}
|
to_encode = {"exp": expire, "sub": str(subject), "type": "access"}
|
||||||
|
|
||||||
|
encoded_jwt = jwt.encode(
|
||||||
|
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||||
|
)
|
||||||
|
return encoded_jwt
|
||||||
|
|
||||||
|
def create_refresh_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||||
|
"""
|
||||||
|
Creates a JWT refresh token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subject: The subject of the token (e.g., user ID or email).
|
||||||
|
expires_delta: Optional timedelta object for token expiry. If None,
|
||||||
|
uses REFRESH_TOKEN_EXPIRE_MINUTES from settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The encoded JWT refresh token string.
|
||||||
|
"""
|
||||||
|
if expires_delta:
|
||||||
|
expire = datetime.now(timezone.utc) + expires_delta
|
||||||
|
else:
|
||||||
|
expire = datetime.now(timezone.utc) + timedelta(
|
||||||
|
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
||||||
|
)
|
||||||
|
|
||||||
|
# Data to encode in the token payload
|
||||||
|
to_encode = {"exp": expire, "sub": str(subject), "type": "refresh"}
|
||||||
|
|
||||||
encoded_jwt = jwt.encode(
|
encoded_jwt = jwt.encode(
|
||||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||||
@ -91,6 +118,8 @@ def verify_access_token(token: str) -> Optional[dict]:
|
|||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||||
)
|
)
|
||||||
|
if payload.get("type") != "access":
|
||||||
|
raise JWTError("Invalid token type")
|
||||||
return payload
|
return payload
|
||||||
except JWTError as e:
|
except JWTError as e:
|
||||||
# Handles InvalidSignatureError, ExpiredSignatureError, etc.
|
# Handles InvalidSignatureError, ExpiredSignatureError, etc.
|
||||||
@ -101,6 +130,31 @@ def verify_access_token(token: str) -> Optional[dict]:
|
|||||||
print(f"Unexpected error decoding JWT: {e}")
|
print(f"Unexpected error decoding JWT: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def verify_refresh_token(token: str) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Verifies a JWT refresh token and returns its payload if valid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The JWT token string to verify.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The decoded token payload (dict) if the token is valid, not expired,
|
||||||
|
and is a refresh token, otherwise None.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||||
|
)
|
||||||
|
if payload.get("type") != "refresh":
|
||||||
|
raise JWTError("Invalid token type")
|
||||||
|
return payload
|
||||||
|
except JWTError as e:
|
||||||
|
print(f"JWT Error: {e}") # Log the error for debugging
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Unexpected error decoding JWT: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
# You might add a function here later to extract the 'sub' (subject/user id)
|
# You might add a function here later to extract the 'sub' (subject/user id)
|
||||||
# specifically, often used in dependency injection for authentication.
|
# specifically, often used in dependency injection for authentication.
|
||||||
# def get_subject_from_token(token: str) -> Optional[str]:
|
# def get_subject_from_token(token: str) -> Optional[str]:
|
||||||
|
116
be/app/crud/cost.py
Normal file
116
be/app/crud/cost.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
from sqlalchemy.orm import selectinload, joinedload
|
||||||
|
from decimal import Decimal, ROUND_HALF_UP
|
||||||
|
from typing import List as PyList, Dict, Set
|
||||||
|
|
||||||
|
from app.models import List as ListModel, Item as ItemModel, User as UserModel, UserGroup as UserGroupModel, Group as GroupModel
|
||||||
|
from app.schemas.cost import ListCostSummary, UserCostShare
|
||||||
|
from app.core.exceptions import ListNotFoundError, UserNotFoundError # Assuming UserNotFoundError might be useful
|
||||||
|
|
||||||
|
async def calculate_list_cost_summary(db: AsyncSession, list_id: int) -> ListCostSummary:
|
||||||
|
"""
|
||||||
|
Calculates the cost summary for a given list, splitting costs equally
|
||||||
|
among relevant users (group members if list is in a group, or creator if personal).
|
||||||
|
"""
|
||||||
|
# 1. Fetch the list, its items (with their 'added_by_user'), and its group (with members)
|
||||||
|
list_result = await db.execute(
|
||||||
|
select(ListModel)
|
||||||
|
.options(
|
||||||
|
selectinload(ListModel.items).options(joinedload(ItemModel.added_by_user)),
|
||||||
|
selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user)))
|
||||||
|
)
|
||||||
|
.where(ListModel.id == list_id)
|
||||||
|
)
|
||||||
|
db_list: Optional[ListModel] = list_result.scalars().first()
|
||||||
|
|
||||||
|
if not db_list:
|
||||||
|
raise ListNotFoundError(list_id)
|
||||||
|
|
||||||
|
# 2. Determine participating users
|
||||||
|
participating_users: Dict[int, UserModel] = {}
|
||||||
|
if db_list.group:
|
||||||
|
# If list is part of a group, all group members participate
|
||||||
|
for ug_assoc in db_list.group.user_associations:
|
||||||
|
if ug_assoc.user: # Ensure user object is loaded
|
||||||
|
participating_users[ug_assoc.user.id] = ug_assoc.user
|
||||||
|
else:
|
||||||
|
# If personal list, only the creator participates (or if items were added by others somehow, include them)
|
||||||
|
# For simplicity in MVP, if personal, only creator. If shared personal lists become a feature, this needs revisit.
|
||||||
|
# Let's fetch the creator if not already available through relationships (though it should be via ListModel.creator)
|
||||||
|
creator_user = await db.get(UserModel, db_list.created_by_id)
|
||||||
|
if not creator_user:
|
||||||
|
# This case should ideally not happen if data integrity is maintained
|
||||||
|
raise UserNotFoundError(user_id=db_list.created_by_id) # Or a more specific error
|
||||||
|
participating_users[creator_user.id] = creator_user
|
||||||
|
|
||||||
|
# Also ensure all users who added items are included, even if not in the group (edge case, but good for robustness)
|
||||||
|
for item in db_list.items:
|
||||||
|
if item.added_by_user and item.added_by_user.id not in participating_users:
|
||||||
|
participating_users[item.added_by_user.id] = item.added_by_user
|
||||||
|
|
||||||
|
|
||||||
|
num_participating_users = len(participating_users)
|
||||||
|
if num_participating_users == 0:
|
||||||
|
# Handle case with no users (e.g., empty group, or personal list creator deleted - though FKs should prevent)
|
||||||
|
# Or if list has no items and is personal, creator might not be in participating_users if logic changes.
|
||||||
|
# For now, if no users found (e.g. group with no members and list creator somehow not added), return empty/default summary.
|
||||||
|
return ListCostSummary(
|
||||||
|
list_id=db_list.id,
|
||||||
|
list_name=db_list.name,
|
||||||
|
total_list_cost=Decimal("0.00"),
|
||||||
|
num_participating_users=0,
|
||||||
|
equal_share_per_user=Decimal("0.00"),
|
||||||
|
user_balances=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 3. Calculate total cost and items_added_value for each user
|
||||||
|
total_list_cost = Decimal("0.00")
|
||||||
|
user_items_added_value: Dict[int, Decimal] = {user_id: Decimal("0.00") for user_id in participating_users.keys()}
|
||||||
|
|
||||||
|
for item in db_list.items:
|
||||||
|
if item.price is not None and item.price > Decimal("0"):
|
||||||
|
total_list_cost += item.price
|
||||||
|
if item.added_by_id in user_items_added_value: # Item adder must be in participating users
|
||||||
|
user_items_added_value[item.added_by_id] += item.price
|
||||||
|
# If item.added_by_id is not in participating_users (e.g. user left group),
|
||||||
|
# their contribution still counts to total cost, but they aren't part of the split.
|
||||||
|
# The current logic adds item adders to participating_users, so this else is less likely.
|
||||||
|
|
||||||
|
# 4. Calculate equal share per user
|
||||||
|
# Using ROUND_HALF_UP to handle cents appropriately.
|
||||||
|
# Ensure division by zero is handled if num_participating_users could be 0 (already handled above)
|
||||||
|
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) if num_participating_users > 0 else Decimal("0.00")
|
||||||
|
|
||||||
|
# 5. For each user, calculate their balance
|
||||||
|
user_balances: PyList[UserCostShare] = []
|
||||||
|
for user_id, user_obj in participating_users.items():
|
||||||
|
items_added = user_items_added_value.get(user_id, Decimal("0.00"))
|
||||||
|
balance = items_added - equal_share_per_user
|
||||||
|
|
||||||
|
user_identifier = user_obj.name if user_obj.name else user_obj.email # Prefer name, fallback to email
|
||||||
|
|
||||||
|
user_balances.append(
|
||||||
|
UserCostShare(
|
||||||
|
user_id=user_id,
|
||||||
|
user_identifier=user_identifier,
|
||||||
|
items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||||
|
amount_due=equal_share_per_user, # Already quantized
|
||||||
|
balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort user_balances for consistent output, e.g., by user_id or identifier
|
||||||
|
user_balances.sort(key=lambda x: x.user_identifier)
|
||||||
|
|
||||||
|
|
||||||
|
# 6. Return the populated ListCostSummary schema
|
||||||
|
return ListCostSummary(
|
||||||
|
list_id=db_list.id,
|
||||||
|
list_name=db_list.name,
|
||||||
|
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||||
|
num_participating_users=num_participating_users,
|
||||||
|
equal_share_per_user=equal_share_per_user, # Already quantized
|
||||||
|
user_balances=user_balances
|
||||||
|
)
|
@ -13,7 +13,8 @@ from app.core.exceptions import (
|
|||||||
DatabaseConnectionError,
|
DatabaseConnectionError,
|
||||||
DatabaseIntegrityError,
|
DatabaseIntegrityError,
|
||||||
DatabaseQueryError,
|
DatabaseQueryError,
|
||||||
DatabaseTransactionError
|
DatabaseTransactionError,
|
||||||
|
ConflictError
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel:
|
async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel:
|
||||||
@ -26,6 +27,7 @@ async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_
|
|||||||
list_id=list_id,
|
list_id=list_id,
|
||||||
added_by_id=user_id,
|
added_by_id=user_id,
|
||||||
is_complete=False # Default on creation
|
is_complete=False # Default on creation
|
||||||
|
# version is implicitly set to 1 by model default
|
||||||
)
|
)
|
||||||
db.add(db_item)
|
db.add(db_item)
|
||||||
await db.flush()
|
await db.flush()
|
||||||
@ -65,44 +67,57 @@ async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]:
|
|||||||
raise DatabaseQueryError(f"Failed to query item: {str(e)}")
|
raise DatabaseQueryError(f"Failed to query item: {str(e)}")
|
||||||
|
|
||||||
async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel:
|
async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel:
|
||||||
"""Updates an existing item record."""
|
"""Updates an existing item record, checking for version conflicts."""
|
||||||
try:
|
try:
|
||||||
async with db.begin():
|
async with db.begin():
|
||||||
update_data = item_in.model_dump(exclude_unset=True) # Get only provided fields
|
if item_db.version != item_in.version:
|
||||||
|
raise ConflictError(
|
||||||
|
f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. "
|
||||||
|
f"Your version is {item_in.version}, current version is {item_db.version}. Please refresh."
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = item_in.model_dump(exclude_unset=True, exclude={'version'}) # Exclude version
|
||||||
|
|
||||||
# Special handling for is_complete
|
# Special handling for is_complete
|
||||||
if 'is_complete' in update_data:
|
if 'is_complete' in update_data:
|
||||||
if update_data['is_complete'] is True:
|
if update_data['is_complete'] is True:
|
||||||
# Mark as complete: set completed_by_id if not already set
|
if item_db.completed_by_id is None: # Only set if not already completed by someone
|
||||||
if item_db.completed_by_id is None:
|
|
||||||
update_data['completed_by_id'] = user_id
|
update_data['completed_by_id'] = user_id
|
||||||
else:
|
else:
|
||||||
# Mark as incomplete: clear completed_by_id
|
update_data['completed_by_id'] = None # Clear if marked incomplete
|
||||||
update_data['completed_by_id'] = None
|
|
||||||
# Ensure updated_at is refreshed (handled by onupdate in model, but explicit is fine too)
|
|
||||||
# update_data['updated_at'] = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
for key, value in update_data.items():
|
for key, value in update_data.items():
|
||||||
setattr(item_db, key, value)
|
setattr(item_db, key, value)
|
||||||
|
|
||||||
db.add(item_db) # Add to session to track changes
|
item_db.version += 1 # Increment version
|
||||||
|
|
||||||
|
db.add(item_db)
|
||||||
await db.flush()
|
await db.flush()
|
||||||
await db.refresh(item_db)
|
await db.refresh(item_db)
|
||||||
return item_db
|
return item_db
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
raise DatabaseIntegrityError(f"Failed to update item: {str(e)}")
|
await db.rollback()
|
||||||
|
raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}")
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
await db.rollback()
|
||||||
|
raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}")
|
||||||
|
except ConflictError: # Re-raise ConflictError
|
||||||
|
await db.rollback()
|
||||||
|
raise
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
await db.rollback()
|
||||||
raise DatabaseTransactionError(f"Failed to update item: {str(e)}")
|
raise DatabaseTransactionError(f"Failed to update item: {str(e)}")
|
||||||
|
|
||||||
async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
|
async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
|
||||||
"""Deletes an item record."""
|
"""Deletes an item record. Version check should be done by the caller (API endpoint)."""
|
||||||
try:
|
try:
|
||||||
async with db.begin():
|
async with db.begin():
|
||||||
await db.delete(item_db)
|
await db.delete(item_db)
|
||||||
return None # Or return True/False
|
# await db.commit() # Not needed with async with db.begin()
|
||||||
|
return None
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
await db.rollback()
|
||||||
|
raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
await db.rollback()
|
||||||
raise DatabaseTransactionError(f"Failed to delete item: {str(e)}")
|
raise DatabaseTransactionError(f"Failed to delete item: {str(e)}")
|
@ -16,7 +16,8 @@ from app.core.exceptions import (
|
|||||||
DatabaseConnectionError,
|
DatabaseConnectionError,
|
||||||
DatabaseIntegrityError,
|
DatabaseIntegrityError,
|
||||||
DatabaseQueryError,
|
DatabaseQueryError,
|
||||||
DatabaseTransactionError
|
DatabaseTransactionError,
|
||||||
|
ConflictError
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
|
async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
|
||||||
@ -85,32 +86,50 @@ async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = Fals
|
|||||||
raise DatabaseQueryError(f"Failed to query list: {str(e)}")
|
raise DatabaseQueryError(f"Failed to query list: {str(e)}")
|
||||||
|
|
||||||
async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel:
|
async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel:
|
||||||
"""Updates an existing list record."""
|
"""Updates an existing list record, checking for version conflicts."""
|
||||||
try:
|
try:
|
||||||
async with db.begin():
|
async with db.begin():
|
||||||
update_data = list_in.model_dump(exclude_unset=True)
|
if list_db.version != list_in.version:
|
||||||
|
raise ConflictError(
|
||||||
|
f"List '{list_db.name}' (ID: {list_db.id}) has been modified. "
|
||||||
|
f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh."
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = list_in.model_dump(exclude_unset=True, exclude={'version'})
|
||||||
|
|
||||||
for key, value in update_data.items():
|
for key, value in update_data.items():
|
||||||
setattr(list_db, key, value)
|
setattr(list_db, key, value)
|
||||||
|
|
||||||
|
list_db.version += 1
|
||||||
|
|
||||||
db.add(list_db)
|
db.add(list_db)
|
||||||
await db.flush()
|
await db.flush()
|
||||||
await db.refresh(list_db)
|
await db.refresh(list_db)
|
||||||
return list_db
|
return list_db
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
raise DatabaseIntegrityError(f"Failed to update list: {str(e)}")
|
await db.rollback()
|
||||||
|
raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}")
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
await db.rollback()
|
||||||
|
raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}")
|
||||||
|
except ConflictError:
|
||||||
|
await db.rollback()
|
||||||
|
raise
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
await db.rollback()
|
||||||
raise DatabaseTransactionError(f"Failed to update list: {str(e)}")
|
raise DatabaseTransactionError(f"Failed to update list: {str(e)}")
|
||||||
|
|
||||||
async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
|
async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
|
||||||
"""Deletes a list record."""
|
"""Deletes a list record. Version check should be done by the caller (API endpoint)."""
|
||||||
try:
|
try:
|
||||||
async with db.begin():
|
async with db.begin():
|
||||||
await db.delete(list_db)
|
await db.delete(list_db)
|
||||||
return None
|
return None
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
await db.rollback()
|
||||||
|
raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
await db.rollback()
|
||||||
raise DatabaseTransactionError(f"Failed to delete list: {str(e)}")
|
raise DatabaseTransactionError(f"Failed to delete list: {str(e)}")
|
||||||
|
|
||||||
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel:
|
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel:
|
||||||
|
@ -5,22 +5,25 @@ from fastapi import FastAPI
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from app.api.api_router import api_router # Import the main combined router
|
from app.api.api_router import api_router # Import the main combined router
|
||||||
|
from app.config import settings
|
||||||
# Import database and models if needed for startup/shutdown events later
|
# Import database and models if needed for startup/shutdown events later
|
||||||
# from . import database, models
|
# from . import database, models
|
||||||
|
|
||||||
# --- Logging Setup ---
|
# --- Logging Setup ---
|
||||||
# Configure logging (can be more sophisticated later, e.g., using logging.yaml)
|
logging.basicConfig(
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
level=getattr(logging, settings.LOG_LEVEL),
|
||||||
|
format=settings.LOG_FORMAT
|
||||||
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# --- FastAPI App Instance ---
|
# --- FastAPI App Instance ---
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Shared Lists API",
|
title=settings.API_TITLE,
|
||||||
description="API for managing shared shopping lists, OCR, and cost splitting.",
|
description=settings.API_DESCRIPTION,
|
||||||
version="0.1.0",
|
version=settings.API_VERSION,
|
||||||
openapi_url="/api/openapi.json", # Place OpenAPI spec under /api
|
openapi_url=settings.API_OPENAPI_URL,
|
||||||
docs_url="/api/docs", # Place Swagger UI under /api
|
docs_url=settings.API_DOCS_URL,
|
||||||
redoc_url="/api/redoc" # Place ReDoc under /api
|
redoc_url=settings.API_REDOC_URL
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- CORS Middleware ---
|
# --- CORS Middleware ---
|
||||||
@ -37,17 +40,17 @@ origins = [
|
|||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=origins, # List of origins that are allowed to make requests
|
allow_origins=settings.CORS_ORIGINS,
|
||||||
allow_credentials=True, # Allow cookies to be included in requests
|
allow_credentials=True,
|
||||||
allow_methods=["*"], # Allow all methods (GET, POST, PUT, DELETE, etc.)
|
allow_methods=["*"],
|
||||||
allow_headers=["*"], # Allow all headers
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
# --- End CORS Middleware ---
|
# --- End CORS Middleware ---
|
||||||
|
|
||||||
|
|
||||||
# --- Include API Routers ---
|
# --- Include API Routers ---
|
||||||
# All API endpoints will be prefixed with /api
|
# All API endpoints will be prefixed with /api
|
||||||
app.include_router(api_router, prefix="/api")
|
app.include_router(api_router, prefix=settings.API_PREFIX)
|
||||||
# --- End Include API Routers ---
|
# --- End Include API Routers ---
|
||||||
|
|
||||||
|
|
||||||
@ -59,10 +62,7 @@ async def read_root():
|
|||||||
Useful for basic reachability checks.
|
Useful for basic reachability checks.
|
||||||
"""
|
"""
|
||||||
logger.info("Root endpoint '/' accessed.")
|
logger.info("Root endpoint '/' accessed.")
|
||||||
# You could redirect to the docs or return a simple message
|
return {"message": settings.ROOT_MESSAGE}
|
||||||
# from fastapi.responses import RedirectResponse
|
|
||||||
# return RedirectResponse(url="/api/docs")
|
|
||||||
return {"message": "Welcome to the Shared Lists API! Docs available at /api/docs"}
|
|
||||||
# --- End Root Endpoint ---
|
# --- End Root Endpoint ---
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,6 +117,7 @@ class List(Base):
|
|||||||
is_complete = Column(Boolean, default=False, nullable=False)
|
is_complete = Column(Boolean, default=False, nullable=False)
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||||
|
version = Column(Integer, nullable=False, default=1, server_default='1')
|
||||||
|
|
||||||
# --- Relationships ---
|
# --- Relationships ---
|
||||||
creator = relationship("User", back_populates="created_lists") # Link to User.created_lists
|
creator = relationship("User", back_populates="created_lists") # Link to User.created_lists
|
||||||
@ -138,6 +139,7 @@ class Item(Base):
|
|||||||
completed_by_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Who marked it complete
|
completed_by_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Who marked it complete
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||||
|
version = Column(Integer, nullable=False, default=1, server_default='1')
|
||||||
|
|
||||||
# --- Relationships ---
|
# --- Relationships ---
|
||||||
list = relationship("List", back_populates="items") # Link to List.items
|
list = relationship("List", back_populates="items") # Link to List.items
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
# app/schemas/auth.py
|
# app/schemas/auth.py
|
||||||
from pydantic import BaseModel, EmailStr
|
from pydantic import BaseModel, EmailStr
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
class Token(BaseModel):
|
class Token(BaseModel):
|
||||||
access_token: str
|
access_token: str
|
||||||
token_type: str = "bearer" # Default token type
|
refresh_token: str # Added refresh token
|
||||||
|
token_type: str = settings.TOKEN_TYPE # Use configured token type
|
||||||
|
|
||||||
# Optional: If you preferred not to use OAuth2PasswordRequestForm
|
# Optional: If you preferred not to use OAuth2PasswordRequestForm
|
||||||
# class UserLogin(BaseModel):
|
# class UserLogin(BaseModel):
|
||||||
|
22
be/app/schemas/cost.py
Normal file
22
be/app/schemas/cost.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from typing import List, Optional
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
class UserCostShare(BaseModel):
|
||||||
|
user_id: int
|
||||||
|
user_identifier: str # Name or email
|
||||||
|
items_added_value: Decimal = Decimal("0.00") # Total value of items this user added
|
||||||
|
amount_due: Decimal # The user's share of the total cost (for equal split, this is total_cost / num_users)
|
||||||
|
balance: Decimal # items_added_value - amount_due
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
class ListCostSummary(BaseModel):
|
||||||
|
list_id: int
|
||||||
|
list_name: str
|
||||||
|
total_list_cost: Decimal
|
||||||
|
num_participating_users: int
|
||||||
|
equal_share_per_user: Decimal
|
||||||
|
user_balances: List[UserCostShare]
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
@ -1,9 +1,10 @@
|
|||||||
# app/schemas/health.py
|
# app/schemas/health.py
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
class HealthStatus(BaseModel):
|
class HealthStatus(BaseModel):
|
||||||
"""
|
"""
|
||||||
Response model for the health check endpoint.
|
Response model for the health check endpoint.
|
||||||
"""
|
"""
|
||||||
status: str = "ok" # Provide a default value
|
status: str = settings.HEALTH_STATUS_OK # Use configured default value
|
||||||
database: str
|
database: str
|
@ -16,6 +16,7 @@ class ItemPublic(BaseModel):
|
|||||||
completed_by_id: Optional[int] = None
|
completed_by_id: Optional[int] = None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
version: int
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
# Properties to receive via API on creation
|
# Properties to receive via API on creation
|
||||||
@ -31,4 +32,5 @@ class ItemUpdate(BaseModel):
|
|||||||
quantity: Optional[str] = None
|
quantity: Optional[str] = None
|
||||||
is_complete: Optional[bool] = None
|
is_complete: Optional[bool] = None
|
||||||
price: Optional[Decimal] = None # Price added here for update
|
price: Optional[Decimal] = None # Price added here for update
|
||||||
|
version: int
|
||||||
# completed_by_id will be set internally if is_complete is true
|
# completed_by_id will be set internally if is_complete is true
|
@ -16,6 +16,7 @@ class ListUpdate(BaseModel):
|
|||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
is_complete: Optional[bool] = None
|
is_complete: Optional[bool] = None
|
||||||
|
version: int # Client must provide the version for updates
|
||||||
# Potentially add group_id update later if needed
|
# Potentially add group_id update later if needed
|
||||||
|
|
||||||
# Base properties returned by API (common fields)
|
# Base properties returned by API (common fields)
|
||||||
@ -28,6 +29,7 @@ class ListBase(BaseModel):
|
|||||||
is_complete: bool
|
is_complete: bool
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
version: int # Include version in responses
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user