Compare commits

..

No commits in common. "eb19230b22f57fe861befb4de7dcc75cdd45a795" and "515534dcce70a870a6f700377019cf30dd1a9a8c" have entirely different histories.

55 changed files with 2421 additions and 5714 deletions

View File

@ -1,57 +0,0 @@
---
description: FastAPI Database Transactions
globs:
alwaysApply: false
---
## FastAPI Database Transaction Management: Technical Specification
**Objective:** Ensure atomic, consistent, isolated, and durable (ACID) database operations through a standardized transaction management strategy.
**1. API Endpoint Transaction Scope (Primary Strategy):**
* **Mechanism:** A FastAPI dependency `get_transactional_session` (from `app.database` or `app.core.dependencies`) wraps database-modifying API request handlers.
* **Behavior:**
* `async with AsyncSessionLocal() as session:` obtains a session.
* `async with session.begin():` starts a transaction.
* **Commit:** Automatic on successful completion of the `yield session` block (i.e., endpoint handler success).
* **Rollback:** Automatic on any exception raised from the `yield session` block.
* **Usage:** Endpoints performing CUD (Create, Update, Delete) operations **MUST** use `db: AsyncSession = Depends(get_transactional_session)`.
* **Read-Only Endpoints:** May use `get_async_session` (alias `get_db`) or `get_transactional_session` (results in an empty transaction).
**2. CRUD Layer Function Design:**
* **Transaction Participation:** CRUD functions (in `app/crud/`) operate on the session provided by the caller.
* **Composability Pattern:** Employ `async with db.begin_nested() if db.in_transaction() else db.begin():` to wrap database modification logic within the CRUD function.
* If an outer transaction exists (e.g., from `get_transactional_session`), `begin_nested()` creates a **savepoint**. The `async with` block commits/rolls back this savepoint.
* If no outer transaction exists (e.g., direct call from a script), `begin()` starts a **new transaction**. The `async with` block commits/rolls back this transaction.
* **NO Direct `db.commit()` / `db.rollback()`:** CRUD functions **MUST NOT** call these directly. The `async with begin_nested()/begin()` block and the outermost transaction manager are responsible.
* **`await db.flush()`:** Use only when necessary within the `async with` block to:
1. Obtain auto-generated IDs for subsequent operations in the *same* transaction.
2. Force database constraint checks mid-transaction.
* **Error Handling:** Raise specific custom exceptions (e.g., `ListNotFoundError`, `DatabaseIntegrityError`). These exceptions will trigger rollbacks in the managing transaction contexts.
**3. Non-API Operations (Background Tasks, Scripts):**
* **Explicit Management:** These contexts **MUST** manage their own session and transaction lifecycles.
* **Pattern:**
```python
async with AsyncSessionLocal() as session:
async with session.begin(): # Manages transaction for the task's scope
try:
# Call CRUD functions, which will participate via savepoints
await crud_operation_1(db=session, ...)
await crud_operation_2(db=session, ...)
# Commit is handled by session.begin() context manager on success
except Exception:
# Rollback is handled by session.begin() context manager on error
raise
```
**4. Key Principles Summary:**
* **API:** `get_transactional_session` for CUD.
* **CRUD:** Use `async with db.begin_nested() if db.in_transaction() else db.begin():`. No direct commit/rollback. Use `flush()` strategically.
* **Background Tasks:** Explicit `AsyncSessionLocal()` and `session.begin()` context managers.
This strategy ensures a clear separation of concerns, promotes composable CRUD operations, and centralizes final transaction control at the appropriate layer.

View File

@ -1,42 +0,0 @@
"""Initial database schema
Revision ID: 5271d18372e5
Revises: 5e8b6dde50fc
Create Date: 2025-05-17 14:39:03.690180
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '5271d18372e5'
down_revision: Union[str, None] = '5e8b6dde50fc'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('expenses', sa.Column('created_by_user_id', sa.Integer(), nullable=False))
op.create_index(op.f('ix_expenses_created_by_user_id'), 'expenses', ['created_by_user_id'], unique=False)
op.create_foreign_key(None, 'expenses', 'users', ['created_by_user_id'], ['id'])
op.add_column('settlements', sa.Column('created_by_user_id', sa.Integer(), nullable=False))
op.create_index(op.f('ix_settlements_created_by_user_id'), 'settlements', ['created_by_user_id'], unique=False)
op.create_foreign_key(None, 'settlements', 'users', ['created_by_user_id'], ['id'])
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, 'settlements', type_='foreignkey')
op.drop_index(op.f('ix_settlements_created_by_user_id'), table_name='settlements')
op.drop_column('settlements', 'created_by_user_id')
op.drop_constraint(None, 'expenses', type_='foreignkey')
op.drop_index(op.f('ix_expenses_created_by_user_id'), table_name='expenses')
op.drop_column('expenses', 'created_by_user_id')
# ### end Alembic commands ###

View File

@ -6,8 +6,6 @@ Create Date: 2025-05-13 23:30:02.005611
"""
from typing import Sequence, Union
import secrets
from passlib.context import CryptContext
from alembic import op
import sqlalchemy as sa
@ -22,21 +20,14 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# Create password hasher
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Generate a secure random password and hash it
random_password = secrets.token_urlsafe(32) # 32 bytes of randomness
secure_hash = pwd_context.hash(random_password)
# 1. Add columns as nullable or with a default
op.add_column('users', sa.Column('hashed_password', sa.String(), nullable=True))
op.add_column('users', sa.Column('is_active', sa.Boolean(), nullable=True, server_default=sa.sql.expression.true()))
op.add_column('users', sa.Column('is_superuser', sa.Boolean(), nullable=True, server_default=sa.sql.expression.false()))
op.add_column('users', sa.Column('is_verified', sa.Boolean(), nullable=True, server_default=sa.sql.expression.false()))
# 2. Set default values for existing rows with secure hash
op.execute(f"UPDATE users SET hashed_password = '{secure_hash}' WHERE hashed_password IS NULL")
# 2. Set default values for existing rows
op.execute("UPDATE users SET hashed_password = '' WHERE hashed_password IS NULL")
op.execute("UPDATE users SET is_active = true WHERE is_active IS NULL")
op.execute("UPDATE users SET is_superuser = false WHERE is_superuser IS NULL")
op.execute("UPDATE users SET is_verified = false WHERE is_verified IS NULL")

View File

@ -1,32 +0,0 @@
"""Initial database schema
Revision ID: 5ed3ccbf05f7
Revises: 5271d18372e5
Create Date: 2025-05-17 14:40:52.165607
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '5ed3ccbf05f7'
down_revision: Union[str, None] = '5271d18372e5'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###

View File

@ -1,32 +0,0 @@
"""check_models_alignment
Revision ID: 8efbdc779a76
Revises: 5ed3ccbf05f7
Create Date: 2025-05-17 15:03:08.242908
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '8efbdc779a76'
down_revision: Union[str, None] = '5ed3ccbf05f7'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###

View File

@ -2,9 +2,9 @@ from fastapi import APIRouter, Depends, Request
from fastapi.responses import RedirectResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.database import get_transactional_session
from app.database import get_async_session
from app.models import User
from app.auth import oauth, fastapi_users, auth_backend
from app.auth import oauth, fastapi_users
from app.config import settings
router = APIRouter()
@ -14,7 +14,7 @@ async def google_login(request: Request):
return await oauth.google.authorize_redirect(request, settings.GOOGLE_REDIRECT_URI)
@router.get('/google/callback')
async def google_callback(request: Request, db: AsyncSession = Depends(get_transactional_session)):
async def google_callback(request: Request, db: AsyncSession = Depends(get_async_session)):
token_data = await oauth.google.authorize_access_token(request)
user_info = await oauth.google.parse_id_token(request, token_data)
@ -31,28 +31,25 @@ async def google_callback(request: Request, db: AsyncSession = Depends(get_trans
is_active=True
)
db.add(new_user)
await db.flush() # Use flush instead of commit since we're in a transaction
await db.commit()
await db.refresh(new_user)
user_to_login = new_user
# Generate JWT token
strategy = auth_backend.get_strategy()
token_response = await strategy.write_token(user_to_login)
access_token = token_response["access_token"]
refresh_token = token_response.get("refresh_token") # Use .get for safety, though it should be there
# Redirect to frontend with tokens
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}"
if refresh_token:
redirect_url += f"&refresh_token={refresh_token}"
strategy = fastapi_users._auth_backends[0].get_strategy()
token = await strategy.write_token(user_to_login)
return RedirectResponse(url=redirect_url)
# Redirect to frontend with token
return RedirectResponse(
url=f"{settings.FRONTEND_URL}/auth/callback?token={token}"
)
@router.get('/apple/login')
async def apple_login(request: Request):
return await oauth.apple.authorize_redirect(request, settings.APPLE_REDIRECT_URI)
@router.get('/apple/callback')
async def apple_callback(request: Request, db: AsyncSession = Depends(get_transactional_session)):
async def apple_callback(request: Request, db: AsyncSession = Depends(get_async_session)):
token_data = await oauth.apple.authorize_access_token(request)
user_info = token_data.get('user', await oauth.apple.userinfo(token=token_data) if hasattr(oauth.apple, 'userinfo') else {})
if 'email' not in user_info and 'sub' in token_data:
@ -80,18 +77,15 @@ async def apple_callback(request: Request, db: AsyncSession = Depends(get_transa
is_active=True
)
db.add(new_user)
await db.flush() # Use flush instead of commit since we're in a transaction
await db.commit()
await db.refresh(new_user)
user_to_login = new_user
# Generate JWT token
strategy = auth_backend.get_strategy()
token_response = await strategy.write_token(user_to_login)
access_token = token_response["access_token"]
refresh_token = token_response.get("refresh_token") # Use .get for safety
# Redirect to frontend with tokens
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}"
if refresh_token:
redirect_url += f"&refresh_token={refresh_token}"
return RedirectResponse(url=redirect_url)
strategy = fastapi_users._auth_backends[0].get_strategy()
token = await strategy.write_token(user_to_login)
# Redirect to frontend with token
return RedirectResponse(
url=f"{settings.FRONTEND_URL}/auth/callback?token={token}"
)

View File

@ -4,10 +4,9 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, selectinload
from decimal import Decimal, ROUND_HALF_UP, ROUND_DOWN
from typing import List
from decimal import Decimal, ROUND_HALF_UP
from app.database import get_transactional_session
from app.database import get_db
from app.auth import current_active_user
from app.models import (
User as UserModel,
@ -20,7 +19,7 @@ from app.models import (
ExpenseSplit as ExpenseSplitModel,
Settlement as SettlementModel
)
from app.schemas.cost import ListCostSummary, GroupBalanceSummary, UserCostShare, UserBalanceDetail, SuggestedSettlement
from app.schemas.cost import ListCostSummary, GroupBalanceSummary
from app.schemas.expense import ExpenseCreate
from app.crud import list as crud_list
from app.crud import expense as crud_expense
@ -29,85 +28,6 @@ from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotF
logger = logging.getLogger(__name__)
router = APIRouter()
def calculate_suggested_settlements(user_balances: List[UserBalanceDetail]) -> List[SuggestedSettlement]:
"""
Calculate suggested settlements to balance the finances within a group.
This function takes the current balances of all users and suggests optimal settlements
to minimize the number of transactions needed to settle all debts.
Args:
user_balances: List of UserBalanceDetail objects with their current balances
Returns:
List of SuggestedSettlement objects representing the suggested payments
"""
# Create list of users who owe money (negative balance) and who are owed money (positive balance)
debtors = [] # Users who owe money (negative balance)
creditors = [] # Users who are owed money (positive balance)
# Threshold to consider a balance as zero due to floating point precision
epsilon = Decimal('0.01')
# Sort users into debtors and creditors
for user in user_balances:
# Skip users with zero balance (or very close to zero)
if abs(user.net_balance) < epsilon:
continue
if user.net_balance < Decimal('0'):
# User owes money
debtors.append({
'user_id': user.user_id,
'user_identifier': user.user_identifier,
'amount': -user.net_balance # Convert to positive amount
})
else:
# User is owed money
creditors.append({
'user_id': user.user_id,
'user_identifier': user.user_identifier,
'amount': user.net_balance
})
# Sort by amount (descending) to handle largest debts first
debtors.sort(key=lambda x: x['amount'], reverse=True)
creditors.sort(key=lambda x: x['amount'], reverse=True)
settlements = []
# Iterate through debtors and match them with creditors
while debtors and creditors:
debtor = debtors[0]
creditor = creditors[0]
# Determine the settlement amount (the smaller of the two amounts)
amount = min(debtor['amount'], creditor['amount']).quantize(Decimal('0.01'), rounding=ROUND_HALF_UP)
# Create settlement record
if amount > Decimal('0'):
settlements.append(
SuggestedSettlement(
from_user_id=debtor['user_id'],
from_user_identifier=debtor['user_identifier'],
to_user_id=creditor['user_id'],
to_user_identifier=creditor['user_identifier'],
amount=amount
)
)
# Update balances
debtor['amount'] -= amount
creditor['amount'] -= amount
# Remove users who have settled their debts/credits
if debtor['amount'] < epsilon:
debtors.pop(0)
if creditor['amount'] < epsilon:
creditors.pop(0)
return settlements
@router.get(
"/lists/{list_id}/cost-summary",
response_model=ListCostSummary,
@ -120,7 +40,7 @@ def calculate_suggested_settlements(user_balances: List[UserBalanceDetail]) -> L
)
async def get_list_cost_summary(
list_id: int,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""
@ -185,7 +105,7 @@ async def get_list_cost_summary(
total_amount=total_amount,
list_id=list_id,
split_type=SplitTypeEnum.ITEM_BASED,
paid_by_user_id=db_list.creator.id
paid_by_user_id=current_user.id # Use current user as payer for now
)
db_expense = await crud_expense.create_expense(db=db, expense_in=expense_in)
@ -217,36 +137,17 @@ async def get_list_cost_summary(
user_balances=[]
)
# This is the ideal equal share, returned in the summary
equal_share_per_user_for_response = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
# Sort users for deterministic remainder distribution
sorted_participating_users = sorted(list(participating_users), key=lambda u: u.id)
user_final_shares = {}
if num_participating_users > 0:
base_share_unrounded = total_list_cost / Decimal(num_participating_users)
# Calculate initial share for each user, rounding down
for user in sorted_participating_users:
user_final_shares[user.id] = base_share_unrounded.quantize(Decimal("0.01"), rounding=ROUND_DOWN)
# Calculate sum of rounded down shares
sum_of_rounded_shares = sum(user_final_shares.values())
# Calculate remaining pennies to be distributed
remaining_pennies = int(((total_list_cost - sum_of_rounded_shares) * Decimal("100")).to_integral_value(rounding=ROUND_HALF_UP))
# Distribute remaining pennies one by one to sorted users
for i in range(remaining_pennies):
user_to_adjust = sorted_participating_users[i % num_participating_users]
user_final_shares[user_to_adjust.id] += Decimal("0.01")
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
remainder = total_list_cost - (equal_share_per_user * num_participating_users)
user_balances = []
for user in sorted_participating_users: # Iterate over sorted users
first_user_processed = False
for user in participating_users:
items_added = user_items_added_value.get(user.id, Decimal("0.00"))
# current_user_share is now the precisely calculated share for this user
current_user_share = user_final_shares.get(user.id, Decimal("0.00"))
current_user_share = equal_share_per_user
if not first_user_processed and remainder != Decimal("0"):
current_user_share += remainder
first_user_processed = True
balance = items_added - current_user_share
user_identifier = user.name if user.name else user.email
@ -266,7 +167,7 @@ async def get_list_cost_summary(
list_name=db_list.name,
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
num_participating_users=num_participating_users,
equal_share_per_user=equal_share_per_user_for_response, # Use the ideal share for the response field
equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
user_balances=user_balances
)
@ -282,7 +183,7 @@ async def get_list_cost_summary(
)
async def get_group_balance_summary(
group_id: int,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""

View File

@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from typing import List as PyList, Optional, Sequence
from app.database import get_transactional_session
from app.database import get_db
from app.auth import current_active_user
from app.models import User as UserModel, Group as GroupModel, List as ListModel, UserGroup as UserGroupModel, UserRoleEnum
from app.schemas.expense import (
@ -46,7 +46,7 @@ async def check_list_access_for_financials(db: AsyncSession, list_id: int, user_
)
async def create_new_expense(
expense_in: ExpenseCreate,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} creating expense: {expense_in.description}")
@ -109,7 +109,7 @@ async def create_new_expense(
@router.get("/expenses/{expense_id}", response_model=ExpensePublic, summary="Get Expense by ID", tags=["Expenses"])
async def get_expense(
expense_id: int,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} requesting expense ID {expense_id}")
@ -130,7 +130,7 @@ async def list_list_expenses(
list_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=200),
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} listing expenses for list ID {list_id}")
@ -143,7 +143,7 @@ async def list_group_expenses(
group_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=200),
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} listing expenses for group ID {group_id}")
@ -155,7 +155,7 @@ async def list_group_expenses(
async def update_expense_details(
expense_id: int,
expense_in: ExpenseUpdate,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""
@ -209,7 +209,7 @@ async def update_expense_details(
async def delete_expense_record(
expense_id: int,
expected_version: Optional[int] = Query(None, description="Expected version for optimistic locking"),
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""
@ -273,7 +273,7 @@ async def delete_expense_record(
)
async def create_new_settlement(
settlement_in: SettlementCreate,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} recording settlement in group {settlement_in.group_id}")
@ -299,7 +299,7 @@ async def create_new_settlement(
@router.get("/settlements/{settlement_id}", response_model=SettlementPublic, summary="Get Settlement by ID", tags=["Settlements"])
async def get_settlement(
settlement_id: int,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} requesting settlement ID {settlement_id}")
@ -321,7 +321,7 @@ async def list_group_settlements(
group_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=200),
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} listing settlements for group ID {group_id}")
@ -333,7 +333,7 @@ async def list_group_settlements(
async def update_settlement_details(
settlement_id: int,
settlement_in: SettlementUpdate,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""
@ -387,7 +387,7 @@ async def update_settlement_details(
async def delete_settlement_record(
settlement_id: int,
expected_version: Optional[int] = Query(None, description="Expected version for optimistic locking"),
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""

View File

@ -5,13 +5,13 @@ from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_transactional_session
from app.database import get_db
from app.auth import current_active_user
from app.models import User as UserModel, UserRoleEnum # Import model and enum
from app.schemas.group import GroupCreate, GroupPublic
from app.schemas.invite import InviteCodePublic
from app.schemas.message import Message # For simple responses
from app.schemas.list import ListPublic, ListDetail
from app.schemas.list import ListPublic
from app.crud import group as crud_group
from app.crud import invite as crud_invite
from app.crud import list as crud_list
@ -36,7 +36,7 @@ router = APIRouter()
)
async def create_group(
group_in: GroupCreate,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""Creates a new group, adding the creator as the owner."""
@ -54,7 +54,7 @@ async def create_group(
tags=["Groups"]
)
async def read_user_groups(
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""Retrieves all groups the current user is a member of."""
@ -71,7 +71,7 @@ async def read_user_groups(
)
async def read_group(
group_id: int,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""Retrieves details for a specific group, including members, if the user is part of it."""
@ -98,7 +98,7 @@ async def read_group(
)
async def create_group_invite(
group_id: int,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""Generates a new invite code for the group. Requires owner/admin role (MVP: owner only)."""
@ -118,49 +118,11 @@ async def create_group_invite(
invite = await crud_invite.create_invite(db=db, group_id=group_id, creator_id=current_user.id)
if not invite:
logger.error(f"Failed to generate unique invite code for group {group_id}")
# This case should ideally be covered by exceptions from create_invite now
raise InviteCreationError(group_id)
logger.info(f"User {current_user.email} created invite code for group {group_id}")
return invite
@router.get(
"/{group_id}/invites",
response_model=InviteCodePublic, # Or Optional[InviteCodePublic] if it can be null
summary="Get Group Active Invite Code",
tags=["Groups", "Invites"]
)
async def get_group_active_invite(
group_id: int,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Retrieves the active invite code for the group. Requires group membership (owner/admin to be stricter later if needed)."""
logger.info(f"User {current_user.email} attempting to get active invite for group {group_id}")
# Permission check: Ensure user is a member of the group to view invite code
# Using get_user_role_in_group which also checks membership indirectly
user_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=current_user.id)
if user_role is None: # Not a member
logger.warning(f"Permission denied: User {current_user.email} is not a member of group {group_id} and cannot view invite code.")
# More specific error or let GroupPermissionError handle if we want to be generic
raise GroupMembershipError(group_id, "view invite code for this group (not a member)")
# Fetch the active invite for the group
invite = await crud_invite.get_active_invite_for_group(db, group_id=group_id)
if not invite:
# This case means no active (non-expired, active=true) invite exists.
# The frontend can then prompt to generate one.
logger.info(f"No active invite code found for group {group_id} when requested by {current_user.email}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No active invite code found for this group. Please generate one."
)
logger.info(f"User {current_user.email} retrieved active invite code for group {group_id}")
return invite # Pydantic will convert InviteModel to InviteCodePublic
@router.delete(
"/{group_id}/leave",
response_model=Message,
@ -169,7 +131,7 @@ async def get_group_active_invite(
)
async def leave_group(
group_id: int,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""Removes the current user from the specified group."""
@ -208,7 +170,7 @@ async def leave_group(
async def remove_group_member(
group_id: int,
user_id_to_remove: int,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""Removes a specified user from the group. Requires current user to be owner."""
@ -241,13 +203,13 @@ async def remove_group_member(
@router.get(
"/{group_id}/lists",
response_model=List[ListDetail],
response_model=List[ListPublic],
summary="Get Group Lists",
tags=["Groups", "Lists"]
)
async def read_group_lists(
group_id: int,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""Retrieves all lists belonging to a specific group, if the user is a member."""

View File

@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import text
from app.database import get_transactional_session
from app.database import get_async_session
from app.schemas.health import HealthStatus
from app.core.exceptions import DatabaseConnectionError
@ -18,7 +18,7 @@ router = APIRouter()
description="Checks the operational status of the API and its connection to the database.",
tags=["Health"]
)
async def check_health(db: AsyncSession = Depends(get_transactional_session)):
async def check_health(db: AsyncSession = Depends(get_async_session)):
"""
Health check endpoint. Verifies API reachability and database connection.
"""

View File

@ -3,7 +3,7 @@ import logging
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_transactional_session
from app.database import get_db
from app.auth import current_active_user
from app.models import User as UserModel, UserRoleEnum
from app.schemas.invite import InviteAccept
@ -30,7 +30,7 @@ router = APIRouter()
)
async def accept_invite(
invite_in: InviteAccept,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""Accepts a group invite using the provided invite code."""

View File

@ -5,7 +5,7 @@ from typing import List as PyList, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Response, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_transactional_session
from app.database import get_db
from app.auth import current_active_user
# --- Import Models Correctly ---
from app.models import User as UserModel
@ -23,7 +23,7 @@ router = APIRouter()
# Now ItemModel is defined before being used as a type hint
async def get_item_and_verify_access(
item_id: int,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user)
) -> ItemModel:
"""Dependency to get an item and verify the user has access to its list."""
@ -52,7 +52,7 @@ async def get_item_and_verify_access(
async def create_list_item(
list_id: int,
item_in: ItemCreate,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""Adds a new item to a specific list. User must have access to the list."""
@ -80,7 +80,7 @@ async def create_list_item(
)
async def read_list_items(
list_id: int,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
# Add sorting/filtering params later if needed: sort_by: str = 'created_at', order: str = 'asc'
):
@ -99,7 +99,7 @@ async def read_list_items(
@router.put(
"/lists/{list_id}/items/{item_id}", # Nested under lists
"/items/{item_id}", # Operate directly on item ID
response_model=ItemPublic,
summary="Update Item",
tags=["Items"],
@ -108,11 +108,10 @@ async def read_list_items(
}
)
async def update_item(
list_id: int,
item_id: int,
item_id: int, # Item ID from path
item_in: ItemUpdate,
item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user), # Need user ID for completed_by
):
"""
@ -141,7 +140,7 @@ async def update_item(
@router.delete(
"/lists/{list_id}/items/{item_id}", # Nested under lists
"/items/{item_id}", # Operate directly on item ID
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete Item",
tags=["Items"],
@ -150,11 +149,10 @@ async def update_item(
}
)
async def delete_item(
list_id: int,
item_id: int,
item_id: int, # Item ID from path
expected_version: Optional[int] = Query(None, description="The expected version of the item to delete for optimistic locking."),
item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user), # Log who deleted it
):
"""

View File

@ -5,7 +5,7 @@ from typing import List as PyList, Optional # Alias for Python List type hint
from fastapi import APIRouter, Depends, HTTPException, status, Response, Query # Added Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_transactional_session
from app.database import get_db
from app.auth import current_active_user
from app.models import User as UserModel
from app.schemas.list import ListCreate, ListUpdate, ListPublic, ListDetail
@ -40,7 +40,7 @@ router = APIRouter()
)
async def create_list(
list_in: ListCreate,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""
@ -86,12 +86,12 @@ async def create_list(
@router.get(
"", # Route relative to prefix "/lists"
response_model=PyList[ListDetail], # Return a list of detailed list info including items
response_model=PyList[ListPublic], # Return a list of basic list info
summary="List Accessible Lists",
tags=["Lists"]
)
async def read_lists(
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
# Add pagination parameters later if needed: skip: int = 0, limit: int = 100
):
@ -113,7 +113,7 @@ async def read_lists(
)
async def read_list(
list_id: int,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""
@ -138,7 +138,7 @@ async def read_list(
async def update_list(
list_id: int,
list_in: ListUpdate,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""
@ -176,7 +176,7 @@ async def update_list(
async def delete_list(
list_id: int,
expected_version: Optional[int] = Query(None, description="The expected version of the list to delete for optimistic locking."),
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""
@ -211,7 +211,7 @@ async def delete_list(
)
async def read_list_status(
list_id: int,
db: AsyncSession = Depends(get_transactional_session),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(current_active_user),
):
"""

View File

@ -7,7 +7,7 @@ from google.api_core import exceptions as google_exceptions
from app.auth import current_active_user
from app.models import User as UserModel
from app.schemas.ocr import OcrExtractResponse
from app.core.gemini import GeminiOCRService, gemini_initialization_error
from app.core.gemini import extract_items_from_image_gemini, gemini_initialization_error, GeminiOCRService
from app.core.exceptions import (
OCRServiceUnavailableError,
OCRServiceConfigError,
@ -56,8 +56,11 @@ async def ocr_extract_items(
raise FileTooLargeError()
try:
# Use the ocr_service instance instead of the standalone function
extracted_items = await ocr_service.extract_items(image_data=contents)
# Call the Gemini helper function
extracted_items = await extract_items_from_image_gemini(
image_bytes=contents,
mime_type=image_file.content_type
)
logger.info(f"Successfully extracted {len(extracted_items)} items for user {current_user.email}.")
return OcrExtractResponse(extracted_items=extracted_items)

View File

@ -3,7 +3,7 @@ import pytest
from httpx import AsyncClient
from app.schemas.user import UserPublic # For response validation
# from app.core.security import create_access_token # Commented out as FastAPI-Users handles token creation
from app.core.security import create_access_token
pytestmark = pytest.mark.asyncio
@ -51,15 +51,15 @@ async def test_read_users_me_invalid_token(client: AsyncClient):
assert response.status_code == 401
assert response.json()["detail"] == "Could not validate credentials" # Detail from our dependency
# async def test_read_users_me_expired_token(client: AsyncClient):
# # Create a short-lived token manually (or adjust settings temporarily)
# email = "testexpired@example.com"
# # Assume create_access_token allows timedelta override
# # expired_token = create_access_token(subject=email, expires_delta=timedelta(seconds=-10))
# # headers = {"Authorization": f"Bearer {expired_token}"}
async def test_read_users_me_expired_token(client: AsyncClient):
# Create a short-lived token manually (or adjust settings temporarily)
email = "testexpired@example.com"
# Assume create_access_token allows timedelta override
expired_token = create_access_token(subject=email, expires_delta=timedelta(seconds=-10))
headers = {"Authorization": f"Bearer {expired_token}"}
# # response = await client.get("/api/v1/users/me", headers=headers)
# # assert response.status_code == 401
# # assert response.json()["detail"] == "Could not validate credentials" # Detail from our dependency
response = await client.get("/api/v1/users/me", headers=headers)
assert response.status_code == 401
assert response.json()["detail"] == "Could not validate credentials"
# Add test case for valid token but user deleted from DB if needed

View File

@ -15,7 +15,7 @@ from starlette.config import Config
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response
from .database import get_session
from .database import get_async_session
from .models import User
from .config import settings
@ -65,7 +65,7 @@ class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
):
print(f"User {user.id} has logged in.")
async def get_user_db(session: AsyncSession = Depends(get_session)):
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
yield SQLAlchemyUserDatabase(session, User)
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):

View File

@ -16,8 +16,8 @@ class Settings(BaseSettings):
# --- JWT Settings --- (SECRET_KEY is used by FastAPI-Users)
SECRET_KEY: str # Must be set via environment variable
TOKEN_TYPE: str = "bearer" # Default token type for JWT authentication
# FastAPI-Users handles JWT algorithm internally
# ALGORITHM: str = "HS256" # Handled by FastAPI-Users strategy
# ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # This specific line is commented, the one under Session Settings is used.
# --- OCR Settings ---
MAX_FILE_SIZE_MB: int = 10 # Maximum allowed file size for OCR processing
@ -36,14 +36,6 @@ Bread
__Apples__
Organic Bananas
"""
# --- OCR Error Messages ---
OCR_SERVICE_UNAVAILABLE: str = "OCR service is currently unavailable. Please try again later."
OCR_SERVICE_CONFIG_ERROR: str = "OCR service configuration error. Please contact support."
OCR_UNEXPECTED_ERROR: str = "An unexpected error occurred during OCR processing."
OCR_QUOTA_EXCEEDED: str = "OCR service quota exceeded. Please try again later."
OCR_INVALID_FILE_TYPE: str = "Invalid file type. Supported types: {types}"
OCR_FILE_TOO_LARGE: str = "File too large. Maximum size: {size}MB"
OCR_PROCESSING_ERROR: str = "Error processing image: {detail}"
# --- Gemini AI Settings ---
GEMINI_MODEL_NAME: str = "gemini-2.0-flash" # The model to use for OCR
@ -106,14 +98,6 @@ Organic Bananas
DB_TRANSACTION_ERROR: str = "Database transaction error"
DB_QUERY_ERROR: str = "Database query error"
# --- Auth Error Messages ---
AUTH_INVALID_CREDENTIALS: str = "Invalid username or password"
AUTH_NOT_AUTHENTICATED: str = "Not authenticated"
AUTH_JWT_ERROR: str = "JWT token error: {error}"
AUTH_JWT_UNEXPECTED_ERROR: str = "Unexpected JWT error: {error}"
AUTH_HEADER_NAME: str = "WWW-Authenticate"
AUTH_HEADER_PREFIX: str = "Bearer"
# OAuth Settings
GOOGLE_CLIENT_ID: str = ""
GOOGLE_CLIENT_SECRET: str = ""

View File

@ -128,14 +128,6 @@ class DatabaseQueryError(HTTPException):
detail=detail
)
class ExpenseOperationError(HTTPException):
"""Raised when an expense operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class OCRServiceUnavailableError(HTTPException):
"""Raised when the OCR service is unavailable."""
def __init__(self):
@ -248,22 +240,6 @@ class ListStatusNotFoundError(HTTPException):
detail=f"Status for list {list_id} not found"
)
class InviteOperationError(HTTPException):
"""Raised when an invite operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class SettlementOperationError(HTTPException):
"""Raised when a settlement operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class ConflictError(HTTPException):
"""Raised when an optimistic lock version conflict occurs."""
def __init__(self, detail: str):
@ -295,7 +271,7 @@ class JWTError(HTTPException):
def __init__(self, error: str):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.AUTH_JWT_ERROR.format(error=error),
detail=settings.JWT_ERROR.format(error=error),
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
)
@ -304,30 +280,6 @@ class JWTUnexpectedError(HTTPException):
def __init__(self, error: str):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.AUTH_JWT_UNEXPECTED_ERROR.format(error=error),
detail=settings.JWT_UNEXPECTED_ERROR.format(error=error),
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
)
class ListOperationError(HTTPException):
"""Raised when a list operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class ItemOperationError(HTTPException):
"""Raised when an item operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class UserOperationError(HTTPException):
"""Raised when a user operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)

View File

@ -9,8 +9,7 @@ from app.core.exceptions import (
OCRServiceUnavailableError,
OCRServiceConfigError,
OCRUnexpectedError,
OCRQuotaExceededError,
OCRProcessingError
OCRQuotaExceededError
)
logger = logging.getLogger(__name__)
@ -26,6 +25,12 @@ try:
# Initialize the specific model we want to use
gemini_flash_client = genai.GenerativeModel(
model_name=settings.GEMINI_MODEL_NAME,
# Safety settings from config
safety_settings={
getattr(HarmCategory, category): getattr(HarmBlockThreshold, threshold)
for category, threshold in settings.GEMINI_SAFETY_SETTINGS.items()
},
# Generation config from settings
generation_config=genai.types.GenerationConfig(
**settings.GEMINI_GENERATION_CONFIG
)
@ -50,10 +55,10 @@ def get_gemini_client():
Raises an exception if initialization failed.
"""
if gemini_initialization_error:
raise OCRServiceConfigError()
raise RuntimeError(f"Gemini client could not be initialized: {gemini_initialization_error}")
if gemini_flash_client is None:
# This case should ideally be covered by the check above, but as a safeguard:
raise OCRServiceConfigError()
raise RuntimeError("Gemini client is not available (unknown initialization issue).")
return gemini_flash_client
# Define the prompt as a constant
@ -83,29 +88,26 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "
A list of extracted item strings.
Raises:
OCRServiceConfigError: If the Gemini client is not initialized.
OCRQuotaExceededError: If API quota is exceeded.
OCRServiceUnavailableError: For general API call errors.
OCRProcessingError: If the response is blocked or contains no usable text.
OCRUnexpectedError: For any other unexpected errors.
RuntimeError: If the Gemini client is not initialized.
google_exceptions.GoogleAPIError: For API call errors (quota, invalid key etc.).
ValueError: If the response is blocked or contains no usable text.
"""
client = get_gemini_client() # Raises RuntimeError if not initialized
# Prepare image part for multimodal input
image_part = {
"mime_type": mime_type,
"data": image_bytes
}
# Prepare the full prompt content
prompt_parts = [
settings.OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first
image_part # Then the image
]
logger.info("Sending image to Gemini for item extraction...")
try:
client = get_gemini_client() # Raises OCRServiceConfigError if not initialized
# Prepare image part for multimodal input
image_part = {
"mime_type": mime_type,
"data": image_bytes
}
# Prepare the full prompt content
prompt_parts = [
settings.OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first
image_part # Then the image
]
logger.info("Sending image to Gemini for item extraction...")
# Make the API call
# Use generate_content_async for async FastAPI
response = await client.generate_content_async(prompt_parts)
@ -118,9 +120,9 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "
finish_reason = response.candidates[0].finish_reason if response.candidates else 'UNKNOWN'
safety_ratings = response.candidates[0].safety_ratings if response.candidates else 'N/A'
if finish_reason == 'SAFETY':
raise OCRProcessingError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
raise ValueError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
else:
raise OCRUnexpectedError()
raise ValueError(f"Gemini response was empty or incomplete. Finish Reason: {finish_reason}")
# Extract text - assumes the first part of the first candidate is the text response
raw_text = response.text # response.text is a shortcut for response.candidates[0].content.parts[0].text
@ -141,53 +143,32 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "
except google_exceptions.GoogleAPIError as e:
logger.error(f"Gemini API Error: {e}", exc_info=True)
if "quota" in str(e).lower():
raise OCRQuotaExceededError()
raise OCRServiceUnavailableError()
except (OCRServiceConfigError, OCRQuotaExceededError, OCRServiceUnavailableError, OCRProcessingError, OCRUnexpectedError):
# Re-raise specific OCR exceptions
raise
# Re-raise specific Google API errors for endpoint to handle (e.g., quota)
raise e
except Exception as e:
# Catch other unexpected errors during generation or processing
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
# Wrap in a custom exception
raise OCRUnexpectedError()
# Wrap in a generic ValueError or re-raise
raise ValueError(f"Failed to process image with Gemini: {e}") from e
class GeminiOCRService:
def __init__(self):
try:
genai.configure(api_key=settings.GEMINI_API_KEY)
self.model = genai.GenerativeModel(
model_name=settings.GEMINI_MODEL_NAME,
generation_config=genai.types.GenerationConfig(
**settings.GEMINI_GENERATION_CONFIG
)
)
self.model = genai.GenerativeModel(settings.GEMINI_MODEL_NAME)
self.model.safety_settings = settings.GEMINI_SAFETY_SETTINGS
self.model.generation_config = settings.GEMINI_GENERATION_CONFIG
except Exception as e:
logger.error(f"Failed to initialize Gemini client: {e}")
raise OCRServiceConfigError()
async def extract_items(self, image_data: bytes, mime_type: str = "image/jpeg") -> List[str]:
async def extract_items(self, image_data: bytes) -> List[str]:
"""
Extract shopping list items from an image using Gemini Vision.
Args:
image_data: The image content as bytes.
mime_type: The MIME type of the image (e.g., "image/jpeg", "image/png", "image/webp").
Returns:
A list of extracted item strings.
Raises:
OCRServiceConfigError: If the Gemini client is not initialized.
OCRQuotaExceededError: If API quota is exceeded.
OCRServiceUnavailableError: For general API call errors.
OCRProcessingError: If the response is blocked or contains no usable text.
OCRUnexpectedError: For any other unexpected errors.
"""
try:
# Create image part
image_parts = [{"mime_type": mime_type, "data": image_data}]
image_parts = [{"mime_type": "image/jpeg", "data": image_data}]
# Generate content
response = await self.model.generate_content_async(
@ -196,34 +177,19 @@ class GeminiOCRService:
# Process response
if not response.text:
logger.warning("Gemini response is empty")
raise OCRUnexpectedError()
# Check for safety blocks
if hasattr(response, 'candidates') and response.candidates and hasattr(response.candidates[0], 'finish_reason'):
finish_reason = response.candidates[0].finish_reason
if finish_reason == 'SAFETY':
safety_ratings = response.candidates[0].safety_ratings if hasattr(response.candidates[0], 'safety_ratings') else 'N/A'
raise OCRProcessingError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
# Split response into lines and clean up
items = []
for line in response.text.splitlines():
cleaned_line = line.strip()
if cleaned_line and len(cleaned_line) > 1 and not cleaned_line.startswith("Example"):
items.append(cleaned_line)
items = [
item.strip()
for item in response.text.split("\n")
if item.strip() and not item.strip().startswith("Example")
]
logger.info(f"Extracted {len(items)} potential items.")
return items
except google_exceptions.GoogleAPIError as e:
except Exception as e:
logger.error(f"Error during OCR extraction: {e}")
if "quota" in str(e).lower():
raise OCRQuotaExceededError()
raise OCRServiceUnavailableError()
except (OCRServiceConfigError, OCRQuotaExceededError, OCRServiceUnavailableError, OCRProcessingError, OCRUnexpectedError):
# Re-raise specific OCR exceptions
raise
except Exception as e:
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
raise OCRUnexpectedError()
raise OCRServiceUnavailableError()

View File

@ -8,9 +8,6 @@ from passlib.context import CryptContext
from app.config import settings # Import settings from config
# --- Password Hashing ---
# These functions are used for password hashing and verification
# They complement FastAPI-Users but provide direct access to the underlying password functionality
# when needed outside of the FastAPI-Users authentication flow.
# Configure passlib context
# Using bcrypt as the default hashing scheme
@ -20,8 +17,6 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
Verifies a plain text password against a hashed password.
This is used by FastAPI-Users internally, but also exposed here for custom authentication flows
if needed.
Args:
plain_password: The password attempt.
@ -39,8 +34,6 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
def hash_password(password: str) -> str:
"""
Hashes a plain text password using the configured context (bcrypt).
This is used by FastAPI-Users internally, but also exposed here for
custom user creation or password reset flows if needed.
Args:
password: The plain text password to hash.
@ -52,22 +45,14 @@ def hash_password(password: str) -> str:
# --- JSON Web Tokens (JWT) ---
# FastAPI-Users now handles all JWT token creation and validation.
# The code below is commented out because FastAPI-Users provides these features.
# It's kept for reference in case a custom implementation is needed later.
# FastAPI-Users now handles all tokenization.
# Example of a potential future implementation:
# You might add a function here later to extract the 'sub' (subject/user id)
# specifically, often used in dependency injection for authentication.
# def get_subject_from_token(token: str) -> Optional[str]:
# """
# Extract the subject (user ID) from a JWT token.
# This would be used if we need to validate tokens outside of FastAPI-Users flow.
# For now, use fastapi_users.current_user dependency instead.
# """
# # This would need to use FastAPI-Users' token verification if ever implemented
# # For example, by decoding the token using the strategy from the auth backend
# try:
# payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
# payload = {} # Placeholder for actual token decoding logic
# if payload:
# return payload.get("sub")
# except JWTError:
# return None
# return None

View File

@ -3,9 +3,8 @@ import logging # Add logging import
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError # Added import
from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict, Any
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict
from datetime import datetime, timezone # Added timezone
from app.models import (
@ -24,12 +23,7 @@ from app.core.exceptions import (
ListNotFoundError,
GroupNotFoundError,
UserNotFoundError,
InvalidOperationError, # Import the new exception
DatabaseConnectionError, # Added
DatabaseIntegrityError, # Added
DatabaseQueryError, # Added
DatabaseTransactionError,# Added
ExpenseOperationError # Added specific exception
InvalidOperationError # Import the new exception
)
# Placeholder for InvalidOperationError if not defined in app.core.exceptions
@ -114,98 +108,60 @@ async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_us
GroupNotFoundError: If specified group doesn't exist
InvalidOperationError: For various validation failures
"""
# Helper function to round decimals consistently
def round_money(amount: Decimal) -> Decimal:
return amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
# 1. Context Validation
# Validate basic context requirements first
if not expense_in.list_id and not expense_in.group_id:
raise InvalidOperationError("Expense must be associated with a list or a group.")
# 2. User Validation
payer = await db.get(UserModel, expense_in.paid_by_user_id)
if not payer:
raise UserNotFoundError(user_id=expense_in.paid_by_user_id)
# 3. List/Group Context Resolution
final_group_id = await _resolve_expense_context(db, expense_in)
# 4. Create the expense object
db_expense = ExpenseModel(
description=expense_in.description,
total_amount=round_money(expense_in.total_amount),
currency=expense_in.currency or "USD",
expense_date=expense_in.expense_date or datetime.now(timezone.utc),
split_type=expense_in.split_type,
list_id=expense_in.list_id,
group_id=final_group_id,
item_id=expense_in.item_id,
paid_by_user_id=expense_in.paid_by_user_id,
created_by_user_id=current_user_id # Track who created this expense
)
# 5. Generate splits based on split type
splits_to_create = await _generate_expense_splits(db, db_expense, expense_in, round_money)
# 6. Single transaction for expense and all splits
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
# 1. Validate payer
payer = await db.get(UserModel, expense_in.paid_by_user_id)
if not payer:
raise UserNotFoundError(user_id=expense_in.paid_by_user_id, identifier="Payer")
# 2. Context Resolution and Validation (now part of the transaction)
if not expense_in.list_id and not expense_in.group_id and not expense_in.item_id:
raise InvalidOperationError("Expense must be associated with a list, a group, or an item.")
final_group_id = await _resolve_expense_context(db, expense_in)
# Further validation for item_id if provided
db_item_instance = None
if expense_in.item_id:
db_item_instance = await db.get(ItemModel, expense_in.item_id)
if not db_item_instance:
raise InvalidOperationError(f"Item with ID {expense_in.item_id} not found.")
# Potentially link item's list/group if not already set on expense_in
if db_item_instance.list_id and not expense_in.list_id:
expense_in.list_id = db_item_instance.list_id
# Re-resolve context if list_id was derived from item
final_group_id = await _resolve_expense_context(db, expense_in)
# 3. Create the ExpenseModel instance
db_expense = ExpenseModel(
description=expense_in.description,
total_amount=_round_money(expense_in.total_amount),
currency=expense_in.currency or "USD",
expense_date=expense_in.expense_date or datetime.now(timezone.utc),
split_type=expense_in.split_type,
list_id=expense_in.list_id,
group_id=final_group_id, # Use resolved group_id
item_id=expense_in.item_id,
paid_by_user_id=expense_in.paid_by_user_id,
created_by_user_id=current_user_id
)
db.add(db_expense)
await db.flush() # Get expense ID
# 4. Generate splits (passing current_user_id through kwargs if needed by specific split types)
splits_to_create = await _generate_expense_splits(
db=db,
expense_model=db_expense,
expense_in=expense_in,
current_user_id=current_user_id # Pass for item-based splits needing creator info
)
for split_model in splits_to_create:
split_model.expense_id = db_expense.id # Set FK after db_expense has ID
db.add_all(splits_to_create)
await db.flush() # Persist splits
# 5. Re-fetch the expense with all necessary relationships for the response
stmt = (
select(ExpenseModel)
.where(ExpenseModel.id == db_expense.id)
.options(
selectinload(ExpenseModel.paid_by_user),
selectinload(ExpenseModel.created_by_user), # If you have this relationship
selectinload(ExpenseModel.list),
selectinload(ExpenseModel.group),
selectinload(ExpenseModel.item),
selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user)
)
)
result = await db.execute(stmt)
loaded_expense = result.scalar_one_or_none()
if loaded_expense is None:
# The context manager will handle rollback if an exception is raised.
# await transaction.rollback() # Should be handled by context manager
raise ExpenseOperationError("Failed to load expense after creation.")
# await transaction.commit() # Explicit commit removed, context manager handles it.
return loaded_expense
except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
# These are business logic validation errors, re-raise them.
# If a transaction was started, the context manager handles rollback.
raise
except IntegrityError as e:
# Context manager handles rollback.
logger.error(f"Database integrity error during expense creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to save expense due to database integrity issue: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during expense creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error during expense creation: {str(e)}")
except SQLAlchemyError as e:
# Context manager handles rollback.
logger.error(f"Unexpected SQLAlchemy error during expense creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to save expense due to a database transaction error: {str(e)}")
db.add(db_expense)
await db.flush() # Get expense ID without committing
# Update all splits with the expense ID
for split in splits_to_create:
split.expense_id = db_expense.id
db.add_all(splits_to_create)
await db.commit()
except Exception as e:
await db.rollback()
logger.error(f"Failed to save expense: {str(e)}", exc_info=True)
raise InvalidOperationError(f"Failed to save expense: {str(e)}")
# Refresh to get the splits relationship populated
await db.refresh(db_expense, attribute_names=["splits"])
return db_expense
async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]:
@ -241,32 +197,39 @@ async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate)
async def _generate_expense_splits(
db: AsyncSession,
expense_model: ExpenseModel,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
**kwargs: Any
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Generates appropriate expense splits based on split type."""
splits_to_create: PyList[ExpenseSplitModel] = []
# Pass db to split creation helpers if they need to fetch more data (e.g., item details for item-based)
common_args = {"db": db, "expense_model": expense_model, "expense_in": expense_in, "round_money_func": _round_money, "kwargs": kwargs}
# Create splits based on the split type
if expense_in.split_type == SplitTypeEnum.EQUAL:
splits_to_create = await _create_equal_splits(**common_args)
splits_to_create = await _create_equal_splits(
db, db_expense, expense_in, round_money
)
elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS:
splits_to_create = await _create_exact_amount_splits(**common_args)
splits_to_create = await _create_exact_amount_splits(
db, db_expense, expense_in, round_money
)
elif expense_in.split_type == SplitTypeEnum.PERCENTAGE:
splits_to_create = await _create_percentage_splits(**common_args)
splits_to_create = await _create_percentage_splits(
db, db_expense, expense_in, round_money
)
elif expense_in.split_type == SplitTypeEnum.SHARES:
splits_to_create = await _create_shares_splits(**common_args)
splits_to_create = await _create_shares_splits(
db, db_expense, expense_in, round_money
)
elif expense_in.split_type == SplitTypeEnum.ITEM_BASED:
splits_to_create = await _create_item_based_splits(**common_args)
splits_to_create = await _create_item_based_splits(
db, db_expense, expense_in, round_money
)
else:
raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}")
@ -277,24 +240,29 @@ async def _generate_expense_splits(
return splits_to_create
async def _create_equal_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
async def _create_equal_splits(
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates equal splits among users."""
users_for_splitting = await get_users_for_splitting(
db, expense_model.group_id, expense_model.list_id, expense_model.paid_by_user_id
db, db_expense.group_id, expense_in.list_id, expense_in.paid_by_user_id
)
if not users_for_splitting:
raise InvalidOperationError("No users found for EQUAL split.")
num_users = len(users_for_splitting)
amount_per_user = round_money_func(expense_model.total_amount / Decimal(num_users))
remainder = expense_model.total_amount - (amount_per_user * num_users)
amount_per_user = round_money(db_expense.total_amount / Decimal(num_users))
remainder = db_expense.total_amount - (amount_per_user * num_users)
splits = []
for i, user in enumerate(users_for_splitting):
split_amount = amount_per_user
if i == 0 and remainder != Decimal('0'):
split_amount = round_money_func(amount_per_user + remainder)
split_amount = round_money(amount_per_user + remainder)
splits.append(ExpenseSplitModel(
user_id=user.id,
@ -304,7 +272,12 @@ async def _create_equal_splits(db: AsyncSession, expense_model: ExpenseModel, ex
return splits
async def _create_exact_amount_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
async def _create_exact_amount_splits(
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates splits with exact amounts."""
if not expense_in.splits_in:
@ -320,7 +293,7 @@ async def _create_exact_amount_splits(db: AsyncSession, expense_model: ExpenseMo
if split_in.owed_amount <= Decimal('0'):
raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.")
rounded_amount = round_money_func(split_in.owed_amount)
rounded_amount = round_money(split_in.owed_amount)
current_total += rounded_amount
splits.append(ExpenseSplitModel(
@ -328,15 +301,20 @@ async def _create_exact_amount_splits(db: AsyncSession, expense_model: ExpenseMo
owed_amount=rounded_amount
))
if round_money_func(current_total) != expense_model.total_amount:
if round_money(current_total) != db_expense.total_amount:
raise InvalidOperationError(
f"Sum of exact split amounts ({current_total}) != expense total ({expense_model.total_amount})."
f"Sum of exact split amounts ({current_total}) != expense total ({db_expense.total_amount})."
)
return splits
async def _create_percentage_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
async def _create_percentage_splits(
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates splits based on percentages."""
if not expense_in.splits_in:
@ -356,7 +334,7 @@ async def _create_percentage_splits(db: AsyncSession, expense_model: ExpenseMode
)
total_percentage += split_in.share_percentage
owed_amount = round_money_func(expense_model.total_amount * (split_in.share_percentage / Decimal("100")))
owed_amount = round_money(db_expense.total_amount * (split_in.share_percentage / Decimal("100")))
current_total += owed_amount
splits.append(ExpenseSplitModel(
@ -365,18 +343,23 @@ async def _create_percentage_splits(db: AsyncSession, expense_model: ExpenseMode
share_percentage=split_in.share_percentage
))
if round_money_func(total_percentage) != Decimal("100.00"):
if round_money(total_percentage) != Decimal("100.00"):
raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.")
# Adjust for rounding differences
if current_total != expense_model.total_amount and splits:
diff = expense_model.total_amount - current_total
splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
if current_total != db_expense.total_amount and splits:
diff = db_expense.total_amount - current_total
splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff)
return splits
async def _create_shares_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
async def _create_shares_splits(
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates splits based on shares."""
if not expense_in.splits_in:
@ -398,7 +381,7 @@ async def _create_shares_splits(db: AsyncSession, expense_model: ExpenseModel, e
raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.")
share_ratio = Decimal(split_in.share_units) / Decimal(total_shares)
owed_amount = round_money_func(expense_model.total_amount * share_ratio)
owed_amount = round_money(db_expense.total_amount * share_ratio)
current_total += owed_amount
splits.append(ExpenseSplitModel(
@ -408,26 +391,31 @@ async def _create_shares_splits(db: AsyncSession, expense_model: ExpenseModel, e
))
# Adjust for rounding differences
if current_total != expense_model.total_amount and splits:
diff = expense_model.total_amount - current_total
splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
if current_total != db_expense.total_amount and splits:
diff = db_expense.total_amount - current_total
splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff)
return splits
async def _create_item_based_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
async def _create_item_based_splits(
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates splits based on items in a shopping list."""
if not expense_model.list_id:
if not expense_in.list_id:
raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.")
if expense_in.splits_in:
logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.")
# Build query to fetch relevant items
items_query = select(ItemModel).where(ItemModel.list_id == expense_model.list_id)
if expense_model.item_id:
items_query = items_query.where(ItemModel.id == expense_model.item_id)
items_query = select(ItemModel).where(ItemModel.list_id == expense_in.list_id)
if expense_in.item_id:
items_query = items_query.where(ItemModel.id == expense_in.item_id)
else:
items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0")))
@ -437,9 +425,9 @@ async def _create_item_based_splits(db: AsyncSession, expense_model: ExpenseMode
if not relevant_items:
error_msg = (
f"Specified item ID {expense_model.item_id} not found in list {expense_model.list_id}."
if expense_model.item_id else
f"List {expense_model.list_id} has no priced items to base the expense on."
f"Specified item ID {expense_in.item_id} not found in list {expense_in.list_id}."
if expense_in.item_id else
f"List {expense_in.list_id} has no priced items to base the expense on."
)
raise InvalidOperationError(error_msg)
@ -450,9 +438,9 @@ async def _create_item_based_splits(db: AsyncSession, expense_model: ExpenseMode
for item in relevant_items:
if item.price is None or item.price <= Decimal("0"):
if expense_model.item_id:
if expense_in.item_id:
raise InvalidOperationError(
f"Item ID {expense_model.item_id} must have a positive price for ITEM_BASED expense."
f"Item ID {expense_in.item_id} must have a positive price for ITEM_BASED expense."
)
continue
@ -466,13 +454,13 @@ async def _create_item_based_splits(db: AsyncSession, expense_model: ExpenseMode
if processed_items == 0:
raise InvalidOperationError(
f"No items with positive prices found in list {expense_model.list_id} to create ITEM_BASED expense."
f"No items with positive prices found in list {expense_in.list_id} to create ITEM_BASED expense."
)
# Validate total matches calculated total
if round_money_func(calculated_total) != expense_model.total_amount:
if round_money(calculated_total) != db_expense.total_amount:
raise InvalidOperationError(
f"Expense total amount ({expense_model.total_amount}) does not match the "
f"Expense total amount ({db_expense.total_amount}) does not match the "
f"calculated total from item prices ({calculated_total})."
)
@ -481,7 +469,7 @@ async def _create_item_based_splits(db: AsyncSession, expense_model: ExpenseMode
for user_id, owed_amount in user_owed_amounts.items():
splits.append(ExpenseSplitModel(
user_id=user_id,
owed_amount=round_money_func(owed_amount)
owed_amount=round_money(owed_amount)
))
return splits
@ -535,7 +523,7 @@ async def get_expenses_for_group(db: AsyncSession, group_id: int, skip: int = 0,
async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in: ExpenseUpdate) -> ExpenseModel:
"""
Updates an existing expense.
Updates an existing expense.
Only allows updates to description, currency, and expense_date to avoid split complexities.
Requires version matching for optimistic locking.
"""
@ -566,27 +554,18 @@ async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in:
# For now, if only version was sent, we still increment if it matched.
pass # Or raise InvalidOperationError("No updatable fields provided.")
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
expense_db.version += 1
expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp
# db.add(expense_db) # Not strictly necessary as expense_db is already tracked by the session
await db.flush() # Persist changes to the DB and run constraints
await db.refresh(expense_db) # Refresh the object from the DB
return expense_db
except InvalidOperationError: # Re-raise validation errors to be handled by the caller
raise
except IntegrityError as e:
logger.error(f"Database integrity error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseIntegrityError(f"Failed to update expense ID {expense_db.id} due to database integrity issue.") from e
except SQLAlchemyError as e: # Catch other SQLAlchemy errors
logger.error(f"Database transaction error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseTransactionError(f"Failed to update expense ID {expense_db.id} due to a database transaction error.") from e
# No generic Exception catch here, let other unexpected errors propagate if not SQLAlchemy related.
expense_db.version += 1
expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp
try:
await db.commit()
await db.refresh(expense_db)
except Exception as e:
await db.rollback()
# Consider specific DB error types if needed
raise InvalidOperationError(f"Failed to update expense: {str(e)}")
return expense_db
async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None:
"""
@ -600,20 +579,12 @@ async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_ve
# status_code=status.HTTP_409_CONFLICT
)
await db.delete(expense_db)
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
await db.delete(expense_db)
await db.flush() # Ensure the delete operation is sent to the database
except InvalidOperationError: # Re-raise validation errors
raise
except IntegrityError as e:
logger.error(f"Database integrity error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseIntegrityError(f"Failed to delete expense ID {expense_db.id} due to database integrity issue.") from e
except SQLAlchemyError as e: # Catch other SQLAlchemy errors
logger.error(f"Database transaction error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseTransactionError(f"Failed to delete expense ID {expense_db.id} due to a database transaction error.") from e
await db.commit()
except Exception as e:
await db.rollback()
raise InvalidOperationError(f"Failed to delete expense: {str(e)}")
return None
# Note: The InvalidOperationError is a simple ValueError placeholder.

View File

@ -4,8 +4,7 @@ from sqlalchemy.future import select
from sqlalchemy.orm import selectinload # For eager loading members
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List
from sqlalchemy import delete, func
import logging # Add logging import
from sqlalchemy import func
from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel
from app.schemas.group import GroupCreate
@ -21,19 +20,14 @@ from app.core.exceptions import (
GroupPermissionError # Import GroupPermissionError
)
logger = logging.getLogger(__name__) # Initialize logger
# --- Group CRUD ---
async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) -> GroupModel:
"""Creates a group and adds the creator as the owner."""
try:
# Use the composability pattern for transactions as per fastapi-db-strategy.
# This creates a savepoint if already in a transaction (e.g., from get_transactional_session)
# or starts a new transaction if called outside of one (e.g., from a script).
async with db.begin_nested() if db.in_transaction() else db.begin():
async with db.begin():
db_group = GroupModel(name=group_in.name, created_by_id=creator_id)
db.add(db_group)
await db.flush() # Assigns ID to db_group
await db.flush()
db_user_group = UserGroupModel(
user_id=creator_id,
@ -41,33 +35,15 @@ async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int)
role=UserRoleEnum.owner
)
db.add(db_user_group)
await db.flush() # Commits user_group, links to group
# After creation and linking, explicitly load the group with its member associations and users
stmt = (
select(GroupModel)
.where(GroupModel.id == db_group.id)
.options(
selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user)
)
)
result = await db.execute(stmt)
loaded_group = result.scalar_one_or_none()
if loaded_group is None:
# This should not happen if we just created it, but as a safeguard
raise GroupOperationError("Failed to load group after creation.")
return loaded_group
await db.flush()
await db.refresh(db_group)
return db_group
except IntegrityError as e:
logger.error(f"Database integrity error during group creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to create group due to integrity issue: {str(e)}")
raise DatabaseIntegrityError(f"Failed to create group: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during group creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error during group creation: {str(e)}")
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during group creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Database transaction error during group creation: {str(e)}")
raise DatabaseTransactionError(f"Failed to create group: {str(e)}")
async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]:
"""Gets all groups a user is a member of."""
@ -76,9 +52,7 @@ async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]:
select(GroupModel)
.join(UserGroupModel)
.where(UserGroupModel.user_id == user_id)
.options(
selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user)
)
.options(selectinload(GroupModel.member_associations))
)
return result.scalars().all()
except OperationalError as e:
@ -132,48 +106,29 @@ async def get_user_role_in_group(db: AsyncSession, group_id: int, user_id: int)
async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role: UserRoleEnum = UserRoleEnum.member) -> Optional[UserGroupModel]:
"""Adds a user to a group if they aren't already a member."""
try:
# Check if user is already a member before starting a transaction
existing_stmt = select(UserGroupModel.id).where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
existing_result = await db.execute(existing_stmt)
if existing_result.scalar_one_or_none():
return None
async with db.begin():
existing = await db.execute(
select(UserGroupModel).where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
)
if existing.scalar_one_or_none():
return None
# Use a single transaction
async with db.begin_nested() if db.in_transaction() else db.begin():
db_user_group = UserGroupModel(user_id=user_id, group_id=group_id, role=role)
db.add(db_user_group)
await db.flush() # Assigns ID to db_user_group
# Eagerly load the 'user' and 'group' relationships for the response
stmt = (
select(UserGroupModel)
.where(UserGroupModel.id == db_user_group.id)
.options(
selectinload(UserGroupModel.user),
selectinload(UserGroupModel.group)
)
)
result = await db.execute(stmt)
loaded_user_group = result.scalar_one_or_none()
if loaded_user_group is None:
raise GroupOperationError(f"Failed to load user group association after adding user {user_id} to group {group_id}.")
return loaded_user_group
await db.flush()
await db.refresh(db_user_group)
return db_user_group
except IntegrityError as e:
logger.error(f"Database integrity error while adding user to group: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to add user to group: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error while adding user to group: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error while adding user to group: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to add user to group: {str(e)}")
async def remove_user_from_group(db: AsyncSession, group_id: int, user_id: int) -> bool:
"""Removes a user from a group."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin():
async with db.begin():
result = await db.execute(
delete(UserGroupModel)
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
@ -181,10 +136,8 @@ async def remove_user_from_group(db: AsyncSession, group_id: int, user_id: int)
)
return result.scalar_one_or_none() is not None
except OperationalError as e:
logger.error(f"Database connection error while removing user from group: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error while removing user from group: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to remove user from group: {str(e)}")
async def get_group_member_count(db: AsyncSession, group_id: int) -> int:

View File

@ -1,199 +1,69 @@
# app/crud/invite.py
import logging # Add logging import
import secrets
from datetime import datetime, timedelta, timezone
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
from sqlalchemy import delete # Import delete statement
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
from typing import Optional
from app.models import Invite as InviteModel, Group as GroupModel, User as UserModel # Import related models for selectinload
from app.core.exceptions import (
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
InviteOperationError # Add new specific exception
)
logger = logging.getLogger(__name__) # Initialize logger
from app.models import Invite as InviteModel
# Invite codes should be reasonably unique, but handle potential collision
MAX_CODE_GENERATION_ATTEMPTS = 5
async def deactivate_all_active_invites_for_group(db: AsyncSession, group_id: int):
"""Deactivates all currently active invite codes for a specific group."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin():
stmt = (
select(InviteModel)
.where(InviteModel.group_id == group_id, InviteModel.is_active == True)
)
result = await db.execute(stmt)
active_invites = result.scalars().all()
async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 7) -> Optional[InviteModel]:
"""Creates a new invite code for a group."""
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
code = None
attempts = 0
if not active_invites:
return # No active invites to deactivate
for invite in active_invites:
invite.is_active = False
db.add(invite)
await db.flush() # Flush changes within this transaction block
# await db.flush() # Removed: Rely on caller to flush/commit
# No explicit commit here, assuming it's part of a larger transaction or caller handles commit.
except OperationalError as e:
logger.error(f"Database connection error deactivating invites for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error deactivating invites for group {group_id}: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error deactivating invites for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error deactivating invites for group {group_id}: {str(e)}")
async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 365 * 100) -> Optional[InviteModel]: # Default to 100 years
"""Creates a new invite code for a group, deactivating any existing active ones for that group first."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin():
# Deactivate existing active invites for this group
await deactivate_all_active_invites_for_group(db, group_id)
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
potential_code = None
for attempt in range(MAX_CODE_GENERATION_ATTEMPTS):
potential_code = secrets.token_urlsafe(16)
existing_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
existing_result = await db.execute(existing_check_stmt)
if existing_result.scalar_one_or_none() is None:
break
if attempt == MAX_CODE_GENERATION_ATTEMPTS - 1:
raise InviteOperationError("Failed to generate a unique invite code after several attempts.")
final_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
final_check_result = await db.execute(final_check_stmt)
if final_check_result.scalar_one_or_none() is not None:
raise InviteOperationError("Invite code collision detected just before creation attempt.")
db_invite = InviteModel(
code=potential_code,
group_id=group_id,
created_by_id=creator_id,
expires_at=expires_at,
is_active=True
)
db.add(db_invite)
await db.flush()
stmt = (
select(InviteModel)
.where(InviteModel.id == db_invite.id)
.options(
selectinload(InviteModel.group),
selectinload(InviteModel.creator)
)
)
result = await db.execute(stmt)
loaded_invite = result.scalar_one_or_none()
if loaded_invite is None:
raise InviteOperationError("Failed to load invite after creation and flush.")
return loaded_invite
except InviteOperationError: # Already specific, re-raise
raise
except IntegrityError as e:
logger.error(f"Database integrity error during invite creation for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to create invite due to DB integrity issue: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during invite creation for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error during invite creation: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during invite creation for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error during invite creation: {str(e)}")
async def get_active_invite_for_group(db: AsyncSession, group_id: int) -> Optional[InviteModel]:
"""Gets the currently active and non-expired invite for a specific group."""
now = datetime.now(timezone.utc)
try:
stmt = (
select(InviteModel).where(
InviteModel.group_id == group_id,
InviteModel.is_active == True,
InviteModel.expires_at > now # Still respect expiry, even if very long
)
.order_by(InviteModel.created_at.desc()) # Get the most recent one if multiple (should not happen)
.limit(1)
.options(
selectinload(InviteModel.group), # Eager load group
selectinload(InviteModel.creator) # Eager load creator
)
# Generate a unique code, retrying if a collision occurs (highly unlikely but safe)
while attempts < MAX_CODE_GENERATION_ATTEMPTS:
attempts += 1
potential_code = secrets.token_urlsafe(16)
# Check if an *active* invite with this code already exists
existing = await db.execute(
select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
)
result = await db.execute(stmt)
return result.scalars().first()
except OperationalError as e:
logger.error(f"Database connection error fetching active invite for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error fetching active invite for group {group_id}: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"DB query error fetching active invite for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseQueryError(f"DB query error fetching active invite for group {group_id}: {str(e)}")
if existing.scalar_one_or_none() is None:
code = potential_code
break
if code is None:
# Failed to generate a unique code after several attempts
return None
db_invite = InviteModel(
code=code,
group_id=group_id,
created_by_id=creator_id,
expires_at=expires_at,
is_active=True
)
db.add(db_invite)
await db.commit()
await db.refresh(db_invite)
return db_invite
async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]:
"""Gets an active and non-expired invite by its code."""
now = datetime.now(timezone.utc)
try:
stmt = (
select(InviteModel).where(
InviteModel.code == code,
InviteModel.is_active == True,
InviteModel.expires_at > now
)
.options(
selectinload(InviteModel.group),
selectinload(InviteModel.creator)
)
result = await db.execute(
select(InviteModel).where(
InviteModel.code == code,
InviteModel.is_active == True,
InviteModel.expires_at > now
)
result = await db.execute(stmt)
return result.scalars().first()
except OperationalError as e:
raise DatabaseConnectionError(f"DB connection error fetching invite: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"DB query error fetching invite: {str(e)}")
)
return result.scalars().first()
async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteModel:
"""Marks an invite as inactive (used) and reloads with relationships."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
invite.is_active = False
db.add(invite) # Add to session to track change
await db.flush() # Persist is_active change
# Re-fetch with relationships
stmt = (
select(InviteModel)
.where(InviteModel.id == invite.id)
.options(
selectinload(InviteModel.group),
selectinload(InviteModel.creator)
)
)
result = await db.execute(stmt)
updated_invite = result.scalar_one_or_none()
if updated_invite is None: # Should not happen as invite is passed in
raise InviteOperationError("Failed to load invite after deactivation.")
return updated_invite
except OperationalError as e:
logger.error(f"Database connection error deactivating invite: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error deactivating invite: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error deactivating invite: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error deactivating invite: {str(e)}")
# Ensure InviteOperationError is defined in app.core.exceptions
# Example: class InviteOperationError(AppException): pass
"""Marks an invite as inactive (used)."""
invite.is_active = False
db.add(invite) # Add to session to track change
await db.commit()
await db.refresh(invite)
return invite
# Optional: Function to periodically delete old, inactive invites
# async def cleanup_old_invites(db: AsyncSession, older_than_days: int = 30): ...

View File

@ -1,14 +1,12 @@
# app/crud/item.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
from sqlalchemy import delete as sql_delete, update as sql_update # Use aliases
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List as PyList
from datetime import datetime, timezone
import logging # Add logging import
from app.models import Item as ItemModel, User as UserModel # Import UserModel for type hints if needed for selectinload
from app.models import Item as ItemModel
from app.schemas.item import ItemCreate, ItemUpdate
from app.core.exceptions import (
ItemNotFoundError,
@ -16,68 +14,46 @@ from app.core.exceptions import (
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
ConflictError,
ItemOperationError # Add if specific item operation errors are needed
ConflictError
)
logger = logging.getLogger(__name__) # Initialize logger
async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel:
"""Creates a new item record for a specific list."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
db_item = ItemModel(
name=item_in.name,
quantity=item_in.quantity,
list_id=list_id,
added_by_id=user_id,
is_complete=False
)
db.add(db_item)
await db.flush() # Assigns ID
# Re-fetch with relationships
stmt = (
select(ItemModel)
.where(ItemModel.id == db_item.id)
.options(
selectinload(ItemModel.added_by_user),
selectinload(ItemModel.completed_by_user) # Will be None but loads relationship
)
)
result = await db.execute(stmt)
loaded_item = result.scalar_one_or_none()
if loaded_item is None:
# await transaction.rollback() # Redundant, context manager handles rollback on exception
raise ItemOperationError("Failed to load item after creation.") # Define ItemOperationError
return loaded_item
db_item = ItemModel(
name=item_in.name,
quantity=item_in.quantity,
list_id=list_id,
added_by_id=user_id,
is_complete=False # Default on creation
# version is implicitly set to 1 by model default
)
db.add(db_item)
await db.flush()
await db.refresh(db_item)
await db.commit() # Explicitly commit here
return db_item
except IntegrityError as e:
logger.error(f"Database integrity error during item creation: {str(e)}", exc_info=True)
await db.rollback() # Rollback on integrity error
raise DatabaseIntegrityError(f"Failed to create item: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during item creation: {str(e)}", exc_info=True)
await db.rollback() # Rollback on operational error
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during item creation: {str(e)}", exc_info=True)
await db.rollback() # Rollback on other SQLAlchemy errors
raise DatabaseTransactionError(f"Failed to create item: {str(e)}")
# Removed generic Exception block as SQLAlchemyError should cover DB issues,
# and context manager handles rollback.
except Exception as e: # Catch any other exception and attempt rollback
await db.rollback()
raise # Re-raise the original exception
async def get_items_by_list_id(db: AsyncSession, list_id: int) -> PyList[ItemModel]:
"""Gets all items belonging to a specific list, ordered by creation time."""
try:
stmt = (
result = await db.execute(
select(ItemModel)
.where(ItemModel.list_id == list_id)
.options(
selectinload(ItemModel.added_by_user),
selectinload(ItemModel.completed_by_user)
)
.order_by(ItemModel.created_at.asc())
.order_by(ItemModel.created_at.asc()) # Or desc() if preferred
)
result = await db.execute(stmt)
return result.scalars().all()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
@ -87,16 +63,7 @@ async def get_items_by_list_id(db: AsyncSession, list_id: int) -> PyList[ItemMod
async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]:
"""Gets a single item by its ID."""
try:
stmt = (
select(ItemModel)
.where(ItemModel.id == item_id)
.options(
selectinload(ItemModel.added_by_user),
selectinload(ItemModel.completed_by_user),
selectinload(ItemModel.list) # Often useful to get the parent list
)
)
result = await db.execute(stmt)
result = await db.execute(select(ItemModel).where(ItemModel.id == item_id))
return result.scalars().first()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
@ -106,74 +73,59 @@ async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]:
async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel:
"""Updates an existing item record, checking for version conflicts."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
if item_db.version != item_in.version:
# No need to rollback here, as the transaction hasn't committed.
# The context manager will handle rollback if an exception is raised.
raise ConflictError(
f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. "
f"Your version is {item_in.version}, current version is {item_db.version}. Please refresh."
)
update_data = item_in.model_dump(exclude_unset=True, exclude={'version'})
if 'is_complete' in update_data:
if update_data['is_complete'] is True:
if item_db.completed_by_id is None:
update_data['completed_by_id'] = user_id
else:
update_data['completed_by_id'] = None
for key, value in update_data.items():
setattr(item_db, key, value)
item_db.version += 1
db.add(item_db) # Mark as dirty
await db.flush()
# Re-fetch with relationships
stmt = (
select(ItemModel)
.where(ItemModel.id == item_db.id)
.options(
selectinload(ItemModel.added_by_user),
selectinload(ItemModel.completed_by_user),
selectinload(ItemModel.list)
)
# Check version conflict
if item_db.version != item_in.version:
raise ConflictError(
f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. "
f"Your version is {item_in.version}, current version is {item_db.version}. Please refresh."
)
result = await db.execute(stmt)
updated_item = result.scalar_one_or_none()
if updated_item is None: # Should not happen
# Rollback will be handled by context manager on raise
raise ItemOperationError("Failed to load item after update.")
update_data = item_in.model_dump(exclude_unset=True, exclude={'version'}) # Exclude version
return updated_item
# Special handling for is_complete
if 'is_complete' in update_data:
if update_data['is_complete'] is True:
if item_db.completed_by_id is None: # Only set if not already completed by someone
update_data['completed_by_id'] = user_id
else:
update_data['completed_by_id'] = None # Clear if marked incomplete
# Apply updates
for key, value in update_data.items():
setattr(item_db, key, value)
item_db.version += 1 # Increment version
db.add(item_db)
await db.flush()
await db.refresh(item_db)
# Commit the transaction if not part of a larger transaction
await db.commit()
return item_db
except IntegrityError as e:
logger.error(f"Database integrity error during item update: {str(e)}", exc_info=True)
await db.rollback()
raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error while updating item: {str(e)}", exc_info=True)
await db.rollback()
raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}")
except ConflictError: # Re-raise ConflictError, rollback handled by context manager
except ConflictError: # Re-raise ConflictError
await db.rollback()
raise
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during item update: {str(e)}", exc_info=True)
await db.rollback()
raise DatabaseTransactionError(f"Failed to update item: {str(e)}")
async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
"""Deletes an item record. Version check should be done by the caller (API endpoint)."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
await db.delete(item_db)
# await transaction.commit() # Removed
# No return needed for None
await db.delete(item_db)
await db.commit()
return None
except OperationalError as e:
logger.error(f"Database connection error while deleting item: {str(e)}", exc_info=True)
await db.rollback()
raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error while deleting item: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to delete item: {str(e)}")
# Ensure ItemOperationError is defined in app.core.exceptions if used
# Example: class ItemOperationError(AppException): pass
await db.rollback()
raise DatabaseTransactionError(f"Failed to delete item: {str(e)}")

View File

@ -5,7 +5,6 @@ from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import or_, and_, delete as sql_delete, func as sql_func, desc
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List as PyList
import logging # Add logging import
from app.schemas.list import ListStatus
from app.models import List as ListModel, UserGroup as UserGroupModel, Item as ItemModel
@ -18,16 +17,15 @@ from app.core.exceptions import (
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
ConflictError,
ListOperationError
ConflictError
)
logger = logging.getLogger(__name__) # Initialize logger
async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
"""Creates a new list record."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
# Check if we're already in a transaction
if db.in_transaction():
# If we're already in a transaction, just create the list
db_list = ListModel(
name=list_in.name,
description=list_in.description,
@ -36,33 +34,28 @@ async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) ->
is_complete=False
)
db.add(db_list)
await db.flush() # Assigns ID
# Re-fetch with relationships for the response
stmt = (
select(ListModel)
.where(ListModel.id == db_list.id)
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
# selectinload(ListModel.items) # Optionally add if items are always needed in response
await db.flush()
await db.refresh(db_list)
return db_list
else:
# If no transaction is active, start one
async with db.begin():
db_list = ListModel(
name=list_in.name,
description=list_in.description,
group_id=list_in.group_id,
created_by_id=creator_id,
is_complete=False
)
)
result = await db.execute(stmt)
loaded_list = result.scalar_one_or_none()
if loaded_list is None:
raise ListOperationError("Failed to load list after creation.")
return loaded_list
db.add(db_list)
await db.flush()
await db.refresh(db_list)
return db_list
except IntegrityError as e:
logger.error(f"Database integrity error during list creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to create list: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during list creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during list creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to create list: {str(e)}")
async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel]:
@ -73,25 +66,14 @@ async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel
)
user_group_ids = group_ids_result.scalars().all()
# Build conditions for the OR clause dynamically
conditions = [
and_(ListModel.created_by_id == user_id, ListModel.group_id.is_(None))
]
if user_group_ids:
if user_group_ids: # Only add the IN clause if there are group IDs
conditions.append(ListModel.group_id.in_(user_group_ids))
query = (
select(ListModel)
.where(or_(*conditions))
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group),
selectinload(ListModel.items).options(
joinedload(ItemModel.added_by_user),
joinedload(ItemModel.completed_by_user)
)
)
.order_by(ListModel.updated_at.desc())
)
query = select(ListModel).where(or_(*conditions)).order_by(ListModel.updated_at.desc())
result = await db.execute(query)
return result.scalars().all()
@ -103,17 +85,11 @@ async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel
async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = False) -> Optional[ListModel]:
"""Gets a single list by ID, optionally loading its items."""
try:
query = (
select(ListModel)
.where(ListModel.id == list_id)
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
)
)
query = select(ListModel).where(ListModel.id == list_id)
if load_items:
query = query.options(
selectinload(ListModel.items).options(
selectinload(ListModel.items)
.options(
joinedload(ItemModel.added_by_user),
joinedload(ItemModel.completed_by_user)
)
@ -128,8 +104,8 @@ async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = Fals
async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel:
"""Updates an existing list record, checking for version conflicts."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
if list_db.version != list_in.version: # list_db here is the one passed in, pre-loaded by API layer
async with db.begin():
if list_db.version != list_in.version:
raise ConflictError(
f"List '{list_db.name}' (ID: {list_db.id}) has been modified. "
f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh."
@ -142,48 +118,34 @@ async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate)
list_db.version += 1
db.add(list_db) # Add the already attached list_db to mark it dirty for the session
db.add(list_db)
await db.flush()
# Re-fetch with relationships for the response
stmt = (
select(ListModel)
.where(ListModel.id == list_db.id)
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
# selectinload(ListModel.items) # Optionally add if items are always needed in response
)
)
result = await db.execute(stmt)
updated_list = result.scalar_one_or_none()
if updated_list is None: # Should not happen
raise ListOperationError("Failed to load list after update.")
return updated_list
await db.refresh(list_db)
return list_db
except IntegrityError as e:
logger.error(f"Database integrity error during list update: {str(e)}", exc_info=True)
await db.rollback()
raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error while updating list: {str(e)}", exc_info=True)
await db.rollback()
raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}")
except ConflictError:
await db.rollback()
raise
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during list update: {str(e)}", exc_info=True)
await db.rollback()
raise DatabaseTransactionError(f"Failed to update list: {str(e)}")
async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
"""Deletes a list record. Version check should be done by the caller (API endpoint)."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Standardize transaction
async with db.begin():
await db.delete(list_db)
return None
except OperationalError as e:
logger.error(f"Database connection error while deleting list: {str(e)}", exc_info=True)
await db.rollback()
raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error while deleting list: {str(e)}", exc_info=True)
await db.rollback()
raise DatabaseTransactionError(f"Failed to delete list: {str(e)}")
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel:
@ -250,48 +212,39 @@ async def get_list_by_name_and_group(
db: AsyncSession,
name: str,
group_id: Optional[int],
user_id: int # user_id is for permission check, not direct list attribute
user_id: int
) -> Optional[ListModel]:
"""
Gets a list by name and group, ensuring the user has permission to access it.
Used for conflict resolution when creating lists.
"""
try:
# Base query for the list itself
base_query = select(ListModel).where(ListModel.name == name)
# Build the base query
query = select(ListModel).where(ListModel.name == name)
# Add group condition
if group_id is not None:
base_query = base_query.where(ListModel.group_id == group_id)
query = query.where(ListModel.group_id == group_id)
else:
base_query = base_query.where(ListModel.group_id.is_(None))
query = query.where(ListModel.group_id.is_(None))
# Add eager loading for common relationships
base_query = base_query.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
)
# Add permission conditions
conditions = [
ListModel.created_by_id == user_id # User is creator
]
if group_id is not None:
# User is member of the group
conditions.append(
and_(
ListModel.group_id == group_id,
ListModel.created_by_id != user_id # Not the creator
)
)
list_result = await db.execute(base_query)
target_list = list_result.scalar_one_or_none()
if not target_list:
return None
# Permission check
is_creator = target_list.created_by_id == user_id
if is_creator:
return target_list
if target_list.group_id:
from app.crud.group import is_user_member # Assuming this is a quick check not needing its own transaction
is_member_of_group = await is_user_member(db, group_id=target_list.group_id, user_id=user_id)
if is_member_of_group:
return target_list
# If not creator and (not a group list or not a member of the group list)
return None
query = query.where(or_(*conditions))
result = await db.execute(query)
return result.scalars().first()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:

View File

@ -3,144 +3,84 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import or_
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
from decimal import Decimal, ROUND_HALF_UP
from decimal import Decimal
from typing import List as PyList, Optional, Sequence
from datetime import datetime, timezone
import logging # Add logging import
from app.models import (
Settlement as SettlementModel,
User as UserModel,
Group as GroupModel,
UserGroup as UserGroupModel
Group as GroupModel
)
from app.schemas.expense import SettlementCreate, SettlementUpdate
from app.core.exceptions import (
UserNotFoundError,
GroupNotFoundError,
InvalidOperationError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
SettlementOperationError,
ConflictError
)
logger = logging.getLogger(__name__) # Initialize logger
from app.schemas.expense import SettlementCreate, SettlementUpdate # SettlementUpdate not used yet
from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError
async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel:
"""Creates a new settlement record."""
# Validate Payer, Payee, and Group exist
payer = await db.get(UserModel, settlement_in.paid_by_user_id)
if not payer:
raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer")
payee = await db.get(UserModel, settlement_in.paid_to_user_id)
if not payee:
raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee")
if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id:
raise InvalidOperationError("Payer and Payee cannot be the same user.")
group = await db.get(GroupModel, settlement_in.group_id)
if not group:
raise GroupNotFoundError(settlement_in.group_id)
# Optional: Check if current_user_id is part of the group or is one of the parties involved
# This is more of an API-level permission check but could be added here if strict.
# For example: if current_user_id not in [settlement_in.paid_by_user_id, settlement_in.paid_to_user_id]:
# is_in_group = await db.execute(select(UserGroupModel).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id))
# if not is_in_group.first():
# raise InvalidOperationError("You can only record settlements you are part of or for groups you belong to.")
db_settlement = SettlementModel(
group_id=settlement_in.group_id,
paid_by_user_id=settlement_in.paid_by_user_id,
paid_to_user_id=settlement_in.paid_to_user_id,
amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc),
description=settlement_in.description
# created_by_user_id = current_user_id # Optional: Who recorded this settlement
)
db.add(db_settlement)
try:
async with db.begin_nested() if db.in_transaction() else db.begin():
payer = await db.get(UserModel, settlement_in.paid_by_user_id)
if not payer:
raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer")
payee = await db.get(UserModel, settlement_in.paid_to_user_id)
if not payee:
raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee")
if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id:
raise InvalidOperationError("Payer and Payee cannot be the same user.")
group = await db.get(GroupModel, settlement_in.group_id)
if not group:
raise GroupNotFoundError(settlement_in.group_id)
# Permission check example (can be in API layer too)
# if current_user_id not in [payer.id, payee.id]:
# is_member_stmt = select(UserGroupModel.id).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id).limit(1)
# is_member_result = await db.execute(is_member_stmt)
# if not is_member_result.scalar_one_or_none():
# raise InvalidOperationError("Settlement recorder must be part of the group or one of the parties.")
db_settlement = SettlementModel(
group_id=settlement_in.group_id,
paid_by_user_id=settlement_in.paid_by_user_id,
paid_to_user_id=settlement_in.paid_to_user_id,
amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc),
description=settlement_in.description,
created_by_user_id=current_user_id
)
db.add(db_settlement)
await db.flush()
# Re-fetch with relationships
stmt = (
select(SettlementModel)
.where(SettlementModel.id == db_settlement.id)
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
)
)
result = await db.execute(stmt)
loaded_settlement = result.scalar_one_or_none()
if loaded_settlement is None:
raise SettlementOperationError("Failed to load settlement after creation.")
return loaded_settlement
except (UserNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
# These are validation errors, re-raise them.
# If a transaction was started, context manager handles rollback.
raise
except IntegrityError as e:
logger.error(f"Database integrity error during settlement creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to save settlement due to DB integrity: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during settlement creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error during settlement creation: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during settlement creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error during settlement creation: {str(e)}")
await db.commit()
await db.refresh(db_settlement, attribute_names=["payer", "payee", "group"])
except Exception as e:
await db.rollback()
raise InvalidOperationError(f"Failed to save settlement: {str(e)}")
return db_settlement
async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional[SettlementModel]:
try:
result = await db.execute(
select(SettlementModel)
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
)
.where(SettlementModel.id == settlement_id)
result = await db.execute(
select(SettlementModel)
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group)
)
return result.scalars().first()
except OperationalError as e:
# Optional: logger.warning or info if needed for read operations
raise DatabaseConnectionError(f"DB connection error fetching settlement: {str(e)}")
except SQLAlchemyError as e:
# Optional: logger.warning or info if needed for read operations
raise DatabaseQueryError(f"DB query error fetching settlement: {str(e)}")
.where(SettlementModel.id == settlement_id)
)
return result.scalars().first()
async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]:
try:
result = await db.execute(
select(SettlementModel)
.where(SettlementModel.group_id == group_id)
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
.offset(skip).limit(limit)
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
)
)
return result.scalars().all()
except OperationalError as e:
raise DatabaseConnectionError(f"DB connection error fetching group settlements: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"DB query error fetching group settlements: {str(e)}")
result = await db.execute(
select(SettlementModel)
.where(SettlementModel.group_id == group_id)
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
.offset(skip).limit(limit)
.options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee))
)
return result.scalars().all()
async def get_settlements_involving_user(
db: AsyncSession,
@ -149,29 +89,18 @@ async def get_settlements_involving_user(
skip: int = 0,
limit: int = 100
) -> Sequence[SettlementModel]:
try:
query = (
select(SettlementModel)
.where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id))
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
.offset(skip).limit(limit)
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
)
)
if group_id:
query = query.where(SettlementModel.group_id == group_id)
result = await db.execute(query)
return result.scalars().all()
except OperationalError as e:
raise DatabaseConnectionError(f"DB connection error fetching user settlements: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"DB query error fetching user settlements: {str(e)}")
query = (
select(SettlementModel)
.where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id))
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
.offset(skip).limit(limit)
.options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), selectinload(SettlementModel.group))
)
if group_id:
query = query.where(SettlementModel.group_id == group_id)
result = await db.execute(query)
return result.scalars().all()
async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, settlement_in: SettlementUpdate) -> SettlementModel:
"""
@ -179,103 +108,58 @@ async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, se
Only allows updates to description and settlement_date.
Requires version matching for optimistic locking.
Assumes SettlementUpdate schema includes a version field.
Assumes SettlementModel has version and updated_at fields.
"""
# Check if SettlementUpdate schema has 'version'. If not, this check needs to be adapted or version passed differently.
if not hasattr(settlement_in, 'version') or settlement_db.version != settlement_in.version:
raise InvalidOperationError(
f"Settlement (ID: {settlement_db.id}) has been modified. "
f"Your version does not match current version {settlement_db.version}. Please refresh.",
# status_code=status.HTTP_409_CONFLICT
)
update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"})
allowed_to_update = {"description", "settlement_date"}
updated_something = False
for field, value in update_data.items():
if field in allowed_to_update:
setattr(settlement_db, field, value)
updated_something = True
else:
raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed for settlements.")
if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update):
pass # No actual updatable fields provided, but version matched.
settlement_db.version += 1 # Assuming SettlementModel has a version field, add if missing
settlement_db.updated_at = datetime.now(timezone.utc)
try:
async with db.begin_nested() if db.in_transaction() else db.begin():
# Ensure the settlement_db passed is managed by the current session if not already.
# This is usually true if fetched by an endpoint dependency using the same session.
# If not, `db.add(settlement_db)` might be needed before modification if it's detached.
if not hasattr(settlement_db, 'version') or not hasattr(settlement_in, 'version'):
raise InvalidOperationError("Version field is missing in model or input for optimistic locking.")
if settlement_db.version != settlement_in.version:
raise ConflictError( # Make sure ConflictError is defined in exceptions
f"Settlement (ID: {settlement_db.id}) has been modified. "
f"Your version {settlement_in.version} does not match current version {settlement_db.version}. Please refresh."
)
update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"})
allowed_to_update = {"description", "settlement_date"}
updated_something = False
for field, value in update_data.items():
if field in allowed_to_update:
setattr(settlement_db, field, value)
updated_something = True
# Silently ignore fields not allowed to update or raise error:
# else:
# raise InvalidOperationError(f"Field '{field}' cannot be updated.")
if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update):
# No updatable fields were actually provided, or they didn't change
# Still, we might want to return the re-loaded settlement if version matched.
pass
settlement_db.version += 1
settlement_db.updated_at = datetime.now(timezone.utc) # Ensure model has this field
db.add(settlement_db) # Mark as dirty
await db.flush()
# Re-fetch with relationships
stmt = (
select(SettlementModel)
.where(SettlementModel.id == settlement_db.id)
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
)
)
result = await db.execute(stmt)
updated_settlement = result.scalar_one_or_none()
if updated_settlement is None: # Should not happen
raise SettlementOperationError("Failed to load settlement after update.")
return updated_settlement
except ConflictError as e: # ConflictError should be defined in exceptions
raise
except InvalidOperationError as e:
raise
except IntegrityError as e:
logger.error(f"Database integrity error during settlement update: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to update settlement due to DB integrity: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during settlement update: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error during settlement update: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during settlement update: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error during settlement update: {str(e)}")
await db.commit()
await db.refresh(settlement_db)
except Exception as e:
await db.rollback()
raise InvalidOperationError(f"Failed to update settlement: {str(e)}")
return settlement_db
async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, expected_version: Optional[int] = None) -> None:
"""
Deletes a settlement. Requires version matching if expected_version is provided.
Assumes SettlementModel has a version field.
"""
try:
async with db.begin_nested() if db.in_transaction() else db.begin():
if expected_version is not None:
if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version:
raise ConflictError( # Make sure ConflictError is defined
f"Settlement (ID: {settlement_db.id}) cannot be deleted. "
f"Expected version {expected_version} does not match current version {settlement_db.version}. Please refresh."
)
if expected_version is not None:
if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version:
raise InvalidOperationError(
f"Settlement (ID: {settlement_db.id}) cannot be deleted. "
f"Expected version {expected_version} does not match current version. Please refresh.",
# status_code=status.HTTP_409_CONFLICT
)
await db.delete(settlement_db)
except ConflictError as e: # ConflictError should be defined
raise
except OperationalError as e:
logger.error(f"Database connection error during settlement deletion: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error during settlement deletion: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during settlement deletion: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error during settlement deletion: {str(e)}")
# Ensure SettlementOperationError and ConflictError are defined in app.core.exceptions
# Example: class SettlementOperationError(AppException): pass
# Example: class ConflictError(AppException): status_code = 409
await db.delete(settlement_db)
try:
await db.commit()
except Exception as e:
await db.rollback()
raise InvalidOperationError(f"Failed to delete settlement: {str(e)}")
return None

View File

@ -1,12 +1,10 @@
# app/crud/user.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional
import logging # Add logging import
from app.models import User as UserModel, UserGroup as UserGroupModel, Group as GroupModel # Import related models for selectinload
from app.models import User as UserModel # Alias to avoid name clash
from app.schemas.user import UserCreate
from app.core.security import hash_password
from app.core.exceptions import (
@ -15,76 +13,39 @@ from app.core.exceptions import (
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
UserOperationError # Add if specific user operation errors are needed
DatabaseTransactionError
)
logger = logging.getLogger(__name__) # Initialize logger
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]:
"""Fetches a user from the database by email, with common relationships."""
"""Fetches a user from the database by email."""
try:
# db.begin() is not strictly necessary for a single read, but ensures atomicity if multiple reads were added.
# For a single select, it can be omitted if preferred, session handles connection.
stmt = (
select(UserModel)
.filter(UserModel.email == email)
.options(
selectinload(UserModel.group_associations).selectinload(UserGroupModel.group), # Groups user is member of
selectinload(UserModel.created_groups) # Groups user created
# Add other relationships as needed by UserPublic schema
)
)
result = await db.execute(stmt)
return result.scalars().first()
async with db.begin():
result = await db.execute(select(UserModel).filter(UserModel.email == email))
return result.scalars().first()
except OperationalError as e:
logger.error(f"Database connection error while fetching user by email '{email}': {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error while fetching user by email '{email}': {str(e)}", exc_info=True)
raise DatabaseQueryError(f"Failed to query user: {str(e)}")
async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel:
"""Creates a new user record in the database with common relationships loaded."""
"""Creates a new user record in the database."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
async with db.begin():
_hashed_password = hash_password(user_in.password)
db_user = UserModel(
email=user_in.email,
hashed_password=_hashed_password, # Field name in model is hashed_password
password_hash=_hashed_password,
name=user_in.name
)
db.add(db_user)
await db.flush() # Flush to get DB-generated values like ID
# Re-fetch with relationships
stmt = (
select(UserModel)
.where(UserModel.id == db_user.id)
.options(
selectinload(UserModel.group_associations).selectinload(UserGroupModel.group),
selectinload(UserModel.created_groups)
# Add other relationships as needed by UserPublic schema
)
)
result = await db.execute(stmt)
loaded_user = result.scalar_one_or_none()
if loaded_user is None:
raise UserOperationError("Failed to load user after creation.") # Define UserOperationError
return loaded_user
await db.flush() # Flush to get DB-generated values
await db.refresh(db_user)
return db_user
except IntegrityError as e:
logger.error(f"Database integrity error during user creation for email '{user_in.email}': {str(e)}", exc_info=True)
if "unique constraint" in str(e).lower() and ("users_email_key" in str(e).lower() or "ix_users_email" in str(e).lower()):
raise EmailAlreadyRegisteredError(email=user_in.email)
raise DatabaseIntegrityError(f"Failed to create user due to integrity issue: {str(e)}")
if "unique constraint" in str(e).lower():
raise EmailAlreadyRegisteredError()
raise DatabaseIntegrityError(f"Failed to create user: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during user creation for email '{user_in.email}': {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error during user creation: {str(e)}")
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during user creation for email '{user_in.email}': {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to create user due to other DB error: {str(e)}")
# Ensure UserOperationError is defined in app.core.exceptions if used
# Example: class UserOperationError(AppException): pass
raise DatabaseTransactionError(f"Failed to create user: {str(e)}")

View File

@ -30,32 +30,21 @@ AsyncSessionLocal = sessionmaker(
Base = declarative_base()
# Dependency to get DB session in path operations
async def get_session() -> AsyncSession: # type: ignore
async def get_async_session() -> AsyncSession: # type: ignore
"""
Dependency function that yields an AsyncSession.
Ensures the session is closed after the request.
"""
async with AsyncSessionLocal() as session:
yield session
# The 'async with' block handles session.close() automatically.
# Commit/rollback should be handled by the functions using the session.
async def get_transactional_session() -> AsyncSession: # type: ignore
"""
Dependency function that yields an AsyncSession and manages a transaction.
Commits the transaction if the request handler succeeds, otherwise rollbacks.
Ensures the session is closed after the request.
"""
async with AsyncSessionLocal() as session:
try:
await session.begin()
yield session
# Commit the transaction if no errors occurred
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
await session.close() # Not strictly necessary with async context manager, but explicit
# Alias for backward compatibility
get_db = get_session
get_db = get_async_session

View File

@ -65,11 +65,9 @@ class User(Base):
# --- Relationships for Cost Splitting ---
expenses_paid = relationship("Expense", foreign_keys="Expense.paid_by_user_id", back_populates="paid_by_user", cascade="all, delete-orphan")
expenses_created = relationship("Expense", foreign_keys="Expense.created_by_user_id", back_populates="created_by_user", cascade="all, delete-orphan")
expense_splits = relationship("ExpenseSplit", foreign_keys="ExpenseSplit.user_id", back_populates="user", cascade="all, delete-orphan")
settlements_made = relationship("Settlement", foreign_keys="Settlement.paid_by_user_id", back_populates="payer", cascade="all, delete-orphan")
settlements_received = relationship("Settlement", foreign_keys="Settlement.paid_to_user_id", back_populates="payee", cascade="all, delete-orphan")
settlements_created = relationship("Settlement", foreign_keys="Settlement.created_by_user_id", back_populates="created_by_user", cascade="all, delete-orphan")
# --- End Relationships for Cost Splitting ---
@ -199,7 +197,6 @@ class Expense(Base):
group_id = Column(Integer, ForeignKey("groups.id"), nullable=True, index=True)
item_id = Column(Integer, ForeignKey("items.id"), nullable=True)
paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
@ -207,7 +204,6 @@ class Expense(Base):
# Relationships
paid_by_user = relationship("User", foreign_keys=[paid_by_user_id], back_populates="expenses_paid")
created_by_user = relationship("User", foreign_keys=[created_by_user_id], back_populates="expenses_created")
list = relationship("List", foreign_keys=[list_id], back_populates="expenses")
group = relationship("Group", foreign_keys=[group_id], back_populates="expenses")
item = relationship("Item", foreign_keys=[item_id], back_populates="expenses")
@ -250,7 +246,6 @@ class Settlement(Base):
amount = Column(Numeric(10, 2), nullable=False)
settlement_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
description = Column(Text, nullable=True)
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
@ -260,7 +255,6 @@ class Settlement(Base):
group = relationship("Group", foreign_keys=[group_id], back_populates="settlements")
payer = relationship("User", foreign_keys=[paid_by_user_id], back_populates="settlements_made")
payee = relationship("User", foreign_keys=[paid_to_user_id], back_populates="settlements_received")
created_by_user = relationship("User", foreign_keys=[created_by_user_id], back_populates="settlements_created")
__table_args__ = (
# Ensure payer and payee are different users

View File

@ -79,7 +79,6 @@ class ExpensePublic(ExpenseBase):
created_at: datetime
updated_at: datetime
version: int
created_by_user_id: int
splits: List[ExpenseSplitPublic] = []
# paid_by_user: Optional[UserPublic] # If nesting user details
# list: Optional[ListPublic] # If nesting list details
@ -120,11 +119,9 @@ class SettlementPublic(SettlementBase):
id: int
created_at: datetime
updated_at: datetime
version: int
created_by_user_id: int
# payer: Optional[UserPublic] # If we want to include payer details
# payee: Optional[UserPublic] # If we want to include payee details
# group: Optional[GroupPublic] # If we want to include group details
# payer: Optional[UserPublic]
# payee: Optional[UserPublic]
# group: Optional[GroupPublic]
model_config = ConfigDict(from_attributes=True)
# Placeholder for nested schemas (e.g., UserPublic) if needed

View File

@ -1,5 +0,0 @@
[pytest]
pythonpath = .
testpaths = tests
python_files = test_*.py
asyncio_mode = auto

View File

@ -16,9 +16,4 @@ fastapi-users[sqlalchemy]>=12.1.2
email-validator>=2.0.0
fastapi-users[oauth]>=12.1.2
authlib>=1.3.0
itsdangerous>=2.1.2
pytest>=7.4.0
pytest-asyncio>=0.21.0
pytest-cov>=4.1.0
httpx>=0.24.0 # For async HTTP testing
aiosqlite>=0.19.0 # For async SQLite support in tests
itsdangerous>=2.1.2

View File

@ -3,52 +3,41 @@ from fastapi import status
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Callable, Dict, Any
from unittest.mock import patch, MagicMock
from app.models import User as UserModel, Group as GroupModel, List as ListModel
from app.schemas.expense import ExpenseCreate, ExpensePublic, ExpenseUpdate
# from app.config import settings # Comment out the original import
from app.schemas.expense import ExpenseCreate
from app.core.config import settings
# Helper to create a URL for an endpoint
# API_V1_STR = settings.API_V1_STR # Comment out the original assignment
@pytest.fixture(scope="module")
def mock_settings_financials():
mock_settings = MagicMock()
mock_settings.API_V1_STR = "/api/v1"
return mock_settings
# Patch the settings in the test module
@pytest.fixture(autouse=True)
def patch_settings_financials(mock_settings_financials):
with patch("app.config.settings", mock_settings_financials):
yield
API_V1_STR = settings.API_V1_STR
def expense_url(endpoint: str = "") -> str:
# Use the mocked API_V1_STR via the patched settings object
from app.config import settings # Import settings here to use the patched version
return f"{settings.API_V1_STR}/financials/expenses{endpoint}"
return f"{API_V1_STR}/financials/expenses{endpoint}"
def settlement_url(endpoint: str = "") -> str:
# Use the mocked API_V1_STR via the patched settings object
from app.config import settings # Import settings here to use the patched version
return f"{settings.API_V1_STR}/financials/settlements{endpoint}"
return f"{API_V1_STR}/financials/settlements{endpoint}"
@pytest.mark.asyncio
async def test_create_new_expense_success_list_context(
client: AsyncClient,
db_session: AsyncSession,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_list_user_is_member: ListModel,
db_session: AsyncSession, # Assuming a fixture for db session
normal_user_token_headers: Dict[str, str], # Assuming a fixture for user auth
test_user: UserModel, # Assuming a fixture for a test user
test_list_user_is_member: ListModel, # Assuming a fixture for a list user is member of
) -> None:
"""
Test successful creation of a new expense linked to a list.
"""
expense_data = ExpenseCreate(
description="Test Expense for List",
amount=100.00,
currency="USD",
paid_by_user_id=test_user.id,
list_id=test_list_user_is_member.id,
group_id=None,
group_id=None, # group_id should be derived from list if list is in a group
# category_id: Optional[int] = None # Assuming category is optional
# expense_date: Optional[date] = None # Assuming date is optional
# splits: Optional[List[SplitCreate]] = [] # Assuming splits are optional for now
)
response = await client.post(
@ -64,6 +53,7 @@ async def test_create_new_expense_success_list_context(
assert content["currency"] == expense_data.currency
assert content["paid_by_user_id"] == test_user.id
assert content["list_id"] == test_list_user_is_member.id
# If test_list_user_is_member has a group_id, it should be set in the response
if test_list_user_is_member.group_id:
assert content["group_id"] == test_list_user_is_member.group_id
else:
@ -79,8 +69,11 @@ async def test_create_new_expense_success_group_context(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_group_user_is_member: GroupModel,
test_group_user_is_member: GroupModel, # Assuming a fixture for a group user is member of
) -> None:
"""
Test successful creation of a new expense linked directly to a group.
"""
expense_data = ExpenseCreate(
description="Test Expense for Group",
amount=50.00,
@ -110,6 +103,9 @@ async def test_create_new_expense_fail_no_list_or_group(
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
) -> None:
"""
Test expense creation fails if neither list_id nor group_id is provided.
"""
expense_data = ExpenseCreate(
description="Test Invalid Expense",
amount=10.00,
@ -132,23 +128,28 @@ async def test_create_new_expense_fail_no_list_or_group(
@pytest.mark.asyncio
async def test_create_new_expense_fail_paid_by_other_not_owner(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_group_user_is_member: GroupModel,
another_user_in_group: UserModel,
normal_user_token_headers: Dict[str, str], # User is member, not owner
test_user: UserModel, # This is the current_user (member)
test_group_user_is_member: GroupModel, # Group the current_user is a member of
another_user_in_group: UserModel, # Another user in the same group
# Ensure test_user is NOT an owner of test_group_user_is_member for this test
) -> None:
"""
Test creation fails if paid_by_user_id is another user, and current_user is not a group owner.
Assumes normal_user_token_headers belongs to a user who is a member but not an owner of test_group_user_is_member.
"""
expense_data = ExpenseCreate(
description="Expense paid by other",
amount=75.00,
currency="GBP",
paid_by_user_id=another_user_in_group.id,
paid_by_user_id=another_user_in_group.id, # Paid by someone else
group_id=test_group_user_is_member.id,
list_id=None,
)
response = await client.post(
expense_url(),
headers=normal_user_token_headers,
headers=normal_user_token_headers, # Current user is a member, not owner
json=expense_data.model_dump(exclude_unset=True)
)
@ -156,13 +157,22 @@ async def test_create_new_expense_fail_paid_by_other_not_owner(
content = response.json()
assert "Only group owners can create expenses paid by others" in content["detail"]
# --- Add tests for other endpoints below ---
# GET /expenses/{expense_id}
@pytest.mark.asyncio
async def test_get_expense_success(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
created_expense: ExpensePublic,
# Assume an existing expense created by test_user or in a group/list they have access to
# This would typically be created by another test or a fixture
created_expense: ExpensePublic, # Assuming a fixture that provides a created expense
) -> None:
"""
Test successfully retrieving an existing expense.
User has access either by being the payer, or via list/group membership.
"""
response = await client.get(
expense_url(f"/{created_expense.id}"),
headers=normal_user_token_headers
@ -171,136 +181,148 @@ async def test_get_expense_success(
content = response.json()
assert content["id"] == created_expense.id
assert content["description"] == created_expense.description
assert content["amount"] == created_expense.amount
assert content["paid_by_user_id"] == created_expense.paid_by_user_id
if created_expense.list_id:
assert content["list_id"] == created_expense.list_id
if created_expense.group_id:
assert content["group_id"] == created_expense.group_id
# TODO: Add more tests for get_expense:
# - expense not found -> 404
# - user has no access (not payer, not in list, not in group if applicable) -> 403
# - expense in list, user has list access
# - expense in group, user has group access
# - expense personal (no list, no group), user is payer
# - expense personal (no list, no group), user is NOT payer -> 403
@pytest.mark.asyncio
async def test_get_expense_not_found(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
) -> None:
"""
Test retrieving a non-existent expense results in 404.
"""
non_existent_expense_id = 9999999
response = await client.get(
expense_url("/999"),
expense_url(f"/{non_existent_expense_id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_404_NOT_FOUND
content = response.json()
assert "Expense not found" in content["detail"]
assert "not found" in content["detail"].lower()
@pytest.mark.asyncio
async def test_get_expense_forbidden_personal_expense_other_user(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
personal_expense_of_another_user: ExpensePublic,
normal_user_token_headers: Dict[str, str], # Belongs to test_user
# Fixture for an expense paid by another_user, not linked to any list/group test_user has access to
personal_expense_of_another_user: ExpensePublic
) -> None:
"""
Test retrieving a personal expense of another user (no shared list/group) results in 403.
"""
response = await client.get(
expense_url(f"/{personal_expense_of_another_user.id}"),
headers=normal_user_token_headers
headers=normal_user_token_headers # Current user querying
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "You do not have permission to access this expense" in content["detail"]
assert "Not authorized to view this expense" in content["detail"]
@pytest.mark.asyncio
async def test_get_expense_forbidden_not_member_of_list_or_group(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
another_user: UserModel,
expense_in_inaccessible_list_or_group: ExpensePublic,
) -> None:
response = await client.get(
expense_url(f"/{expense_in_inaccessible_list_or_group.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "You do not have permission to access this expense" in content["detail"]
@pytest.mark.asyncio
async def test_get_expense_success_in_list_user_has_access(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
expense_in_accessible_list: ExpensePublic,
test_list_user_is_member: ListModel,
) -> None:
response = await client.get(
expense_url(f"/{expense_in_accessible_list.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert content["id"] == expense_in_accessible_list.id
assert content["list_id"] == test_list_user_is_member.id
@pytest.mark.asyncio
async def test_get_expense_success_in_group_user_has_access(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
expense_in_accessible_group: ExpensePublic,
test_group_user_is_member: GroupModel,
) -> None:
response = await client.get(
expense_url(f"/{expense_in_accessible_group.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert content["id"] == expense_in_accessible_group.id
assert content["group_id"] == test_group_user_is_member.id
# GET /lists/{list_id}/expenses
@pytest.mark.asyncio
async def test_list_list_expenses_success(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_list_user_is_member: ListModel,
test_list_user_is_member: ListModel, # List the user is a member of
# Assume some expenses have been created for this list by a fixture or previous tests
) -> None:
"""
Test successfully listing expenses for a list the user has access to.
"""
response = await client.get(
expense_url(f"?list_id={test_list_user_is_member.id}"),
f"{API_V1_STR}/financials/lists/{test_list_user_is_member.id}/expenses",
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
for expense in content:
assert expense["list_id"] == test_list_user_is_member.id
for expense_item in content: # Renamed from expense to avoid conflict if a fixture is named expense
assert expense_item["list_id"] == test_list_user_is_member.id
# TODO: Add more tests for list_list_expenses:
# - list not found -> 404 (ListNotFoundError from check_list_access_for_financials)
# - user has no access to list -> 403 (ListPermissionError from check_list_access_for_financials)
# - list exists but has no expenses -> empty list, 200 OK
# - test pagination (skip, limit)
@pytest.mark.asyncio
async def test_list_list_expenses_list_not_found(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
) -> None:
"""
Test listing expenses for a non-existent list results in 404 (or appropriate error from permission check).
The check_list_access_for_financials raises ListNotFoundError, which might be caught and raised as 404.
The endpoint itself also has a get for ListModel, which would 404 first if permission check passed (not possible here).
Based on financials.py, ListNotFoundError is raised by check_list_access_for_financials.
This should translate to a 404 or a 403 if ListPermissionError wraps it with an action.
The current ListPermissionError in check_list_access_for_financials re-raises ListNotFoundError if that's the cause.
ListNotFoundError is a custom exception often mapped to 404.
Let's assume ListNotFoundError results in a 404 response from an exception handler.
"""
non_existent_list_id = 99999
response = await client.get(
expense_url("?list_id=999"),
f"{API_V1_STR}/financials/lists/{non_existent_list_id}/expenses",
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_404_NOT_FOUND
# The ListNotFoundError is raised by the check_list_access_for_financials helper,
# which is then re-raised. FastAPI default exception handlers or custom ones
# would convert this to an HTTP response. Typically NotFoundError -> 404.
# If ListPermissionError catches it and re-raises it specifically, it might be 403.
# From the code: `except ListNotFoundError: raise` means it propagates.
# Let's assume a global handler for NotFoundError derived exceptions leads to 404.
assert response.status_code == status.HTTP_404_NOT_FOUND
# The actual detail might vary based on how ListNotFoundError is handled by FastAPI
# For now, we check the status code. If financials.py maps it differently, this will need adjustment.
# Based on `raise ListNotFoundError(expense_in.list_id)` in create_new_expense, and if that leads to 400,
# this might be inconsistent. However, `check_list_access_for_financials` just re-raises ListNotFoundError.
# Let's stick to expecting 404 for a direct not found error from a path parameter.
content = response.json()
assert "List not found" in content["detail"]
assert "list not found" in content["detail"].lower() # Common detail for not found errors
@pytest.mark.asyncio
async def test_list_list_expenses_no_access(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_list_user_not_member: ListModel,
normal_user_token_headers: Dict[str, str], # User who will attempt access
test_list_user_not_member: ListModel, # A list current user is NOT a member of
) -> None:
"""
Test listing expenses for a list the user does not have access to (403 Forbidden).
"""
response = await client.get(
expense_url(f"?list_id={test_list_user_not_member.id}"),
f"{API_V1_STR}/financials/lists/{test_list_user_not_member.id}/expenses",
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "You do not have permission to access this list" in content["detail"]
assert f"User does not have permission to access financial data for list {test_list_user_not_member.id}" in content["detail"]
@pytest.mark.asyncio
async def test_list_list_expenses_empty(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_list_user_is_member_no_expenses: ListModel,
test_list_user_is_member_no_expenses: ListModel, # List user is member of, but has no expenses
) -> None:
"""
Test listing expenses for an accessible list that has no expenses (empty list, 200 OK).
"""
response = await client.get(
expense_url(f"?list_id={test_list_user_is_member_no_expenses.id}"),
f"{API_V1_STR}/financials/lists/{test_list_user_is_member_no_expenses.id}/expenses",
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
@ -308,342 +330,44 @@ async def test_list_list_expenses_empty(
assert isinstance(content, list)
assert len(content) == 0
@pytest.mark.asyncio
async def test_list_list_expenses_pagination(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_list_with_multiple_expenses: ListModel,
created_expenses_for_list: list[ExpensePublic],
) -> None:
# Test first page
response = await client.get(
expense_url(f"?list_id={test_list_with_multiple_expenses.id}&skip=0&limit=2"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
assert len(content) == 2
assert content[0]["id"] == created_expenses_for_list[0].id
assert content[1]["id"] == created_expenses_for_list[1].id
# Test second page
response = await client.get(
expense_url(f"?list_id={test_list_with_multiple_expenses.id}&skip=2&limit=2"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
assert len(content) == 2
assert content[0]["id"] == created_expenses_for_list[2].id
assert content[1]["id"] == created_expenses_for_list[3].id
# GET /groups/{group_id}/expenses
@pytest.mark.asyncio
async def test_list_group_expenses_success(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_group_user_is_member: GroupModel,
test_group_user_is_member: GroupModel, # Group the user is a member of
# Assume some expenses have been created for this group by a fixture or previous tests
) -> None:
"""
Test successfully listing expenses for a group the user has access to.
"""
response = await client.get(
expense_url(f"?group_id={test_group_user_is_member.id}"),
f"{API_V1_STR}/financials/groups/{test_group_user_is_member.id}/expenses",
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
for expense in content:
assert expense["group_id"] == test_group_user_is_member.id
# Further assertions can be made here, e.g., checking if all expenses belong to the group
for expense_item in content:
assert expense_item["group_id"] == test_group_user_is_member.id
# Expenses in a group might also have a list_id if they were added via a list belonging to that group
@pytest.mark.asyncio
async def test_list_group_expenses_group_not_found(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
) -> None:
response = await client.get(
expense_url("?group_id=999"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_404_NOT_FOUND
content = response.json()
assert "Group not found" in content["detail"]
# TODO: Add more tests for list_group_expenses:
# - group not found -> 404 (GroupNotFoundError from check_group_membership)
# - user has no access to group (not a member) -> 403 (GroupMembershipError from check_group_membership)
# - group exists but has no expenses -> empty list, 200 OK
# - test pagination (skip, limit)
@pytest.mark.asyncio
async def test_list_group_expenses_no_access(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_group_user_not_member: GroupModel,
) -> None:
response = await client.get(
expense_url(f"?group_id={test_group_user_not_member.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "You do not have permission to access this group" in content["detail"]
@pytest.mark.asyncio
async def test_list_group_expenses_empty(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_group_user_is_member_no_expenses: GroupModel,
) -> None:
response = await client.get(
expense_url(f"?group_id={test_group_user_is_member_no_expenses.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
assert len(content) == 0
@pytest.mark.asyncio
async def test_list_group_expenses_pagination(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_group_with_multiple_expenses: GroupModel,
created_expenses_for_group: list[ExpensePublic],
) -> None:
# Test first page
response = await client.get(
expense_url(f"?group_id={test_group_with_multiple_expenses.id}&skip=0&limit=2"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
assert len(content) == 2
assert content[0]["id"] == created_expenses_for_group[0].id
assert content[1]["id"] == created_expenses_for_group[1].id
# Test second page
response = await client.get(
expense_url(f"?group_id={test_group_with_multiple_expenses.id}&skip=2&limit=2"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
assert len(content) == 2
assert content[0]["id"] == created_expenses_for_group[2].id
assert content[1]["id"] == created_expenses_for_group[3].id
@pytest.mark.asyncio
async def test_update_expense_success_payer_updates_details(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
expense_paid_by_test_user: ExpensePublic,
) -> None:
update_data = ExpenseUpdate(
description="Updated expense description",
version=expense_paid_by_test_user.version,
)
response = await client.put(
expense_url(f"/{expense_paid_by_test_user.id}"),
headers=normal_user_token_headers,
json=update_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert content["description"] == update_data.description
assert content["version"] == expense_paid_by_test_user.version + 1
@pytest.mark.asyncio
async def test_update_expense_success_group_owner_updates_others_expense(
client: AsyncClient,
group_owner_token_headers: Dict[str, str],
group_owner: UserModel,
expense_paid_by_another_in_group_where_test_user_is_owner: ExpensePublic,
another_user_in_group: UserModel,
) -> None:
update_data = ExpenseUpdate(
description="Updated by group owner",
version=expense_paid_by_another_in_group_where_test_user_is_owner.version,
)
response = await client.put(
expense_url(f"/{expense_paid_by_another_in_group_where_test_user_is_owner.id}"),
headers=group_owner_token_headers,
json=update_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert content["description"] == update_data.description
assert content["version"] == expense_paid_by_another_in_group_where_test_user_is_owner.version + 1
@pytest.mark.asyncio
async def test_update_expense_fail_not_payer_nor_group_owner(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
expense_paid_by_another_in_group_where_test_user_is_member: ExpensePublic,
another_user_in_group: UserModel,
) -> None:
update_data = ExpenseUpdate(
description="Attempted update by non-owner",
version=expense_paid_by_another_in_group_where_test_user_is_member.version,
)
response = await client.put(
expense_url(f"/{expense_paid_by_another_in_group_where_test_user_is_member.id}"),
headers=normal_user_token_headers,
json=update_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "You do not have permission to update this expense" in content["detail"]
@pytest.mark.asyncio
async def test_update_expense_fail_not_found(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
) -> None:
update_data = ExpenseUpdate(
description="Update attempt on non-existent expense",
version=1,
)
response = await client.put(
expense_url("/999"),
headers=normal_user_token_headers,
json=update_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_404_NOT_FOUND
content = response.json()
assert "Expense not found" in content["detail"]
@pytest.mark.asyncio
async def test_update_expense_fail_change_paid_by_user_not_owner(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
expense_paid_by_test_user_in_group: ExpensePublic,
another_user_in_same_group: UserModel,
) -> None:
update_data = ExpenseUpdate(
paid_by_user_id=another_user_in_same_group.id,
version=expense_paid_by_test_user_in_group.version,
)
response = await client.put(
expense_url(f"/{expense_paid_by_test_user_in_group.id}"),
headers=normal_user_token_headers,
json=update_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "Only group owners can change the payer of an expense" in content["detail"]
@pytest.mark.asyncio
async def test_update_expense_success_owner_changes_paid_by_user(
client: AsyncClient,
group_owner_token_headers: Dict[str, str],
group_owner: UserModel,
expense_in_group_owner_group: ExpensePublic,
another_user_in_same_group: UserModel,
) -> None:
update_data = ExpenseUpdate(
paid_by_user_id=another_user_in_same_group.id,
version=expense_in_group_owner_group.version,
)
response = await client.put(
expense_url(f"/{expense_in_group_owner_group.id}"),
headers=group_owner_token_headers,
json=update_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert content["paid_by_user_id"] == another_user_in_same_group.id
assert content["version"] == expense_in_group_owner_group.version + 1
@pytest.mark.asyncio
async def test_delete_expense_success_payer(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
expense_paid_by_test_user: ExpensePublic,
) -> None:
response = await client.delete(
expense_url(f"/{expense_paid_by_test_user.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_204_NO_CONTENT
@pytest.mark.asyncio
async def test_delete_expense_success_group_owner(
client: AsyncClient,
group_owner_token_headers: Dict[str, str],
group_owner: UserModel,
expense_paid_by_another_in_group_where_test_user_is_owner: ExpensePublic,
) -> None:
response = await client.delete(
expense_url(f"/{expense_paid_by_another_in_group_where_test_user_is_owner.id}"),
headers=group_owner_token_headers
)
assert response.status_code == status.HTTP_204_NO_CONTENT
@pytest.mark.asyncio
async def test_delete_expense_fail_not_payer_nor_group_owner(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
expense_paid_by_another_in_group_where_test_user_is_member: ExpensePublic,
) -> None:
response = await client.delete(
expense_url(f"/{expense_paid_by_another_in_group_where_test_user_is_member.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "You do not have permission to delete this expense" in content["detail"]
@pytest.mark.asyncio
async def test_delete_expense_fail_not_found(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
) -> None:
response = await client.delete(
expense_url("/999"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_404_NOT_FOUND
content = response.json()
assert "Expense not found" in content["detail"]
@pytest.mark.asyncio
async def test_delete_expense_idempotency(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
expense_paid_by_test_user: ExpensePublic,
) -> None:
# First delete
response = await client.delete(
expense_url(f"/{expense_paid_by_test_user.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_204_NO_CONTENT
# Second delete should also succeed
response = await client.delete(
expense_url(f"/{expense_paid_by_test_user.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_204_NO_CONTENT
# PUT /expenses/{expense_id}
# DELETE /expenses/{expense_id}
# GET /settlements/{settlement_id}
# POST /settlements
# GET /groups/{group_id}/settlements
# PUT /settlements/{settlement_id}
# DELETE /settlements/{settlement_id}
# DELETE /settlements/{settlement_id}
pytest.skip("Still implementing other tests", allow_module_level=True)

View File

@ -1,56 +0,0 @@
import pytest
import asyncio
from typing import AsyncGenerator
from fastapi.testclient import TestClient
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.main import app
from app.models import Base
from app.database import get_db
from app.config import settings
# Create test database engine
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
engine = create_async_engine(
TEST_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
)
@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session")
async def test_db():
"""Create test database and tables."""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
async def db_session(test_db) -> AsyncGenerator[AsyncSession, None]:
"""Create a fresh database session for each test."""
async with TestingSessionLocal() as session:
yield session
@pytest.fixture
async def client(db_session) -> AsyncGenerator[TestClient, None]:
"""Create a test client with the test database session."""
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as test_client:
yield test_client
app.dependency_overrides.clear()

View File

@ -30,15 +30,16 @@ def mock_gemini_settings():
@pytest.fixture
def mock_generative_model_instance():
model_instance = AsyncMock(spec=genai.GenerativeModel)
model_instance = MagicMock(spec=genai.GenerativeModel)
model_instance.generate_content_async = AsyncMock()
return model_instance
@pytest.fixture
def patch_google_ai_client(mock_generative_model_instance):
with patch('google.generativeai.GenerativeModel', return_value=mock_generative_model_instance) as mock_generative_model, \
patch('google.generativeai.configure') as mock_configure:
yield mock_configure, mock_generative_model, mock_generative_model_instance
@patch('google.generativeai.GenerativeModel')
@patch('google.generativeai.configure')
def patch_google_ai_client(mock_configure, mock_generative_model, mock_generative_model_instance):
mock_generative_model.return_value = mock_generative_model_instance
return mock_configure, mock_generative_model, mock_generative_model_instance
# --- Test Gemini Client Initialization (Global Client) ---
@ -136,22 +137,25 @@ def test_get_gemini_client_none_client_unknown_issue(mock_client_var, mock_error
async def test_extract_items_from_image_gemini_success(
mock_gemini_settings,
mock_generative_model_instance,
patch_google_ai_client
patch_google_ai_client # This fixture patches google.generativeai for the module
):
mock_response = MagicMock()
mock_response.text = "Item 1\nItem 2\n Item 3 \n\nAnother Item"
mock_candidate = MagicMock()
mock_candidate.content.parts = [MagicMock(text=mock_response.text)]
mock_candidate.finish_reason = 'STOP'
mock_candidate.safety_ratings = []
mock_response.candidates = [mock_candidate]
mock_generative_model_instance.generate_content_async.return_value = mock_response
""" Test successful item extraction """
# Ensure the global client is mocked to be the one we control
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
patch('app.core.gemini.gemini_initialization_error', None):
mock_response = MagicMock()
mock_response.text = "Item 1\nItem 2\n Item 3 \n\nAnother Item"
# Simulate the structure for safety checks if needed
mock_candidate = MagicMock()
mock_candidate.content.parts = [MagicMock(text=mock_response.text)]
mock_candidate.finish_reason = 'STOP' # Or whatever is appropriate for success
mock_candidate.safety_ratings = []
mock_response.candidates = [mock_candidate]
mock_generative_model_instance.generate_content_async.return_value = mock_response
image_bytes = b"dummy_image_bytes"
mime_type = "image/png"
@ -164,7 +168,9 @@ async def test_extract_items_from_image_gemini_success(
assert items == ["Item 1", "Item 2", "Item 3", "Another Item"]
@pytest.mark.asyncio
async def test_extract_items_from_image_gemini_client_not_init(mock_gemini_settings):
async def test_extract_items_from_image_gemini_client_not_init(
mock_gemini_settings
):
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch('app.core.gemini.gemini_flash_client', None), \
patch('app.core.gemini.gemini_initialization_error', "Initialization failed explicitly"):
@ -174,16 +180,16 @@ async def test_extract_items_from_image_gemini_client_not_init(mock_gemini_setti
await gemini.extract_items_from_image_gemini(image_bytes)
@pytest.mark.asyncio
@patch('app.core.gemini.get_gemini_client') # Mock the getter to control the client directly
async def test_extract_items_from_image_gemini_api_quota_error(
mock_gemini_settings,
mock_get_client,
mock_gemini_settings,
mock_generative_model_instance
):
mock_get_client.return_value = mock_generative_model_instance
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ResourceExhausted("Quota exceeded")
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
patch('app.core.gemini.gemini_initialization_error', None):
with patch('app.core.gemini.settings', mock_gemini_settings):
image_bytes = b"dummy_image_bytes"
with pytest.raises(google_exceptions.ResourceExhausted, match="Quota exceeded"):
await gemini.extract_items_from_image_gemini(image_bytes)
@ -210,91 +216,61 @@ def test_gemini_ocr_service_init_failure(MockGenerativeModel, MockConfigure, moc
gemini.GeminiOCRService()
@pytest.mark.asyncio
async def test_gemini_ocr_service_extract_items_success(
mock_gemini_settings,
mock_generative_model_instance
):
async def test_gemini_ocr_service_extract_items_success(mock_gemini_settings, mock_generative_model_instance):
mock_response = MagicMock()
mock_response.text = "Item 1\nItem 2\n Item 3 \n\nAnother Item"
mock_candidate = MagicMock()
mock_candidate.content.parts = [MagicMock(text=mock_response.text)]
mock_candidate.finish_reason = 'STOP'
mock_candidate.safety_ratings = []
mock_response.candidates = [mock_candidate]
mock_response.text = "Apple\nBanana\nOrange\nExample output should be ignored"
mock_generative_model_instance.generate_content_async.return_value = mock_response
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
patch('app.core.gemini.gemini_initialization_error', None):
service = gemini.GeminiOCRService()
image_bytes = b"dummy_image_bytes"
mime_type = "image/png"
items = await service.extract_items(image_bytes, mime_type)
mock_generative_model_instance.generate_content_async.assert_called_once_with([
mock_gemini_settings.OCR_ITEM_EXTRACTION_PROMPT,
{"mime_type": mime_type, "data": image_bytes}
])
assert items == ["Item 1", "Item 2", "Item 3", "Another Item"]
with patch('app.core.gemini.settings', mock_gemini_settings):
# Patch the model instance within the service for this test
with patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance) as patched_model_class,
patch.object(genai, 'configure') as patched_configure:
service = gemini.GeminiOCRService() # Re-init to use the patched model
items = await service.extract_items(b"dummy_image")
expected_call_args = [
mock_gemini_settings.OCR_ITEM_EXTRACTION_PROMPT,
{"mime_type": "image/jpeg", "data": b"dummy_image"}
]
service.model.generate_content_async.assert_called_once_with(contents=expected_call_args)
assert items == ["Apple", "Banana", "Orange"]
@pytest.mark.asyncio
async def test_gemini_ocr_service_extract_items_quota_error(
mock_gemini_settings,
mock_generative_model_instance
):
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ResourceExhausted("Quota exceeded")
async def test_gemini_ocr_service_extract_items_quota_error(mock_gemini_settings, mock_generative_model_instance):
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ResourceExhausted("Quota limits exceeded.")
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
patch('app.core.gemini.gemini_initialization_error', None):
patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance), \
patch.object(genai, 'configure'):
service = gemini.GeminiOCRService()
image_bytes = b"dummy_image_bytes"
with pytest.raises(OCRQuotaExceededError):
await service.extract_items(image_bytes)
await service.extract_items(b"dummy_image")
@pytest.mark.asyncio
async def test_gemini_ocr_service_extract_items_api_unavailable(
mock_gemini_settings,
mock_generative_model_instance
):
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ServiceUnavailable("Service unavailable")
async def test_gemini_ocr_service_extract_items_api_unavailable(mock_gemini_settings, mock_generative_model_instance):
# Simulate a generic API error that isn't quota related
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.InternalServerError("Service unavailable")
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
patch('app.core.gemini.gemini_initialization_error', None):
patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance), \
patch.object(genai, 'configure'):
service = gemini.GeminiOCRService()
image_bytes = b"dummy_image_bytes"
with pytest.raises(OCRServiceUnavailableError):
await service.extract_items(image_bytes)
await service.extract_items(b"dummy_image")
@pytest.mark.asyncio
async def test_gemini_ocr_service_extract_items_no_text_response(
mock_gemini_settings,
mock_generative_model_instance
):
async def test_gemini_ocr_service_extract_items_no_text_response(mock_gemini_settings, mock_generative_model_instance):
mock_response = MagicMock()
mock_response.text = ""
mock_candidate = MagicMock()
mock_candidate.content.parts = [MagicMock(text=mock_response.text)]
mock_candidate.finish_reason = 'STOP'
mock_candidate.safety_ratings = []
mock_response.candidates = [mock_candidate]
mock_response.text = None # Simulate no text in response
mock_generative_model_instance.generate_content_async.return_value = mock_response
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
patch('app.core.gemini.gemini_initialization_error', None):
patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance), \
patch.object(genai, 'configure'):
service = gemini.GeminiOCRService()
image_bytes = b"dummy_image_bytes"
items = await service.extract_items(image_bytes)
assert items == []
with pytest.raises(OCRUnexpectedError):
await service.extract_items(b"dummy_image")

View File

@ -8,10 +8,10 @@ from passlib.context import CryptContext
from app.core.security import (
verify_password,
hash_password,
# create_access_token,
# create_refresh_token,
# verify_access_token,
# verify_refresh_token,
create_access_token,
create_refresh_token,
verify_access_token,
verify_refresh_token,
pwd_context, # Import for direct testing if needed, or to check its config
)
# Assuming app.config.settings will be mocked
@ -44,173 +44,173 @@ def test_verify_password_invalid_hash_format():
invalid_hash = "notarealhash"
assert verify_password(password, invalid_hash) is False
# --- Tests for JWT Creation ---
# --- Tests for JWT Creation ---
# Mock settings for JWT tests
# @pytest.fixture(scope="module")
# def mock_jwt_settings():
# mock_settings = MagicMock()
# mock_settings.SECRET_KEY = "testsecretkey"
# mock_settings.ALGORITHM = "HS256"
# mock_settings.ACCESS_TOKEN_EXPIRE_MINUTES = 30
# mock_settings.REFRESH_TOKEN_EXPIRE_MINUTES = 10080 # 7 days
# return mock_settings
@pytest.fixture(scope="module")
def mock_jwt_settings():
mock_settings = MagicMock()
mock_settings.SECRET_KEY = "testsecretkey"
mock_settings.ALGORITHM = "HS256"
mock_settings.ACCESS_TOKEN_EXPIRE_MINUTES = 30
mock_settings.REFRESH_TOKEN_EXPIRE_MINUTES = 10080 # 7 days
return mock_settings
# @patch('app.core.security.settings')
# def test_create_access_token_default_expiry(mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
@patch('app.core.security.settings')
def test_create_access_token_default_expiry(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
# subject = "user@example.com"
# token = create_access_token(subject)
# assert isinstance(token, str)
subject = "user@example.com"
token = create_access_token(subject)
assert isinstance(token, str)
# decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
# assert decoded_payload["sub"] == subject
# assert decoded_payload["type"] == "access"
# assert "exp" in decoded_payload
# # Check if expiry is roughly correct (within a small delta)
# expected_expiry = datetime.now(timezone.utc) + timedelta(minutes=mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES)
# assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
assert decoded_payload["sub"] == subject
assert decoded_payload["type"] == "access"
assert "exp" in decoded_payload
# Check if expiry is roughly correct (within a small delta)
expected_expiry = datetime.now(timezone.utc) + timedelta(minutes=mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES)
assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
# @patch('app.core.security.settings')
# def test_create_access_token_custom_expiry(mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# # ACCESS_TOKEN_EXPIRE_MINUTES is not used here due to custom delta
@patch('app.core.security.settings')
def test_create_access_token_custom_expiry(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# ACCESS_TOKEN_EXPIRE_MINUTES is not used here due to custom delta
# subject = 123 # Subject can be int
# custom_delta = timedelta(hours=1)
# token = create_access_token(subject, expires_delta=custom_delta)
# assert isinstance(token, str)
subject = 123 # Subject can be int
custom_delta = timedelta(hours=1)
token = create_access_token(subject, expires_delta=custom_delta)
assert isinstance(token, str)
# decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
# assert decoded_payload["sub"] == str(subject)
# assert decoded_payload["type"] == "access"
# expected_expiry = datetime.now(timezone.utc) + custom_delta
# assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
assert decoded_payload["sub"] == str(subject)
assert decoded_payload["type"] == "access"
expected_expiry = datetime.now(timezone.utc) + custom_delta
assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
# @patch('app.core.security.settings')
# def test_create_refresh_token_default_expiry(mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES
@patch('app.core.security.settings')
def test_create_refresh_token_default_expiry(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES
# subject = "refresh_subject"
# token = create_refresh_token(subject)
# assert isinstance(token, str)
subject = "refresh_subject"
token = create_refresh_token(subject)
assert isinstance(token, str)
# decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
# assert decoded_payload["sub"] == subject
# assert decoded_payload["type"] == "refresh"
# expected_expiry = datetime.now(timezone.utc) + timedelta(minutes=mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES)
# assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
assert decoded_payload["sub"] == subject
assert decoded_payload["type"] == "refresh"
expected_expiry = datetime.now(timezone.utc) + timedelta(minutes=mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES)
assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
# --- Tests for JWT Verification --- (More tests to be added here)
# @patch('app.core.security.settings')
# def test_verify_access_token_valid(mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
@patch('app.core.security.settings')
def test_verify_access_token_valid(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
subject = "test_user_valid_access"
token = create_access_token(subject)
payload = verify_access_token(token)
assert payload is not None
assert payload["sub"] == subject
assert payload["type"] == "access"
# subject = "test_user_valid_access"
# token = create_access_token(subject)
# payload = verify_access_token(token)
# assert payload is not None
# assert payload["sub"] == subject
# assert payload["type"] == "access"
@patch('app.core.security.settings')
def test_verify_access_token_invalid_signature(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
# @patch('app.core.security.settings')
# def test_verify_access_token_invalid_signature(mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
subject = "test_user_invalid_sig"
# Create token with correct key
token = create_access_token(subject)
# Try to verify with wrong key
mock_settings_global.SECRET_KEY = "wrongsecretkey"
payload = verify_access_token(token)
assert payload is None
# subject = "test_user_invalid_sig"
# # Create token with correct key
# token = create_access_token(subject)
@patch('app.core.security.settings')
@patch('app.core.security.datetime') # Mock datetime to control token expiry
def test_verify_access_token_expired(mock_datetime, mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = 1 # Expire in 1 minute
# # Try to verify with wrong key
# mock_settings_global.SECRET_KEY = "wrongsecretkey"
# payload = verify_access_token(token)
# assert payload is None
# Set current time for token creation
now = datetime.now(timezone.utc)
mock_datetime.now.return_value = now
mock_datetime.fromtimestamp = datetime.fromtimestamp # Ensure original fromtimestamp is used by jwt.decode
mock_datetime.timedelta = timedelta # Ensure original timedelta is used
# @patch('app.core.security.settings')
# @patch('app.core.security.datetime') # Mock datetime to control token expiry
# def test_verify_access_token_expired(mock_datetime, mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = 1 # Expire in 1 minute
subject = "test_user_expired"
token = create_access_token(subject)
# # Set current time for token creation
# now = datetime.now(timezone.utc)
# mock_datetime.now.return_value = now
# mock_datetime.fromtimestamp = datetime.fromtimestamp # Ensure original fromtimestamp is used by jwt.decode
# mock_datetime.timedelta = timedelta # Ensure original timedelta is used
# Advance time beyond expiry for verification
mock_datetime.now.return_value = now + timedelta(minutes=5)
payload = verify_access_token(token)
assert payload is None
# subject = "test_user_expired"
# token = create_access_token(subject)
@patch('app.core.security.settings')
def test_verify_access_token_wrong_type(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES # For refresh token creation
# # Advance time beyond expiry for verification
# mock_datetime.now.return_value = now + timedelta(minutes=5)
# payload = verify_access_token(token)
# assert payload is None
subject = "test_user_wrong_type"
# Create a refresh token
refresh_token = create_refresh_token(subject)
# Try to verify it as an access token
payload = verify_access_token(refresh_token)
assert payload is None
# @patch('app.core.security.settings')
# def test_verify_access_token_wrong_type(mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES # For refresh token creation
@patch('app.core.security.settings')
def test_verify_refresh_token_valid(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES
# subject = "test_user_wrong_type"
# # Create a refresh token
# refresh_token = create_refresh_token(subject)
subject = "test_user_valid_refresh"
token = create_refresh_token(subject)
payload = verify_refresh_token(token)
assert payload is not None
assert payload["sub"] == subject
assert payload["type"] == "refresh"
# # Try to verify it as an access token
# payload = verify_access_token(refresh_token)
# assert payload is None
@patch('app.core.security.settings')
@patch('app.core.security.datetime')
def test_verify_refresh_token_expired(mock_datetime, mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = 1 # Expire in 1 minute
# @patch('app.core.security.settings')
# def test_verify_refresh_token_valid(mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES
now = datetime.now(timezone.utc)
mock_datetime.now.return_value = now
mock_datetime.fromtimestamp = datetime.fromtimestamp
mock_datetime.timedelta = timedelta
# subject = "test_user_valid_refresh"
# token = create_refresh_token(subject)
# payload = verify_refresh_token(token)
# assert payload is not None
# assert payload["sub"] == subject
# assert payload["type"] == "refresh"
subject = "test_user_expired_refresh"
token = create_refresh_token(subject)
# @patch('app.core.security.settings')
# @patch('app.core.security.datetime')
# def test_verify_refresh_token_expired(mock_datetime, mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = 1 # Expire in 1 minute
mock_datetime.now.return_value = now + timedelta(minutes=5)
payload = verify_refresh_token(token)
assert payload is None
# now = datetime.now(timezone.utc)
# mock_datetime.now.return_value = now
# mock_datetime.fromtimestamp = datetime.fromtimestamp
# mock_datetime.timedelta = timedelta
@patch('app.core.security.settings')
def test_verify_refresh_token_wrong_type(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
# subject = "test_user_expired_refresh"
# token = create_refresh_token(subject)
# mock_datetime.now.return_value = now + timedelta(minutes=5)
# payload = verify_refresh_token(token)
# assert payload is None
# @patch('app.core.security.settings')
# def test_verify_refresh_token_wrong_type(mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
# subject = "test_user_wrong_type_refresh"
# access_token = create_access_token(subject)
# payload = verify_refresh_token(access_token)
# assert payload is None
subject = "test_user_wrong_type_refresh"
access_token = create_access_token(subject)
payload = verify_refresh_token(access_token)
assert payload is None

View File

@ -36,8 +36,6 @@ from app.core.exceptions import (
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.begin = AsyncMock()
session.begin_nested = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
@ -45,8 +43,7 @@ def mock_db_session():
session.delete = MagicMock()
session.execute = AsyncMock()
session.get = AsyncMock()
session.flush = AsyncMock()
session.in_transaction = MagicMock(return_value=False)
session.flush = AsyncMock() # create_expense uses flush
return session
@pytest.fixture
@ -125,9 +122,7 @@ def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model
group_id=expense_create_data_equal_split_group_ctx.group_id,
item_id=expense_create_data_equal_split_group_ctx.item_id,
paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id,
created_by_user_id=basic_user_model.id,
paid_by=basic_user_model, # Assuming paid_by relation is loaded
created_by_user=basic_user_model, # Assuming created_by_user relation is loaded
# splits would be populated after creation usually
version=1
)
@ -152,60 +147,47 @@ async def test_get_users_for_splitting_group_context(mock_db_session, basic_grou
# --- create_expense Tests ---
@pytest.mark.asyncio
async def test_create_expense_equal_split_group_success(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model, another_user_model):
mock_db_session.get.side_effect = [basic_user_model, basic_group_model]
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = ExpenseModel(
id=1,
description=expense_create_data_equal_split_group_ctx.description,
total_amount=expense_create_data_equal_split_group_ctx.total_amount,
currency=expense_create_data_equal_split_group_ctx.currency,
expense_date=expense_create_data_equal_split_group_ctx.expense_date,
split_type=expense_create_data_equal_split_group_ctx.split_type,
list_id=expense_create_data_equal_split_group_ctx.list_id,
group_id=expense_create_data_equal_split_group_ctx.group_id,
item_id=expense_create_data_equal_split_group_ctx.item_id,
paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id,
created_by_user_id=basic_user_model.id,
version=1
)
mock_db_session.execute.return_value = mock_result
mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # Payer, Group
# Mock get_users_for_splitting call within create_expense
# This is a bit tricky as it's an internal call. Patching is an option.
with patch('app.crud.expense.get_users_for_splitting', new_callable=AsyncMock) as mock_get_users:
mock_get_users.return_value = [basic_user_model, another_user_model]
created_expense = await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, current_user_id=1)
mock_db_session.add.assert_called()
mock_db_session.flush.assert_called_once()
# mock_db_session.commit.assert_called_once() # create_expense does not commit itself
# mock_db_session.refresh.assert_called_once() # create_expense does not refresh itself
assert created_expense is not None
assert created_expense.total_amount == expense_create_data_equal_split_group_ctx.total_amount
assert created_expense.split_type == SplitTypeEnum.EQUAL
assert len(created_expense.splits) == 2
assert len(created_expense.splits) == 2 # Expect splits to be added to the model instance
# Check split amounts
expected_amount_per_user = (expense_create_data_equal_split_group_ctx.total_amount / 2).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
for split in created_expense.splits:
assert split.owed_amount == expected_amount_per_user
@pytest.mark.asyncio
async def test_create_expense_exact_split_success(mock_db_session, expense_create_data_exact_split, basic_user_model, basic_group_model, another_user_model):
mock_db_session.get.side_effect = [basic_user_model, basic_group_model]
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = ExpenseModel(
id=1,
description=expense_create_data_exact_split.description,
total_amount=expense_create_data_exact_split.total_amount,
currency="USD",
expense_date=expense_create_data_exact_split.expense_date,
split_type=expense_create_data_exact_split.split_type,
list_id=expense_create_data_exact_split.list_id,
group_id=expense_create_data_exact_split.group_id,
item_id=expense_create_data_exact_split.item_id,
paid_by_user_id=expense_create_data_exact_split.paid_by_user_id,
created_by_user_id=basic_user_model.id,
version=1
)
mock_db_session.execute.return_value = mock_result
mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # Payer, Group
# Mock the select for user validation in exact splits
mock_user_select_result = AsyncMock()
mock_user_select_result.all.return_value = [(basic_user_model.id,), (another_user_model.id,)] # Simulate (id,) tuples
# To make it behave like scalars().all() that returns a list of IDs:
# We need to mock the scalars().all() part, or the whole execute chain for user validation.
# A simpler way for this specific case might be to mock the select for User.id
mock_execute_user_ids = AsyncMock()
# Construct a mock result that mimics what `await db.execute(select(UserModel.id)...)` would give then process
# It's a bit involved, usually `found_user_ids = {row[0] for row in user_results}`
# Let's assume the select returns a list of Row objects or tuples with one element
mock_user_ids_result_proxy = MagicMock()
mock_user_ids_result_proxy.__iter__.return_value = iter([(basic_user_model.id,), (another_user_model.id,)])
mock_db_session.execute.return_value = mock_user_ids_result_proxy
created_expense = await create_expense(mock_db_session, expense_create_data_exact_split, current_user_id=1)
@ -214,6 +196,8 @@ async def test_create_expense_exact_split_success(mock_db_session, expense_creat
assert created_expense is not None
assert created_expense.split_type == SplitTypeEnum.EXACT_AMOUNTS
assert len(created_expense.splits) == 2
assert created_expense.splits[0].owed_amount == Decimal("60.00")
assert created_expense.splits[1].owed_amount == Decimal("40.00")
@pytest.mark.asyncio
async def test_create_expense_payer_not_found(mock_db_session, expense_create_data_equal_split_group_ctx):
@ -236,7 +220,7 @@ async def test_get_expense_by_id_found(mock_db_session, db_expense_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_expense_model
mock_db_session.execute.return_value = mock_result
expense = await get_expense_by_id(mock_db_session, 1)
assert expense is not None
assert expense.id == 1
@ -250,7 +234,6 @@ async def test_get_expense_by_id_not_found(mock_db_session):
expense = await get_expense_by_id(mock_db_session, 999)
assert expense is None
mock_db_session.execute.assert_called_once()
# --- get_expenses_for_list Tests ---
@pytest.mark.asyncio
@ -261,7 +244,7 @@ async def test_get_expenses_for_list_success(mock_db_session, db_expense_model):
expenses = await get_expenses_for_list(mock_db_session, list_id=1)
assert len(expenses) == 1
assert expenses[0].list_id == 1
assert expenses[0].id == db_expense_model.id
mock_db_session.execute.assert_called_once()
# --- get_expenses_for_group Tests ---
@ -273,7 +256,7 @@ async def test_get_expenses_for_group_success(mock_db_session, db_expense_model)
expenses = await get_expenses_for_group(mock_db_session, group_id=1)
assert len(expenses) == 1
assert expenses[0].group_id == 1
assert expenses[0].id == db_expense_model.id
mock_db_session.execute.assert_called_once()
# --- Stubs for update_expense and delete_expense ---

View File

@ -30,27 +30,16 @@ from app.core.exceptions import (
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock() # Overall session mock
# For session.begin() and session.begin_nested()
# These are sync methods returning an async context manager.
# The returned AsyncMock will act as the async context manager.
mock_transaction_context = AsyncMock()
session.begin = MagicMock(return_value=mock_transaction_context)
session.begin_nested = MagicMock(return_value=mock_transaction_context) # Can use the same or a new one
# Async methods on the session itself
session = AsyncMock()
session.begin = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.execute = AsyncMock() # Correct: execute is async
session.get = AsyncMock() # Correct: get is async
session.flush = AsyncMock() # Correct: flush is async
# Sync methods on the session
session.add = MagicMock()
session.delete = MagicMock()
session.in_transaction = MagicMock(return_value=False)
session.execute = AsyncMock()
session.get = AsyncMock() # Used by check_list_permission via get_list_by_id
session.flush = AsyncMock()
return session
@pytest.fixture
@ -95,45 +84,28 @@ async def test_create_list_success(mock_db_session, list_create_data, user_model
instance.version = 1
instance.updated_at = datetime.now(timezone.utc)
return None
mock_db_session.refresh.return_value = None
mock_db_session.refresh.side_effect = mock_refresh
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = ListModel(
id=100,
name=list_create_data.name,
description=list_create_data.description,
created_by_id=user_model.id,
version=1,
updated_at=datetime.now(timezone.utc)
)
mock_db_session.execute.return_value = mock_result
created_list = await create_list(mock_db_session, list_create_data, user_model.id)
mock_db_session.add.assert_called_once()
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once()
assert created_list.name == list_create_data.name
assert created_list.created_by_id == user_model.id
# --- get_lists_for_user Tests ---
@pytest.mark.asyncio
async def test_get_lists_for_user_mix(mock_db_session, user_model, db_list_personal_model, db_list_group_model):
# Mock for the object returned by .scalars() for group_ids query
mock_group_ids_scalar_result = MagicMock()
mock_group_ids_scalar_result.all.return_value = [db_list_group_model.group_id]
# Mock for the object returned by await session.execute() for group_ids query
mock_group_ids_execute_result = MagicMock()
mock_group_ids_execute_result.scalars.return_value = mock_group_ids_scalar_result
# Mock for the object returned by .scalars() for lists query
mock_lists_scalar_result = MagicMock()
mock_lists_scalar_result.all.return_value = [db_list_personal_model, db_list_group_model]
# Mock for the object returned by await session.execute() for lists query
mock_lists_execute_result = MagicMock()
mock_lists_execute_result.scalars.return_value = mock_lists_scalar_result
# Simulate user is part of group for db_list_group_model
mock_group_ids_result = AsyncMock()
mock_group_ids_result.scalars.return_value.all.return_value = [db_list_group_model.group_id]
mock_db_session.execute.side_effect = [mock_group_ids_execute_result, mock_lists_execute_result]
mock_lists_result = AsyncMock()
# Order should be personal list (created by user_id) then group list
mock_lists_result.scalars.return_value.all.return_value = [db_list_personal_model, db_list_group_model]
mock_db_session.execute.side_effect = [mock_group_ids_result, mock_lists_result]
lists = await get_lists_for_user(mock_db_session, user_model.id)
assert len(lists) == 2
@ -144,55 +116,44 @@ async def test_get_lists_for_user_mix(mock_db_session, user_model, db_list_perso
# --- get_list_by_id Tests ---
@pytest.mark.asyncio
async def test_get_list_by_id_found_no_items(mock_db_session, db_list_personal_model):
# Mock for the object returned by .scalars()
mock_scalar_result = MagicMock()
mock_scalar_result.first.return_value = db_list_personal_model
# Mock for the object returned by await session.execute()
mock_execute_result = MagicMock()
mock_execute_result.scalars.return_value = mock_scalar_result
mock_db_session.execute.return_value = mock_execute_result
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_list_personal_model
mock_db_session.execute.return_value = mock_result
found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=False)
assert found_list is not None
assert found_list.id == db_list_personal_model.id
# query options should not include selectinload for items
# (difficult to assert directly without inspecting query object in detail)
@pytest.mark.asyncio
async def test_get_list_by_id_found_with_items(mock_db_session, db_list_personal_model):
# Simulate items loaded for the list
db_list_personal_model.items = [ItemModel(id=1, name="Test Item")]
# Mock for the object returned by .scalars()
mock_scalar_result = MagicMock()
mock_scalar_result.first.return_value = db_list_personal_model
# Mock for the object returned by await session.execute()
mock_execute_result = MagicMock()
mock_execute_result.scalars.return_value = mock_scalar_result
mock_db_session.execute.return_value = mock_execute_result
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_list_personal_model
mock_db_session.execute.return_value = mock_result
found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=True)
assert found_list is not None
assert len(found_list.items) == 1
# query options should include selectinload for items
# --- update_list Tests ---
@pytest.mark.asyncio
async def test_update_list_success(mock_db_session, db_list_personal_model, list_update_data):
list_update_data.version = db_list_personal_model.version
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = db_list_personal_model
mock_db_session.execute.return_value = mock_result
list_update_data.version = db_list_personal_model.version # Match version
updated_list = await update_list(mock_db_session, db_list_personal_model, list_update_data)
assert updated_list.name == list_update_data.name
assert updated_list.version == db_list_personal_model.version + 1
assert updated_list.version == db_list_personal_model.version # version incremented in db_list_personal_model
mock_db_session.add.assert_called_once_with(db_list_personal_model)
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once_with(db_list_personal_model)
@pytest.mark.asyncio
async def test_update_list_conflict(mock_db_session, db_list_personal_model, list_update_data):
list_update_data.version = db_list_personal_model.version + 1
list_update_data.version = db_list_personal_model.version + 1 # Version mismatch
with pytest.raises(ConflictError):
await update_list(mock_db_session, db_list_personal_model, list_update_data)
mock_db_session.rollback.assert_called_once()
@ -202,109 +163,95 @@ async def test_update_list_conflict(mock_db_session, db_list_personal_model, lis
async def test_delete_list_success(mock_db_session, db_list_personal_model):
await delete_list(mock_db_session, db_list_personal_model)
mock_db_session.delete.assert_called_once_with(db_list_personal_model)
mock_db_session.commit.assert_called_once() # from async with db.begin()
# --- check_list_permission Tests ---
@pytest.mark.asyncio
async def test_check_list_permission_creator_access_personal_list(mock_db_session, db_list_personal_model, user_model):
# Mock for the object returned by .scalars()
mock_scalar_result = MagicMock()
mock_scalar_result.first.return_value = db_list_personal_model
# Mock for the object returned by await session.execute()
mock_execute_result = MagicMock()
mock_execute_result.scalars.return_value = mock_scalar_result
mock_db_session.execute.return_value = mock_execute_result
# get_list_by_id (called by check_list_permission) will mock execute
mock_list_fetch_result = AsyncMock()
mock_list_fetch_result.scalars.return_value.first.return_value = db_list_personal_model
mock_db_session.execute.return_value = mock_list_fetch_result
ret_list = await check_list_permission(mock_db_session, db_list_personal_model.id, user_model.id)
assert ret_list.id == db_list_personal_model.id
@pytest.mark.asyncio
async def test_check_list_permission_group_member_access_group_list(mock_db_session, db_list_group_model, another_user_model, group_model):
# Mock for the object returned by .scalars()
mock_scalar_result = MagicMock()
mock_scalar_result.first.return_value = db_list_group_model
# Mock for the object returned by await session.execute()
mock_execute_result = MagicMock()
mock_execute_result.scalars.return_value = mock_scalar_result
mock_db_session.execute.return_value = mock_execute_result
# User `another_user_model` is not creator but member of the group
db_list_group_model.creator_id = user_model.id # Original creator is user_model
db_list_group_model.creator = user_model
# Mock get_list_by_id internal call
mock_list_fetch_result = AsyncMock()
mock_list_fetch_result.scalars.return_value.first.return_value = db_list_group_model
# Mock is_user_member call
with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member:
mock_is_member.return_value = True
mock_is_member.return_value = True # another_user_model is a member
mock_db_session.execute.return_value = mock_list_fetch_result
ret_list = await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id)
assert ret_list.id == db_list_group_model.id
mock_is_member.assert_called_once_with(mock_db_session, group_id=group_model.id, user_id=another_user_model.id)
@pytest.mark.asyncio
async def test_check_list_permission_non_member_no_access_group_list(mock_db_session, db_list_group_model, another_user_model):
# Mock for the object returned by .scalars()
mock_scalar_result = MagicMock()
mock_scalar_result.first.return_value = db_list_group_model
db_list_group_model.creator_id = user_model.id # Creator is not another_user_model
# Mock for the object returned by await session.execute()
mock_execute_result = MagicMock()
mock_execute_result.scalars.return_value = mock_scalar_result
mock_db_session.execute.return_value = mock_execute_result
mock_list_fetch_result = AsyncMock()
mock_list_fetch_result.scalars.return_value.first.return_value = db_list_group_model
with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member:
mock_is_member.return_value = False
mock_is_member.return_value = False # another_user_model is NOT a member
mock_db_session.execute.return_value = mock_list_fetch_result
with pytest.raises(ListPermissionError):
await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id)
@pytest.mark.asyncio
async def test_check_list_permission_list_not_found(mock_db_session, user_model):
# Mock for the object returned by .scalars()
mock_scalar_result = MagicMock()
mock_scalar_result.first.return_value = None
# Mock for the object returned by await session.execute()
mock_execute_result = MagicMock()
mock_execute_result.scalars.return_value = mock_scalar_result
mock_db_session.execute.return_value = mock_execute_result
mock_list_fetch_result = AsyncMock()
mock_list_fetch_result.scalars.return_value.first.return_value = None # List not found
mock_db_session.execute.return_value = mock_list_fetch_result
with pytest.raises(ListNotFoundError):
await check_list_permission(mock_db_session, 999, user_model.id)
# --- get_list_status Tests ---
@pytest.mark.asyncio
async def test_get_list_status_success(mock_db_session, db_list_personal_model):
# This test is more complex due to multiple potential execute calls or specific query structures
# For simplicity, assuming the primary query for the list model uses the same pattern:
mock_list_scalar_result = MagicMock()
mock_list_scalar_result.first.return_value = db_list_personal_model
mock_list_execute_result = MagicMock()
mock_list_execute_result.scalars.return_value = mock_list_scalar_result
list_updated_at = datetime.now(timezone.utc) - timezone.timedelta(hours=1)
item_updated_at = datetime.now(timezone.utc)
item_count = 5
db_list_personal_model.updated_at = list_updated_at
# Mock for ListModel.updated_at query
mock_list_updated_result = AsyncMock()
mock_list_updated_result.scalar_one_or_none.return_value = list_updated_at
# If get_list_status makes other db calls (e.g., for items, counts), they need similar mocking.
# For now, let's assume the first execute call is for the list itself.
# If the error persists as "'coroutine' object has no attribute 'latest_item_updated_at'",
# it means the `get_list_status` function is not awaiting something before accessing that attribute,
# or the mock for the object that *should* have `latest_item_updated_at` is incorrect.
# Mock for ItemModel status query
mock_item_status_result = AsyncMock()
# SQLAlchemy query for func.max and func.count returns a Row-like object or None
mock_item_status_row = MagicMock()
mock_item_status_row.latest_item_updated_at = item_updated_at
mock_item_status_row.item_count = item_count
mock_item_status_result.first.return_value = mock_item_status_row
# A simplified mock for a single execute call. You might need to adjust if get_list_status does more.
mock_db_session.execute.return_value = mock_list_execute_result
mock_db_session.execute.side_effect = [mock_list_updated_result, mock_item_status_result]
# Patching sql_func.max if it's directly used and causing issues with AsyncMock
with patch('app.crud.list.sql_func.max') as mock_sql_max:
# Example: if sql_func.max is part of a subquery or column expression
# this mock might not be hit directly if the execute call itself is fully mocked.
# This part is speculative without seeing the `get_list_status` implementation.
mock_sql_max.return_value = "mocked_max_value"
status = await get_list_status(mock_db_session, db_list_personal_model.id)
assert isinstance(status, ListStatus)
status = await get_list_status(mock_db_session, db_list_personal_model.id)
assert status.list_updated_at == list_updated_at
assert status.latest_item_updated_at == item_updated_at
assert status.item_count == item_count
assert mock_db_session.execute.call_count == 2
@pytest.mark.asyncio
async def test_get_list_status_list_not_found(mock_db_session):
# Mock for the object returned by .scalars()
mock_scalar_result = MagicMock()
mock_scalar_result.first.return_value = None
# Mock for the object returned by await session.execute()
mock_execute_result = MagicMock()
mock_execute_result.scalars.return_value = mock_scalar_result
mock_db_session.execute.return_value = mock_execute_result
mock_list_updated_result = AsyncMock()
mock_list_updated_result.scalar_one_or_none.return_value = None # List not found
mock_db_session.execute.return_value = mock_list_updated_result
with pytest.raises(ListNotFoundError):
await get_list_status(mock_db_session, 999)

View File

@ -16,14 +16,12 @@ from app.crud.settlement import (
)
from app.schemas.expense import SettlementCreate, SettlementUpdate
from app.models import Settlement as SettlementModel, User as UserModel, Group as GroupModel
from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError, ConflictError
from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.begin = AsyncMock()
session.begin_nested = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
@ -31,8 +29,6 @@ def mock_db_session():
session.delete = MagicMock()
session.execute = AsyncMock()
session.get = AsyncMock()
session.flush = AsyncMock()
session.in_transaction = MagicMock(return_value=False)
return session
@pytest.fixture
@ -64,14 +60,12 @@ def db_settlement_model():
amount=Decimal("10.50"),
settlement_date=datetime.now(timezone.utc),
description="Original settlement",
created_by_user_id=1,
version=1, # Initial version
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
payer=UserModel(id=1, name="Payer User"),
payee=UserModel(id=2, name="Payee User"),
group=GroupModel(id=1, name="Test Group"),
created_by_user=UserModel(id=1, name="Payer User") # Same as payer for simplicity
group=GroupModel(id=1, name="Test Group")
)
@pytest.fixture
@ -89,31 +83,19 @@ def group_model():
# Tests for create_settlement
@pytest.mark.asyncio
async def test_create_settlement_success(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model):
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model]
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = SettlementModel(
id=1,
group_id=settlement_create_data.group_id,
paid_by_user_id=settlement_create_data.paid_by_user_id,
paid_to_user_id=settlement_create_data.paid_to_user_id,
amount=settlement_create_data.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
settlement_date=settlement_create_data.settlement_date,
description=settlement_create_data.description,
created_by_user_id=1,
version=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
mock_db_session.execute.return_value = mock_result
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model] # Order of gets
created_settlement = await create_settlement(mock_db_session, settlement_create_data, current_user_id=1)
mock_db_session.add.assert_called_once()
mock_db_session.flush.assert_called_once()
mock_db_session.commit.assert_called_once()
mock_db_session.refresh.assert_called_once()
assert created_settlement is not None
assert created_settlement.amount == settlement_create_data.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
assert created_settlement.paid_by_user_id == settlement_create_data.paid_by_user_id
assert created_settlement.paid_to_user_id == settlement_create_data.paid_to_user_id
@pytest.mark.asyncio
async def test_create_settlement_payer_not_found(mock_db_session, settlement_create_data):
mock_db_session.get.side_effect = [None, payee_user_model, group_model]
@ -155,10 +137,7 @@ async def test_create_settlement_commit_failure(mock_db_session, settlement_crea
# Tests for get_settlement_by_id
@pytest.mark.asyncio
async def test_get_settlement_by_id_found(mock_db_session, db_settlement_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_settlement_model
mock_db_session.execute.return_value = mock_result
mock_db_session.execute.return_value.scalars.return_value.first.return_value = db_settlement_model
settlement = await get_settlement_by_id(mock_db_session, 1)
assert settlement is not None
assert settlement.id == 1
@ -166,20 +145,14 @@ async def test_get_settlement_by_id_found(mock_db_session, db_settlement_model):
@pytest.mark.asyncio
async def test_get_settlement_by_id_not_found(mock_db_session):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None
mock_db_session.execute.return_value = mock_result
mock_db_session.execute.return_value.scalars.return_value.first.return_value = None
settlement = await get_settlement_by_id(mock_db_session, 999)
assert settlement is None
# Tests for get_settlements_for_group
@pytest.mark.asyncio
async def test_get_settlements_for_group_success(mock_db_session, db_settlement_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_settlement_model]
mock_db_session.execute.return_value = mock_result
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [db_settlement_model]
settlements = await get_settlements_for_group(mock_db_session, group_id=1)
assert len(settlements) == 1
assert settlements[0].group_id == 1
@ -188,10 +161,7 @@ async def test_get_settlements_for_group_success(mock_db_session, db_settlement_
# Tests for get_settlements_involving_user
@pytest.mark.asyncio
async def test_get_settlements_involving_user_success(mock_db_session, db_settlement_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_settlement_model]
mock_db_session.execute.return_value = mock_result
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [db_settlement_model]
settlements = await get_settlements_involving_user(mock_db_session, user_id=1)
assert len(settlements) == 1
assert settlements[0].paid_by_user_id == 1 or settlements[0].paid_to_user_id == 1
@ -199,37 +169,39 @@ async def test_get_settlements_involving_user_success(mock_db_session, db_settle
@pytest.mark.asyncio
async def test_get_settlements_involving_user_with_group_filter(mock_db_session, db_settlement_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_settlement_model]
mock_db_session.execute.return_value = mock_result
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [db_settlement_model]
settlements = await get_settlements_involving_user(mock_db_session, user_id=1, group_id=1)
assert len(settlements) == 1
# More specific assertions about the query would require deeper mocking of SQLAlchemy query construction
mock_db_session.execute.assert_called_once()
# Tests for update_settlement
@pytest.mark.asyncio
async def test_update_settlement_success(mock_db_session, db_settlement_model, settlement_update_data):
# Ensure settlement_update_data.version matches db_settlement_model.version
settlement_update_data.version = db_settlement_model.version
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = db_settlement_model
mock_db_session.execute.return_value = mock_result
# Mock datetime.now()
fixed_datetime_now = datetime.now(timezone.utc)
with patch('app.crud.settlement.datetime', wraps=datetime) as mock_datetime:
mock_datetime.now.return_value = fixed_datetime_now
updated_settlement = await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
updated_settlement = await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
mock_db_session.add.assert_called_once_with(db_settlement_model)
mock_db_session.flush.assert_called_once()
mock_db_session.commit.assert_called_once()
mock_db_session.refresh.assert_called_once()
assert updated_settlement.description == settlement_update_data.description
assert updated_settlement.settlement_date == settlement_update_data.settlement_date
assert updated_settlement.version == db_settlement_model.version + 1
assert updated_settlement.version == db_settlement_model.version + 1 # Version incremented
assert updated_settlement.updated_at == fixed_datetime_now
@pytest.mark.asyncio
async def test_update_settlement_version_mismatch(mock_db_session, db_settlement_model, settlement_update_data):
settlement_update_data.version = db_settlement_model.version + 1
with pytest.raises(ConflictError):
settlement_update_data.version = db_settlement_model.version + 1 # Mismatched version
with pytest.raises(InvalidOperationError) as excinfo:
await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
mock_db_session.rollback.assert_called_once()
assert "version does not match" in str(excinfo.value)
@pytest.mark.asyncio
async def test_update_settlement_disallowed_field(mock_db_session, db_settlement_model):
@ -263,10 +235,11 @@ async def test_delete_settlement_success_with_version_check(mock_db_session, db_
@pytest.mark.asyncio
async def test_delete_settlement_version_mismatch(mock_db_session, db_settlement_model):
db_settlement_model.version = 2
with pytest.raises(ConflictError):
await delete_settlement(mock_db_session, db_settlement_model, expected_version=1)
mock_db_session.rollback.assert_called_once()
with pytest.raises(InvalidOperationError) as excinfo:
await delete_settlement(mock_db_session, db_settlement_model, expected_version=db_settlement_model.version + 1)
assert "Expected version" in str(excinfo.value)
assert "does not match current version" in str(excinfo.value)
mock_db_session.delete.assert_not_called()
@pytest.mark.asyncio
async def test_delete_settlement_commit_failure(mock_db_session, db_settlement_model):

View File

@ -17,19 +17,7 @@ from app.core.exceptions import (
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.begin = AsyncMock()
session.begin_nested = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock()
session.execute = AsyncMock()
session.get = AsyncMock()
session.flush = AsyncMock()
session.in_transaction = MagicMock(return_value=False)
return session
return AsyncMock()
@pytest.fixture
def user_create_data():
@ -42,10 +30,7 @@ def existing_user_data():
# Tests for get_user_by_email
@pytest.mark.asyncio
async def test_get_user_by_email_found(mock_db_session, existing_user_data):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = existing_user_data
mock_db_session.execute.return_value = mock_result
mock_db_session.execute.return_value.scalars.return_value.first.return_value = existing_user_data
user = await get_user_by_email(mock_db_session, "exists@example.com")
assert user is not None
assert user.email == "exists@example.com"
@ -53,10 +38,7 @@ async def test_get_user_by_email_found(mock_db_session, existing_user_data):
@pytest.mark.asyncio
async def test_get_user_by_email_not_found(mock_db_session):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None
mock_db_session.execute.return_value = mock_result
mock_db_session.execute.return_value.scalars.return_value.first.return_value = None
user = await get_user_by_email(mock_db_session, "nonexistent@example.com")
assert user is None
mock_db_session.execute.assert_called_once()
@ -78,22 +60,29 @@ async def test_get_user_by_email_db_query_error(mock_db_session):
# Tests for create_user
@pytest.mark.asyncio
async def test_create_user_success(mock_db_session, user_create_data):
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = UserModel(
id=1,
email=user_create_data.email,
name=user_create_data.name,
password_hash="hashed_password" # This would be set by the actual hash_password function
)
mock_db_session.execute.return_value = mock_result
# The actual user object returned would be created by SQLAlchemy based on db_user
# We mock the process: db.add is called, then db.flush, then db.refresh updates db_user
async def mock_refresh(user_model_instance):
user_model_instance.id = 1 # Simulate DB assigning an ID
# Simulate other db-generated fields if necessary
return None
mock_db_session.refresh = AsyncMock(side_effect=mock_refresh)
mock_db_session.flush = AsyncMock()
mock_db_session.add = MagicMock()
created_user = await create_user(mock_db_session, user_create_data)
mock_db_session.add.assert_called_once()
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once()
assert created_user is not None
assert created_user.email == user_create_data.email
assert created_user.name == user_create_data.name
assert created_user.id == 1
assert hasattr(created_user, 'id') # Check if ID was assigned (simulated by mock_refresh)
# Password hash check would be more involved, ensure hash_password was called correctly
# For now, we assume hash_password works as intended and is tested elsewhere.
@pytest.mark.asyncio
async def test_create_user_email_already_registered(mock_db_session, user_create_data):

View File

@ -1,65 +1,32 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/vite.svg" /> <!-- Or your favicon -->
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<meta name="description" content="mitlist pwa">
<meta name="format-detection" content="telephone=no">
<meta name="msapplication-tap-highlight" content="no">
<link href="https://fonts.googleapis.com/icon?family=Material+Icons" rel="stylesheet">
<!-- PWA manifest and theme color will be injected by vite-plugin-pwa -->
<title>mitlist</title>
</head>
<body>
<svg width="0" height="0" style="position: absolute">
<defs>
<symbol viewBox="0 0 24 24" id="icon-plus">
<path d="M19 13h-6v6h-2v-6H5v-2h6V5h2v6h6v2z" />
</symbol>
<symbol viewBox="0 0 24 24" id="icon-edit">
<path
d="M3 17.25V21h3.75L17.81 9.94l-3.75-3.75L3 17.25zM20.71 7.04c.39-.39.39-1.02 0-1.41l-2.34-2.34a.9959.9959 0 0 0-1.41 0l-1.83 1.83 3.75 3.75 1.83-1.83z" />
</symbol>
<symbol viewBox="0 0 24 24" id="icon-trash">
<path d="M6 19c0 1.1.9 2 2 2h8c1.1 0 2-.9 2-2V7H6v12zM19 4h-3.5l-1-1h-5l-1 1H5v2h14V4z" />
</symbol>
<symbol viewBox="0 0 24 24" id="icon-check">
<path d="M9 16.17 4.83 12l-1.42 1.41L9 19 21 7l-1.41-1.41z" />
</symbol>
<symbol viewBox="0 0 24 24" id="icon-close">
<path
d="M19 6.41 17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12z" />
</symbol>
<symbol viewBox="0 0 24 24" id="icon-alert-triangle">
<path d="M1 21h22L12 2 1 21zm12-3h-2v-2h2v2zm0-4h-2v-4h2v4z" />
</symbol>
<symbol viewBox="0 0 24 24" id="icon-clipboard">
<path
d="M16 2H8C6.9 2 6 2.9 6 4v16c0 1.1.9 2 2 2h12c1.1 0 2-.9 2-2V8l-6-6zm-4 18c-1.1 0-2-.9-2-2s.9-2 2-2 2 .9 2 2-.9 2-2 2zm4-10H8V8h8v2zm2-4V4l4 4h-4z" />
</symbol>
<symbol viewBox="0 0 24 24" id="icon-info">
<path
d="M11 7h2v2h-2zm0 4h2v6h-2zm1-9C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm0 18c-4.41 0-8-3.59-8-8s3.59-8 8-8 8 3.59 8 8-3.59 8-8 8z" />
</symbol>
<symbol viewBox="0 0 24 24" id="icon-settings">
<path
d="M19.43 12.98c.04-.32.07-.64.07-.98s-.03-.66-.07-.98l2.11-1.65c.19-.15.24-.42.12-.64l-2-3.46c-.12-.22-.39-.3-.61-.22l-2.49 1c-.52-.4-1.08-.73-1.69-.98l-.38-2.65C14.46 2.18 14.25 2 14 2h-4c-.25 0-.46.18-.49.42l-.38 2.65c-.61.25-1.17.59-1.69.98l-2.49-1c-.23-.09-.49 0-.61.22l-2 3.46c-.13.22-.07.49.12.64l2.11 1.65c-.04.32-.07.65-.07.98s.03.66.07.98l-2.11 1.65c-.19.15-.24.42-.12.64l2 3.46c.12.22.39.3.61.22l2.49-1c.52.4 1.08.73 1.69.98l.38 2.65c.03.24.24.42.49.42h4c.25 0 .46-.18.49-.42l.38-2.65c.61-.25 1.17-.59 1.69-.98l2.49 1c.23.09.49 0 .61-.22l2-3.46c.12-.22.07-.49-.12-.64l-2.11-1.65zM12 15.5c-1.93 0-3.5-1.57-3.5-3.5s1.57-3.5 3.5-3.5 3.5 1.57 3.5 3.5-1.57 3.5-3.5 3.5z" />
</symbol>
<symbol viewBox="0 0 24 24" id="icon-user">
<path
d="M12 12c2.21 0 4-1.79 4-4s-1.79-4-4-4-4 1.79-4 4 1.79 4 4 4zm0 2c-2.67 0-8 1.34-8 4v2h16v-2c0-2.66-5.33-4-8-4z" />
</symbol>
<symbol viewBox="0 0 24 24" id="icon-bell">
<path
d="M12 22c1.1 0 2-.9 2-2h-4c0 1.1.9 2 2 2zm6-6v-5c0-3.07-1.63-5.64-4.5-6.32V4c0-.83-.67-1.5-1.5-1.5s-1.5.67-1.5 1.5v.68C7.64 5.36 6 7.92 6 11v5l-2 2v1h16v-1l-2-2zm-2 1H8v-6c0-2.21 1.79-4 4-4s4 1.79 4 4v6z" />
</symbol>
</defs>
</svg>
<div id="app"></div>
<script type="module" src="/src/main.ts"></script>
</body>
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/vite.svg" /> <!-- Or your favicon -->
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<meta name="description" content="mitlist pwa">
<meta name="format-detection" content="telephone=no">
<meta name="msapplication-tap-highlight" content="no">
<!-- PWA manifest and theme color will be injected by vite-plugin-pwa -->
<title>mitlist</title>
</head>
<body>
<svg width="0" height="0" style="position: absolute">
<defs>
<symbol viewBox="0 0 24 24" id="icon-plus"><path d="M19 13h-6v6h-2v-6H5v-2h6V5h2v6h6v2z" /></symbol>
<symbol viewBox="0 0 24 24" id="icon-edit"><path d="M3 17.25V21h3.75L17.81 9.94l-3.75-3.75L3 17.25zM20.71 7.04c.39-.39.39-1.02 0-1.41l-2.34-2.34a.9959.9959 0 0 0-1.41 0l-1.83 1.83 3.75 3.75 1.83-1.83z" /></symbol>
<symbol viewBox="0 0 24 24" id="icon-trash"><path d="M6 19c0 1.1.9 2 2 2h8c1.1 0 2-.9 2-2V7H6v12zM19 4h-3.5l-1-1h-5l-1 1H5v2h14V4z" /></symbol>
<symbol viewBox="0 0 24 24" id="icon-check"><path d="M9 16.17 4.83 12l-1.42 1.41L9 19 21 7l-1.41-1.41z" /></symbol>
<symbol viewBox="0 0 24 24" id="icon-close"><path d="M19 6.41 17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12z" /></symbol>
<symbol viewBox="0 0 24 24" id="icon-alert-triangle"><path d="M1 21h22L12 2 1 21zm12-3h-2v-2h2v2zm0-4h-2v-4h2v4z" /></symbol>
<symbol viewBox="0 0 24 24" id="icon-clipboard"><path d="M16 2H8C6.9 2 6 2.9 6 4v16c0 1.1.9 2 2 2h12c1.1 0 2-.9 2-2V8l-6-6zm-4 18c-1.1 0-2-.9-2-2s.9-2 2-2 2 .9 2 2-.9 2-2 2zm4-10H8V8h8v2zm2-4V4l4 4h-4z" /></symbol>
<symbol viewBox="0 0 24 24" id="icon-info"><path d="M11 7h2v2h-2zm0 4h2v6h-2zm1-9C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm0 18c-4.41 0-8-3.59-8-8s3.59-8 8-8 8 3.59 8 8-3.59 8-8 8z"/></symbol>
<symbol viewBox="0 0 24 24" id="icon-settings"><path d="M19.43 12.98c.04-.32.07-.64.07-.98s-.03-.66-.07-.98l2.11-1.65c.19-.15.24-.42.12-.64l-2-3.46c-.12-.22-.39-.3-.61-.22l-2.49 1c-.52-.4-1.08-.73-1.69-.98l-.38-2.65C14.46 2.18 14.25 2 14 2h-4c-.25 0-.46.18-.49.42l-.38 2.65c-.61.25-1.17.59-1.69.98l-2.49-1c-.23-.09-.49 0-.61.22l-2 3.46c-.13.22-.07.49.12.64l2.11 1.65c-.04.32-.07.65-.07.98s.03.66.07.98l-2.11 1.65c-.19.15-.24.42-.12.64l2 3.46c.12.22.39.3.61.22l2.49-1c.52.4 1.08.73 1.69.98l.38 2.65c.03.24.24.42.49.42h4c.25 0 .46-.18.49-.42l.38-2.65c.61-.25 1.17-.59 1.69-.98l2.49 1c.23.09.49 0 .61-.22l2-3.46c.12-.22.07-.49-.12-.64l-2.11-1.65zM12 15.5c-1.93 0-3.5-1.57-3.5-3.5s1.57-3.5 3.5-3.5 3.5 1.57 3.5 3.5-1.57 3.5-3.5 3.5z"/></symbol>
<symbol viewBox="0 0 24 24" id="icon-user"><path d="M12 12c2.21 0 4-1.79 4-4s-1.79-4-4-4-4 1.79-4 4 1.79 4 4 4zm0 2c-2.67 0-8 1.34-8 4v2h16v-2c0-2.66-5.33-4-8-4z"/></symbol>
<symbol viewBox="0 0 24 24" id="icon-bell"> <path d="M12 22c1.1 0 2-.9 2-2h-4c0 1.1.9 2 2 2zm6-6v-5c0-3.07-1.63-5.64-4.5-6.32V4c0-.83-.67-1.5-1.5-1.5s-1.5.67-1.5 1.5v.68C7.64 5.36 6 7.92 6 11v5l-2 2v1h16v-1l-2-2zm-2 1H8v-6c0-2.21 1.79-4 4-4s4 1.79 4 4v6z"/> </symbol>
</defs>
</svg>
<div id="app"></div>
<script type="module" src="/src/main.ts"></script>
</body>
</html>

View File

@ -18,8 +18,7 @@ body {
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
color: #2c3e50;
background-color: #f0f2f5;
/* Example background */
background-color: #f0f2f5; /* Example background */
}
#app {

File diff suppressed because it is too large Load Diff

View File

@ -40,7 +40,7 @@
<li v-for="(value, key) in conflictData?.localVersion.data" :key="key" class="list-item-simple">
<strong class="text-caption-strong">{{ formatKey(key) }}</strong>
<span :class="{ 'text-positive-inline': isDifferent(key as string) }">{{ formatValue(value)
}}</span>
}}</span>
</li>
</ul>
</div>
@ -59,7 +59,7 @@
<li v-for="(value, key) in conflictData?.serverVersion.data" :key="key" class="list-item-simple">
<strong class="text-caption-strong">{{ formatKey(key) }}</strong>
<span :class="{ 'text-positive-inline': isDifferent(key as string) }">{{ formatValue(value)
}}</span>
}}</span>
</li>
</ul>
</div>

View File

@ -57,7 +57,7 @@ const props = defineProps<{
const emit = defineEmits<{
(e: 'update:modelValue', value: boolean): void;
(e: 'created', newList: any): void;
(e: 'created'): void;
}>();
const isOpen = useVModel(props, 'modelValue', emit);
@ -108,7 +108,7 @@ const onSubmit = async () => {
}
loading.value = true;
try {
const response = await apiClient.post(API_ENDPOINTS.LISTS.BASE, {
await apiClient.post(API_ENDPOINTS.LISTS.BASE, {
name: listName.value,
description: description.value,
group_id: selectedGroupId.value,
@ -116,7 +116,7 @@ const onSubmit = async () => {
notificationStore.addNotification({ message: 'List created successfully', type: 'success' });
emit('created', response.data);
emit('created');
closeModal();
} catch (error: unknown) {
const message = error instanceof Error ? error.message : 'Failed to create list';

View File

@ -51,8 +51,6 @@ export const API_ENDPOINTS = {
LISTS: (groupId: string) => `/groups/${groupId}/lists`,
MEMBERS: (groupId: string) => `/groups/${groupId}/members`,
MEMBER: (groupId: string, userId: string) => `/groups/${groupId}/members/${userId}`,
CREATE_INVITE: (groupId: string) => `/groups/${groupId}/invites`,
GET_ACTIVE_INVITE: (groupId: string) => `/groups/${groupId}/invites`,
LEAVE: (groupId: string) => `/groups/${groupId}/leave`,
DELETE: (groupId: string) => `/groups/${groupId}`,
SETTINGS: (groupId: string) => `/groups/${groupId}/settings`,
@ -64,9 +62,9 @@ export const API_ENDPOINTS = {
INVITES: {
BASE: '/invites',
BY_ID: (id: string) => `/invites/${id}`,
ACCEPT: (id: string) => `/invites/accept/${id}`,
DECLINE: (id: string) => `/invites/decline/${id}`,
REVOKE: (id: string) => `/invites/revoke/${id}`,
ACCEPT: (id: string) => `/invites/${id}/accept`,
DECLINE: (id: string) => `/invites/${id}/decline`,
REVOKE: (id: string) => `/invites/${id}/revoke`,
LIST: '/invites',
PENDING: '/invites/pending',
SENT: '/invites/sent',

View File

@ -5,11 +5,7 @@
<div class="user-menu" v-if="authStore.isAuthenticated">
<button @click="toggleUserMenu" class="user-menu-button">
<!-- Placeholder for user icon -->
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 0 24 24" width="24px" fill="#ff7b54">
<path d="M0 0h24v24H0z" fill="none" />
<path
d="M12 12c2.21 0 4-1.79 4-4s-1.79-4-4-4-4 1.79-4 4 1.79 4 4 4zm0 2c-2.67 0-8 1.34-8 4v2h16v-2c0-2.66-5.33-4-8-4z" />
</svg>
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 0 24 24" width="24px" fill="#ff7b54"><path d="M0 0h24v24H0z" fill="none"/><path d="M12 12c2.21 0 4-1.79 4-4s-1.79-4-4-4-4 1.79-4 4 1.79 4 4 4zm0 2c-2.67 0-8 1.34-8 4v2h16v-2c0-2.66-5.33-4-8-4z"/></svg>
</button>
<div v-if="userMenuOpen" class="dropdown-menu" ref="userMenuDropdown">
<a href="#" @click.prevent="handleLogout">Logout</a>
@ -18,47 +14,29 @@
</header>
<main class="page-container">
<keep-alive>
<router-view v-slot="{ Component, route }">
<component :is="Component" v-if="route.meta.keepAlive" />
<component :is="Component" v-else :key="route.fullPath" />
</router-view>
</keep-alive>
<router-view />
</main>
<OfflineIndicator />
<footer class="app-footer">
<nav class="tabs">
<router-link to="/lists" class="tab-item" active-class="active">
<span class="material-icons">list</span>
<span class="tab-text">Lists</span>
</router-link>
<router-link to="/groups" class="tab-item" active-class="active">
<span class="material-icons">group</span>
<span class="tab-text">Groups</span>
</router-link>
<!-- <router-link to="/account" class="tab-item" active-class="active">
<span class="material-icons">person</span>
<span class="tab-text">Account</span>
</router-link> -->
<router-link to="/lists" class="tab-item" active-class="active">Lists</router-link>
<router-link to="/groups" class="tab-item" active-class="active">Groups</router-link>
<router-link to="/account" class="tab-item" active-class="active">Account</router-link>
</nav>
</footer>
</div>
</template>
<script setup lang="ts">
import { ref, defineComponent } from 'vue';
import { ref } from 'vue';
import { useRouter } from 'vue-router';
import { useAuthStore } from '@/stores/auth';
import OfflineIndicator from '@/components/OfflineIndicator.vue';
import { onClickOutside } from '@vueuse/core';
import { useNotificationStore } from '@/stores/notifications';
defineComponent({
name: 'MainLayout'
});
const router = useRouter();
const authStore = useAuthStore();
const notificationStore = useNotificationStore();
@ -108,7 +86,7 @@ const handleLogout = async () => {
display: flex;
align-items: center;
justify-content: space-between;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
position: sticky;
top: 0;
z-index: 100;
@ -135,9 +113,8 @@ const handleLogout = async () => {
display: flex;
align-items: center;
justify-content: center;
&:hover {
background-color: rgba(255, 255, 255, 0.1);
background-color: rgba(255,255,255,0.1);
}
}
@ -149,7 +126,7 @@ const handleLogout = async () => {
background-color: #f3f3f3;
border: 1px solid #ddd;
border-radius: 4px;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15);
box-shadow: 0 2px 8px rgba(0,0,0,0.15);
min-width: 150px;
z-index: 101;
@ -158,7 +135,6 @@ const handleLogout = async () => {
padding: 0.5rem 1rem;
color: var(--text-color);
text-decoration: none;
&:hover {
background-color: #f5f5f5;
}
@ -194,29 +170,15 @@ const handleLogout = async () => {
flex-direction: column;
align-items: center;
justify-content: center;
color: var(--text-color);
color: var(--text-color); // Or a specific inactive tab color
text-decoration: none;
font-size: 0.8rem;
font-size: 0.8rem; // Example size
padding: 0.5rem 0;
border-bottom: 2px solid transparent;
gap: 4px;
.material-icons {
font-size: 24px;
}
.tab-text {
display: none;
}
@media (min-width: 768px) {
flex-direction: row;
gap: 8px;
.tab-text {
display: inline;
}
}
// Icon would go here if you add them
// Example: svg or <i> for icon fonts
&.active {
color: var(--primary-color);

View File

@ -3,18 +3,16 @@
<h1 class="mb-3">Account Settings</h1>
<div v-if="loading" class="text-center">
<div class="spinner-dots" role="status"><span /><span /><span /></div>
<div class="spinner-dots" role="status"><span/><span/><span/></div>
<p>Loading profile...</p>
</div>
<div v-else-if="error" class="alert alert-error mb-3" role="alert">
<div class="alert-content">
<svg class="icon" aria-hidden="true">
<use xlink:href="#icon-alert-triangle" />
</svg>
<svg class="icon" aria-hidden="true"><use xlink:href="#icon-alert-triangle" /></svg>
{{ error }}
</div>
<button type="button" class="btn btn-sm btn-danger" @click="fetchProfile">Retry</button>
<button type="button" class="btn btn-sm btn-danger" @click="fetchProfile">Retry</button>
</div>
<form v-else @submit.prevent="onSubmitProfile">
@ -37,7 +35,7 @@
</div>
<div class="card-footer">
<button type="submit" class="btn btn-primary" :disabled="saving">
<span v-if="saving" class="spinner-dots-sm" role="status"><span /><span /><span /></span>
<span v-if="saving" class="spinner-dots-sm" role="status"><span/><span/><span/></span>
Save Changes
</button>
</div>
@ -64,7 +62,7 @@
</div>
<div class="card-footer">
<button type="submit" class="btn btn-primary" :disabled="changingPassword">
<span v-if="changingPassword" class="spinner-dots-sm" role="status"><span /><span /><span /></span>
<span v-if="changingPassword" class="spinner-dots-sm" role="status"><span/><span/><span/></span>
Change Password
</button>
</div>
@ -195,8 +193,8 @@ const onChangePassword = async () => {
try {
// API endpoint expects 'new' not 'newPassword'
await apiClient.put(API_ENDPOINTS.USERS.PASSWORD, {
current: password.value.current,
new: password.value.newPassword
current: password.value.current,
new: password.value.newPassword
});
password.value = { current: '', newPassword: '' };
notificationStore.addNotification({ message: 'Password changed successfully', type: 'success' });
@ -231,44 +229,31 @@ onMounted(() => {
<style scoped>
.page-padding {
padding: 1rem;
/* Or use var(--padding-page) if defined in Valerie UI */
}
.mb-3 {
margin-bottom: 1.5rem;
}
/* From Valerie UI */
.flex-grow {
flex-grow: 1;
padding: 1rem; /* Or use var(--padding-page) if defined in Valerie UI */
}
.mb-3 { margin-bottom: 1.5rem; } /* From Valerie UI */
.flex-grow { flex-grow: 1; }
.preference-list {
list-style: none;
padding: 0;
margin: 0;
}
.preference-item {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0.75rem 0;
border-bottom: 1px solid #eee;
/* Softer border for list items */
border-bottom: 1px solid #eee; /* Softer border for list items */
}
.preference-item:last-child {
border-bottom: none;
}
.preference-label {
display: flex;
flex-direction: column;
margin-right: 1rem;
}
.preference-label small {
font-size: 0.85rem;
opacity: 0.7;

View File

@ -28,17 +28,12 @@ const error = ref<string | null>(null);
onMounted(async () => {
try {
const accessToken = route.query.access_token as string | undefined;
const refreshToken = route.query.refresh_token as string | undefined;
const legacyToken = route.query.token as string | undefined;
const tokenToUse = accessToken || legacyToken;
if (!tokenToUse) {
const token = route.query.token as string;
if (!token) {
throw new Error('No token provided');
}
await authStore.setTokens({ access_token: tokenToUse, refresh_token: refreshToken });
await authStore.setTokens({ access_token: token, refresh_token: '' });
notificationStore.addNotification({ message: 'Login successful', type: 'success' });
router.push('/');
} catch (err) {

View File

@ -4,86 +4,74 @@
<div class="spinner-dots" role="status"><span /><span /><span /></div>
<p>Loading group details...</p>
</div>
<div v-else-if="error" class="alert alert-error mb-3" role="alert">
<div v-else-if="error" class="alert alert-error" role="alert">
<div class="alert-content">
<svg class="icon" aria-hidden="true">
<use xlink:href="#icon-alert-triangle" />
</svg>
{{ error }}
</div>
<button type="button" class="btn btn-sm btn-danger" @click="fetchGroupDetails">Retry</button>
</div>
<div v-else-if="group">
<h1 class="mb-3">{{ group.name }}</h1>
<h1 class="mb-3">Group: {{ group.name }}</h1>
<div class="neo-grid">
<!-- Group Members Section -->
<div class="neo-card">
<div class="neo-card-header">
<h3>Group Members</h3>
</div>
<div class="neo-card-body">
<div v-if="group.members && group.members.length > 0" class="neo-members-list">
<div v-for="member in group.members" :key="member.id" class="neo-member-item">
<div class="neo-member-info">
<span class="neo-member-name">{{ member.email }}</span>
<span class="neo-member-role" :class="member.role?.toLowerCase()">{{ member.role || 'Member' }}</span>
</div>
<button v-if="canRemoveMember(member)" class="btn btn-danger btn-sm" @click="removeMember(member.id)"
:disabled="removingMember === member.id">
<span v-if="removingMember === member.id" class="spinner-dots-sm"
role="status"><span /><span /><span /></span>
Remove
</button>
</div>
</div>
<div v-else class="neo-empty-state">
<svg class="icon icon-lg" aria-hidden="true">
<use xlink:href="#icon-users" />
</svg>
<p>No members found.</p>
</div>
</div>
<!-- Group Members Section -->
<div class="card mt-3">
<div class="card-header">
<h3>Group Members</h3>
</div>
<!-- Invite Members Section -->
<div class="neo-card">
<div class="neo-card-header">
<h3>Invite Members</h3>
</div>
<div class="neo-card-body">
<button class="btn btn-primary w-full" @click="generateInviteCode" :disabled="generatingInvite">
<span v-if="generatingInvite" class="spinner-dots-sm" role="status"><span /><span /><span /></span>
{{ inviteCode ? 'Regenerate Invite Code' : 'Generate Invite Code' }}
</button>
<div v-if="inviteCode" class="neo-invite-code mt-3">
<label for="inviteCodeInput" class="neo-label">Current Active Invite Code:</label>
<div class="neo-input-group">
<input id="inviteCodeInput" type="text" :value="inviteCode" class="neo-input" readonly />
<button class="btn btn-neutral btn-icon-only" @click="copyInviteCodeHandler"
aria-label="Copy invite code">
<svg class="icon">
<use xlink:href="#icon-clipboard"></use>
</svg>
</button>
<div class="card-body">
<div v-if="group.members && group.members.length > 0" class="members-list">
<div v-for="member in group.members" :key="member.id" class="member-item">
<div class="member-info">
<span class="member-name">{{ member.email }}</span>
<span class="member-role" :class="member.role?.toLowerCase()">{{ member.role || 'Member' }}</span>
</div>
<p v-if="copySuccess" class="neo-success-text">Invite code copied to clipboard!</p>
</div>
<div v-else class="neo-empty-state mt-3">
<svg class="icon icon-lg" aria-hidden="true">
<use xlink:href="#icon-link" />
</svg>
<p>No active invite code. Click the button above to generate one.</p>
<button v-if="canRemoveMember(member)" class="btn btn-danger btn-sm" @click="removeMember(member.id)"
:disabled="removingMember === member.id">
<span v-if="removingMember === member.id" class="spinner-dots-sm"
role="status"><span /><span /><span /></span>
Remove
</button>
</div>
</div>
<div v-else class="text-muted">
No members found.
</div>
</div>
</div>
<!-- Lists Section -->
<!-- Placeholder for lists related to this group -->
<div class="mt-4">
<ListsPage :group-id="groupId" />
</div>
<!-- Invite Members Section -->
<div class="card mt-3">
<div class="card-header">
<h3>Invite Members</h3>
</div>
<div class="card-body">
<button class="btn btn-secondary" @click="generateInviteCode" :disabled="generatingInvite">
<span v-if="generatingInvite" class="spinner-dots-sm" role="status"><span /><span /><span /></span>
Generate Invite Code
</button>
<div v-if="inviteCode" class="form-group mt-2">
<label for="inviteCodeInput" class="form-label">Invite Code:</label>
<div class="flex items-center">
<input id="inviteCodeInput" type="text" :value="inviteCode" class="form-input flex-grow" readonly />
<button class="btn btn-neutral btn-icon-only ml-1" @click="copyInviteCodeHandler"
aria-label="Copy invite code">
<svg class="icon">
<use xlink:href="#icon-clipboard"></use>
</svg> <!-- Assuming #icon-clipboard is 'content_copy' -->
</button>
</div>
<p v-if="copySuccess" class="form-success-text mt-1">Invite code copied to clipboard!</p>
</div>
</div>
</div>
</div>
<div v-else class="alert alert-info" role="status">
@ -124,7 +112,6 @@ const group = ref<Group | null>(null);
const loading = ref(true);
const error = ref<string | null>(null);
const inviteCode = ref<string | null>(null);
const inviteExpiresAt = ref<string | null>(null);
const generatingInvite = ref(false);
const copySuccess = ref(false);
const removingMember = ref<number | null>(null);
@ -136,33 +123,6 @@ const { copy, copied, isSupported: clipboardIsSupported } = useClipboard({
source: computed(() => inviteCode.value || '')
});
const fetchActiveInviteCode = async () => {
if (!groupId.value) return;
// Consider adding a loading state for this fetch if needed, e.g., initialInviteCodeLoading
try {
const response = await apiClient.get(API_ENDPOINTS.GROUPS.GET_ACTIVE_INVITE(String(groupId.value)));
if (response.data && response.data.code) {
inviteCode.value = response.data.code;
inviteExpiresAt.value = response.data.expires_at; // Store expiry
} else {
inviteCode.value = null; // No active code found
inviteExpiresAt.value = null;
}
} catch (err: any) {
if (err.response && err.response.status === 404) {
inviteCode.value = null; // Explicitly set to null on 404
inviteExpiresAt.value = null;
// Optional: notify user or set a flag to show "generate one" message more prominently
console.info('No active invite code found for this group.');
} else {
const message = err instanceof Error ? err.message : 'Failed to fetch active invite code.';
// error.value = message; // This would display a large error banner, might be too much
console.error('Error fetching active invite code:', err);
notificationStore.addNotification({ message, type: 'error' });
}
}
};
const fetchGroupDetails = async () => {
if (!groupId.value) return;
loading.value = true;
@ -178,24 +138,19 @@ const fetchGroupDetails = async () => {
} finally {
loading.value = false;
}
// Fetch active invite code after group details are loaded
await fetchActiveInviteCode();
};
const generateInviteCode = async () => {
if (!groupId.value) return;
generatingInvite.value = true;
inviteCode.value = null;
copySuccess.value = false;
try {
const response = await apiClient.post(API_ENDPOINTS.GROUPS.CREATE_INVITE(String(groupId.value)));
if (response.data && response.data.code) {
inviteCode.value = response.data.code;
inviteExpiresAt.value = response.data.expires_at; // Update with new expiry
notificationStore.addNotification({ message: 'New invite code generated successfully!', type: 'success' });
} else {
// Should not happen if POST is successful and returns the code
throw new Error('New invite code data is invalid.');
}
const response = await apiClient.post(API_ENDPOINTS.INVITES.BASE, {
group_id: groupId.value, // Ensure this matches API expectation (string or number)
});
inviteCode.value = response.data.invite_code;
notificationStore.addNotification({ message: 'Invite code generated successfully!', type: 'success' });
} catch (err: unknown) {
const message = err instanceof Error ? err.message : 'Failed to generate invite code.';
console.error('Error generating invite code:', err);
@ -256,8 +211,6 @@ onMounted(() => {
<style scoped>
.page-padding {
padding: 1rem;
max-width: 1200px;
margin: 0 auto;
}
.mt-1 {
@ -284,167 +237,64 @@ onMounted(() => {
margin-left: 0.25rem;
}
.w-full {
width: 100%;
/* Adjusted from Valerie UI for tighter fit */
.form-success-text {
color: var(--success);
/* Or a darker green for text */
font-size: 0.9rem;
font-weight: bold;
}
/* Neo Grid Layout */
.neo-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
gap: 2rem;
margin-bottom: 2rem;
.flex-grow {
flex-grow: 1;
}
/* Neo Card Styles */
.neo-card {
border-radius: 18px;
box-shadow: 6px 6px 0 #111;
background: #fff;
border: 3px solid #111;
overflow: hidden;
}
.neo-card-header {
padding: 1.5rem;
border-bottom: 3px solid #111;
background: #fafafa;
}
.neo-card-header h3 {
font-weight: 900;
font-size: 1.25rem;
margin: 0;
letter-spacing: 0.5px;
}
.neo-card-body {
padding: 1.5rem;
}
/* Members List Styles */
.neo-members-list {
/* Members list styles */
.members-list {
display: flex;
flex-direction: column;
gap: 1rem;
gap: 0.75rem;
}
.neo-member-item {
.member-item {
display: flex;
justify-content: space-between;
align-items: center;
padding: 1rem;
border-radius: 12px;
background: #fafafa;
border: 2px solid #111;
transition: transform 0.1s ease-in-out;
padding: 0.5rem;
border-radius: 0.25rem;
background-color: var(--surface-2);
}
.neo-member-item:hover {
transform: translateY(-2px);
}
.neo-member-info {
.member-info {
display: flex;
align-items: center;
gap: 1rem;
gap: 0.75rem;
}
.neo-member-name {
font-weight: 600;
font-size: 1.1rem;
.member-name {
font-weight: 500;
}
.neo-member-role {
.member-role {
font-size: 0.875rem;
padding: 0.25rem 0.75rem;
padding: 0.25rem 0.5rem;
border-radius: 1rem;
background: #e0e0e0;
font-weight: 600;
background-color: var(--surface-3);
}
.neo-member-role.owner {
background: #111;
.member-role.owner {
background-color: var(--primary);
color: white;
}
/* Invite Code Styles */
.neo-invite-code {
background: #fafafa;
padding: 1rem;
border-radius: 12px;
border: 2px solid #111;
.btn-sm {
padding: 0.25rem 0.5rem;
font-size: 0.875rem;
}
.neo-label {
display: block;
font-weight: 600;
margin-bottom: 0.5rem;
}
.neo-input-group {
display: flex;
gap: 0.5rem;
}
.neo-input {
flex: 1;
padding: 0.75rem;
border: 2px solid #111;
border-radius: 8px;
font-family: monospace;
font-size: 1rem;
background: white;
}
.neo-success-text {
color: var(--success);
font-size: 0.9rem;
font-weight: 600;
margin-top: 0.5rem;
}
/* Empty State Styles */
.neo-empty-state {
text-align: center;
padding: 2rem;
color: #666;
}
.neo-empty-state .icon {
width: 3rem;
height: 3rem;
margin-bottom: 1rem;
opacity: 0.5;
}
/* Responsive Adjustments */
@media (max-width: 900px) {
.neo-grid {
gap: 1.5rem;
}
}
@media (max-width: 600px) {
.page-padding {
padding: 0.5rem;
}
.neo-card-header,
.neo-card-body {
padding: 1rem;
}
.neo-member-item {
flex-direction: column;
gap: 0.75rem;
align-items: flex-start;
}
.neo-member-info {
flex-direction: column;
align-items: flex-start;
gap: 0.5rem;
}
.text-muted {
color: var(--text-2);
font-style: italic;
}
</style>

View File

@ -1,75 +1,72 @@
<template>
<main class="container page-padding">
<!-- <h1 class="mb-3">Your Groups</h1> -->
<div class="flex justify-between items-center mb-3">
<h1>Your Groups</h1>
<button class="btn btn-primary" @click="openCreateGroupDialog">
<svg class="icon" aria-hidden="true">
<use xlink:href="#icon-plus" />
</svg>
Create Group
</button>
</div>
<div v-if="fetchError" class="alert alert-error mb-3" role="alert">
<div v-if="loading" class="text-center">
<div class="spinner-dots" role="status"><span /><span /><span /></div>
<p>Loading groups...</p>
</div>
<div v-else-if="fetchError" class="alert alert-error" role="alert">
<div class="alert-content">
<svg class="icon" aria-hidden="true">
<use xlink:href="#icon-alert-triangle" />
</svg>
{{ fetchError }}
</div>
<button type="button" class="btn btn-sm btn-danger" @click="fetchGroups">Retry</button>
</div>
<div v-else-if="groups.length === 0" class="card empty-state-card">
<ul v-else-if="groups.length" class="item-list">
<li v-for="group in groups" :key="group.id" class="list-item interactive-list-item" @click="selectGroup(group)"
@keydown.enter="selectGroup(group)" tabindex="0">
<div class="list-item-content">
<span class="item-text">{{ group.name }}</span>
<!-- Could add more details here if needed -->
</div>
</li>
</ul>
<div v-else class="card empty-state-card">
<svg class="icon icon-lg" aria-hidden="true">
<use xlink:href="#icon-clipboard" />
</svg>
<h3>No Groups Yet!</h3>
<p>You are not a member of any groups yet. Create one or join using an invite code.</p>
<button class="btn btn-primary mt-2" @click="openCreateGroupDialog">
<svg class="icon" aria-hidden="true">
<use xlink:href="#icon-plus" />
</svg>
Create New Group
</button>
</div>
<div v-else class="mb-3">
<div class="neo-groups-grid">
<div v-for="group in groups" :key="group.id" class="neo-group-card" @click="selectGroup(group)">
<h1 class="neo-group-header">{{ group.name }}</h1>
<div class="neo-group-actions">
<button class="btn btn-sm btn-secondary" @click.stop="openCreateListDialog(group)">
<svg class="icon" aria-hidden="true">
<use xlink:href="#icon-plus" />
</svg>
List
</button>
<details class="card mb-3">
<summary class="card-header flex items-center cursor-pointer"
style="display: flex; justify-content: space-between;">
<h3>
<svg class="icon" aria-hidden="true">
<use xlink:href="#icon-user" />
</svg> <!-- Placeholder icon -->
Join a Group with Invite Code
</h3>
<span class="expand-icon" aria-hidden="true"></span> <!-- Basic expand indicator -->
</summary>
<div class="card-body">
<form @submit.prevent="handleJoinGroup" class="flex items-center" style="gap: 0.5rem;">
<div class="form-group flex-grow" style="margin-bottom: 0;">
<label for="joinInviteCodeInput" class="sr-only">Enter Invite Code</label>
<input type="text" id="joinInviteCodeInput" v-model="inviteCodeToJoin" class="form-input"
placeholder="Enter Invite Code" required ref="joinInviteCodeInputRef" />
</div>
</div>
<div class="neo-create-group-card" @click="openCreateGroupDialog">
+ Group
</div>
<button type="submit" class="btn btn-secondary" :disabled="joiningGroup">
<span v-if="joiningGroup" class="spinner-dots-sm" role="status"><span /><span /><span /></span>
Join
</button>
</form>
<p v-if="joinGroupFormError" class="form-error-text mt-1">{{ joinGroupFormError }}</p>
</div>
<details class="card mb-3 mt-4">
<summary class="card-header flex items-center cursor-pointer justify-between">
<h3>
<svg class="icon" aria-hidden="true">
<use xlink:href="#icon-user" />
</svg>
Join a Group with Invite Code
</h3>
<span class="expand-icon" aria-hidden="true"></span>
</summary>
<div class="card-body">
<form @submit.prevent="handleJoinGroup" class="flex items-center" style="gap: 0.5rem;">
<div class="form-group flex-grow" style="margin-bottom: 0;">
<label for="joinInviteCodeInput" class="sr-only">Enter Invite Code</label>
<input type="text" id="joinInviteCodeInput" v-model="inviteCodeToJoin" class="form-input"
placeholder="Enter Invite Code" required ref="joinInviteCodeInputRef" />
</div>
<button type="submit" class="btn btn-secondary" :disabled="joiningGroup">
<span v-if="joiningGroup" class="spinner-dots-sm" role="status"><span /><span /><span /></span>
Join
</button>
</form>
<p v-if="joinGroupFormError" class="form-error-text mt-1">{{ joinGroupFormError }}</p>
</div>
</details>
</div>
</details>
<!-- Create Group Dialog -->
<div v-if="showCreateGroupDialog" class="modal-backdrop open" @click.self="closeCreateGroupDialog">
@ -102,34 +99,26 @@
</form>
</div>
</div>
<!-- Create List Modal -->
<CreateListModal v-model="showCreateListModal" :groups="availableGroupsForModal" @created="onListCreated" />
</main>
</template>
<script setup lang="ts">
import { ref, onMounted, nextTick } from 'vue';
import { useRouter } from 'vue-router';
import { apiClient, API_ENDPOINTS } from '@/config/api';
import { useStorage } from '@vueuse/core';
import { apiClient, API_ENDPOINTS } from '@/config/api'; // Assuming path
import { onClickOutside } from '@vueuse/core';
import { useNotificationStore } from '@/stores/notifications';
import CreateListModal from '@/components/CreateListModal.vue';
interface Group {
id: number;
id: string | number;
name: string;
description?: string;
member_count: number;
created_at: string;
updated_at: string;
}
const router = useRouter();
const notificationStore = useNotificationStore();
const groups = ref<Group[]>([]);
const loading = ref(false);
const loading = ref(true);
const fetchError = ref<string | null>(null);
const showCreateGroupDialog = ref(false);
@ -144,37 +133,20 @@ const joiningGroup = ref(false);
const joinInviteCodeInputRef = ref<HTMLInputElement | null>(null);
const joinGroupFormError = ref<string | null>(null);
const showCreateListModal = ref(false);
const availableGroupsForModal = ref<{ label: string; value: number; }[]>([]);
// Cache groups in localStorage
const cachedGroups = useStorage<Group[]>('cached-groups', []);
const cachedTimestamp = useStorage<number>('cached-groups-timestamp', 0);
const CACHE_DURATION = 5 * 60 * 1000; // 5 minutes in milliseconds
// Load cached data immediately if available and not expired
const loadCachedData = () => {
const now = Date.now();
if (cachedGroups.value.length > 0 && (now - cachedTimestamp.value) < CACHE_DURATION) {
groups.value = cachedGroups.value;
}
};
// Fetch fresh data from API
const fetchGroups = async () => {
loading.value = true;
fetchError.value = null;
try {
const response = await apiClient.get(API_ENDPOINTS.GROUPS.BASE);
groups.value = response.data;
// Update cache
cachedGroups.value = response.data;
cachedTimestamp.value = Date.now();
} catch (err) {
fetchError.value = err instanceof Error ? err.message : 'Failed to load groups';
// If we have cached data, keep showing it even if refresh failed
if (cachedGroups.value.length === 0) {
groups.value = [];
}
groups.value = Array.isArray(response.data) ? response.data : [];
} catch (error: unknown) {
const message = error instanceof Error ? error.message : 'Failed to load groups. Please try again.';
fetchError.value = message;
groups.value = [];
console.error('Error fetching groups:', error);
notificationStore.addNotification({ message, type: 'error' });
} finally {
loading.value = false;
}
};
@ -210,9 +182,6 @@ const handleCreateGroup = async () => {
groups.value.push(newGroup);
closeCreateGroupDialog();
notificationStore.addNotification({ message: `Group '${newGroup.name}' created successfully.`, type: 'success' });
// Update cache
cachedGroups.value = groups.value;
cachedTimestamp.value = Date.now();
} else {
throw new Error('Invalid data received from server.');
}
@ -244,9 +213,6 @@ const handleJoinGroup = async () => {
}
inviteCodeToJoin.value = '';
notificationStore.addNotification({ message: `Successfully joined group '${joinedGroup.name}'.`, type: 'success' });
// Update cache
cachedGroups.value = groups.value;
cachedTimestamp.value = Date.now();
} else {
// If API returns only success message, re-fetch groups
await fetchGroups(); // Refresh the list of groups
@ -267,45 +233,20 @@ const selectGroup = (group: Group) => {
router.push(`/groups/${group.id}`);
};
const openCreateListDialog = (group: Group) => {
availableGroupsForModal.value = [{
label: group.name,
value: group.id
}];
showCreateListModal.value = true;
};
const onListCreated = (newList: any) => {
notificationStore.addNotification({
message: `List '${newList.name}' created successfully.`,
type: 'success'
});
};
onMounted(async () => {
// Load cached data immediately
loadCachedData();
// Then fetch fresh data in background
await fetchGroups();
onMounted(() => {
fetchGroups();
});
</script>
<style scoped>
.page-padding {
padding: 1rem;
max-width: 1200px;
margin: 0 auto;
}
.mb-3 {
margin-bottom: 1.5rem;
}
.mt-4 {
margin-top: 2rem;
}
.mt-1 {
margin-top: 0.5rem;
}
@ -314,74 +255,17 @@ onMounted(async () => {
margin-left: 0.5rem;
}
/* Responsive grid for cards */
.neo-groups-grid {
display: flex;
flex-wrap: wrap;
gap: 2rem;
justify-content: center;
align-items: flex-start;
margin-bottom: 2rem;
}
/* Card styles */
.neo-group-card,
.neo-create-group-card {
border-radius: 18px;
box-shadow: 6px 6px 0 #111;
max-width: 420px;
min-width: 260px;
width: 100%;
margin: 0 auto;
background: #fff;
display: flex;
flex-direction: row;
align-items: center;
justify-content: space-between;
margin-bottom: 0;
padding: 2rem 2rem 1.5rem 2rem;
.interactive-list-item {
cursor: pointer;
transition: transform 0.1s ease-in-out, box-shadow 0.1s ease-in-out;
border: 3px solid #111;
transition: background-color var(--transition-speed) var(--transition-ease-out);
}
.neo-group-card:hover {
transform: translateY(-3px);
box-shadow: 6px 9px 0 #111;
}
.neo-group-header {
font-weight: 900;
font-size: 1.25rem;
/* margin-bottom: 1rem; */
letter-spacing: 0.5px;
text-transform: none;
}
.neo-group-actions {
margin-top: 0;
}
.neo-create-group-card {
border: 3px dashed #111;
background: #fafafa;
padding: 2.5rem 0;
text-align: center;
font-weight: 900;
font-size: 1.1rem;
color: #222;
cursor: pointer;
margin-top: 0;
transition: background 0.1s;
display: flex;
align-items: center;
justify-content: center;
min-height: 120px;
margin-bottom: 2.5rem;
}
.neo-create-group-card:hover {
background: #f0f0f0;
.interactive-list-item:hover,
.interactive-list-item:focus-visible {
background-color: rgba(0, 0, 0, 0.03);
outline: var(--focus-outline);
outline-offset: -3px;
/* Adjust to be inside the border */
}
.form-error-text {
@ -395,10 +279,12 @@ onMounted(async () => {
details>summary {
list-style: none;
/* Hide default marker */
}
details>summary::-webkit-details-marker {
display: none;
/* Hide default marker for Chrome */
}
.expand-icon {
@ -412,35 +298,4 @@ details[open] .expand-icon {
.cursor-pointer {
cursor: pointer;
}
/* Responsive adjustments */
@media (max-width: 900px) {
.neo-groups-grid {
gap: 1.2rem;
}
.neo-group-card,
.neo-create-group-card {
max-width: 95vw;
min-width: 180px;
padding-left: 1rem;
padding-right: 1rem;
}
}
@media (max-width: 600px) {
.page-padding {
padding: 0.5rem;
}
.neo-group-card,
.neo-create-group-card {
padding: 1.2rem 0.7rem 1rem 0.7rem;
font-size: 1rem;
}
.neo-group-header {
font-size: 1.1rem;
}
}
</style>

View File

@ -1,100 +1,113 @@
<template>
<main class="neo-container page-padding">
<div v-if="loading" class="neo-loading-state">
<main class="container page-padding">
<div v-if="loading" class="text-center">
<div class="spinner-dots" role="status"><span /><span /><span /></div>
<p>Loading list...</p>
<p>Loading list details...</p>
</div>
<div v-else-if="error" class="neo-error-state">
<svg class="icon" aria-hidden="true">
<use xlink:href="#icon-alert-triangle" />
</svg>
{{ error }}
<button class="neo-button" @click="fetchListDetails">Retry</button>
<div v-else-if="error" class="alert alert-error mb-3" role="alert">
<div class="alert-content">
<svg class="icon" aria-hidden="true">
<use xlink:href="#icon-alert-triangle" />
</svg>
{{ error }}
</div>
<button type="button" class="btn btn-sm btn-danger" @click="fetchListDetails">Retry</button>
</div>
<template v-else-if="list">
<!-- Header -->
<div class="neo-list-header">
<h1 class="neo-title mb-3">{{ list.name }}</h1>
<div class="neo-header-actions">
<button class="neo-action-button" @click="showCostSummaryDialog = true"
:class="{ 'neo-disabled': !isOnline }">
<svg class="icon">
<div class="flex justify-between items-center flex-wrap mb-2">
<h1>{{ list.name }}</h1>
<div class="flex items-center flex-wrap" style="gap: 0.5rem;">
<button class="btn btn-neutral btn-sm" @click="showCostSummaryDialog = true"
:class="{ 'feature-offline-disabled': !isOnline }"
:data-tooltip="!isOnline ? 'Cost summary requires online connection' : ''">
<svg class="icon icon-sm">
<use xlink:href="#icon-clipboard" />
</svg> Cost Summary
</svg>
Cost Summary
</button>
<button class="neo-action-button" @click="openOcrDialog" :class="{ 'neo-disabled': !isOnline }">
<svg class="icon">
<button class="btn btn-secondary btn-sm" @click="openOcrDialog"
:class="{ 'feature-offline-disabled': !isOnline }"
:data-tooltip="!isOnline ? 'OCR requires online connection' : ''">
<svg class="icon icon-sm">
<use xlink:href="#icon-plus" />
</svg> Add via OCR
</svg>
Add via OCR
</button>
<div class="neo-status" :class="list.is_complete ? 'neo-status-complete' : 'neo-status-active'">
<span v-if="list.group_id">Group List</span>
<span v-else>Personal List</span>
</div>
<span class="item-badge ml-1" :class="list.is_complete ? 'badge-settled' : 'badge-pending'">
{{ list.is_complete ? 'Complete' : 'Active' }}
</span>
</div>
</div>
<p v-if="list.description" class="neo-description">{{ list.description }}</p>
<!-- Add Item Form -->
<form @submit.prevent="onAddItem" class="card mb-3">
<div class="card-body">
<div class="flex items-end flex-wrap" style="gap: 1rem;">
<div class="form-group flex-grow" style="margin-bottom: 0;">
<label for="newItemName" class="form-label">Item Name</label>
<input type="text" id="newItemName" v-model="newItem.name" class="form-input" required
ref="itemNameInputRef" />
</div>
<div class="form-group" style="margin-bottom: 0; min-width: 120px;">
<label for="newItemQuantity" class="form-label">Quantity</label>
<input type="number" id="newItemQuantity" v-model="newItem.quantity" class="form-input" min="1" />
</div>
<button type="submit" class="btn btn-primary" :disabled="addingItem">
<span v-if="addingItem" class="spinner-dots-sm" role="status"><span /><span /><span /></span>
Add Item
</button>
</div>
</div>
</form>
<!-- Items List -->
<div v-if="list.items.length === 0" class="neo-empty-state">
<div v-if="list.items.length === 0" class="card empty-state-card">
<svg class="icon icon-lg" aria-hidden="true">
<use xlink:href="#icon-clipboard" />
</svg>
<h3>No Items Yet!</h3>
<p>Add some items using the form below.</p>
<p>This list is empty. Add some items using the form above.</p>
</div>
<div v-else class="neo-list-card">
<ul class="neo-item-list">
<li v-for="item in list.items" :key="item.id" class="neo-item"
:class="{ 'neo-item-complete': item.is_complete }">
<div class="neo-item-content">
<label class="neo-checkbox-label">
<ul v-else class="item-list">
<li v-for="item in list.items" :key="item.id" class="list-item" :class="{
'completed': item.is_complete,
'is-swiped': item.swiped,
'offline-item': isItemPendingSync(item),
'synced': !isItemPendingSync(item)
}" @touchstart="handleTouchStart" @touchmove="handleTouchMove" @touchend="handleTouchEnd">
<div class="list-item-content">
<div class="list-item-main">
<label class="checkbox-label mb-0 flex-shrink-0">
<input type="checkbox" :checked="item.is_complete"
@change="confirmUpdateItem(item, ($event.target as HTMLInputElement).checked)"
:disabled="item.updating" :aria-label="item.name" />
<span class="neo-checkmark"></span>
<span class="checkmark"></span>
</label>
<div class="neo-item-details">
<span class="neo-item-name">{{ item.name }}</span>
<span v-if="item.quantity" class="neo-item-quantity">× {{ item.quantity }}</span>
<div v-if="item.is_complete" class="neo-price-input">
<input type="number" v-model.number="item.priceInput" class="neo-number-input" placeholder="Price"
step="0.01" @blur="updateItemPrice(item)"
<div class="item-text flex-grow">
<span :class="{ 'text-decoration-line-through': item.is_complete }">{{ item.name }}</span>
<small v-if="item.quantity" class="item-caption">Quantity: {{ item.quantity }}</small>
<div v-if="item.is_complete" class="form-group mt-1" style="max-width: 150px; margin-bottom: 0;">
<label :for="`price-${item.id}`" class="sr-only">Price for {{ item.name }}</label>
<input :id="`price-${item.id}`" type="number" v-model.number="item.priceInput"
class="form-input form-input-sm" placeholder="Price" step="0.01" @blur="updateItemPrice(item)"
@keydown.enter.prevent="($event.target as HTMLInputElement).blur()" />
</div>
</div>
<div class="neo-item-actions">
<button class="neo-icon-button neo-edit-button" @click.stop="editItem(item)" aria-label="Edit item">
<svg class="icon">
<use xlink:href="#icon-edit"></use>
</svg>
</button>
<button class="neo-icon-button neo-delete-button" @click.stop="confirmDeleteItem(item)"
:disabled="item.deleting" aria-label="Delete item">
<svg class="icon">
<use xlink:href="#icon-trash"></use>
</svg>
</button>
</div>
</div>
</li>
<li class="neo-item new-item-input">
<form @submit.prevent="onAddItem" class="neo-checkbox-label neo-new-item-form">
<input type="checkbox" disabled />
<input type="text" v-model="newItem.name" class="neo-new-item-input" placeholder="Add a new item" required
ref="itemNameInputRef" />
<input type="number" v-model="newItem.quantity" class="neo-quantity-input" placeholder="Qty" min="1" />
<button type="submit" class="neo-add-button" :disabled="addingItem">
<span v-if="addingItem" class="spinner-dots-sm"><span /><span /><span /></span>
<span v-else>Add</span>
<div class="list-item-actions">
<button class="btn btn-danger btn-sm btn-icon-only" @click.stop="confirmDeleteItem(item)"
:disabled="item.deleting" aria-label="Delete item">
<svg class="icon icon-sm">
<use xlink:href="#icon-trash"></use>
</svg>
</button>
</form>
</li>
</ul>
</div>
</div>
</div>
</li>
</ul>
</template>
<!-- OCR Dialog -->
@ -248,7 +261,15 @@ interface Item {
swiped?: boolean; // For swipe UI
}
interface List { id: number; name: string; description?: string; is_complete: boolean; items: Item[]; version: number; updated_at: string; group_id?: number; }
interface List {
id: number;
name: string;
description?: string;
is_complete: boolean;
items: Item[];
version: number;
updated_at: string;
}
interface UserCostShare {
user_id: number;
@ -728,524 +749,151 @@ onUnmounted(() => {
stopPolling();
});
// Add after deleteItem function
const editItem = (item: Item) => {
// For now, just simulate editing by toggling name and adding "(Edited)" when clicked
// In a real implementation, you would show a modal or inline form
if (!item.name.includes('(Edited)')) {
item.name += ' (Edited)';
}
// Placeholder for future edit functionality
notificationStore.addNotification({
message: 'Edit functionality would show here (modal or inline form)',
type: 'info'
});
};
</script>
<style scoped>
.neo-container {
padding: 1rem;
max-width: 1200px;
margin: 0 auto;
}
.page-padding {
padding: 1rem;
max-width: 1200px;
margin: 0 auto;
}
.mb-1 {
margin-bottom: 0.5rem;
}
.mb-2 {
margin-bottom: 1rem;
}
.mb-3 {
margin-bottom: 1.5rem;
}
.neo-loading-state,
.neo-error-state,
.neo-empty-state {
text-align: center;
padding: 3rem 1rem;
margin: 2rem 0;
border: 3px solid #111;
border-radius: 18px;
background: #fff;
box-shadow: 6px 6px 0 #111;
}
.neo-error-state {
border-color: #e74c3c;
}
.neo-list-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 1.5rem;
}
.neo-title {
font-size: 2.5rem;
font-weight: 900;
margin: 0;
line-height: 1.2;
}
.neo-header-actions {
display: flex;
align-items: center;
gap: 0.75rem;
}
.neo-description {
font-size: 1.2rem;
margin-bottom: 2rem;
color: #555;
}
.neo-status {
font-weight: 900;
font-size: 1rem;
padding: 0.4rem 1rem;
border: 3px solid #111;
border-radius: 50px;
background: #fff;
box-shadow: 3px 3px 0 #111;
}
.neo-status-active {
background: #f7f7d4;
}
.neo-status-complete {
background: #d4f7dd;
}
.neo-list-card {
break-inside: avoid;
border-radius: 18px;
box-shadow: 6px 6px 0 #111;
width: 100%;
margin: 0 0 2rem 0;
background: #fff;
display: flex;
flex-direction: column;
cursor: pointer;
border: 3px solid #111;
}
.neo-item-list {
list-style: none;
padding: 0;
margin: 0 0 2rem 0;
break-inside: avoid;
width: 100%;
background: #fff;
display: flex;
flex-direction: column;
}
.neo-item {
padding: 1.2rem;
margin-bottom: 0;
border-bottom: 1px solid #eee;
background: #fff;
transition: background-color 0.1s ease-in-out;
}
.neo-item:last-child {
border-bottom: none;
}
.neo-item:hover {
background-color: #f9f9f9;
}
.neo-item-complete {
background: #f9f9f9;
}
.neo-item-content {
display: flex;
align-items: center;
}
.neo-checkbox-label {
display: flex;
align-items: center;
gap: 0.7em;
cursor: pointer;
}
.neo-checkbox-label input[type="checkbox"] {
width: 1.2em;
height: 1.2em;
accent-color: #111;
border: 2px solid #111;
border-radius: 4px;
}
.neo-item-details {
flex-grow: 1;
display: flex;
flex-direction: column;
}
.neo-item-name {
font-size: 1.1rem;
font-weight: 700;
}
.neo-item-complete .neo-item-name {
text-decoration: line-through;
opacity: 0.6;
}
.neo-item-quantity {
font-size: 0.9rem;
color: #555;
margin-top: 0.2rem;
}
.neo-price-input {
.mt-1 {
margin-top: 0.5rem;
}
.neo-item-actions {
display: flex;
gap: 0.5rem;
}
.neo-icon-button {
background: none;
border: none;
cursor: pointer;
padding: 0.5rem;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
}
.neo-edit-button {
color: #3498db;
}
.neo-edit-button:hover {
background: #eef7fd;
}
.neo-delete-button {
background: none;
border: none;
cursor: pointer;
color: #e74c3c;
padding: 0.5rem;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
margin-left: 0;
}
.neo-delete-button:hover {
background: #fee;
}
.neo-actions {
display: flex;
gap: 1rem;
margin-bottom: 2rem;
}
.neo-action-button {
background: #fff;
border: 3px solid #111;
border-radius: 8px;
padding: 0.6rem 1rem;
font-weight: 700;
cursor: pointer;
display: flex;
align-items: center;
gap: 0.5rem;
box-shadow: 3px 3px 0 #111;
transition: transform 0.1s ease-in-out, box-shadow 0.1s ease-in-out;
}
.neo-action-button:hover {
transform: translateY(-2px);
box-shadow: 3px 5px 0 #111;
}
.neo-action-button .icon {
width: 1.2rem;
height: 1.2rem;
}
.neo-disabled {
opacity: 0.5;
cursor: not-allowed;
}
.neo-add-item-form {
display: flex;
gap: 0.5rem;
margin-top: 2rem;
border: 3px solid #111;
border-radius: 12px;
padding: 1rem;
background: #f9f9f9;
box-shadow: 4px 4px 0 #111;
}
.neo-new-item-form {
width: 100%;
gap: 10px;
}
.neo-text-input {
flex-grow: 1;
border: 2px solid #111;
border-radius: 8px;
padding: 0.8rem;
font-size: 1.1rem;
font-weight: 500;
}
.neo-new-item-input {
background: transparent;
border: none;
outline: none;
all: unset;
width: 100%;
font-size: 1.1rem;
font-weight: 500;
color: #444;
flex-grow: 1;
}
.neo-new-item-input::placeholder {
color: #999;
font-weight: 500;
}
.neo-quantity-input {
width: 80px;
border: 2px solid #111;
border-radius: 8px;
padding: 0.4rem;
font-size: 1rem;
font-weight: 500;
}
.neo-number-input {
border: 2px solid #111;
border-radius: 6px;
padding: 0.5rem;
font-size: 1rem;
width: 100px;
}
.neo-add-button {
background: #111;
color: white;
border: none;
border-radius: 8px;
padding: 0 1rem;
font-weight: 700;
cursor: pointer;
min-width: 60px;
height: 2rem;
}
.neo-button {
background: #111;
color: white;
border: none;
border-radius: 8px;
padding: 0.8rem 1.5rem;
font-weight: 700;
.mt-2 {
margin-top: 1rem;
cursor: pointer;
}
.new-item-input {
margin-top: 0.5rem;
padding: 0.5rem;
.ml-1 {
margin-left: 0.25rem;
}
/* Responsive adjustments */
@media (max-width: 900px) {
.neo-container {
padding: 0.8rem;
}
.neo-title {
font-size: 1.8rem;
}
.neo-item {
padding: 1rem;
}
}
@media (max-width: 600px) {
.neo-container {
padding: 0.5rem;
}
.neo-header-actions {
flex-direction: column;
align-items: flex-start;
gap: 0.5rem;
}
.neo-title {
font-size: 1.5rem;
}
.neo-item-name {
font-size: 1rem;
}
.neo-add-item-form {
flex-direction: column;
padding: 0.8rem;
}
.neo-quantity-input {
width: 100%;
}
}
.modal-backdrop {
background-color: rgba(0, 0, 0, 0.5);
position: fixed;
top: 0;
left: 0;
right: 0;
bottom: 0;
display: flex;
align-items: center;
justify-content: center;
z-index: 1000;
}
.modal-container {
background: white;
border-radius: 18px;
border: 3px solid #111;
box-shadow: 6px 6px 0 #111;
width: 90%;
max-width: 500px;
max-height: 90vh;
overflow-y: auto;
padding: 0;
}
.modal-header {
padding: 1.5rem;
border-bottom: 1px solid #eee;
display: flex;
justify-content: space-between;
align-items: center;
}
.modal-body {
padding: 1.5rem;
}
.modal-footer {
padding: 1.5rem;
border-top: 1px solid #eee;
display: flex;
justify-content: flex-end;
gap: 0.5rem;
}
.close-button {
background: none;
border: none;
cursor: pointer;
color: #666;
}
.item-badge {
display: inline-block;
padding: 0.25rem 0.5rem;
border-radius: 16px;
font-weight: 700;
font-size: 0.9rem;
}
.badge-settled {
background-color: #d4f7dd;
color: #2c784c;
}
.badge-pending {
background-color: #ffe1d6;
color: #c64600;
.ml-2 {
margin-left: 0.5rem;
}
.text-right {
text-align: right;
}
.text-center {
text-align: center;
.flex-grow {
flex-grow: 1;
}
.spinner-dots {
display: flex;
align-items: center;
justify-content: center;
gap: 0.3rem;
margin: 0 auto;
.item-caption {
display: block;
font-size: 0.8rem;
opacity: 0.6;
margin-top: 0.25rem;
}
.spinner-dots span {
width: 8px;
height: 8px;
background-color: #555;
border-radius: 50%;
animation: dot-pulse 1.4s infinite ease-in-out both;
.text-decoration-line-through {
text-decoration: line-through;
}
.spinner-dots-sm {
display: inline-flex;
align-items: center;
gap: 0.2rem;
.form-input-sm {
/* For price input */
padding: 0.4rem 0.6rem;
font-size: 0.9rem;
}
.spinner-dots-sm span {
width: 4px;
height: 4px;
background-color: white;
border-radius: 50%;
animation: dot-pulse 1.4s infinite ease-in-out both;
.cost-overview p {
margin-bottom: 0.5rem;
font-size: 1.05rem;
}
.spinner-dots span:nth-child(1),
.spinner-dots-sm span:nth-child(1) {
animation-delay: -0.32s;
.form-error-text {
color: var(--danger);
font-size: 0.85rem;
}
.spinner-dots span:nth-child(2),
.spinner-dots-sm span:nth-child(2) {
animation-delay: -0.16s;
.list-item.completed .item-text {
/* text-decoration: line-through; is handled by span class */
opacity: 0.7;
}
@keyframes dot-pulse {
.list-item-actions {
margin-left: auto;
/* Pushes actions to the right */
padding-left: 1rem;
/* Space before actions */
}
0%,
80%,
100% {
transform: scale(0);
.offline-item {
position: relative;
opacity: 0.8;
transition: opacity 0.3s ease;
}
.offline-item::after {
content: '';
position: absolute;
top: 0.5rem;
right: 0.5rem;
width: 1rem;
height: 1rem;
background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24' fill='none' stroke='currentColor' stroke-width='2' stroke-linecap='round' stroke-linejoin='round'%3E%3Cpath d='M21 12a9 9 0 0 0-9-9 9.75 9.75 0 0 0-6.74 2.74L3 8'/%3E%3Cpath d='M3 3v5h5'/%3E%3Cpath d='M3 12a9 9 0 0 0 9 9 9.75 9.75 0 0 0 6.74-2.74L21 16'/%3E%3Cpath d='M16 21h5v-5'/%3E%3C/svg%3E");
background-size: contain;
background-repeat: no-repeat;
animation: spin 1s linear infinite;
}
.offline-item.synced {
opacity: 1;
}
.offline-item.synced::after {
display: none;
}
@keyframes spin {
from {
transform: rotate(0deg);
}
40% {
transform: scale(1);
to {
transform: rotate(360deg);
}
}
.feature-offline-disabled {
position: relative;
cursor: not-allowed;
opacity: 0.6;
}
.feature-offline-disabled::before {
content: attr(data-tooltip);
position: absolute;
bottom: 100%;
left: 50%;
transform: translateX(-50%);
padding: 0.5rem;
background-color: var(--bg-color-tooltip, #333);
color: white;
border-radius: 0.25rem;
font-size: 0.875rem;
white-space: nowrap;
opacity: 0;
visibility: hidden;
transition: all 0.2s ease;
z-index: 1000;
}
.feature-offline-disabled:hover::before {
opacity: 1;
visibility: visible;
}
</style>

View File

@ -1,8 +1,13 @@
<template>
<main class="container page-padding">
<!-- <h1 class="mb-3">{{ pageTitle }}</h1> -->
<h1 class="mb-3">{{ pageTitle }}</h1>
<div v-if="error" class="alert alert-error mb-3" role="alert">
<div v-if="loading" class="text-center">
<div class="spinner-dots" role="status"><span /><span /><span /></div>
<p>Loading lists...</p>
</div>
<div v-else-if="error" class="alert alert-error mb-3" role="alert">
<div class="alert-content">
<svg class="icon" aria-hidden="true">
<use xlink:href="#icon-alert-triangle" />
@ -27,31 +32,47 @@
</button>
</div>
<div v-else>
<div class="neo-lists-grid">
<div v-for="list in lists" :key="list.id" class="neo-list-card" @click="navigateToList(list.id)">
<div class="neo-list-header">{{ list.name }}</div>
<div class="neo-list-desc">{{ list.description || 'No description' }}</div>
<ul class="neo-item-list">
<li v-for="item in list.items" :key="item.id" class="neo-list-item">
<label class="neo-checkbox-label" @click.stop>
<input type="checkbox" :checked="item.is_complete" @change="toggleItem(list, item)" />
<span :class="{ 'neo-completed': item.is_complete }">{{ item.name }}</span>
</label>
</li>
<li class="neo-list-item new-item-input">
<label class="neo-checkbox-label">
<input type="checkbox" disabled />
<input type="text" class="neo-new-item-input" placeholder="Add new item..."
@keyup.enter="addNewItem(list, $event)" @blur="addNewItem(list, $event)" @click.stop />
</label>
</li>
</ul>
<ul v-else class="item-list">
<li v-for="list in lists" :key="list.id" class="list-item interactive-list-item" tabindex="0"
@click="navigateToList(list.id)" @keydown.enter="navigateToList(list.id)">
<div class="list-item-content">
<div class="list-item-main" style="flex-direction: column; align-items: flex-start;">
<span class="item-text" style="font-size: 1.1rem; font-weight: bold;">{{ list.name }}</span>
<small class="item-caption">{{ list.description || 'No description' }}</small>
<small v-if="!list.group_id && !props.groupId" class="item-caption icon-caption">
<svg class="icon icon-sm">
<use xlink:href="#icon-user" />
</svg> Personal List
</small>
<small v-if="list.group_id && !props.groupId" class="item-caption icon-caption">
<svg class="icon icon-sm">
<use xlink:href="#icon-user" />
</svg> <!-- Placeholder, group icon not in Valerie -->
Group List ({{ getGroupName(list.group_id) || `ID: ${list.group_id}` }})
</small>
</div>
<div class="list-item-details" style="flex-direction: column; align-items: flex-end;">
<span class="item-badge" :class="list.is_complete ? 'badge-settled' : 'badge-pending'">
{{ list.is_complete ? 'Complete' : 'Active' }}
</span>
<small class="item-caption mt-1">
Updated: {{ new Date(list.updated_at).toLocaleDateString() }}
</small>
</div>
</div>
<div class="neo-create-list-card" @click="showCreateModal = true">
+ Create a new list
</div>
</div>
</li>
</ul>
<div class="page-sticky-bottom-right">
<button class="btn btn-primary btn-icon-only" style="width: 56px; height: 56px; border-radius: 50%; padding: 0;"
@click="showCreateModal = true" :aria-label="currentGroupId ? 'Create Group List' : 'Create List'"
data-tooltip="Create New List">
<svg class="icon icon-lg" style="margin-right:0;">
<use xlink:href="#icon-plus" />
</svg>
</button>
<!-- Basic Tooltip (requires JS from Valerie UI example to function on hover/focus) -->
<!-- <span class="tooltip-text" role="tooltip">{{ currentGroupId ? 'Create Group List' : 'Create List' }}</span> -->
</div>
<CreateListModal v-model="showCreateModal" :groups="availableGroupsForModal" @created="onListCreated" />
@ -62,8 +83,7 @@
import { ref, onMounted, computed, watch } from 'vue';
import { useRoute, useRouter } from 'vue-router';
import { apiClient, API_ENDPOINTS } from '@/config/api';
import CreateListModal from '@/components/CreateListModal.vue';
import { useStorage } from '@vueuse/core';
import CreateListModal from '@/components/CreateListModal.vue'; // Adjusted path
interface List {
id: number;
@ -75,7 +95,6 @@ interface List {
group_id?: number | null;
created_at: string;
version: number;
items: Item[];
}
interface Group {
@ -83,17 +102,6 @@ interface Group {
name: string;
}
interface Item {
id: number;
name: string;
quantity?: string | number;
is_complete: boolean;
price?: number | null;
version: number;
updating?: boolean;
updated_at: string;
}
const props = defineProps<{
groupId?: number | string; // Prop for when ListsPage is embedded (e.g. in GroupDetailPage)
}>();
@ -101,11 +109,12 @@ const props = defineProps<{
const route = useRoute();
const router = useRouter();
const loading = ref(false);
const loading = ref(true);
const error = ref<string | null>(null);
const lists = ref<(List & { items: Item[] })[]>([]);
const allFetchedGroups = ref<Group[]>([]);
const currentViewedGroup = ref<Group | null>(null);
const lists = ref<List[]>([]);
const allFetchedGroups = ref<Group[]>([]); // Store all groups user has access to for display
const currentViewedGroup = ref<Group | null>(null); // For the title if on a specific group's list page
const showCreateModal = ref(false);
const currentGroupId = computed<number | null>(() => {
@ -167,47 +176,35 @@ const fetchAllAccessibleGroups = async () => {
}
};
// Cache lists in localStorage
const cachedLists = useStorage<(List & { items: Item[] })[]>('cached-lists', []);
const cachedTimestamp = useStorage<number>('cached-lists-timestamp', 0);
const CACHE_DURATION = 5 * 60 * 1000; // 5 minutes in milliseconds
const loadCachedData = () => {
const now = Date.now();
if (cachedLists.value.length > 0 && (now - cachedTimestamp.value) < CACHE_DURATION) {
lists.value = cachedLists.value;
}
};
const fetchLists = async () => {
loading.value = true;
error.value = null;
try {
// If currentGroupId is set, fetch lists for that group. Otherwise, fetch all user's lists.
const endpoint = currentGroupId.value
? API_ENDPOINTS.GROUPS.LISTS(String(currentGroupId.value))
: API_ENDPOINTS.LISTS.BASE;
const response = await apiClient.get(endpoint);
lists.value = response.data as (List & { items: Item[] })[];
// Update cache
cachedLists.value = response.data;
cachedTimestamp.value = Date.now();
lists.value = response.data as List[];
} catch (err: unknown) {
error.value = err instanceof Error ? err.message : 'Failed to fetch lists.';
console.error(error.value, err);
// If we have cached data, keep showing it even if refresh failed
if (cachedLists.value.length === 0) {
lists.value = [];
}
} finally {
loading.value = false;
}
};
const fetchListsAndGroups = async () => {
loading.value = true;
await Promise.all([
fetchLists(),
fetchAllAccessibleGroups()
]);
await fetchCurrentViewGroupName();
await fetchCurrentViewGroupName(); // Depends on allFetchedGroups
loading.value = false;
};
const availableGroupsForModal = computed(() => {
return allFetchedGroups.value.map(group => ({
label: group.name,
@ -220,76 +217,20 @@ const getGroupName = (groupId?: number | null): string | undefined => {
return allFetchedGroups.value.find(g => g.id === groupId)?.name;
}
const onListCreated = (newList: List & { items: Item[] }) => {
lists.value = [...lists.value, newList];
// Update cache
cachedLists.value = lists.value;
cachedTimestamp.value = Date.now();
};
const toggleItem = async (list: (List & { items: Item[] }), item: Item) => {
const original = item.is_complete;
item.is_complete = !item.is_complete;
item.updating = true;
try {
await apiClient.put(
API_ENDPOINTS.LISTS.ITEM(String(list.id), String(item.id)),
{
is_complete: item.is_complete,
version: item.version,
name: item.name,
quantity: item.quantity,
price: item.price
}
);
item.version++;
} catch (err) {
item.is_complete = original;
console.error('Failed to update item:', err);
} finally {
item.updating = false;
}
};
const addNewItem = async (list: (List & { items: Item[] }), event: Event) => {
const input = event.target as HTMLInputElement;
const itemName = input.value.trim();
if (!itemName) {
input.value = '';
return;
}
try {
const response = await apiClient.post(API_ENDPOINTS.LISTS.ITEMS(String(list.id)), {
name: itemName,
is_complete: false,
quantity: null,
price: null
});
list.items.push(response.data as Item);
input.value = '';
} catch (err) {
console.error('Failed to add new item:', err);
}
const onListCreated = () => {
fetchLists(); // Refresh lists after one is created
};
const navigateToList = (listId: number) => {
router.push({ name: 'ListDetail', params: { id: listId } });
router.push(`/lists/${listId}`);
};
onMounted(() => {
// Load cached data immediately
loadCachedData();
// Then fetch fresh data in background
fetchListsAndGroups();
});
// Watch for changes in groupId
// Watch for changes in groupId (e.g., if used as a component and prop changes)
watch(currentGroupId, () => {
loadCachedData();
fetchListsAndGroups();
});
@ -298,173 +239,75 @@ watch(currentGroupId, () => {
<style scoped>
.page-padding {
padding: 1rem;
max-width: 1200px;
margin: 0 auto;
}
.mb-3 {
margin-bottom: 1.5rem;
}
/* Masonry grid for cards */
.neo-lists-grid {
columns: 3 500px;
column-gap: 2rem;
margin-bottom: 2rem;
}
/* Card styles */
.neo-list-card,
.neo-create-list-card {
break-inside: avoid;
border-radius: 18px;
box-shadow: 6px 6px 0 #111;
width: 100%;
margin: 0 0 2rem 0;
background: #fff;
display: flex;
flex-direction: column;
/* padding: 2rem 2rem 1.5rem 2rem;
padding: 2rem 2rem 1.5rem 2rem; */
/* padding-inline: ; */
cursor: pointer;
/* transition: transform 0.1s ease-in-out, box-shadow 0.1s ease-in-out; */
border: 3px solid #111;
}
.neo-list-card:hover {
/* transform: translateY(-3px); */
box-shadow: 6px 9px 0 #111;
/* padding: 2rem 2rem 1.5rem 2rem; */
border: 3px solid #111;
}
.neo-list-header {
padding-block-start: 1rem;
font-weight: 900;
font-size: 1.25rem;
margin-bottom: 0.5rem;
letter-spacing: 0.5px;
text-transform: none;
}
.neo-list-desc {
font-size: 1rem;
color: #444;
margin-bottom: 1.2rem;
font-weight: 500;
}
.neo-item-list {
list-style: none;
padding: 0;
margin: 0;
}
.neo-list-item {
margin-bottom: 1.1rem;
font-size: 1.1rem;
font-weight: 700;
display: flex;
align-items: center;
}
.neo-checkbox-label {
display: flex;
align-items: center;
gap: 0.7em;
cursor: pointer;
}
.neo-checkbox-label input[type="checkbox"] {
width: 1.2em;
height: 1.2em;
accent-color: #111;
border: 2px solid #111;
border-radius: 4px;
margin-right: 0.5em;
}
.neo-completed {
text-decoration: line-through;
opacity: 0.5;
}
.neo-create-list-card {
border: 3px dashed #111;
background: #fafafa;
padding: 2.5rem 0;
text-align: center;
font-weight: 900;
font-size: 1.1rem;
color: #222;
cursor: pointer;
margin-top: 0;
transition: background 0.1s;
display: flex;
align-items: center;
justify-content: center;
min-height: 120px;
margin-bottom: 2.5rem;
}
.neo-create-list-card:hover {
background: #f0f0f0;
}
/* Responsive adjustments */
@media (max-width: 900px) {
.neo-lists-grid {
columns: 2 260px;
column-gap: 1.2rem;
}
.neo-list-card,
.neo-create-list-card {
margin-bottom: 1.2rem;
padding-left: 1rem;
padding-right: 1rem;
}
}
@media (max-width: 600px) {
.page-padding {
padding: 0.5rem;
}
.neo-lists-grid {
columns: 1 280px;
}
.neo-list-card,
.neo-create-list-card {
padding: 1.2rem 0.7rem 1rem 0.7rem;
font-size: 1rem;
}
.neo-list-header {
font-size: 1.1rem;
}
}
.neo-new-item-input {
.mt-1 {
margin-top: 0.5rem;
padding: 0.5rem;
}
.neo-new-item-input input[type="text"] {
background: transparent;
border: none;
outline: none;
all: unset;
.mt-2 {
margin-top: 1rem;
}
.interactive-list-item {
cursor: pointer;
transition: background-color var(--transition-speed) var(--transition-ease-out);
}
.interactive-list-item:hover,
.interactive-list-item:focus-visible {
background-color: rgba(0, 0, 0, 0.03);
outline: var(--focus-outline);
outline-offset: -3px;
}
.item-caption {
display: block;
font-size: 0.85rem;
opacity: 0.7;
margin-top: 0.25rem;
}
.icon-caption .icon {
vertical-align: -0.1em;
/* Align icon better with text */
}
.page-sticky-bottom-right {
position: fixed;
bottom: 5rem;
right: 1.5rem;
z-index: 999;
/* Below modals */
}
.page-sticky-bottom-right .btn {
box-shadow: var(--shadow-lg);
/* Make it pop more */
}
/* Ensure list item content uses full width for proper layout */
.list-item-content {
display: flex;
justify-content: space-between;
width: 100%;
font-size: 1.1rem;
font-weight: 700;
color: #444;
align-items: flex-start;
/* Align items to top if they wrap */
}
.neo-new-item-input input[type="text"]::placeholder {
color: #999;
font-weight: 500;
.list-item-main {
flex-grow: 1;
margin-right: 1rem;
/* Space before details */
}
.list-item-details {
flex-shrink: 0;
/* Prevent badges from shrinking */
text-align: right;
}
</style>

View File

@ -8,45 +8,27 @@ const routes: RouteRecordRaw[] = [
component: () => import('../layouts/MainLayout.vue'), // Use .. alias
children: [
{ path: '', redirect: '/lists' },
{
path: 'lists',
name: 'PersonalLists',
component: () => import('../pages/ListsPage.vue'),
meta: { keepAlive: true }
},
{ path: 'lists', name: 'PersonalLists', component: () => import('../pages/ListsPage.vue') },
{
path: 'lists/:id',
name: 'ListDetail',
component: () => import('../pages/ListDetailPage.vue'),
props: true,
meta: { keepAlive: true }
},
{
path: 'groups',
name: 'GroupsList',
component: () => import('../pages/GroupsPage.vue'),
meta: { keepAlive: true }
},
{ path: 'groups', name: 'GroupsList', component: () => import('../pages/GroupsPage.vue') },
{
path: 'groups/:id',
name: 'GroupDetail',
component: () => import('../pages/GroupDetailPage.vue'),
props: true,
meta: { keepAlive: true }
},
{
path: 'groups/:groupId/lists',
name: 'GroupLists',
component: () => import('../pages/ListsPage.vue'), // Reusing ListsPage
props: true,
meta: { keepAlive: true }
},
{
path: 'account',
name: 'Account',
component: () => import('../pages/AccountPage.vue'),
meta: { keepAlive: true }
},
{ path: 'account', name: 'Account', component: () => import('../pages/AccountPage.vue') },
],
},
{

View File

@ -7,7 +7,6 @@ import router from '@/router';
interface AuthState {
accessToken: string | null;
refreshToken: string | null;
user: {
email: string;
name: string;
@ -18,7 +17,6 @@ interface AuthState {
export const useAuthStore = defineStore('auth', () => {
// State
const accessToken = ref<string | null>(localStorage.getItem('token'));
const refreshToken = ref<string | null>(localStorage.getItem('refreshToken'));
const user = ref<AuthState['user']>(null);
// Getters
@ -26,21 +24,15 @@ export const useAuthStore = defineStore('auth', () => {
const getUser = computed(() => user.value);
// Actions
const setTokens = (tokens: { access_token: string; refresh_token?: string }) => {
const setTokens = (tokens: { access_token: string }) => {
accessToken.value = tokens.access_token;
localStorage.setItem('token', tokens.access_token);
if (tokens.refresh_token) {
refreshToken.value = tokens.refresh_token;
localStorage.setItem('refreshToken', tokens.refresh_token);
}
};
const clearTokens = () => {
accessToken.value = null;
refreshToken.value = null;
user.value = null;
localStorage.removeItem('token');
localStorage.removeItem('refreshToken');
};
const setUser = (userData: AuthState['user']) => {
@ -74,8 +66,8 @@ export const useAuthStore = defineStore('auth', () => {
},
});
const { access_token, refresh_token } = response.data;
setTokens({ access_token, refresh_token });
const { access_token } = response.data;
setTokens({ access_token });
await fetchCurrentUser();
return response.data;
};
@ -93,7 +85,6 @@ export const useAuthStore = defineStore('auth', () => {
return {
accessToken,
user,
refreshToken,
isAuthenticated,
getUser,
setTokens,