This file is a merged representation of the entire codebase, combined into a single document by Repomix.
This section contains a summary of this file.
This file contains a packed representation of the entire repository's contents.
It is designed to be easily consumable by AI systems for analysis, code review,
or other automated processes.
The content is organized as follows:
1. This summary section
2. Repository information
3. Directory structure
4. Repository files, each consisting of:
- File path as an attribute
- Full contents of the file
- This file should be treated as read-only. Any changes should be made to the
original repository files, not this packed version.
- When processing this file, use the file path to distinguish
between different files in the repository.
- Be aware that this file may contain sensitive information. Handle it with
the same level of security as you would the original repository.
- Some files may have been excluded based on .gitignore rules and Repomix's configuration
- Binary files are not included in this packed representation. Please refer to the Repository Structure section for a complete list of file paths, including binary files
- Files matching patterns in .gitignore are excluded
- Files matching default ignore patterns are excluded
- Files are sorted by Git change count (files with more changes are at the bottom)
.gitea/workflows/ci.yml
be/.dockerignore
be/.gitignore
be/alembic.ini
be/alembic/env.py
be/alembic/README
be/alembic/script.py.mako
be/alembic/versions/bc37e9c7ae19_fresh_start.py
be/app/api/api_router.py
be/app/api/dependencies.py
be/app/api/v1/api.py
be/app/api/v1/endpoints/auth.py
be/app/api/v1/endpoints/costs.py
be/app/api/v1/endpoints/financials.py
be/app/api/v1/endpoints/groups.py
be/app/api/v1/endpoints/health.py
be/app/api/v1/endpoints/invites.py
be/app/api/v1/endpoints/items.py
be/app/api/v1/endpoints/lists.py
be/app/api/v1/endpoints/ocr.py
be/app/api/v1/endpoints/users.py
be/app/api/v1/test_auth.py
be/app/api/v1/test_users.py
be/app/config.py
be/app/core/api_config.py
be/app/core/exceptions.py
be/app/core/gemini.py
be/app/core/security.py
be/app/crud/expense.py
be/app/crud/group.py
be/app/crud/invite.py
be/app/crud/item.py
be/app/crud/list.py
be/app/crud/settlement.py
be/app/crud/user.py
be/app/database.py
be/app/main.py
be/app/models.py
be/app/schemas/auth.py
be/app/schemas/cost.py
be/app/schemas/expense.py
be/app/schemas/group.py
be/app/schemas/health.py
be/app/schemas/invite.py
be/app/schemas/item.py
be/app/schemas/list.py
be/app/schemas/message.py
be/app/schemas/ocr.py
be/app/schemas/user.py
be/Dockerfile
be/requirements.txt
be/tests/api/v1/endpoints/test_financials.py
be/tests/core/__init__.py
be/tests/core/test_exceptions.py
be/tests/core/test_gemini.py
be/tests/core/test_security.py
be/tests/crud/__init__.py
be/tests/crud/test_expense.py
be/tests/crud/test_group.py
be/tests/crud/test_invite.py
be/tests/crud/test_item.py
be/tests/crud/test_list.py
be/tests/crud/test_settlement.py
be/tests/crud/test_user.py
be/Untitled-1.md
docker-compose.yml
fe/.editorconfig
fe/.gitignore
fe/.npmrc
fe/.prettierrc.json
fe/.vscode/extensions.json
fe/.vscode/settings.json
fe/eslint.config.js
fe/index.html
fe/package.json
fe/postcss.config.js
fe/public/icons/safari-pinned-tab.svg
fe/quasar.config.ts
fe/README.md
fe/src-pwa/custom-service-worker.ts
fe/src-pwa/manifest.json
fe/src-pwa/pwa-env.d.ts
fe/src-pwa/register-service-worker.ts
fe/src-pwa/tsconfig.json
fe/src/App.vue
fe/src/assets/quasar-logo-vertical.svg
fe/src/boot/axios.ts
fe/src/boot/i18n.ts
fe/src/components/ConflictResolutionDialog.vue
fe/src/components/CreateListModal.vue
fe/src/components/EssentialLink.vue
fe/src/components/models.ts
fe/src/components/OfflineIndicator.vue
fe/src/config/api-config.ts
fe/src/config/api.ts
fe/src/css/app.scss
fe/src/css/quasar.variables.scss
fe/src/env.d.ts
fe/src/i18n/en-US/index.ts
fe/src/i18n/index.ts
fe/src/layouts/AuthLayout.vue
fe/src/layouts/MainLayout.vue
fe/src/pages/AccountPage.vue
fe/src/pages/ErrorNotFound.vue
fe/src/pages/GroupDetailPage.vue
fe/src/pages/GroupsPage.vue
fe/src/pages/IndexPage.vue
fe/src/pages/ListDetailPage.vue
fe/src/pages/ListsPage.vue
fe/src/pages/LoginPage.vue
fe/src/pages/SignupPage.vue
fe/src/router/index.ts
fe/src/router/routes.ts
fe/src/stores/auth.ts
fe/src/stores/index.ts
fe/src/stores/offline.ts
fe/tsconfig.json
This section contains the contents of the repository's files.
# When you push to the develop branch or open/update a pull request targeting main, Gitea will:
# Trigger the "CI Checks" workflow.
# Execute the checks job on a runner.
# Run each step sequentially.
# If any of the linter/formatter check commands (black --check, ruff check, npm run lint) exit with a non-zero status code (indicating an error or check failure), the step and the entire job will fail.
# You will see the status (success/failure) associated with your commit or pull request in the Gitea interface.
name: CI Checks
# Define triggers for the workflow
on:
push:
branches:
- develop # Run on pushes to the develop branch
pull_request:
branches:
- main # Run on pull requests targeting the main branch
jobs:
checks:
name: Linters and Formatters
runs-on: ubuntu-latest # Use a standard Linux runner environment
steps:
- name: Checkout code
uses: actions/checkout@v4 # Fetches the repository code
# --- Backend Checks (Python/FastAPI) ---
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11' # Match your project/Dockerfile version
cache: 'pip' # Cache pip dependencies based on requirements.txt
cache-dependency-path: 'be/requirements.txt' # Specify path for caching
- name: Install Backend Dependencies and Tools
working-directory: ./be # Run command within the 'be' directory
run: |
pip install --upgrade pip
pip install -r requirements.txt
pip install black ruff # Install formatters/linters for CI check
- name: Run Black Formatter Check (Backend)
working-directory: ./be
run: black --check --diff .
- name: Run Ruff Linter (Backend)
working-directory: ./be
run: ruff check .
# --- Frontend Checks (SvelteKit/Node.js) ---
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '20.x' # Or specify your required Node.js version (e.g., 'lts/*')
cache: 'npm' # Or 'pnpm' / 'yarn' depending on your package manager
cache-dependency-path: 'fe/package-lock.json' # Adjust lockfile name if needed
- name: Install Frontend Dependencies
working-directory: ./fe # Run command within the 'fe' directory
run: npm install # Or 'pnpm install' / 'yarn install'
- name: Run ESLint and Prettier Check (Frontend)
working-directory: ./fe
# Assuming you have a 'lint' script in fe/package.json that runs both
# Example package.json script: "lint": "prettier --check . && eslint ."
run: npm run lint
# If no combined script, run separately:
# run: |
# npm run format -- --check # Or 'npx prettier --check .'
# npm run lint # Or 'npx eslint .'
# - name: Run Frontend Type Check (Optional but recommended)
# working-directory: ./fe
# # Assuming you have a 'check' script: "check": "svelte-kit sync && svelte-check ..."
# run: npm run check
# - name: Run Placeholder Tests (Optional)
# run: |
# # Add commands to run backend tests if available
# # Add commands to run frontend tests (e.g., npm test in ./fe) if available
# echo "No tests configured yet."
# Git files
.git
.gitignore
# Virtual environment
.venv
venv/
env/
ENV/
*.env # Ignore local .env files within the backend directory if any
# Python cache
__pycache__/
*.py[cod]
*$py.class
# IDE files
.idea/
.vscode/
# Test artifacts
.pytest_cache/
htmlcov/
.coverage*
# Other build/temp files
*.egg-info/
dist/
build/
*.db # e.g., sqlite temp dbs
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# PEP 582; used by PDM, Flit and potentially other tools.
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv/
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static analysis results
.pytype/
# alembic default temp file
*.db # If using sqlite for alembic versions locally for instance
# If you use alembic autogenerate, it might create temporary files
# Depending on your DB, adjust if necessary
# *.sql.tmp
# IDE files
.idea/
.vscode/
# OS generated files
.DS_Store
Thumbs.db
# A generic, single database configuration.
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
# version_path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
version_path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
; sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
from logging.config import fileConfig
import os
import sys
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
# Ensure the 'app' directory is in the Python path
# Adjust the path if your project structure is different
sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), '..')))
# Import your app's Base and settings
from app.models import Base # Import Base from your models module
from app.config import settings # Import settings to get DATABASE_URL
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Set the sqlalchemy.url from your application settings
# Use a synchronous version of the URL for Alembic's operations
sync_db_url = settings.DATABASE_URL.replace("+asyncpg", "") if settings.DATABASE_URL else None
if not sync_db_url:
raise ValueError("DATABASE_URL not found in settings for Alembic.")
config.set_main_option('sqlalchemy.url', sync_db_url)
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
Generic single-database configuration.
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}
"""fresh start
Revision ID: bc37e9c7ae19
Revises:
Create Date: 2025-05-08 16:06:51.208542
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'bc37e9c7ae19'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('users',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('email', sa.String(), nullable=False),
sa.Column('password_hash', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False)
op.create_index(op.f('ix_users_name'), 'users', ['name'], unique=False)
op.create_table('groups',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('created_by_id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_groups_id'), 'groups', ['id'], unique=False)
op.create_index(op.f('ix_groups_name'), 'groups', ['name'], unique=False)
op.create_table('invites',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('code', sa.String(), nullable=False),
sa.Column('group_id', sa.Integer(), nullable=False),
sa.Column('created_by_id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index('ix_invites_active_code', 'invites', ['code'], unique=True, postgresql_where=sa.text('is_active = true'))
op.create_index(op.f('ix_invites_code'), 'invites', ['code'], unique=False)
op.create_index(op.f('ix_invites_id'), 'invites', ['id'], unique=False)
op.create_table('lists',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('created_by_id', sa.Integer(), nullable=False),
sa.Column('group_id', sa.Integer(), nullable=True),
sa.Column('is_complete', sa.Boolean(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('version', sa.Integer(), server_default='1', nullable=False),
sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_lists_id'), 'lists', ['id'], unique=False)
op.create_index(op.f('ix_lists_name'), 'lists', ['name'], unique=False)
op.create_table('settlements',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('group_id', sa.Integer(), nullable=False),
sa.Column('paid_by_user_id', sa.Integer(), nullable=False),
sa.Column('paid_to_user_id', sa.Integer(), nullable=False),
sa.Column('amount', sa.Numeric(precision=10, scale=2), nullable=False),
sa.Column('settlement_date', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('version', sa.Integer(), server_default='1', nullable=False),
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ),
sa.ForeignKeyConstraint(['paid_by_user_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['paid_to_user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_settlements_id'), 'settlements', ['id'], unique=False)
op.create_table('user_groups',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('group_id', sa.Integer(), nullable=False),
sa.Column('role', sa.Enum('owner', 'member', name='userroleenum'), nullable=False),
sa.Column('joined_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('user_id', 'group_id', name='uq_user_group')
)
op.create_index(op.f('ix_user_groups_id'), 'user_groups', ['id'], unique=False)
op.create_table('items',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('list_id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('quantity', sa.String(), nullable=True),
sa.Column('is_complete', sa.Boolean(), nullable=False),
sa.Column('price', sa.Numeric(precision=10, scale=2), nullable=True),
sa.Column('added_by_id', sa.Integer(), nullable=False),
sa.Column('completed_by_id', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('version', sa.Integer(), server_default='1', nullable=False),
sa.ForeignKeyConstraint(['added_by_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['completed_by_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['list_id'], ['lists.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_items_id'), 'items', ['id'], unique=False)
op.create_index(op.f('ix_items_name'), 'items', ['name'], unique=False)
op.create_table('expenses',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('description', sa.String(), nullable=False),
sa.Column('total_amount', sa.Numeric(precision=10, scale=2), nullable=False),
sa.Column('currency', sa.String(), nullable=False),
sa.Column('expense_date', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('split_type', sa.Enum('EQUAL', 'EXACT_AMOUNTS', 'PERCENTAGE', 'SHARES', 'ITEM_BASED', name='splittypeenum'), nullable=False),
sa.Column('list_id', sa.Integer(), nullable=True),
sa.Column('group_id', sa.Integer(), nullable=True),
sa.Column('item_id', sa.Integer(), nullable=True),
sa.Column('paid_by_user_id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('version', sa.Integer(), server_default='1', nullable=False),
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ),
sa.ForeignKeyConstraint(['item_id'], ['items.id'], ),
sa.ForeignKeyConstraint(['list_id'], ['lists.id'], ),
sa.ForeignKeyConstraint(['paid_by_user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_expenses_id'), 'expenses', ['id'], unique=False)
op.create_table('expense_splits',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('expense_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('owed_amount', sa.Numeric(precision=10, scale=2), nullable=False),
sa.Column('share_percentage', sa.Numeric(precision=5, scale=2), nullable=True),
sa.Column('share_units', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['expense_id'], ['expenses.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('expense_id', 'user_id', name='uq_expense_user_split')
)
op.create_index(op.f('ix_expense_splits_id'), 'expense_splits', ['id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_expense_splits_id'), table_name='expense_splits')
op.drop_table('expense_splits')
op.drop_index(op.f('ix_expenses_id'), table_name='expenses')
op.drop_table('expenses')
op.drop_index(op.f('ix_items_name'), table_name='items')
op.drop_index(op.f('ix_items_id'), table_name='items')
op.drop_table('items')
op.drop_index(op.f('ix_user_groups_id'), table_name='user_groups')
op.drop_table('user_groups')
op.drop_index(op.f('ix_settlements_id'), table_name='settlements')
op.drop_table('settlements')
op.drop_index(op.f('ix_lists_name'), table_name='lists')
op.drop_index(op.f('ix_lists_id'), table_name='lists')
op.drop_table('lists')
op.drop_index(op.f('ix_invites_id'), table_name='invites')
op.drop_index(op.f('ix_invites_code'), table_name='invites')
op.drop_index('ix_invites_active_code', table_name='invites', postgresql_where=sa.text('is_active = true'))
op.drop_table('invites')
op.drop_index(op.f('ix_groups_name'), table_name='groups')
op.drop_index(op.f('ix_groups_id'), table_name='groups')
op.drop_table('groups')
op.drop_index(op.f('ix_users_name'), table_name='users')
op.drop_index(op.f('ix_users_id'), table_name='users')
op.drop_index(op.f('ix_users_email'), table_name='users')
op.drop_table('users')
# ### end Alembic commands ###
# app/api/api_router.py
from fastapi import APIRouter
from app.api.v1.api import api_router_v1 # Import the v1 router
api_router = APIRouter()
# Include versioned routers here, adding the /api prefix
api_router.include_router(api_router_v1, prefix="/v1") # Mounts v1 endpoints under /api/v1/...
# Add other API versions later
# e.g., api_router.include_router(api_router_v2, prefix="/v2")
# app/api/v1/endpoints/financials.py
import logging
from fastapi import APIRouter, Depends, HTTPException, status, Query, Response
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from typing import List as PyList, Optional, Sequence
from app.database import get_db
from app.api.dependencies import get_current_user
from app.models import User as UserModel, Group as GroupModel, List as ListModel, UserGroup as UserGroupModel, UserRoleEnum
from app.schemas.expense import (
ExpenseCreate, ExpensePublic,
SettlementCreate, SettlementPublic,
ExpenseUpdate, SettlementUpdate
)
from app.crud import expense as crud_expense
from app.crud import settlement as crud_settlement
from app.crud import group as crud_group
from app.crud import list as crud_list
from app.core.exceptions import (
ListNotFoundError, GroupNotFoundError, UserNotFoundError,
InvalidOperationError, GroupPermissionError, ListPermissionError,
ItemNotFoundError, GroupMembershipError
)
logger = logging.getLogger(__name__)
router = APIRouter()
# --- Helper for permissions ---
async def check_list_access_for_financials(db: AsyncSession, list_id: int, user_id: int, action: str = "access financial data for"):
try:
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=user_id, require_member=True)
except ListPermissionError as e:
logger.warning(f"ListPermissionError in check_list_access_for_financials for list {list_id}, user {user_id}, action '{action}': {e.detail}")
raise ListPermissionError(list_id, action=action)
except ListNotFoundError:
raise
# --- Expense Endpoints ---
@router.post(
"/expenses",
response_model=ExpensePublic,
status_code=status.HTTP_201_CREATED,
summary="Create New Expense",
tags=["Expenses"]
)
async def create_new_expense(
expense_in: ExpenseCreate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
logger.info(f"User {current_user.email} creating expense: {expense_in.description}")
effective_group_id = expense_in.group_id
is_group_context = False
if expense_in.list_id:
# Check basic access to list (implies membership if list is in group)
await check_list_access_for_financials(db, expense_in.list_id, current_user.id, action="create expenses for")
list_obj = await db.get(ListModel, expense_in.list_id)
if not list_obj:
raise ListNotFoundError(expense_in.list_id)
if list_obj.group_id:
if expense_in.group_id and list_obj.group_id != expense_in.group_id:
raise InvalidOperationError(f"List {list_obj.id} belongs to group {list_obj.group_id}, not group {expense_in.group_id} specified in expense.")
effective_group_id = list_obj.group_id
is_group_context = True # Expense is tied to a group via the list
elif expense_in.group_id:
raise InvalidOperationError(f"Personal list {list_obj.id} cannot have expense associated with group {expense_in.group_id}.")
# If list is personal, no group check needed yet, handled by payer check below.
elif effective_group_id: # Only group_id provided for expense
is_group_context = True
# Ensure user is at least a member to create expense in group context
await crud_group.check_group_membership(db, group_id=effective_group_id, user_id=current_user.id, action="create expenses for")
else:
# This case should ideally be caught by earlier checks if list_id was present but list was personal.
# If somehow reached, it means no list_id and no group_id.
raise InvalidOperationError("Expense must be linked to a list_id or group_id.")
# Finalize expense payload with correct group_id if derived
expense_in_final = expense_in.model_copy(update={"group_id": effective_group_id})
# --- Granular Permission Check for Payer ---
if expense_in_final.paid_by_user_id != current_user.id:
logger.warning(f"User {current_user.email} attempting to create expense paid by other user {expense_in_final.paid_by_user_id}")
# If creating expense paid by someone else, user MUST be owner IF in group context
if is_group_context and effective_group_id:
try:
await crud_group.check_user_role_in_group(db, group_id=effective_group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="create expense paid by another user")
except GroupPermissionError as e:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Only group owners can create expenses paid by others. {str(e)}")
else:
# Cannot create expense paid by someone else for a personal list (no group context)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Cannot create expense paid by another user for a personal list.")
# If paying for self, basic list/group membership check above is sufficient.
try:
created_expense = await crud_expense.create_expense(db=db, expense_in=expense_in_final, current_user_id=current_user.id)
logger.info(f"Expense '{created_expense.description}' (ID: {created_expense.id}) created successfully.")
return created_expense
except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError, GroupMembershipError) as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except NotImplementedError as e:
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error creating expense: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
@router.get("/expenses/{expense_id}", response_model=ExpensePublic, summary="Get Expense by ID", tags=["Expenses"])
async def get_expense(
expense_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
logger.info(f"User {current_user.email} requesting expense ID {expense_id}")
expense = await crud_expense.get_expense_by_id(db, expense_id=expense_id)
if not expense:
raise ItemNotFoundError(item_id=expense_id)
if expense.list_id:
await check_list_access_for_financials(db, expense.list_id, current_user.id)
elif expense.group_id:
await crud_group.check_group_membership(db, group_id=expense.group_id, user_id=current_user.id)
elif expense.paid_by_user_id != current_user.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to view this expense")
return expense
@router.get("/lists/{list_id}/expenses", response_model=PyList[ExpensePublic], summary="List Expenses for a List", tags=["Expenses", "Lists"])
async def list_list_expenses(
list_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=200),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
logger.info(f"User {current_user.email} listing expenses for list ID {list_id}")
await check_list_access_for_financials(db, list_id, current_user.id)
expenses = await crud_expense.get_expenses_for_list(db, list_id=list_id, skip=skip, limit=limit)
return expenses
@router.get("/groups/{group_id}/expenses", response_model=PyList[ExpensePublic], summary="List Expenses for a Group", tags=["Expenses", "Groups"])
async def list_group_expenses(
group_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=200),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
logger.info(f"User {current_user.email} listing expenses for group ID {group_id}")
await crud_group.check_group_membership(db, group_id=group_id, user_id=current_user.id, action="list expenses for")
expenses = await crud_expense.get_expenses_for_group(db, group_id=group_id, skip=skip, limit=limit)
return expenses
@router.put("/expenses/{expense_id}", response_model=ExpensePublic, summary="Update Expense", tags=["Expenses"])
async def update_expense_details(
expense_id: int,
expense_in: ExpenseUpdate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Updates an existing expense (description, currency, expense_date only).
Requires the current version number for optimistic locking.
User must have permission to modify the expense (e.g., be the payer or group admin).
"""
logger.info(f"User {current_user.email} attempting to update expense ID {expense_id} (version {expense_in.version})")
expense_db = await crud_expense.get_expense_by_id(db, expense_id=expense_id)
if not expense_db:
raise ItemNotFoundError(item_id=expense_id)
# --- Granular Permission Check ---
can_modify = False
# 1. User paid for the expense
if expense_db.paid_by_user_id == current_user.id:
can_modify = True
# 2. OR User is owner of the group the expense belongs to
elif expense_db.group_id:
try:
await crud_group.check_user_role_in_group(db, group_id=expense_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="modify group expenses")
can_modify = True
logger.info(f"Allowing update for expense {expense_id} by group owner {current_user.email}")
except GroupMembershipError: # User not even a member
pass # Keep can_modify as False
except GroupPermissionError: # User is member but not owner
pass # Keep can_modify as False
except GroupNotFoundError: # Group doesn't exist (data integrity issue)
logger.error(f"Group {expense_db.group_id} not found for expense {expense_id} during update check.")
pass # Keep can_modify as False
# Note: If expense is only linked to a personal list (no group), only payer can modify.
if not can_modify:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot modify this expense (must be payer or group owner)")
try:
updated_expense = await crud_expense.update_expense(db=db, expense_db=expense_db, expense_in=expense_in)
logger.info(f"Expense ID {expense_id} updated successfully to version {updated_expense.version}.")
return updated_expense
except InvalidOperationError as e:
# Check if it's a version conflict (409) or other validation error (400)
status_code = status.HTTP_400_BAD_REQUEST
if "version" in str(e).lower():
status_code = status.HTTP_409_CONFLICT
raise HTTPException(status_code=status_code, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error updating expense {expense_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
@router.delete("/expenses/{expense_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete Expense", tags=["Expenses"])
async def delete_expense_record(
expense_id: int,
expected_version: Optional[int] = Query(None, description="Expected version for optimistic locking"),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Deletes an expense and its associated splits.
Requires expected_version query parameter for optimistic locking.
User must have permission to delete the expense (e.g., be the payer or group admin).
"""
logger.info(f"User {current_user.email} attempting to delete expense ID {expense_id} (expected version {expected_version})")
expense_db = await crud_expense.get_expense_by_id(db, expense_id=expense_id)
if not expense_db:
# Return 204 even if not found, as the end state is achieved (item is gone)
logger.warning(f"Attempt to delete non-existent expense ID {expense_id}")
return Response(status_code=status.HTTP_204_NO_CONTENT)
# Alternatively, raise NotFoundError(detail=f"Expense {expense_id} not found") -> 404
# --- Granular Permission Check ---
can_delete = False
# 1. User paid for the expense
if expense_db.paid_by_user_id == current_user.id:
can_delete = True
# 2. OR User is owner of the group the expense belongs to
elif expense_db.group_id:
try:
await crud_group.check_user_role_in_group(db, group_id=expense_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="delete group expenses")
can_delete = True
logger.info(f"Allowing delete for expense {expense_id} by group owner {current_user.email}")
except GroupMembershipError:
pass
except GroupPermissionError:
pass
except GroupNotFoundError:
logger.error(f"Group {expense_db.group_id} not found for expense {expense_id} during delete check.")
pass
if not can_delete:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot delete this expense (must be payer or group owner)")
try:
await crud_expense.delete_expense(db=db, expense_db=expense_db, expected_version=expected_version)
logger.info(f"Expense ID {expense_id} deleted successfully.")
# No need to return content on 204
except InvalidOperationError as e:
# Check if it's a version conflict (409) or other validation error (400)
status_code = status.HTTP_400_BAD_REQUEST
if "version" in str(e).lower():
status_code = status.HTTP_409_CONFLICT
raise HTTPException(status_code=status_code, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error deleting expense {expense_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
return Response(status_code=status.HTTP_204_NO_CONTENT)
# --- Settlement Endpoints ---
@router.post(
"/settlements",
response_model=SettlementPublic,
status_code=status.HTTP_201_CREATED,
summary="Record New Settlement",
tags=["Settlements"]
)
async def create_new_settlement(
settlement_in: SettlementCreate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
logger.info(f"User {current_user.email} recording settlement in group {settlement_in.group_id}")
await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=current_user.id, action="record settlements in")
try:
await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=settlement_in.paid_by_user_id, action="be a payer in this group's settlement")
await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=settlement_in.paid_to_user_id, action="be a payee in this group's settlement")
except GroupMembershipError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Payer or payee issue: {str(e)}")
except GroupNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
try:
created_settlement = await crud_settlement.create_settlement(db=db, settlement_in=settlement_in, current_user_id=current_user.id)
logger.info(f"Settlement ID {created_settlement.id} recorded successfully in group {settlement_in.group_id}.")
return created_settlement
except (UserNotFoundError, GroupNotFoundError, InvalidOperationError, GroupMembershipError) as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error recording settlement: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
@router.get("/settlements/{settlement_id}", response_model=SettlementPublic, summary="Get Settlement by ID", tags=["Settlements"])
async def get_settlement(
settlement_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
logger.info(f"User {current_user.email} requesting settlement ID {settlement_id}")
settlement = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id)
if not settlement:
raise ItemNotFoundError(item_id=settlement_id)
is_party_to_settlement = current_user.id in [settlement.paid_by_user_id, settlement.paid_to_user_id]
try:
await crud_group.check_group_membership(db, group_id=settlement.group_id, user_id=current_user.id)
except GroupMembershipError:
if not is_party_to_settlement:
raise GroupMembershipError(settlement.group_id, action="view this settlement's details")
logger.info(f"User {current_user.email} (party to settlement) viewing settlement {settlement_id} for group {settlement.group_id}.")
return settlement
@router.get("/groups/{group_id}/settlements", response_model=PyList[SettlementPublic], summary="List Settlements for a Group", tags=["Settlements", "Groups"])
async def list_group_settlements(
group_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=200),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
logger.info(f"User {current_user.email} listing settlements for group ID {group_id}")
await crud_group.check_group_membership(db, group_id=group_id, user_id=current_user.id, action="list settlements for this group")
settlements = await crud_settlement.get_settlements_for_group(db, group_id=group_id, skip=skip, limit=limit)
return settlements
@router.put("/settlements/{settlement_id}", response_model=SettlementPublic, summary="Update Settlement", tags=["Settlements"])
async def update_settlement_details(
settlement_id: int,
settlement_in: SettlementUpdate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Updates an existing settlement (description, settlement_date only).
Requires the current version number for optimistic locking.
User must have permission (e.g., be involved party or group admin).
"""
logger.info(f"User {current_user.email} attempting to update settlement ID {settlement_id} (version {settlement_in.version})")
settlement_db = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id)
if not settlement_db:
raise ItemNotFoundError(item_id=settlement_id)
# --- Granular Permission Check ---
can_modify = False
# 1. User is involved party (payer or payee)
is_party = current_user.id in [settlement_db.paid_by_user_id, settlement_db.paid_to_user_id]
if is_party:
can_modify = True
# 2. OR User is owner of the group the settlement belongs to
# Note: Settlements always have a group_id based on current model
elif settlement_db.group_id:
try:
await crud_group.check_user_role_in_group(db, group_id=settlement_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="modify group settlements")
can_modify = True
logger.info(f"Allowing update for settlement {settlement_id} by group owner {current_user.email}")
except GroupMembershipError:
pass
except GroupPermissionError:
pass
except GroupNotFoundError:
logger.error(f"Group {settlement_db.group_id} not found for settlement {settlement_id} during update check.")
pass
if not can_modify:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot modify this settlement (must be involved party or group owner)")
try:
updated_settlement = await crud_settlement.update_settlement(db=db, settlement_db=settlement_db, settlement_in=settlement_in)
logger.info(f"Settlement ID {settlement_id} updated successfully to version {updated_settlement.version}.")
return updated_settlement
except InvalidOperationError as e:
status_code = status.HTTP_400_BAD_REQUEST
if "version" in str(e).lower():
status_code = status.HTTP_409_CONFLICT
raise HTTPException(status_code=status_code, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error updating settlement {settlement_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
@router.delete("/settlements/{settlement_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete Settlement", tags=["Settlements"])
async def delete_settlement_record(
settlement_id: int,
expected_version: Optional[int] = Query(None, description="Expected version for optimistic locking"),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Deletes a settlement.
Requires expected_version query parameter for optimistic locking.
User must have permission (e.g., be involved party or group admin).
"""
logger.info(f"User {current_user.email} attempting to delete settlement ID {settlement_id} (expected version {expected_version})")
settlement_db = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id)
if not settlement_db:
logger.warning(f"Attempt to delete non-existent settlement ID {settlement_id}")
return Response(status_code=status.HTTP_204_NO_CONTENT)
# --- Granular Permission Check ---
can_delete = False
# 1. User is involved party (payer or payee)
is_party = current_user.id in [settlement_db.paid_by_user_id, settlement_db.paid_to_user_id]
if is_party:
can_delete = True
# 2. OR User is owner of the group the settlement belongs to
elif settlement_db.group_id:
try:
await crud_group.check_user_role_in_group(db, group_id=settlement_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="delete group settlements")
can_delete = True
logger.info(f"Allowing delete for settlement {settlement_id} by group owner {current_user.email}")
except GroupMembershipError:
pass
except GroupPermissionError:
pass
except GroupNotFoundError:
logger.error(f"Group {settlement_db.group_id} not found for settlement {settlement_id} during delete check.")
pass
if not can_delete:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot delete this settlement (must be involved party or group owner)")
try:
await crud_settlement.delete_settlement(db=db, settlement_db=settlement_db, expected_version=expected_version)
logger.info(f"Settlement ID {settlement_id} deleted successfully.")
except InvalidOperationError as e:
status_code = status.HTTP_400_BAD_REQUEST
if "version" in str(e).lower():
status_code = status.HTTP_409_CONFLICT
raise HTTPException(status_code=status_code, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error deleting settlement {settlement_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
return Response(status_code=status.HTTP_204_NO_CONTENT)
# TODO (remaining from original list):
# (None - GET/POST/PUT/DELETE implemented for Expense/Settlement)
# app/api/v1/endpoints/users.py
import logging
from fastapi import APIRouter, Depends, HTTPException
from app.api.dependencies import get_current_user # Import the dependency
from app.schemas.user import UserPublic # Import the response schema
from app.models import User as UserModel # Import the DB model for type hinting
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get(
"/me",
response_model=UserPublic, # Use the public schema to avoid exposing hash
summary="Get Current User",
description="Retrieves the details of the currently authenticated user.",
tags=["Users"]
)
async def read_users_me(
current_user: UserModel = Depends(get_current_user) # Apply the dependency
):
"""
Returns the data for the user associated with the current valid access token.
"""
logger.info(f"Fetching details for current user: {current_user.email}")
# The 'current_user' object is the SQLAlchemy model instance returned by the dependency.
# Pydantic's response_model will automatically convert it using UserPublic schema.
return current_user
# Add other user-related endpoints here later (e.g., update user, list users (admin))
# Example: be/tests/api/v1/test_auth.py
import pytest
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.security import verify_password
from app.crud.user import get_user_by_email
from app.schemas.user import UserPublic # Import for response validation
from app.schemas.auth import Token # Import for response validation
pytestmark = pytest.mark.asyncio
async def test_signup_success(client: AsyncClient, db: AsyncSession):
email = "testsignup@example.com"
password = "testpassword123"
response = await client.post(
"/api/v1/auth/signup",
json={"email": email, "password": password, "name": "Test Signup"},
)
assert response.status_code == 201
data = response.json()
assert data["email"] == email
assert data["name"] == "Test Signup"
assert "id" in data
assert "created_at" in data
# Verify password hash is NOT returned
assert "password" not in data
assert "hashed_password" not in data
# Verify user exists in DB
user_db = await get_user_by_email(db, email=email)
assert user_db is not None
assert user_db.email == email
assert verify_password(password, user_db.hashed_password)
async def test_signup_email_exists(client: AsyncClient, db: AsyncSession):
# Create user first
email = "testexists@example.com"
password = "testpassword123"
await client.post(
"/api/v1/auth/signup",
json={"email": email, "password": password},
)
# Attempt signup again with same email
response = await client.post(
"/api/v1/auth/signup",
json={"email": email, "password": "anotherpassword"},
)
assert response.status_code == 400
assert "Email already registered" in response.json()["detail"]
async def test_login_success(client: AsyncClient, db: AsyncSession):
email = "testlogin@example.com"
password = "testpassword123"
# Create user first via signup
signup_res = await client.post(
"/api/v1/auth/signup", json={"email": email, "password": password}
)
assert signup_res.status_code == 201
# Attempt login
login_payload = {"username": email, "password": password}
response = await client.post("/api/v1/auth/login", data=login_payload) # Use data for form encoding
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert data["token_type"] == "bearer"
# Optionally verify the token itself here using verify_access_token
async def test_login_wrong_password(client: AsyncClient, db: AsyncSession):
email = "testloginwrong@example.com"
password = "testpassword123"
await client.post(
"/api/v1/auth/signup", json={"email": email, "password": password}
)
login_payload = {"username": email, "password": "wrongpassword"}
response = await client.post("/api/v1/auth/login", data=login_payload)
assert response.status_code == 401
assert "Incorrect email or password" in response.json()["detail"]
assert "WWW-Authenticate" in response.headers
assert response.headers["WWW-Authenticate"] == "Bearer"
async def test_login_user_not_found(client: AsyncClient, db: AsyncSession):
login_payload = {"username": "nosuchuser@example.com", "password": "anypassword"}
response = await client.post("/api/v1/auth/login", data=login_payload)
assert response.status_code == 401
assert "Incorrect email or password" in response.json()["detail"]
# Example: be/tests/api/v1/test_users.py
import pytest
from httpx import AsyncClient
from app.schemas.user import UserPublic # For response validation
from app.core.security import create_access_token
pytestmark = pytest.mark.asyncio
# Helper function to get a valid token
async def get_auth_headers(client: AsyncClient, email: str, password: str) -> dict:
"""Logs in a user and returns authorization headers."""
login_payload = {"username": email, "password": password}
response = await client.post("/api/v1/auth/login", data=login_payload)
response.raise_for_status() # Raise exception for non-2xx status
token_data = response.json()
return {"Authorization": f"Bearer {token_data['access_token']}"}
async def test_read_users_me_success(client: AsyncClient):
# 1. Create user
email = "testme@example.com"
password = "password123"
signup_res = await client.post(
"/api/v1/auth/signup", json={"email": email, "password": password, "name": "Test Me"}
)
assert signup_res.status_code == 201
user_data = UserPublic(**signup_res.json()) # Validate signup response
# 2. Get token
headers = await get_auth_headers(client, email, password)
# 3. Request /users/me
response = await client.get("/api/v1/users/me", headers=headers)
assert response.status_code == 200
me_data = response.json()
assert me_data["email"] == email
assert me_data["name"] == "Test Me"
assert me_data["id"] == user_data.id # Check ID matches signup
assert "password" not in me_data
assert "hashed_password" not in me_data
async def test_read_users_me_no_token(client: AsyncClient):
response = await client.get("/api/v1/users/me") # No headers
assert response.status_code == 401 # Handled by OAuth2PasswordBearer
assert response.json()["detail"] == "Not authenticated" # Default detail from OAuth2PasswordBearer
async def test_read_users_me_invalid_token(client: AsyncClient):
headers = {"Authorization": "Bearer invalid-token-string"}
response = await client.get("/api/v1/users/me", headers=headers)
assert response.status_code == 401
assert response.json()["detail"] == "Could not validate credentials" # Detail from our dependency
async def test_read_users_me_expired_token(client: AsyncClient):
# Create a short-lived token manually (or adjust settings temporarily)
email = "testexpired@example.com"
# Assume create_access_token allows timedelta override
expired_token = create_access_token(subject=email, expires_delta=timedelta(seconds=-10))
headers = {"Authorization": f"Bearer {expired_token}"}
response = await client.get("/api/v1/users/me", headers=headers)
assert response.status_code == 401
assert response.json()["detail"] == "Could not validate credentials"
# Add test case for valid token but user deleted from DB if needed
from typing import Dict, Any
from app.config import settings
# API Version
API_VERSION = "v1"
# API Prefix
API_PREFIX = f"/api/{API_VERSION}"
# API Endpoints
class APIEndpoints:
# Auth
AUTH = {
"LOGIN": "/auth/login",
"SIGNUP": "/auth/signup",
"REFRESH_TOKEN": "/auth/refresh-token",
}
# Users
USERS = {
"PROFILE": "/users/profile",
"UPDATE_PROFILE": "/users/profile",
}
# Lists
LISTS = {
"BASE": "/lists",
"BY_ID": "/lists/{id}",
"ITEMS": "/lists/{list_id}/items",
"ITEM": "/lists/{list_id}/items/{item_id}",
}
# Groups
GROUPS = {
"BASE": "/groups",
"BY_ID": "/groups/{id}",
"LISTS": "/groups/{group_id}/lists",
"MEMBERS": "/groups/{group_id}/members",
}
# Invites
INVITES = {
"BASE": "/invites",
"BY_ID": "/invites/{id}",
"ACCEPT": "/invites/{id}/accept",
"DECLINE": "/invites/{id}/decline",
}
# OCR
OCR = {
"PROCESS": "/ocr/process",
}
# Financials
FINANCIALS = {
"EXPENSES": "/financials/expenses",
"EXPENSE": "/financials/expenses/{id}",
"SETTLEMENTS": "/financials/settlements",
"SETTLEMENT": "/financials/settlements/{id}",
}
# Health
HEALTH = {
"CHECK": "/health",
}
# API Metadata
API_METADATA = {
"title": settings.API_TITLE,
"description": settings.API_DESCRIPTION,
"version": settings.API_VERSION,
"openapi_url": settings.API_OPENAPI_URL,
"docs_url": settings.API_DOCS_URL,
"redoc_url": settings.API_REDOC_URL,
}
# API Tags
API_TAGS = [
{"name": "Authentication", "description": "Authentication and authorization endpoints"},
{"name": "Users", "description": "User management endpoints"},
{"name": "Lists", "description": "Shopping list management endpoints"},
{"name": "Groups", "description": "Group management endpoints"},
{"name": "Invites", "description": "Group invitation management endpoints"},
{"name": "OCR", "description": "Optical Character Recognition endpoints"},
{"name": "Financials", "description": "Financial management endpoints"},
{"name": "Health", "description": "Health check endpoints"},
]
# Helper function to get full API URL
def get_api_url(endpoint: str, **kwargs) -> str:
"""
Get the full API URL for an endpoint.
Args:
endpoint: The endpoint path
**kwargs: Path parameters to format the endpoint
Returns:
str: The full API URL
"""
formatted_endpoint = endpoint.format(**kwargs)
return f"{API_PREFIX}{formatted_endpoint}"
# app/crud/expense.py
import logging # Add logging import
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict
from datetime import datetime, timezone # Added timezone
from app.models import (
Expense as ExpenseModel,
ExpenseSplit as ExpenseSplitModel,
User as UserModel,
List as ListModel,
Group as GroupModel,
UserGroup as UserGroupModel,
SplitTypeEnum,
Item as ItemModel
)
from app.schemas.expense import ExpenseCreate, ExpenseSplitCreate, ExpenseUpdate # Removed unused ExpenseUpdate
from app.core.exceptions import (
# Using existing specific exceptions where possible
ListNotFoundError,
GroupNotFoundError,
UserNotFoundError,
InvalidOperationError # Import the new exception
)
# Placeholder for InvalidOperationError if not defined in app.core.exceptions
# This should be a proper HTTPException subclass if used in API layer
# class CrudInvalidOperationError(ValueError): # For internal CRUD validation logic
# pass
logger = logging.getLogger(__name__) # Initialize logger
async def get_users_for_splitting(db: AsyncSession, expense_group_id: Optional[int], expense_list_id: Optional[int], expense_paid_by_user_id: int) -> PyList[UserModel]:
"""
Determines the list of users an expense should be split amongst.
Priority: Group members (if group_id), then List's group members or creator (if list_id).
Fallback to only the payer if no other context yields users.
"""
users_to_split_with: PyList[UserModel] = []
processed_user_ids = set()
async def _add_user(user: Optional[UserModel]):
if user and user.id not in processed_user_ids:
users_to_split_with.append(user)
processed_user_ids.add(user.id)
if expense_group_id:
group_result = await db.execute(
select(GroupModel).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user)))
.where(GroupModel.id == expense_group_id)
)
group = group_result.scalars().first()
if not group:
raise GroupNotFoundError(expense_group_id)
for assoc in group.member_associations:
await _add_user(assoc.user)
elif expense_list_id: # Only if group_id was not primary context
list_result = await db.execute(
select(ListModel)
.options(
selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))),
selectinload(ListModel.creator)
)
.where(ListModel.id == expense_list_id)
)
db_list = list_result.scalars().first()
if not db_list:
raise ListNotFoundError(expense_list_id)
if db_list.group:
for assoc in db_list.group.member_associations:
await _add_user(assoc.user)
elif db_list.creator:
await _add_user(db_list.creator)
if not users_to_split_with:
payer_user = await db.get(UserModel, expense_paid_by_user_id)
if not payer_user:
# This should have been caught earlier if paid_by_user_id was validated before calling this helper
raise UserNotFoundError(user_id=expense_paid_by_user_id)
await _add_user(payer_user)
if not users_to_split_with:
# This should ideally not be reached if payer is always a fallback
raise InvalidOperationError("Could not determine any users for splitting the expense.")
return users_to_split_with
async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_user_id: int) -> ExpenseModel:
"""Creates a new expense and its associated splits.
Args:
db: Database session
expense_in: Expense creation data
current_user_id: ID of the user creating the expense
Returns:
The created expense with splits
Raises:
UserNotFoundError: If payer or split users don't exist
ListNotFoundError: If specified list doesn't exist
GroupNotFoundError: If specified group doesn't exist
InvalidOperationError: For various validation failures
"""
# Helper function to round decimals consistently
def round_money(amount: Decimal) -> Decimal:
return amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
# 1. Context Validation
# Validate basic context requirements first
if not expense_in.list_id and not expense_in.group_id:
raise InvalidOperationError("Expense must be associated with a list or a group.")
# 2. User Validation
payer = await db.get(UserModel, expense_in.paid_by_user_id)
if not payer:
raise UserNotFoundError(user_id=expense_in.paid_by_user_id)
# 3. List/Group Context Resolution
final_group_id = await _resolve_expense_context(db, expense_in)
# 4. Create the expense object
db_expense = ExpenseModel(
description=expense_in.description,
total_amount=round_money(expense_in.total_amount),
currency=expense_in.currency or "USD",
expense_date=expense_in.expense_date or datetime.now(timezone.utc),
split_type=expense_in.split_type,
list_id=expense_in.list_id,
group_id=final_group_id,
item_id=expense_in.item_id,
paid_by_user_id=expense_in.paid_by_user_id,
created_by_user_id=current_user_id # Track who created this expense
)
# 5. Generate splits based on split type
splits_to_create = await _generate_expense_splits(db, db_expense, expense_in, round_money)
# 6. Single transaction for expense and all splits
try:
db.add(db_expense)
await db.flush() # Get expense ID without committing
# Update all splits with the expense ID
for split in splits_to_create:
split.expense_id = db_expense.id
db.add_all(splits_to_create)
await db.commit()
except Exception as e:
await db.rollback()
logger.error(f"Failed to save expense: {str(e)}", exc_info=True)
raise InvalidOperationError(f"Failed to save expense: {str(e)}")
# Refresh to get the splits relationship populated
await db.refresh(db_expense, attribute_names=["splits"])
return db_expense
async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]:
"""Resolves and validates the expense's context (list and group).
Returns the final group_id for the expense after validation.
"""
final_group_id = expense_in.group_id
# If list_id is provided, validate it and potentially derive group_id
if expense_in.list_id:
list_obj = await db.get(ListModel, expense_in.list_id)
if not list_obj:
raise ListNotFoundError(expense_in.list_id)
# If list belongs to a group, verify consistency or inherit group_id
if list_obj.group_id:
if expense_in.group_id and list_obj.group_id != expense_in.group_id:
raise InvalidOperationError(
f"List {expense_in.list_id} belongs to group {list_obj.group_id}, "
f"but expense was specified for group {expense_in.group_id}."
)
final_group_id = list_obj.group_id # Prioritize list's group
# If only group_id is provided (no list_id), validate group_id
elif final_group_id:
group_obj = await db.get(GroupModel, final_group_id)
if not group_obj:
raise GroupNotFoundError(final_group_id)
return final_group_id
async def _generate_expense_splits(
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Generates appropriate expense splits based on split type."""
splits_to_create: PyList[ExpenseSplitModel] = []
# Create splits based on the split type
if expense_in.split_type == SplitTypeEnum.EQUAL:
splits_to_create = await _create_equal_splits(
db, db_expense, expense_in, round_money
)
elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS:
splits_to_create = await _create_exact_amount_splits(
db, db_expense, expense_in, round_money
)
elif expense_in.split_type == SplitTypeEnum.PERCENTAGE:
splits_to_create = await _create_percentage_splits(
db, db_expense, expense_in, round_money
)
elif expense_in.split_type == SplitTypeEnum.SHARES:
splits_to_create = await _create_shares_splits(
db, db_expense, expense_in, round_money
)
elif expense_in.split_type == SplitTypeEnum.ITEM_BASED:
splits_to_create = await _create_item_based_splits(
db, db_expense, expense_in, round_money
)
else:
raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}")
if not splits_to_create:
raise InvalidOperationError("No expense splits were generated.")
return splits_to_create
async def _create_equal_splits(
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates equal splits among users."""
users_for_splitting = await get_users_for_splitting(
db, db_expense.group_id, expense_in.list_id, expense_in.paid_by_user_id
)
if not users_for_splitting:
raise InvalidOperationError("No users found for EQUAL split.")
num_users = len(users_for_splitting)
amount_per_user = round_money(db_expense.total_amount / Decimal(num_users))
remainder = db_expense.total_amount - (amount_per_user * num_users)
splits = []
for i, user in enumerate(users_for_splitting):
split_amount = amount_per_user
if i == 0 and remainder != Decimal('0'):
split_amount = round_money(amount_per_user + remainder)
splits.append(ExpenseSplitModel(
user_id=user.id,
owed_amount=split_amount
))
return splits
async def _create_exact_amount_splits(
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates splits with exact amounts."""
if not expense_in.splits_in:
raise InvalidOperationError("Splits data is required for EXACT_AMOUNTS split type.")
# Validate all users in splits exist
await _validate_users_in_splits(db, expense_in.splits_in)
current_total = Decimal("0.00")
splits = []
for split_in in expense_in.splits_in:
if split_in.owed_amount <= Decimal('0'):
raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.")
rounded_amount = round_money(split_in.owed_amount)
current_total += rounded_amount
splits.append(ExpenseSplitModel(
user_id=split_in.user_id,
owed_amount=rounded_amount
))
if round_money(current_total) != db_expense.total_amount:
raise InvalidOperationError(
f"Sum of exact split amounts ({current_total}) != expense total ({db_expense.total_amount})."
)
return splits
async def _create_percentage_splits(
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates splits based on percentages."""
if not expense_in.splits_in:
raise InvalidOperationError("Splits data is required for PERCENTAGE split type.")
# Validate all users in splits exist
await _validate_users_in_splits(db, expense_in.splits_in)
total_percentage = Decimal("0.00")
current_total = Decimal("0.00")
splits = []
for split_in in expense_in.splits_in:
if not (split_in.share_percentage and Decimal("0") < split_in.share_percentage <= Decimal("100")):
raise InvalidOperationError(
f"Invalid percentage {split_in.share_percentage} for user {split_in.user_id}."
)
total_percentage += split_in.share_percentage
owed_amount = round_money(db_expense.total_amount * (split_in.share_percentage / Decimal("100")))
current_total += owed_amount
splits.append(ExpenseSplitModel(
user_id=split_in.user_id,
owed_amount=owed_amount,
share_percentage=split_in.share_percentage
))
if round_money(total_percentage) != Decimal("100.00"):
raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.")
# Adjust for rounding differences
if current_total != db_expense.total_amount and splits:
diff = db_expense.total_amount - current_total
splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff)
return splits
async def _create_shares_splits(
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates splits based on shares."""
if not expense_in.splits_in:
raise InvalidOperationError("Splits data is required for SHARES split type.")
# Validate all users in splits exist
await _validate_users_in_splits(db, expense_in.splits_in)
# Calculate total shares
total_shares = sum(s.share_units for s in expense_in.splits_in if s.share_units and s.share_units > 0)
if total_shares == 0:
raise InvalidOperationError("Total shares cannot be zero for SHARES split.")
splits = []
current_total = Decimal("0.00")
for split_in in expense_in.splits_in:
if not (split_in.share_units and split_in.share_units > 0):
raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.")
share_ratio = Decimal(split_in.share_units) / Decimal(total_shares)
owed_amount = round_money(db_expense.total_amount * share_ratio)
current_total += owed_amount
splits.append(ExpenseSplitModel(
user_id=split_in.user_id,
owed_amount=owed_amount,
share_units=split_in.share_units
))
# Adjust for rounding differences
if current_total != db_expense.total_amount and splits:
diff = db_expense.total_amount - current_total
splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff)
return splits
async def _create_item_based_splits(
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates splits based on items in a shopping list."""
if not expense_in.list_id:
raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.")
if expense_in.splits_in:
logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.")
# Build query to fetch relevant items
items_query = select(ItemModel).where(ItemModel.list_id == expense_in.list_id)
if expense_in.item_id:
items_query = items_query.where(ItemModel.id == expense_in.item_id)
else:
items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0")))
# Load items with their adders
items_result = await db.execute(items_query.options(selectinload(ItemModel.added_by_user)))
relevant_items = items_result.scalars().all()
if not relevant_items:
error_msg = (
f"Specified item ID {expense_in.item_id} not found in list {expense_in.list_id}."
if expense_in.item_id else
f"List {expense_in.list_id} has no priced items to base the expense on."
)
raise InvalidOperationError(error_msg)
# Aggregate owed amounts by user
calculated_total = Decimal("0.00")
user_owed_amounts = defaultdict(Decimal)
processed_items = 0
for item in relevant_items:
if item.price is None or item.price <= Decimal("0"):
if expense_in.item_id:
raise InvalidOperationError(
f"Item ID {expense_in.item_id} must have a positive price for ITEM_BASED expense."
)
continue
if not item.added_by_user:
logger.error(f"Item ID {item.id} is missing added_by_user relationship.")
raise InvalidOperationError(f"Data integrity issue: Item {item.id} is missing adder information.")
calculated_total += item.price
user_owed_amounts[item.added_by_user.id] += item.price
processed_items += 1
if processed_items == 0:
raise InvalidOperationError(
f"No items with positive prices found in list {expense_in.list_id} to create ITEM_BASED expense."
)
# Validate total matches calculated total
if round_money(calculated_total) != db_expense.total_amount:
raise InvalidOperationError(
f"Expense total amount ({db_expense.total_amount}) does not match the "
f"calculated total from item prices ({calculated_total})."
)
# Create splits based on aggregated amounts
splits = []
for user_id, owed_amount in user_owed_amounts.items():
splits.append(ExpenseSplitModel(
user_id=user_id,
owed_amount=round_money(owed_amount)
))
return splits
async def _validate_users_in_splits(db: AsyncSession, splits_in: PyList[ExpenseSplitCreate]) -> None:
"""Validates that all users in the splits exist."""
user_ids_in_split = [s.user_id for s in splits_in]
user_results = await db.execute(select(UserModel.id).where(UserModel.id.in_(user_ids_in_split)))
found_user_ids = {row[0] for row in user_results}
if len(found_user_ids) != len(user_ids_in_split):
missing_user_ids = set(user_ids_in_split) - found_user_ids
raise UserNotFoundError(identifier=f"users in split data: {list(missing_user_ids)}")
async def get_expense_by_id(db: AsyncSession, expense_id: int) -> Optional[ExpenseModel]:
result = await db.execute(
select(ExpenseModel)
.options(
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)),
selectinload(ExpenseModel.paid_by_user),
selectinload(ExpenseModel.list),
selectinload(ExpenseModel.group),
selectinload(ExpenseModel.item)
)
.where(ExpenseModel.id == expense_id)
)
return result.scalars().first()
async def get_expenses_for_list(db: AsyncSession, list_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.list_id == list_id)
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
.offset(skip).limit(limit)
.options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user))) # Also load user for each split
)
return result.scalars().all()
async def get_expenses_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.group_id == group_id)
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
.offset(skip).limit(limit)
.options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)))
)
return result.scalars().all()
async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in: ExpenseUpdate) -> ExpenseModel:
"""
Updates an existing expense.
Only allows updates to description, currency, and expense_date to avoid split complexities.
Requires version matching for optimistic locking.
"""
if expense_db.version != expense_in.version:
raise InvalidOperationError(
f"Expense '{expense_db.description}' (ID: {expense_db.id}) has been modified. "
f"Your version is {expense_in.version}, current version is {expense_db.version}. Please refresh.",
# status_code=status.HTTP_409_CONFLICT # This would be for the API layer to set
)
update_data = expense_in.model_dump(exclude_unset=True, exclude={"version"}) # Exclude version itself from data
# Fields that are safe to update without affecting splits or core logic
allowed_to_update = {"description", "currency", "expense_date"}
updated_something = False
for field, value in update_data.items():
if field in allowed_to_update:
setattr(expense_db, field, value)
updated_something = True
else:
# If any other field is present in the update payload, it's an invalid operation for this simple update
raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed.")
if not updated_something and not expense_in.model_fields_set.intersection(allowed_to_update):
# No actual updatable fields were provided in the payload, even if others (like version) were.
# This could be a non-issue, or an indication of a misuse of the endpoint.
# For now, if only version was sent, we still increment if it matched.
pass # Or raise InvalidOperationError("No updatable fields provided.")
expense_db.version += 1
expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp
try:
await db.commit()
await db.refresh(expense_db)
except Exception as e:
await db.rollback()
# Consider specific DB error types if needed
raise InvalidOperationError(f"Failed to update expense: {str(e)}")
return expense_db
async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None:
"""
Deletes an expense. Requires version matching if expected_version is provided.
Associated ExpenseSplits are cascade deleted by the database foreign key constraint.
"""
if expected_version is not None and expense_db.version != expected_version:
raise InvalidOperationError(
f"Expense '{expense_db.description}' (ID: {expense_db.id}) cannot be deleted. "
f"Your expected version {expected_version} does not match current version {expense_db.version}. Please refresh.",
# status_code=status.HTTP_409_CONFLICT
)
await db.delete(expense_db)
try:
await db.commit()
except Exception as e:
await db.rollback()
raise InvalidOperationError(f"Failed to delete expense: {str(e)}")
return None
# Note: The InvalidOperationError is a simple ValueError placeholder.
# For API endpoints, these should be translated to appropriate HTTPExceptions.
# Ensure app.core.exceptions has proper HTTP error classes if needed.
# app/crud/invite.py
import secrets
from datetime import datetime, timedelta, timezone
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import delete # Import delete statement
from typing import Optional
from app.models import Invite as InviteModel
# Invite codes should be reasonably unique, but handle potential collision
MAX_CODE_GENERATION_ATTEMPTS = 5
async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 7) -> Optional[InviteModel]:
"""Creates a new invite code for a group."""
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
code = None
attempts = 0
# Generate a unique code, retrying if a collision occurs (highly unlikely but safe)
while attempts < MAX_CODE_GENERATION_ATTEMPTS:
attempts += 1
potential_code = secrets.token_urlsafe(16)
# Check if an *active* invite with this code already exists
existing = await db.execute(
select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
)
if existing.scalar_one_or_none() is None:
code = potential_code
break
if code is None:
# Failed to generate a unique code after several attempts
return None
db_invite = InviteModel(
code=code,
group_id=group_id,
created_by_id=creator_id,
expires_at=expires_at,
is_active=True
)
db.add(db_invite)
await db.commit()
await db.refresh(db_invite)
return db_invite
async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]:
"""Gets an active and non-expired invite by its code."""
now = datetime.now(timezone.utc)
result = await db.execute(
select(InviteModel).where(
InviteModel.code == code,
InviteModel.is_active == True,
InviteModel.expires_at > now
)
)
return result.scalars().first()
async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteModel:
"""Marks an invite as inactive (used)."""
invite.is_active = False
db.add(invite) # Add to session to track change
await db.commit()
await db.refresh(invite)
return invite
# Optional: Function to periodically delete old, inactive invites
# async def cleanup_old_invites(db: AsyncSession, older_than_days: int = 30): ...
# app/crud/settlement.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import or_
from decimal import Decimal
from typing import List as PyList, Optional, Sequence
from datetime import datetime, timezone
from app.models import (
Settlement as SettlementModel,
User as UserModel,
Group as GroupModel
)
from app.schemas.expense import SettlementCreate, SettlementUpdate # SettlementUpdate not used yet
from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError
async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel:
"""Creates a new settlement record."""
# Validate Payer, Payee, and Group exist
payer = await db.get(UserModel, settlement_in.paid_by_user_id)
if not payer:
raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer")
payee = await db.get(UserModel, settlement_in.paid_to_user_id)
if not payee:
raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee")
if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id:
raise InvalidOperationError("Payer and Payee cannot be the same user.")
group = await db.get(GroupModel, settlement_in.group_id)
if not group:
raise GroupNotFoundError(settlement_in.group_id)
# Optional: Check if current_user_id is part of the group or is one of the parties involved
# This is more of an API-level permission check but could be added here if strict.
# For example: if current_user_id not in [settlement_in.paid_by_user_id, settlement_in.paid_to_user_id]:
# is_in_group = await db.execute(select(UserGroupModel).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id))
# if not is_in_group.first():
# raise InvalidOperationError("You can only record settlements you are part of or for groups you belong to.")
db_settlement = SettlementModel(
group_id=settlement_in.group_id,
paid_by_user_id=settlement_in.paid_by_user_id,
paid_to_user_id=settlement_in.paid_to_user_id,
amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc),
description=settlement_in.description
# created_by_user_id = current_user_id # Optional: Who recorded this settlement
)
db.add(db_settlement)
try:
await db.commit()
await db.refresh(db_settlement, attribute_names=["payer", "payee", "group"])
except Exception as e:
await db.rollback()
raise InvalidOperationError(f"Failed to save settlement: {str(e)}")
return db_settlement
async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional[SettlementModel]:
result = await db.execute(
select(SettlementModel)
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group)
)
.where(SettlementModel.id == settlement_id)
)
return result.scalars().first()
async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]:
result = await db.execute(
select(SettlementModel)
.where(SettlementModel.group_id == group_id)
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
.offset(skip).limit(limit)
.options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee))
)
return result.scalars().all()
async def get_settlements_involving_user(
db: AsyncSession,
user_id: int,
group_id: Optional[int] = None,
skip: int = 0,
limit: int = 100
) -> Sequence[SettlementModel]:
query = (
select(SettlementModel)
.where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id))
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
.offset(skip).limit(limit)
.options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), selectinload(SettlementModel.group))
)
if group_id:
query = query.where(SettlementModel.group_id == group_id)
result = await db.execute(query)
return result.scalars().all()
async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, settlement_in: SettlementUpdate) -> SettlementModel:
"""
Updates an existing settlement.
Only allows updates to description and settlement_date.
Requires version matching for optimistic locking.
Assumes SettlementUpdate schema includes a version field.
"""
# Check if SettlementUpdate schema has 'version'. If not, this check needs to be adapted or version passed differently.
if not hasattr(settlement_in, 'version') or settlement_db.version != settlement_in.version:
raise InvalidOperationError(
f"Settlement (ID: {settlement_db.id}) has been modified. "
f"Your version does not match current version {settlement_db.version}. Please refresh.",
# status_code=status.HTTP_409_CONFLICT
)
update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"})
allowed_to_update = {"description", "settlement_date"}
updated_something = False
for field, value in update_data.items():
if field in allowed_to_update:
setattr(settlement_db, field, value)
updated_something = True
else:
raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed for settlements.")
if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update):
pass # No actual updatable fields provided, but version matched.
settlement_db.version += 1 # Assuming SettlementModel has a version field, add if missing
settlement_db.updated_at = datetime.now(timezone.utc)
try:
await db.commit()
await db.refresh(settlement_db)
except Exception as e:
await db.rollback()
raise InvalidOperationError(f"Failed to update settlement: {str(e)}")
return settlement_db
async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, expected_version: Optional[int] = None) -> None:
"""
Deletes a settlement. Requires version matching if expected_version is provided.
Assumes SettlementModel has a version field.
"""
if expected_version is not None:
if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version:
raise InvalidOperationError(
f"Settlement (ID: {settlement_db.id}) cannot be deleted. "
f"Expected version {expected_version} does not match current version. Please refresh.",
# status_code=status.HTTP_409_CONFLICT
)
await db.delete(settlement_db)
try:
await db.commit()
except Exception as e:
await db.rollback()
raise InvalidOperationError(f"Failed to delete settlement: {str(e)}")
return None
# TODO: Implement update_settlement (consider restrictions, versioning)
# TODO: Implement delete_settlement (consider implications on balances)
# app/database.py
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
from app.config import settings
# Ensure DATABASE_URL is set before proceeding
if not settings.DATABASE_URL:
raise ValueError("DATABASE_URL is not configured in settings.")
# Create the SQLAlchemy async engine
# pool_recycle=3600 helps prevent stale connections on some DBs
engine = create_async_engine(
settings.DATABASE_URL,
echo=True, # Log SQL queries (useful for debugging)
future=True, # Use SQLAlchemy 2.0 style features
pool_recycle=3600 # Optional: recycle connections after 1 hour
)
# Create a configured "Session" class
# expire_on_commit=False prevents attributes from expiring after commit
AsyncSessionLocal = sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
# Base class for our ORM models
Base = declarative_base()
# Dependency to get DB session in path operations
async def get_db() -> AsyncSession: # type: ignore
"""
Dependency function that yields an AsyncSession.
Ensures the session is closed after the request.
"""
async with AsyncSessionLocal() as session:
try:
yield session
# Optionally commit if your endpoints modify data directly
# await session.commit() # Usually commit happens within endpoint logic
except Exception:
await session.rollback()
raise
finally:
await session.close() # Not strictly necessary with async context manager, but explicit
# app/schemas/expense.py
from pydantic import BaseModel, ConfigDict, validator
from typing import List, Optional
from decimal import Decimal
from datetime import datetime
# Assuming SplitTypeEnum is accessible here, e.g., from app.models or app.core.enums
# For now, let's redefine it or import it if models.py is parsable by Pydantic directly
# If it's from app.models, you might need to make app.models.SplitTypeEnum Pydantic-compatible or map it.
# For simplicity during schema definition, I'll redefine a string enum here.
# In a real setup, ensure this aligns with the SQLAlchemy enum in models.py.
from app.models import SplitTypeEnum # Try importing directly
# --- ExpenseSplit Schemas ---
class ExpenseSplitBase(BaseModel):
user_id: int
owed_amount: Decimal
share_percentage: Optional[Decimal] = None
share_units: Optional[int] = None
class ExpenseSplitCreate(ExpenseSplitBase):
pass # All fields from base are needed for creation
class ExpenseSplitPublic(ExpenseSplitBase):
id: int
expense_id: int
# user: Optional[UserPublic] # If we want to nest user details
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
# --- Expense Schemas ---
class ExpenseBase(BaseModel):
description: str
total_amount: Decimal
currency: Optional[str] = "USD"
expense_date: Optional[datetime] = None
split_type: SplitTypeEnum
list_id: Optional[int] = None
group_id: Optional[int] = None # Should be present if list_id is not, and vice-versa
item_id: Optional[int] = None
paid_by_user_id: int
class ExpenseCreate(ExpenseBase):
# For EQUAL split, splits are generated. For others, they might be provided.
# This logic will be in the CRUD: if split_type is EXACT_AMOUNTS, PERCENTAGE, SHARES,
# then 'splits_in' should be provided.
splits_in: Optional[List[ExpenseSplitCreate]] = None
@validator('total_amount')
def total_amount_must_be_positive(cls, v):
if v <= Decimal('0'):
raise ValueError('Total amount must be positive')
return v
# Basic validation: if list_id is None, group_id must be provided.
# More complex cross-field validation might be needed.
@validator('group_id', always=True)
def check_list_or_group_id(cls, v, values):
if values.get('list_id') is None and v is None:
raise ValueError('Either list_id or group_id must be provided for an expense')
return v
class ExpenseUpdate(BaseModel):
description: Optional[str] = None
total_amount: Optional[Decimal] = None
currency: Optional[str] = None
expense_date: Optional[datetime] = None
split_type: Optional[SplitTypeEnum] = None
list_id: Optional[int] = None
group_id: Optional[int] = None
item_id: Optional[int] = None
# paid_by_user_id is usually not updatable directly to maintain integrity.
# Updating splits would be a more complex operation, potentially a separate endpoint or careful logic.
version: int # For optimistic locking
class ExpensePublic(ExpenseBase):
id: int
created_at: datetime
updated_at: datetime
version: int
splits: List[ExpenseSplitPublic] = []
# paid_by_user: Optional[UserPublic] # If nesting user details
# list: Optional[ListPublic] # If nesting list details
# group: Optional[GroupPublic] # If nesting group details
# item: Optional[ItemPublic] # If nesting item details
model_config = ConfigDict(from_attributes=True)
# --- Settlement Schemas ---
class SettlementBase(BaseModel):
group_id: int
paid_by_user_id: int
paid_to_user_id: int
amount: Decimal
settlement_date: Optional[datetime] = None
description: Optional[str] = None
class SettlementCreate(SettlementBase):
@validator('amount')
def amount_must_be_positive(cls, v):
if v <= Decimal('0'):
raise ValueError('Settlement amount must be positive')
return v
@validator('paid_to_user_id')
def payer_and_payee_must_be_different(cls, v, values):
if 'paid_by_user_id' in values and v == values['paid_by_user_id']:
raise ValueError('Payer and payee cannot be the same user')
return v
class SettlementUpdate(BaseModel):
amount: Optional[Decimal] = None
settlement_date: Optional[datetime] = None
description: Optional[str] = None
# group_id, paid_by_user_id, paid_to_user_id are typically not updatable.
version: int # For optimistic locking
class SettlementPublic(SettlementBase):
id: int
created_at: datetime
updated_at: datetime
# payer: Optional[UserPublic]
# payee: Optional[UserPublic]
# group: Optional[GroupPublic]
model_config = ConfigDict(from_attributes=True)
# Placeholder for nested schemas (e.g., UserPublic) if needed
# from app.schemas.user import UserPublic
# from app.schemas.list import ListPublic
# from app.schemas.group import GroupPublic
# from app.schemas.item import ItemPublic
# app/schemas/group.py
from pydantic import BaseModel, ConfigDict
from datetime import datetime
from typing import Optional, List
from .user import UserPublic # Import UserPublic to represent members
# Properties to receive via API on creation
class GroupCreate(BaseModel):
name: str
# Properties to return to client
class GroupPublic(BaseModel):
id: int
name: str
created_by_id: int
created_at: datetime
members: Optional[List[UserPublic]] = None # Include members only in detailed view
model_config = ConfigDict(from_attributes=True)
# Properties stored in DB (if needed, often GroupPublic is sufficient)
# class GroupInDB(GroupPublic):
# pass
# app/schemas/invite.py
from pydantic import BaseModel
from datetime import datetime
# Properties to receive when accepting an invite
class InviteAccept(BaseModel):
code: str
# Properties to return when an invite is created
class InviteCodePublic(BaseModel):
code: str
expires_at: datetime
group_id: int
# Properties for internal use/DB (optional)
# class Invite(InviteCodePublic):
# id: int
# created_by_id: int
# is_active: bool = True
# model_config = ConfigDict(from_attributes=True)
# app/schemas/message.py
from pydantic import BaseModel
class Message(BaseModel):
detail: str
# app/schemas/ocr.py
from pydantic import BaseModel
from typing import List
class OcrExtractResponse(BaseModel):
extracted_items: List[str] # A list of potential item names
import pytest
from fastapi import status
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Callable, Dict, Any
from app.models import User as UserModel, Group as GroupModel, List as ListModel
from app.schemas.expense import ExpenseCreate
from app.core.config import settings
# Helper to create a URL for an endpoint
API_V1_STR = settings.API_V1_STR
def expense_url(endpoint: str = "") -> str:
return f"{API_V1_STR}/financials/expenses{endpoint}"
def settlement_url(endpoint: str = "") -> str:
return f"{API_V1_STR}/financials/settlements{endpoint}"
@pytest.mark.asyncio
async def test_create_new_expense_success_list_context(
client: AsyncClient,
db_session: AsyncSession, # Assuming a fixture for db session
normal_user_token_headers: Dict[str, str], # Assuming a fixture for user auth
test_user: UserModel, # Assuming a fixture for a test user
test_list_user_is_member: ListModel, # Assuming a fixture for a list user is member of
) -> None:
"""
Test successful creation of a new expense linked to a list.
"""
expense_data = ExpenseCreate(
description="Test Expense for List",
amount=100.00,
currency="USD",
paid_by_user_id=test_user.id,
list_id=test_list_user_is_member.id,
group_id=None, # group_id should be derived from list if list is in a group
# category_id: Optional[int] = None # Assuming category is optional
# expense_date: Optional[date] = None # Assuming date is optional
# splits: Optional[List[SplitCreate]] = [] # Assuming splits are optional for now
)
response = await client.post(
expense_url(),
headers=normal_user_token_headers,
json=expense_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_201_CREATED
content = response.json()
assert content["description"] == expense_data.description
assert content["amount"] == expense_data.amount
assert content["currency"] == expense_data.currency
assert content["paid_by_user_id"] == test_user.id
assert content["list_id"] == test_list_user_is_member.id
# If test_list_user_is_member has a group_id, it should be set in the response
if test_list_user_is_member.group_id:
assert content["group_id"] == test_list_user_is_member.group_id
else:
assert content["group_id"] is None
assert "id" in content
assert "created_at" in content
assert "updated_at" in content
assert "version" in content
assert content["version"] == 1
@pytest.mark.asyncio
async def test_create_new_expense_success_group_context(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_group_user_is_member: GroupModel, # Assuming a fixture for a group user is member of
) -> None:
"""
Test successful creation of a new expense linked directly to a group.
"""
expense_data = ExpenseCreate(
description="Test Expense for Group",
amount=50.00,
currency="EUR",
paid_by_user_id=test_user.id,
group_id=test_group_user_is_member.id,
list_id=None,
)
response = await client.post(
expense_url(),
headers=normal_user_token_headers,
json=expense_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_201_CREATED
content = response.json()
assert content["description"] == expense_data.description
assert content["paid_by_user_id"] == test_user.id
assert content["group_id"] == test_group_user_is_member.id
assert content["list_id"] is None
assert content["version"] == 1
@pytest.mark.asyncio
async def test_create_new_expense_fail_no_list_or_group(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
) -> None:
"""
Test expense creation fails if neither list_id nor group_id is provided.
"""
expense_data = ExpenseCreate(
description="Test Invalid Expense",
amount=10.00,
currency="USD",
paid_by_user_id=test_user.id,
list_id=None,
group_id=None,
)
response = await client.post(
expense_url(),
headers=normal_user_token_headers,
json=expense_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
content = response.json()
assert "Expense must be linked to a list_id or group_id" in content["detail"]
@pytest.mark.asyncio
async def test_create_new_expense_fail_paid_by_other_not_owner(
client: AsyncClient,
normal_user_token_headers: Dict[str, str], # User is member, not owner
test_user: UserModel, # This is the current_user (member)
test_group_user_is_member: GroupModel, # Group the current_user is a member of
another_user_in_group: UserModel, # Another user in the same group
# Ensure test_user is NOT an owner of test_group_user_is_member for this test
) -> None:
"""
Test creation fails if paid_by_user_id is another user, and current_user is not a group owner.
Assumes normal_user_token_headers belongs to a user who is a member but not an owner of test_group_user_is_member.
"""
expense_data = ExpenseCreate(
description="Expense paid by other",
amount=75.00,
currency="GBP",
paid_by_user_id=another_user_in_group.id, # Paid by someone else
group_id=test_group_user_is_member.id,
list_id=None,
)
response = await client.post(
expense_url(),
headers=normal_user_token_headers, # Current user is a member, not owner
json=expense_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "Only group owners can create expenses paid by others" in content["detail"]
# --- Add tests for other endpoints below ---
# GET /expenses/{expense_id}
@pytest.mark.asyncio
async def test_get_expense_success(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
# Assume an existing expense created by test_user or in a group/list they have access to
# This would typically be created by another test or a fixture
created_expense: ExpensePublic, # Assuming a fixture that provides a created expense
) -> None:
"""
Test successfully retrieving an existing expense.
User has access either by being the payer, or via list/group membership.
"""
response = await client.get(
expense_url(f"/{created_expense.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert content["id"] == created_expense.id
assert content["description"] == created_expense.description
assert content["amount"] == created_expense.amount
assert content["paid_by_user_id"] == created_expense.paid_by_user_id
if created_expense.list_id:
assert content["list_id"] == created_expense.list_id
if created_expense.group_id:
assert content["group_id"] == created_expense.group_id
# TODO: Add more tests for get_expense:
# - expense not found -> 404
# - user has no access (not payer, not in list, not in group if applicable) -> 403
# - expense in list, user has list access
# - expense in group, user has group access
# - expense personal (no list, no group), user is payer
# - expense personal (no list, no group), user is NOT payer -> 403
@pytest.mark.asyncio
async def test_get_expense_not_found(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
) -> None:
"""
Test retrieving a non-existent expense results in 404.
"""
non_existent_expense_id = 9999999
response = await client.get(
expense_url(f"/{non_existent_expense_id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_404_NOT_FOUND
content = response.json()
assert "not found" in content["detail"].lower()
@pytest.mark.asyncio
async def test_get_expense_forbidden_personal_expense_other_user(
client: AsyncClient,
normal_user_token_headers: Dict[str, str], # Belongs to test_user
# Fixture for an expense paid by another_user, not linked to any list/group test_user has access to
personal_expense_of_another_user: ExpensePublic
) -> None:
"""
Test retrieving a personal expense of another user (no shared list/group) results in 403.
"""
response = await client.get(
expense_url(f"/{personal_expense_of_another_user.id}"),
headers=normal_user_token_headers # Current user querying
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "Not authorized to view this expense" in content["detail"]
# GET /lists/{list_id}/expenses
@pytest.mark.asyncio
async def test_list_list_expenses_success(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_list_user_is_member: ListModel, # List the user is a member of
# Assume some expenses have been created for this list by a fixture or previous tests
) -> None:
"""
Test successfully listing expenses for a list the user has access to.
"""
response = await client.get(
f"{API_V1_STR}/financials/lists/{test_list_user_is_member.id}/expenses",
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
for expense_item in content: # Renamed from expense to avoid conflict if a fixture is named expense
assert expense_item["list_id"] == test_list_user_is_member.id
# TODO: Add more tests for list_list_expenses:
# - list not found -> 404 (ListNotFoundError from check_list_access_for_financials)
# - user has no access to list -> 403 (ListPermissionError from check_list_access_for_financials)
# - list exists but has no expenses -> empty list, 200 OK
# - test pagination (skip, limit)
@pytest.mark.asyncio
async def test_list_list_expenses_list_not_found(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
) -> None:
"""
Test listing expenses for a non-existent list results in 404 (or appropriate error from permission check).
The check_list_access_for_financials raises ListNotFoundError, which might be caught and raised as 404.
The endpoint itself also has a get for ListModel, which would 404 first if permission check passed (not possible here).
Based on financials.py, ListNotFoundError is raised by check_list_access_for_financials.
This should translate to a 404 or a 403 if ListPermissionError wraps it with an action.
The current ListPermissionError in check_list_access_for_financials re-raises ListNotFoundError if that's the cause.
ListNotFoundError is a custom exception often mapped to 404.
Let's assume ListNotFoundError results in a 404 response from an exception handler.
"""
non_existent_list_id = 99999
response = await client.get(
f"{API_V1_STR}/financials/lists/{non_existent_list_id}/expenses",
headers=normal_user_token_headers
)
# The ListNotFoundError is raised by the check_list_access_for_financials helper,
# which is then re-raised. FastAPI default exception handlers or custom ones
# would convert this to an HTTP response. Typically NotFoundError -> 404.
# If ListPermissionError catches it and re-raises it specifically, it might be 403.
# From the code: `except ListNotFoundError: raise` means it propagates.
# Let's assume a global handler for NotFoundError derived exceptions leads to 404.
assert response.status_code == status.HTTP_404_NOT_FOUND
# The actual detail might vary based on how ListNotFoundError is handled by FastAPI
# For now, we check the status code. If financials.py maps it differently, this will need adjustment.
# Based on `raise ListNotFoundError(expense_in.list_id)` in create_new_expense, and if that leads to 400,
# this might be inconsistent. However, `check_list_access_for_financials` just re-raises ListNotFoundError.
# Let's stick to expecting 404 for a direct not found error from a path parameter.
content = response.json()
assert "list not found" in content["detail"].lower() # Common detail for not found errors
@pytest.mark.asyncio
async def test_list_list_expenses_no_access(
client: AsyncClient,
normal_user_token_headers: Dict[str, str], # User who will attempt access
test_list_user_not_member: ListModel, # A list current user is NOT a member of
) -> None:
"""
Test listing expenses for a list the user does not have access to (403 Forbidden).
"""
response = await client.get(
f"{API_V1_STR}/financials/lists/{test_list_user_not_member.id}/expenses",
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert f"User does not have permission to access financial data for list {test_list_user_not_member.id}" in content["detail"]
@pytest.mark.asyncio
async def test_list_list_expenses_empty(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_list_user_is_member_no_expenses: ListModel, # List user is member of, but has no expenses
) -> None:
"""
Test listing expenses for an accessible list that has no expenses (empty list, 200 OK).
"""
response = await client.get(
f"{API_V1_STR}/financials/lists/{test_list_user_is_member_no_expenses.id}/expenses",
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
assert len(content) == 0
# GET /groups/{group_id}/expenses
@pytest.mark.asyncio
async def test_list_group_expenses_success(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_group_user_is_member: GroupModel, # Group the user is a member of
# Assume some expenses have been created for this group by a fixture or previous tests
) -> None:
"""
Test successfully listing expenses for a group the user has access to.
"""
response = await client.get(
f"{API_V1_STR}/financials/groups/{test_group_user_is_member.id}/expenses",
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
# Further assertions can be made here, e.g., checking if all expenses belong to the group
for expense_item in content:
assert expense_item["group_id"] == test_group_user_is_member.id
# Expenses in a group might also have a list_id if they were added via a list belonging to that group
# TODO: Add more tests for list_group_expenses:
# - group not found -> 404 (GroupNotFoundError from check_group_membership)
# - user has no access to group (not a member) -> 403 (GroupMembershipError from check_group_membership)
# - group exists but has no expenses -> empty list, 200 OK
# - test pagination (skip, limit)
# PUT /expenses/{expense_id}
# DELETE /expenses/{expense_id}
# GET /settlements/{settlement_id}
# POST /settlements
# GET /groups/{group_id}/settlements
# PUT /settlements/{settlement_id}
# DELETE /settlements/{settlement_id}
pytest.skip("Still implementing other tests", allow_module_level=True)
import pytest
from fastapi import status
from app.core.exceptions import (
ListNotFoundError,
ListPermissionError,
ListCreatorRequiredError,
GroupNotFoundError,
GroupPermissionError,
GroupMembershipError,
GroupOperationError,
GroupValidationError,
ItemNotFoundError,
UserNotFoundError,
InvalidOperationError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseTransactionError,
DatabaseQueryError,
OCRServiceUnavailableError,
OCRServiceConfigError,
OCRUnexpectedError,
OCRQuotaExceededError,
InvalidFileTypeError,
FileTooLargeError,
OCRProcessingError,
EmailAlreadyRegisteredError,
UserCreationError,
InviteNotFoundError,
InviteExpiredError,
InviteAlreadyUsedError,
InviteCreationError,
ListStatusNotFoundError,
ConflictError,
InvalidCredentialsError,
NotAuthenticatedError,
JWTError,
JWTUnexpectedError
)
# TODO: It seems like settings are used in some exceptions.
# You will need to mock app.config.settings for these tests to pass.
# Consider using pytest-mock or unittest.mock.patch.
# Example: from app.config import settings
def test_list_not_found_error():
list_id = 123
with pytest.raises(ListNotFoundError) as excinfo:
raise ListNotFoundError(list_id=list_id)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"List {list_id} not found"
def test_list_permission_error():
list_id = 456
action = "delete"
with pytest.raises(ListPermissionError) as excinfo:
raise ListPermissionError(list_id=list_id, action=action)
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
assert excinfo.value.detail == f"You do not have permission to {action} list {list_id}"
def test_list_permission_error_default_action():
list_id = 789
with pytest.raises(ListPermissionError) as excinfo:
raise ListPermissionError(list_id=list_id)
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
assert excinfo.value.detail == f"You do not have permission to access list {list_id}"
def test_list_creator_required_error():
list_id = 101
action = "update"
with pytest.raises(ListCreatorRequiredError) as excinfo:
raise ListCreatorRequiredError(list_id=list_id, action=action)
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
assert excinfo.value.detail == f"Only the list creator can {action} list {list_id}"
def test_group_not_found_error():
group_id = 202
with pytest.raises(GroupNotFoundError) as excinfo:
raise GroupNotFoundError(group_id=group_id)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"Group {group_id} not found"
def test_group_permission_error():
group_id = 303
action = "invite"
with pytest.raises(GroupPermissionError) as excinfo:
raise GroupPermissionError(group_id=group_id, action=action)
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
assert excinfo.value.detail == f"You do not have permission to {action} in group {group_id}"
def test_group_membership_error():
group_id = 404
action = "post"
with pytest.raises(GroupMembershipError) as excinfo:
raise GroupMembershipError(group_id=group_id, action=action)
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
assert excinfo.value.detail == f"You must be a member of group {group_id} to {action}"
def test_group_membership_error_default_action():
group_id = 505
with pytest.raises(GroupMembershipError) as excinfo:
raise GroupMembershipError(group_id=group_id)
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
assert excinfo.value.detail == f"You must be a member of group {group_id} to access"
def test_group_operation_error():
detail_msg = "Failed to perform group operation."
with pytest.raises(GroupOperationError) as excinfo:
raise GroupOperationError(detail=detail_msg)
assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert excinfo.value.detail == detail_msg
def test_group_validation_error():
detail_msg = "Invalid group data."
with pytest.raises(GroupValidationError) as excinfo:
raise GroupValidationError(detail=detail_msg)
assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
assert excinfo.value.detail == detail_msg
def test_item_not_found_error():
item_id = 606
with pytest.raises(ItemNotFoundError) as excinfo:
raise ItemNotFoundError(item_id=item_id)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"Item {item_id} not found"
def test_user_not_found_error_no_identifier():
with pytest.raises(UserNotFoundError) as excinfo:
raise UserNotFoundError()
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == "User not found."
def test_user_not_found_error_with_id():
user_id = 707
with pytest.raises(UserNotFoundError) as excinfo:
raise UserNotFoundError(user_id=user_id)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"User with ID {user_id} not found."
def test_user_not_found_error_with_identifier_string():
identifier = "test_user"
with pytest.raises(UserNotFoundError) as excinfo:
raise UserNotFoundError(identifier=identifier)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"User with identifier '{identifier}' not found."
def test_invalid_operation_error():
detail_msg = "This operation is not allowed."
with pytest.raises(InvalidOperationError) as excinfo:
raise InvalidOperationError(detail=detail_msg)
assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
assert excinfo.value.detail == detail_msg
def test_invalid_operation_error_custom_status():
detail_msg = "This operation is forbidden."
custom_status = status.HTTP_403_FORBIDDEN
with pytest.raises(InvalidOperationError) as excinfo:
raise InvalidOperationError(detail=detail_msg, status_code=custom_status)
assert excinfo.value.status_code == custom_status
assert excinfo.value.detail == detail_msg
# The following exceptions depend on `settings`
# We need to mock `app.config.settings` for these tests.
# For now, I will add placeholder tests that would fail without mocking.
# Consider using pytest-mock or unittest.mock.patch for this.
# def test_database_connection_error(mocker):
# mocker.patch("app.core.exceptions.settings.DB_CONNECTION_ERROR", "Test DB connection error")
# with pytest.raises(DatabaseConnectionError) as excinfo:
# raise DatabaseConnectionError()
# assert excinfo.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
# assert excinfo.value.detail == "Test DB connection error" # settings.DB_CONNECTION_ERROR
# def test_database_integrity_error(mocker):
# mocker.patch("app.core.exceptions.settings.DB_INTEGRITY_ERROR", "Test DB integrity error")
# with pytest.raises(DatabaseIntegrityError) as excinfo:
# raise DatabaseIntegrityError()
# assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
# assert excinfo.value.detail == "Test DB integrity error" # settings.DB_INTEGRITY_ERROR
# def test_database_transaction_error(mocker):
# mocker.patch("app.core.exceptions.settings.DB_TRANSACTION_ERROR", "Test DB transaction error")
# with pytest.raises(DatabaseTransactionError) as excinfo:
# raise DatabaseTransactionError()
# assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# assert excinfo.value.detail == "Test DB transaction error" # settings.DB_TRANSACTION_ERROR
# def test_database_query_error(mocker):
# mocker.patch("app.core.exceptions.settings.DB_QUERY_ERROR", "Test DB query error")
# with pytest.raises(DatabaseQueryError) as excinfo:
# raise DatabaseQueryError()
# assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# assert excinfo.value.detail == "Test DB query error" # settings.DB_QUERY_ERROR
# def test_ocr_service_unavailable_error(mocker):
# mocker.patch("app.core.exceptions.settings.OCR_SERVICE_UNAVAILABLE", "Test OCR unavailable")
# with pytest.raises(OCRServiceUnavailableError) as excinfo:
# raise OCRServiceUnavailableError()
# assert excinfo.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
# assert excinfo.value.detail == "Test OCR unavailable" # settings.OCR_SERVICE_UNAVAILABLE
# def test_ocr_service_config_error(mocker):
# mocker.patch("app.core.exceptions.settings.OCR_SERVICE_CONFIG_ERROR", "Test OCR config error")
# with pytest.raises(OCRServiceConfigError) as excinfo:
# raise OCRServiceConfigError()
# assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# assert excinfo.value.detail == "Test OCR config error" # settings.OCR_SERVICE_CONFIG_ERROR
# def test_ocr_unexpected_error(mocker):
# mocker.patch("app.core.exceptions.settings.OCR_UNEXPECTED_ERROR", "Test OCR unexpected error")
# with pytest.raises(OCRUnexpectedError) as excinfo:
# raise OCRUnexpectedError()
# assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# assert excinfo.value.detail == "Test OCR unexpected error" # settings.OCR_UNEXPECTED_ERROR
# def test_ocr_quota_exceeded_error(mocker):
# mocker.patch("app.core.exceptions.settings.OCR_QUOTA_EXCEEDED", "Test OCR quota exceeded")
# with pytest.raises(OCRQuotaExceededError) as excinfo:
# raise OCRQuotaExceededError()
# assert excinfo.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
# assert excinfo.value.detail == "Test OCR quota exceeded" # settings.OCR_QUOTA_EXCEEDED
# def test_invalid_file_type_error(mocker):
# test_types = ["png", "jpg"]
# mocker.patch("app.core.exceptions.settings.ALLOWED_IMAGE_TYPES", test_types)
# mocker.patch("app.core.exceptions.settings.OCR_INVALID_FILE_TYPE", "Invalid type: {types}")
# with pytest.raises(InvalidFileTypeError) as excinfo:
# raise InvalidFileTypeError()
# assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
# assert excinfo.value.detail == f"Invalid type: {', '.join(test_types)}" # settings.OCR_INVALID_FILE_TYPE.format(types=", ".join(settings.ALLOWED_IMAGE_TYPES))
# def test_file_too_large_error(mocker):
# max_size = 10
# mocker.patch("app.core.exceptions.settings.MAX_FILE_SIZE_MB", max_size)
# mocker.patch("app.core.exceptions.settings.OCR_FILE_TOO_LARGE", "File too large: {size}MB")
# with pytest.raises(FileTooLargeError) as excinfo:
# raise FileTooLargeError()
# assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
# assert excinfo.value.detail == f"File too large: {max_size}MB" # settings.OCR_FILE_TOO_LARGE.format(size=settings.MAX_FILE_SIZE_MB)
# def test_ocr_processing_error(mocker):
# error_detail = "Specific OCR error"
# mocker.patch("app.core.exceptions.settings.OCR_PROCESSING_ERROR", "OCR processing failed: {detail}")
# with pytest.raises(OCRProcessingError) as excinfo:
# raise OCRProcessingError(detail=error_detail)
# assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
# assert excinfo.value.detail == f"OCR processing failed: {error_detail}" # settings.OCR_PROCESSING_ERROR.format(detail=detail)
def test_email_already_registered_error():
with pytest.raises(EmailAlreadyRegisteredError) as excinfo:
raise EmailAlreadyRegisteredError()
assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
assert excinfo.value.detail == "Email already registered."
def test_user_creation_error():
with pytest.raises(UserCreationError) as excinfo:
raise UserCreationError()
assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert excinfo.value.detail == "An error occurred during user creation."
def test_invite_not_found_error():
invite_code = "TESTCODE123"
with pytest.raises(InviteNotFoundError) as excinfo:
raise InviteNotFoundError(invite_code=invite_code)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"Invite code {invite_code} not found"
def test_invite_expired_error():
invite_code = "EXPIREDCODE"
with pytest.raises(InviteExpiredError) as excinfo:
raise InviteExpiredError(invite_code=invite_code)
assert excinfo.value.status_code == status.HTTP_410_GONE
assert excinfo.value.detail == f"Invite code {invite_code} has expired"
def test_invite_already_used_error():
invite_code = "USEDCODE"
with pytest.raises(InviteAlreadyUsedError) as excinfo:
raise InviteAlreadyUsedError(invite_code=invite_code)
assert excinfo.value.status_code == status.HTTP_410_GONE
assert excinfo.value.detail == f"Invite code {invite_code} has already been used"
def test_invite_creation_error():
group_id = 909
with pytest.raises(InviteCreationError) as excinfo:
raise InviteCreationError(group_id=group_id)
assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert excinfo.value.detail == f"Failed to create invite for group {group_id}"
def test_list_status_not_found_error():
list_id = 808
with pytest.raises(ListStatusNotFoundError) as excinfo:
raise ListStatusNotFoundError(list_id=list_id)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"Status for list {list_id} not found"
def test_conflict_error():
detail_msg = "Resource version mismatch."
with pytest.raises(ConflictError) as excinfo:
raise ConflictError(detail=detail_msg)
assert excinfo.value.status_code == status.HTTP_409_CONFLICT
assert excinfo.value.detail == detail_msg
# Tests for auth-related exceptions that likely require mocking app.config.settings
# def test_invalid_credentials_error(mocker):
# mocker.patch("app.core.exceptions.settings.AUTH_INVALID_CREDENTIALS", "Invalid test credentials")
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth")
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer")
# with pytest.raises(InvalidCredentialsError) as excinfo:
# raise InvalidCredentialsError()
# assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED
# assert excinfo.value.detail == "Invalid test credentials"
# assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"invalid_credentials\""}
# def test_not_authenticated_error(mocker):
# mocker.patch("app.core.exceptions.settings.AUTH_NOT_AUTHENTICATED", "Not authenticated test")
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth")
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer")
# with pytest.raises(NotAuthenticatedError) as excinfo:
# raise NotAuthenticatedError()
# assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED
# assert excinfo.value.detail == "Not authenticated test"
# assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"not_authenticated\""}
# def test_jwt_error(mocker):
# error_msg = "Test JWT issue"
# mocker.patch("app.core.exceptions.settings.JWT_ERROR", "JWT error: {error}")
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth")
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer")
# with pytest.raises(JWTError) as excinfo:
# raise JWTError(error=error_msg)
# assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED
# assert excinfo.value.detail == f"JWT error: {error_msg}"
# assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"invalid_token\""}
# def test_jwt_unexpected_error(mocker):
# error_msg = "Unexpected test JWT issue"
# mocker.patch("app.core.exceptions.settings.JWT_UNEXPECTED_ERROR", "Unexpected JWT error: {error}")
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth")
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer")
# with pytest.raises(JWTUnexpectedError) as excinfo:
# raise JWTUnexpectedError(error=error_msg)
# assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED
# assert excinfo.value.detail == f"Unexpected JWT error: {error_msg}"
# assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"invalid_token\""}
import pytest
import asyncio
from unittest.mock import patch, MagicMock, AsyncMock
import google.generativeai as genai
from google.api_core import exceptions as google_exceptions
# Modules to test
from app.core import gemini
from app.core.exceptions import (
OCRServiceUnavailableError,
OCRServiceConfigError,
OCRUnexpectedError,
OCRQuotaExceededError
)
# Default Mock Settings
@pytest.fixture
def mock_gemini_settings():
settings_mock = MagicMock()
settings_mock.GEMINI_API_KEY = "test_api_key"
settings_mock.GEMINI_MODEL_NAME = "gemini-pro-vision"
settings_mock.GEMINI_SAFETY_SETTINGS = {
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
}
settings_mock.GEMINI_GENERATION_CONFIG = {"temperature": 0.7}
settings_mock.OCR_ITEM_EXTRACTION_PROMPT = "Extract items:"
return settings_mock
@pytest.fixture
def mock_generative_model_instance():
model_instance = MagicMock(spec=genai.GenerativeModel)
model_instance.generate_content_async = AsyncMock()
return model_instance
@pytest.fixture
@patch('google.generativeai.GenerativeModel')
@patch('google.generativeai.configure')
def patch_google_ai_client(mock_configure, mock_generative_model, mock_generative_model_instance):
mock_generative_model.return_value = mock_generative_model_instance
return mock_configure, mock_generative_model, mock_generative_model_instance
# --- Test Gemini Client Initialization (Global Client) ---
# Parametrize to test different scenarios for the global client init
@pytest.mark.parametrize(
"api_key_present, configure_raises, model_init_raises, expected_error_message_part",
[
(True, None, None, None), # Success
(False, None, None, "GEMINI_API_KEY not configured"), # API key missing
(True, Exception("Config error"), None, "Failed to initialize Gemini AI client: Config error"), # genai.configure error
(True, None, Exception("Model init error"), "Failed to initialize Gemini AI client: Model init error"), # GenerativeModel error
]
)
@patch('app.core.gemini.genai') # Patch genai within the gemini module
def test_global_gemini_client_initialization(
mock_genai_module,
mock_gemini_settings,
api_key_present,
configure_raises,
model_init_raises,
expected_error_message_part
):
"""Tests the global gemini_flash_client initialization logic in app.core.gemini."""
# We need to reload the module to re-trigger its top-level initialization code.
# This is a bit tricky. A common pattern is to put init logic in a function.
# For now, we'll try to simulate it by controlling mocks before module access.
with patch('app.core.gemini.settings', mock_gemini_settings):
if not api_key_present:
mock_gemini_settings.GEMINI_API_KEY = None
mock_genai_module.configure = MagicMock()
mock_genai_module.GenerativeModel = MagicMock()
mock_genai_module.types = genai.types # Keep original types
mock_genai_module.HarmCategory = genai.HarmCategory
mock_genai_module.HarmBlockThreshold = genai.HarmBlockThreshold
if configure_raises:
mock_genai_module.configure.side_effect = configure_raises
if model_init_raises:
mock_genai_module.GenerativeModel.side_effect = model_init_raises
# Python modules are singletons. To re-run top-level code, we need to unload and reload.
# This is generally discouraged. It's better to have an explicit init function.
# For this test, we'll check the state variables set by the module's import-time code.
import importlib
importlib.reload(gemini) # This re-runs the try-except block at the top of gemini.py
if expected_error_message_part:
assert gemini.gemini_initialization_error is not None
assert expected_error_message_part in gemini.gemini_initialization_error
assert gemini.gemini_flash_client is None
else:
assert gemini.gemini_initialization_error is None
assert gemini.gemini_flash_client is not None
mock_genai_module.configure.assert_called_once_with(api_key="test_api_key")
mock_genai_module.GenerativeModel.assert_called_once()
# Could add more assertions about safety_settings and generation_config here
# Clean up after reload for other tests
importlib.reload(gemini)
# --- Test get_gemini_client ---
# Assuming the global client tests above set the stage for these
@patch('app.core.gemini.gemini_flash_client', new_callable=MagicMock)
@patch('app.core.gemini.gemini_initialization_error', None)
def test_get_gemini_client_success(mock_client_var, mock_error_var):
mock_client_var.return_value = MagicMock(spec=genai.GenerativeModel) # Simulate an initialized client
gemini.gemini_flash_client = mock_client_var # Assign the mock
gemini.gemini_initialization_error = None
client = gemini.get_gemini_client()
assert client is not None
@patch('app.core.gemini.gemini_flash_client', None)
@patch('app.core.gemini.gemini_initialization_error', "Test init error")
def test_get_gemini_client_init_error(mock_client_var, mock_error_var):
gemini.gemini_flash_client = None
gemini.gemini_initialization_error = "Test init error"
with pytest.raises(RuntimeError, match="Gemini client could not be initialized: Test init error"):
gemini.get_gemini_client()
@patch('app.core.gemini.gemini_flash_client', None)
@patch('app.core.gemini.gemini_initialization_error', None) # No init error, but client is None
def test_get_gemini_client_none_client_unknown_issue(mock_client_var, mock_error_var):
gemini.gemini_flash_client = None
gemini.gemini_initialization_error = None
with pytest.raises(RuntimeError, match="Gemini client is not available \(unknown initialization issue\)."):
gemini.get_gemini_client()
# --- Tests for extract_items_from_image_gemini --- (Simplified for brevity, needs more cases)
@pytest.mark.asyncio
async def test_extract_items_from_image_gemini_success(
mock_gemini_settings,
mock_generative_model_instance,
patch_google_ai_client # This fixture patches google.generativeai for the module
):
""" Test successful item extraction """
# Ensure the global client is mocked to be the one we control
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
patch('app.core.gemini.gemini_initialization_error', None):
mock_response = MagicMock()
mock_response.text = "Item 1\nItem 2\n Item 3 \n\nAnother Item"
# Simulate the structure for safety checks if needed
mock_candidate = MagicMock()
mock_candidate.content.parts = [MagicMock(text=mock_response.text)]
mock_candidate.finish_reason = 'STOP' # Or whatever is appropriate for success
mock_candidate.safety_ratings = []
mock_response.candidates = [mock_candidate]
mock_generative_model_instance.generate_content_async.return_value = mock_response
image_bytes = b"dummy_image_bytes"
mime_type = "image/png"
items = await gemini.extract_items_from_image_gemini(image_bytes, mime_type)
mock_generative_model_instance.generate_content_async.assert_called_once_with([
mock_gemini_settings.OCR_ITEM_EXTRACTION_PROMPT,
{"mime_type": mime_type, "data": image_bytes}
])
assert items == ["Item 1", "Item 2", "Item 3", "Another Item"]
@pytest.mark.asyncio
async def test_extract_items_from_image_gemini_client_not_init(
mock_gemini_settings
):
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch('app.core.gemini.gemini_flash_client', None), \
patch('app.core.gemini.gemini_initialization_error', "Initialization failed explicitly"):
image_bytes = b"dummy_image_bytes"
with pytest.raises(RuntimeError, match="Gemini client could not be initialized: Initialization failed explicitly"):
await gemini.extract_items_from_image_gemini(image_bytes)
@pytest.mark.asyncio
@patch('app.core.gemini.get_gemini_client') # Mock the getter to control the client directly
async def test_extract_items_from_image_gemini_api_quota_error(
mock_get_client,
mock_gemini_settings,
mock_generative_model_instance
):
mock_get_client.return_value = mock_generative_model_instance
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ResourceExhausted("Quota exceeded")
with patch('app.core.gemini.settings', mock_gemini_settings):
image_bytes = b"dummy_image_bytes"
with pytest.raises(google_exceptions.ResourceExhausted, match="Quota exceeded"):
await gemini.extract_items_from_image_gemini(image_bytes)
# --- Tests for GeminiOCRService --- (Example tests, more needed)
@patch('app.core.gemini.genai.configure')
@patch('app.core.gemini.genai.GenerativeModel')
def test_gemini_ocr_service_init_success(MockGenerativeModel, MockConfigure, mock_gemini_settings, mock_generative_model_instance):
MockGenerativeModel.return_value = mock_generative_model_instance
with patch('app.core.gemini.settings', mock_gemini_settings):
service = gemini.GeminiOCRService()
MockConfigure.assert_called_once_with(api_key=mock_gemini_settings.GEMINI_API_KEY)
MockGenerativeModel.assert_called_once_with(mock_gemini_settings.GEMINI_MODEL_NAME)
assert service.model == mock_generative_model_instance
# Could add assertions for safety_settings and generation_config if they are set directly on model
@patch('app.core.gemini.genai.configure')
@patch('app.core.gemini.genai.GenerativeModel', side_effect=Exception("Init model failed"))
def test_gemini_ocr_service_init_failure(MockGenerativeModel, MockConfigure, mock_gemini_settings):
with patch('app.core.gemini.settings', mock_gemini_settings):
with pytest.raises(OCRServiceConfigError):
gemini.GeminiOCRService()
@pytest.mark.asyncio
async def test_gemini_ocr_service_extract_items_success(mock_gemini_settings, mock_generative_model_instance):
mock_response = MagicMock()
mock_response.text = "Apple\nBanana\nOrange\nExample output should be ignored"
mock_generative_model_instance.generate_content_async.return_value = mock_response
with patch('app.core.gemini.settings', mock_gemini_settings):
# Patch the model instance within the service for this test
with patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance) as patched_model_class,
patch.object(genai, 'configure') as patched_configure:
service = gemini.GeminiOCRService() # Re-init to use the patched model
items = await service.extract_items(b"dummy_image")
expected_call_args = [
mock_gemini_settings.OCR_ITEM_EXTRACTION_PROMPT,
{"mime_type": "image/jpeg", "data": b"dummy_image"}
]
service.model.generate_content_async.assert_called_once_with(contents=expected_call_args)
assert items == ["Apple", "Banana", "Orange"]
@pytest.mark.asyncio
async def test_gemini_ocr_service_extract_items_quota_error(mock_gemini_settings, mock_generative_model_instance):
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ResourceExhausted("Quota limits exceeded.")
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance), \
patch.object(genai, 'configure'):
service = gemini.GeminiOCRService()
with pytest.raises(OCRQuotaExceededError):
await service.extract_items(b"dummy_image")
@pytest.mark.asyncio
async def test_gemini_ocr_service_extract_items_api_unavailable(mock_gemini_settings, mock_generative_model_instance):
# Simulate a generic API error that isn't quota related
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.InternalServerError("Service unavailable")
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance), \
patch.object(genai, 'configure'):
service = gemini.GeminiOCRService()
with pytest.raises(OCRServiceUnavailableError):
await service.extract_items(b"dummy_image")
@pytest.mark.asyncio
async def test_gemini_ocr_service_extract_items_no_text_response(mock_gemini_settings, mock_generative_model_instance):
mock_response = MagicMock()
mock_response.text = None # Simulate no text in response
mock_generative_model_instance.generate_content_async.return_value = mock_response
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance), \
patch.object(genai, 'configure'):
service = gemini.GeminiOCRService()
with pytest.raises(OCRUnexpectedError):
await service.extract_items(b"dummy_image")
import pytest
from unittest.mock import patch, MagicMock
from datetime import datetime, timedelta, timezone
from jose import jwt, JWTError
from passlib.context import CryptContext
from app.core.security import (
verify_password,
hash_password,
create_access_token,
create_refresh_token,
verify_access_token,
verify_refresh_token,
pwd_context, # Import for direct testing if needed, or to check its config
)
# Assuming app.config.settings will be mocked
# from app.config import settings
# --- Tests for Password Hashing ---
def test_hash_password():
password = "securepassword123"
hashed = hash_password(password)
assert isinstance(hashed, str)
assert hashed != password
# Check that the default scheme (bcrypt) is used by verifying the hash prefix
# bcrypt hashes typically start with $2b$ or $2a$ or $2y$
assert hashed.startswith("$2b$") or hashed.startswith("$2a$") or hashed.startswith("$2y$")
def test_verify_password_correct():
password = "testpassword"
hashed_password = pwd_context.hash(password) # Use the same context for consistency
assert verify_password(password, hashed_password) is True
def test_verify_password_incorrect():
password = "testpassword"
wrong_password = "wrongpassword"
hashed_password = pwd_context.hash(password)
assert verify_password(wrong_password, hashed_password) is False
def test_verify_password_invalid_hash_format():
password = "testpassword"
invalid_hash = "notarealhash"
assert verify_password(password, invalid_hash) is False
# --- Tests for JWT Creation ---
# Mock settings for JWT tests
@pytest.fixture(scope="module")
def mock_jwt_settings():
mock_settings = MagicMock()
mock_settings.SECRET_KEY = "testsecretkey"
mock_settings.ALGORITHM = "HS256"
mock_settings.ACCESS_TOKEN_EXPIRE_MINUTES = 30
mock_settings.REFRESH_TOKEN_EXPIRE_MINUTES = 10080 # 7 days
return mock_settings
@patch('app.core.security.settings')
def test_create_access_token_default_expiry(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
subject = "user@example.com"
token = create_access_token(subject)
assert isinstance(token, str)
decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
assert decoded_payload["sub"] == subject
assert decoded_payload["type"] == "access"
assert "exp" in decoded_payload
# Check if expiry is roughly correct (within a small delta)
expected_expiry = datetime.now(timezone.utc) + timedelta(minutes=mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES)
assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
@patch('app.core.security.settings')
def test_create_access_token_custom_expiry(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# ACCESS_TOKEN_EXPIRE_MINUTES is not used here due to custom delta
subject = 123 # Subject can be int
custom_delta = timedelta(hours=1)
token = create_access_token(subject, expires_delta=custom_delta)
assert isinstance(token, str)
decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
assert decoded_payload["sub"] == str(subject)
assert decoded_payload["type"] == "access"
expected_expiry = datetime.now(timezone.utc) + custom_delta
assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
@patch('app.core.security.settings')
def test_create_refresh_token_default_expiry(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES
subject = "refresh_subject"
token = create_refresh_token(subject)
assert isinstance(token, str)
decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
assert decoded_payload["sub"] == subject
assert decoded_payload["type"] == "refresh"
expected_expiry = datetime.now(timezone.utc) + timedelta(minutes=mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES)
assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
# --- Tests for JWT Verification --- (More tests to be added here)
@patch('app.core.security.settings')
def test_verify_access_token_valid(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
subject = "test_user_valid_access"
token = create_access_token(subject)
payload = verify_access_token(token)
assert payload is not None
assert payload["sub"] == subject
assert payload["type"] == "access"
@patch('app.core.security.settings')
def test_verify_access_token_invalid_signature(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
subject = "test_user_invalid_sig"
# Create token with correct key
token = create_access_token(subject)
# Try to verify with wrong key
mock_settings_global.SECRET_KEY = "wrongsecretkey"
payload = verify_access_token(token)
assert payload is None
@patch('app.core.security.settings')
@patch('app.core.security.datetime') # Mock datetime to control token expiry
def test_verify_access_token_expired(mock_datetime, mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = 1 # Expire in 1 minute
# Set current time for token creation
now = datetime.now(timezone.utc)
mock_datetime.now.return_value = now
mock_datetime.fromtimestamp = datetime.fromtimestamp # Ensure original fromtimestamp is used by jwt.decode
mock_datetime.timedelta = timedelta # Ensure original timedelta is used
subject = "test_user_expired"
token = create_access_token(subject)
# Advance time beyond expiry for verification
mock_datetime.now.return_value = now + timedelta(minutes=5)
payload = verify_access_token(token)
assert payload is None
@patch('app.core.security.settings')
def test_verify_access_token_wrong_type(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES # For refresh token creation
subject = "test_user_wrong_type"
# Create a refresh token
refresh_token = create_refresh_token(subject)
# Try to verify it as an access token
payload = verify_access_token(refresh_token)
assert payload is None
@patch('app.core.security.settings')
def test_verify_refresh_token_valid(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES
subject = "test_user_valid_refresh"
token = create_refresh_token(subject)
payload = verify_refresh_token(token)
assert payload is not None
assert payload["sub"] == subject
assert payload["type"] == "refresh"
@patch('app.core.security.settings')
@patch('app.core.security.datetime')
def test_verify_refresh_token_expired(mock_datetime, mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = 1 # Expire in 1 minute
now = datetime.now(timezone.utc)
mock_datetime.now.return_value = now
mock_datetime.fromtimestamp = datetime.fromtimestamp
mock_datetime.timedelta = timedelta
subject = "test_user_expired_refresh"
token = create_refresh_token(subject)
mock_datetime.now.return_value = now + timedelta(minutes=5)
payload = verify_refresh_token(token)
assert payload is None
@patch('app.core.security.settings')
def test_verify_refresh_token_wrong_type(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
subject = "test_user_wrong_type_refresh"
access_token = create_access_token(subject)
payload = verify_refresh_token(access_token)
assert payload is None
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError
from decimal import Decimal, ROUND_HALF_UP
from datetime import datetime, timezone
from typing import List as PyList, Optional
from app.crud.expense import (
create_expense,
get_expense_by_id,
get_expenses_for_list,
get_expenses_for_group,
update_expense, # Assuming update_expense exists
delete_expense, # Assuming delete_expense exists
get_users_for_splitting # Helper, might test indirectly
)
from app.schemas.expense import ExpenseCreate, ExpenseUpdate, ExpenseSplitCreate
from app.models import (
Expense as ExpenseModel,
ExpenseSplit as ExpenseSplitModel,
User as UserModel,
List as ListModel,
Group as GroupModel,
UserGroup as UserGroupModel,
Item as ItemModel,
SplitTypeEnum
)
from app.core.exceptions import (
ListNotFoundError,
GroupNotFoundError,
UserNotFoundError,
InvalidOperationError
)
# General Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock()
session.execute = AsyncMock()
session.get = AsyncMock()
session.flush = AsyncMock() # create_expense uses flush
return session
@pytest.fixture
def basic_user_model():
return UserModel(id=1, name="Test User", email="test@example.com")
@pytest.fixture
def another_user_model():
return UserModel(id=2, name="Another User", email="another@example.com")
@pytest.fixture
def basic_group_model():
group = GroupModel(id=1, name="Test Group")
# Simulate member_associations for get_users_for_splitting if needed directly
# group.member_associations = [UserGroupModel(user_id=1, group_id=1, user=basic_user_model()), UserGroupModel(user_id=2, group_id=1, user=another_user_model())]
return group
@pytest.fixture
def basic_list_model(basic_group_model, basic_user_model):
return ListModel(id=1, name="Test List", group_id=basic_group_model.id, group=basic_group_model, creator_id=basic_user_model.id, creator=basic_user_model)
@pytest.fixture
def expense_create_data_equal_split_list_ctx(basic_list_model, basic_user_model):
return ExpenseCreate(
description="Grocery run",
total_amount=Decimal("30.00"),
currency="USD",
expense_date=datetime.now(timezone.utc),
split_type=SplitTypeEnum.EQUAL,
list_id=basic_list_model.id,
group_id=None, # Derived from list
item_id=None,
paid_by_user_id=basic_user_model.id,
splits_in=None
)
@pytest.fixture
def expense_create_data_equal_split_group_ctx(basic_group_model, basic_user_model):
return ExpenseCreate(
description="Movies",
total_amount=Decimal("50.00"),
currency="USD",
expense_date=datetime.now(timezone.utc),
split_type=SplitTypeEnum.EQUAL,
list_id=None,
group_id=basic_group_model.id,
item_id=None,
paid_by_user_id=basic_user_model.id,
splits_in=None
)
@pytest.fixture
def expense_create_data_exact_split(basic_group_model, basic_user_model, another_user_model):
return ExpenseCreate(
description="Dinner",
total_amount=Decimal("100.00"),
split_type=SplitTypeEnum.EXACT_AMOUNTS,
group_id=basic_group_model.id,
paid_by_user_id=basic_user_model.id,
splits_in=[
ExpenseSplitCreate(user_id=basic_user_model.id, owed_amount=Decimal("60.00")),
ExpenseSplitCreate(user_id=another_user_model.id, owed_amount=Decimal("40.00")),
]
)
@pytest.fixture
def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model):
return ExpenseModel(
id=1,
description=expense_create_data_equal_split_group_ctx.description,
total_amount=expense_create_data_equal_split_group_ctx.total_amount,
currency=expense_create_data_equal_split_group_ctx.currency,
expense_date=expense_create_data_equal_split_group_ctx.expense_date,
split_type=expense_create_data_equal_split_group_ctx.split_type,
list_id=expense_create_data_equal_split_group_ctx.list_id,
group_id=expense_create_data_equal_split_group_ctx.group_id,
item_id=expense_create_data_equal_split_group_ctx.item_id,
paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id,
paid_by=basic_user_model, # Assuming paid_by relation is loaded
# splits would be populated after creation usually
version=1
)
# Tests for get_users_for_splitting (indirectly tested via create_expense, but stubs for direct if needed)
@pytest.mark.asyncio
async def test_get_users_for_splitting_group_context(mock_db_session, basic_group_model, basic_user_model, another_user_model):
# Setup group with members
user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id)
user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id)
basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2]
mock_execute = AsyncMock()
mock_execute.scalars.return_value.first.return_value = basic_group_model
mock_db_session.execute.return_value = mock_execute
users = await get_users_for_splitting(mock_db_session, expense_group_id=1, expense_list_id=None, expense_paid_by_user_id=1)
assert len(users) == 2
assert basic_user_model in users
assert another_user_model in users
# --- create_expense Tests ---
@pytest.mark.asyncio
async def test_create_expense_equal_split_group_success(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model, another_user_model):
mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # Payer, Group
# Mock get_users_for_splitting call within create_expense
# This is a bit tricky as it's an internal call. Patching is an option.
with patch('app.crud.expense.get_users_for_splitting', new_callable=AsyncMock) as mock_get_users:
mock_get_users.return_value = [basic_user_model, another_user_model]
created_expense = await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, current_user_id=1)
mock_db_session.add.assert_called()
mock_db_session.flush.assert_called_once()
# mock_db_session.commit.assert_called_once() # create_expense does not commit itself
# mock_db_session.refresh.assert_called_once() # create_expense does not refresh itself
assert created_expense is not None
assert created_expense.total_amount == expense_create_data_equal_split_group_ctx.total_amount
assert created_expense.split_type == SplitTypeEnum.EQUAL
assert len(created_expense.splits) == 2 # Expect splits to be added to the model instance
# Check split amounts
expected_amount_per_user = (expense_create_data_equal_split_group_ctx.total_amount / 2).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
for split in created_expense.splits:
assert split.owed_amount == expected_amount_per_user
@pytest.mark.asyncio
async def test_create_expense_exact_split_success(mock_db_session, expense_create_data_exact_split, basic_user_model, basic_group_model, another_user_model):
mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # Payer, Group
# Mock the select for user validation in exact splits
mock_user_select_result = AsyncMock()
mock_user_select_result.all.return_value = [(basic_user_model.id,), (another_user_model.id,)] # Simulate (id,) tuples
# To make it behave like scalars().all() that returns a list of IDs:
# We need to mock the scalars().all() part, or the whole execute chain for user validation.
# A simpler way for this specific case might be to mock the select for User.id
mock_execute_user_ids = AsyncMock()
# Construct a mock result that mimics what `await db.execute(select(UserModel.id)...)` would give then process
# It's a bit involved, usually `found_user_ids = {row[0] for row in user_results}`
# Let's assume the select returns a list of Row objects or tuples with one element
mock_user_ids_result_proxy = MagicMock()
mock_user_ids_result_proxy.__iter__.return_value = iter([(basic_user_model.id,), (another_user_model.id,)])
mock_db_session.execute.return_value = mock_user_ids_result_proxy
created_expense = await create_expense(mock_db_session, expense_create_data_exact_split, current_user_id=1)
mock_db_session.add.assert_called()
mock_db_session.flush.assert_called_once()
assert created_expense is not None
assert created_expense.split_type == SplitTypeEnum.EXACT_AMOUNTS
assert len(created_expense.splits) == 2
assert created_expense.splits[0].owed_amount == Decimal("60.00")
assert created_expense.splits[1].owed_amount == Decimal("40.00")
@pytest.mark.asyncio
async def test_create_expense_payer_not_found(mock_db_session, expense_create_data_equal_split_group_ctx):
mock_db_session.get.return_value = None # Payer not found
with pytest.raises(UserNotFoundError):
await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, 1)
@pytest.mark.asyncio
async def test_create_expense_no_list_or_group(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model):
mock_db_session.get.return_value = basic_user_model # Payer found
expense_data = expense_create_data_equal_split_group_ctx.model_copy()
expense_data.list_id = None
expense_data.group_id = None
with pytest.raises(InvalidOperationError, match="Expense must be associated with a list or a group"):
await create_expense(mock_db_session, expense_data, 1)
# --- get_expense_by_id Tests ---
@pytest.mark.asyncio
async def test_get_expense_by_id_found(mock_db_session, db_expense_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_expense_model
mock_db_session.execute.return_value = mock_result
expense = await get_expense_by_id(mock_db_session, 1)
assert expense is not None
assert expense.id == 1
mock_db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_expense_by_id_not_found(mock_db_session):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None
mock_db_session.execute.return_value = mock_result
expense = await get_expense_by_id(mock_db_session, 999)
assert expense is None
# --- get_expenses_for_list Tests ---
@pytest.mark.asyncio
async def test_get_expenses_for_list_success(mock_db_session, db_expense_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_expense_model]
mock_db_session.execute.return_value = mock_result
expenses = await get_expenses_for_list(mock_db_session, list_id=1)
assert len(expenses) == 1
assert expenses[0].id == db_expense_model.id
mock_db_session.execute.assert_called_once()
# --- get_expenses_for_group Tests ---
@pytest.mark.asyncio
async def test_get_expenses_for_group_success(mock_db_session, db_expense_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_expense_model]
mock_db_session.execute.return_value = mock_result
expenses = await get_expenses_for_group(mock_db_session, group_id=1)
assert len(expenses) == 1
assert expenses[0].id == db_expense_model.id
mock_db_session.execute.assert_called_once()
# --- Stubs for update_expense and delete_expense ---
# These will need more details once the actual implementation of update/delete is clear
# For example, how splits are handled on update, versioning, etc.
@pytest.mark.asyncio
async def test_update_expense_stub(mock_db_session):
# Placeholder: Test logic for update_expense will be more complex
# Needs ExpenseUpdate schema, existing expense object, and mocking of commit/refresh
# Also depends on what fields are updatable and how splits are managed.
expense_to_update = MagicMock(spec=ExpenseModel)
expense_to_update.version = 1
update_payload = ExpenseUpdate(description="New description", version=1) # Add other fields as per schema definition
# Simulate the update_expense function behavior
# For example, if it loads the expense, modifies, commits, refreshes:
# mock_db_session.get.return_value = expense_to_update
# updated_expense = await update_expense(mock_db_session, expense_to_update, update_payload)
# assert updated_expense.description == "New description"
# mock_db_session.commit.assert_called_once()
# mock_db_session.refresh.assert_called_once()
pass # Replace with actual test logic
@pytest.mark.asyncio
async def test_delete_expense_stub(mock_db_session):
# Placeholder: Test logic for delete_expense
# Needs an existing expense object and mocking of delete/commit
# Also, consider implications (e.g., are splits deleted?)
expense_to_delete = MagicMock(spec=ExpenseModel)
expense_to_delete.id = 1
expense_to_delete.version = 1
# Simulate delete_expense behavior
# mock_db_session.get.return_value = expense_to_delete # If it re-fetches
# await delete_expense(mock_db_session, expense_to_delete, expected_version=1)
# mock_db_session.delete.assert_called_once_with(expense_to_delete)
# mock_db_session.commit.assert_called_once()
pass # Replace with actual test logic
# TODO: Add more tests for create_expense covering:
# - List context success
# - Percentage, Shares, Item-based splits
# - Error cases for each split type (e.g., total mismatch, invalid inputs)
# - Validation of list_id/group_id consistency
# - User not found in splits_in
# - Item not found for ITEM_BASED split
# TODO: Flesh out update_expense tests:
# - Success case
# - Version mismatch
# - Trying to update immutable fields
# - How splits are handled (recalculated, deleted/recreated, or not changeable)
# TODO: Flesh out delete_expense tests:
# - Success case
# - Version mismatch (if applicable)
# - Ensure associated splits are also deleted (cascade behavior)
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError
from sqlalchemy.future import select
from sqlalchemy import delete, func # For remove_user_from_group and get_group_member_count
from app.crud.group import (
create_group,
get_user_groups,
get_group_by_id,
is_user_member,
get_user_role_in_group,
add_user_to_group,
remove_user_from_group,
get_group_member_count,
check_group_membership,
check_user_role_in_group
)
from app.schemas.group import GroupCreate
from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel, UserRoleEnum
from app.core.exceptions import (
GroupOperationError,
GroupNotFoundError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
GroupMembershipError,
GroupPermissionError
)
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
# Patch begin_nested for SQLAlchemy 1.4+ if used, or just begin() if that's the pattern
# For simplicity, assuming `async with db.begin():` translates to db.begin() and db.commit()/rollback()
session.begin = AsyncMock() # Mock the begin call used in async with db.begin()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock() # For remove_user_from_group (if it uses session.delete)
session.execute = AsyncMock()
session.get = AsyncMock()
session.flush = AsyncMock()
return session
@pytest.fixture
def group_create_data():
return GroupCreate(name="Test Group")
@pytest.fixture
def creator_user_model():
return UserModel(id=1, name="Creator User", email="creator@example.com")
@pytest.fixture
def member_user_model():
return UserModel(id=2, name="Member User", email="member@example.com")
@pytest.fixture
def db_group_model(creator_user_model):
return GroupModel(id=1, name="Test Group", created_by_id=creator_user_model.id, creator=creator_user_model)
@pytest.fixture
def db_user_group_owner_assoc(db_group_model, creator_user_model):
return UserGroupModel(user_id=creator_user_model.id, group_id=db_group_model.id, role=UserRoleEnum.owner, user=creator_user_model, group=db_group_model)
@pytest.fixture
def db_user_group_member_assoc(db_group_model, member_user_model):
return UserGroupModel(user_id=member_user_model.id, group_id=db_group_model.id, role=UserRoleEnum.member, user=member_user_model, group=db_group_model)
# --- create_group Tests ---
@pytest.mark.asyncio
async def test_create_group_success(mock_db_session, group_create_data, creator_user_model):
async def mock_refresh(instance):
instance.id = 1 # Simulate ID assignment by DB
return None
mock_db_session.refresh = AsyncMock(side_effect=mock_refresh)
created_group = await create_group(mock_db_session, group_create_data, creator_user_model.id)
assert mock_db_session.add.call_count == 2 # Group and UserGroup
mock_db_session.flush.assert_called() # Called multiple times
mock_db_session.refresh.assert_called_once_with(created_group)
assert created_group is not None
assert created_group.name == group_create_data.name
assert created_group.created_by_id == creator_user_model.id
# Further check if UserGroup was created correctly by inspecting mock_db_session.add calls or by fetching
@pytest.mark.asyncio
async def test_create_group_integrity_error(mock_db_session, group_create_data, creator_user_model):
mock_db_session.flush.side_effect = IntegrityError("mock integrity error", "params", "orig")
with pytest.raises(DatabaseIntegrityError):
await create_group(mock_db_session, group_create_data, creator_user_model.id)
mock_db_session.rollback.assert_called_once() # Assuming rollback within the except block of create_group
# --- get_user_groups Tests ---
@pytest.mark.asyncio
async def test_get_user_groups_success(mock_db_session, db_group_model, creator_user_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_group_model]
mock_db_session.execute.return_value = mock_result
groups = await get_user_groups(mock_db_session, creator_user_model.id)
assert len(groups) == 1
assert groups[0].name == db_group_model.name
mock_db_session.execute.assert_called_once()
# --- get_group_by_id Tests ---
@pytest.mark.asyncio
async def test_get_group_by_id_found(mock_db_session, db_group_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_group_model
mock_db_session.execute.return_value = mock_result
group = await get_group_by_id(mock_db_session, db_group_model.id)
assert group is not None
assert group.id == db_group_model.id
# Add assertions for eager loaded members if applicable and mocked
@pytest.mark.asyncio
async def test_get_group_by_id_not_found(mock_db_session):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None
mock_db_session.execute.return_value = mock_result
group = await get_group_by_id(mock_db_session, 999)
assert group is None
# --- is_user_member Tests ---
@pytest.mark.asyncio
async def test_is_user_member_true(mock_db_session, db_group_model, creator_user_model):
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = 1 # Simulate UserGroup.id found
mock_db_session.execute.return_value = mock_result
is_member = await is_user_member(mock_db_session, db_group_model.id, creator_user_model.id)
assert is_member is True
@pytest.mark.asyncio
async def test_is_user_member_false(mock_db_session, db_group_model, member_user_model):
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = None # Simulate no UserGroup.id found
mock_db_session.execute.return_value = mock_result
is_member = await is_user_member(mock_db_session, db_group_model.id, member_user_model.id + 1) # Non-member
assert is_member is False
# --- get_user_role_in_group Tests ---
@pytest.mark.asyncio
async def test_get_user_role_in_group_owner(mock_db_session, db_group_model, creator_user_model):
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = UserRoleEnum.owner
mock_db_session.execute.return_value = mock_result
role = await get_user_role_in_group(mock_db_session, db_group_model.id, creator_user_model.id)
assert role == UserRoleEnum.owner
# --- add_user_to_group Tests ---
@pytest.mark.asyncio
async def test_add_user_to_group_new_member(mock_db_session, db_group_model, member_user_model):
# First execute call for checking existing membership returns None
mock_existing_check_result = AsyncMock()
mock_existing_check_result.scalar_one_or_none.return_value = None
mock_db_session.execute.return_value = mock_existing_check_result
async def mock_refresh_user_group(instance):
instance.id = 100 # Simulate ID for UserGroupModel
return None
mock_db_session.refresh = AsyncMock(side_effect=mock_refresh_user_group)
user_group_assoc = await add_user_to_group(mock_db_session, db_group_model.id, member_user_model.id, UserRoleEnum.member)
mock_db_session.add.assert_called_once()
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once()
assert user_group_assoc is not None
assert user_group_assoc.user_id == member_user_model.id
assert user_group_assoc.group_id == db_group_model.id
assert user_group_assoc.role == UserRoleEnum.member
@pytest.mark.asyncio
async def test_add_user_to_group_already_member(mock_db_session, db_group_model, creator_user_model, db_user_group_owner_assoc):
mock_existing_check_result = AsyncMock()
mock_existing_check_result.scalar_one_or_none.return_value = db_user_group_owner_assoc # User is already a member
mock_db_session.execute.return_value = mock_existing_check_result
user_group_assoc = await add_user_to_group(mock_db_session, db_group_model.id, creator_user_model.id)
assert user_group_assoc is None
mock_db_session.add.assert_not_called()
# --- remove_user_from_group Tests ---
@pytest.mark.asyncio
async def test_remove_user_from_group_success(mock_db_session, db_group_model, member_user_model):
mock_delete_result = AsyncMock()
mock_delete_result.scalar_one_or_none.return_value = 1 # Simulate a row was deleted (returning ID)
mock_db_session.execute.return_value = mock_delete_result
removed = await remove_user_from_group(mock_db_session, db_group_model.id, member_user_model.id)
assert removed is True
# Assert that db.execute was called with a delete statement
# This requires inspecting the call args of mock_db_session.execute
# For simplicity, we check it was called. A deeper check would validate the SQL query itself.
mock_db_session.execute.assert_called_once()
# --- get_group_member_count Tests ---
@pytest.mark.asyncio
async def test_get_group_member_count_success(mock_db_session, db_group_model):
mock_count_result = AsyncMock()
mock_count_result.scalar_one.return_value = 5
mock_db_session.execute.return_value = mock_count_result
count = await get_group_member_count(mock_db_session, db_group_model.id)
assert count == 5
# --- check_group_membership Tests ---
@pytest.mark.asyncio
async def test_check_group_membership_is_member(mock_db_session, db_group_model, creator_user_model):
mock_db_session.get.return_value = db_group_model # Group exists
mock_membership_result = AsyncMock()
mock_membership_result.scalar_one_or_none.return_value = 1 # User is a member
mock_db_session.execute.return_value = mock_membership_result
await check_group_membership(mock_db_session, db_group_model.id, creator_user_model.id)
# No exception means success
@pytest.mark.asyncio
async def test_check_group_membership_group_not_found(mock_db_session, creator_user_model):
mock_db_session.get.return_value = None # Group does not exist
with pytest.raises(GroupNotFoundError):
await check_group_membership(mock_db_session, 999, creator_user_model.id)
@pytest.mark.asyncio
async def test_check_group_membership_not_member(mock_db_session, db_group_model, member_user_model):
mock_db_session.get.return_value = db_group_model # Group exists
mock_membership_result = AsyncMock()
mock_membership_result.scalar_one_or_none.return_value = None # User is not a member
mock_db_session.execute.return_value = mock_membership_result
with pytest.raises(GroupMembershipError):
await check_group_membership(mock_db_session, db_group_model.id, member_user_model.id)
# --- check_user_role_in_group Tests ---
@pytest.mark.asyncio
async def test_check_user_role_in_group_sufficient_role(mock_db_session, db_group_model, creator_user_model):
# Mock check_group_membership (implicitly called)
mock_db_session.get.return_value = db_group_model
mock_membership_check = AsyncMock()
mock_membership_check.scalar_one_or_none.return_value = 1 # User is member
# Mock get_user_role_in_group
mock_role_check = AsyncMock()
mock_role_check.scalar_one_or_none.return_value = UserRoleEnum.owner
mock_db_session.execute.side_effect = [mock_membership_check, mock_role_check]
await check_user_role_in_group(mock_db_session, db_group_model.id, creator_user_model.id, UserRoleEnum.member)
# No exception means success
@pytest.mark.asyncio
async def test_check_user_role_in_group_insufficient_role(mock_db_session, db_group_model, member_user_model):
mock_db_session.get.return_value = db_group_model # Group exists
mock_membership_check = AsyncMock()
mock_membership_check.scalar_one_or_none.return_value = 1 # User is member (for check_group_membership call)
mock_role_check = AsyncMock()
mock_role_check.scalar_one_or_none.return_value = UserRoleEnum.member # User's actual role
mock_db_session.execute.side_effect = [mock_membership_check, mock_role_check]
with pytest.raises(GroupPermissionError):
await check_user_role_in_group(mock_db_session, db_group_model.id, member_user_model.id, UserRoleEnum.owner)
# TODO: Add tests for DB operational/SQLAlchemy errors for each function similar to create_group_integrity_error
# TODO: Test edge cases like trying to add user to non-existent group (should be caught by FK constraints or prior checks)
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError # Assuming these might be raised
from datetime import datetime, timedelta, timezone
import secrets
from app.crud.invite import (
create_invite,
get_active_invite_by_code,
deactivate_invite,
MAX_CODE_GENERATION_ATTEMPTS
)
from app.models import Invite as InviteModel, User as UserModel, Group as GroupModel # For context
# No specific schemas for invite CRUD usually, but models are used.
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.execute = AsyncMock()
return session
@pytest.fixture
def group_model():
return GroupModel(id=1, name="Test Group")
@pytest.fixture
def user_model(): # Creator
return UserModel(id=1, name="Creator User")
@pytest.fixture
def db_invite_model(group_model, user_model):
return InviteModel(
id=1,
code="test_invite_code_123",
group_id=group_model.id,
created_by_id=user_model.id,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
is_active=True
)
# --- create_invite Tests ---
@pytest.mark.asyncio
@patch('app.crud.invite.secrets.token_urlsafe') # Patch secrets.token_urlsafe
async def test_create_invite_success_first_attempt(mock_token_urlsafe, mock_db_session, group_model, user_model):
generated_code = "unique_code_123"
mock_token_urlsafe.return_value = generated_code
# Mock DB execute for checking existing code (first attempt, no existing code)
mock_existing_check_result = AsyncMock()
mock_existing_check_result.scalar_one_or_none.return_value = None
mock_db_session.execute.return_value = mock_existing_check_result
invite = await create_invite(mock_db_session, group_model.id, user_model.id, expires_in_days=5)
mock_token_urlsafe.assert_called_once_with(16)
mock_db_session.execute.assert_called_once() # For the uniqueness check
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
mock_db_session.refresh.assert_called_once_with(invite)
assert invite is not None
assert invite.code == generated_code
assert invite.group_id == group_model.id
assert invite.created_by_id == user_model.id
assert invite.is_active is True
assert invite.expires_at > datetime.now(timezone.utc) + timedelta(days=4) # Check expiry is roughly correct
@pytest.mark.asyncio
@patch('app.crud.invite.secrets.token_urlsafe')
async def test_create_invite_success_after_collision(mock_token_urlsafe, mock_db_session, group_model, user_model):
colliding_code = "colliding_code"
unique_code = "finally_unique_code"
mock_token_urlsafe.side_effect = [colliding_code, unique_code] # First call collides, second is unique
# Mock DB execute for checking existing code
mock_collision_check_result = AsyncMock()
mock_collision_check_result.scalar_one_or_none.return_value = 1 # Simulate collision (ID found)
mock_no_collision_check_result = AsyncMock()
mock_no_collision_check_result.scalar_one_or_none.return_value = None # No collision
mock_db_session.execute.side_effect = [mock_collision_check_result, mock_no_collision_check_result]
invite = await create_invite(mock_db_session, group_model.id, user_model.id)
assert mock_token_urlsafe.call_count == 2
assert mock_db_session.execute.call_count == 2
assert invite is not None
assert invite.code == unique_code
@pytest.mark.asyncio
@patch('app.crud.invite.secrets.token_urlsafe')
async def test_create_invite_fails_after_max_attempts(mock_token_urlsafe, mock_db_session, group_model, user_model):
mock_token_urlsafe.return_value = "always_colliding_code"
mock_collision_check_result = AsyncMock()
mock_collision_check_result.scalar_one_or_none.return_value = 1 # Always collide
mock_db_session.execute.return_value = mock_collision_check_result
invite = await create_invite(mock_db_session, group_model.id, user_model.id)
assert invite is None
assert mock_token_urlsafe.call_count == MAX_CODE_GENERATION_ATTEMPTS
assert mock_db_session.execute.call_count == MAX_CODE_GENERATION_ATTEMPTS
mock_db_session.add.assert_not_called()
# --- get_active_invite_by_code Tests ---
@pytest.mark.asyncio
async def test_get_active_invite_by_code_found_active(mock_db_session, db_invite_model):
db_invite_model.is_active = True
db_invite_model.expires_at = datetime.now(timezone.utc) + timedelta(days=1)
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_invite_model
mock_db_session.execute.return_value = mock_result
invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code)
assert invite is not None
assert invite.code == db_invite_model.code
mock_db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_active_invite_by_code_not_found(mock_db_session):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None
mock_db_session.execute.return_value = mock_result
invite = await get_active_invite_by_code(mock_db_session, "non_existent_code")
assert invite is None
@pytest.mark.asyncio
async def test_get_active_invite_by_code_inactive(mock_db_session, db_invite_model):
db_invite_model.is_active = False # Inactive
db_invite_model.expires_at = datetime.now(timezone.utc) + timedelta(days=1)
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None # Should not be found by query
mock_db_session.execute.return_value = mock_result
invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code)
assert invite is None
@pytest.mark.asyncio
async def test_get_active_invite_by_code_expired(mock_db_session, db_invite_model):
db_invite_model.is_active = True
db_invite_model.expires_at = datetime.now(timezone.utc) - timedelta(days=1) # Expired
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None # Should not be found by query
mock_db_session.execute.return_value = mock_result
invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code)
assert invite is None
# --- deactivate_invite Tests ---
@pytest.mark.asyncio
async def test_deactivate_invite_success(mock_db_session, db_invite_model):
db_invite_model.is_active = True # Ensure it starts active
deactivated_invite = await deactivate_invite(mock_db_session, db_invite_model)
mock_db_session.add.assert_called_once_with(db_invite_model)
mock_db_session.commit.assert_called_once()
mock_db_session.refresh.assert_called_once_with(db_invite_model)
assert deactivated_invite.is_active is False
# It might be useful to test DB error cases (OperationalError, etc.) for each function
# if they have specific try-except blocks, but invite.py seems to rely on caller/framework for some of that.
# create_invite has its own DB interaction within the loop, so that's covered.
# get_active_invite_by_code and deactivate_invite are simpler DB ops.
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError
from datetime import datetime, timezone
from app.crud.item import (
create_item,
get_items_by_list_id,
get_item_by_id,
update_item,
delete_item
)
from app.schemas.item import ItemCreate, ItemUpdate
from app.models import Item as ItemModel, User as UserModel, List as ListModel
from app.core.exceptions import (
ItemNotFoundError, # Not directly raised by CRUD but good for API layer tests
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
ConflictError
)
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.begin = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock()
session.execute = AsyncMock()
session.get = AsyncMock() # Though not directly used in item.py, good for consistency
session.flush = AsyncMock()
return session
@pytest.fixture
def item_create_data():
return ItemCreate(name="Test Item", quantity="1 pack")
@pytest.fixture
def item_update_data():
return ItemUpdate(name="Updated Test Item", quantity="2 packs", version=1, is_complete=False)
@pytest.fixture
def user_model():
return UserModel(id=1, name="Test User", email="test@example.com")
@pytest.fixture
def list_model():
return ListModel(id=1, name="Test List")
@pytest.fixture
def db_item_model(list_model, user_model):
return ItemModel(
id=1,
name="Existing Item",
quantity="1 unit",
list_id=list_model.id,
added_by_id=user_model.id,
is_complete=False,
version=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
# --- create_item Tests ---
@pytest.mark.asyncio
async def test_create_item_success(mock_db_session, item_create_data, list_model, user_model):
async def mock_refresh(instance):
instance.id = 10 # Simulate ID assignment
instance.version = 1 # Simulate version init
instance.created_at = datetime.now(timezone.utc)
instance.updated_at = datetime.now(timezone.utc)
return None
mock_db_session.refresh = AsyncMock(side_effect=mock_refresh)
created_item = await create_item(mock_db_session, item_create_data, list_model.id, user_model.id)
mock_db_session.add.assert_called_once()
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once_with(created_item)
assert created_item is not None
assert created_item.name == item_create_data.name
assert created_item.list_id == list_model.id
assert created_item.added_by_id == user_model.id
assert created_item.is_complete is False
assert created_item.version == 1
@pytest.mark.asyncio
async def test_create_item_integrity_error(mock_db_session, item_create_data, list_model, user_model):
mock_db_session.flush.side_effect = IntegrityError("mock integrity error", "params", "orig")
with pytest.raises(DatabaseIntegrityError):
await create_item(mock_db_session, item_create_data, list_model.id, user_model.id)
mock_db_session.rollback.assert_called_once()
# --- get_items_by_list_id Tests ---
@pytest.mark.asyncio
async def test_get_items_by_list_id_success(mock_db_session, db_item_model, list_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_item_model]
mock_db_session.execute.return_value = mock_result
items = await get_items_by_list_id(mock_db_session, list_model.id)
assert len(items) == 1
assert items[0].id == db_item_model.id
mock_db_session.execute.assert_called_once()
# --- get_item_by_id Tests ---
@pytest.mark.asyncio
async def test_get_item_by_id_found(mock_db_session, db_item_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_item_model
mock_db_session.execute.return_value = mock_result
item = await get_item_by_id(mock_db_session, db_item_model.id)
assert item is not None
assert item.id == db_item_model.id
@pytest.mark.asyncio
async def test_get_item_by_id_not_found(mock_db_session):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None
mock_db_session.execute.return_value = mock_result
item = await get_item_by_id(mock_db_session, 999)
assert item is None
# --- update_item Tests ---
@pytest.mark.asyncio
async def test_update_item_success(mock_db_session, db_item_model, item_update_data, user_model):
item_update_data.version = db_item_model.version # Match versions for successful update
item_update_data.name = "Newly Updated Name"
item_update_data.is_complete = True # Test completion logic
updated_item = await update_item(mock_db_session, db_item_model, item_update_data, user_model.id)
mock_db_session.add.assert_called_once_with(db_item_model) # add is used for existing objects too
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once_with(db_item_model)
assert updated_item.name == "Newly Updated Name"
assert updated_item.version == db_item_model.version # Check version increment logic in test
assert updated_item.is_complete is True
assert updated_item.completed_by_id == user_model.id
@pytest.mark.asyncio
async def test_update_item_version_conflict(mock_db_session, db_item_model, item_update_data, user_model):
item_update_data.version = db_item_model.version + 1 # Create a version mismatch
with pytest.raises(ConflictError):
await update_item(mock_db_session, db_item_model, item_update_data, user_model.id)
mock_db_session.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_update_item_set_incomplete(mock_db_session, db_item_model, item_update_data, user_model):
db_item_model.is_complete = True # Start as complete
db_item_model.completed_by_id = user_model.id
db_item_model.version = 1
item_update_data.version = 1
item_update_data.is_complete = False
item_update_data.name = db_item_model.name # No name change for this test
item_update_data.quantity = db_item_model.quantity
updated_item = await update_item(mock_db_session, db_item_model, item_update_data, user_model.id)
assert updated_item.is_complete is False
assert updated_item.completed_by_id is None
assert updated_item.version == 2
# --- delete_item Tests ---
@pytest.mark.asyncio
async def test_delete_item_success(mock_db_session, db_item_model):
result = await delete_item(mock_db_session, db_item_model)
assert result is None
mock_db_session.delete.assert_called_once_with(db_item_model)
mock_db_session.commit.assert_called_once() # Commit happens in the `async with db.begin()` context manager
@pytest.mark.asyncio
async def test_delete_item_db_error(mock_db_session, db_item_model):
mock_db_session.delete.side_effect = OperationalError("mock op error", "params", "orig")
with pytest.raises(DatabaseConnectionError):
await delete_item(mock_db_session, db_item_model)
mock_db_session.rollback.assert_called_once()
# TODO: Add more specific DB error tests (Operational, SQLAlchemyError) for each function.
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError
from sqlalchemy.future import select
from sqlalchemy import func as sql_func # For get_list_status
from datetime import datetime, timezone
from app.crud.list import (
create_list,
get_lists_for_user,
get_list_by_id,
update_list,
delete_list,
check_list_permission,
get_list_status
)
from app.schemas.list import ListCreate, ListUpdate, ListStatus
from app.models import List as ListModel, User as UserModel, Group as GroupModel, UserGroup as UserGroupModel, Item as ItemModel
from app.core.exceptions import (
ListNotFoundError,
ListPermissionError,
ListCreatorRequiredError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
ConflictError
)
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.begin = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock()
session.execute = AsyncMock()
session.get = AsyncMock() # Used by check_list_permission via get_list_by_id
session.flush = AsyncMock()
return session
@pytest.fixture
def list_create_data():
return ListCreate(name="New Shopping List", description="Groceries for the week")
@pytest.fixture
def list_update_data():
return ListUpdate(name="Updated Shopping List", description="Weekend Groceries", version=1)
@pytest.fixture
def user_model():
return UserModel(id=1, name="Test User", email="test@example.com")
@pytest.fixture
def another_user_model():
return UserModel(id=2, name="Another User", email="another@example.com")
@pytest.fixture
def group_model():
return GroupModel(id=1, name="Test Group")
@pytest.fixture
def db_list_personal_model(user_model):
return ListModel(
id=1, name="Personal List", created_by_id=user_model.id, creator=user_model,
version=1, updated_at=datetime.now(timezone.utc), items=[]
)
@pytest.fixture
def db_list_group_model(user_model, group_model):
return ListModel(
id=2, name="Group List", created_by_id=user_model.id, creator=user_model,
group_id=group_model.id, group=group_model, version=1, updated_at=datetime.now(timezone.utc), items=[]
)
# --- create_list Tests ---
@pytest.mark.asyncio
async def test_create_list_success(mock_db_session, list_create_data, user_model):
async def mock_refresh(instance):
instance.id = 100
instance.version = 1
instance.updated_at = datetime.now(timezone.utc)
return None
mock_db_session.refresh.return_value = None
mock_db_session.refresh.side_effect = mock_refresh
created_list = await create_list(mock_db_session, list_create_data, user_model.id)
mock_db_session.add.assert_called_once()
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once()
assert created_list.name == list_create_data.name
assert created_list.created_by_id == user_model.id
# --- get_lists_for_user Tests ---
@pytest.mark.asyncio
async def test_get_lists_for_user_mix(mock_db_session, user_model, db_list_personal_model, db_list_group_model):
# Simulate user is part of group for db_list_group_model
mock_group_ids_result = AsyncMock()
mock_group_ids_result.scalars.return_value.all.return_value = [db_list_group_model.group_id]
mock_lists_result = AsyncMock()
# Order should be personal list (created by user_id) then group list
mock_lists_result.scalars.return_value.all.return_value = [db_list_personal_model, db_list_group_model]
mock_db_session.execute.side_effect = [mock_group_ids_result, mock_lists_result]
lists = await get_lists_for_user(mock_db_session, user_model.id)
assert len(lists) == 2
assert db_list_personal_model in lists
assert db_list_group_model in lists
assert mock_db_session.execute.call_count == 2
# --- get_list_by_id Tests ---
@pytest.mark.asyncio
async def test_get_list_by_id_found_no_items(mock_db_session, db_list_personal_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_list_personal_model
mock_db_session.execute.return_value = mock_result
found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=False)
assert found_list is not None
assert found_list.id == db_list_personal_model.id
# query options should not include selectinload for items
# (difficult to assert directly without inspecting query object in detail)
@pytest.mark.asyncio
async def test_get_list_by_id_found_with_items(mock_db_session, db_list_personal_model):
# Simulate items loaded for the list
db_list_personal_model.items = [ItemModel(id=1, name="Test Item")]
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_list_personal_model
mock_db_session.execute.return_value = mock_result
found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=True)
assert found_list is not None
assert len(found_list.items) == 1
# query options should include selectinload for items
# --- update_list Tests ---
@pytest.mark.asyncio
async def test_update_list_success(mock_db_session, db_list_personal_model, list_update_data):
list_update_data.version = db_list_personal_model.version # Match version
updated_list = await update_list(mock_db_session, db_list_personal_model, list_update_data)
assert updated_list.name == list_update_data.name
assert updated_list.version == db_list_personal_model.version # version incremented in db_list_personal_model
mock_db_session.add.assert_called_once_with(db_list_personal_model)
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once_with(db_list_personal_model)
@pytest.mark.asyncio
async def test_update_list_conflict(mock_db_session, db_list_personal_model, list_update_data):
list_update_data.version = db_list_personal_model.version + 1 # Version mismatch
with pytest.raises(ConflictError):
await update_list(mock_db_session, db_list_personal_model, list_update_data)
mock_db_session.rollback.assert_called_once()
# --- delete_list Tests ---
@pytest.mark.asyncio
async def test_delete_list_success(mock_db_session, db_list_personal_model):
await delete_list(mock_db_session, db_list_personal_model)
mock_db_session.delete.assert_called_once_with(db_list_personal_model)
mock_db_session.commit.assert_called_once() # from async with db.begin()
# --- check_list_permission Tests ---
@pytest.mark.asyncio
async def test_check_list_permission_creator_access_personal_list(mock_db_session, db_list_personal_model, user_model):
# get_list_by_id (called by check_list_permission) will mock execute
mock_list_fetch_result = AsyncMock()
mock_list_fetch_result.scalars.return_value.first.return_value = db_list_personal_model
mock_db_session.execute.return_value = mock_list_fetch_result
ret_list = await check_list_permission(mock_db_session, db_list_personal_model.id, user_model.id)
assert ret_list.id == db_list_personal_model.id
@pytest.mark.asyncio
async def test_check_list_permission_group_member_access_group_list(mock_db_session, db_list_group_model, another_user_model, group_model):
# User `another_user_model` is not creator but member of the group
db_list_group_model.creator_id = user_model.id # Original creator is user_model
db_list_group_model.creator = user_model
# Mock get_list_by_id internal call
mock_list_fetch_result = AsyncMock()
mock_list_fetch_result.scalars.return_value.first.return_value = db_list_group_model
# Mock is_user_member call
with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member:
mock_is_member.return_value = True # another_user_model is a member
mock_db_session.execute.return_value = mock_list_fetch_result
ret_list = await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id)
assert ret_list.id == db_list_group_model.id
mock_is_member.assert_called_once_with(mock_db_session, group_id=group_model.id, user_id=another_user_model.id)
@pytest.mark.asyncio
async def test_check_list_permission_non_member_no_access_group_list(mock_db_session, db_list_group_model, another_user_model):
db_list_group_model.creator_id = user_model.id # Creator is not another_user_model
mock_list_fetch_result = AsyncMock()
mock_list_fetch_result.scalars.return_value.first.return_value = db_list_group_model
with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member:
mock_is_member.return_value = False # another_user_model is NOT a member
mock_db_session.execute.return_value = mock_list_fetch_result
with pytest.raises(ListPermissionError):
await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id)
@pytest.mark.asyncio
async def test_check_list_permission_list_not_found(mock_db_session, user_model):
mock_list_fetch_result = AsyncMock()
mock_list_fetch_result.scalars.return_value.first.return_value = None # List not found
mock_db_session.execute.return_value = mock_list_fetch_result
with pytest.raises(ListNotFoundError):
await check_list_permission(mock_db_session, 999, user_model.id)
# --- get_list_status Tests ---
@pytest.mark.asyncio
async def test_get_list_status_success(mock_db_session, db_list_personal_model):
list_updated_at = datetime.now(timezone.utc) - timezone.timedelta(hours=1)
item_updated_at = datetime.now(timezone.utc)
item_count = 5
db_list_personal_model.updated_at = list_updated_at
# Mock for ListModel.updated_at query
mock_list_updated_result = AsyncMock()
mock_list_updated_result.scalar_one_or_none.return_value = list_updated_at
# Mock for ItemModel status query
mock_item_status_result = AsyncMock()
# SQLAlchemy query for func.max and func.count returns a Row-like object or None
mock_item_status_row = MagicMock()
mock_item_status_row.latest_item_updated_at = item_updated_at
mock_item_status_row.item_count = item_count
mock_item_status_result.first.return_value = mock_item_status_row
mock_db_session.execute.side_effect = [mock_list_updated_result, mock_item_status_result]
status = await get_list_status(mock_db_session, db_list_personal_model.id)
assert status.list_updated_at == list_updated_at
assert status.latest_item_updated_at == item_updated_at
assert status.item_count == item_count
assert mock_db_session.execute.call_count == 2
@pytest.mark.asyncio
async def test_get_list_status_list_not_found(mock_db_session):
mock_list_updated_result = AsyncMock()
mock_list_updated_result.scalar_one_or_none.return_value = None # List not found
mock_db_session.execute.return_value = mock_list_updated_result
with pytest.raises(ListNotFoundError):
await get_list_status(mock_db_session, 999)
# TODO: Add more specific DB error tests (Operational, SQLAlchemyError, IntegrityError) for each function.
# TODO: Test check_list_permission with require_creator=True cases.
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.future import select
from decimal import Decimal, ROUND_HALF_UP
from datetime import datetime, timezone
from typing import List as PyList
from app.crud.settlement import (
create_settlement,
get_settlement_by_id,
get_settlements_for_group,
get_settlements_involving_user,
update_settlement,
delete_settlement
)
from app.schemas.expense import SettlementCreate, SettlementUpdate
from app.models import Settlement as SettlementModel, User as UserModel, Group as GroupModel
from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock()
session.execute = AsyncMock()
session.get = AsyncMock()
return session
@pytest.fixture
def settlement_create_data():
return SettlementCreate(
group_id=1,
paid_by_user_id=1,
paid_to_user_id=2,
amount=Decimal("10.50"),
settlement_date=datetime.now(timezone.utc),
description="Test settlement"
)
@pytest.fixture
def settlement_update_data():
return SettlementUpdate(
description="Updated settlement description",
settlement_date=datetime.now(timezone.utc),
version=1 # Assuming version is required for update
)
@pytest.fixture
def db_settlement_model():
return SettlementModel(
id=1,
group_id=1,
paid_by_user_id=1,
paid_to_user_id=2,
amount=Decimal("10.50"),
settlement_date=datetime.now(timezone.utc),
description="Original settlement",
version=1, # Initial version
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
payer=UserModel(id=1, name="Payer User"),
payee=UserModel(id=2, name="Payee User"),
group=GroupModel(id=1, name="Test Group")
)
@pytest.fixture
def payer_user_model():
return UserModel(id=1, name="Payer User", email="payer@example.com")
@pytest.fixture
def payee_user_model():
return UserModel(id=2, name="Payee User", email="payee@example.com")
@pytest.fixture
def group_model():
return GroupModel(id=1, name="Test Group")
# Tests for create_settlement
@pytest.mark.asyncio
async def test_create_settlement_success(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model):
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model] # Order of gets
created_settlement = await create_settlement(mock_db_session, settlement_create_data, current_user_id=1)
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
mock_db_session.refresh.assert_called_once()
assert created_settlement is not None
assert created_settlement.amount == settlement_create_data.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
assert created_settlement.paid_by_user_id == settlement_create_data.paid_by_user_id
assert created_settlement.paid_to_user_id == settlement_create_data.paid_to_user_id
@pytest.mark.asyncio
async def test_create_settlement_payer_not_found(mock_db_session, settlement_create_data):
mock_db_session.get.side_effect = [None, payee_user_model, group_model]
with pytest.raises(UserNotFoundError) as excinfo:
await create_settlement(mock_db_session, settlement_create_data, 1)
assert "Payer" in str(excinfo.value)
@pytest.mark.asyncio
async def test_create_settlement_payee_not_found(mock_db_session, settlement_create_data, payer_user_model):
mock_db_session.get.side_effect = [payer_user_model, None, group_model]
with pytest.raises(UserNotFoundError) as excinfo:
await create_settlement(mock_db_session, settlement_create_data, 1)
assert "Payee" in str(excinfo.value)
@pytest.mark.asyncio
async def test_create_settlement_group_not_found(mock_db_session, settlement_create_data, payer_user_model, payee_user_model):
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, None]
with pytest.raises(GroupNotFoundError):
await create_settlement(mock_db_session, settlement_create_data, 1)
@pytest.mark.asyncio
async def test_create_settlement_payer_equals_payee(mock_db_session, settlement_create_data, payer_user_model, group_model):
settlement_create_data.paid_to_user_id = settlement_create_data.paid_by_user_id
mock_db_session.get.side_effect = [payer_user_model, payer_user_model, group_model]
with pytest.raises(InvalidOperationError) as excinfo:
await create_settlement(mock_db_session, settlement_create_data, 1)
assert "Payer and Payee cannot be the same user" in str(excinfo.value)
@pytest.mark.asyncio
async def test_create_settlement_commit_failure(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model):
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model]
mock_db_session.commit.side_effect = Exception("DB commit failed")
with pytest.raises(InvalidOperationError) as excinfo:
await create_settlement(mock_db_session, settlement_create_data, 1)
assert "Failed to save settlement" in str(excinfo.value)
mock_db_session.rollback.assert_called_once()
# Tests for get_settlement_by_id
@pytest.mark.asyncio
async def test_get_settlement_by_id_found(mock_db_session, db_settlement_model):
mock_db_session.execute.return_value.scalars.return_value.first.return_value = db_settlement_model
settlement = await get_settlement_by_id(mock_db_session, 1)
assert settlement is not None
assert settlement.id == 1
mock_db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_settlement_by_id_not_found(mock_db_session):
mock_db_session.execute.return_value.scalars.return_value.first.return_value = None
settlement = await get_settlement_by_id(mock_db_session, 999)
assert settlement is None
# Tests for get_settlements_for_group
@pytest.mark.asyncio
async def test_get_settlements_for_group_success(mock_db_session, db_settlement_model):
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [db_settlement_model]
settlements = await get_settlements_for_group(mock_db_session, group_id=1)
assert len(settlements) == 1
assert settlements[0].group_id == 1
mock_db_session.execute.assert_called_once()
# Tests for get_settlements_involving_user
@pytest.mark.asyncio
async def test_get_settlements_involving_user_success(mock_db_session, db_settlement_model):
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [db_settlement_model]
settlements = await get_settlements_involving_user(mock_db_session, user_id=1)
assert len(settlements) == 1
assert settlements[0].paid_by_user_id == 1 or settlements[0].paid_to_user_id == 1
mock_db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_settlements_involving_user_with_group_filter(mock_db_session, db_settlement_model):
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [db_settlement_model]
settlements = await get_settlements_involving_user(mock_db_session, user_id=1, group_id=1)
assert len(settlements) == 1
# More specific assertions about the query would require deeper mocking of SQLAlchemy query construction
mock_db_session.execute.assert_called_once()
# Tests for update_settlement
@pytest.mark.asyncio
async def test_update_settlement_success(mock_db_session, db_settlement_model, settlement_update_data):
# Ensure settlement_update_data.version matches db_settlement_model.version
settlement_update_data.version = db_settlement_model.version
# Mock datetime.now()
fixed_datetime_now = datetime.now(timezone.utc)
with patch('app.crud.settlement.datetime', wraps=datetime) as mock_datetime:
mock_datetime.now.return_value = fixed_datetime_now
updated_settlement = await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
mock_db_session.commit.assert_called_once()
mock_db_session.refresh.assert_called_once()
assert updated_settlement.description == settlement_update_data.description
assert updated_settlement.settlement_date == settlement_update_data.settlement_date
assert updated_settlement.version == db_settlement_model.version + 1 # Version incremented
assert updated_settlement.updated_at == fixed_datetime_now
@pytest.mark.asyncio
async def test_update_settlement_version_mismatch(mock_db_session, db_settlement_model, settlement_update_data):
settlement_update_data.version = db_settlement_model.version + 1 # Mismatched version
with pytest.raises(InvalidOperationError) as excinfo:
await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
assert "version does not match" in str(excinfo.value)
@pytest.mark.asyncio
async def test_update_settlement_disallowed_field(mock_db_session, db_settlement_model):
# Create an update payload with a field not allowed to be updated, e.g., 'amount'
invalid_update_data = SettlementUpdate(amount=Decimal("100.00"), version=db_settlement_model.version)
with pytest.raises(InvalidOperationError) as excinfo:
await update_settlement(mock_db_session, db_settlement_model, invalid_update_data)
assert "Field 'amount' cannot be updated" in str(excinfo.value)
@pytest.mark.asyncio
async def test_update_settlement_commit_failure(mock_db_session, db_settlement_model, settlement_update_data):
settlement_update_data.version = db_settlement_model.version
mock_db_session.commit.side_effect = Exception("DB commit failed")
with pytest.raises(InvalidOperationError) as excinfo:
await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
assert "Failed to update settlement" in str(excinfo.value)
mock_db_session.rollback.assert_called_once()
# Tests for delete_settlement
@pytest.mark.asyncio
async def test_delete_settlement_success(mock_db_session, db_settlement_model):
await delete_settlement(mock_db_session, db_settlement_model)
mock_db_session.delete.assert_called_once_with(db_settlement_model)
mock_db_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_delete_settlement_success_with_version_check(mock_db_session, db_settlement_model):
await delete_settlement(mock_db_session, db_settlement_model, expected_version=db_settlement_model.version)
mock_db_session.delete.assert_called_once_with(db_settlement_model)
mock_db_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_delete_settlement_version_mismatch(mock_db_session, db_settlement_model):
with pytest.raises(InvalidOperationError) as excinfo:
await delete_settlement(mock_db_session, db_settlement_model, expected_version=db_settlement_model.version + 1)
assert "Expected version" in str(excinfo.value)
assert "does not match current version" in str(excinfo.value)
mock_db_session.delete.assert_not_called()
@pytest.mark.asyncio
async def test_delete_settlement_commit_failure(mock_db_session, db_settlement_model):
mock_db_session.commit.side_effect = Exception("DB commit failed")
with pytest.raises(InvalidOperationError) as excinfo:
await delete_settlement(mock_db_session, db_settlement_model)
assert "Failed to delete settlement" in str(excinfo.value)
mock_db_session.rollback.assert_called_once()
import pytest
from unittest.mock import AsyncMock, MagicMock
from sqlalchemy.exc import IntegrityError, OperationalError
from app.crud.user import get_user_by_email, create_user
from app.schemas.user import UserCreate
from app.models import User as UserModel
from app.core.exceptions import (
UserCreationError,
EmailAlreadyRegisteredError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError
)
# Fixtures
@pytest.fixture
def mock_db_session():
return AsyncMock()
@pytest.fixture
def user_create_data():
return UserCreate(email="test@example.com", password="password123", name="Test User")
@pytest.fixture
def existing_user_data():
return UserModel(id=1, email="exists@example.com", password_hash="hashed_password", name="Existing User")
# Tests for get_user_by_email
@pytest.mark.asyncio
async def test_get_user_by_email_found(mock_db_session, existing_user_data):
mock_db_session.execute.return_value.scalars.return_value.first.return_value = existing_user_data
user = await get_user_by_email(mock_db_session, "exists@example.com")
assert user is not None
assert user.email == "exists@example.com"
mock_db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_user_by_email_not_found(mock_db_session):
mock_db_session.execute.return_value.scalars.return_value.first.return_value = None
user = await get_user_by_email(mock_db_session, "nonexistent@example.com")
assert user is None
mock_db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_user_by_email_db_connection_error(mock_db_session):
mock_db_session.execute.side_effect = OperationalError("mock_op_error", "params", "orig")
with pytest.raises(DatabaseConnectionError):
await get_user_by_email(mock_db_session, "test@example.com")
@pytest.mark.asyncio
async def test_get_user_by_email_db_query_error(mock_db_session):
# Simulate a generic SQLAlchemyError that is not OperationalError
mock_db_session.execute.side_effect = IntegrityError("mock_sql_error", "params", "orig") # Using IntegrityError as an example of SQLAlchemyError
with pytest.raises(DatabaseQueryError):
await get_user_by_email(mock_db_session, "test@example.com")
# Tests for create_user
@pytest.mark.asyncio
async def test_create_user_success(mock_db_session, user_create_data):
# The actual user object returned would be created by SQLAlchemy based on db_user
# We mock the process: db.add is called, then db.flush, then db.refresh updates db_user
async def mock_refresh(user_model_instance):
user_model_instance.id = 1 # Simulate DB assigning an ID
# Simulate other db-generated fields if necessary
return None
mock_db_session.refresh = AsyncMock(side_effect=mock_refresh)
mock_db_session.flush = AsyncMock()
mock_db_session.add = MagicMock()
created_user = await create_user(mock_db_session, user_create_data)
mock_db_session.add.assert_called_once()
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once()
assert created_user is not None
assert created_user.email == user_create_data.email
assert created_user.name == user_create_data.name
assert hasattr(created_user, 'id') # Check if ID was assigned (simulated by mock_refresh)
# Password hash check would be more involved, ensure hash_password was called correctly
# For now, we assume hash_password works as intended and is tested elsewhere.
@pytest.mark.asyncio
async def test_create_user_email_already_registered(mock_db_session, user_create_data):
mock_db_session.flush.side_effect = IntegrityError("mock error (unique constraint)", "params", "orig")
with pytest.raises(EmailAlreadyRegisteredError):
await create_user(mock_db_session, user_create_data)
@pytest.mark.asyncio
async def test_create_user_db_integrity_error_not_unique(mock_db_session, user_create_data):
# Simulate an IntegrityError that is not related to a unique constraint
mock_db_session.flush.side_effect = IntegrityError("mock error (not unique constraint)", "params", "orig")
with pytest.raises(DatabaseIntegrityError):
await create_user(mock_db_session, user_create_data)
@pytest.mark.asyncio
async def test_create_user_db_connection_error(mock_db_session, user_create_data):
mock_db_session.begin.side_effect = OperationalError("mock_op_error", "params", "orig")
with pytest.raises(DatabaseConnectionError):
await create_user(mock_db_session, user_create_data)
# also test OperationalError on flush
mock_db_session.begin.side_effect = None # reset side effect
mock_db_session.flush.side_effect = OperationalError("mock_op_error", "params", "orig")
with pytest.raises(DatabaseConnectionError):
await create_user(mock_db_session, user_create_data)
@pytest.mark.asyncio
async def test_create_user_db_transaction_error(mock_db_session, user_create_data):
# Simulate a generic SQLAlchemyError on flush that is not IntegrityError or OperationalError
mock_db_session.flush.side_effect = UserCreationError("Simulated non-specific SQLAlchemyError") # Or any other SQLAlchemyError
with pytest.raises(DatabaseTransactionError):
await create_user(mock_db_session, user_create_data)
## Polished PWA Plan: Shared Lists & Household Management
## 1. Product Overview
**Concept:**
Develop a Progressive Web App (PWA) focused on simplifying household coordination. Users can:
- Create, manage, and **share** shopping lists within defined groups (e.g., households, trip members).
- Capture images of receipts or shopping lists via the browser and extract items using **Google Cloud Vision API** for OCR.
- Track item costs on shared lists and easily split expenses among group participants.
- (Future) Manage and assign household chores.
**Target Audience:** Households, roommates, families, groups organizing shared purchases.
**UX Philosophy:**
- **User-Centered & Collaborative:** Design intuitive flows for both individual use and group collaboration with minimal friction.
- **Native-like PWA Experience:** Leverage service workers, caching, and manifest files for reliable offline use, installability, and smooth performance.
- **Clarity & Accessibility:** Prioritize high contrast, legible typography, straightforward navigation, and adherence to accessibility standards (WCAG).
- **Informative Feedback:** Provide clear visual feedback for actions (animations, loading states), OCR processing status, and data synchronization, including handling potential offline conflicts gracefully.
---
## 2. MVP Scope (Refined & Focused)
The MVP will focus on delivering a robust, shareable shopping list experience with integrated OCR and cost splitting, built as a high-quality PWA. **Chore management is deferred post-MVP** to ensure a polished core experience at launch.
1. **Shared Shopping List Management:**
* **Core Features:** Create, update, delete lists and items. Mark items as complete. Basic item sorting/reordering (e.g., manual drag-and-drop).
* **Collaboration:** Share lists within user-defined groups. Real-time (or near real-time) updates visible to group members (via polling or simple WebSocket for MVP).
* **PWA/UX:** Responsive design, offline access to cached lists, basic conflict indication if offline edits clash (e.g., "Item updated by another user, refresh needed").
2. **OCR Integration (Google Cloud Vision):**
* **Core Features:** Capture images via browser (`` or `getUserMedia`). Upload images to the FastAPI backend. Backend securely calls **Google Cloud Vision API (Text Detection / Document Text Detection)**. Process results, suggest items to add to the list.
* **PWA/UX:** Clear instructions for image capture. Progress indicators during upload/processing. Display editable OCR results for user review and confirmation before adding to the list. Handle potential API errors or low-confidence results gracefully.
3. **Cost Splitting (Integrated with Lists):**
* **Core Features:** Assign prices to items *on the shopping list* as they are purchased. Add participants (from the shared group) to a list's expense split. Calculate totals per list and simple equal splits per participant.
* **PWA/UX:** Clear display of totals and individual shares. Easy interface for marking items as bought and adding their price.
4. **User Authentication & Group Management:**
* **Core Features:** Secure email/password signup & login (JWT-based). Ability to create simple groups (e.g., "Household"). Mechanism to invite/add users to a group (e.g., unique invite code/link). Basic role distinction (e.g., group owner/admin, member) if necessary for managing participants.
* **PWA/UX:** Minimalist forms, clear inline validation, smooth onboarding explaining the group concept.
5. **Core PWA Functionality:**
* **Core Features:** Installable via `manifest.json`. Offline access via service worker caching (app shell, static assets, user data). Basic background sync strategy for offline actions (e.g., "last write wins" for simple edits, potentially queueing adds/deletes).
---
## 3. Feature Breakdown & UX Enhancements (MVP Focus)
### A. Shared Shopping Lists
- **Screens:** Dashboard (list overview), List Detail (items), Group Management.
- **Flows:** Create list -> (Optional) Share with group -> Add/edit/check items -> See updates from others -> Mark list complete.
- **UX Focus:** Smooth transitions, clear indication of shared status, offline caching, simple conflict notification (not full resolution in MVP).
### B. OCR with Google Cloud Vision
- **Flow:** Tap "Add via OCR" -> Capture/Select Image -> Upload -> Show Progress -> Display Review Screen (editable text boxes for potential items) -> User confirms/edits -> Items added to list.
- **UX Focus:** Clear instructions, robust error handling (API errors, poor image quality feedback if possible), easy correction interface, manage user expectations regarding OCR accuracy. Monitor API costs/quotas.
### C. Integrated Cost Splitting
- **Flow:** Open shared list -> Mark item "bought" -> Input price -> View updated list total -> Go to "Split Costs" view for the list -> Confirm participants (group members) -> See calculated equal split.
- **UX Focus:** Seamless transition from shopping to cost entry. Clear, real-time calculation display. Simple participant management within the list context.
### D. User Auth & Groups
- **Flow:** Sign up/Login -> Create a group -> Invite members (e.g., share code) -> Member joins group -> Access shared lists.
- **UX Focus:** Secure and straightforward auth. Simple group creation and joining process. Clear visibility of group members.
### E. PWA Essentials
- **Manifest:** Define app name, icons, theme, display mode.
- **Service Worker:** Cache app shell, assets, API responses (user data). Implement basic offline sync queue for actions performed offline (e.g., adding/checking items). Define a clear sync conflict strategy (e.g., last-write-wins, notify user on conflict).
---
## 4. Architecture & Technology Stack
### Frontend: Svelte PWA
- **Framework:** Svelte/SvelteKit (Excellent for performant, component-based PWAs).
- **State Management:** Svelte Stores for managing UI state and cached data.
- **PWA Tools:** Workbox.js (via SvelteKit integration or standalone) for robust service worker generation and caching strategies.
- **Styling:** Tailwind CSS or standard CSS with scoped styles.
- **UX:** Design system (e.g., using Figma), Storybook for component development.
### Backend: FastAPI & PostgreSQL
- **Framework:** FastAPI (High performance, async support, auto-docs, Pydantic validation).
- **Database:** PostgreSQL (Reliable, supports JSONB for flexibility if needed). Schema designed to handle users, groups, lists, items, costs, and relationships. Basic indexing on foreign keys and frequently queried fields (user IDs, group IDs, list IDs).
- **ORM:** SQLAlchemy (async support with v2.0+) or Tortoise ORM (async-native). Alembic for migrations.
- **OCR Integration:** Use the official **Google Cloud Client Libraries for Python** to interact with the Vision API. Implement robust error handling, retries, and potentially rate limiting/cost control logic. Ensure API calls are `async` to avoid blocking.
- **Authentication:** JWT tokens for stateless session management.
- **Deployment:** Containerize using Docker/Docker Compose for development and deployment consistency. Deploy on a scalable cloud platform (e.g., Google Cloud Run, AWS Fargate, DigitalOcean App Platform).
- **Monitoring:** Logging (standard Python logging), Error Tracking (Sentry), Performance Monitoring (Prometheus/Grafana if needed later).
---
# Finalized User Stories, Flow Mapping, Sharing Model & Sync, Tech Stack & Initial Architecture Diagram
## 1. User Stories
### Authentication & User Management
- As a new user, I want to sign up with my email so I can create and manage shopping lists
- As a returning user, I want to log in securely to access my lists and groups
- As a user, I want to reset my password if I forget it
- As a user, I want to edit my profile information (name, avatar)
### Group Management
- As a user, I want to create a new group (e.g., "Household", "Roommates") to organize shared lists
- As a group creator, I want to invite others to join my group via a shareable link/code
- As an invitee, I want to easily join a group by clicking a link or entering a code
- As a group owner, I want to remove members if needed
- As a user, I want to leave a group I no longer wish to be part of
- As a user, I want to see all groups I belong to and switch between them
### List Management
- As a user, I want to create a personal shopping list with a title and optional description
- As a user, I want to share a list with a specific group so members can collaborate
- As a user, I want to view all my lists (personal and shared) from a central dashboard
- As a user, I want to archive or delete lists I no longer need
- As a user, I want to mark a list as "shopping complete" when finished
- As a user, I want to see which group a list is shared with
### Item Management
- As a user, I want to add items to a list with names and optional quantities
- As a user, I want to mark items as purchased when shopping
- As a user, I want to edit item details (name, quantity, notes)
- As a user, I want to delete items from a list
- As a user, I want to reorder items on my list for shopping efficiency
- As a user, I want to see who added or marked items as purchased in shared lists
### OCR Integration
- As a user, I want to capture a photo of a physical shopping list or receipt
- As a user, I want the app to extract text and convert it into list items
- As a user, I want to review and edit OCR results before adding to my list
- As a user, I want clear feedback on OCR processing status
- As a user, I want to retry OCR if the results aren't satisfactory
### Cost Splitting
- As a user, I want to add prices to items as I purchase them
- As a user, I want to see the total cost of all purchased items in a list
- As a user, I want to split costs equally among group members
- As a user, I want to see who owes what amount based on the split
- As a user, I want to mark expenses as settled
### PWA & Offline Experience
- As a user, I want to install the app on my home screen for quick access
- As a user, I want to view and edit my lists even when offline
- As a user, I want my changes to sync automatically when I'm back online
- As a user, I want to be notified if my offline changes conflict with others' changes
# docker-compose.yml (in project root)
version: '3.8'
services:
db:
image: postgres:15 # Use a specific PostgreSQL version
container_name: postgres_db
environment:
POSTGRES_USER: dev_user # Define DB user
POSTGRES_PASSWORD: dev_password # Define DB password
POSTGRES_DB: dev_db # Define Database name
volumes:
- postgres_data:/var/lib/postgresql/data # Persist data using a named volume
ports:
- "5432:5432" # Expose PostgreSQL port to host (optional, for direct access)
healthcheck:
test: ["CMD-SHELL", "pg_isready -U $${POSTGRES_USER} -d $${POSTGRES_DB}"]
interval: 10s
timeout: 5s
retries: 5
start_period: 10s
restart: unless-stopped
backend:
container_name: fastapi_backend
build:
context: ./be # Path to the directory containing the Dockerfile
dockerfile: Dockerfile
volumes:
# Mount local code into the container for development hot-reloading
# The code inside the container at /app will mirror your local ./be directory
- ./be:/app
ports:
- "8000:8000" # Map container port 8000 to host port 8000
environment:
# Pass the database URL to the backend container
# Uses the service name 'db' as the host, and credentials defined above
# IMPORTANT: Use the correct async driver prefix if your app needs it!
- DATABASE_URL=postgresql+asyncpg://dev_user:dev_password@db:5432/dev_db
# Add other environment variables needed by the backend here
# - SOME_OTHER_VAR=some_value
depends_on:
db: # Wait for the db service to be healthy before starting backend
condition: service_healthy
command: ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] # Override CMD for development reload
restart: unless-stopped
pgadmin: # Optional service for database administration
image: dpage/pgadmin4:latest
container_name: pgadmin4_server
environment:
PGADMIN_DEFAULT_EMAIL: admin@example.com # Change as needed
PGADMIN_DEFAULT_PASSWORD: admin_password # Change to a secure password
PGADMIN_CONFIG_SERVER_MODE: 'False' # Run in Desktop mode for easier local dev server setup
volumes:
- pgadmin_data:/var/lib/pgadmin # Persist pgAdmin configuration
ports:
- "5050:80" # Map container port 80 to host port 5050
depends_on:
- db # Depends on the database service
restart: unless-stopped
volumes: # Define named volumes for data persistence
postgres_data:
pgadmin_data:
[*.{js,jsx,mjs,cjs,ts,tsx,mts,cts,vue}]
charset = utf-8
indent_size = 2
indent_style = space
end_of_line = lf
insert_final_newline = true
trim_trailing_whitespace = true
{
"$schema": "https://json.schemastore.org/prettierrc",
"singleQuote": true,
"printWidth": 100
}
{
"recommendations": [
"dbaeumer.vscode-eslint",
"esbenp.prettier-vscode",
"editorconfig.editorconfig",
"vue.volar",
"wayou.vscode-todo-highlight"
],
"unwantedRecommendations": [
"octref.vetur",
"hookyqr.beautify",
"dbaeumer.jshint",
"ms-vscode.vscode-typescript-tslint-plugin"
]
}
{
"editor.bracketPairColorization.enabled": true,
"editor.guides.bracketPairs": true,
"editor.formatOnSave": true,
"editor.defaultFormatter": "esbenp.prettier-vscode",
"editor.codeActionsOnSave": [
"source.fixAll.eslint"
],
"eslint.validate": [
"javascript",
"javascriptreact",
"typescript",
"vue"
],
"typescript.tsdk": "node_modules/typescript/lib"
}
import js from '@eslint/js'
import globals from 'globals'
import pluginVue from 'eslint-plugin-vue'
import pluginQuasar from '@quasar/app-vite/eslint'
import { defineConfigWithVueTs, vueTsConfigs } from '@vue/eslint-config-typescript'
import prettierSkipFormatting from '@vue/eslint-config-prettier/skip-formatting'
export default defineConfigWithVueTs(
{
/**
* Ignore the following files.
* Please note that pluginQuasar.configs.recommended() already ignores
* the "node_modules" folder for you (and all other Quasar project
* relevant folders and files).
*
* ESLint requires "ignores" key to be the only one in this object
*/
// ignores: []
},
pluginQuasar.configs.recommended(),
js.configs.recommended,
/**
* https://eslint.vuejs.org
*
* pluginVue.configs.base
* -> Settings and rules to enable correct ESLint parsing.
* pluginVue.configs[ 'flat/essential']
* -> base, plus rules to prevent errors or unintended behavior.
* pluginVue.configs["flat/strongly-recommended"]
* -> Above, plus rules to considerably improve code readability and/or dev experience.
* pluginVue.configs["flat/recommended"]
* -> Above, plus rules to enforce subjective community defaults to ensure consistency.
*/
pluginVue.configs[ 'flat/essential' ],
{
files: ['**/*.ts', '**/*.vue'],
rules: {
'@typescript-eslint/consistent-type-imports': [
'error',
{ prefer: 'type-imports' }
],
}
},
// https://github.com/vuejs/eslint-config-typescript
vueTsConfigs.recommendedTypeChecked,
{
languageOptions: {
ecmaVersion: 'latest',
sourceType: 'module',
globals: {
...globals.browser,
...globals.node, // SSR, Electron, config files
process: 'readonly', // process.env.*
ga: 'readonly', // Google Analytics
cordova: 'readonly',
Capacitor: 'readonly',
chrome: 'readonly', // BEX related
browser: 'readonly' // BEX related
}
},
// add your custom rules here
rules: {
'prefer-promise-reject-errors': 'off',
// allow debugger during development only
'no-debugger': process.env.NODE_ENV === 'production' ? 'error' : 'off'
}
},
{
files: [ 'src-pwa/custom-service-worker.ts' ],
languageOptions: {
globals: {
...globals.serviceworker
}
}
},
prettierSkipFormatting
)
<%= productName %>
// https://github.com/michael-ciniawsky/postcss-load-config
import autoprefixer from 'autoprefixer'
// import rtlcss from 'postcss-rtlcss'
export default {
plugins: [
// https://github.com/postcss/autoprefixer
autoprefixer({
overrideBrowserslist: [
'last 4 Chrome versions',
'last 4 Firefox versions',
'last 4 Edge versions',
'last 4 Safari versions',
'last 4 Android versions',
'last 4 ChromeAndroid versions',
'last 4 FirefoxAndroid versions',
'last 4 iOS versions'
]
}),
// https://github.com/elchininet/postcss-rtlcss
// If you want to support RTL css, then
// 1. yarn/pnpm/bun/npm install postcss-rtlcss
// 2. optionally set quasar.config.js > framework > lang to an RTL language
// 3. uncomment the following line (and its import statement above):
// rtlcss()
]
}
{
"orientation": "portrait",
"background_color": "#ffffff",
"theme_color": "#027be3",
"icons": [
{
"src": "icons/icon-128x128.png",
"sizes": "128x128",
"type": "image/png"
},
{
"src": "icons/icon-192x192.png",
"sizes": "192x192",
"type": "image/png"
},
{
"src": "icons/icon-256x256.png",
"sizes": "256x256",
"type": "image/png"
},
{
"src": "icons/icon-384x384.png",
"sizes": "384x384",
"type": "image/png"
},
{
"src": "icons/icon-512x512.png",
"sizes": "512x512",
"type": "image/png"
}
]
}
declare namespace NodeJS {
interface ProcessEnv {
SERVICE_WORKER_FILE: string;
PWA_FALLBACK_HTML: string;
PWA_SERVICE_WORKER_REGEX: string;
}
}
import { register } from 'register-service-worker';
// The ready(), registered(), cached(), updatefound() and updated()
// events passes a ServiceWorkerRegistration instance in their arguments.
// ServiceWorkerRegistration: https://developer.mozilla.org/en-US/docs/Web/API/ServiceWorkerRegistration
register(process.env.SERVICE_WORKER_FILE, {
// The registrationOptions object will be passed as the second argument
// to ServiceWorkerContainer.register()
// https://developer.mozilla.org/en-US/docs/Web/API/ServiceWorkerContainer/register#Parameter
// registrationOptions: { scope: './' },
ready (/* registration */) {
// console.log('Service worker is active.')
},
registered (/* registration */) {
// console.log('Service worker has been registered.')
},
cached (/* registration */) {
// console.log('Content has been cached for offline use.')
},
updatefound (/* registration */) {
// console.log('New content is downloading.')
},
updated (/* registration */) {
// console.log('New content is available; please refresh.')
},
offline () {
// console.log('No internet connection found. App is running in offline mode.')
},
error (/* err */) {
// console.error('Error during service worker registration:', err)
},
});
{
"extends": "../tsconfig.json",
"compilerOptions": {
"lib": ["WebWorker", "ESNext"]
},
"include": ["*.ts", "*.d.ts"]
}
import { defineBoot } from '#q-app/wrappers';
import { createI18n } from 'vue-i18n';
import messages from 'src/i18n';
export type MessageLanguages = keyof typeof messages;
// Type-define 'en-US' as the master schema for the resource
export type MessageSchema = typeof messages['en-US'];
// See https://vue-i18n.intlify.dev/guide/advanced/typescript.html#global-resource-schema-type-definition
/* eslint-disable @typescript-eslint/no-empty-object-type */
declare module 'vue-i18n' {
// define the locale messages schema
export interface DefineLocaleMessage extends MessageSchema {}
// define the datetime format schema
export interface DefineDateTimeFormat {}
// define the number format schema
export interface DefineNumberFormat {}
}
/* eslint-enable @typescript-eslint/no-empty-object-type */
export default defineBoot(({ app }) => {
const i18n = createI18n<{ message: MessageSchema }, MessageLanguages>({
locale: 'en-US',
legacy: false,
messages,
});
// Set i18n instance on app
app.use(i18n);
});
{{ title }}{{ caption }}
export interface Todo {
id: number;
content: string;
}
export interface Meta {
totalCount: number;
}
import { api } from 'boot/axios';
import { API_BASE_URL, API_VERSION, API_ENDPOINTS } from './api-config';
// Helper function to get full API URL
export const getApiUrl = (endpoint: string): string => {
return `${API_BASE_URL}/api/${API_VERSION}${endpoint}`;
};
// Helper function to make API calls
export const apiClient = {
get: (endpoint: string, config = {}) => api.get(getApiUrl(endpoint), config),
post: (endpoint: string, data = {}, config = {}) => api.post(getApiUrl(endpoint), data, config),
put: (endpoint: string, data = {}, config = {}) => api.put(getApiUrl(endpoint), data, config),
patch: (endpoint: string, data = {}, config = {}) => api.patch(getApiUrl(endpoint), data, config),
delete: (endpoint: string, config = {}) => api.delete(getApiUrl(endpoint), config),
};
export { API_ENDPOINTS };
// app global css in SCSS form
// Quasar SCSS (& Sass) Variables
// --------------------------------------------------
// To customize the look and feel of this app, you can override
// the Sass/SCSS variables found in Quasar's source Sass/SCSS files.
// Check documentation for full list of Quasar variables
// Your own variables (that are declared here) and Quasar's own
// ones will be available out of the box in your .vue/.scss/.sass files
// It's highly recommended to change the default colors
// to match your app's branding.
// Tip: Use the "Theme Builder" on Quasar's documentation website.
$primary : #1976D2;
$secondary : #26A69A;
$accent : #9C27B0;
$dark : #1D1D1D;
$dark-page : #121212;
$positive : #21BA45;
$negative : #C10015;
$info : #31CCEC;
$warning : #F2C037;
declare namespace NodeJS {
interface ProcessEnv {
NODE_ENV: string;
VUE_ROUTER_MODE: 'hash' | 'history' | 'abstract' | undefined;
VUE_ROUTER_BASE: string | undefined;
}
}
// This is just an example,
// so you can safely delete all default props below
export default {
failed: 'Action failed',
success: 'Action was successful'
};
import enUS from './en-US';
export default {
'en-US': enUS
};
404
Oops. Nothing here...
Login
Don't have an account? Sign up
Sign Up
Already have an account? Login
import { defineRouter } from '#q-app/wrappers';
import {
createMemoryHistory,
createRouter,
createWebHashHistory,
createWebHistory,
} from 'vue-router';
import routes from './routes';
import { useAuthStore } from 'stores/auth';
/*
* If not building with SSR mode, you can
* directly export the Router instantiation;
*
* The function below can be async too; either use
* async/await or return a Promise which resolves
* with the Router instance.
*/
export default defineRouter(function (/* { store, ssrContext } */) {
const createHistory = process.env.SERVER
? createMemoryHistory
: process.env.VUE_ROUTER_MODE === 'history'
? createWebHistory
: createWebHashHistory;
const Router = createRouter({
scrollBehavior: () => ({ left: 0, top: 0 }),
routes,
// Leave this as is and make changes in quasar.conf.js instead!
// quasar.conf.js -> build -> vueRouterMode
// quasar.conf.js -> build -> publicPath
history: createHistory(process.env.VUE_ROUTER_BASE),
});
// Navigation guard to check authentication
Router.beforeEach((to, from, next) => {
const authStore = useAuthStore();
const isAuthenticated = authStore.isAuthenticated;
// Define public routes that don't require authentication
const publicRoutes = ['/login', '/signup'];
// Check if the route requires authentication
const requiresAuth = !publicRoutes.includes(to.path);
if (requiresAuth && !isAuthenticated) {
// Redirect to login if trying to access protected route without authentication
next({ path: '/login', query: { redirect: to.fullPath } });
} else if (!requiresAuth && isAuthenticated) {
// Redirect to home if trying to access login/signup while authenticated
next({ path: '/' });
} else {
// Proceed with navigation
next();
}
});
return Router;
});
import { defineStore } from '#q-app/wrappers'
import { createPinia } from 'pinia'
/*
* When adding new properties to stores, you should also
* extend the `PiniaCustomProperties` interface.
* @see https://pinia.vuejs.org/core-concepts/plugins.html#typing-new-store-properties
*/
declare module 'pinia' {
// eslint-disable-next-line @typescript-eslint/no-empty-object-type
export interface PiniaCustomProperties {
// add your custom properties here, if any
}
}
/*
* If not building with SSR mode, you can
* directly export the Store instantiation;
*
* The function below can be async too; either use
* async/await or return a Promise which resolves
* with the Store instance.
*/
export default defineStore((/* { ssrContext } */) => {
const pinia = createPinia()
// You can add Pinia plugins here
// pinia.use(SomePiniaPlugin)
return pinia
})
# app/api/dependencies.py
import logging
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
from jose import JWTError
from app.database import get_db
from app.core.security import verify_access_token
from app.crud import user as crud_user
from app.models import User as UserModel # Import the SQLAlchemy model
from app.config import settings
logger = logging.getLogger(__name__)
# Define the OAuth2 scheme
# tokenUrl should point to your login endpoint relative to the base path
# It's used by Swagger UI for the "Authorize" button flow.
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=settings.OAUTH2_TOKEN_URL)
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db)
) -> UserModel:
"""
Dependency to get the current user based on the JWT token.
- Extracts token using OAuth2PasswordBearer.
- Verifies the token (signature, expiry).
- Fetches the user from the database based on the token's subject (email).
- Raises HTTPException 401 if any step fails.
Returns:
The authenticated user's database model instance.
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.AUTH_CREDENTIALS_ERROR,
headers={settings.AUTH_HEADER_NAME: settings.AUTH_HEADER_PREFIX},
)
payload = verify_access_token(token)
if payload is None:
logger.warning("Token verification failed (invalid, expired, or malformed).")
raise credentials_exception
email: Optional[str] = payload.get("sub")
if email is None:
logger.error("Token payload missing 'sub' (subject/email).")
raise credentials_exception # Token is malformed
# Fetch user from database
user = await crud_user.get_user_by_email(db, email=email)
if user is None:
logger.warning(f"User corresponding to token subject not found: {email}")
# Could happen if user deleted after token issuance
raise credentials_exception # Treat as invalid credentials
logger.debug(f"Authenticated user retrieved: {user.email} (ID: {user.id})")
return user
# Optional: Dependency for getting the *active* current user
# You might add an `is_active` flag to your User model later
# async def get_current_active_user(
# current_user: UserModel = Depends(get_current_user)
# ) -> UserModel:
# if not current_user.is_active: # Assuming an is_active attribute
# logger.warning(f"Authentication attempt by inactive user: {current_user.email}")
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user")
# return current_user
# app/api/v1/endpoints/groups.py
import logging
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.api.dependencies import get_current_user
from app.models import User as UserModel, UserRoleEnum # Import model and enum
from app.schemas.group import GroupCreate, GroupPublic
from app.schemas.invite import InviteCodePublic
from app.schemas.message import Message # For simple responses
from app.crud import group as crud_group
from app.crud import invite as crud_invite
from app.core.exceptions import (
GroupNotFoundError,
GroupPermissionError,
GroupMembershipError,
GroupOperationError,
GroupValidationError,
InviteCreationError
)
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post(
"", # Route relative to prefix "/groups"
response_model=GroupPublic,
status_code=status.HTTP_201_CREATED,
summary="Create New Group",
tags=["Groups"]
)
async def create_group(
group_in: GroupCreate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""Creates a new group, adding the creator as the owner."""
logger.info(f"User {current_user.email} creating group: {group_in.name}")
created_group = await crud_group.create_group(db=db, group_in=group_in, creator_id=current_user.id)
# Load members explicitly if needed for the response (optional here)
# created_group = await crud_group.get_group_by_id(db, created_group.id)
return created_group
@router.get(
"", # Route relative to prefix "/groups"
response_model=List[GroupPublic],
summary="List User's Groups",
tags=["Groups"]
)
async def read_user_groups(
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""Retrieves all groups the current user is a member of."""
logger.info(f"Fetching groups for user: {current_user.email}")
groups = await crud_group.get_user_groups(db=db, user_id=current_user.id)
return groups
@router.get(
"/{group_id}",
response_model=GroupPublic,
summary="Get Group Details",
tags=["Groups"]
)
async def read_group(
group_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""Retrieves details for a specific group, including members, if the user is part of it."""
logger.info(f"User {current_user.email} requesting details for group ID: {group_id}")
# Check if user is a member first
is_member = await crud_group.is_user_member(db=db, group_id=group_id, user_id=current_user.id)
if not is_member:
logger.warning(f"Access denied: User {current_user.email} not member of group {group_id}")
raise GroupMembershipError(group_id, "view group details")
group = await crud_group.get_group_by_id(db=db, group_id=group_id)
if not group:
logger.error(f"Group {group_id} requested by member {current_user.email} not found (data inconsistency?)")
raise GroupNotFoundError(group_id)
return group
@router.post(
"/{group_id}/invites",
response_model=InviteCodePublic,
summary="Create Group Invite",
tags=["Groups", "Invites"]
)
async def create_group_invite(
group_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""Generates a new invite code for the group. Requires owner/admin role (MVP: owner only)."""
logger.info(f"User {current_user.email} attempting to create invite for group {group_id}")
user_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=current_user.id)
# --- Permission Check (MVP: Owner only) ---
if user_role != UserRoleEnum.owner:
logger.warning(f"Permission denied: User {current_user.email} (role: {user_role}) cannot create invite for group {group_id}")
raise GroupPermissionError(group_id, "create invites")
# Check if group exists (implicitly done by role check, but good practice)
group = await crud_group.get_group_by_id(db, group_id)
if not group:
raise GroupNotFoundError(group_id)
invite = await crud_invite.create_invite(db=db, group_id=group_id, creator_id=current_user.id)
if not invite:
logger.error(f"Failed to generate unique invite code for group {group_id}")
raise InviteCreationError(group_id)
logger.info(f"User {current_user.email} created invite code for group {group_id}")
return invite
@router.delete(
"/{group_id}/leave",
response_model=Message,
summary="Leave Group",
tags=["Groups"]
)
async def leave_group(
group_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""Removes the current user from the specified group."""
logger.info(f"User {current_user.email} attempting to leave group {group_id}")
user_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=current_user.id)
if user_role is None:
raise GroupMembershipError(group_id, "leave (you are not a member)")
# --- MVP: Prevent owner leaving if they are the last member/owner ---
if user_role == UserRoleEnum.owner:
member_count = await crud_group.get_group_member_count(db, group_id)
# More robust check: count owners. For now, just check member count.
if member_count <= 1:
logger.warning(f"Owner {current_user.email} attempted to leave group {group_id} as last member.")
raise GroupValidationError("Owner cannot leave the group as the last member. Delete the group or transfer ownership.")
# Proceed with removal
deleted = await crud_group.remove_user_from_group(db, group_id=group_id, user_id=current_user.id)
if not deleted:
# Should not happen if role check passed, but handle defensively
logger.error(f"Failed to remove user {current_user.email} from group {group_id} despite being a member.")
raise GroupOperationError("Failed to leave group")
logger.info(f"User {current_user.email} successfully left group {group_id}")
return Message(detail="Successfully left the group")
# --- Optional: Remove Member Endpoint ---
@router.delete(
"/{group_id}/members/{user_id_to_remove}",
response_model=Message,
summary="Remove Member From Group (Owner Only)",
tags=["Groups"]
)
async def remove_group_member(
group_id: int,
user_id_to_remove: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""Removes a specified user from the group. Requires current user to be owner."""
logger.info(f"Owner {current_user.email} attempting to remove user {user_id_to_remove} from group {group_id}")
owner_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=current_user.id)
# --- Permission Check ---
if owner_role != UserRoleEnum.owner:
logger.warning(f"Permission denied: User {current_user.email} (role: {owner_role}) cannot remove members from group {group_id}")
raise GroupPermissionError(group_id, "remove members")
# Prevent owner removing themselves via this endpoint
if current_user.id == user_id_to_remove:
raise GroupValidationError("Owner cannot remove themselves using this endpoint. Use 'Leave Group' instead.")
# Check if target user is actually in the group
target_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=user_id_to_remove)
if target_role is None:
raise GroupMembershipError(group_id, "remove this user (they are not a member)")
# Proceed with removal
deleted = await crud_group.remove_user_from_group(db, group_id=group_id, user_id=user_id_to_remove)
if not deleted:
logger.error(f"Owner {current_user.email} failed to remove user {user_id_to_remove} from group {group_id}.")
raise GroupOperationError("Failed to remove member")
logger.info(f"Owner {current_user.email} successfully removed user {user_id_to_remove} from group {group_id}")
return Message(detail="Successfully removed member from the group")
# app/api/v1/endpoints/health.py
import logging
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import text
from app.database import get_db
from app.schemas.health import HealthStatus
from app.core.exceptions import DatabaseConnectionError
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get(
"/health",
response_model=HealthStatus,
summary="Perform a Health Check",
description="Checks the operational status of the API and its connection to the database.",
tags=["Health"]
)
async def check_health(db: AsyncSession = Depends(get_db)):
"""
Health check endpoint. Verifies API reachability and database connection.
"""
try:
# Try executing a simple query to check DB connection
result = await db.execute(text("SELECT 1"))
if result.scalar_one() == 1:
logger.info("Health check successful: Database connection verified.")
return HealthStatus(status="ok", database="connected")
else:
# This case should ideally not happen with 'SELECT 1'
logger.error("Health check failed: Database connection check returned unexpected result.")
raise DatabaseConnectionError("Unexpected result from database connection check")
except Exception as e:
logger.error(f"Health check failed: Database connection error - {e}", exc_info=True)
raise DatabaseConnectionError(str(e))
# app/api/v1/endpoints/invites.py
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.api.dependencies import get_current_user
from app.models import User as UserModel, UserRoleEnum
from app.schemas.invite import InviteAccept
from app.schemas.message import Message
from app.crud import invite as crud_invite
from app.crud import group as crud_group
from app.core.exceptions import (
InviteNotFoundError,
InviteExpiredError,
InviteAlreadyUsedError,
InviteCreationError,
GroupNotFoundError,
GroupMembershipError
)
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post(
"/accept", # Route relative to prefix "/invites"
response_model=Message,
summary="Accept Group Invite",
tags=["Invites"]
)
async def accept_invite(
invite_in: InviteAccept,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""Accepts a group invite using the provided invite code."""
logger.info(f"User {current_user.email} attempting to accept invite code: {invite_in.invite_code}")
# Get the invite
invite = await crud_invite.get_invite_by_code(db, invite_code=invite_in.invite_code)
if not invite:
logger.warning(f"Invalid invite code attempted by user {current_user.email}: {invite_in.invite_code}")
raise InviteNotFoundError(invite_in.invite_code)
# Check if invite is expired
if invite.is_expired():
logger.warning(f"Expired invite code attempted by user {current_user.email}: {invite_in.invite_code}")
raise InviteExpiredError(invite_in.invite_code)
# Check if invite has already been used
if invite.used_at:
logger.warning(f"Already used invite code attempted by user {current_user.email}: {invite_in.invite_code}")
raise InviteAlreadyUsedError(invite_in.invite_code)
# Check if group still exists
group = await crud_group.get_group_by_id(db, group_id=invite.group_id)
if not group:
logger.error(f"Group {invite.group_id} not found for invite {invite_in.invite_code}")
raise GroupNotFoundError(invite.group_id)
# Check if user is already a member
is_member = await crud_group.is_user_member(db, group_id=invite.group_id, user_id=current_user.id)
if is_member:
logger.warning(f"User {current_user.email} already a member of group {invite.group_id}")
raise GroupMembershipError(invite.group_id, "join (already a member)")
# Add user to group and mark invite as used
success = await crud_invite.accept_invite(db, invite=invite, user_id=current_user.id)
if not success:
logger.error(f"Failed to accept invite {invite_in.invite_code} for user {current_user.email}")
raise InviteCreationError(invite.group_id)
logger.info(f"User {current_user.email} successfully joined group {invite.group_id} via invite {invite_in.invite_code}")
return Message(detail="Successfully joined the group")
# app/core/security.py
from datetime import datetime, timedelta, timezone
from typing import Any, Union, Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from app.config import settings # Import settings from config
# --- Password Hashing ---
# Configure passlib context
# Using bcrypt as the default hashing scheme
# 'deprecated="auto"' will automatically upgrade hashes if needed on verification
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
Verifies a plain text password against a hashed password.
Args:
plain_password: The password attempt.
hashed_password: The stored hash from the database.
Returns:
True if the password matches the hash, False otherwise.
"""
try:
return pwd_context.verify(plain_password, hashed_password)
except Exception:
# Handle potential errors during verification (e.g., invalid hash format)
return False
def hash_password(password: str) -> str:
"""
Hashes a plain text password using the configured context (bcrypt).
Args:
password: The plain text password to hash.
Returns:
The resulting hash string.
"""
return pwd_context.hash(password)
# --- JSON Web Tokens (JWT) ---
def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
"""
Creates a JWT access token.
Args:
subject: The subject of the token (e.g., user ID or email).
expires_delta: Optional timedelta object for token expiry. If None,
uses ACCESS_TOKEN_EXPIRE_MINUTES from settings.
Returns:
The encoded JWT access token string.
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
# Data to encode in the token payload
to_encode = {"exp": expire, "sub": str(subject), "type": "access"}
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def create_refresh_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
"""
Creates a JWT refresh token.
Args:
subject: The subject of the token (e.g., user ID or email).
expires_delta: Optional timedelta object for token expiry. If None,
uses REFRESH_TOKEN_EXPIRE_MINUTES from settings.
Returns:
The encoded JWT refresh token string.
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
)
# Data to encode in the token payload
to_encode = {"exp": expire, "sub": str(subject), "type": "refresh"}
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def verify_access_token(token: str) -> Optional[dict]:
"""
Verifies a JWT access token and returns its payload if valid.
Args:
token: The JWT token string to verify.
Returns:
The decoded token payload (dict) if the token is valid and not expired,
otherwise None.
"""
try:
# Decode the token. This also automatically verifies:
# - Signature (using SECRET_KEY and ALGORITHM)
# - Expiration ('exp' claim)
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
if payload.get("type") != "access":
raise JWTError("Invalid token type")
return payload
except JWTError as e:
# Handles InvalidSignatureError, ExpiredSignatureError, etc.
print(f"JWT Error: {e}") # Log the error for debugging
return None
except Exception as e:
# Handle other potential unexpected errors during decoding
print(f"Unexpected error decoding JWT: {e}")
return None
def verify_refresh_token(token: str) -> Optional[dict]:
"""
Verifies a JWT refresh token and returns its payload if valid.
Args:
token: The JWT token string to verify.
Returns:
The decoded token payload (dict) if the token is valid, not expired,
and is a refresh token, otherwise None.
"""
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
if payload.get("type") != "refresh":
raise JWTError("Invalid token type")
return payload
except JWTError as e:
print(f"JWT Error: {e}") # Log the error for debugging
return None
except Exception as e:
print(f"Unexpected error decoding JWT: {e}")
return None
# You might add a function here later to extract the 'sub' (subject/user id)
# specifically, often used in dependency injection for authentication.
# def get_subject_from_token(token: str) -> Optional[str]:
# payload = verify_access_token(token)
# if payload:
# return payload.get("sub")
# return None
# app/crud/user.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional
from app.models import User as UserModel # Alias to avoid name clash
from app.schemas.user import UserCreate
from app.core.security import hash_password
from app.core.exceptions import (
UserCreationError,
EmailAlreadyRegisteredError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError
)
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]:
"""Fetches a user from the database by email."""
try:
async with db.begin():
result = await db.execute(select(UserModel).filter(UserModel.email == email))
return result.scalars().first()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query user: {str(e)}")
async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel:
"""Creates a new user record in the database."""
try:
async with db.begin():
_hashed_password = hash_password(user_in.password)
db_user = UserModel(
email=user_in.email,
password_hash=_hashed_password,
name=user_in.name
)
db.add(db_user)
await db.flush() # Flush to get DB-generated values
await db.refresh(db_user)
return db_user
except IntegrityError as e:
if "unique constraint" in str(e).lower():
raise EmailAlreadyRegisteredError()
raise DatabaseIntegrityError(f"Failed to create user: {str(e)}")
except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseTransactionError(f"Failed to create user: {str(e)}")
# app/schemas/auth.py
from pydantic import BaseModel, EmailStr
from app.config import settings
class Token(BaseModel):
access_token: str
refresh_token: str # Added refresh token
token_type: str = settings.TOKEN_TYPE # Use configured token type
# Optional: If you preferred not to use OAuth2PasswordRequestForm
# class UserLogin(BaseModel):
# email: EmailStr
# password: str
from pydantic import BaseModel, ConfigDict
from typing import List, Optional
from decimal import Decimal
class UserCostShare(BaseModel):
user_id: int
user_identifier: str # Name or email
items_added_value: Decimal = Decimal("0.00") # Total value of items this user added
amount_due: Decimal # The user's share of the total cost (for equal split, this is total_cost / num_users)
balance: Decimal # items_added_value - amount_due
model_config = ConfigDict(from_attributes=True)
class ListCostSummary(BaseModel):
list_id: int
list_name: str
total_list_cost: Decimal
num_participating_users: int
equal_share_per_user: Decimal
user_balances: List[UserCostShare]
model_config = ConfigDict(from_attributes=True)
class UserBalanceDetail(BaseModel):
user_id: int
user_identifier: str # Name or email
total_paid_for_expenses: Decimal = Decimal("0.00")
total_share_of_expenses: Decimal = Decimal("0.00")
total_settlements_paid: Decimal = Decimal("0.00")
total_settlements_received: Decimal = Decimal("0.00")
net_balance: Decimal = Decimal("0.00") # (paid_for_expenses + settlements_received) - (share_of_expenses + settlements_paid)
model_config = ConfigDict(from_attributes=True)
class SuggestedSettlement(BaseModel):
from_user_id: int
from_user_identifier: str # Name or email of payer
to_user_id: int
to_user_identifier: str # Name or email of payee
amount: Decimal
model_config = ConfigDict(from_attributes=True)
class GroupBalanceSummary(BaseModel):
group_id: int
group_name: str
overall_total_expenses: Decimal = Decimal("0.00")
overall_total_settlements: Decimal = Decimal("0.00")
user_balances: List[UserBalanceDetail]
# Optional: Could add a list of suggested settlements to zero out balances
suggested_settlements: Optional[List[SuggestedSettlement]] = None
model_config = ConfigDict(from_attributes=True)
# class SuggestedSettlement(BaseModel):
# from_user_id: int
# to_user_id: int
# amount: Decimal
# app/schemas/health.py
from pydantic import BaseModel
from app.config import settings
class HealthStatus(BaseModel):
"""
Response model for the health check endpoint.
"""
status: str = settings.HEALTH_STATUS_OK # Use configured default value
database: str
# app/schemas/item.py
from pydantic import BaseModel, ConfigDict
from datetime import datetime
from typing import Optional
from decimal import Decimal
# Properties to return to client
class ItemPublic(BaseModel):
id: int
list_id: int
name: str
quantity: Optional[str] = None
is_complete: bool
price: Optional[Decimal] = None
added_by_id: int
completed_by_id: Optional[int] = None
created_at: datetime
updated_at: datetime
version: int
model_config = ConfigDict(from_attributes=True)
# Properties to receive via API on creation
class ItemCreate(BaseModel):
name: str
quantity: Optional[str] = None
# list_id will be from path param
# added_by_id will be from current_user
# Properties to receive via API on update
class ItemUpdate(BaseModel):
name: Optional[str] = None
quantity: Optional[str] = None
is_complete: Optional[bool] = None
price: Optional[Decimal] = None # Price added here for update
version: int
# completed_by_id will be set internally if is_complete is true
# app/schemas/list.py
from pydantic import BaseModel, ConfigDict
from datetime import datetime
from typing import Optional, List
from .item import ItemPublic # Import item schema for nesting
# Properties to receive via API on creation
class ListCreate(BaseModel):
name: str
description: Optional[str] = None
group_id: Optional[int] = None # Optional for sharing
# Properties to receive via API on update
class ListUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
is_complete: Optional[bool] = None
version: int # Client must provide the version for updates
# Potentially add group_id update later if needed
# Base properties returned by API (common fields)
class ListBase(BaseModel):
id: int
name: str
description: Optional[str] = None
created_by_id: int
group_id: Optional[int] = None
is_complete: bool
created_at: datetime
updated_at: datetime
version: int # Include version in responses
model_config = ConfigDict(from_attributes=True)
# Properties returned when listing lists (no items)
class ListPublic(ListBase):
pass # Inherits all from ListBase
# Properties returned for a single list detail (includes items)
class ListDetail(ListBase):
items: List[ItemPublic] = [] # Include list of items
class ListStatus(BaseModel):
list_updated_at: datetime
latest_item_updated_at: Optional[datetime] = None # Can be null if list has no items
item_count: int
# app/schemas/user.py
from pydantic import BaseModel, EmailStr, ConfigDict
from datetime import datetime
from typing import Optional
# Shared properties
class UserBase(BaseModel):
email: EmailStr
name: Optional[str] = None
# Properties to receive via API on creation
class UserCreate(UserBase):
password: str
# Properties to receive via API on update (optional, add later if needed)
# class UserUpdate(UserBase):
# password: Optional[str] = None
# Properties stored in DB
class UserInDBBase(UserBase):
id: int
password_hash: str
created_at: datetime
model_config = ConfigDict(from_attributes=True) # Use orm_mode in Pydantic v1
# Additional properties to return via API (excluding password)
class UserPublic(UserBase):
id: int
created_at: datetime
model_config = ConfigDict(from_attributes=True)
# Full user model including hashed password (for internal use/reading from DB)
class User(UserInDBBase):
pass
.DS_Store
.thumbs.db
node_modules
# Quasar core related directories
.quasar
/dist
/quasar.config.*.temporary.compiled*
# Cordova related directories and files
/src-cordova/node_modules
/src-cordova/platforms
/src-cordova/plugins
/src-cordova/www
# Capacitor related directories and files
/src-capacitor/www
/src-capacitor/node_modules
# Log files
npm-debug.log*
yarn-debug.log*
yarn-error.log*
# Editor directories and files
.idea
*.suo
*.ntvs*
*.njsproj
*.sln
# local .env files
.env.local*
# pnpm-related options
shamefully-hoist=true
strict-peer-dependencies=false
# to get the latest compatible packages when creating the project https://github.com/pnpm/pnpm/issues/6463
resolution-mode=highest
// Configuration for your app
// https://v2.quasar.dev/quasar-cli-vite/quasar-config-file
import { defineConfig } from '#q-app/wrappers';
import { fileURLToPath } from 'node:url';
export default defineConfig((ctx) => {
return {
// https://v2.quasar.dev/quasar-cli-vite/prefetch-feature
// preFetch: true,
// app boot file (/src/boot)
// --> boot files are part of "main.js"
// https://v2.quasar.dev/quasar-cli-vite/boot-files
boot: ['i18n', 'axios'],
// https://v2.quasar.dev/quasar-cli-vite/quasar-config-file#css
css: ['app.scss'],
// https://github.com/quasarframework/quasar/tree/dev/extras
extras: [
// 'ionicons-v4',
// 'mdi-v7',
// 'fontawesome-v6',
// 'eva-icons',
// 'themify',
// 'line-awesome',
// 'roboto-font-latin-ext', // this or either 'roboto-font', NEVER both!
'roboto-font', // optional, you are not bound to it
'material-icons', // optional, you are not bound to it
],
// Full list of options: https://v2.quasar.dev/quasar-cli-vite/quasar-config-file#build
build: {
target: {
browser: ['es2022', 'firefox115', 'chrome115', 'safari14'],
node: 'node20',
},
typescript: {
strict: true,
vueShim: true,
// extendTsConfig (tsConfig) {}
},
vueRouterMode: 'hash', // available values: 'hash', 'history'
// vueRouterBase,
// vueDevtools,
// vueOptionsAPI: false,
// rebuildCache: true, // rebuilds Vite/linter/etc cache on startup
// publicPath: '/',
// analyze: true,
// env: {},
// rawDefine: {}
// ignorePublicFolder: true,
// minify: false,
// polyfillModulePreload: true,
// distDir
// extendViteConf (viteConf) {},
// viteVuePluginOptions: {},
vitePlugins: [
[
'@intlify/unplugin-vue-i18n/vite',
{
// if you want to use Vue I18n Legacy API, you need to set `compositionOnly: false`
// compositionOnly: false,
// if you want to use named tokens in your Vue I18n messages, such as 'Hello {name}',
// you need to set `runtimeOnly: false`
// runtimeOnly: false,
ssr: ctx.modeName === 'ssr',
// you need to set i18n resource including paths !
include: [fileURLToPath(new URL('./src/i18n', import.meta.url))],
},
],
[
'vite-plugin-checker',
{
vueTsc: true,
eslint: {
lintCommand: 'eslint -c ./eslint.config.js "./src*/**/*.{ts,js,mjs,cjs,vue}"',
useFlatConfig: true,
},
},
{ server: false },
],
],
},
// Full list of options: https://v2.quasar.dev/quasar-cli-vite/quasar-config-file#devserver
devServer: {
// https: true,
open: true, // opens browser window automatically
},
// https://v2.quasar.dev/quasar-cli-vite/quasar-config-file#framework
framework: {
config: {},
// iconSet: 'material-icons', // Quasar icon set
// lang: 'en-US', // Quasar language pack
// For special cases outside of where the auto-import strategy can have an impact
// (like functional components as one of the examples),
// you can manually specify Quasar components/directives to be available everywhere:
//
// components: [],
// directives: [],
// Quasar plugins
plugins: ['Notify'],
},
// animations: 'all', // --- includes all animations
// https://v2.quasar.dev/options/animations
animations: [],
// https://v2.quasar.dev/quasar-cli-vite/quasar-config-file#sourcefiles
// sourceFiles: {
// rootComponent: 'src/App.vue',
// router: 'src/router/index',
// store: 'src/store/index',
// pwaRegisterServiceWorker: 'src-pwa/register-service-worker',
// pwaServiceWorker: 'src-pwa/custom-service-worker',
// pwaManifestFile: 'src-pwa/manifest.json',
// electronMain: 'src-electron/electron-main',
// electronPreload: 'src-electron/electron-preload'
// bexManifestFile: 'src-bex/manifest.json
// },
// https://v2.quasar.dev/quasar-cli-vite/developing-ssr/configuring-ssr
ssr: {
prodPort: 3000, // The default port that the production server should use
// (gets superseded if process.env.PORT is specified at runtime)
middlewares: [
'render', // keep this as last one
],
// extendPackageJson (json) {},
// extendSSRWebserverConf (esbuildConf) {},
// manualStoreSerialization: true,
// manualStoreSsrContextInjection: true,
// manualStoreHydration: true,
// manualPostHydrationTrigger: true,
pwa: false,
// pwaOfflineHtmlFilename: 'offline.html', // do NOT use index.html as name!
// pwaExtendGenerateSWOptions (cfg) {},
// pwaExtendInjectManifestOptions (cfg) {}
},
// https://v2.quasar.dev/quasar-cli-vite/developing-pwa/configuring-pwa
pwa: {
workboxMode: 'InjectManifest', // Changed from 'GenerateSW' to 'InjectManifest'
swFilename: 'sw.js',
manifestFilename: 'manifest.json',
injectPwaMetaTags: true,
// extendManifestJson (json) {},
// useCredentialsForManifestTag: true,
// extendPWACustomSWConf (esbuildConf) {},
// extendGenerateSWOptions (cfg) {},
// extendInjectManifestOptions (cfg) {}
},
// Full list of options: https://v2.quasar.dev/quasar-cli-vite/developing-cordova-apps/configuring-cordova
cordova: {
// noIosLegacyBuildFlag: true, // uncomment only if you know what you are doing
},
// Full list of options: https://v2.quasar.dev/quasar-cli-vite/developing-capacitor-apps/configuring-capacitor
capacitor: {
hideSplashscreen: true,
},
// Full list of options: https://v2.quasar.dev/quasar-cli-vite/developing-electron-apps/configuring-electron
electron: {
// extendElectronMainConf (esbuildConf) {},
// extendElectronPreloadConf (esbuildConf) {},
// extendPackageJson (json) {},
// Electron preload scripts (if any) from /src-electron, WITHOUT file extension
preloadScripts: ['electron-preload'],
// specify the debugging port to use for the Electron app when running in development mode
inspectPort: 5858,
bundler: 'packager', // 'packager' or 'builder'
packager: {
// https://github.com/electron-userland/electron-packager/blob/master/docs/api.md#options
// OS X / Mac App Store
// appBundleId: '',
// appCategoryType: '',
// osxSign: '',
// protocol: 'myapp://path',
// Windows only
// win32metadata: { ... }
},
builder: {
// https://www.electron.build/configuration/configuration
appId: 'mitlist',
},
},
// Full list of options: https://v2.quasar.dev/quasar-cli-vite/developing-browser-extensions/configuring-bex
bex: {
// extendBexScriptsConf (esbuildConf) {},
// extendBexManifestJson (json) {},
/**
* The list of extra scripts (js/ts) not in your bex manifest that you want to
* compile and use in your browser extension. Maybe dynamic use them?
*
* Each entry in the list should be a relative filename to /src-bex/
*
* @example [ 'my-script.ts', 'sub-folder/my-other-script.js' ]
*/
extraScripts: [],
},
};
});
# mitlist (mitlist)
mitlist pwa
## Install the dependencies
```bash
yarn
# or
npm install
```
### Start the app in development mode (hot-code reloading, error reporting, etc.)
```bash
quasar dev
```
### Lint the files
```bash
yarn lint
# or
npm run lint
```
### Format the files
```bash
yarn format
# or
npm run format
```
### Build the app for production
```bash
quasar build
```
### Customize the configuration
See [Configuring quasar.config.js](https://v2.quasar.dev/quasar-cli-vite/quasar-config-js).
/*
* This file (which will be your service worker)
* is picked up by the build system ONLY if
* quasar.config file > pwa > workboxMode is set to "InjectManifest"
*/
declare const self: ServiceWorkerGlobalScope &
typeof globalThis & { skipWaiting: () => Promise };
import { clientsClaim } from 'workbox-core';
import {
precacheAndRoute,
cleanupOutdatedCaches,
createHandlerBoundToURL,
} from 'workbox-precaching';
import { registerRoute, NavigationRoute } from 'workbox-routing';
import { CacheFirst, NetworkFirst } from 'workbox-strategies';
import { ExpirationPlugin } from 'workbox-expiration';
import { CacheableResponsePlugin } from 'workbox-cacheable-response';
import type { WorkboxPlugin } from 'workbox-core/types';
self.skipWaiting().catch((error) => {
console.error('Error during service worker activation:', error);
});
clientsClaim();
// Use with precache injection
precacheAndRoute(self.__WB_MANIFEST);
cleanupOutdatedCaches();
// Cache app shell and static assets with Cache First strategy
registerRoute(
// Match static assets
({ request }) =>
request.destination === 'style' ||
request.destination === 'script' ||
request.destination === 'image' ||
request.destination === 'font',
new CacheFirst({
cacheName: 'static-assets',
plugins: [
new CacheableResponsePlugin({
statuses: [0, 200],
}) as WorkboxPlugin,
new ExpirationPlugin({
maxEntries: 60,
maxAgeSeconds: 30 * 24 * 60 * 60, // 30 days
}) as WorkboxPlugin,
],
})
);
// Cache API calls with Network First strategy
registerRoute(
// Match API calls
({ url }) => url.pathname.startsWith('/api/'),
new NetworkFirst({
cacheName: 'api-cache',
plugins: [
new CacheableResponsePlugin({
statuses: [0, 200],
}) as WorkboxPlugin,
new ExpirationPlugin({
maxEntries: 50,
maxAgeSeconds: 24 * 60 * 60, // 24 hours
}) as WorkboxPlugin,
],
})
);
// Non-SSR fallbacks to index.html
// Production SSR fallbacks to offline.html (except for dev)
if (process.env.MODE !== 'ssr' || process.env.PROD) {
registerRoute(
new NavigationRoute(createHandlerBoundToURL(process.env.PWA_FALLBACK_HTML), {
denylist: [new RegExp(process.env.PWA_SERVICE_WORKER_REGEX), /workbox-(.)*\.js$/],
}),
);
}
Conflict Resolution
This item was modified while you were offline. Please review the changes and choose how to resolve the conflict.
Your Version
Last modified: {{ formatDate(conflictData?.localVersion.timestamp ?? 0) }}
{{ formatKey(key) }}
{{ formatValue(value) }}
Server Version
Last modified: {{ formatDate(conflictData?.serverVersion.timestamp ?? 0) }}
You are currently offline. Changes will be saved locally.
Syncing {{ pendingActionCount }} pending {{ pendingActionCount === 1 ? 'change' : 'changes' }}...
Pending Changes
{{ getActionLabel(action) }}
{{ new Date(action.timestamp).toLocaleString() }}
Mooo Logout
import type { RouteRecordRaw } from 'vue-router';
const routes: RouteRecordRaw[] = [
{
path: '/',
component: () => import('layouts/MainLayout.vue'),
children: [
{ path: '', redirect: '/lists' },
{ path: 'lists', name: 'PersonalLists', component: () => import('pages/ListsPage.vue') },
{
path: 'lists/:id',
name: 'ListDetail',
component: () => import('pages/ListDetailPage.vue'),
props: true,
},
{ path: 'groups', name: 'GroupsList', component: () => import('pages/GroupsPage.vue') },
{
path: 'groups/:id',
name: 'GroupDetail',
component: () => import('pages/GroupDetailPage.vue'),
props: true,
},
{
path: 'groups/:groupId/lists',
name: 'GroupLists',
component: () => import('pages/ListsPage.vue'),
props: true,
},
{ path: 'account', name: 'Account', component: () => import('pages/AccountPage.vue') },
],
},
{
path: '/',
component: () => import('layouts/AuthLayout.vue'),
children: [
{ path: 'login', component: () => import('pages/LoginPage.vue') },
{ path: 'signup', component: () => import('pages/SignupPage.vue') },
],
},
// Always leave this as last one,
// but you can also remove it
{
path: '/:catchAll(.*)*',
component: () => import('pages/ErrorNotFound.vue'),
},
];
export default routes;
import { defineStore } from 'pinia';
import { ref, computed } from 'vue';
import { apiClient, API_ENDPOINTS } from 'src/config/api';
interface AuthState {
accessToken: string | null;
refreshToken: string | null;
user: {
email: string;
name: string;
} | null;
}
export const useAuthStore = defineStore('auth', () => {
// State
const accessToken = ref(localStorage.getItem('token'));
const refreshToken = ref(localStorage.getItem('refresh_token'));
const user = ref(null);
// Getters
const isAuthenticated = computed(() => !!accessToken.value);
const getUser = computed(() => user.value);
// Actions
const setTokens = (tokens: { access_token: string; refresh_token: string }) => {
accessToken.value = tokens.access_token;
refreshToken.value = tokens.refresh_token;
localStorage.setItem('token', tokens.access_token);
localStorage.setItem('refresh_token', tokens.refresh_token);
};
const clearTokens = () => {
accessToken.value = null;
refreshToken.value = null;
user.value = null;
localStorage.removeItem('token');
localStorage.removeItem('refresh_token');
};
const setUser = (userData: AuthState['user']) => {
user.value = userData;
};
const login = async (email: string, password: string) => {
const formData = new FormData();
formData.append('username', email);
formData.append('password', password);
const response = await apiClient.post(API_ENDPOINTS.AUTH.LOGIN, formData, {
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
});
const { access_token, refresh_token } = response.data;
setTokens({ access_token, refresh_token });
return response.data;
};
const signup = async (userData: { name: string; email: string; password: string }) => {
const response = await apiClient.post(API_ENDPOINTS.AUTH.SIGNUP, userData);
return response.data;
};
const refreshAccessToken = async () => {
if (!refreshToken.value) {
throw new Error('No refresh token available');
}
try {
const response = await apiClient.post(API_ENDPOINTS.AUTH.REFRESH_TOKEN, {
refresh_token: refreshToken.value,
});
const { access_token, refresh_token } = response.data;
setTokens({ access_token, refresh_token });
return response.data;
} catch (error) {
clearTokens();
throw error;
}
};
const logout = () => {
clearTokens();
};
return {
// State
accessToken,
refreshToken,
user,
// Getters
isAuthenticated,
getUser,
// Actions
setTokens,
clearTokens,
setUser,
login,
signup,
refreshAccessToken,
logout,
};
});
import { defineStore } from 'pinia';
import { ref, computed } from 'vue';
import { useQuasar } from 'quasar';
import { LocalStorage } from 'quasar';
export interface OfflineAction {
id: string;
type: 'add' | 'complete' | 'update' | 'delete';
itemId?: string;
data: unknown;
timestamp: number;
version?: number;
}
export interface ConflictResolution {
version: 'local' | 'server' | 'merge';
action: OfflineAction;
}
export interface ConflictData {
localVersion: {
data: Record;
timestamp: number;
};
serverVersion: {
data: Record;
timestamp: number;
};
action: OfflineAction;
}
export const useOfflineStore = defineStore('offline', () => {
const $q = useQuasar();
const isOnline = ref(navigator.onLine);
const pendingActions = ref([]);
const isProcessingQueue = ref(false);
const showConflictDialog = ref(false);
const currentConflict = ref(null);
// Initialize from IndexedDB
const init = () => {
try {
const stored = LocalStorage.getItem('offline-actions');
if (stored) {
pendingActions.value = JSON.parse(stored as string);
}
} catch (error) {
console.error('Failed to load offline actions:', error);
}
};
// Save to IndexedDB
const saveToStorage = () => {
try {
LocalStorage.set('offline-actions', JSON.stringify(pendingActions.value));
} catch (error) {
console.error('Failed to save offline actions:', error);
}
};
// Add a new offline action
const addAction = (action: Omit) => {
const newAction: OfflineAction = {
...action,
id: crypto.randomUUID(),
timestamp: Date.now(),
};
pendingActions.value.push(newAction);
saveToStorage();
};
// Process the queue when online
const processQueue = async () => {
if (isProcessingQueue.value || !isOnline.value) return;
isProcessingQueue.value = true;
const actions = [...pendingActions.value];
for (const action of actions) {
try {
await processAction(action);
pendingActions.value = pendingActions.value.filter(a => a.id !== action.id);
saveToStorage();
} catch (error) {
if (error instanceof Error && error.message.includes('409')) {
$q.notify({
type: 'warning',
message: 'Item was modified by someone else while you were offline. Please review.',
actions: [
{
label: 'Review',
color: 'white',
handler: () => {
// TODO: Implement conflict resolution UI
}
}
]
});
} else {
console.error('Failed to process offline action:', error);
}
}
}
isProcessingQueue.value = false;
};
// Process a single action
const processAction = async (action: OfflineAction) => {
TODO: Implement actual API calls
switch (action.type) {
case 'add':
// await api.addItem(action.data);
break;
case 'complete':
// await api.completeItem(action.itemId, action.data);
break;
case 'update':
// await api.updateItem(action.itemId, action.data);
break;
case 'delete':
// await api.deleteItem(action.itemId);
break;
}
};
// Listen for online/offline status changes
const setupNetworkListeners = () => {
window.addEventListener('online', () => {
(async () => {
isOnline.value = true;
await processQueue();
})().catch(error => {
console.error('Error processing queue:', error);
});
});
window.addEventListener('offline', () => {
isOnline.value = false;
});
};
// Computed properties
const hasPendingActions = computed(() => pendingActions.value.length > 0);
const pendingActionCount = computed(() => pendingActions.value.length);
// Initialize
init();
setupNetworkListeners();
const handleConflictResolution = (resolution: ConflictResolution) => {
// Implement the logic to handle the conflict resolution
console.log('Conflict resolution:', resolution);
};
return {
isOnline,
pendingActions,
hasPendingActions,
pendingActionCount,
showConflictDialog,
currentConflict,
addAction,
processQueue,
handleConflictResolution,
};
});
{
"extends": "./.quasar/tsconfig.json"
}
# app/api/v1/endpoints/auth.py
import logging
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.schemas.user import UserCreate, UserPublic
from app.schemas.auth import Token
from app.crud import user as crud_user
from app.core.security import (
verify_password,
create_access_token,
create_refresh_token,
verify_refresh_token
)
from app.core.exceptions import (
EmailAlreadyRegisteredError,
InvalidCredentialsError,
UserCreationError
)
from app.config import settings
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post(
"/signup",
response_model=UserPublic,
status_code=201,
summary="Register New User",
description="Creates a new user account.",
tags=["Authentication"]
)
async def signup(
user_in: UserCreate,
db: AsyncSession = Depends(get_db)
):
"""
Handles user registration.
- Validates input data.
- Checks if email already exists.
- Hashes the password.
- Stores the new user in the database.
"""
logger.info(f"Signup attempt for email: {user_in.email}")
existing_user = await crud_user.get_user_by_email(db, email=user_in.email)
if existing_user:
logger.warning(f"Signup failed: Email already registered - {user_in.email}")
raise EmailAlreadyRegisteredError()
try:
created_user = await crud_user.create_user(db=db, user_in=user_in)
logger.info(f"User created successfully: {created_user.email} (ID: {created_user.id})")
return created_user
except Exception as e:
logger.error(f"Error during user creation for {user_in.email}: {e}", exc_info=True)
raise UserCreationError()
@router.post(
"/login",
response_model=Token,
summary="User Login",
description="Authenticates a user and returns an access and refresh token.",
tags=["Authentication"]
)
async def login(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: AsyncSession = Depends(get_db)
):
"""
Handles user login.
- Finds user by email (provided in 'username' field of form).
- Verifies the provided password against the stored hash.
- Generates and returns JWT access and refresh tokens upon successful authentication.
"""
logger.info(f"Login attempt for user: {form_data.username}")
user = await crud_user.get_user_by_email(db, email=form_data.username)
if not user or not verify_password(form_data.password, user.password_hash):
logger.warning(f"Login failed: Invalid credentials for user {form_data.username}")
raise InvalidCredentialsError()
access_token = create_access_token(subject=user.email)
refresh_token = create_refresh_token(subject=user.email)
logger.info(f"Login successful, tokens generated for user: {user.email}")
return Token(
access_token=access_token,
refresh_token=refresh_token,
token_type=settings.TOKEN_TYPE
)
@router.post(
"/refresh",
response_model=Token,
summary="Refresh Access Token",
description="Refreshes an access token using a refresh token.",
tags=["Authentication"]
)
async def refresh_token(
refresh_token_str: str,
db: AsyncSession = Depends(get_db)
):
"""
Handles access token refresh.
- Verifies the provided refresh token.
- If valid, generates and returns a new JWT access token and the same refresh token.
"""
logger.info("Access token refresh attempt")
payload = verify_refresh_token(refresh_token_str)
if not payload:
logger.warning("Refresh token invalid or expired")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
headers={"WWW-Authenticate": "Bearer"},
)
user_email = payload.get("sub")
if not user_email:
logger.error("User email not found in refresh token payload")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token payload",
headers={"WWW-Authenticate": "Bearer"},
)
new_access_token = create_access_token(subject=user_email)
logger.info(f"Access token refreshed for user: {user_email}")
return Token(
access_token=new_access_token,
refresh_token=refresh_token_str,
token_type=settings.TOKEN_TYPE
)
# app/api/v1/endpoints/lists.py
import logging
from typing import List as PyList, Optional # Alias for Python List type hint
from fastapi import APIRouter, Depends, HTTPException, status, Response, Query # Added Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.api.dependencies import get_current_user
from app.models import User as UserModel
from app.schemas.list import ListCreate, ListUpdate, ListPublic, ListDetail
from app.schemas.message import Message # For simple responses
from app.crud import list as crud_list
from app.crud import group as crud_group # Need for group membership check
from app.schemas.list import ListStatus
from app.core.exceptions import (
GroupMembershipError,
ListNotFoundError,
ListPermissionError,
ListStatusNotFoundError,
ConflictError # Added ConflictError
)
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post(
"", # Route relative to prefix "/lists"
response_model=ListPublic, # Return basic list info on creation
status_code=status.HTTP_201_CREATED,
summary="Create New List",
tags=["Lists"]
)
async def create_list(
list_in: ListCreate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Creates a new shopping list.
- If `group_id` is provided, the user must be a member of that group.
- If `group_id` is null, it's a personal list.
"""
logger.info(f"User {current_user.email} creating list: {list_in.name}")
group_id = list_in.group_id
# Permission Check: If sharing with a group, verify membership
if group_id:
is_member = await crud_group.is_user_member(db, group_id=group_id, user_id=current_user.id)
if not is_member:
logger.warning(f"User {current_user.email} attempted to create list in group {group_id} but is not a member.")
raise GroupMembershipError(group_id, "create lists")
created_list = await crud_list.create_list(db=db, list_in=list_in, creator_id=current_user.id)
logger.info(f"List '{created_list.name}' (ID: {created_list.id}) created successfully for user {current_user.email}.")
return created_list
@router.get(
"", # Route relative to prefix "/lists"
response_model=PyList[ListPublic], # Return a list of basic list info
summary="List Accessible Lists",
tags=["Lists"]
)
async def read_lists(
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
# Add pagination parameters later if needed: skip: int = 0, limit: int = 100
):
"""
Retrieves lists accessible to the current user:
- Personal lists created by the user.
- Lists belonging to groups the user is a member of.
"""
logger.info(f"Fetching lists accessible to user: {current_user.email}")
lists = await crud_list.get_lists_for_user(db=db, user_id=current_user.id)
return lists
@router.get(
"/{list_id}",
response_model=ListDetail, # Return detailed list info including items
summary="Get List Details",
tags=["Lists"]
)
async def read_list(
list_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Retrieves details for a specific list, including its items,
if the user has permission (creator or group member).
"""
logger.info(f"User {current_user.email} requesting details for list ID: {list_id}")
# The check_list_permission function will raise appropriate exceptions
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
return list_db
@router.put(
"/{list_id}",
response_model=ListPublic, # Return updated basic info
summary="Update List",
tags=["Lists"],
responses={ # Add 409 to responses
status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified by someone else"}
}
)
async def update_list(
list_id: int,
list_in: ListUpdate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Updates a list's details (name, description, is_complete).
Requires user to be the creator or a member of the list's group.
The client MUST provide the current `version` of the list in the `list_in` payload.
If the version does not match, a 409 Conflict is returned.
"""
logger.info(f"User {current_user.email} attempting to update list ID: {list_id} with version {list_in.version}")
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
try:
updated_list = await crud_list.update_list(db=db, list_db=list_db, list_in=list_in)
logger.info(f"List {list_id} updated successfully by user {current_user.email} to version {updated_list.version}.")
return updated_list
except ConflictError as e: # Catch and re-raise as HTTPException for proper FastAPI response
logger.warning(f"Conflict updating list {list_id} for user {current_user.email}: {str(e)}")
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
except Exception as e: # Catch other potential errors from crud operation
logger.error(f"Error updating list {list_id} for user {current_user.email}: {str(e)}")
# Consider a more generic error, but for now, let's keep it specific if possible
# Re-raising might be better if crud layer already raises appropriate HTTPExceptions
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the list.")
@router.delete(
"/{list_id}",
status_code=status.HTTP_204_NO_CONTENT, # Standard for successful DELETE with no body
summary="Delete List",
tags=["Lists"],
responses={ # Add 409 to responses
status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified, cannot delete specified version"}
}
)
async def delete_list(
list_id: int,
expected_version: Optional[int] = Query(None, description="The expected version of the list to delete for optimistic locking."),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Deletes a list. Requires user to be the creator of the list.
If `expected_version` is provided and does not match the list's current version,
a 409 Conflict is returned.
"""
logger.info(f"User {current_user.email} attempting to delete list ID: {list_id}, expected version: {expected_version}")
# Use the helper, requiring creator permission
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id, require_creator=True)
if expected_version is not None and list_db.version != expected_version:
logger.warning(
f"Conflict deleting list {list_id} for user {current_user.email}. "
f"Expected version {expected_version}, actual version {list_db.version}."
)
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"List has been modified. Expected version {expected_version}, but current version is {list_db.version}. Please refresh."
)
await crud_list.delete_list(db=db, list_db=list_db)
logger.info(f"List {list_id} (version: {list_db.version}) deleted successfully by user {current_user.email}.")
return Response(status_code=status.HTTP_204_NO_CONTENT)
@router.get(
"/{list_id}/status",
response_model=ListStatus,
summary="Get List Status",
tags=["Lists"]
)
async def read_list_status(
list_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Retrieves the last update time for the list and its items, plus item count.
Used for polling to check if a full refresh is needed.
Requires user to have permission to view the list.
"""
# Verify user has access to the list first
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
if not list_db:
# Check if list exists at all for correct error code
exists = await crud_list.get_list_by_id(db, list_id)
if not exists:
raise ListNotFoundError(list_id)
raise ListPermissionError(list_id, "access this list's status")
# Fetch the status details
list_status = await crud_list.get_list_status(db=db, list_id=list_id)
if not list_status:
# Should not happen if check_list_permission passed, but handle defensively
logger.error(f"Could not retrieve status for list {list_id} even though permission check passed.")
raise ListStatusNotFoundError(list_id)
return list_status
import logging
from typing import List
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException, status
from google.api_core import exceptions as google_exceptions
from app.api.dependencies import get_current_user
from app.models import User as UserModel
from app.schemas.ocr import OcrExtractResponse
from app.core.gemini import extract_items_from_image_gemini, gemini_initialization_error, GeminiOCRService
from app.core.exceptions import (
OCRServiceUnavailableError,
OCRServiceConfigError,
OCRUnexpectedError,
OCRQuotaExceededError,
InvalidFileTypeError,
FileTooLargeError,
OCRProcessingError
)
from app.config import settings
logger = logging.getLogger(__name__)
router = APIRouter()
ocr_service = GeminiOCRService()
@router.post(
"/extract-items",
response_model=OcrExtractResponse,
summary="Extract List Items via OCR (Gemini)",
tags=["OCR"]
)
async def ocr_extract_items(
current_user: UserModel = Depends(get_current_user),
image_file: UploadFile = File(..., description="Image file (JPEG, PNG, WEBP) of the shopping list or receipt."),
):
"""
Accepts an image upload, sends it to Gemini Flash with a prompt
to extract shopping list items, and returns the parsed items.
"""
# Check if Gemini client initialized correctly
if gemini_initialization_error:
logger.error("OCR endpoint called but Gemini client failed to initialize.")
raise OCRServiceUnavailableError(gemini_initialization_error)
logger.info(f"User {current_user.email} uploading image '{image_file.filename}' for OCR extraction.")
# --- File Validation ---
if image_file.content_type not in settings.ALLOWED_IMAGE_TYPES:
logger.warning(f"Invalid file type uploaded by {current_user.email}: {image_file.content_type}")
raise InvalidFileTypeError()
# Simple size check
contents = await image_file.read()
if len(contents) > settings.MAX_FILE_SIZE_MB * 1024 * 1024:
logger.warning(f"File too large uploaded by {current_user.email}: {len(contents)} bytes")
raise FileTooLargeError()
try:
# Call the Gemini helper function
extracted_items = await extract_items_from_image_gemini(
image_bytes=contents,
mime_type=image_file.content_type
)
logger.info(f"Successfully extracted {len(extracted_items)} items for user {current_user.email}.")
return OcrExtractResponse(extracted_items=extracted_items)
except OCRServiceUnavailableError:
raise OCRServiceUnavailableError()
except OCRServiceConfigError:
raise OCRServiceConfigError()
except OCRQuotaExceededError:
raise OCRQuotaExceededError()
except Exception as e:
raise OCRProcessingError(str(e))
finally:
# Ensure file handle is closed
await image_file.close()
# app/core/gemini.py
import logging
from typing import List
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold # For safety settings
from google.api_core import exceptions as google_exceptions
from app.config import settings
from app.core.exceptions import (
OCRServiceUnavailableError,
OCRServiceConfigError,
OCRUnexpectedError,
OCRQuotaExceededError
)
logger = logging.getLogger(__name__)
# --- Global variable to hold the initialized model client ---
gemini_flash_client = None
gemini_initialization_error = None # Store potential init error
# --- Configure and Initialize ---
try:
if settings.GEMINI_API_KEY:
genai.configure(api_key=settings.GEMINI_API_KEY)
# Initialize the specific model we want to use
gemini_flash_client = genai.GenerativeModel(
model_name=settings.GEMINI_MODEL_NAME,
# Safety settings from config
safety_settings={
getattr(HarmCategory, category): getattr(HarmBlockThreshold, threshold)
for category, threshold in settings.GEMINI_SAFETY_SETTINGS.items()
},
# Generation config from settings
generation_config=genai.types.GenerationConfig(
**settings.GEMINI_GENERATION_CONFIG
)
)
logger.info(f"Gemini AI client initialized successfully for model '{settings.GEMINI_MODEL_NAME}'.")
else:
# Store error if API key is missing
gemini_initialization_error = "GEMINI_API_KEY not configured. Gemini client not initialized."
logger.error(gemini_initialization_error)
except Exception as e:
# Catch any other unexpected errors during initialization
gemini_initialization_error = f"Failed to initialize Gemini AI client: {e}"
logger.exception(gemini_initialization_error) # Log full traceback
gemini_flash_client = None # Ensure client is None on error
# --- Function to get the client (optional, allows checking error) ---
def get_gemini_client():
"""
Returns the initialized Gemini client instance.
Raises an exception if initialization failed.
"""
if gemini_initialization_error:
raise RuntimeError(f"Gemini client could not be initialized: {gemini_initialization_error}")
if gemini_flash_client is None:
# This case should ideally be covered by the check above, but as a safeguard:
raise RuntimeError("Gemini client is not available (unknown initialization issue).")
return gemini_flash_client
# Define the prompt as a constant
OCR_ITEM_EXTRACTION_PROMPT = """
Extract the shopping list items from this image.
List each distinct item on a new line.
Ignore prices, quantities, store names, discounts, taxes, totals, and other non-item text.
Focus only on the names of the products or items to be purchased.
If the image does not appear to be a shopping list or receipt, state that clearly.
Example output for a grocery list:
Milk
Eggs
Bread
Apples
Organic Bananas
"""
async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "image/jpeg") -> List[str]:
"""
Uses Gemini Flash to extract shopping list items from image bytes.
Args:
image_bytes: The image content as bytes.
mime_type: The MIME type of the image (e.g., "image/jpeg", "image/png", "image/webp").
Returns:
A list of extracted item strings.
Raises:
RuntimeError: If the Gemini client is not initialized.
google_exceptions.GoogleAPIError: For API call errors (quota, invalid key etc.).
ValueError: If the response is blocked or contains no usable text.
"""
client = get_gemini_client() # Raises RuntimeError if not initialized
# Prepare image part for multimodal input
image_part = {
"mime_type": mime_type,
"data": image_bytes
}
# Prepare the full prompt content
prompt_parts = [
settings.OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first
image_part # Then the image
]
logger.info("Sending image to Gemini for item extraction...")
try:
# Make the API call
# Use generate_content_async for async FastAPI
response = await client.generate_content_async(prompt_parts)
# --- Process the response ---
# Check for safety blocks or lack of content
if not response.candidates or not response.candidates[0].content.parts:
logger.warning("Gemini response blocked or empty.", extra={"response": response})
# Check finish_reason if available
finish_reason = response.candidates[0].finish_reason if response.candidates else 'UNKNOWN'
safety_ratings = response.candidates[0].safety_ratings if response.candidates else 'N/A'
if finish_reason == 'SAFETY':
raise ValueError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
else:
raise ValueError(f"Gemini response was empty or incomplete. Finish Reason: {finish_reason}")
# Extract text - assumes the first part of the first candidate is the text response
raw_text = response.text # response.text is a shortcut for response.candidates[0].content.parts[0].text
logger.info("Received raw text from Gemini.")
# logger.debug(f"Gemini Raw Text:\n{raw_text}") # Optional: Log full response text
# Parse the text response
items = []
for line in raw_text.splitlines(): # Split by newline
cleaned_line = line.strip() # Remove leading/trailing whitespace
# Basic filtering: ignore empty lines and potential non-item lines
if cleaned_line and len(cleaned_line) > 1: # Ignore very short lines too?
# Add more sophisticated filtering if needed (e.g., regex, keyword check)
items.append(cleaned_line)
logger.info(f"Extracted {len(items)} potential items.")
return items
except google_exceptions.GoogleAPIError as e:
logger.error(f"Gemini API Error: {e}", exc_info=True)
# Re-raise specific Google API errors for endpoint to handle (e.g., quota)
raise e
except Exception as e:
# Catch other unexpected errors during generation or processing
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
# Wrap in a generic ValueError or re-raise
raise ValueError(f"Failed to process image with Gemini: {e}") from e
class GeminiOCRService:
def __init__(self):
try:
genai.configure(api_key=settings.GEMINI_API_KEY)
self.model = genai.GenerativeModel(settings.GEMINI_MODEL_NAME)
self.model.safety_settings = settings.GEMINI_SAFETY_SETTINGS
self.model.generation_config = settings.GEMINI_GENERATION_CONFIG
except Exception as e:
logger.error(f"Failed to initialize Gemini client: {e}")
raise OCRServiceConfigError()
async def extract_items(self, image_data: bytes) -> List[str]:
"""
Extract shopping list items from an image using Gemini Vision.
"""
try:
# Create image part
image_parts = [{"mime_type": "image/jpeg", "data": image_data}]
# Generate content
response = await self.model.generate_content_async(
contents=[settings.OCR_ITEM_EXTRACTION_PROMPT, *image_parts]
)
# Process response
if not response.text:
raise OCRUnexpectedError()
# Split response into lines and clean up
items = [
item.strip()
for item in response.text.split("\n")
if item.strip() and not item.strip().startswith("Example")
]
return items
except Exception as e:
logger.error(f"Error during OCR extraction: {e}")
if "quota" in str(e).lower():
raise OCRQuotaExceededError()
raise OCRServiceUnavailableError()
# be/Dockerfile
# Choose a suitable Python base image
FROM python:3.11-slim
# Set environment variables
ENV PYTHONDONTWRITEBYTECODE 1 # Prevent python from writing pyc files
ENV PYTHONUNBUFFERED 1 # Keep stdout/stderr unbuffered
# Set the working directory in the container
WORKDIR /app
# Install system dependencies if needed (e.g., for psycopg2 build)
# RUN apt-get update && apt-get install -y --no-install-recommends gcc build-essential libpq-dev && rm -rf /var/lib/apt/lists/*
# Install Python dependencies
# Upgrade pip first
RUN pip install --no-cache-dir --upgrade pip
# Copy only requirements first to leverage Docker cache
COPY requirements.txt requirements.txt
# Install dependencies
RUN pip install --no-cache-dir -r requirements.txt
# Copy the rest of the application code into the working directory
COPY . .
# This includes your 'app/' directory, alembic.ini, etc.
# Expose the port the app runs on
EXPOSE 8000
# Command to run the application using uvicorn
# The default command for production (can be overridden in docker-compose for development)
# Note: Make sure 'app.main:app' correctly points to your FastAPI app instance
# relative to the WORKDIR (/app). If your main.py is directly in /app, this is correct.
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
fastapi>=0.95.0
uvicorn[standard]>=0.20.0
sqlalchemy[asyncio]>=2.0.0 # Core ORM + Async support
asyncpg>=0.27.0 # Async PostgreSQL driver
psycopg2-binary>=2.9.0 # Often needed by Alembic even if app uses asyncpg
alembic>=1.9.0 # Database migrations
pydantic-settings>=2.0.0 # For loading settings from .env
python-dotenv>=1.0.0 # To load .env file for scripts/alembic
passlib[bcrypt]>=1.7.4
python-jose[cryptography]>=3.3.0
pydantic[email]
google-generativeai>=0.5.0
{
"name": "mitlist",
"version": "0.0.1",
"description": "mitlist pwa",
"productName": "mitlist",
"author": "Mohamad ",
"type": "module",
"private": true,
"scripts": {
"lint": "eslint -c ./eslint.config.js \"./src*/**/*.{ts,js,cjs,mjs,vue}\"",
"format": "prettier --write \"**/*.{js,ts,vue,scss,html,md,json}\" --ignore-path .gitignore",
"test": "echo \"No test specified\" && exit 0",
"dev": "quasar dev",
"build": "quasar build",
"postinstall": "quasar prepare"
},
"dependencies": {
"@quasar/extras": "^1.16.4",
"axios": "^1.2.1",
"pinia": "^3.0.1",
"quasar": "^2.16.0",
"register-service-worker": "^1.7.2",
"vue": "^3.4.18",
"vue-i18n": "^11.0.0",
"vue-router": "^4.0.12"
},
"devDependencies": {
"@eslint/js": "^9.14.0",
"@intlify/unplugin-vue-i18n": "^4.0.0",
"@quasar/app-vite": "^2.1.0",
"@types/node": "^20.5.9",
"@vue/eslint-config-prettier": "^10.1.0",
"@vue/eslint-config-typescript": "^14.4.0",
"autoprefixer": "^10.4.2",
"eslint": "^9.14.0",
"eslint-plugin-vue": "^9.30.0",
"globals": "^15.12.0",
"prettier": "^3.3.3",
"typescript": "~5.5.3",
"vite-plugin-checker": "^0.9.0",
"vue-tsc": "^2.0.29",
"workbox-build": "^7.3.0",
"workbox-cacheable-response": "^7.3.0",
"workbox-core": "^7.3.0",
"workbox-expiration": "^7.3.0",
"workbox-precaching": "^7.3.0",
"workbox-routing": "^7.3.0",
"workbox-strategies": "^7.3.0"
},
"engines": {
"node": "^28 || ^26 || ^24 || ^22 || ^20 || ^18",
"npm": ">= 6.13.4",
"yarn": ">= 1.21.1"
}
}
import { boot } from 'quasar/wrappers';
import axios from 'axios';
import { API_BASE_URL } from 'src/config/api-config';
// Create axios instance
const api = axios.create({
baseURL: API_BASE_URL,
headers: {
'Content-Type': 'application/json',
},
});
// Request interceptor
api.interceptors.request.use(
(config) => {
const token = localStorage.getItem('token');
if (token) {
config.headers.Authorization = `Bearer ${token}`;
}
return config;
},
(error) => {
return Promise.reject(new Error(String(error)));
}
);
// Response interceptor
api.interceptors.response.use(
(response) => response,
async (error) => {
const originalRequest = error.config;
// If error is 401 and we haven't tried to refresh token yet
if (error.response?.status === 401 && !originalRequest._retry) {
originalRequest._retry = true;
try {
const refreshToken = localStorage.getItem('refreshToken');
if (!refreshToken) {
throw new Error('No refresh token available');
}
// Call refresh token endpoint
const response = await api.post('/api/v1/auth/refresh-token', {
refresh_token: refreshToken,
});
const { access_token } = response.data;
localStorage.setItem('token', access_token);
// Retry the original request with new token
originalRequest.headers.Authorization = `Bearer ${access_token}`;
return api(originalRequest);
} catch (refreshError) {
// If refresh token fails, clear storage and redirect to login
localStorage.removeItem('token');
localStorage.removeItem('refreshToken');
window.location.href = '/login';
return Promise.reject(new Error(String(refreshError)));
}
}
return Promise.reject(new Error(String(error)));
}
);
export default boot(({ app }) => {
app.config.globalProperties.$axios = axios;
app.config.globalProperties.$api = api;
});
export { api };
Account Settings
Loading profile...
{{ error }}
Profile Information
Change Password
Notification Preferences
Email NotificationsReceive email notifications for important updatesList UpdatesGet notified when lists are updatedGroup ActivitiesReceive notifications for group activities
# app/api/v1/endpoints/costs.py
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, selectinload
from decimal import Decimal, ROUND_HALF_UP
from app.database import get_db
from app.api.dependencies import get_current_user
from app.models import (
User as UserModel,
Group as GroupModel,
List as ListModel,
Expense as ExpenseModel,
Item as ItemModel,
UserGroup as UserGroupModel,
SplitTypeEnum,
ExpenseSplit as ExpenseSplitModel,
Settlement as SettlementModel
)
from app.schemas.cost import ListCostSummary, GroupBalanceSummary
from app.schemas.expense import ExpenseCreate
from app.crud import list as crud_list
from app.crud import expense as crud_expense
from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError, GroupNotFoundError
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get(
"/lists/{list_id}/cost-summary",
response_model=ListCostSummary,
summary="Get Cost Summary for a List",
tags=["Costs"],
responses={
status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this list"},
status.HTTP_404_NOT_FOUND: {"description": "List or associated user not found"}
}
)
async def get_list_cost_summary(
list_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Retrieves a calculated cost summary for a specific list, detailing total costs,
equal shares per user, and individual user balances based on their contributions.
The user must have access to the list to view its cost summary.
Costs are split among group members if the list belongs to a group, or just for
the creator if it's a personal list. All users who added items with prices are
included in the calculation.
"""
logger.info(f"User {current_user.email} requesting cost summary for list {list_id}")
# 1. Verify user has access to the target list
try:
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
except ListPermissionError as e:
logger.warning(f"Permission denied for user {current_user.email} on list {list_id}: {str(e)}")
raise
except ListNotFoundError as e:
logger.warning(f"List {list_id} not found when checking permissions for cost summary: {str(e)}")
raise
# 2. Get the list with its items and users
list_result = await db.execute(
select(ListModel)
.options(
selectinload(ListModel.items).options(selectinload(ItemModel.added_by_user)),
selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))),
selectinload(ListModel.creator)
)
.where(ListModel.id == list_id)
)
db_list = list_result.scalars().first()
if not db_list:
raise ListNotFoundError(list_id)
# 3. Get or create an expense for this list
expense_result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.list_id == list_id)
.options(selectinload(ExpenseModel.splits))
)
db_expense = expense_result.scalars().first()
if not db_expense:
# Create a new expense for this list
total_amount = sum(item.price for item in db_list.items if item.price is not None and item.price > Decimal("0"))
if total_amount == Decimal("0"):
return ListCostSummary(
list_id=db_list.id,
list_name=db_list.name,
total_list_cost=Decimal("0.00"),
num_participating_users=0,
equal_share_per_user=Decimal("0.00"),
user_balances=[]
)
# Create expense with ITEM_BASED split type
expense_in = ExpenseCreate(
description=f"Cost summary for list {db_list.name}",
total_amount=total_amount,
list_id=list_id,
split_type=SplitTypeEnum.ITEM_BASED,
paid_by_user_id=current_user.id # Use current user as payer for now
)
db_expense = await crud_expense.create_expense(db=db, expense_in=expense_in)
# 4. Calculate cost summary from expense splits
participating_users = set()
user_items_added_value = {}
total_list_cost = Decimal("0.00")
# Get all users who added items
for item in db_list.items:
if item.price is not None and item.price > Decimal("0") and item.added_by_user:
participating_users.add(item.added_by_user)
user_items_added_value[item.added_by_user.id] = user_items_added_value.get(item.added_by_user.id, Decimal("0.00")) + item.price
total_list_cost += item.price
# Get all users from expense splits
for split in db_expense.splits:
if split.user:
participating_users.add(split.user)
num_participating_users = len(participating_users)
if num_participating_users == 0:
return ListCostSummary(
list_id=db_list.id,
list_name=db_list.name,
total_list_cost=Decimal("0.00"),
num_participating_users=0,
equal_share_per_user=Decimal("0.00"),
user_balances=[]
)
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
remainder = total_list_cost - (equal_share_per_user * num_participating_users)
user_balances = []
first_user_processed = False
for user in participating_users:
items_added = user_items_added_value.get(user.id, Decimal("0.00"))
current_user_share = equal_share_per_user
if not first_user_processed and remainder != Decimal("0"):
current_user_share += remainder
first_user_processed = True
balance = items_added - current_user_share
user_identifier = user.name if user.name else user.email
user_balances.append(
UserCostShare(
user_id=user.id,
user_identifier=user_identifier,
items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
)
)
user_balances.sort(key=lambda x: x.user_identifier)
return ListCostSummary(
list_id=db_list.id,
list_name=db_list.name,
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
num_participating_users=num_participating_users,
equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
user_balances=user_balances
)
@router.get(
"/groups/{group_id}/balance-summary",
response_model=GroupBalanceSummary,
summary="Get Detailed Balance Summary for a Group",
tags=["Costs", "Groups"],
responses={
status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this group"},
status.HTTP_404_NOT_FOUND: {"description": "Group not found"}
}
)
async def get_group_balance_summary(
group_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Retrieves a detailed financial balance summary for all users within a specific group.
It considers all expenses, their splits, and all settlements recorded for the group.
The user must be a member of the group to view its balance summary.
"""
logger.info(f"User {current_user.email} requesting balance summary for group {group_id}")
# 1. Verify user is a member of the target group
group_check = await db.execute(
select(GroupModel)
.options(selectinload(GroupModel.member_associations))
.where(GroupModel.id == group_id)
)
db_group_for_check = group_check.scalars().first()
if not db_group_for_check:
raise GroupNotFoundError(group_id)
user_is_member = any(assoc.user_id == current_user.id for assoc in db_group_for_check.member_associations)
if not user_is_member:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"User not a member of group {group_id}")
# 2. Get all expenses and settlements for the group
expenses_result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.group_id == group_id)
.options(selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user))
)
expenses = expenses_result.scalars().all()
settlements_result = await db.execute(
select(SettlementModel)
.where(SettlementModel.group_id == group_id)
.options(
selectinload(SettlementModel.paid_by_user),
selectinload(SettlementModel.paid_to_user)
)
)
settlements = settlements_result.scalars().all()
# 3. Calculate user balances
user_balances_data = {}
for assoc in db_group_for_check.member_associations:
if assoc.user:
user_balances_data[assoc.user.id] = UserBalanceDetail(
user_id=assoc.user.id,
user_identifier=assoc.user.name if assoc.user.name else assoc.user.email
)
# Process expenses
for expense in expenses:
if expense.paid_by_user_id in user_balances_data:
user_balances_data[expense.paid_by_user_id].total_paid_for_expenses += expense.total_amount
for split in expense.splits:
if split.user_id in user_balances_data:
user_balances_data[split.user_id].total_share_of_expenses += split.owed_amount
# Process settlements
for settlement in settlements:
if settlement.paid_by_user_id in user_balances_data:
user_balances_data[settlement.paid_by_user_id].total_settlements_paid += settlement.amount
if settlement.paid_to_user_id in user_balances_data:
user_balances_data[settlement.paid_to_user_id].total_settlements_received += settlement.amount
# Calculate net balances
final_user_balances = []
for user_id, data in user_balances_data.items():
data.net_balance = (
data.total_paid_for_expenses + data.total_settlements_received
) - (data.total_share_of_expenses + data.total_settlements_paid)
data.total_paid_for_expenses = data.total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_share_of_expenses = data.total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_settlements_paid = data.total_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_settlements_received = data.total_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.net_balance = data.net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
final_user_balances.append(data)
# Sort by user identifier
final_user_balances.sort(key=lambda x: x.user_identifier)
# Calculate suggested settlements
suggested_settlements = calculate_suggested_settlements(final_user_balances)
return GroupBalanceSummary(
group_id=db_group_for_check.id,
group_name=db_group_for_check.name,
user_balances=final_user_balances,
suggested_settlements=suggested_settlements
)
# app/api/v1/endpoints/items.py
import logging
from typing import List as PyList, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Response, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.api.dependencies import get_current_user
# --- Import Models Correctly ---
from app.models import User as UserModel
from app.models import Item as ItemModel # <-- IMPORT Item and alias it
# --- End Import Models ---
from app.schemas.item import ItemCreate, ItemUpdate, ItemPublic
from app.crud import item as crud_item
from app.crud import list as crud_list
from app.core.exceptions import ItemNotFoundError, ListPermissionError, ConflictError
logger = logging.getLogger(__name__)
router = APIRouter()
# --- Helper Dependency for Item Permissions ---
# Now ItemModel is defined before being used as a type hint
async def get_item_and_verify_access(
item_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user)
) -> ItemModel:
"""Dependency to get an item and verify the user has access to its list."""
item_db = await crud_item.get_item_by_id(db, item_id=item_id)
if not item_db:
raise ItemNotFoundError(item_id)
# Check permission on the parent list
try:
await crud_list.check_list_permission(db=db, list_id=item_db.list_id, user_id=current_user.id)
except ListPermissionError as e:
# Re-raise with a more specific message
raise ListPermissionError(item_db.list_id, "access this item's list")
return item_db
# --- Endpoints ---
@router.post(
"/lists/{list_id}/items", # Nested under lists
response_model=ItemPublic,
status_code=status.HTTP_201_CREATED,
summary="Add Item to List",
tags=["Items"]
)
async def create_list_item(
list_id: int,
item_in: ItemCreate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""Adds a new item to a specific list. User must have access to the list."""
user_email = current_user.email # Access email attribute before async operations
logger.info(f"User {user_email} adding item to list {list_id}: {item_in.name}")
# Verify user has access to the target list
try:
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
except ListPermissionError as e:
# Re-raise with a more specific message
raise ListPermissionError(list_id, "add items to this list")
created_item = await crud_item.create_item(
db=db, item_in=item_in, list_id=list_id, user_id=current_user.id
)
logger.info(f"Item '{created_item.name}' (ID: {created_item.id}) added to list {list_id} by user {user_email}.")
return created_item
@router.get(
"/lists/{list_id}/items", # Nested under lists
response_model=PyList[ItemPublic],
summary="List Items in List",
tags=["Items"]
)
async def read_list_items(
list_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
# Add sorting/filtering params later if needed: sort_by: str = 'created_at', order: str = 'asc'
):
"""Retrieves all items for a specific list if the user has access."""
user_email = current_user.email # Access email attribute before async operations
logger.info(f"User {user_email} listing items for list {list_id}")
# Verify user has access to the list
try:
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
except ListPermissionError as e:
# Re-raise with a more specific message
raise ListPermissionError(list_id, "view items in this list")
items = await crud_item.get_items_by_list_id(db=db, list_id=list_id)
return items
@router.put(
"/items/{item_id}", # Operate directly on item ID
response_model=ItemPublic,
summary="Update Item",
tags=["Items"],
responses={
status.HTTP_409_CONFLICT: {"description": "Conflict: Item has been modified by someone else"}
}
)
async def update_item(
item_id: int, # Item ID from path
item_in: ItemUpdate,
item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user), # Need user ID for completed_by
):
"""
Updates an item's details (name, quantity, is_complete, price).
User must have access to the list the item belongs to.
The client MUST provide the current `version` of the item in the `item_in` payload.
If the version does not match, a 409 Conflict is returned.
Sets/unsets `completed_by_id` based on `is_complete` flag.
"""
user_email = current_user.email # Access email attribute before async operations
logger.info(f"User {user_email} attempting to update item ID: {item_id} with version {item_in.version}")
# Permission check is handled by get_item_and_verify_access dependency
try:
updated_item = await crud_item.update_item(
db=db, item_db=item_db, item_in=item_in, user_id=current_user.id
)
logger.info(f"Item {item_id} updated successfully by user {user_email} to version {updated_item.version}.")
return updated_item
except ConflictError as e:
logger.warning(f"Conflict updating item {item_id} for user {user_email}: {str(e)}")
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
except Exception as e:
logger.error(f"Error updating item {item_id} for user {user_email}: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the item.")
@router.delete(
"/items/{item_id}", # Operate directly on item ID
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete Item",
tags=["Items"],
responses={
status.HTTP_409_CONFLICT: {"description": "Conflict: Item has been modified, cannot delete specified version"}
}
)
async def delete_item(
item_id: int, # Item ID from path
expected_version: Optional[int] = Query(None, description="The expected version of the item to delete for optimistic locking."),
item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user), # Log who deleted it
):
"""
Deletes an item. User must have access to the list the item belongs to.
If `expected_version` is provided and does not match the item's current version,
a 409 Conflict is returned.
"""
user_email = current_user.email # Access email attribute before async operations
logger.info(f"User {user_email} attempting to delete item ID: {item_id}, expected version: {expected_version}")
# Permission check is handled by get_item_and_verify_access dependency
if expected_version is not None and item_db.version != expected_version:
logger.warning(
f"Conflict deleting item {item_id} for user {user_email}. "
f"Expected version {expected_version}, actual version {item_db.version}."
)
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Item has been modified. Expected version {expected_version}, but current version is {item_db.version}. Please refresh."
)
await crud_item.delete_item(db=db, item_db=item_db)
logger.info(f"Item {item_id} (version {item_db.version}) deleted successfully by user {user_email}.")
return Response(status_code=status.HTTP_204_NO_CONTENT)
# app/crud/group.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload # For eager loading members
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List
from sqlalchemy import func
from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel
from app.schemas.group import GroupCreate
from app.models import UserRoleEnum # Import enum
from app.core.exceptions import (
GroupOperationError,
GroupNotFoundError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
GroupMembershipError,
GroupPermissionError # Import GroupPermissionError
)
# --- Group CRUD ---
async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) -> GroupModel:
"""Creates a group and adds the creator as the owner."""
try:
async with db.begin():
db_group = GroupModel(name=group_in.name, created_by_id=creator_id)
db.add(db_group)
await db.flush()
db_user_group = UserGroupModel(
user_id=creator_id,
group_id=db_group.id,
role=UserRoleEnum.owner
)
db.add(db_user_group)
await db.flush()
await db.refresh(db_group)
return db_group
except IntegrityError as e:
raise DatabaseIntegrityError(f"Failed to create group: {str(e)}")
except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseTransactionError(f"Failed to create group: {str(e)}")
async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]:
"""Gets all groups a user is a member of."""
try:
result = await db.execute(
select(GroupModel)
.join(UserGroupModel)
.where(UserGroupModel.user_id == user_id)
.options(selectinload(GroupModel.member_associations))
)
return result.scalars().all()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query user groups: {str(e)}")
async def get_group_by_id(db: AsyncSession, group_id: int) -> Optional[GroupModel]:
"""Gets a single group by its ID, optionally loading members."""
try:
result = await db.execute(
select(GroupModel)
.where(GroupModel.id == group_id)
.options(
selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user)
)
)
return result.scalars().first()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query group: {str(e)}")
async def is_user_member(db: AsyncSession, group_id: int, user_id: int) -> bool:
"""Checks if a user is a member of a specific group."""
try:
result = await db.execute(
select(UserGroupModel.id)
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
.limit(1)
)
return result.scalar_one_or_none() is not None
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to check group membership: {str(e)}")
async def get_user_role_in_group(db: AsyncSession, group_id: int, user_id: int) -> Optional[UserRoleEnum]:
"""Gets the role of a user in a specific group."""
try:
result = await db.execute(
select(UserGroupModel.role)
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
)
return result.scalar_one_or_none()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query user role: {str(e)}")
async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role: UserRoleEnum = UserRoleEnum.member) -> Optional[UserGroupModel]:
"""Adds a user to a group if they aren't already a member."""
try:
async with db.begin():
existing = await db.execute(
select(UserGroupModel).where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
)
if existing.scalar_one_or_none():
return None
db_user_group = UserGroupModel(user_id=user_id, group_id=group_id, role=role)
db.add(db_user_group)
await db.flush()
await db.refresh(db_user_group)
return db_user_group
except IntegrityError as e:
raise DatabaseIntegrityError(f"Failed to add user to group: {str(e)}")
except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseTransactionError(f"Failed to add user to group: {str(e)}")
async def remove_user_from_group(db: AsyncSession, group_id: int, user_id: int) -> bool:
"""Removes a user from a group."""
try:
async with db.begin():
result = await db.execute(
delete(UserGroupModel)
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
.returning(UserGroupModel.id)
)
return result.scalar_one_or_none() is not None
except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseTransactionError(f"Failed to remove user from group: {str(e)}")
async def get_group_member_count(db: AsyncSession, group_id: int) -> int:
"""Counts the number of members in a group."""
try:
result = await db.execute(
select(func.count(UserGroupModel.id)).where(UserGroupModel.group_id == group_id)
)
return result.scalar_one()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to count group members: {str(e)}")
async def check_group_membership(
db: AsyncSession,
group_id: int,
user_id: int,
action: str = "access this group"
) -> None:
"""
Checks if a user is a member of a group. Raises exceptions if not found or not a member.
Raises:
GroupNotFoundError: If the group_id does not exist.
GroupMembershipError: If the user_id is not a member of the group.
"""
try:
# Check group existence first
group_exists = await db.get(GroupModel, group_id)
if not group_exists:
raise GroupNotFoundError(group_id)
# Check membership
membership = await db.execute(
select(UserGroupModel.id)
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
.limit(1)
)
if membership.scalar_one_or_none() is None:
raise GroupMembershipError(group_id, action=action)
# If we reach here, the user is a member
return None
except GroupNotFoundError: # Re-raise specific errors
raise
except GroupMembershipError:
raise
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database while checking membership: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to check group membership: {str(e)}")
async def check_user_role_in_group(
db: AsyncSession,
group_id: int,
user_id: int,
required_role: UserRoleEnum,
action: str = "perform this action"
) -> None:
"""
Checks if a user is a member of a group and has the required role (or higher).
Raises:
GroupNotFoundError: If the group_id does not exist.
GroupMembershipError: If the user_id is not a member of the group.
GroupPermissionError: If the user does not have the required role.
"""
# First, ensure user is a member (this also checks group existence)
await check_group_membership(db, group_id, user_id, action=f"be checked for permissions to {action}")
# Get the user's actual role
actual_role = await get_user_role_in_group(db, group_id, user_id)
# Define role hierarchy (assuming owner > member)
role_hierarchy = {UserRoleEnum.owner: 2, UserRoleEnum.member: 1}
if not actual_role or role_hierarchy.get(actual_role, 0) < role_hierarchy.get(required_role, 0):
raise GroupPermissionError(
group_id=group_id,
action=f"{action} (requires at least '{required_role.value}' role)"
)
# If role is sufficient, return None
return None
# app/main.py
import logging
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.api_router import api_router
from app.config import settings
from app.core.api_config import API_METADATA, API_TAGS
# Import database and models if needed for startup/shutdown events later
# from . import database, models
# --- Logging Setup ---
logging.basicConfig(
level=getattr(logging, settings.LOG_LEVEL),
format=settings.LOG_FORMAT
)
logger = logging.getLogger(__name__)
# --- FastAPI App Instance ---
app = FastAPI(
**API_METADATA,
openapi_tags=API_TAGS
)
# --- CORS Middleware ---
# Define allowed origins. Be specific in production!
# Use ["*"] for wide open access during early development if needed,
# but restrict it as soon as possible.
# SvelteKit default dev port is 5173
origins = [
"http://localhost:5174",
"http://localhost:8000", # Allow requests from the API itself (e.g., Swagger UI)
# Add your deployed frontend URL here later
# "https://your-frontend-domain.com",
]
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- End CORS Middleware ---
# --- Include API Routers ---
# All API endpoints will be prefixed with /api
app.include_router(api_router, prefix=settings.API_PREFIX)
# --- End Include API Routers ---
# --- Root Endpoint (Optional - outside the main API structure) ---
@app.get("/", tags=["Root"])
async def read_root():
"""
Provides a simple welcome message at the root path.
Useful for basic reachability checks.
"""
logger.info("Root endpoint '/' accessed.")
return {"message": settings.ROOT_MESSAGE}
# --- End Root Endpoint ---
# --- Application Startup/Shutdown Events (Optional) ---
# @app.on_event("startup")
# async def startup_event():
# logger.info("Application startup: Connecting to database...")
# # You might perform initial checks or warm-up here
# # await database.engine.connect() # Example check (get_db handles sessions per request)
# logger.info("Application startup complete.")
# @app.on_event("shutdown")
# async def shutdown_event():
# logger.info("Application shutdown: Disconnecting from database...")
# # await database.engine.dispose() # Close connection pool
# logger.info("Application shutdown complete.")
# --- End Events ---
# --- Direct Run (for simple local testing if needed) ---
# It's better to use `uvicorn app.main:app --reload` from the terminal
# if __name__ == "__main__":
# logger.info("Starting Uvicorn server directly from main.py")
# uvicorn.run(app, host="0.0.0.0", port=8000)
# ------------------------------------------------------
Create New List
// API Version
export const API_VERSION = 'v1';
// API Base URL
export const API_BASE_URL = import.meta.env.VITE_API_URL || 'http://localhost:8000';
// API Endpoints
export const API_ENDPOINTS = {
// Auth
AUTH: {
LOGIN: '/auth/login',
SIGNUP: '/auth/signup',
REFRESH_TOKEN: '/auth/refresh-token',
LOGOUT: '/auth/logout',
VERIFY_EMAIL: '/auth/verify-email',
RESET_PASSWORD: '/auth/reset-password',
FORGOT_PASSWORD: '/auth/forgot-password',
},
// Users
USERS: {
PROFILE: '/users/me',
UPDATE_PROFILE: '/users/me',
PASSWORD: '/users/password',
AVATAR: '/users/avatar',
SETTINGS: '/users/settings',
NOTIFICATIONS: '/users/notifications',
PREFERENCES: '/users/preferences',
},
// Lists
LISTS: {
BASE: '/lists',
BY_ID: (id: string) => `/lists/${id}`,
ITEMS: (listId: string) => `/lists/${listId}/items`,
ITEM: (listId: string, itemId: string) => `/lists/${listId}/items/${itemId}`,
SHARE: (listId: string) => `/lists/${listId}/share`,
UNSHARE: (listId: string) => `/lists/${listId}/unshare`,
COMPLETE: (listId: string) => `/lists/${listId}/complete`,
REOPEN: (listId: string) => `/lists/${listId}/reopen`,
ARCHIVE: (listId: string) => `/lists/${listId}/archive`,
RESTORE: (listId: string) => `/lists/${listId}/restore`,
DUPLICATE: (listId: string) => `/lists/${listId}/duplicate`,
EXPORT: (listId: string) => `/lists/${listId}/export`,
IMPORT: '/lists/import',
},
// Groups
GROUPS: {
BASE: '/groups',
BY_ID: (id: string) => `/groups/${id}`,
LISTS: (groupId: string) => `/groups/${groupId}/lists`,
MEMBERS: (groupId: string) => `/groups/${groupId}/members`,
MEMBER: (groupId: string, userId: string) => `/groups/${groupId}/members/${userId}`,
LEAVE: (groupId: string) => `/groups/${groupId}/leave`,
DELETE: (groupId: string) => `/groups/${groupId}`,
SETTINGS: (groupId: string) => `/groups/${groupId}/settings`,
ROLES: (groupId: string) => `/groups/${groupId}/roles`,
ROLE: (groupId: string, roleId: string) => `/groups/${groupId}/roles/${roleId}`,
},
// Invites
INVITES: {
BASE: '/invites',
BY_ID: (id: string) => `/invites/${id}`,
ACCEPT: (id: string) => `/invites/${id}/accept`,
DECLINE: (id: string) => `/invites/${id}/decline`,
REVOKE: (id: string) => `/invites/${id}/revoke`,
LIST: '/invites',
PENDING: '/invites/pending',
SENT: '/invites/sent',
},
// Items (for direct operations like update, get by ID)
ITEMS: {
BY_ID: (itemId: string) => `/items/${itemId}`,
},
// OCR
OCR: {
PROCESS: '/ocr/extract-items',
STATUS: (jobId: string) => `/ocr/status/${jobId}`,
RESULT: (jobId: string) => `/ocr/result/${jobId}`,
BATCH: '/ocr/batch',
CANCEL: (jobId: string) => `/ocr/cancel/${jobId}`,
HISTORY: '/ocr/history',
},
// Costs
COSTS: {
BASE: '/costs',
LIST_SUMMARY: (listId: string | number) => `/costs/lists/${listId}/cost-summary`,
GROUP_BALANCE_SUMMARY: (groupId: string | number) => `/costs/groups/${groupId}/balance-summary`,
},
// Financials
FINANCIALS: {
EXPENSES: '/financials/expenses',
EXPENSE: (id: string) => `/financials/expenses/${id}`,
SETTLEMENTS: '/financials/settlements',
SETTLEMENT: (id: string) => `/financials/settlements/${id}`,
BALANCES: '/financials/balances',
BALANCE: (userId: string) => `/financials/balances/${userId}`,
REPORTS: '/financials/reports',
REPORT: (id: string) => `/financials/reports/${id}`,
CATEGORIES: '/financials/categories',
CATEGORY: (id: string) => `/financials/categories/${id}`,
},
// Health
HEALTH: {
CHECK: '/health',
VERSION: '/health/version',
STATUS: '/health/status',
METRICS: '/health/metrics',
LOGS: '/health/logs',
},
};
from fastapi import HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from app.config import settings
from typing import Optional
class ListNotFoundError(HTTPException):
"""Raised when a list is not found."""
def __init__(self, list_id: int):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"List {list_id} not found"
)
class ListPermissionError(HTTPException):
"""Raised when a user doesn't have permission to access a list."""
def __init__(self, list_id: int, action: str = "access"):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"You do not have permission to {action} list {list_id}"
)
class ListCreatorRequiredError(HTTPException):
"""Raised when an action requires the list creator but the user is not the creator."""
def __init__(self, list_id: int, action: str):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Only the list creator can {action} list {list_id}"
)
class GroupNotFoundError(HTTPException):
"""Raised when a group is not found."""
def __init__(self, group_id: int):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Group {group_id} not found"
)
class GroupPermissionError(HTTPException):
"""Raised when a user doesn't have permission to perform an action in a group."""
def __init__(self, group_id: int, action: str):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"You do not have permission to {action} in group {group_id}"
)
class GroupMembershipError(HTTPException):
"""Raised when a user attempts to perform an action that requires group membership."""
def __init__(self, group_id: int, action: str = "access"):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"You must be a member of group {group_id} to {action}"
)
class GroupOperationError(HTTPException):
"""Raised when a group operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class GroupValidationError(HTTPException):
"""Raised when a group operation is invalid."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=detail
)
class ItemNotFoundError(HTTPException):
"""Raised when an item is not found."""
def __init__(self, item_id: int):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Item {item_id} not found"
)
class UserNotFoundError(HTTPException):
"""Raised when a user is not found."""
def __init__(self, user_id: Optional[int] = None, identifier: Optional[str] = None):
detail_msg = "User not found."
if user_id:
detail_msg = f"User with ID {user_id} not found."
elif identifier:
detail_msg = f"User with identifier '{identifier}' not found."
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail_msg
)
class InvalidOperationError(HTTPException):
"""Raised when an operation is invalid or disallowed by business logic."""
def __init__(self, detail: str, status_code: int = status.HTTP_400_BAD_REQUEST):
super().__init__(
status_code=status_code,
detail=detail
)
class DatabaseConnectionError(HTTPException):
"""Raised when there is an error connecting to the database."""
def __init__(self):
super().__init__(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=settings.DB_CONNECTION_ERROR
)
class DatabaseIntegrityError(HTTPException):
"""Raised when a database integrity constraint is violated."""
def __init__(self):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=settings.DB_INTEGRITY_ERROR
)
class DatabaseTransactionError(HTTPException):
"""Raised when a database transaction fails."""
def __init__(self, detail: str = settings.DB_TRANSACTION_ERROR):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class DatabaseQueryError(HTTPException):
"""Raised when a database query fails."""
def __init__(self, detail: str = settings.DB_QUERY_ERROR):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class OCRServiceUnavailableError(HTTPException):
"""Raised when the OCR service is unavailable."""
def __init__(self):
super().__init__(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=settings.OCR_SERVICE_UNAVAILABLE
)
class OCRServiceConfigError(HTTPException):
"""Raised when there is an error in the OCR service configuration."""
def __init__(self):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=settings.OCR_SERVICE_CONFIG_ERROR
)
class OCRUnexpectedError(HTTPException):
"""Raised when there is an unexpected error in the OCR service."""
def __init__(self):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=settings.OCR_UNEXPECTED_ERROR
)
class OCRQuotaExceededError(HTTPException):
"""Raised when the OCR service quota is exceeded."""
def __init__(self):
super().__init__(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=settings.OCR_QUOTA_EXCEEDED
)
class InvalidFileTypeError(HTTPException):
"""Raised when an invalid file type is uploaded for OCR."""
def __init__(self):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=settings.OCR_INVALID_FILE_TYPE.format(types=", ".join(settings.ALLOWED_IMAGE_TYPES))
)
class FileTooLargeError(HTTPException):
"""Raised when an uploaded file exceeds the size limit."""
def __init__(self):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=settings.OCR_FILE_TOO_LARGE.format(size=settings.MAX_FILE_SIZE_MB)
)
class OCRProcessingError(HTTPException):
"""Raised when there is an error processing the image with OCR."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=settings.OCR_PROCESSING_ERROR.format(detail=detail)
)
class EmailAlreadyRegisteredError(HTTPException):
"""Raised when attempting to register with an email that is already in use."""
def __init__(self):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered."
)
class UserCreationError(HTTPException):
"""Raised when there is an error creating a new user."""
def __init__(self):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred during user creation."
)
class InviteNotFoundError(HTTPException):
"""Raised when an invite is not found."""
def __init__(self, invite_code: str):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Invite code {invite_code} not found"
)
class InviteExpiredError(HTTPException):
"""Raised when an invite has expired."""
def __init__(self, invite_code: str):
super().__init__(
status_code=status.HTTP_410_GONE,
detail=f"Invite code {invite_code} has expired"
)
class InviteAlreadyUsedError(HTTPException):
"""Raised when an invite has already been used."""
def __init__(self, invite_code: str):
super().__init__(
status_code=status.HTTP_410_GONE,
detail=f"Invite code {invite_code} has already been used"
)
class InviteCreationError(HTTPException):
"""Raised when an invite cannot be created."""
def __init__(self, group_id: int):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create invite for group {group_id}"
)
class ListStatusNotFoundError(HTTPException):
"""Raised when a list's status cannot be retrieved."""
def __init__(self, list_id: int):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Status for list {list_id} not found"
)
class ConflictError(HTTPException):
"""Raised when an optimistic lock version conflict occurs."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_409_CONFLICT,
detail=detail
)
class InvalidCredentialsError(HTTPException):
"""Raised when login credentials are invalid."""
def __init__(self):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.AUTH_INVALID_CREDENTIALS,
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_credentials\""}
)
class NotAuthenticatedError(HTTPException):
"""Raised when the user is not authenticated."""
def __init__(self):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.AUTH_NOT_AUTHENTICATED,
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"not_authenticated\""}
)
class JWTError(HTTPException):
"""Raised when there is an error with the JWT token."""
def __init__(self, error: str):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.JWT_ERROR.format(error=error),
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
)
class JWTUnexpectedError(HTTPException):
"""Raised when there is an unexpected error with the JWT token."""
def __init__(self, error: str):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.JWT_UNEXPECTED_ERROR.format(error=error),
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
)
# app/crud/list.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import or_, and_, delete as sql_delete, func as sql_func, desc
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List as PyList
from app.schemas.list import ListStatus
from app.models import List as ListModel, UserGroup as UserGroupModel, Item as ItemModel
from app.schemas.list import ListCreate, ListUpdate
from app.core.exceptions import (
ListNotFoundError,
ListPermissionError,
ListCreatorRequiredError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
ConflictError
)
async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
"""Creates a new list record."""
try:
async with db.begin():
db_list = ListModel(
name=list_in.name,
description=list_in.description,
group_id=list_in.group_id,
created_by_id=creator_id,
is_complete=False
)
db.add(db_list)
await db.flush()
await db.refresh(db_list)
return db_list
except IntegrityError as e:
raise DatabaseIntegrityError(f"Failed to create list: {str(e)}")
except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseTransactionError(f"Failed to create list: {str(e)}")
async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel]:
"""Gets all lists accessible by a user."""
try:
group_ids_result = await db.execute(
select(UserGroupModel.group_id).where(UserGroupModel.user_id == user_id)
)
user_group_ids = group_ids_result.scalars().all()
# Build conditions for the OR clause dynamically
conditions = [
and_(ListModel.created_by_id == user_id, ListModel.group_id.is_(None))
]
if user_group_ids: # Only add the IN clause if there are group IDs
conditions.append(ListModel.group_id.in_(user_group_ids))
query = select(ListModel).where(or_(*conditions)).order_by(ListModel.updated_at.desc())
result = await db.execute(query)
return result.scalars().all()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query user lists: {str(e)}")
async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = False) -> Optional[ListModel]:
"""Gets a single list by ID, optionally loading its items."""
try:
query = select(ListModel).where(ListModel.id == list_id)
if load_items:
query = query.options(
selectinload(ListModel.items)
.options(
joinedload(ItemModel.added_by_user),
joinedload(ItemModel.completed_by_user)
)
)
result = await db.execute(query)
return result.scalars().first()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query list: {str(e)}")
async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel:
"""Updates an existing list record, checking for version conflicts."""
try:
async with db.begin():
if list_db.version != list_in.version:
raise ConflictError(
f"List '{list_db.name}' (ID: {list_db.id}) has been modified. "
f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh."
)
update_data = list_in.model_dump(exclude_unset=True, exclude={'version'})
for key, value in update_data.items():
setattr(list_db, key, value)
list_db.version += 1
db.add(list_db)
await db.flush()
await db.refresh(list_db)
return list_db
except IntegrityError as e:
await db.rollback()
raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}")
except OperationalError as e:
await db.rollback()
raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}")
except ConflictError:
await db.rollback()
raise
except SQLAlchemyError as e:
await db.rollback()
raise DatabaseTransactionError(f"Failed to update list: {str(e)}")
async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
"""Deletes a list record. Version check should be done by the caller (API endpoint)."""
try:
async with db.begin():
await db.delete(list_db)
return None
except OperationalError as e:
await db.rollback()
raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}")
except SQLAlchemyError as e:
await db.rollback()
raise DatabaseTransactionError(f"Failed to delete list: {str(e)}")
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel:
"""Fetches a list and verifies user permission."""
try:
list_db = await get_list_by_id(db, list_id=list_id, load_items=True)
if not list_db:
raise ListNotFoundError(list_id)
is_creator = list_db.created_by_id == user_id
if require_creator:
if not is_creator:
raise ListCreatorRequiredError(list_id, "access")
return list_db
if is_creator:
return list_db
if list_db.group_id:
from app.crud.group import is_user_member
is_member = await is_user_member(db, group_id=list_db.group_id, user_id=user_id)
if not is_member:
raise ListPermissionError(list_id)
return list_db
else:
raise ListPermissionError(list_id)
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to check list permissions: {str(e)}")
async def get_list_status(db: AsyncSession, list_id: int) -> ListStatus:
"""Gets the update timestamps and item count for a list."""
try:
list_query = select(ListModel.updated_at).where(ListModel.id == list_id)
list_result = await db.execute(list_query)
list_updated_at = list_result.scalar_one_or_none()
if list_updated_at is None:
raise ListNotFoundError(list_id)
item_status_query = (
select(
sql_func.max(ItemModel.updated_at).label("latest_item_updated_at"),
sql_func.count(ItemModel.id).label("item_count")
)
.where(ItemModel.list_id == list_id)
)
item_result = await db.execute(item_status_query)
item_status = item_result.first()
return ListStatus(
list_updated_at=list_updated_at,
latest_item_updated_at=item_status.latest_item_updated_at if item_status else None,
item_count=item_status.item_count if item_status else 0
)
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to get list status: {str(e)}")
# app/models.py
import enum
import secrets
from datetime import datetime, timedelta, timezone
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
ForeignKey,
Boolean,
Enum as SAEnum,
UniqueConstraint,
Index,
DDL,
event,
delete,
func,
text as sa_text,
Text, # <-- Add Text for description
Numeric # <-- Add Numeric for price
)
from sqlalchemy.orm import relationship, backref
from .database import Base
# --- Enums ---
class UserRoleEnum(enum.Enum):
owner = "owner"
member = "member"
class SplitTypeEnum(enum.Enum):
EQUAL = "EQUAL" # Split equally among all involved users
EXACT_AMOUNTS = "EXACT_AMOUNTS" # Specific amounts for each user (defined in ExpenseSplit)
PERCENTAGE = "PERCENTAGE" # Percentage for each user (defined in ExpenseSplit)
SHARES = "SHARES" # Proportional to shares/units (defined in ExpenseSplit)
ITEM_BASED = "ITEM_BASED" # If an expense is derived directly from item prices and who added them
# Add more types as needed, e.g., UNPAID (for tracking debts not part of a formal expense)
# --- User Model ---
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
email = Column(String, unique=True, index=True, nullable=False)
password_hash = Column(String, nullable=False)
name = Column(String, index=True, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
# --- Relationships ---
created_groups = relationship("Group", back_populates="creator")
group_associations = relationship("UserGroup", back_populates="user", cascade="all, delete-orphan")
created_invites = relationship("Invite", back_populates="creator")
# --- NEW Relationships for Lists/Items ---
created_lists = relationship("List", foreign_keys="List.created_by_id", back_populates="creator") # Link List.created_by_id -> User
added_items = relationship("Item", foreign_keys="Item.added_by_id", back_populates="added_by_user") # Link Item.added_by_id -> User
completed_items = relationship("Item", foreign_keys="Item.completed_by_id", back_populates="completed_by_user") # Link Item.completed_by_id -> User
# --- End NEW Relationships ---
# --- Relationships for Cost Splitting ---
expenses_paid = relationship("Expense", foreign_keys="Expense.paid_by_user_id", back_populates="paid_by_user", cascade="all, delete-orphan")
expense_splits = relationship("ExpenseSplit", foreign_keys="ExpenseSplit.user_id", back_populates="user", cascade="all, delete-orphan")
settlements_made = relationship("Settlement", foreign_keys="Settlement.paid_by_user_id", back_populates="payer", cascade="all, delete-orphan")
settlements_received = relationship("Settlement", foreign_keys="Settlement.paid_to_user_id", back_populates="payee", cascade="all, delete-orphan")
# --- End Relationships for Cost Splitting ---
# --- Group Model ---
class Group(Base):
__tablename__ = "groups"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, index=True, nullable=False)
created_by_id = Column(Integer, ForeignKey("users.id"), nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
# --- Relationships ---
creator = relationship("User", back_populates="created_groups")
member_associations = relationship("UserGroup", back_populates="group", cascade="all, delete-orphan")
invites = relationship("Invite", back_populates="group", cascade="all, delete-orphan")
# --- NEW Relationship for Lists ---
lists = relationship("List", back_populates="group", cascade="all, delete-orphan") # Link List.group_id -> Group
# --- End NEW Relationship ---
# --- Relationships for Cost Splitting ---
expenses = relationship("Expense", foreign_keys="Expense.group_id", back_populates="group", cascade="all, delete-orphan")
settlements = relationship("Settlement", foreign_keys="Settlement.group_id", back_populates="group", cascade="all, delete-orphan")
# --- End Relationships for Cost Splitting ---
# --- UserGroup Association Model ---
class UserGroup(Base):
__tablename__ = "user_groups"
__table_args__ = (UniqueConstraint('user_id', 'group_id', name='uq_user_group'),)
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
group_id = Column(Integer, ForeignKey("groups.id", ondelete="CASCADE"), nullable=False)
role = Column(SAEnum(UserRoleEnum, name="userroleenum", create_type=True), nullable=False, default=UserRoleEnum.member)
joined_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
user = relationship("User", back_populates="group_associations")
group = relationship("Group", back_populates="member_associations")
# --- Invite Model ---
class Invite(Base):
__tablename__ = "invites"
__table_args__ = (
Index('ix_invites_active_code', 'code', unique=True, postgresql_where=sa_text('is_active = true')),
)
id = Column(Integer, primary_key=True, index=True)
code = Column(String, unique=False, index=True, nullable=False, default=lambda: secrets.token_urlsafe(16))
group_id = Column(Integer, ForeignKey("groups.id", ondelete="CASCADE"), nullable=False)
created_by_id = Column(Integer, ForeignKey("users.id"), nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
expires_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc) + timedelta(days=7))
is_active = Column(Boolean, default=True, nullable=False)
group = relationship("Group", back_populates="invites")
creator = relationship("User", back_populates="created_invites")
# === NEW: List Model ===
class List(Base):
__tablename__ = "lists"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, index=True, nullable=False)
description = Column(Text, nullable=True)
created_by_id = Column(Integer, ForeignKey("users.id"), nullable=False) # Who created this list
group_id = Column(Integer, ForeignKey("groups.id"), nullable=True) # Which group it belongs to (NULL if personal)
is_complete = Column(Boolean, default=False, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
version = Column(Integer, nullable=False, default=1, server_default='1')
# --- Relationships ---
creator = relationship("User", back_populates="created_lists") # Link to User.created_lists
group = relationship("Group", back_populates="lists") # Link to Group.lists
items = relationship("Item", back_populates="list", cascade="all, delete-orphan", order_by="Item.created_at") # Link to Item.list, cascade deletes
# --- Relationships for Cost Splitting ---
expenses = relationship("Expense", foreign_keys="Expense.list_id", back_populates="list", cascade="all, delete-orphan")
# --- End Relationships for Cost Splitting ---
# === NEW: Item Model ===
class Item(Base):
__tablename__ = "items"
id = Column(Integer, primary_key=True, index=True)
list_id = Column(Integer, ForeignKey("lists.id", ondelete="CASCADE"), nullable=False) # Belongs to which list
name = Column(String, index=True, nullable=False)
quantity = Column(String, nullable=True) # Flexible quantity (e.g., "1", "2 lbs", "a bunch")
is_complete = Column(Boolean, default=False, nullable=False)
price = Column(Numeric(10, 2), nullable=True) # For cost splitting later (e.g., 12345678.99)
added_by_id = Column(Integer, ForeignKey("users.id"), nullable=False) # Who added this item
completed_by_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Who marked it complete
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
version = Column(Integer, nullable=False, default=1, server_default='1')
# --- Relationships ---
list = relationship("List", back_populates="items") # Link to List.items
added_by_user = relationship("User", foreign_keys=[added_by_id], back_populates="added_items") # Link to User.added_items
completed_by_user = relationship("User", foreign_keys=[completed_by_id], back_populates="completed_items") # Link to User.completed_items
# --- Relationships for Cost Splitting ---
# If an item directly results in an expense, or an expense can be tied to an item.
expenses = relationship("Expense", back_populates="item") # An item might have multiple associated expenses
# --- End Relationships for Cost Splitting ---
# === NEW Models for Advanced Cost Splitting ===
class Expense(Base):
__tablename__ = "expenses"
id = Column(Integer, primary_key=True, index=True)
description = Column(String, nullable=False)
total_amount = Column(Numeric(10, 2), nullable=False)
currency = Column(String, nullable=False, default="USD") # Consider making this an Enum too if few currencies
expense_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
split_type = Column(SAEnum(SplitTypeEnum, name="splittypeenum", create_type=True), nullable=False)
# Foreign Keys
list_id = Column(Integer, ForeignKey("lists.id"), nullable=True)
group_id = Column(Integer, ForeignKey("groups.id"), nullable=True) # If not list-specific but group-specific
item_id = Column(Integer, ForeignKey("items.id"), nullable=True) # If the expense is for a specific item
paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
version = Column(Integer, nullable=False, default=1, server_default='1')
# Relationships
paid_by_user = relationship("User", foreign_keys=[paid_by_user_id], back_populates="expenses_paid")
list = relationship("List", foreign_keys=[list_id], back_populates="expenses")
group = relationship("Group", foreign_keys=[group_id], back_populates="expenses")
item = relationship("Item", foreign_keys=[item_id], back_populates="expenses")
splits = relationship("ExpenseSplit", back_populates="expense", cascade="all, delete-orphan")
__table_args__ = (
# Example: Ensure either list_id or group_id is present if item_id is null
# CheckConstraint('(item_id IS NOT NULL) OR (list_id IS NOT NULL) OR (group_id IS NOT NULL)', name='chk_expense_context'),
)
class ExpenseSplit(Base):
__tablename__ = "expense_splits"
__table_args__ = (UniqueConstraint('expense_id', 'user_id', name='uq_expense_user_split'),)
id = Column(Integer, primary_key=True, index=True)
expense_id = Column(Integer, ForeignKey("expenses.id", ondelete="CASCADE"), nullable=False)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
owed_amount = Column(Numeric(10, 2), nullable=False) # For EQUAL or EXACT_AMOUNTS
# For PERCENTAGE split (value from 0.00 to 100.00)
share_percentage = Column(Numeric(5, 2), nullable=True)
# For SHARES split (e.g., user A has 2 shares, user B has 3 shares)
share_units = Column(Integer, nullable=True)
# is_settled might be better tracked via actual Settlement records or a reconciliation process
# is_settled = Column(Boolean, default=False, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
# Relationships
expense = relationship("Expense", back_populates="splits")
user = relationship("User", foreign_keys=[user_id], back_populates="expense_splits")
class Settlement(Base):
__tablename__ = "settlements"
id = Column(Integer, primary_key=True, index=True)
group_id = Column(Integer, ForeignKey("groups.id"), nullable=False) # Settlements usually within a group
paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
paid_to_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
amount = Column(Numeric(10, 2), nullable=False)
settlement_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
description = Column(Text, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
version = Column(Integer, nullable=False, default=1, server_default='1')
# Relationships
group = relationship("Group", foreign_keys=[group_id], back_populates="settlements")
payer = relationship("User", foreign_keys=[paid_by_user_id], back_populates="settlements_made")
payee = relationship("User", foreign_keys=[paid_to_user_id], back_populates="settlements_received")
__table_args__ = (
# Ensure payer and payee are different users
# CheckConstraint('paid_by_user_id <> paid_to_user_id', name='chk_settlement_payer_ne_payee'),
)
# Potential future: PaymentMethod model, etc.
Your Groups
{{ group.name }}You are not a member of any groups yet.
Create New Group
{{ pageTitle }}
Loading lists...
{{ error }}
{{ noListsMessage }}
{{ list.name }}{{ list.description || 'No description' }} Personal List
Group List (ID: {{ list.group_id }})
Updated: {{ new Date(list.updated_at).toLocaleDateString() }}
# app/crud/item.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import delete as sql_delete, update as sql_update # Use aliases
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List as PyList
from datetime import datetime, timezone
from app.models import Item as ItemModel
from app.schemas.item import ItemCreate, ItemUpdate
from app.core.exceptions import (
ItemNotFoundError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
ConflictError
)
async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel:
"""Creates a new item record for a specific list."""
try:
db_item = ItemModel(
name=item_in.name,
quantity=item_in.quantity,
list_id=list_id,
added_by_id=user_id,
is_complete=False # Default on creation
# version is implicitly set to 1 by model default
)
db.add(db_item)
await db.flush()
await db.refresh(db_item)
await db.commit() # Explicitly commit here
return db_item
except IntegrityError as e:
await db.rollback() # Rollback on integrity error
raise DatabaseIntegrityError(f"Failed to create item: {str(e)}")
except OperationalError as e:
await db.rollback() # Rollback on operational error
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
await db.rollback() # Rollback on other SQLAlchemy errors
raise DatabaseTransactionError(f"Failed to create item: {str(e)}")
except Exception as e: # Catch any other exception and attempt rollback
await db.rollback()
raise # Re-raise the original exception
async def get_items_by_list_id(db: AsyncSession, list_id: int) -> PyList[ItemModel]:
"""Gets all items belonging to a specific list, ordered by creation time."""
try:
result = await db.execute(
select(ItemModel)
.where(ItemModel.list_id == list_id)
.order_by(ItemModel.created_at.asc()) # Or desc() if preferred
)
return result.scalars().all()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query items: {str(e)}")
async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]:
"""Gets a single item by its ID."""
try:
result = await db.execute(select(ItemModel).where(ItemModel.id == item_id))
return result.scalars().first()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query item: {str(e)}")
async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel:
"""Updates an existing item record, checking for version conflicts."""
try:
# Check version conflict
if item_db.version != item_in.version:
raise ConflictError(
f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. "
f"Your version is {item_in.version}, current version is {item_db.version}. Please refresh."
)
update_data = item_in.model_dump(exclude_unset=True, exclude={'version'}) # Exclude version
# Special handling for is_complete
if 'is_complete' in update_data:
if update_data['is_complete'] is True:
if item_db.completed_by_id is None: # Only set if not already completed by someone
update_data['completed_by_id'] = user_id
else:
update_data['completed_by_id'] = None # Clear if marked incomplete
# Apply updates
for key, value in update_data.items():
setattr(item_db, key, value)
item_db.version += 1 # Increment version
db.add(item_db)
await db.flush()
await db.refresh(item_db)
# Commit the transaction if not part of a larger transaction
await db.commit()
return item_db
except IntegrityError as e:
await db.rollback()
raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}")
except OperationalError as e:
await db.rollback()
raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}")
except ConflictError: # Re-raise ConflictError
await db.rollback()
raise
except SQLAlchemyError as e:
await db.rollback()
raise DatabaseTransactionError(f"Failed to update item: {str(e)}")
async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
"""Deletes an item record. Version check should be done by the caller (API endpoint)."""
try:
await db.delete(item_db)
await db.commit()
return None
except OperationalError as e:
await db.rollback()
raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}")
except SQLAlchemyError as e:
await db.rollback()
raise DatabaseTransactionError(f"Failed to delete item: {str(e)}")
Group: {{ group.name }}
Invite Members
Invite code copied to clipboard!
Loading group details...
Group not found or an error occurred.
from fastapi import APIRouter
from app.api.v1.endpoints import health
from app.api.v1.endpoints import auth
from app.api.v1.endpoints import users
from app.api.v1.endpoints import groups
from app.api.v1.endpoints import invites
from app.api.v1.endpoints import lists
from app.api.v1.endpoints import items
from app.api.v1.endpoints import ocr
from app.api.v1.endpoints import costs
from app.api.v1.endpoints import financials
api_router_v1 = APIRouter()
api_router_v1.include_router(health.router)
api_router_v1.include_router(auth.router, prefix="/auth", tags=["Authentication"])
api_router_v1.include_router(users.router, prefix="/users", tags=["Users"])
api_router_v1.include_router(groups.router, prefix="/groups", tags=["Groups"])
api_router_v1.include_router(invites.router, prefix="/invites", tags=["Invites"])
api_router_v1.include_router(lists.router, prefix="/lists", tags=["Lists"])
api_router_v1.include_router(items.router, tags=["Items"])
api_router_v1.include_router(ocr.router, prefix="/ocr", tags=["OCR"])
api_router_v1.include_router(costs.router, prefix="/costs", tags=["Costs"])
api_router_v1.include_router(financials.router)
# Add other v1 endpoint routers here later
# e.g., api_router_v1.include_router(users.router, prefix="/users", tags=["Users"])
# app/config.py
import os
from pydantic_settings import BaseSettings
from dotenv import load_dotenv
import logging
import secrets
load_dotenv()
logger = logging.getLogger(__name__)
class Settings(BaseSettings):
DATABASE_URL: str | None = None
GEMINI_API_KEY: str | None = None
# --- JWT Settings ---
SECRET_KEY: str # Must be set via environment variable
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # Default token lifetime: 30 minutes
REFRESH_TOKEN_EXPIRE_MINUTES: int = 10080 # Default refresh token lifetime: 7 days
# --- OCR Settings ---
MAX_FILE_SIZE_MB: int = 10 # Maximum allowed file size for OCR processing
ALLOWED_IMAGE_TYPES: list[str] = ["image/jpeg", "image/png", "image/webp"] # Supported image formats
OCR_ITEM_EXTRACTION_PROMPT: str = """
Extract the shopping list items from this image.
List each distinct item on a new line.
Ignore prices, quantities, store names, discounts, taxes, totals, and other non-item text.
Focus only on the names of the products or items to be purchased.
Add 2 underscores before and after the item name, if it is struck through.
If the image does not appear to be a shopping list or receipt, state that clearly.
Example output for a grocery list:
Milk
Eggs
Bread
__Apples__
Organic Bananas
"""
# --- Gemini AI Settings ---
GEMINI_MODEL_NAME: str = "gemini-2.0-flash" # The model to use for OCR
GEMINI_SAFETY_SETTINGS: dict = {
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE",
}
GEMINI_GENERATION_CONFIG: dict = {
"candidate_count": 1,
"max_output_tokens": 2048,
"temperature": 0.9,
"top_p": 1,
"top_k": 1
}
# --- API Settings ---
API_PREFIX: str = "/api" # Base path for all API endpoints
API_OPENAPI_URL: str = "/api/openapi.json"
API_DOCS_URL: str = "/api/docs"
API_REDOC_URL: str = "/api/redoc"
CORS_ORIGINS: list[str] = [
"http://localhost:5174",
"http://localhost:8000",
"http://localhost:9000",
# Add your deployed frontend URL here later
# "https://your-frontend-domain.com",
]
# --- API Metadata ---
API_TITLE: str = "Shared Lists API"
API_DESCRIPTION: str = "API for managing shared shopping lists, OCR, and cost splitting."
API_VERSION: str = "0.1.0"
ROOT_MESSAGE: str = "Welcome to the Shared Lists API! Docs available at /api/docs"
# --- Logging Settings ---
LOG_LEVEL: str = "INFO"
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
# --- Health Check Settings ---
HEALTH_STATUS_OK: str = "ok"
HEALTH_STATUS_ERROR: str = "error"
# --- Auth Settings ---
OAUTH2_TOKEN_URL: str = "/api/v1/auth/login" # Path to login endpoint
TOKEN_TYPE: str = "bearer" # Default token type for OAuth2
AUTH_HEADER_PREFIX: str = "Bearer" # Prefix for Authorization header
AUTH_HEADER_NAME: str = "WWW-Authenticate" # Name of auth header
AUTH_CREDENTIALS_ERROR: str = "Could not validate credentials"
AUTH_INVALID_CREDENTIALS: str = "Incorrect email or password"
AUTH_NOT_AUTHENTICATED: str = "Not authenticated"
# --- HTTP Status Messages ---
HTTP_400_DETAIL: str = "Bad Request"
HTTP_401_DETAIL: str = "Unauthorized"
HTTP_403_DETAIL: str = "Forbidden"
HTTP_404_DETAIL: str = "Not Found"
HTTP_422_DETAIL: str = "Unprocessable Entity"
HTTP_429_DETAIL: str = "Too Many Requests"
HTTP_500_DETAIL: str = "Internal Server Error"
HTTP_503_DETAIL: str = "Service Unavailable"
# --- Database Error Messages ---
DB_CONNECTION_ERROR: str = "Database connection error"
DB_INTEGRITY_ERROR: str = "Database integrity error"
DB_TRANSACTION_ERROR: str = "Database transaction error"
DB_QUERY_ERROR: str = "Database query error"
class Config:
env_file = ".env"
env_file_encoding = 'utf-8'
extra = "ignore"
settings = Settings()
# Validation for critical settings
if settings.DATABASE_URL is None:
raise ValueError("DATABASE_URL environment variable must be set.")
# Enforce secure secret key
if not settings.SECRET_KEY:
raise ValueError("SECRET_KEY environment variable must be set. Generate a secure key using: openssl rand -hex 32")
# Validate secret key strength
if len(settings.SECRET_KEY) < 32:
raise ValueError("SECRET_KEY must be at least 32 characters long for security")
if settings.GEMINI_API_KEY is None:
logger.error("CRITICAL: GEMINI_API_KEY environment variable not set. Gemini features will be unavailable.")
else:
# Optional: Log partial key for confirmation (avoid logging full key)
logger.info(f"GEMINI_API_KEY loaded (starts with: {settings.GEMINI_API_KEY[:4]}...).")
Loading list details...
{{ error }}
{{ list.name }}
Add Items via OCR
Review Extracted Items
List Cost Summary
Loading cost summary...
{{ costSummaryError }}
Total List Cost: {{ formatCurrency(listCostSummary.total_list_cost) }}
Equal Share Per User: {{ formatCurrency(listCostSummary.equal_share_per_user) }}