Enhance configuration and error handling in the application; add new error messages for OCR and authentication processes. Refactor database session management to include transaction handling, and update models to track user creation for expenses and settlements. Update API endpoints to improve cost-sharing calculations and adjust invite management routes for clarity.

This commit is contained in:
mohamad 2025-05-17 13:56:17 +02:00
parent c2aa62fa03
commit 5abe7839f1
14 changed files with 294 additions and 82 deletions

View File

@ -4,7 +4,8 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, selectinload
from decimal import Decimal, ROUND_HALF_UP
from decimal import Decimal, ROUND_HALF_UP, ROUND_DOWN
from typing import List
from app.database import get_db
from app.auth import current_active_user
@ -19,7 +20,7 @@ from app.models import (
ExpenseSplit as ExpenseSplitModel,
Settlement as SettlementModel
)
from app.schemas.cost import ListCostSummary, GroupBalanceSummary
from app.schemas.cost import ListCostSummary, GroupBalanceSummary, UserCostShare, UserBalanceDetail, SuggestedSettlement
from app.schemas.expense import ExpenseCreate
from app.crud import list as crud_list
from app.crud import expense as crud_expense
@ -28,6 +29,85 @@ from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotF
logger = logging.getLogger(__name__)
router = APIRouter()
def calculate_suggested_settlements(user_balances: List[UserBalanceDetail]) -> List[SuggestedSettlement]:
"""
Calculate suggested settlements to balance the finances within a group.
This function takes the current balances of all users and suggests optimal settlements
to minimize the number of transactions needed to settle all debts.
Args:
user_balances: List of UserBalanceDetail objects with their current balances
Returns:
List of SuggestedSettlement objects representing the suggested payments
"""
# Create list of users who owe money (negative balance) and who are owed money (positive balance)
debtors = [] # Users who owe money (negative balance)
creditors = [] # Users who are owed money (positive balance)
# Threshold to consider a balance as zero due to floating point precision
epsilon = Decimal('0.01')
# Sort users into debtors and creditors
for user in user_balances:
# Skip users with zero balance (or very close to zero)
if abs(user.net_balance) < epsilon:
continue
if user.net_balance < Decimal('0'):
# User owes money
debtors.append({
'user_id': user.user_id,
'user_identifier': user.user_identifier,
'amount': -user.net_balance # Convert to positive amount
})
else:
# User is owed money
creditors.append({
'user_id': user.user_id,
'user_identifier': user.user_identifier,
'amount': user.net_balance
})
# Sort by amount (descending) to handle largest debts first
debtors.sort(key=lambda x: x['amount'], reverse=True)
creditors.sort(key=lambda x: x['amount'], reverse=True)
settlements = []
# Iterate through debtors and match them with creditors
while debtors and creditors:
debtor = debtors[0]
creditor = creditors[0]
# Determine the settlement amount (the smaller of the two amounts)
amount = min(debtor['amount'], creditor['amount']).quantize(Decimal('0.01'), rounding=ROUND_HALF_UP)
# Create settlement record
if amount > Decimal('0'):
settlements.append(
SuggestedSettlement(
from_user_id=debtor['user_id'],
from_user_identifier=debtor['user_identifier'],
to_user_id=creditor['user_id'],
to_user_identifier=creditor['user_identifier'],
amount=amount
)
)
# Update balances
debtor['amount'] -= amount
creditor['amount'] -= amount
# Remove users who have settled their debts/credits
if debtor['amount'] < epsilon:
debtors.pop(0)
if creditor['amount'] < epsilon:
creditors.pop(0)
return settlements
@router.get(
"/lists/{list_id}/cost-summary",
response_model=ListCostSummary,
@ -105,7 +185,7 @@ async def get_list_cost_summary(
total_amount=total_amount,
list_id=list_id,
split_type=SplitTypeEnum.ITEM_BASED,
paid_by_user_id=current_user.id # Use current user as payer for now
paid_by_user_id=db_list.creator.id
)
db_expense = await crud_expense.create_expense(db=db, expense_in=expense_in)
@ -137,17 +217,36 @@ async def get_list_cost_summary(
user_balances=[]
)
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
remainder = total_list_cost - (equal_share_per_user * num_participating_users)
# This is the ideal equal share, returned in the summary
equal_share_per_user_for_response = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
# Sort users for deterministic remainder distribution
sorted_participating_users = sorted(list(participating_users), key=lambda u: u.id)
user_final_shares = {}
if num_participating_users > 0:
base_share_unrounded = total_list_cost / Decimal(num_participating_users)
# Calculate initial share for each user, rounding down
for user in sorted_participating_users:
user_final_shares[user.id] = base_share_unrounded.quantize(Decimal("0.01"), rounding=ROUND_DOWN)
# Calculate sum of rounded down shares
sum_of_rounded_shares = sum(user_final_shares.values())
# Calculate remaining pennies to be distributed
remaining_pennies = int(((total_list_cost - sum_of_rounded_shares) * Decimal("100")).to_integral_value(rounding=ROUND_HALF_UP))
# Distribute remaining pennies one by one to sorted users
for i in range(remaining_pennies):
user_to_adjust = sorted_participating_users[i % num_participating_users]
user_final_shares[user_to_adjust.id] += Decimal("0.01")
user_balances = []
first_user_processed = False
for user in participating_users:
for user in sorted_participating_users: # Iterate over sorted users
items_added = user_items_added_value.get(user.id, Decimal("0.00"))
current_user_share = equal_share_per_user
if not first_user_processed and remainder != Decimal("0"):
current_user_share += remainder
first_user_processed = True
# current_user_share is now the precisely calculated share for this user
current_user_share = user_final_shares.get(user.id, Decimal("0.00"))
balance = items_added - current_user_share
user_identifier = user.name if user.name else user.email
@ -167,7 +266,7 @@ async def get_list_cost_summary(
list_name=db_list.name,
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
num_participating_users=num_participating_users,
equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
equal_share_per_user=equal_share_per_user_for_response, # Use the ideal share for the response field
user_balances=user_balances
)

View File

@ -7,7 +7,7 @@ from google.api_core import exceptions as google_exceptions
from app.auth import current_active_user
from app.models import User as UserModel
from app.schemas.ocr import OcrExtractResponse
from app.core.gemini import extract_items_from_image_gemini, gemini_initialization_error, GeminiOCRService
from app.core.gemini import GeminiOCRService, gemini_initialization_error
from app.core.exceptions import (
OCRServiceUnavailableError,
OCRServiceConfigError,
@ -56,11 +56,8 @@ async def ocr_extract_items(
raise FileTooLargeError()
try:
# Call the Gemini helper function
extracted_items = await extract_items_from_image_gemini(
image_bytes=contents,
mime_type=image_file.content_type
)
# Use the ocr_service instance instead of the standalone function
extracted_items = await ocr_service.extract_items(image_data=contents)
logger.info(f"Successfully extracted {len(extracted_items)} items for user {current_user.email}.")
return OcrExtractResponse(extracted_items=extracted_items)

View File

@ -16,8 +16,7 @@ class Settings(BaseSettings):
# --- JWT Settings --- (SECRET_KEY is used by FastAPI-Users)
SECRET_KEY: str # Must be set via environment variable
# ALGORITHM: str = "HS256" # Handled by FastAPI-Users strategy
# ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # This specific line is commented, the one under Session Settings is used.
# FastAPI-Users handles JWT algorithm internally
# --- OCR Settings ---
MAX_FILE_SIZE_MB: int = 10 # Maximum allowed file size for OCR processing
@ -36,6 +35,14 @@ Bread
__Apples__
Organic Bananas
"""
# --- OCR Error Messages ---
OCR_SERVICE_UNAVAILABLE: str = "OCR service is currently unavailable. Please try again later."
OCR_SERVICE_CONFIG_ERROR: str = "OCR service configuration error. Please contact support."
OCR_UNEXPECTED_ERROR: str = "An unexpected error occurred during OCR processing."
OCR_QUOTA_EXCEEDED: str = "OCR service quota exceeded. Please try again later."
OCR_INVALID_FILE_TYPE: str = "Invalid file type. Supported types: {types}"
OCR_FILE_TOO_LARGE: str = "File too large. Maximum size: {size}MB"
OCR_PROCESSING_ERROR: str = "Error processing image: {detail}"
# --- Gemini AI Settings ---
GEMINI_MODEL_NAME: str = "gemini-2.0-flash" # The model to use for OCR
@ -98,6 +105,14 @@ Organic Bananas
DB_TRANSACTION_ERROR: str = "Database transaction error"
DB_QUERY_ERROR: str = "Database query error"
# --- Auth Error Messages ---
AUTH_INVALID_CREDENTIALS: str = "Invalid username or password"
AUTH_NOT_AUTHENTICATED: str = "Not authenticated"
AUTH_JWT_ERROR: str = "JWT token error: {error}"
AUTH_JWT_UNEXPECTED_ERROR: str = "Unexpected JWT error: {error}"
AUTH_HEADER_NAME: str = "WWW-Authenticate"
AUTH_HEADER_PREFIX: str = "Bearer"
# OAuth Settings
GOOGLE_CLIENT_ID: str = ""
GOOGLE_CLIENT_SECRET: str = ""

View File

@ -295,7 +295,7 @@ class JWTError(HTTPException):
def __init__(self, error: str):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.JWT_ERROR.format(error=error),
detail=settings.AUTH_JWT_ERROR.format(error=error),
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
)
@ -304,7 +304,7 @@ class JWTUnexpectedError(HTTPException):
def __init__(self, error: str):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.JWT_UNEXPECTED_ERROR.format(error=error),
detail=settings.AUTH_JWT_UNEXPECTED_ERROR.format(error=error),
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
)

View File

@ -9,7 +9,8 @@ from app.core.exceptions import (
OCRServiceUnavailableError,
OCRServiceConfigError,
OCRUnexpectedError,
OCRQuotaExceededError
OCRQuotaExceededError,
OCRProcessingError
)
logger = logging.getLogger(__name__)
@ -55,10 +56,10 @@ def get_gemini_client():
Raises an exception if initialization failed.
"""
if gemini_initialization_error:
raise RuntimeError(f"Gemini client could not be initialized: {gemini_initialization_error}")
raise OCRServiceConfigError()
if gemini_flash_client is None:
# This case should ideally be covered by the check above, but as a safeguard:
raise RuntimeError("Gemini client is not available (unknown initialization issue).")
raise OCRServiceConfigError()
return gemini_flash_client
# Define the prompt as a constant
@ -88,11 +89,14 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "
A list of extracted item strings.
Raises:
RuntimeError: If the Gemini client is not initialized.
google_exceptions.GoogleAPIError: For API call errors (quota, invalid key etc.).
ValueError: If the response is blocked or contains no usable text.
OCRServiceConfigError: If the Gemini client is not initialized.
OCRQuotaExceededError: If API quota is exceeded.
OCRServiceUnavailableError: For general API call errors.
OCRProcessingError: If the response is blocked or contains no usable text.
OCRUnexpectedError: For any other unexpected errors.
"""
client = get_gemini_client() # Raises RuntimeError if not initialized
try:
client = get_gemini_client() # Raises OCRServiceConfigError if not initialized
# Prepare image part for multimodal input
image_part = {
@ -107,7 +111,7 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "
]
logger.info("Sending image to Gemini for item extraction...")
try:
# Make the API call
# Use generate_content_async for async FastAPI
response = await client.generate_content_async(prompt_parts)
@ -120,9 +124,9 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "
finish_reason = response.candidates[0].finish_reason if response.candidates else 'UNKNOWN'
safety_ratings = response.candidates[0].safety_ratings if response.candidates else 'N/A'
if finish_reason == 'SAFETY':
raise ValueError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
raise OCRProcessingError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
else:
raise ValueError(f"Gemini response was empty or incomplete. Finish Reason: {finish_reason}")
raise OCRUnexpectedError()
# Extract text - assumes the first part of the first candidate is the text response
raw_text = response.text # response.text is a shortcut for response.candidates[0].content.parts[0].text
@ -143,32 +147,59 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "
except google_exceptions.GoogleAPIError as e:
logger.error(f"Gemini API Error: {e}", exc_info=True)
# Re-raise specific Google API errors for endpoint to handle (e.g., quota)
raise e
if "quota" in str(e).lower():
raise OCRQuotaExceededError()
raise OCRServiceUnavailableError()
except (OCRServiceConfigError, OCRQuotaExceededError, OCRServiceUnavailableError, OCRProcessingError, OCRUnexpectedError):
# Re-raise specific OCR exceptions
raise
except Exception as e:
# Catch other unexpected errors during generation or processing
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
# Wrap in a generic ValueError or re-raise
raise ValueError(f"Failed to process image with Gemini: {e}") from e
# Wrap in a custom exception
raise OCRUnexpectedError()
class GeminiOCRService:
def __init__(self):
try:
genai.configure(api_key=settings.GEMINI_API_KEY)
self.model = genai.GenerativeModel(settings.GEMINI_MODEL_NAME)
self.model.safety_settings = settings.GEMINI_SAFETY_SETTINGS
self.model.generation_config = settings.GEMINI_GENERATION_CONFIG
self.model = genai.GenerativeModel(
model_name=settings.GEMINI_MODEL_NAME,
# Safety settings from config
safety_settings={
getattr(HarmCategory, category): getattr(HarmBlockThreshold, threshold)
for category, threshold in settings.GEMINI_SAFETY_SETTINGS.items()
},
# Generation config from settings
generation_config=genai.types.GenerationConfig(
**settings.GEMINI_GENERATION_CONFIG
)
)
except Exception as e:
logger.error(f"Failed to initialize Gemini client: {e}")
raise OCRServiceConfigError()
async def extract_items(self, image_data: bytes) -> List[str]:
async def extract_items(self, image_data: bytes, mime_type: str = "image/jpeg") -> List[str]:
"""
Extract shopping list items from an image using Gemini Vision.
Args:
image_data: The image content as bytes.
mime_type: The MIME type of the image (e.g., "image/jpeg", "image/png", "image/webp").
Returns:
A list of extracted item strings.
Raises:
OCRServiceConfigError: If the Gemini client is not initialized.
OCRQuotaExceededError: If API quota is exceeded.
OCRServiceUnavailableError: For general API call errors.
OCRProcessingError: If the response is blocked or contains no usable text.
OCRUnexpectedError: For any other unexpected errors.
"""
try:
# Create image part
image_parts = [{"mime_type": "image/jpeg", "data": image_data}]
image_parts = [{"mime_type": mime_type, "data": image_data}]
# Generate content
response = await self.model.generate_content_async(
@ -177,19 +208,34 @@ class GeminiOCRService:
# Process response
if not response.text:
logger.warning("Gemini response is empty")
raise OCRUnexpectedError()
# Split response into lines and clean up
items = [
item.strip()
for item in response.text.split("\n")
if item.strip() and not item.strip().startswith("Example")
]
# Check for safety blocks
if hasattr(response, 'candidates') and response.candidates and hasattr(response.candidates[0], 'finish_reason'):
finish_reason = response.candidates[0].finish_reason
if finish_reason == 'SAFETY':
safety_ratings = response.candidates[0].safety_ratings if hasattr(response.candidates[0], 'safety_ratings') else 'N/A'
raise OCRProcessingError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
# Split response into lines and clean up
items = []
for line in response.text.splitlines():
cleaned_line = line.strip()
if cleaned_line and len(cleaned_line) > 1 and not cleaned_line.startswith("Example"):
items.append(cleaned_line)
logger.info(f"Extracted {len(items)} potential items.")
return items
except Exception as e:
except google_exceptions.GoogleAPIError as e:
logger.error(f"Error during OCR extraction: {e}")
if "quota" in str(e).lower():
raise OCRQuotaExceededError()
raise OCRServiceUnavailableError()
except (OCRServiceConfigError, OCRQuotaExceededError, OCRServiceUnavailableError, OCRProcessingError, OCRUnexpectedError):
# Re-raise specific OCR exceptions
raise
except Exception as e:
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
raise OCRUnexpectedError()

View File

@ -8,6 +8,9 @@ from passlib.context import CryptContext
from app.config import settings # Import settings from config
# --- Password Hashing ---
# These functions are used for password hashing and verification
# They complement FastAPI-Users but provide direct access to the underlying password functionality
# when needed outside of the FastAPI-Users authentication flow.
# Configure passlib context
# Using bcrypt as the default hashing scheme
@ -17,6 +20,8 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
Verifies a plain text password against a hashed password.
This is used by FastAPI-Users internally, but also exposed here for custom authentication flows
if needed.
Args:
plain_password: The password attempt.
@ -34,6 +39,8 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
def hash_password(password: str) -> str:
"""
Hashes a plain text password using the configured context (bcrypt).
This is used by FastAPI-Users internally, but also exposed here for
custom user creation or password reset flows if needed.
Args:
password: The plain text password to hash.
@ -45,14 +52,22 @@ def hash_password(password: str) -> str:
# --- JSON Web Tokens (JWT) ---
# FastAPI-Users now handles all tokenization.
# FastAPI-Users now handles all JWT token creation and validation.
# The code below is commented out because FastAPI-Users provides these features.
# It's kept for reference in case a custom implementation is needed later.
# You might add a function here later to extract the 'sub' (subject/user id)
# specifically, often used in dependency injection for authentication.
# Example of a potential future implementation:
# def get_subject_from_token(token: str) -> Optional[str]:
# """
# Extract the subject (user ID) from a JWT token.
# This would be used if we need to validate tokens outside of FastAPI-Users flow.
# For now, use fastapi_users.current_user dependency instead.
# """
# # This would need to use FastAPI-Users' token verification if ever implemented
# # For example, by decoding the token using the strategy from the auth backend
# payload = {} # Placeholder for actual token decoding logic
# if payload:
# try:
# payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
# return payload.get("sub")
# except JWTError:
# return None
# return None

View File

@ -59,8 +59,8 @@ async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, c
paid_to_user_id=settlement_in.paid_to_user_id,
amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc),
description=settlement_in.description
# created_by_user_id = current_user_id # Optional: Who recorded this settlement
description=settlement_in.description,
created_by_user_id=current_user_id
)
db.add(db_settlement)
await db.flush()
@ -72,7 +72,8 @@ async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, c
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group)
selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
)
)
result = await db.execute(stmt)
@ -103,7 +104,8 @@ async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group)
selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
)
.where(SettlementModel.id == settlement_id)
)
@ -123,7 +125,8 @@ async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int =
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group)
selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
)
)
return result.scalars().all()
@ -149,7 +152,8 @@ async def get_settlements_involving_user(
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group)
selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
)
)
if group_id:
@ -216,7 +220,8 @@ async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, se
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group)
selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
)
)
result = await db.execute(stmt)

View File

@ -2,6 +2,9 @@
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
from app.config import settings
import logging
logger = logging.getLogger(__name__)
# Ensure DATABASE_URL is set before proceeding
if not settings.DATABASE_URL:
@ -42,3 +45,22 @@ async def get_async_session() -> AsyncSession: # type: ignore
# Alias for backward compatibility
get_db = get_async_session
async def get_transactional_session() -> AsyncSession: # type: ignore
"""
Dependency function that yields an AsyncSession wrapped in a transaction.
Commits on successful completion of the request handler, rolls back on exceptions.
"""
async with AsyncSessionLocal() as session:
async with session.begin(): # Start a transaction
try:
logger.debug(f"Transaction started for session {id(session)}")
yield session
# If no exceptions were raised by the endpoint, the 'session.begin()'
# context manager will automatically commit here.
logger.debug(f"Transaction committed for session {id(session)}")
except Exception as e:
# The 'session.begin()' context manager will automatically
# rollback on any exception.
logger.error(f"Transaction rolled back for session {id(session)} due to: {e}", exc_info=True)
raise # Re-raise the exception to be handled by FastAPI's error handlers

View File

@ -65,9 +65,11 @@ class User(Base):
# --- Relationships for Cost Splitting ---
expenses_paid = relationship("Expense", foreign_keys="Expense.paid_by_user_id", back_populates="paid_by_user", cascade="all, delete-orphan")
expenses_created = relationship("Expense", foreign_keys="Expense.created_by_user_id", back_populates="created_by_user", cascade="all, delete-orphan")
expense_splits = relationship("ExpenseSplit", foreign_keys="ExpenseSplit.user_id", back_populates="user", cascade="all, delete-orphan")
settlements_made = relationship("Settlement", foreign_keys="Settlement.paid_by_user_id", back_populates="payer", cascade="all, delete-orphan")
settlements_received = relationship("Settlement", foreign_keys="Settlement.paid_to_user_id", back_populates="payee", cascade="all, delete-orphan")
settlements_created = relationship("Settlement", foreign_keys="Settlement.created_by_user_id", back_populates="created_by_user", cascade="all, delete-orphan")
# --- End Relationships for Cost Splitting ---
@ -197,6 +199,7 @@ class Expense(Base):
group_id = Column(Integer, ForeignKey("groups.id"), nullable=True, index=True)
item_id = Column(Integer, ForeignKey("items.id"), nullable=True)
paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
@ -204,6 +207,7 @@ class Expense(Base):
# Relationships
paid_by_user = relationship("User", foreign_keys=[paid_by_user_id], back_populates="expenses_paid")
created_by_user = relationship("User", foreign_keys=[created_by_user_id], back_populates="expenses_created")
list = relationship("List", foreign_keys=[list_id], back_populates="expenses")
group = relationship("Group", foreign_keys=[group_id], back_populates="expenses")
item = relationship("Item", foreign_keys=[item_id], back_populates="expenses")
@ -246,6 +250,7 @@ class Settlement(Base):
amount = Column(Numeric(10, 2), nullable=False)
settlement_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
description = Column(Text, nullable=True)
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
@ -255,6 +260,7 @@ class Settlement(Base):
group = relationship("Group", foreign_keys=[group_id], back_populates="settlements")
payer = relationship("User", foreign_keys=[paid_by_user_id], back_populates="settlements_made")
payee = relationship("User", foreign_keys=[paid_to_user_id], back_populates="settlements_received")
created_by_user = relationship("User", foreign_keys=[created_by_user_id], back_populates="settlements_created")
__table_args__ = (
# Ensure payer and payee are different users

View File

@ -79,6 +79,7 @@ class ExpensePublic(ExpenseBase):
created_at: datetime
updated_at: datetime
version: int
created_by_user_id: int
splits: List[ExpenseSplitPublic] = []
# paid_by_user: Optional[UserPublic] # If nesting user details
# list: Optional[ListPublic] # If nesting list details
@ -119,9 +120,11 @@ class SettlementPublic(SettlementBase):
id: int
created_at: datetime
updated_at: datetime
# payer: Optional[UserPublic]
# payee: Optional[UserPublic]
# group: Optional[GroupPublic]
version: int
created_by_user_id: int
# payer: Optional[UserPublic] # If we want to include payer details
# payee: Optional[UserPublic] # If we want to include payee details
# group: Optional[GroupPublic] # If we want to include group details
model_config = ConfigDict(from_attributes=True)
# Placeholder for nested schemas (e.g., UserPublic) if needed

View File

@ -122,7 +122,9 @@ def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model
group_id=expense_create_data_equal_split_group_ctx.group_id,
item_id=expense_create_data_equal_split_group_ctx.item_id,
paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id,
created_by_user_id=basic_user_model.id,
paid_by=basic_user_model, # Assuming paid_by relation is loaded
created_by_user=basic_user_model, # Assuming created_by_user relation is loaded
# splits would be populated after creation usually
version=1
)

View File

@ -60,12 +60,14 @@ def db_settlement_model():
amount=Decimal("10.50"),
settlement_date=datetime.now(timezone.utc),
description="Original settlement",
created_by_user_id=1,
version=1, # Initial version
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
payer=UserModel(id=1, name="Payer User"),
payee=UserModel(id=2, name="Payee User"),
group=GroupModel(id=1, name="Test Group")
group=GroupModel(id=1, name="Test Group"),
created_by_user=UserModel(id=1, name="Payer User") # Same as payer for simplicity
)
@pytest.fixture

View File

@ -64,9 +64,9 @@ export const API_ENDPOINTS = {
INVITES: {
BASE: '/invites',
BY_ID: (id: string) => `/invites/${id}`,
ACCEPT: (id: string) => `/invites/${id}/accept`,
DECLINE: (id: string) => `/invites/${id}/decline`,
REVOKE: (id: string) => `/invites/${id}/revoke`,
ACCEPT: (id: string) => `/invites/accept/${id}`,
DECLINE: (id: string) => `/invites/decline/${id}`,
REVOKE: (id: string) => `/invites/revoke/${id}`,
LIST: '/invites',
PENDING: '/invites/pending',
SENT: '/invites/sent',