Enhance configuration and error handling in the application; add new error messages for OCR and authentication processes. Refactor database session management to include transaction handling, and update models to track user creation for expenses and settlements. Update API endpoints to improve cost-sharing calculations and adjust invite management routes for clarity.

This commit is contained in:
mohamad 2025-05-17 13:56:17 +02:00
parent c2aa62fa03
commit 5abe7839f1
14 changed files with 294 additions and 82 deletions

View File

@ -4,7 +4,8 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, selectinload from sqlalchemy.orm import Session, selectinload
from decimal import Decimal, ROUND_HALF_UP from decimal import Decimal, ROUND_HALF_UP, ROUND_DOWN
from typing import List
from app.database import get_db from app.database import get_db
from app.auth import current_active_user from app.auth import current_active_user
@ -19,7 +20,7 @@ from app.models import (
ExpenseSplit as ExpenseSplitModel, ExpenseSplit as ExpenseSplitModel,
Settlement as SettlementModel Settlement as SettlementModel
) )
from app.schemas.cost import ListCostSummary, GroupBalanceSummary from app.schemas.cost import ListCostSummary, GroupBalanceSummary, UserCostShare, UserBalanceDetail, SuggestedSettlement
from app.schemas.expense import ExpenseCreate from app.schemas.expense import ExpenseCreate
from app.crud import list as crud_list from app.crud import list as crud_list
from app.crud import expense as crud_expense from app.crud import expense as crud_expense
@ -28,6 +29,85 @@ from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotF
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
def calculate_suggested_settlements(user_balances: List[UserBalanceDetail]) -> List[SuggestedSettlement]:
"""
Calculate suggested settlements to balance the finances within a group.
This function takes the current balances of all users and suggests optimal settlements
to minimize the number of transactions needed to settle all debts.
Args:
user_balances: List of UserBalanceDetail objects with their current balances
Returns:
List of SuggestedSettlement objects representing the suggested payments
"""
# Create list of users who owe money (negative balance) and who are owed money (positive balance)
debtors = [] # Users who owe money (negative balance)
creditors = [] # Users who are owed money (positive balance)
# Threshold to consider a balance as zero due to floating point precision
epsilon = Decimal('0.01')
# Sort users into debtors and creditors
for user in user_balances:
# Skip users with zero balance (or very close to zero)
if abs(user.net_balance) < epsilon:
continue
if user.net_balance < Decimal('0'):
# User owes money
debtors.append({
'user_id': user.user_id,
'user_identifier': user.user_identifier,
'amount': -user.net_balance # Convert to positive amount
})
else:
# User is owed money
creditors.append({
'user_id': user.user_id,
'user_identifier': user.user_identifier,
'amount': user.net_balance
})
# Sort by amount (descending) to handle largest debts first
debtors.sort(key=lambda x: x['amount'], reverse=True)
creditors.sort(key=lambda x: x['amount'], reverse=True)
settlements = []
# Iterate through debtors and match them with creditors
while debtors and creditors:
debtor = debtors[0]
creditor = creditors[0]
# Determine the settlement amount (the smaller of the two amounts)
amount = min(debtor['amount'], creditor['amount']).quantize(Decimal('0.01'), rounding=ROUND_HALF_UP)
# Create settlement record
if amount > Decimal('0'):
settlements.append(
SuggestedSettlement(
from_user_id=debtor['user_id'],
from_user_identifier=debtor['user_identifier'],
to_user_id=creditor['user_id'],
to_user_identifier=creditor['user_identifier'],
amount=amount
)
)
# Update balances
debtor['amount'] -= amount
creditor['amount'] -= amount
# Remove users who have settled their debts/credits
if debtor['amount'] < epsilon:
debtors.pop(0)
if creditor['amount'] < epsilon:
creditors.pop(0)
return settlements
@router.get( @router.get(
"/lists/{list_id}/cost-summary", "/lists/{list_id}/cost-summary",
response_model=ListCostSummary, response_model=ListCostSummary,
@ -105,7 +185,7 @@ async def get_list_cost_summary(
total_amount=total_amount, total_amount=total_amount,
list_id=list_id, list_id=list_id,
split_type=SplitTypeEnum.ITEM_BASED, split_type=SplitTypeEnum.ITEM_BASED,
paid_by_user_id=current_user.id # Use current user as payer for now paid_by_user_id=db_list.creator.id
) )
db_expense = await crud_expense.create_expense(db=db, expense_in=expense_in) db_expense = await crud_expense.create_expense(db=db, expense_in=expense_in)
@ -137,17 +217,36 @@ async def get_list_cost_summary(
user_balances=[] user_balances=[]
) )
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) # This is the ideal equal share, returned in the summary
remainder = total_list_cost - (equal_share_per_user * num_participating_users) equal_share_per_user_for_response = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
# Sort users for deterministic remainder distribution
sorted_participating_users = sorted(list(participating_users), key=lambda u: u.id)
user_final_shares = {}
if num_participating_users > 0:
base_share_unrounded = total_list_cost / Decimal(num_participating_users)
# Calculate initial share for each user, rounding down
for user in sorted_participating_users:
user_final_shares[user.id] = base_share_unrounded.quantize(Decimal("0.01"), rounding=ROUND_DOWN)
# Calculate sum of rounded down shares
sum_of_rounded_shares = sum(user_final_shares.values())
# Calculate remaining pennies to be distributed
remaining_pennies = int(((total_list_cost - sum_of_rounded_shares) * Decimal("100")).to_integral_value(rounding=ROUND_HALF_UP))
# Distribute remaining pennies one by one to sorted users
for i in range(remaining_pennies):
user_to_adjust = sorted_participating_users[i % num_participating_users]
user_final_shares[user_to_adjust.id] += Decimal("0.01")
user_balances = [] user_balances = []
first_user_processed = False for user in sorted_participating_users: # Iterate over sorted users
for user in participating_users:
items_added = user_items_added_value.get(user.id, Decimal("0.00")) items_added = user_items_added_value.get(user.id, Decimal("0.00"))
current_user_share = equal_share_per_user # current_user_share is now the precisely calculated share for this user
if not first_user_processed and remainder != Decimal("0"): current_user_share = user_final_shares.get(user.id, Decimal("0.00"))
current_user_share += remainder
first_user_processed = True
balance = items_added - current_user_share balance = items_added - current_user_share
user_identifier = user.name if user.name else user.email user_identifier = user.name if user.name else user.email
@ -167,7 +266,7 @@ async def get_list_cost_summary(
list_name=db_list.name, list_name=db_list.name,
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
num_participating_users=num_participating_users, num_participating_users=num_participating_users,
equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), equal_share_per_user=equal_share_per_user_for_response, # Use the ideal share for the response field
user_balances=user_balances user_balances=user_balances
) )

View File

@ -7,7 +7,7 @@ from google.api_core import exceptions as google_exceptions
from app.auth import current_active_user from app.auth import current_active_user
from app.models import User as UserModel from app.models import User as UserModel
from app.schemas.ocr import OcrExtractResponse from app.schemas.ocr import OcrExtractResponse
from app.core.gemini import extract_items_from_image_gemini, gemini_initialization_error, GeminiOCRService from app.core.gemini import GeminiOCRService, gemini_initialization_error
from app.core.exceptions import ( from app.core.exceptions import (
OCRServiceUnavailableError, OCRServiceUnavailableError,
OCRServiceConfigError, OCRServiceConfigError,
@ -56,11 +56,8 @@ async def ocr_extract_items(
raise FileTooLargeError() raise FileTooLargeError()
try: try:
# Call the Gemini helper function # Use the ocr_service instance instead of the standalone function
extracted_items = await extract_items_from_image_gemini( extracted_items = await ocr_service.extract_items(image_data=contents)
image_bytes=contents,
mime_type=image_file.content_type
)
logger.info(f"Successfully extracted {len(extracted_items)} items for user {current_user.email}.") logger.info(f"Successfully extracted {len(extracted_items)} items for user {current_user.email}.")
return OcrExtractResponse(extracted_items=extracted_items) return OcrExtractResponse(extracted_items=extracted_items)

View File

@ -16,8 +16,7 @@ class Settings(BaseSettings):
# --- JWT Settings --- (SECRET_KEY is used by FastAPI-Users) # --- JWT Settings --- (SECRET_KEY is used by FastAPI-Users)
SECRET_KEY: str # Must be set via environment variable SECRET_KEY: str # Must be set via environment variable
# ALGORITHM: str = "HS256" # Handled by FastAPI-Users strategy # FastAPI-Users handles JWT algorithm internally
# ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # This specific line is commented, the one under Session Settings is used.
# --- OCR Settings --- # --- OCR Settings ---
MAX_FILE_SIZE_MB: int = 10 # Maximum allowed file size for OCR processing MAX_FILE_SIZE_MB: int = 10 # Maximum allowed file size for OCR processing
@ -36,6 +35,14 @@ Bread
__Apples__ __Apples__
Organic Bananas Organic Bananas
""" """
# --- OCR Error Messages ---
OCR_SERVICE_UNAVAILABLE: str = "OCR service is currently unavailable. Please try again later."
OCR_SERVICE_CONFIG_ERROR: str = "OCR service configuration error. Please contact support."
OCR_UNEXPECTED_ERROR: str = "An unexpected error occurred during OCR processing."
OCR_QUOTA_EXCEEDED: str = "OCR service quota exceeded. Please try again later."
OCR_INVALID_FILE_TYPE: str = "Invalid file type. Supported types: {types}"
OCR_FILE_TOO_LARGE: str = "File too large. Maximum size: {size}MB"
OCR_PROCESSING_ERROR: str = "Error processing image: {detail}"
# --- Gemini AI Settings --- # --- Gemini AI Settings ---
GEMINI_MODEL_NAME: str = "gemini-2.0-flash" # The model to use for OCR GEMINI_MODEL_NAME: str = "gemini-2.0-flash" # The model to use for OCR
@ -98,6 +105,14 @@ Organic Bananas
DB_TRANSACTION_ERROR: str = "Database transaction error" DB_TRANSACTION_ERROR: str = "Database transaction error"
DB_QUERY_ERROR: str = "Database query error" DB_QUERY_ERROR: str = "Database query error"
# --- Auth Error Messages ---
AUTH_INVALID_CREDENTIALS: str = "Invalid username or password"
AUTH_NOT_AUTHENTICATED: str = "Not authenticated"
AUTH_JWT_ERROR: str = "JWT token error: {error}"
AUTH_JWT_UNEXPECTED_ERROR: str = "Unexpected JWT error: {error}"
AUTH_HEADER_NAME: str = "WWW-Authenticate"
AUTH_HEADER_PREFIX: str = "Bearer"
# OAuth Settings # OAuth Settings
GOOGLE_CLIENT_ID: str = "" GOOGLE_CLIENT_ID: str = ""
GOOGLE_CLIENT_SECRET: str = "" GOOGLE_CLIENT_SECRET: str = ""

View File

@ -295,7 +295,7 @@ class JWTError(HTTPException):
def __init__(self, error: str): def __init__(self, error: str):
super().__init__( super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.JWT_ERROR.format(error=error), detail=settings.AUTH_JWT_ERROR.format(error=error),
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""} headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
) )
@ -304,7 +304,7 @@ class JWTUnexpectedError(HTTPException):
def __init__(self, error: str): def __init__(self, error: str):
super().__init__( super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.JWT_UNEXPECTED_ERROR.format(error=error), detail=settings.AUTH_JWT_UNEXPECTED_ERROR.format(error=error),
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""} headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
) )

View File

@ -9,7 +9,8 @@ from app.core.exceptions import (
OCRServiceUnavailableError, OCRServiceUnavailableError,
OCRServiceConfigError, OCRServiceConfigError,
OCRUnexpectedError, OCRUnexpectedError,
OCRQuotaExceededError OCRQuotaExceededError,
OCRProcessingError
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -55,10 +56,10 @@ def get_gemini_client():
Raises an exception if initialization failed. Raises an exception if initialization failed.
""" """
if gemini_initialization_error: if gemini_initialization_error:
raise RuntimeError(f"Gemini client could not be initialized: {gemini_initialization_error}") raise OCRServiceConfigError()
if gemini_flash_client is None: if gemini_flash_client is None:
# This case should ideally be covered by the check above, but as a safeguard: # This case should ideally be covered by the check above, but as a safeguard:
raise RuntimeError("Gemini client is not available (unknown initialization issue).") raise OCRServiceConfigError()
return gemini_flash_client return gemini_flash_client
# Define the prompt as a constant # Define the prompt as a constant
@ -88,11 +89,14 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "
A list of extracted item strings. A list of extracted item strings.
Raises: Raises:
RuntimeError: If the Gemini client is not initialized. OCRServiceConfigError: If the Gemini client is not initialized.
google_exceptions.GoogleAPIError: For API call errors (quota, invalid key etc.). OCRQuotaExceededError: If API quota is exceeded.
ValueError: If the response is blocked or contains no usable text. OCRServiceUnavailableError: For general API call errors.
OCRProcessingError: If the response is blocked or contains no usable text.
OCRUnexpectedError: For any other unexpected errors.
""" """
client = get_gemini_client() # Raises RuntimeError if not initialized try:
client = get_gemini_client() # Raises OCRServiceConfigError if not initialized
# Prepare image part for multimodal input # Prepare image part for multimodal input
image_part = { image_part = {
@ -107,7 +111,7 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "
] ]
logger.info("Sending image to Gemini for item extraction...") logger.info("Sending image to Gemini for item extraction...")
try:
# Make the API call # Make the API call
# Use generate_content_async for async FastAPI # Use generate_content_async for async FastAPI
response = await client.generate_content_async(prompt_parts) response = await client.generate_content_async(prompt_parts)
@ -120,9 +124,9 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "
finish_reason = response.candidates[0].finish_reason if response.candidates else 'UNKNOWN' finish_reason = response.candidates[0].finish_reason if response.candidates else 'UNKNOWN'
safety_ratings = response.candidates[0].safety_ratings if response.candidates else 'N/A' safety_ratings = response.candidates[0].safety_ratings if response.candidates else 'N/A'
if finish_reason == 'SAFETY': if finish_reason == 'SAFETY':
raise ValueError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}") raise OCRProcessingError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
else: else:
raise ValueError(f"Gemini response was empty or incomplete. Finish Reason: {finish_reason}") raise OCRUnexpectedError()
# Extract text - assumes the first part of the first candidate is the text response # Extract text - assumes the first part of the first candidate is the text response
raw_text = response.text # response.text is a shortcut for response.candidates[0].content.parts[0].text raw_text = response.text # response.text is a shortcut for response.candidates[0].content.parts[0].text
@ -143,32 +147,59 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "
except google_exceptions.GoogleAPIError as e: except google_exceptions.GoogleAPIError as e:
logger.error(f"Gemini API Error: {e}", exc_info=True) logger.error(f"Gemini API Error: {e}", exc_info=True)
# Re-raise specific Google API errors for endpoint to handle (e.g., quota) if "quota" in str(e).lower():
raise e raise OCRQuotaExceededError()
raise OCRServiceUnavailableError()
except (OCRServiceConfigError, OCRQuotaExceededError, OCRServiceUnavailableError, OCRProcessingError, OCRUnexpectedError):
# Re-raise specific OCR exceptions
raise
except Exception as e: except Exception as e:
# Catch other unexpected errors during generation or processing # Catch other unexpected errors during generation or processing
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True) logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
# Wrap in a generic ValueError or re-raise # Wrap in a custom exception
raise ValueError(f"Failed to process image with Gemini: {e}") from e raise OCRUnexpectedError()
class GeminiOCRService: class GeminiOCRService:
def __init__(self): def __init__(self):
try: try:
genai.configure(api_key=settings.GEMINI_API_KEY) genai.configure(api_key=settings.GEMINI_API_KEY)
self.model = genai.GenerativeModel(settings.GEMINI_MODEL_NAME) self.model = genai.GenerativeModel(
self.model.safety_settings = settings.GEMINI_SAFETY_SETTINGS model_name=settings.GEMINI_MODEL_NAME,
self.model.generation_config = settings.GEMINI_GENERATION_CONFIG # Safety settings from config
safety_settings={
getattr(HarmCategory, category): getattr(HarmBlockThreshold, threshold)
for category, threshold in settings.GEMINI_SAFETY_SETTINGS.items()
},
# Generation config from settings
generation_config=genai.types.GenerationConfig(
**settings.GEMINI_GENERATION_CONFIG
)
)
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize Gemini client: {e}") logger.error(f"Failed to initialize Gemini client: {e}")
raise OCRServiceConfigError() raise OCRServiceConfigError()
async def extract_items(self, image_data: bytes) -> List[str]: async def extract_items(self, image_data: bytes, mime_type: str = "image/jpeg") -> List[str]:
""" """
Extract shopping list items from an image using Gemini Vision. Extract shopping list items from an image using Gemini Vision.
Args:
image_data: The image content as bytes.
mime_type: The MIME type of the image (e.g., "image/jpeg", "image/png", "image/webp").
Returns:
A list of extracted item strings.
Raises:
OCRServiceConfigError: If the Gemini client is not initialized.
OCRQuotaExceededError: If API quota is exceeded.
OCRServiceUnavailableError: For general API call errors.
OCRProcessingError: If the response is blocked or contains no usable text.
OCRUnexpectedError: For any other unexpected errors.
""" """
try: try:
# Create image part # Create image part
image_parts = [{"mime_type": "image/jpeg", "data": image_data}] image_parts = [{"mime_type": mime_type, "data": image_data}]
# Generate content # Generate content
response = await self.model.generate_content_async( response = await self.model.generate_content_async(
@ -177,19 +208,34 @@ class GeminiOCRService:
# Process response # Process response
if not response.text: if not response.text:
logger.warning("Gemini response is empty")
raise OCRUnexpectedError() raise OCRUnexpectedError()
# Split response into lines and clean up # Check for safety blocks
items = [ if hasattr(response, 'candidates') and response.candidates and hasattr(response.candidates[0], 'finish_reason'):
item.strip() finish_reason = response.candidates[0].finish_reason
for item in response.text.split("\n") if finish_reason == 'SAFETY':
if item.strip() and not item.strip().startswith("Example") safety_ratings = response.candidates[0].safety_ratings if hasattr(response.candidates[0], 'safety_ratings') else 'N/A'
] raise OCRProcessingError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
# Split response into lines and clean up
items = []
for line in response.text.splitlines():
cleaned_line = line.strip()
if cleaned_line and len(cleaned_line) > 1 and not cleaned_line.startswith("Example"):
items.append(cleaned_line)
logger.info(f"Extracted {len(items)} potential items.")
return items return items
except Exception as e: except google_exceptions.GoogleAPIError as e:
logger.error(f"Error during OCR extraction: {e}") logger.error(f"Error during OCR extraction: {e}")
if "quota" in str(e).lower(): if "quota" in str(e).lower():
raise OCRQuotaExceededError() raise OCRQuotaExceededError()
raise OCRServiceUnavailableError() raise OCRServiceUnavailableError()
except (OCRServiceConfigError, OCRQuotaExceededError, OCRServiceUnavailableError, OCRProcessingError, OCRUnexpectedError):
# Re-raise specific OCR exceptions
raise
except Exception as e:
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
raise OCRUnexpectedError()

View File

@ -8,6 +8,9 @@ from passlib.context import CryptContext
from app.config import settings # Import settings from config from app.config import settings # Import settings from config
# --- Password Hashing --- # --- Password Hashing ---
# These functions are used for password hashing and verification
# They complement FastAPI-Users but provide direct access to the underlying password functionality
# when needed outside of the FastAPI-Users authentication flow.
# Configure passlib context # Configure passlib context
# Using bcrypt as the default hashing scheme # Using bcrypt as the default hashing scheme
@ -17,6 +20,8 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password: str, hashed_password: str) -> bool: def verify_password(plain_password: str, hashed_password: str) -> bool:
""" """
Verifies a plain text password against a hashed password. Verifies a plain text password against a hashed password.
This is used by FastAPI-Users internally, but also exposed here for custom authentication flows
if needed.
Args: Args:
plain_password: The password attempt. plain_password: The password attempt.
@ -34,6 +39,8 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
def hash_password(password: str) -> str: def hash_password(password: str) -> str:
""" """
Hashes a plain text password using the configured context (bcrypt). Hashes a plain text password using the configured context (bcrypt).
This is used by FastAPI-Users internally, but also exposed here for
custom user creation or password reset flows if needed.
Args: Args:
password: The plain text password to hash. password: The plain text password to hash.
@ -45,14 +52,22 @@ def hash_password(password: str) -> str:
# --- JSON Web Tokens (JWT) --- # --- JSON Web Tokens (JWT) ---
# FastAPI-Users now handles all tokenization. # FastAPI-Users now handles all JWT token creation and validation.
# The code below is commented out because FastAPI-Users provides these features.
# It's kept for reference in case a custom implementation is needed later.
# You might add a function here later to extract the 'sub' (subject/user id) # Example of a potential future implementation:
# specifically, often used in dependency injection for authentication.
# def get_subject_from_token(token: str) -> Optional[str]: # def get_subject_from_token(token: str) -> Optional[str]:
# """
# Extract the subject (user ID) from a JWT token.
# This would be used if we need to validate tokens outside of FastAPI-Users flow.
# For now, use fastapi_users.current_user dependency instead.
# """
# # This would need to use FastAPI-Users' token verification if ever implemented # # This would need to use FastAPI-Users' token verification if ever implemented
# # For example, by decoding the token using the strategy from the auth backend # # For example, by decoding the token using the strategy from the auth backend
# payload = {} # Placeholder for actual token decoding logic # try:
# if payload: # payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
# return payload.get("sub") # return payload.get("sub")
# except JWTError:
# return None
# return None # return None

View File

@ -59,8 +59,8 @@ async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, c
paid_to_user_id=settlement_in.paid_to_user_id, paid_to_user_id=settlement_in.paid_to_user_id,
amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc), settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc),
description=settlement_in.description description=settlement_in.description,
# created_by_user_id = current_user_id # Optional: Who recorded this settlement created_by_user_id=current_user_id
) )
db.add(db_settlement) db.add(db_settlement)
await db.flush() await db.flush()
@ -72,7 +72,8 @@ async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, c
.options( .options(
selectinload(SettlementModel.payer), selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee), selectinload(SettlementModel.payee),
selectinload(SettlementModel.group) selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
) )
) )
result = await db.execute(stmt) result = await db.execute(stmt)
@ -103,7 +104,8 @@ async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional
.options( .options(
selectinload(SettlementModel.payer), selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee), selectinload(SettlementModel.payee),
selectinload(SettlementModel.group) selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
) )
.where(SettlementModel.id == settlement_id) .where(SettlementModel.id == settlement_id)
) )
@ -123,7 +125,8 @@ async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int =
.options( .options(
selectinload(SettlementModel.payer), selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee), selectinload(SettlementModel.payee),
selectinload(SettlementModel.group) selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
) )
) )
return result.scalars().all() return result.scalars().all()
@ -149,7 +152,8 @@ async def get_settlements_involving_user(
.options( .options(
selectinload(SettlementModel.payer), selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee), selectinload(SettlementModel.payee),
selectinload(SettlementModel.group) selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
) )
) )
if group_id: if group_id:
@ -216,7 +220,8 @@ async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, se
.options( .options(
selectinload(SettlementModel.payer), selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee), selectinload(SettlementModel.payee),
selectinload(SettlementModel.group) selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
) )
) )
result = await db.execute(stmt) result = await db.execute(stmt)

View File

@ -2,6 +2,9 @@
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy.orm import sessionmaker, declarative_base
from app.config import settings from app.config import settings
import logging
logger = logging.getLogger(__name__)
# Ensure DATABASE_URL is set before proceeding # Ensure DATABASE_URL is set before proceeding
if not settings.DATABASE_URL: if not settings.DATABASE_URL:
@ -42,3 +45,22 @@ async def get_async_session() -> AsyncSession: # type: ignore
# Alias for backward compatibility # Alias for backward compatibility
get_db = get_async_session get_db = get_async_session
async def get_transactional_session() -> AsyncSession: # type: ignore
"""
Dependency function that yields an AsyncSession wrapped in a transaction.
Commits on successful completion of the request handler, rolls back on exceptions.
"""
async with AsyncSessionLocal() as session:
async with session.begin(): # Start a transaction
try:
logger.debug(f"Transaction started for session {id(session)}")
yield session
# If no exceptions were raised by the endpoint, the 'session.begin()'
# context manager will automatically commit here.
logger.debug(f"Transaction committed for session {id(session)}")
except Exception as e:
# The 'session.begin()' context manager will automatically
# rollback on any exception.
logger.error(f"Transaction rolled back for session {id(session)} due to: {e}", exc_info=True)
raise # Re-raise the exception to be handled by FastAPI's error handlers

View File

@ -65,9 +65,11 @@ class User(Base):
# --- Relationships for Cost Splitting --- # --- Relationships for Cost Splitting ---
expenses_paid = relationship("Expense", foreign_keys="Expense.paid_by_user_id", back_populates="paid_by_user", cascade="all, delete-orphan") expenses_paid = relationship("Expense", foreign_keys="Expense.paid_by_user_id", back_populates="paid_by_user", cascade="all, delete-orphan")
expenses_created = relationship("Expense", foreign_keys="Expense.created_by_user_id", back_populates="created_by_user", cascade="all, delete-orphan")
expense_splits = relationship("ExpenseSplit", foreign_keys="ExpenseSplit.user_id", back_populates="user", cascade="all, delete-orphan") expense_splits = relationship("ExpenseSplit", foreign_keys="ExpenseSplit.user_id", back_populates="user", cascade="all, delete-orphan")
settlements_made = relationship("Settlement", foreign_keys="Settlement.paid_by_user_id", back_populates="payer", cascade="all, delete-orphan") settlements_made = relationship("Settlement", foreign_keys="Settlement.paid_by_user_id", back_populates="payer", cascade="all, delete-orphan")
settlements_received = relationship("Settlement", foreign_keys="Settlement.paid_to_user_id", back_populates="payee", cascade="all, delete-orphan") settlements_received = relationship("Settlement", foreign_keys="Settlement.paid_to_user_id", back_populates="payee", cascade="all, delete-orphan")
settlements_created = relationship("Settlement", foreign_keys="Settlement.created_by_user_id", back_populates="created_by_user", cascade="all, delete-orphan")
# --- End Relationships for Cost Splitting --- # --- End Relationships for Cost Splitting ---
@ -197,6 +199,7 @@ class Expense(Base):
group_id = Column(Integer, ForeignKey("groups.id"), nullable=True, index=True) group_id = Column(Integer, ForeignKey("groups.id"), nullable=True, index=True)
item_id = Column(Integer, ForeignKey("items.id"), nullable=True) item_id = Column(Integer, ForeignKey("items.id"), nullable=True)
paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
@ -204,6 +207,7 @@ class Expense(Base):
# Relationships # Relationships
paid_by_user = relationship("User", foreign_keys=[paid_by_user_id], back_populates="expenses_paid") paid_by_user = relationship("User", foreign_keys=[paid_by_user_id], back_populates="expenses_paid")
created_by_user = relationship("User", foreign_keys=[created_by_user_id], back_populates="expenses_created")
list = relationship("List", foreign_keys=[list_id], back_populates="expenses") list = relationship("List", foreign_keys=[list_id], back_populates="expenses")
group = relationship("Group", foreign_keys=[group_id], back_populates="expenses") group = relationship("Group", foreign_keys=[group_id], back_populates="expenses")
item = relationship("Item", foreign_keys=[item_id], back_populates="expenses") item = relationship("Item", foreign_keys=[item_id], back_populates="expenses")
@ -246,6 +250,7 @@ class Settlement(Base):
amount = Column(Numeric(10, 2), nullable=False) amount = Column(Numeric(10, 2), nullable=False)
settlement_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) settlement_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
description = Column(Text, nullable=True) description = Column(Text, nullable=True)
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
@ -255,6 +260,7 @@ class Settlement(Base):
group = relationship("Group", foreign_keys=[group_id], back_populates="settlements") group = relationship("Group", foreign_keys=[group_id], back_populates="settlements")
payer = relationship("User", foreign_keys=[paid_by_user_id], back_populates="settlements_made") payer = relationship("User", foreign_keys=[paid_by_user_id], back_populates="settlements_made")
payee = relationship("User", foreign_keys=[paid_to_user_id], back_populates="settlements_received") payee = relationship("User", foreign_keys=[paid_to_user_id], back_populates="settlements_received")
created_by_user = relationship("User", foreign_keys=[created_by_user_id], back_populates="settlements_created")
__table_args__ = ( __table_args__ = (
# Ensure payer and payee are different users # Ensure payer and payee are different users

View File

@ -79,6 +79,7 @@ class ExpensePublic(ExpenseBase):
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
version: int version: int
created_by_user_id: int
splits: List[ExpenseSplitPublic] = [] splits: List[ExpenseSplitPublic] = []
# paid_by_user: Optional[UserPublic] # If nesting user details # paid_by_user: Optional[UserPublic] # If nesting user details
# list: Optional[ListPublic] # If nesting list details # list: Optional[ListPublic] # If nesting list details
@ -119,9 +120,11 @@ class SettlementPublic(SettlementBase):
id: int id: int
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
# payer: Optional[UserPublic] version: int
# payee: Optional[UserPublic] created_by_user_id: int
# group: Optional[GroupPublic] # payer: Optional[UserPublic] # If we want to include payer details
# payee: Optional[UserPublic] # If we want to include payee details
# group: Optional[GroupPublic] # If we want to include group details
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
# Placeholder for nested schemas (e.g., UserPublic) if needed # Placeholder for nested schemas (e.g., UserPublic) if needed

View File

@ -122,7 +122,9 @@ def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model
group_id=expense_create_data_equal_split_group_ctx.group_id, group_id=expense_create_data_equal_split_group_ctx.group_id,
item_id=expense_create_data_equal_split_group_ctx.item_id, item_id=expense_create_data_equal_split_group_ctx.item_id,
paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id, paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id,
created_by_user_id=basic_user_model.id,
paid_by=basic_user_model, # Assuming paid_by relation is loaded paid_by=basic_user_model, # Assuming paid_by relation is loaded
created_by_user=basic_user_model, # Assuming created_by_user relation is loaded
# splits would be populated after creation usually # splits would be populated after creation usually
version=1 version=1
) )

View File

@ -60,12 +60,14 @@ def db_settlement_model():
amount=Decimal("10.50"), amount=Decimal("10.50"),
settlement_date=datetime.now(timezone.utc), settlement_date=datetime.now(timezone.utc),
description="Original settlement", description="Original settlement",
created_by_user_id=1,
version=1, # Initial version version=1, # Initial version
created_at=datetime.now(timezone.utc), created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc),
payer=UserModel(id=1, name="Payer User"), payer=UserModel(id=1, name="Payer User"),
payee=UserModel(id=2, name="Payee User"), payee=UserModel(id=2, name="Payee User"),
group=GroupModel(id=1, name="Test Group") group=GroupModel(id=1, name="Test Group"),
created_by_user=UserModel(id=1, name="Payer User") # Same as payer for simplicity
) )
@pytest.fixture @pytest.fixture

View File

@ -64,9 +64,9 @@ export const API_ENDPOINTS = {
INVITES: { INVITES: {
BASE: '/invites', BASE: '/invites',
BY_ID: (id: string) => `/invites/${id}`, BY_ID: (id: string) => `/invites/${id}`,
ACCEPT: (id: string) => `/invites/${id}/accept`, ACCEPT: (id: string) => `/invites/accept/${id}`,
DECLINE: (id: string) => `/invites/${id}/decline`, DECLINE: (id: string) => `/invites/decline/${id}`,
REVOKE: (id: string) => `/invites/${id}/revoke`, REVOKE: (id: string) => `/invites/revoke/${id}`,
LIST: '/invites', LIST: '/invites',
PENDING: '/invites/pending', PENDING: '/invites/pending',
SENT: '/invites/sent', SENT: '/invites/sent',