diff --git a/be/app/api/v1/endpoints/costs.py b/be/app/api/v1/endpoints/costs.py index fbb3f1a..5489aed 100644 --- a/be/app/api/v1/endpoints/costs.py +++ b/be/app/api/v1/endpoints/costs.py @@ -4,7 +4,8 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session, selectinload -from decimal import Decimal, ROUND_HALF_UP +from decimal import Decimal, ROUND_HALF_UP, ROUND_DOWN +from typing import List from app.database import get_db from app.auth import current_active_user @@ -19,7 +20,7 @@ from app.models import ( ExpenseSplit as ExpenseSplitModel, Settlement as SettlementModel ) -from app.schemas.cost import ListCostSummary, GroupBalanceSummary +from app.schemas.cost import ListCostSummary, GroupBalanceSummary, UserCostShare, UserBalanceDetail, SuggestedSettlement from app.schemas.expense import ExpenseCreate from app.crud import list as crud_list from app.crud import expense as crud_expense @@ -28,6 +29,85 @@ from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotF logger = logging.getLogger(__name__) router = APIRouter() +def calculate_suggested_settlements(user_balances: List[UserBalanceDetail]) -> List[SuggestedSettlement]: + """ + Calculate suggested settlements to balance the finances within a group. + + This function takes the current balances of all users and suggests optimal settlements + to minimize the number of transactions needed to settle all debts. + + Args: + user_balances: List of UserBalanceDetail objects with their current balances + + Returns: + List of SuggestedSettlement objects representing the suggested payments + """ + # Create list of users who owe money (negative balance) and who are owed money (positive balance) + debtors = [] # Users who owe money (negative balance) + creditors = [] # Users who are owed money (positive balance) + + # Threshold to consider a balance as zero due to floating point precision + epsilon = Decimal('0.01') + + # Sort users into debtors and creditors + for user in user_balances: + # Skip users with zero balance (or very close to zero) + if abs(user.net_balance) < epsilon: + continue + + if user.net_balance < Decimal('0'): + # User owes money + debtors.append({ + 'user_id': user.user_id, + 'user_identifier': user.user_identifier, + 'amount': -user.net_balance # Convert to positive amount + }) + else: + # User is owed money + creditors.append({ + 'user_id': user.user_id, + 'user_identifier': user.user_identifier, + 'amount': user.net_balance + }) + + # Sort by amount (descending) to handle largest debts first + debtors.sort(key=lambda x: x['amount'], reverse=True) + creditors.sort(key=lambda x: x['amount'], reverse=True) + + settlements = [] + + # Iterate through debtors and match them with creditors + while debtors and creditors: + debtor = debtors[0] + creditor = creditors[0] + + # Determine the settlement amount (the smaller of the two amounts) + amount = min(debtor['amount'], creditor['amount']).quantize(Decimal('0.01'), rounding=ROUND_HALF_UP) + + # Create settlement record + if amount > Decimal('0'): + settlements.append( + SuggestedSettlement( + from_user_id=debtor['user_id'], + from_user_identifier=debtor['user_identifier'], + to_user_id=creditor['user_id'], + to_user_identifier=creditor['user_identifier'], + amount=amount + ) + ) + + # Update balances + debtor['amount'] -= amount + creditor['amount'] -= amount + + # Remove users who have settled their debts/credits + if debtor['amount'] < epsilon: + debtors.pop(0) + if creditor['amount'] < epsilon: + creditors.pop(0) + + return settlements + @router.get( "/lists/{list_id}/cost-summary", response_model=ListCostSummary, @@ -105,7 +185,7 @@ async def get_list_cost_summary( total_amount=total_amount, list_id=list_id, split_type=SplitTypeEnum.ITEM_BASED, - paid_by_user_id=current_user.id # Use current user as payer for now + paid_by_user_id=db_list.creator.id ) db_expense = await crud_expense.create_expense(db=db, expense_in=expense_in) @@ -137,17 +217,36 @@ async def get_list_cost_summary( user_balances=[] ) - equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - remainder = total_list_cost - (equal_share_per_user * num_participating_users) + # This is the ideal equal share, returned in the summary + equal_share_per_user_for_response = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) + + # Sort users for deterministic remainder distribution + sorted_participating_users = sorted(list(participating_users), key=lambda u: u.id) + + user_final_shares = {} + if num_participating_users > 0: + base_share_unrounded = total_list_cost / Decimal(num_participating_users) + + # Calculate initial share for each user, rounding down + for user in sorted_participating_users: + user_final_shares[user.id] = base_share_unrounded.quantize(Decimal("0.01"), rounding=ROUND_DOWN) + + # Calculate sum of rounded down shares + sum_of_rounded_shares = sum(user_final_shares.values()) + + # Calculate remaining pennies to be distributed + remaining_pennies = int(((total_list_cost - sum_of_rounded_shares) * Decimal("100")).to_integral_value(rounding=ROUND_HALF_UP)) + + # Distribute remaining pennies one by one to sorted users + for i in range(remaining_pennies): + user_to_adjust = sorted_participating_users[i % num_participating_users] + user_final_shares[user_to_adjust.id] += Decimal("0.01") user_balances = [] - first_user_processed = False - for user in participating_users: + for user in sorted_participating_users: # Iterate over sorted users items_added = user_items_added_value.get(user.id, Decimal("0.00")) - current_user_share = equal_share_per_user - if not first_user_processed and remainder != Decimal("0"): - current_user_share += remainder - first_user_processed = True + # current_user_share is now the precisely calculated share for this user + current_user_share = user_final_shares.get(user.id, Decimal("0.00")) balance = items_added - current_user_share user_identifier = user.name if user.name else user.email @@ -167,7 +266,7 @@ async def get_list_cost_summary( list_name=db_list.name, total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), num_participating_users=num_participating_users, - equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), + equal_share_per_user=equal_share_per_user_for_response, # Use the ideal share for the response field user_balances=user_balances ) diff --git a/be/app/api/v1/endpoints/ocr.py b/be/app/api/v1/endpoints/ocr.py index 14192a4..9a21689 100644 --- a/be/app/api/v1/endpoints/ocr.py +++ b/be/app/api/v1/endpoints/ocr.py @@ -7,7 +7,7 @@ from google.api_core import exceptions as google_exceptions from app.auth import current_active_user from app.models import User as UserModel from app.schemas.ocr import OcrExtractResponse -from app.core.gemini import extract_items_from_image_gemini, gemini_initialization_error, GeminiOCRService +from app.core.gemini import GeminiOCRService, gemini_initialization_error from app.core.exceptions import ( OCRServiceUnavailableError, OCRServiceConfigError, @@ -56,11 +56,8 @@ async def ocr_extract_items( raise FileTooLargeError() try: - # Call the Gemini helper function - extracted_items = await extract_items_from_image_gemini( - image_bytes=contents, - mime_type=image_file.content_type - ) + # Use the ocr_service instance instead of the standalone function + extracted_items = await ocr_service.extract_items(image_data=contents) logger.info(f"Successfully extracted {len(extracted_items)} items for user {current_user.email}.") return OcrExtractResponse(extracted_items=extracted_items) diff --git a/be/app/config.py b/be/app/config.py index b92d91e..86ca83a 100644 --- a/be/app/config.py +++ b/be/app/config.py @@ -16,8 +16,7 @@ class Settings(BaseSettings): # --- JWT Settings --- (SECRET_KEY is used by FastAPI-Users) SECRET_KEY: str # Must be set via environment variable - # ALGORITHM: str = "HS256" # Handled by FastAPI-Users strategy - # ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # This specific line is commented, the one under Session Settings is used. + # FastAPI-Users handles JWT algorithm internally # --- OCR Settings --- MAX_FILE_SIZE_MB: int = 10 # Maximum allowed file size for OCR processing @@ -36,6 +35,14 @@ Bread __Apples__ Organic Bananas """ + # --- OCR Error Messages --- + OCR_SERVICE_UNAVAILABLE: str = "OCR service is currently unavailable. Please try again later." + OCR_SERVICE_CONFIG_ERROR: str = "OCR service configuration error. Please contact support." + OCR_UNEXPECTED_ERROR: str = "An unexpected error occurred during OCR processing." + OCR_QUOTA_EXCEEDED: str = "OCR service quota exceeded. Please try again later." + OCR_INVALID_FILE_TYPE: str = "Invalid file type. Supported types: {types}" + OCR_FILE_TOO_LARGE: str = "File too large. Maximum size: {size}MB" + OCR_PROCESSING_ERROR: str = "Error processing image: {detail}" # --- Gemini AI Settings --- GEMINI_MODEL_NAME: str = "gemini-2.0-flash" # The model to use for OCR @@ -98,6 +105,14 @@ Organic Bananas DB_TRANSACTION_ERROR: str = "Database transaction error" DB_QUERY_ERROR: str = "Database query error" + # --- Auth Error Messages --- + AUTH_INVALID_CREDENTIALS: str = "Invalid username or password" + AUTH_NOT_AUTHENTICATED: str = "Not authenticated" + AUTH_JWT_ERROR: str = "JWT token error: {error}" + AUTH_JWT_UNEXPECTED_ERROR: str = "Unexpected JWT error: {error}" + AUTH_HEADER_NAME: str = "WWW-Authenticate" + AUTH_HEADER_PREFIX: str = "Bearer" + # OAuth Settings GOOGLE_CLIENT_ID: str = "" GOOGLE_CLIENT_SECRET: str = "" diff --git a/be/app/core/exceptions.py b/be/app/core/exceptions.py index 6bae250..ed77fb7 100644 --- a/be/app/core/exceptions.py +++ b/be/app/core/exceptions.py @@ -295,7 +295,7 @@ class JWTError(HTTPException): def __init__(self, error: str): super().__init__( status_code=status.HTTP_401_UNAUTHORIZED, - detail=settings.JWT_ERROR.format(error=error), + detail=settings.AUTH_JWT_ERROR.format(error=error), headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""} ) @@ -304,7 +304,7 @@ class JWTUnexpectedError(HTTPException): def __init__(self, error: str): super().__init__( status_code=status.HTTP_401_UNAUTHORIZED, - detail=settings.JWT_UNEXPECTED_ERROR.format(error=error), + detail=settings.AUTH_JWT_UNEXPECTED_ERROR.format(error=error), headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""} ) diff --git a/be/app/core/gemini.py b/be/app/core/gemini.py index c09e983..f5a4f8a 100644 --- a/be/app/core/gemini.py +++ b/be/app/core/gemini.py @@ -9,7 +9,8 @@ from app.core.exceptions import ( OCRServiceUnavailableError, OCRServiceConfigError, OCRUnexpectedError, - OCRQuotaExceededError + OCRQuotaExceededError, + OCRProcessingError ) logger = logging.getLogger(__name__) @@ -55,10 +56,10 @@ def get_gemini_client(): Raises an exception if initialization failed. """ if gemini_initialization_error: - raise RuntimeError(f"Gemini client could not be initialized: {gemini_initialization_error}") + raise OCRServiceConfigError() if gemini_flash_client is None: # This case should ideally be covered by the check above, but as a safeguard: - raise RuntimeError("Gemini client is not available (unknown initialization issue).") + raise OCRServiceConfigError() return gemini_flash_client # Define the prompt as a constant @@ -88,26 +89,29 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = " A list of extracted item strings. Raises: - RuntimeError: If the Gemini client is not initialized. - google_exceptions.GoogleAPIError: For API call errors (quota, invalid key etc.). - ValueError: If the response is blocked or contains no usable text. + OCRServiceConfigError: If the Gemini client is not initialized. + OCRQuotaExceededError: If API quota is exceeded. + OCRServiceUnavailableError: For general API call errors. + OCRProcessingError: If the response is blocked or contains no usable text. + OCRUnexpectedError: For any other unexpected errors. """ - client = get_gemini_client() # Raises RuntimeError if not initialized - - # Prepare image part for multimodal input - image_part = { - "mime_type": mime_type, - "data": image_bytes - } - - # Prepare the full prompt content - prompt_parts = [ - settings.OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first - image_part # Then the image - ] - - logger.info("Sending image to Gemini for item extraction...") try: + client = get_gemini_client() # Raises OCRServiceConfigError if not initialized + + # Prepare image part for multimodal input + image_part = { + "mime_type": mime_type, + "data": image_bytes + } + + # Prepare the full prompt content + prompt_parts = [ + settings.OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first + image_part # Then the image + ] + + logger.info("Sending image to Gemini for item extraction...") + # Make the API call # Use generate_content_async for async FastAPI response = await client.generate_content_async(prompt_parts) @@ -120,9 +124,9 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = " finish_reason = response.candidates[0].finish_reason if response.candidates else 'UNKNOWN' safety_ratings = response.candidates[0].safety_ratings if response.candidates else 'N/A' if finish_reason == 'SAFETY': - raise ValueError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}") + raise OCRProcessingError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}") else: - raise ValueError(f"Gemini response was empty or incomplete. Finish Reason: {finish_reason}") + raise OCRUnexpectedError() # Extract text - assumes the first part of the first candidate is the text response raw_text = response.text # response.text is a shortcut for response.candidates[0].content.parts[0].text @@ -143,32 +147,59 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = " except google_exceptions.GoogleAPIError as e: logger.error(f"Gemini API Error: {e}", exc_info=True) - # Re-raise specific Google API errors for endpoint to handle (e.g., quota) - raise e + if "quota" in str(e).lower(): + raise OCRQuotaExceededError() + raise OCRServiceUnavailableError() + except (OCRServiceConfigError, OCRQuotaExceededError, OCRServiceUnavailableError, OCRProcessingError, OCRUnexpectedError): + # Re-raise specific OCR exceptions + raise except Exception as e: # Catch other unexpected errors during generation or processing logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True) - # Wrap in a generic ValueError or re-raise - raise ValueError(f"Failed to process image with Gemini: {e}") from e + # Wrap in a custom exception + raise OCRUnexpectedError() class GeminiOCRService: def __init__(self): try: genai.configure(api_key=settings.GEMINI_API_KEY) - self.model = genai.GenerativeModel(settings.GEMINI_MODEL_NAME) - self.model.safety_settings = settings.GEMINI_SAFETY_SETTINGS - self.model.generation_config = settings.GEMINI_GENERATION_CONFIG + self.model = genai.GenerativeModel( + model_name=settings.GEMINI_MODEL_NAME, + # Safety settings from config + safety_settings={ + getattr(HarmCategory, category): getattr(HarmBlockThreshold, threshold) + for category, threshold in settings.GEMINI_SAFETY_SETTINGS.items() + }, + # Generation config from settings + generation_config=genai.types.GenerationConfig( + **settings.GEMINI_GENERATION_CONFIG + ) + ) except Exception as e: logger.error(f"Failed to initialize Gemini client: {e}") raise OCRServiceConfigError() - async def extract_items(self, image_data: bytes) -> List[str]: + async def extract_items(self, image_data: bytes, mime_type: str = "image/jpeg") -> List[str]: """ Extract shopping list items from an image using Gemini Vision. + + Args: + image_data: The image content as bytes. + mime_type: The MIME type of the image (e.g., "image/jpeg", "image/png", "image/webp"). + + Returns: + A list of extracted item strings. + + Raises: + OCRServiceConfigError: If the Gemini client is not initialized. + OCRQuotaExceededError: If API quota is exceeded. + OCRServiceUnavailableError: For general API call errors. + OCRProcessingError: If the response is blocked or contains no usable text. + OCRUnexpectedError: For any other unexpected errors. """ try: # Create image part - image_parts = [{"mime_type": "image/jpeg", "data": image_data}] + image_parts = [{"mime_type": mime_type, "data": image_data}] # Generate content response = await self.model.generate_content_async( @@ -177,19 +208,34 @@ class GeminiOCRService: # Process response if not response.text: + logger.warning("Gemini response is empty") raise OCRUnexpectedError() + + # Check for safety blocks + if hasattr(response, 'candidates') and response.candidates and hasattr(response.candidates[0], 'finish_reason'): + finish_reason = response.candidates[0].finish_reason + if finish_reason == 'SAFETY': + safety_ratings = response.candidates[0].safety_ratings if hasattr(response.candidates[0], 'safety_ratings') else 'N/A' + raise OCRProcessingError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}") # Split response into lines and clean up - items = [ - item.strip() - for item in response.text.split("\n") - if item.strip() and not item.strip().startswith("Example") - ] + items = [] + for line in response.text.splitlines(): + cleaned_line = line.strip() + if cleaned_line and len(cleaned_line) > 1 and not cleaned_line.startswith("Example"): + items.append(cleaned_line) + logger.info(f"Extracted {len(items)} potential items.") return items - except Exception as e: + except google_exceptions.GoogleAPIError as e: logger.error(f"Error during OCR extraction: {e}") if "quota" in str(e).lower(): raise OCRQuotaExceededError() - raise OCRServiceUnavailableError() \ No newline at end of file + raise OCRServiceUnavailableError() + except (OCRServiceConfigError, OCRQuotaExceededError, OCRServiceUnavailableError, OCRProcessingError, OCRUnexpectedError): + # Re-raise specific OCR exceptions + raise + except Exception as e: + logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True) + raise OCRUnexpectedError() \ No newline at end of file diff --git a/be/app/core/security.py b/be/app/core/security.py index 87ee4a1..197c732 100644 --- a/be/app/core/security.py +++ b/be/app/core/security.py @@ -8,6 +8,9 @@ from passlib.context import CryptContext from app.config import settings # Import settings from config # --- Password Hashing --- +# These functions are used for password hashing and verification +# They complement FastAPI-Users but provide direct access to the underlying password functionality +# when needed outside of the FastAPI-Users authentication flow. # Configure passlib context # Using bcrypt as the default hashing scheme @@ -17,6 +20,8 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") def verify_password(plain_password: str, hashed_password: str) -> bool: """ Verifies a plain text password against a hashed password. + This is used by FastAPI-Users internally, but also exposed here for custom authentication flows + if needed. Args: plain_password: The password attempt. @@ -34,6 +39,8 @@ def verify_password(plain_password: str, hashed_password: str) -> bool: def hash_password(password: str) -> str: """ Hashes a plain text password using the configured context (bcrypt). + This is used by FastAPI-Users internally, but also exposed here for + custom user creation or password reset flows if needed. Args: password: The plain text password to hash. @@ -45,14 +52,22 @@ def hash_password(password: str) -> str: # --- JSON Web Tokens (JWT) --- -# FastAPI-Users now handles all tokenization. +# FastAPI-Users now handles all JWT token creation and validation. +# The code below is commented out because FastAPI-Users provides these features. +# It's kept for reference in case a custom implementation is needed later. -# You might add a function here later to extract the 'sub' (subject/user id) -# specifically, often used in dependency injection for authentication. +# Example of a potential future implementation: # def get_subject_from_token(token: str) -> Optional[str]: +# """ +# Extract the subject (user ID) from a JWT token. +# This would be used if we need to validate tokens outside of FastAPI-Users flow. +# For now, use fastapi_users.current_user dependency instead. +# """ # # This would need to use FastAPI-Users' token verification if ever implemented # # For example, by decoding the token using the strategy from the auth backend -# payload = {} # Placeholder for actual token decoding logic -# if payload: +# try: +# payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) # return payload.get("sub") +# except JWTError: +# return None # return None \ No newline at end of file diff --git a/be/app/crud/expense.py b/be/app/crud/expense.py index 43c58d4..957e136 100644 --- a/be/app/crud/expense.py +++ b/be/app/crud/expense.py @@ -148,7 +148,7 @@ async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_us group_id=final_group_id, # Use resolved group_id item_id=expense_in.item_id, paid_by_user_id=expense_in.paid_by_user_id, - created_by_user_id=current_user_id + created_by_user_id=current_user_id ) db.add(db_expense) await db.flush() # Get expense ID diff --git a/be/app/crud/settlement.py b/be/app/crud/settlement.py index 30ec0ce..3f7f784 100644 --- a/be/app/crud/settlement.py +++ b/be/app/crud/settlement.py @@ -59,8 +59,8 @@ async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, c paid_to_user_id=settlement_in.paid_to_user_id, amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc), - description=settlement_in.description - # created_by_user_id = current_user_id # Optional: Who recorded this settlement + description=settlement_in.description, + created_by_user_id=current_user_id ) db.add(db_settlement) await db.flush() @@ -72,7 +72,8 @@ async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, c .options( selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), - selectinload(SettlementModel.group) + selectinload(SettlementModel.group), + selectinload(SettlementModel.created_by_user) ) ) result = await db.execute(stmt) @@ -103,7 +104,8 @@ async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional .options( selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), - selectinload(SettlementModel.group) + selectinload(SettlementModel.group), + selectinload(SettlementModel.created_by_user) ) .where(SettlementModel.id == settlement_id) ) @@ -123,7 +125,8 @@ async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = .options( selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), - selectinload(SettlementModel.group) + selectinload(SettlementModel.group), + selectinload(SettlementModel.created_by_user) ) ) return result.scalars().all() @@ -149,7 +152,8 @@ async def get_settlements_involving_user( .options( selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), - selectinload(SettlementModel.group) + selectinload(SettlementModel.group), + selectinload(SettlementModel.created_by_user) ) ) if group_id: @@ -216,7 +220,8 @@ async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, se .options( selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), - selectinload(SettlementModel.group) + selectinload(SettlementModel.group), + selectinload(SettlementModel.created_by_user) ) ) result = await db.execute(stmt) diff --git a/be/app/database.py b/be/app/database.py index db7768d..20fb047 100644 --- a/be/app/database.py +++ b/be/app/database.py @@ -2,6 +2,9 @@ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker, declarative_base from app.config import settings +import logging + +logger = logging.getLogger(__name__) # Ensure DATABASE_URL is set before proceeding if not settings.DATABASE_URL: @@ -41,4 +44,23 @@ async def get_async_session() -> AsyncSession: # type: ignore # Commit/rollback should be handled by the functions using the session. # Alias for backward compatibility -get_db = get_async_session \ No newline at end of file +get_db = get_async_session + +async def get_transactional_session() -> AsyncSession: # type: ignore + """ + Dependency function that yields an AsyncSession wrapped in a transaction. + Commits on successful completion of the request handler, rolls back on exceptions. + """ + async with AsyncSessionLocal() as session: + async with session.begin(): # Start a transaction + try: + logger.debug(f"Transaction started for session {id(session)}") + yield session + # If no exceptions were raised by the endpoint, the 'session.begin()' + # context manager will automatically commit here. + logger.debug(f"Transaction committed for session {id(session)}") + except Exception as e: + # The 'session.begin()' context manager will automatically + # rollback on any exception. + logger.error(f"Transaction rolled back for session {id(session)} due to: {e}", exc_info=True) + raise # Re-raise the exception to be handled by FastAPI's error handlers \ No newline at end of file diff --git a/be/app/models.py b/be/app/models.py index 0bfaece..e31d227 100644 --- a/be/app/models.py +++ b/be/app/models.py @@ -65,9 +65,11 @@ class User(Base): # --- Relationships for Cost Splitting --- expenses_paid = relationship("Expense", foreign_keys="Expense.paid_by_user_id", back_populates="paid_by_user", cascade="all, delete-orphan") + expenses_created = relationship("Expense", foreign_keys="Expense.created_by_user_id", back_populates="created_by_user", cascade="all, delete-orphan") expense_splits = relationship("ExpenseSplit", foreign_keys="ExpenseSplit.user_id", back_populates="user", cascade="all, delete-orphan") settlements_made = relationship("Settlement", foreign_keys="Settlement.paid_by_user_id", back_populates="payer", cascade="all, delete-orphan") settlements_received = relationship("Settlement", foreign_keys="Settlement.paid_to_user_id", back_populates="payee", cascade="all, delete-orphan") + settlements_created = relationship("Settlement", foreign_keys="Settlement.created_by_user_id", back_populates="created_by_user", cascade="all, delete-orphan") # --- End Relationships for Cost Splitting --- @@ -197,6 +199,7 @@ class Expense(Base): group_id = Column(Integer, ForeignKey("groups.id"), nullable=True, index=True) item_id = Column(Integer, ForeignKey("items.id"), nullable=True) paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) + created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) @@ -204,6 +207,7 @@ class Expense(Base): # Relationships paid_by_user = relationship("User", foreign_keys=[paid_by_user_id], back_populates="expenses_paid") + created_by_user = relationship("User", foreign_keys=[created_by_user_id], back_populates="expenses_created") list = relationship("List", foreign_keys=[list_id], back_populates="expenses") group = relationship("Group", foreign_keys=[group_id], back_populates="expenses") item = relationship("Item", foreign_keys=[item_id], back_populates="expenses") @@ -246,6 +250,7 @@ class Settlement(Base): amount = Column(Numeric(10, 2), nullable=False) settlement_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) description = Column(Text, nullable=True) + created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) @@ -255,6 +260,7 @@ class Settlement(Base): group = relationship("Group", foreign_keys=[group_id], back_populates="settlements") payer = relationship("User", foreign_keys=[paid_by_user_id], back_populates="settlements_made") payee = relationship("User", foreign_keys=[paid_to_user_id], back_populates="settlements_received") + created_by_user = relationship("User", foreign_keys=[created_by_user_id], back_populates="settlements_created") __table_args__ = ( # Ensure payer and payee are different users diff --git a/be/app/schemas/expense.py b/be/app/schemas/expense.py index d561abf..8f27630 100644 --- a/be/app/schemas/expense.py +++ b/be/app/schemas/expense.py @@ -79,6 +79,7 @@ class ExpensePublic(ExpenseBase): created_at: datetime updated_at: datetime version: int + created_by_user_id: int splits: List[ExpenseSplitPublic] = [] # paid_by_user: Optional[UserPublic] # If nesting user details # list: Optional[ListPublic] # If nesting list details @@ -119,9 +120,11 @@ class SettlementPublic(SettlementBase): id: int created_at: datetime updated_at: datetime - # payer: Optional[UserPublic] - # payee: Optional[UserPublic] - # group: Optional[GroupPublic] + version: int + created_by_user_id: int + # payer: Optional[UserPublic] # If we want to include payer details + # payee: Optional[UserPublic] # If we want to include payee details + # group: Optional[GroupPublic] # If we want to include group details model_config = ConfigDict(from_attributes=True) # Placeholder for nested schemas (e.g., UserPublic) if needed diff --git a/be/tests/crud/test_expense.py b/be/tests/crud/test_expense.py index 3181889..b4a7044 100644 --- a/be/tests/crud/test_expense.py +++ b/be/tests/crud/test_expense.py @@ -122,7 +122,9 @@ def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model group_id=expense_create_data_equal_split_group_ctx.group_id, item_id=expense_create_data_equal_split_group_ctx.item_id, paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id, + created_by_user_id=basic_user_model.id, paid_by=basic_user_model, # Assuming paid_by relation is loaded + created_by_user=basic_user_model, # Assuming created_by_user relation is loaded # splits would be populated after creation usually version=1 ) diff --git a/be/tests/crud/test_settlement.py b/be/tests/crud/test_settlement.py index dbae0f0..50a380a 100644 --- a/be/tests/crud/test_settlement.py +++ b/be/tests/crud/test_settlement.py @@ -60,12 +60,14 @@ def db_settlement_model(): amount=Decimal("10.50"), settlement_date=datetime.now(timezone.utc), description="Original settlement", + created_by_user_id=1, version=1, # Initial version created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), payer=UserModel(id=1, name="Payer User"), payee=UserModel(id=2, name="Payee User"), - group=GroupModel(id=1, name="Test Group") + group=GroupModel(id=1, name="Test Group"), + created_by_user=UserModel(id=1, name="Payer User") # Same as payer for simplicity ) @pytest.fixture diff --git a/fe/src/config/api-config.ts b/fe/src/config/api-config.ts index 5e998ca..f6bd1ae 100644 --- a/fe/src/config/api-config.ts +++ b/fe/src/config/api-config.ts @@ -64,9 +64,9 @@ export const API_ENDPOINTS = { INVITES: { BASE: '/invites', BY_ID: (id: string) => `/invites/${id}`, - ACCEPT: (id: string) => `/invites/${id}/accept`, - DECLINE: (id: string) => `/invites/${id}/decline`, - REVOKE: (id: string) => `/invites/${id}/revoke`, + ACCEPT: (id: string) => `/invites/accept/${id}`, + DECLINE: (id: string) => `/invites/decline/${id}`, + REVOKE: (id: string) => `/invites/revoke/${id}`, LIST: '/invites', PENDING: '/invites/pending', SENT: '/invites/sent',