This file is a merged representation of the entire codebase, combined into a single document by Repomix. This section contains a summary of this file. This file contains a packed representation of the entire repository's contents. It is designed to be easily consumable by AI systems for analysis, code review, or other automated processes. The content is organized as follows: 1. This summary section 2. Repository information 3. Directory structure 4. Repository files, each consisting of: - File path as an attribute - Full contents of the file - This file should be treated as read-only. Any changes should be made to the original repository files, not this packed version. - When processing this file, use the file path to distinguish between different files in the repository. - Be aware that this file may contain sensitive information. Handle it with the same level of security as you would the original repository. - Some files may have been excluded based on .gitignore rules and Repomix's configuration - Binary files are not included in this packed representation. Please refer to the Repository Structure section for a complete list of file paths, including binary files - Files matching patterns in .gitignore are excluded - Files matching default ignore patterns are excluded - Files are sorted by Git change count (files with more changes are at the bottom) .gitea/workflows/ci.yml be/.dockerignore be/.gitignore be/alembic.ini be/alembic/env.py be/alembic/README be/alembic/script.py.mako be/alembic/versions/bc37e9c7ae19_fresh_start.py be/app/api/api_router.py be/app/api/dependencies.py be/app/api/v1/api.py be/app/api/v1/endpoints/auth.py be/app/api/v1/endpoints/costs.py be/app/api/v1/endpoints/financials.py be/app/api/v1/endpoints/groups.py be/app/api/v1/endpoints/health.py be/app/api/v1/endpoints/invites.py be/app/api/v1/endpoints/items.py be/app/api/v1/endpoints/lists.py be/app/api/v1/endpoints/ocr.py be/app/api/v1/endpoints/users.py be/app/api/v1/test_auth.py be/app/api/v1/test_users.py be/app/config.py be/app/core/api_config.py be/app/core/exceptions.py be/app/core/gemini.py be/app/core/security.py be/app/crud/expense.py be/app/crud/group.py be/app/crud/invite.py be/app/crud/item.py be/app/crud/list.py be/app/crud/settlement.py be/app/crud/user.py be/app/database.py be/app/main.py be/app/models.py be/app/schemas/auth.py be/app/schemas/cost.py be/app/schemas/expense.py be/app/schemas/group.py be/app/schemas/health.py be/app/schemas/invite.py be/app/schemas/item.py be/app/schemas/list.py be/app/schemas/message.py be/app/schemas/ocr.py be/app/schemas/user.py be/Dockerfile be/requirements.txt be/tests/api/v1/endpoints/test_financials.py be/tests/core/__init__.py be/tests/core/test_exceptions.py be/tests/core/test_gemini.py be/tests/core/test_security.py be/tests/crud/__init__.py be/tests/crud/test_expense.py be/tests/crud/test_group.py be/tests/crud/test_invite.py be/tests/crud/test_item.py be/tests/crud/test_list.py be/tests/crud/test_settlement.py be/tests/crud/test_user.py be/Untitled-1.md docker-compose.yml fe/.editorconfig fe/.gitignore fe/.npmrc fe/.prettierrc.json fe/.vscode/extensions.json fe/.vscode/settings.json fe/eslint.config.js fe/index.html fe/package.json fe/postcss.config.js fe/public/icons/safari-pinned-tab.svg fe/quasar.config.ts fe/README.md fe/src-pwa/custom-service-worker.ts fe/src-pwa/manifest.json fe/src-pwa/pwa-env.d.ts fe/src-pwa/register-service-worker.ts fe/src-pwa/tsconfig.json fe/src/App.vue fe/src/assets/quasar-logo-vertical.svg fe/src/boot/axios.ts fe/src/boot/i18n.ts fe/src/components/ConflictResolutionDialog.vue fe/src/components/CreateListModal.vue fe/src/components/EssentialLink.vue fe/src/components/models.ts fe/src/components/OfflineIndicator.vue fe/src/config/api-config.ts fe/src/config/api.ts fe/src/css/app.scss fe/src/css/quasar.variables.scss fe/src/env.d.ts fe/src/i18n/en-US/index.ts fe/src/i18n/index.ts fe/src/layouts/AuthLayout.vue fe/src/layouts/MainLayout.vue fe/src/pages/AccountPage.vue fe/src/pages/ErrorNotFound.vue fe/src/pages/GroupDetailPage.vue fe/src/pages/GroupsPage.vue fe/src/pages/IndexPage.vue fe/src/pages/ListDetailPage.vue fe/src/pages/ListsPage.vue fe/src/pages/LoginPage.vue fe/src/pages/SignupPage.vue fe/src/router/index.ts fe/src/router/routes.ts fe/src/stores/auth.ts fe/src/stores/index.ts fe/src/stores/offline.ts fe/tsconfig.json This section contains the contents of the repository's files. # When you push to the develop branch or open/update a pull request targeting main, Gitea will: # Trigger the "CI Checks" workflow. # Execute the checks job on a runner. # Run each step sequentially. # If any of the linter/formatter check commands (black --check, ruff check, npm run lint) exit with a non-zero status code (indicating an error or check failure), the step and the entire job will fail. # You will see the status (success/failure) associated with your commit or pull request in the Gitea interface. name: CI Checks # Define triggers for the workflow on: push: branches: - develop # Run on pushes to the develop branch pull_request: branches: - main # Run on pull requests targeting the main branch jobs: checks: name: Linters and Formatters runs-on: ubuntu-latest # Use a standard Linux runner environment steps: - name: Checkout code uses: actions/checkout@v4 # Fetches the repository code # --- Backend Checks (Python/FastAPI) --- - name: Set up Python uses: actions/setup-python@v5 with: python-version: '3.11' # Match your project/Dockerfile version cache: 'pip' # Cache pip dependencies based on requirements.txt cache-dependency-path: 'be/requirements.txt' # Specify path for caching - name: Install Backend Dependencies and Tools working-directory: ./be # Run command within the 'be' directory run: | pip install --upgrade pip pip install -r requirements.txt pip install black ruff # Install formatters/linters for CI check - name: Run Black Formatter Check (Backend) working-directory: ./be run: black --check --diff . - name: Run Ruff Linter (Backend) working-directory: ./be run: ruff check . # --- Frontend Checks (SvelteKit/Node.js) --- - name: Set up Node.js uses: actions/setup-node@v4 with: node-version: '20.x' # Or specify your required Node.js version (e.g., 'lts/*') cache: 'npm' # Or 'pnpm' / 'yarn' depending on your package manager cache-dependency-path: 'fe/package-lock.json' # Adjust lockfile name if needed - name: Install Frontend Dependencies working-directory: ./fe # Run command within the 'fe' directory run: npm install # Or 'pnpm install' / 'yarn install' - name: Run ESLint and Prettier Check (Frontend) working-directory: ./fe # Assuming you have a 'lint' script in fe/package.json that runs both # Example package.json script: "lint": "prettier --check . && eslint ." run: npm run lint # If no combined script, run separately: # run: | # npm run format -- --check # Or 'npx prettier --check .' # npm run lint # Or 'npx eslint .' # - name: Run Frontend Type Check (Optional but recommended) # working-directory: ./fe # # Assuming you have a 'check' script: "check": "svelte-kit sync && svelte-check ..." # run: npm run check # - name: Run Placeholder Tests (Optional) # run: | # # Add commands to run backend tests if available # # Add commands to run frontend tests (e.g., npm test in ./fe) if available # echo "No tests configured yet." # Git files .git .gitignore # Virtual environment .venv venv/ env/ ENV/ *.env # Ignore local .env files within the backend directory if any # Python cache __pycache__/ *.py[cod] *$py.class # IDE files .idea/ .vscode/ # Test artifacts .pytest_cache/ htmlcov/ .coverage* # Other build/temp files *.egg-info/ dist/ build/ *.db # e.g., sqlite temp dbs # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # PEP 582; used by PDM, Flit and potentially other tools. __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv/ env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static analysis results .pytype/ # alembic default temp file *.db # If using sqlite for alembic versions locally for instance # If you use alembic autogenerate, it might create temporary files # Depending on your DB, adjust if necessary # *.sql.tmp # IDE files .idea/ .vscode/ # OS generated files .DS_Store Thumbs.db # A generic, single database configuration. [alembic] # path to migration scripts # Use forward slashes (/) also on windows to provide an os agnostic path script_location = alembic # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s # Uncomment the line below if you want the files to be prepended with date and time # see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file # for all available tokens # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s # sys.path path, will be prepended to sys.path if present. # defaults to the current working directory. prepend_sys_path = . # timezone to use when rendering the date within the migration file # as well as the filename. # If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. # Any required deps can installed by adding `alembic[tz]` to the pip requirements # string value is passed to ZoneInfo() # leave blank for localtime # timezone = # max length of characters to apply to the "slug" field # truncate_slug_length = 40 # set to 'true' to run the environment during # the 'revision' command, regardless of autogenerate # revision_environment = false # set to 'true' to allow .pyc and .pyo files without # a source .py file to be detected as revisions in the # versions/ directory # sourceless = false # version location specification; This defaults # to alembic/versions. When using multiple version # directories, initial revisions must be specified with --version-path. # The path separator used here should be the separator specified by "version_path_separator" below. # version_locations = %(here)s/bar:%(here)s/bat:alembic/versions # version path separator; As mentioned above, this is the character used to split # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. # Valid values for version_path_separator are: # # version_path_separator = : # version_path_separator = ; # version_path_separator = space # version_path_separator = newline # # Use os.pathsep. Default configuration used for new projects. version_path_separator = os # set to 'true' to search source files recursively # in each "version_locations" directory # new in Alembic version 1.10 # recursive_version_locations = false # the output encoding used when revision files # are written from script.py.mako # output_encoding = utf-8 ; sqlalchemy.url = driver://user:pass@localhost/dbname [post_write_hooks] # post_write_hooks defines scripts or Python functions that are run # on newly generated revision scripts. See the documentation for further # detail and examples # format using "black" - use the console_scripts runner, against the "black" entrypoint # hooks = black # black.type = console_scripts # black.entrypoint = black # black.options = -l 79 REVISION_SCRIPT_FILENAME # lint with attempts to fix using "ruff" - use the exec runner, execute a binary # hooks = ruff # ruff.type = exec # ruff.executable = %(here)s/.venv/bin/ruff # ruff.options = check --fix REVISION_SCRIPT_FILENAME # Logging configuration [loggers] keys = root,sqlalchemy,alembic [handlers] keys = console [formatters] keys = generic [logger_root] level = WARNING handlers = console qualname = [logger_sqlalchemy] level = WARNING handlers = qualname = sqlalchemy.engine [logger_alembic] level = INFO handlers = qualname = alembic [handler_console] class = StreamHandler args = (sys.stderr,) level = NOTSET formatter = generic [formatter_generic] format = %(levelname)-5.5s [%(name)s] %(message)s datefmt = %H:%M:%S from logging.config import fileConfig import os import sys from sqlalchemy import engine_from_config from sqlalchemy import pool from alembic import context # Ensure the 'app' directory is in the Python path # Adjust the path if your project structure is different sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), '..'))) # Import your app's Base and settings from app.models import Base # Import Base from your models module from app.config import settings # Import settings to get DATABASE_URL # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config # Set the sqlalchemy.url from your application settings # Use a synchronous version of the URL for Alembic's operations sync_db_url = settings.DATABASE_URL.replace("+asyncpg", "") if settings.DATABASE_URL else None if not sync_db_url: raise ValueError("DATABASE_URL not found in settings for Alembic.") config.set_main_option('sqlalchemy.url', sync_db_url) # Interpret the config file for Python logging. # This line sets up loggers basically. if config.config_file_name is not None: fileConfig(config.config_file_name) # add your model's MetaData object here # for 'autogenerate' support # from myapp import mymodel # target_metadata = mymodel.Base.metadata target_metadata = Base.metadata # other values from the config, defined by the needs of env.py, # can be acquired: # my_important_option = config.get_main_option("my_important_option") # ... etc. def run_migrations_offline() -> None: """Run migrations in 'offline' mode. This configures the context with just a URL and not an Engine, though an Engine is acceptable here as well. By skipping the Engine creation we don't even need a DBAPI to be available. Calls to context.execute() here emit the given string to the script output. """ url = config.get_main_option("sqlalchemy.url") context.configure( url=url, target_metadata=target_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"}, ) with context.begin_transaction(): context.run_migrations() def run_migrations_online() -> None: """Run migrations in 'online' mode. In this scenario we need to create an Engine and associate a connection with the context. """ connectable = engine_from_config( config.get_section(config.config_ini_section, {}), prefix="sqlalchemy.", poolclass=pool.NullPool, ) with connectable.connect() as connection: context.configure( connection=connection, target_metadata=target_metadata ) with context.begin_transaction(): context.run_migrations() if context.is_offline_mode(): run_migrations_offline() else: run_migrations_online() Generic single-database configuration. """${message} Revision ID: ${up_revision} Revises: ${down_revision | comma,n} Create Date: ${create_date} """ from typing import Sequence, Union from alembic import op import sqlalchemy as sa ${imports if imports else ""} # revision identifiers, used by Alembic. revision: str = ${repr(up_revision)} down_revision: Union[str, None] = ${repr(down_revision)} branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} def upgrade() -> None: """Upgrade schema.""" ${upgrades if upgrades else "pass"} def downgrade() -> None: """Downgrade schema.""" ${downgrades if downgrades else "pass"} """fresh start Revision ID: bc37e9c7ae19 Revises: Create Date: 2025-05-08 16:06:51.208542 """ from typing import Sequence, Union from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision: str = 'bc37e9c7ae19' down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### op.create_table('users', sa.Column('id', sa.Integer(), nullable=False), sa.Column('email', sa.String(), nullable=False), sa.Column('password_hash', sa.String(), nullable=False), sa.Column('name', sa.String(), nullable=True), sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), sa.PrimaryKeyConstraint('id') ) op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True) op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False) op.create_index(op.f('ix_users_name'), 'users', ['name'], unique=False) op.create_table('groups', sa.Column('id', sa.Integer(), nullable=False), sa.Column('name', sa.String(), nullable=False), sa.Column('created_by_id', sa.Integer(), nullable=False), sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ), sa.PrimaryKeyConstraint('id') ) op.create_index(op.f('ix_groups_id'), 'groups', ['id'], unique=False) op.create_index(op.f('ix_groups_name'), 'groups', ['name'], unique=False) op.create_table('invites', sa.Column('id', sa.Integer(), nullable=False), sa.Column('code', sa.String(), nullable=False), sa.Column('group_id', sa.Integer(), nullable=False), sa.Column('created_by_id', sa.Integer(), nullable=False), sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False), sa.Column('is_active', sa.Boolean(), nullable=False), sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ), sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'), sa.PrimaryKeyConstraint('id') ) op.create_index('ix_invites_active_code', 'invites', ['code'], unique=True, postgresql_where=sa.text('is_active = true')) op.create_index(op.f('ix_invites_code'), 'invites', ['code'], unique=False) op.create_index(op.f('ix_invites_id'), 'invites', ['id'], unique=False) op.create_table('lists', sa.Column('id', sa.Integer(), nullable=False), sa.Column('name', sa.String(), nullable=False), sa.Column('description', sa.Text(), nullable=True), sa.Column('created_by_id', sa.Integer(), nullable=False), sa.Column('group_id', sa.Integer(), nullable=True), sa.Column('is_complete', sa.Boolean(), nullable=False), sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), sa.Column('version', sa.Integer(), server_default='1', nullable=False), sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ), sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ), sa.PrimaryKeyConstraint('id') ) op.create_index(op.f('ix_lists_id'), 'lists', ['id'], unique=False) op.create_index(op.f('ix_lists_name'), 'lists', ['name'], unique=False) op.create_table('settlements', sa.Column('id', sa.Integer(), nullable=False), sa.Column('group_id', sa.Integer(), nullable=False), sa.Column('paid_by_user_id', sa.Integer(), nullable=False), sa.Column('paid_to_user_id', sa.Integer(), nullable=False), sa.Column('amount', sa.Numeric(precision=10, scale=2), nullable=False), sa.Column('settlement_date', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), sa.Column('description', sa.Text(), nullable=True), sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), sa.Column('version', sa.Integer(), server_default='1', nullable=False), sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ), sa.ForeignKeyConstraint(['paid_by_user_id'], ['users.id'], ), sa.ForeignKeyConstraint(['paid_to_user_id'], ['users.id'], ), sa.PrimaryKeyConstraint('id') ) op.create_index(op.f('ix_settlements_id'), 'settlements', ['id'], unique=False) op.create_table('user_groups', sa.Column('id', sa.Integer(), nullable=False), sa.Column('user_id', sa.Integer(), nullable=False), sa.Column('group_id', sa.Integer(), nullable=False), sa.Column('role', sa.Enum('owner', 'member', name='userroleenum'), nullable=False), sa.Column('joined_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'), sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), sa.PrimaryKeyConstraint('id'), sa.UniqueConstraint('user_id', 'group_id', name='uq_user_group') ) op.create_index(op.f('ix_user_groups_id'), 'user_groups', ['id'], unique=False) op.create_table('items', sa.Column('id', sa.Integer(), nullable=False), sa.Column('list_id', sa.Integer(), nullable=False), sa.Column('name', sa.String(), nullable=False), sa.Column('quantity', sa.String(), nullable=True), sa.Column('is_complete', sa.Boolean(), nullable=False), sa.Column('price', sa.Numeric(precision=10, scale=2), nullable=True), sa.Column('added_by_id', sa.Integer(), nullable=False), sa.Column('completed_by_id', sa.Integer(), nullable=True), sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), sa.Column('version', sa.Integer(), server_default='1', nullable=False), sa.ForeignKeyConstraint(['added_by_id'], ['users.id'], ), sa.ForeignKeyConstraint(['completed_by_id'], ['users.id'], ), sa.ForeignKeyConstraint(['list_id'], ['lists.id'], ondelete='CASCADE'), sa.PrimaryKeyConstraint('id') ) op.create_index(op.f('ix_items_id'), 'items', ['id'], unique=False) op.create_index(op.f('ix_items_name'), 'items', ['name'], unique=False) op.create_table('expenses', sa.Column('id', sa.Integer(), nullable=False), sa.Column('description', sa.String(), nullable=False), sa.Column('total_amount', sa.Numeric(precision=10, scale=2), nullable=False), sa.Column('currency', sa.String(), nullable=False), sa.Column('expense_date', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), sa.Column('split_type', sa.Enum('EQUAL', 'EXACT_AMOUNTS', 'PERCENTAGE', 'SHARES', 'ITEM_BASED', name='splittypeenum'), nullable=False), sa.Column('list_id', sa.Integer(), nullable=True), sa.Column('group_id', sa.Integer(), nullable=True), sa.Column('item_id', sa.Integer(), nullable=True), sa.Column('paid_by_user_id', sa.Integer(), nullable=False), sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), sa.Column('version', sa.Integer(), server_default='1', nullable=False), sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ), sa.ForeignKeyConstraint(['item_id'], ['items.id'], ), sa.ForeignKeyConstraint(['list_id'], ['lists.id'], ), sa.ForeignKeyConstraint(['paid_by_user_id'], ['users.id'], ), sa.PrimaryKeyConstraint('id') ) op.create_index(op.f('ix_expenses_id'), 'expenses', ['id'], unique=False) op.create_table('expense_splits', sa.Column('id', sa.Integer(), nullable=False), sa.Column('expense_id', sa.Integer(), nullable=False), sa.Column('user_id', sa.Integer(), nullable=False), sa.Column('owed_amount', sa.Numeric(precision=10, scale=2), nullable=False), sa.Column('share_percentage', sa.Numeric(precision=5, scale=2), nullable=True), sa.Column('share_units', sa.Integer(), nullable=True), sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), sa.ForeignKeyConstraint(['expense_id'], ['expenses.id'], ondelete='CASCADE'), sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), sa.PrimaryKeyConstraint('id'), sa.UniqueConstraint('expense_id', 'user_id', name='uq_expense_user_split') ) op.create_index(op.f('ix_expense_splits_id'), 'expense_splits', ['id'], unique=False) # ### end Alembic commands ### def downgrade() -> None: """Downgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### op.drop_index(op.f('ix_expense_splits_id'), table_name='expense_splits') op.drop_table('expense_splits') op.drop_index(op.f('ix_expenses_id'), table_name='expenses') op.drop_table('expenses') op.drop_index(op.f('ix_items_name'), table_name='items') op.drop_index(op.f('ix_items_id'), table_name='items') op.drop_table('items') op.drop_index(op.f('ix_user_groups_id'), table_name='user_groups') op.drop_table('user_groups') op.drop_index(op.f('ix_settlements_id'), table_name='settlements') op.drop_table('settlements') op.drop_index(op.f('ix_lists_name'), table_name='lists') op.drop_index(op.f('ix_lists_id'), table_name='lists') op.drop_table('lists') op.drop_index(op.f('ix_invites_id'), table_name='invites') op.drop_index(op.f('ix_invites_code'), table_name='invites') op.drop_index('ix_invites_active_code', table_name='invites', postgresql_where=sa.text('is_active = true')) op.drop_table('invites') op.drop_index(op.f('ix_groups_name'), table_name='groups') op.drop_index(op.f('ix_groups_id'), table_name='groups') op.drop_table('groups') op.drop_index(op.f('ix_users_name'), table_name='users') op.drop_index(op.f('ix_users_id'), table_name='users') op.drop_index(op.f('ix_users_email'), table_name='users') op.drop_table('users') # ### end Alembic commands ### # app/api/api_router.py from fastapi import APIRouter from app.api.v1.api import api_router_v1 # Import the v1 router api_router = APIRouter() # Include versioned routers here, adding the /api prefix api_router.include_router(api_router_v1, prefix="/v1") # Mounts v1 endpoints under /api/v1/... # Add other API versions later # e.g., api_router.include_router(api_router_v2, prefix="/v2") # app/api/v1/endpoints/financials.py import logging from fastapi import APIRouter, Depends, HTTPException, status, Query, Response from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from typing import List as PyList, Optional, Sequence from app.database import get_db from app.api.dependencies import get_current_user from app.models import User as UserModel, Group as GroupModel, List as ListModel, UserGroup as UserGroupModel, UserRoleEnum from app.schemas.expense import ( ExpenseCreate, ExpensePublic, SettlementCreate, SettlementPublic, ExpenseUpdate, SettlementUpdate ) from app.crud import expense as crud_expense from app.crud import settlement as crud_settlement from app.crud import group as crud_group from app.crud import list as crud_list from app.core.exceptions import ( ListNotFoundError, GroupNotFoundError, UserNotFoundError, InvalidOperationError, GroupPermissionError, ListPermissionError, ItemNotFoundError, GroupMembershipError ) logger = logging.getLogger(__name__) router = APIRouter() # --- Helper for permissions --- async def check_list_access_for_financials(db: AsyncSession, list_id: int, user_id: int, action: str = "access financial data for"): try: await crud_list.check_list_permission(db=db, list_id=list_id, user_id=user_id, require_member=True) except ListPermissionError as e: logger.warning(f"ListPermissionError in check_list_access_for_financials for list {list_id}, user {user_id}, action '{action}': {e.detail}") raise ListPermissionError(list_id, action=action) except ListNotFoundError: raise # --- Expense Endpoints --- @router.post( "/expenses", response_model=ExpensePublic, status_code=status.HTTP_201_CREATED, summary="Create New Expense", tags=["Expenses"] ) async def create_new_expense( expense_in: ExpenseCreate, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): logger.info(f"User {current_user.email} creating expense: {expense_in.description}") effective_group_id = expense_in.group_id is_group_context = False if expense_in.list_id: # Check basic access to list (implies membership if list is in group) await check_list_access_for_financials(db, expense_in.list_id, current_user.id, action="create expenses for") list_obj = await db.get(ListModel, expense_in.list_id) if not list_obj: raise ListNotFoundError(expense_in.list_id) if list_obj.group_id: if expense_in.group_id and list_obj.group_id != expense_in.group_id: raise InvalidOperationError(f"List {list_obj.id} belongs to group {list_obj.group_id}, not group {expense_in.group_id} specified in expense.") effective_group_id = list_obj.group_id is_group_context = True # Expense is tied to a group via the list elif expense_in.group_id: raise InvalidOperationError(f"Personal list {list_obj.id} cannot have expense associated with group {expense_in.group_id}.") # If list is personal, no group check needed yet, handled by payer check below. elif effective_group_id: # Only group_id provided for expense is_group_context = True # Ensure user is at least a member to create expense in group context await crud_group.check_group_membership(db, group_id=effective_group_id, user_id=current_user.id, action="create expenses for") else: # This case should ideally be caught by earlier checks if list_id was present but list was personal. # If somehow reached, it means no list_id and no group_id. raise InvalidOperationError("Expense must be linked to a list_id or group_id.") # Finalize expense payload with correct group_id if derived expense_in_final = expense_in.model_copy(update={"group_id": effective_group_id}) # --- Granular Permission Check for Payer --- if expense_in_final.paid_by_user_id != current_user.id: logger.warning(f"User {current_user.email} attempting to create expense paid by other user {expense_in_final.paid_by_user_id}") # If creating expense paid by someone else, user MUST be owner IF in group context if is_group_context and effective_group_id: try: await crud_group.check_user_role_in_group(db, group_id=effective_group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="create expense paid by another user") except GroupPermissionError as e: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Only group owners can create expenses paid by others. {str(e)}") else: # Cannot create expense paid by someone else for a personal list (no group context) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Cannot create expense paid by another user for a personal list.") # If paying for self, basic list/group membership check above is sufficient. try: created_expense = await crud_expense.create_expense(db=db, expense_in=expense_in_final, current_user_id=current_user.id) logger.info(f"Expense '{created_expense.description}' (ID: {created_expense.id}) created successfully.") return created_expense except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError, GroupMembershipError) as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) except NotImplementedError as e: raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=str(e)) except Exception as e: logger.error(f"Unexpected error creating expense: {str(e)}", exc_info=True) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.") @router.get("/expenses/{expense_id}", response_model=ExpensePublic, summary="Get Expense by ID", tags=["Expenses"]) async def get_expense( expense_id: int, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): logger.info(f"User {current_user.email} requesting expense ID {expense_id}") expense = await crud_expense.get_expense_by_id(db, expense_id=expense_id) if not expense: raise ItemNotFoundError(item_id=expense_id) if expense.list_id: await check_list_access_for_financials(db, expense.list_id, current_user.id) elif expense.group_id: await crud_group.check_group_membership(db, group_id=expense.group_id, user_id=current_user.id) elif expense.paid_by_user_id != current_user.id: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to view this expense") return expense @router.get("/lists/{list_id}/expenses", response_model=PyList[ExpensePublic], summary="List Expenses for a List", tags=["Expenses", "Lists"]) async def list_list_expenses( list_id: int, skip: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=200), db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): logger.info(f"User {current_user.email} listing expenses for list ID {list_id}") await check_list_access_for_financials(db, list_id, current_user.id) expenses = await crud_expense.get_expenses_for_list(db, list_id=list_id, skip=skip, limit=limit) return expenses @router.get("/groups/{group_id}/expenses", response_model=PyList[ExpensePublic], summary="List Expenses for a Group", tags=["Expenses", "Groups"]) async def list_group_expenses( group_id: int, skip: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=200), db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): logger.info(f"User {current_user.email} listing expenses for group ID {group_id}") await crud_group.check_group_membership(db, group_id=group_id, user_id=current_user.id, action="list expenses for") expenses = await crud_expense.get_expenses_for_group(db, group_id=group_id, skip=skip, limit=limit) return expenses @router.put("/expenses/{expense_id}", response_model=ExpensePublic, summary="Update Expense", tags=["Expenses"]) async def update_expense_details( expense_id: int, expense_in: ExpenseUpdate, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """ Updates an existing expense (description, currency, expense_date only). Requires the current version number for optimistic locking. User must have permission to modify the expense (e.g., be the payer or group admin). """ logger.info(f"User {current_user.email} attempting to update expense ID {expense_id} (version {expense_in.version})") expense_db = await crud_expense.get_expense_by_id(db, expense_id=expense_id) if not expense_db: raise ItemNotFoundError(item_id=expense_id) # --- Granular Permission Check --- can_modify = False # 1. User paid for the expense if expense_db.paid_by_user_id == current_user.id: can_modify = True # 2. OR User is owner of the group the expense belongs to elif expense_db.group_id: try: await crud_group.check_user_role_in_group(db, group_id=expense_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="modify group expenses") can_modify = True logger.info(f"Allowing update for expense {expense_id} by group owner {current_user.email}") except GroupMembershipError: # User not even a member pass # Keep can_modify as False except GroupPermissionError: # User is member but not owner pass # Keep can_modify as False except GroupNotFoundError: # Group doesn't exist (data integrity issue) logger.error(f"Group {expense_db.group_id} not found for expense {expense_id} during update check.") pass # Keep can_modify as False # Note: If expense is only linked to a personal list (no group), only payer can modify. if not can_modify: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot modify this expense (must be payer or group owner)") try: updated_expense = await crud_expense.update_expense(db=db, expense_db=expense_db, expense_in=expense_in) logger.info(f"Expense ID {expense_id} updated successfully to version {updated_expense.version}.") return updated_expense except InvalidOperationError as e: # Check if it's a version conflict (409) or other validation error (400) status_code = status.HTTP_400_BAD_REQUEST if "version" in str(e).lower(): status_code = status.HTTP_409_CONFLICT raise HTTPException(status_code=status_code, detail=str(e)) except Exception as e: logger.error(f"Unexpected error updating expense {expense_id}: {str(e)}", exc_info=True) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.") @router.delete("/expenses/{expense_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete Expense", tags=["Expenses"]) async def delete_expense_record( expense_id: int, expected_version: Optional[int] = Query(None, description="Expected version for optimistic locking"), db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """ Deletes an expense and its associated splits. Requires expected_version query parameter for optimistic locking. User must have permission to delete the expense (e.g., be the payer or group admin). """ logger.info(f"User {current_user.email} attempting to delete expense ID {expense_id} (expected version {expected_version})") expense_db = await crud_expense.get_expense_by_id(db, expense_id=expense_id) if not expense_db: # Return 204 even if not found, as the end state is achieved (item is gone) logger.warning(f"Attempt to delete non-existent expense ID {expense_id}") return Response(status_code=status.HTTP_204_NO_CONTENT) # Alternatively, raise NotFoundError(detail=f"Expense {expense_id} not found") -> 404 # --- Granular Permission Check --- can_delete = False # 1. User paid for the expense if expense_db.paid_by_user_id == current_user.id: can_delete = True # 2. OR User is owner of the group the expense belongs to elif expense_db.group_id: try: await crud_group.check_user_role_in_group(db, group_id=expense_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="delete group expenses") can_delete = True logger.info(f"Allowing delete for expense {expense_id} by group owner {current_user.email}") except GroupMembershipError: pass except GroupPermissionError: pass except GroupNotFoundError: logger.error(f"Group {expense_db.group_id} not found for expense {expense_id} during delete check.") pass if not can_delete: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot delete this expense (must be payer or group owner)") try: await crud_expense.delete_expense(db=db, expense_db=expense_db, expected_version=expected_version) logger.info(f"Expense ID {expense_id} deleted successfully.") # No need to return content on 204 except InvalidOperationError as e: # Check if it's a version conflict (409) or other validation error (400) status_code = status.HTTP_400_BAD_REQUEST if "version" in str(e).lower(): status_code = status.HTTP_409_CONFLICT raise HTTPException(status_code=status_code, detail=str(e)) except Exception as e: logger.error(f"Unexpected error deleting expense {expense_id}: {str(e)}", exc_info=True) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.") return Response(status_code=status.HTTP_204_NO_CONTENT) # --- Settlement Endpoints --- @router.post( "/settlements", response_model=SettlementPublic, status_code=status.HTTP_201_CREATED, summary="Record New Settlement", tags=["Settlements"] ) async def create_new_settlement( settlement_in: SettlementCreate, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): logger.info(f"User {current_user.email} recording settlement in group {settlement_in.group_id}") await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=current_user.id, action="record settlements in") try: await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=settlement_in.paid_by_user_id, action="be a payer in this group's settlement") await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=settlement_in.paid_to_user_id, action="be a payee in this group's settlement") except GroupMembershipError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Payer or payee issue: {str(e)}") except GroupNotFoundError as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) try: created_settlement = await crud_settlement.create_settlement(db=db, settlement_in=settlement_in, current_user_id=current_user.id) logger.info(f"Settlement ID {created_settlement.id} recorded successfully in group {settlement_in.group_id}.") return created_settlement except (UserNotFoundError, GroupNotFoundError, InvalidOperationError, GroupMembershipError) as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) except Exception as e: logger.error(f"Unexpected error recording settlement: {str(e)}", exc_info=True) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.") @router.get("/settlements/{settlement_id}", response_model=SettlementPublic, summary="Get Settlement by ID", tags=["Settlements"]) async def get_settlement( settlement_id: int, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): logger.info(f"User {current_user.email} requesting settlement ID {settlement_id}") settlement = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id) if not settlement: raise ItemNotFoundError(item_id=settlement_id) is_party_to_settlement = current_user.id in [settlement.paid_by_user_id, settlement.paid_to_user_id] try: await crud_group.check_group_membership(db, group_id=settlement.group_id, user_id=current_user.id) except GroupMembershipError: if not is_party_to_settlement: raise GroupMembershipError(settlement.group_id, action="view this settlement's details") logger.info(f"User {current_user.email} (party to settlement) viewing settlement {settlement_id} for group {settlement.group_id}.") return settlement @router.get("/groups/{group_id}/settlements", response_model=PyList[SettlementPublic], summary="List Settlements for a Group", tags=["Settlements", "Groups"]) async def list_group_settlements( group_id: int, skip: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=200), db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): logger.info(f"User {current_user.email} listing settlements for group ID {group_id}") await crud_group.check_group_membership(db, group_id=group_id, user_id=current_user.id, action="list settlements for this group") settlements = await crud_settlement.get_settlements_for_group(db, group_id=group_id, skip=skip, limit=limit) return settlements @router.put("/settlements/{settlement_id}", response_model=SettlementPublic, summary="Update Settlement", tags=["Settlements"]) async def update_settlement_details( settlement_id: int, settlement_in: SettlementUpdate, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """ Updates an existing settlement (description, settlement_date only). Requires the current version number for optimistic locking. User must have permission (e.g., be involved party or group admin). """ logger.info(f"User {current_user.email} attempting to update settlement ID {settlement_id} (version {settlement_in.version})") settlement_db = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id) if not settlement_db: raise ItemNotFoundError(item_id=settlement_id) # --- Granular Permission Check --- can_modify = False # 1. User is involved party (payer or payee) is_party = current_user.id in [settlement_db.paid_by_user_id, settlement_db.paid_to_user_id] if is_party: can_modify = True # 2. OR User is owner of the group the settlement belongs to # Note: Settlements always have a group_id based on current model elif settlement_db.group_id: try: await crud_group.check_user_role_in_group(db, group_id=settlement_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="modify group settlements") can_modify = True logger.info(f"Allowing update for settlement {settlement_id} by group owner {current_user.email}") except GroupMembershipError: pass except GroupPermissionError: pass except GroupNotFoundError: logger.error(f"Group {settlement_db.group_id} not found for settlement {settlement_id} during update check.") pass if not can_modify: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot modify this settlement (must be involved party or group owner)") try: updated_settlement = await crud_settlement.update_settlement(db=db, settlement_db=settlement_db, settlement_in=settlement_in) logger.info(f"Settlement ID {settlement_id} updated successfully to version {updated_settlement.version}.") return updated_settlement except InvalidOperationError as e: status_code = status.HTTP_400_BAD_REQUEST if "version" in str(e).lower(): status_code = status.HTTP_409_CONFLICT raise HTTPException(status_code=status_code, detail=str(e)) except Exception as e: logger.error(f"Unexpected error updating settlement {settlement_id}: {str(e)}", exc_info=True) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.") @router.delete("/settlements/{settlement_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete Settlement", tags=["Settlements"]) async def delete_settlement_record( settlement_id: int, expected_version: Optional[int] = Query(None, description="Expected version for optimistic locking"), db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """ Deletes a settlement. Requires expected_version query parameter for optimistic locking. User must have permission (e.g., be involved party or group admin). """ logger.info(f"User {current_user.email} attempting to delete settlement ID {settlement_id} (expected version {expected_version})") settlement_db = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id) if not settlement_db: logger.warning(f"Attempt to delete non-existent settlement ID {settlement_id}") return Response(status_code=status.HTTP_204_NO_CONTENT) # --- Granular Permission Check --- can_delete = False # 1. User is involved party (payer or payee) is_party = current_user.id in [settlement_db.paid_by_user_id, settlement_db.paid_to_user_id] if is_party: can_delete = True # 2. OR User is owner of the group the settlement belongs to elif settlement_db.group_id: try: await crud_group.check_user_role_in_group(db, group_id=settlement_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="delete group settlements") can_delete = True logger.info(f"Allowing delete for settlement {settlement_id} by group owner {current_user.email}") except GroupMembershipError: pass except GroupPermissionError: pass except GroupNotFoundError: logger.error(f"Group {settlement_db.group_id} not found for settlement {settlement_id} during delete check.") pass if not can_delete: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot delete this settlement (must be involved party or group owner)") try: await crud_settlement.delete_settlement(db=db, settlement_db=settlement_db, expected_version=expected_version) logger.info(f"Settlement ID {settlement_id} deleted successfully.") except InvalidOperationError as e: status_code = status.HTTP_400_BAD_REQUEST if "version" in str(e).lower(): status_code = status.HTTP_409_CONFLICT raise HTTPException(status_code=status_code, detail=str(e)) except Exception as e: logger.error(f"Unexpected error deleting settlement {settlement_id}: {str(e)}", exc_info=True) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.") return Response(status_code=status.HTTP_204_NO_CONTENT) # TODO (remaining from original list): # (None - GET/POST/PUT/DELETE implemented for Expense/Settlement) # app/api/v1/endpoints/users.py import logging from fastapi import APIRouter, Depends, HTTPException from app.api.dependencies import get_current_user # Import the dependency from app.schemas.user import UserPublic # Import the response schema from app.models import User as UserModel # Import the DB model for type hinting logger = logging.getLogger(__name__) router = APIRouter() @router.get( "/me", response_model=UserPublic, # Use the public schema to avoid exposing hash summary="Get Current User", description="Retrieves the details of the currently authenticated user.", tags=["Users"] ) async def read_users_me( current_user: UserModel = Depends(get_current_user) # Apply the dependency ): """ Returns the data for the user associated with the current valid access token. """ logger.info(f"Fetching details for current user: {current_user.email}") # The 'current_user' object is the SQLAlchemy model instance returned by the dependency. # Pydantic's response_model will automatically convert it using UserPublic schema. return current_user # Add other user-related endpoints here later (e.g., update user, list users (admin)) # Example: be/tests/api/v1/test_auth.py import pytest from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession from app.core.security import verify_password from app.crud.user import get_user_by_email from app.schemas.user import UserPublic # Import for response validation from app.schemas.auth import Token # Import for response validation pytestmark = pytest.mark.asyncio async def test_signup_success(client: AsyncClient, db: AsyncSession): email = "testsignup@example.com" password = "testpassword123" response = await client.post( "/api/v1/auth/signup", json={"email": email, "password": password, "name": "Test Signup"}, ) assert response.status_code == 201 data = response.json() assert data["email"] == email assert data["name"] == "Test Signup" assert "id" in data assert "created_at" in data # Verify password hash is NOT returned assert "password" not in data assert "hashed_password" not in data # Verify user exists in DB user_db = await get_user_by_email(db, email=email) assert user_db is not None assert user_db.email == email assert verify_password(password, user_db.hashed_password) async def test_signup_email_exists(client: AsyncClient, db: AsyncSession): # Create user first email = "testexists@example.com" password = "testpassword123" await client.post( "/api/v1/auth/signup", json={"email": email, "password": password}, ) # Attempt signup again with same email response = await client.post( "/api/v1/auth/signup", json={"email": email, "password": "anotherpassword"}, ) assert response.status_code == 400 assert "Email already registered" in response.json()["detail"] async def test_login_success(client: AsyncClient, db: AsyncSession): email = "testlogin@example.com" password = "testpassword123" # Create user first via signup signup_res = await client.post( "/api/v1/auth/signup", json={"email": email, "password": password} ) assert signup_res.status_code == 201 # Attempt login login_payload = {"username": email, "password": password} response = await client.post("/api/v1/auth/login", data=login_payload) # Use data for form encoding assert response.status_code == 200 data = response.json() assert "access_token" in data assert data["token_type"] == "bearer" # Optionally verify the token itself here using verify_access_token async def test_login_wrong_password(client: AsyncClient, db: AsyncSession): email = "testloginwrong@example.com" password = "testpassword123" await client.post( "/api/v1/auth/signup", json={"email": email, "password": password} ) login_payload = {"username": email, "password": "wrongpassword"} response = await client.post("/api/v1/auth/login", data=login_payload) assert response.status_code == 401 assert "Incorrect email or password" in response.json()["detail"] assert "WWW-Authenticate" in response.headers assert response.headers["WWW-Authenticate"] == "Bearer" async def test_login_user_not_found(client: AsyncClient, db: AsyncSession): login_payload = {"username": "nosuchuser@example.com", "password": "anypassword"} response = await client.post("/api/v1/auth/login", data=login_payload) assert response.status_code == 401 assert "Incorrect email or password" in response.json()["detail"] # Example: be/tests/api/v1/test_users.py import pytest from httpx import AsyncClient from app.schemas.user import UserPublic # For response validation from app.core.security import create_access_token pytestmark = pytest.mark.asyncio # Helper function to get a valid token async def get_auth_headers(client: AsyncClient, email: str, password: str) -> dict: """Logs in a user and returns authorization headers.""" login_payload = {"username": email, "password": password} response = await client.post("/api/v1/auth/login", data=login_payload) response.raise_for_status() # Raise exception for non-2xx status token_data = response.json() return {"Authorization": f"Bearer {token_data['access_token']}"} async def test_read_users_me_success(client: AsyncClient): # 1. Create user email = "testme@example.com" password = "password123" signup_res = await client.post( "/api/v1/auth/signup", json={"email": email, "password": password, "name": "Test Me"} ) assert signup_res.status_code == 201 user_data = UserPublic(**signup_res.json()) # Validate signup response # 2. Get token headers = await get_auth_headers(client, email, password) # 3. Request /users/me response = await client.get("/api/v1/users/me", headers=headers) assert response.status_code == 200 me_data = response.json() assert me_data["email"] == email assert me_data["name"] == "Test Me" assert me_data["id"] == user_data.id # Check ID matches signup assert "password" not in me_data assert "hashed_password" not in me_data async def test_read_users_me_no_token(client: AsyncClient): response = await client.get("/api/v1/users/me") # No headers assert response.status_code == 401 # Handled by OAuth2PasswordBearer assert response.json()["detail"] == "Not authenticated" # Default detail from OAuth2PasswordBearer async def test_read_users_me_invalid_token(client: AsyncClient): headers = {"Authorization": "Bearer invalid-token-string"} response = await client.get("/api/v1/users/me", headers=headers) assert response.status_code == 401 assert response.json()["detail"] == "Could not validate credentials" # Detail from our dependency async def test_read_users_me_expired_token(client: AsyncClient): # Create a short-lived token manually (or adjust settings temporarily) email = "testexpired@example.com" # Assume create_access_token allows timedelta override expired_token = create_access_token(subject=email, expires_delta=timedelta(seconds=-10)) headers = {"Authorization": f"Bearer {expired_token}"} response = await client.get("/api/v1/users/me", headers=headers) assert response.status_code == 401 assert response.json()["detail"] == "Could not validate credentials" # Add test case for valid token but user deleted from DB if needed from typing import Dict, Any from app.config import settings # API Version API_VERSION = "v1" # API Prefix API_PREFIX = f"/api/{API_VERSION}" # API Endpoints class APIEndpoints: # Auth AUTH = { "LOGIN": "/auth/login", "SIGNUP": "/auth/signup", "REFRESH_TOKEN": "/auth/refresh-token", } # Users USERS = { "PROFILE": "/users/profile", "UPDATE_PROFILE": "/users/profile", } # Lists LISTS = { "BASE": "/lists", "BY_ID": "/lists/{id}", "ITEMS": "/lists/{list_id}/items", "ITEM": "/lists/{list_id}/items/{item_id}", } # Groups GROUPS = { "BASE": "/groups", "BY_ID": "/groups/{id}", "LISTS": "/groups/{group_id}/lists", "MEMBERS": "/groups/{group_id}/members", } # Invites INVITES = { "BASE": "/invites", "BY_ID": "/invites/{id}", "ACCEPT": "/invites/{id}/accept", "DECLINE": "/invites/{id}/decline", } # OCR OCR = { "PROCESS": "/ocr/process", } # Financials FINANCIALS = { "EXPENSES": "/financials/expenses", "EXPENSE": "/financials/expenses/{id}", "SETTLEMENTS": "/financials/settlements", "SETTLEMENT": "/financials/settlements/{id}", } # Health HEALTH = { "CHECK": "/health", } # API Metadata API_METADATA = { "title": settings.API_TITLE, "description": settings.API_DESCRIPTION, "version": settings.API_VERSION, "openapi_url": settings.API_OPENAPI_URL, "docs_url": settings.API_DOCS_URL, "redoc_url": settings.API_REDOC_URL, } # API Tags API_TAGS = [ {"name": "Authentication", "description": "Authentication and authorization endpoints"}, {"name": "Users", "description": "User management endpoints"}, {"name": "Lists", "description": "Shopping list management endpoints"}, {"name": "Groups", "description": "Group management endpoints"}, {"name": "Invites", "description": "Group invitation management endpoints"}, {"name": "OCR", "description": "Optical Character Recognition endpoints"}, {"name": "Financials", "description": "Financial management endpoints"}, {"name": "Health", "description": "Health check endpoints"}, ] # Helper function to get full API URL def get_api_url(endpoint: str, **kwargs) -> str: """ Get the full API URL for an endpoint. Args: endpoint: The endpoint path **kwargs: Path parameters to format the endpoint Returns: str: The full API URL """ formatted_endpoint = endpoint.format(**kwargs) return f"{API_PREFIX}{formatted_endpoint}" # app/crud/expense.py import logging # Add logging import from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload, joinedload from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict from datetime import datetime, timezone # Added timezone from app.models import ( Expense as ExpenseModel, ExpenseSplit as ExpenseSplitModel, User as UserModel, List as ListModel, Group as GroupModel, UserGroup as UserGroupModel, SplitTypeEnum, Item as ItemModel ) from app.schemas.expense import ExpenseCreate, ExpenseSplitCreate, ExpenseUpdate # Removed unused ExpenseUpdate from app.core.exceptions import ( # Using existing specific exceptions where possible ListNotFoundError, GroupNotFoundError, UserNotFoundError, InvalidOperationError # Import the new exception ) # Placeholder for InvalidOperationError if not defined in app.core.exceptions # This should be a proper HTTPException subclass if used in API layer # class CrudInvalidOperationError(ValueError): # For internal CRUD validation logic # pass logger = logging.getLogger(__name__) # Initialize logger async def get_users_for_splitting(db: AsyncSession, expense_group_id: Optional[int], expense_list_id: Optional[int], expense_paid_by_user_id: int) -> PyList[UserModel]: """ Determines the list of users an expense should be split amongst. Priority: Group members (if group_id), then List's group members or creator (if list_id). Fallback to only the payer if no other context yields users. """ users_to_split_with: PyList[UserModel] = [] processed_user_ids = set() async def _add_user(user: Optional[UserModel]): if user and user.id not in processed_user_ids: users_to_split_with.append(user) processed_user_ids.add(user.id) if expense_group_id: group_result = await db.execute( select(GroupModel).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))) .where(GroupModel.id == expense_group_id) ) group = group_result.scalars().first() if not group: raise GroupNotFoundError(expense_group_id) for assoc in group.member_associations: await _add_user(assoc.user) elif expense_list_id: # Only if group_id was not primary context list_result = await db.execute( select(ListModel) .options( selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))), selectinload(ListModel.creator) ) .where(ListModel.id == expense_list_id) ) db_list = list_result.scalars().first() if not db_list: raise ListNotFoundError(expense_list_id) if db_list.group: for assoc in db_list.group.member_associations: await _add_user(assoc.user) elif db_list.creator: await _add_user(db_list.creator) if not users_to_split_with: payer_user = await db.get(UserModel, expense_paid_by_user_id) if not payer_user: # This should have been caught earlier if paid_by_user_id was validated before calling this helper raise UserNotFoundError(user_id=expense_paid_by_user_id) await _add_user(payer_user) if not users_to_split_with: # This should ideally not be reached if payer is always a fallback raise InvalidOperationError("Could not determine any users for splitting the expense.") return users_to_split_with async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_user_id: int) -> ExpenseModel: """Creates a new expense and its associated splits. Args: db: Database session expense_in: Expense creation data current_user_id: ID of the user creating the expense Returns: The created expense with splits Raises: UserNotFoundError: If payer or split users don't exist ListNotFoundError: If specified list doesn't exist GroupNotFoundError: If specified group doesn't exist InvalidOperationError: For various validation failures """ # Helper function to round decimals consistently def round_money(amount: Decimal) -> Decimal: return amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) # 1. Context Validation # Validate basic context requirements first if not expense_in.list_id and not expense_in.group_id: raise InvalidOperationError("Expense must be associated with a list or a group.") # 2. User Validation payer = await db.get(UserModel, expense_in.paid_by_user_id) if not payer: raise UserNotFoundError(user_id=expense_in.paid_by_user_id) # 3. List/Group Context Resolution final_group_id = await _resolve_expense_context(db, expense_in) # 4. Create the expense object db_expense = ExpenseModel( description=expense_in.description, total_amount=round_money(expense_in.total_amount), currency=expense_in.currency or "USD", expense_date=expense_in.expense_date or datetime.now(timezone.utc), split_type=expense_in.split_type, list_id=expense_in.list_id, group_id=final_group_id, item_id=expense_in.item_id, paid_by_user_id=expense_in.paid_by_user_id, created_by_user_id=current_user_id # Track who created this expense ) # 5. Generate splits based on split type splits_to_create = await _generate_expense_splits(db, db_expense, expense_in, round_money) # 6. Single transaction for expense and all splits try: db.add(db_expense) await db.flush() # Get expense ID without committing # Update all splits with the expense ID for split in splits_to_create: split.expense_id = db_expense.id db.add_all(splits_to_create) await db.commit() except Exception as e: await db.rollback() logger.error(f"Failed to save expense: {str(e)}", exc_info=True) raise InvalidOperationError(f"Failed to save expense: {str(e)}") # Refresh to get the splits relationship populated await db.refresh(db_expense, attribute_names=["splits"]) return db_expense async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]: """Resolves and validates the expense's context (list and group). Returns the final group_id for the expense after validation. """ final_group_id = expense_in.group_id # If list_id is provided, validate it and potentially derive group_id if expense_in.list_id: list_obj = await db.get(ListModel, expense_in.list_id) if not list_obj: raise ListNotFoundError(expense_in.list_id) # If list belongs to a group, verify consistency or inherit group_id if list_obj.group_id: if expense_in.group_id and list_obj.group_id != expense_in.group_id: raise InvalidOperationError( f"List {expense_in.list_id} belongs to group {list_obj.group_id}, " f"but expense was specified for group {expense_in.group_id}." ) final_group_id = list_obj.group_id # Prioritize list's group # If only group_id is provided (no list_id), validate group_id elif final_group_id: group_obj = await db.get(GroupModel, final_group_id) if not group_obj: raise GroupNotFoundError(final_group_id) return final_group_id async def _generate_expense_splits( db: AsyncSession, db_expense: ExpenseModel, expense_in: ExpenseCreate, round_money: Callable[[Decimal], Decimal] ) -> PyList[ExpenseSplitModel]: """Generates appropriate expense splits based on split type.""" splits_to_create: PyList[ExpenseSplitModel] = [] # Create splits based on the split type if expense_in.split_type == SplitTypeEnum.EQUAL: splits_to_create = await _create_equal_splits( db, db_expense, expense_in, round_money ) elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS: splits_to_create = await _create_exact_amount_splits( db, db_expense, expense_in, round_money ) elif expense_in.split_type == SplitTypeEnum.PERCENTAGE: splits_to_create = await _create_percentage_splits( db, db_expense, expense_in, round_money ) elif expense_in.split_type == SplitTypeEnum.SHARES: splits_to_create = await _create_shares_splits( db, db_expense, expense_in, round_money ) elif expense_in.split_type == SplitTypeEnum.ITEM_BASED: splits_to_create = await _create_item_based_splits( db, db_expense, expense_in, round_money ) else: raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}") if not splits_to_create: raise InvalidOperationError("No expense splits were generated.") return splits_to_create async def _create_equal_splits( db: AsyncSession, db_expense: ExpenseModel, expense_in: ExpenseCreate, round_money: Callable[[Decimal], Decimal] ) -> PyList[ExpenseSplitModel]: """Creates equal splits among users.""" users_for_splitting = await get_users_for_splitting( db, db_expense.group_id, expense_in.list_id, expense_in.paid_by_user_id ) if not users_for_splitting: raise InvalidOperationError("No users found for EQUAL split.") num_users = len(users_for_splitting) amount_per_user = round_money(db_expense.total_amount / Decimal(num_users)) remainder = db_expense.total_amount - (amount_per_user * num_users) splits = [] for i, user in enumerate(users_for_splitting): split_amount = amount_per_user if i == 0 and remainder != Decimal('0'): split_amount = round_money(amount_per_user + remainder) splits.append(ExpenseSplitModel( user_id=user.id, owed_amount=split_amount )) return splits async def _create_exact_amount_splits( db: AsyncSession, db_expense: ExpenseModel, expense_in: ExpenseCreate, round_money: Callable[[Decimal], Decimal] ) -> PyList[ExpenseSplitModel]: """Creates splits with exact amounts.""" if not expense_in.splits_in: raise InvalidOperationError("Splits data is required for EXACT_AMOUNTS split type.") # Validate all users in splits exist await _validate_users_in_splits(db, expense_in.splits_in) current_total = Decimal("0.00") splits = [] for split_in in expense_in.splits_in: if split_in.owed_amount <= Decimal('0'): raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.") rounded_amount = round_money(split_in.owed_amount) current_total += rounded_amount splits.append(ExpenseSplitModel( user_id=split_in.user_id, owed_amount=rounded_amount )) if round_money(current_total) != db_expense.total_amount: raise InvalidOperationError( f"Sum of exact split amounts ({current_total}) != expense total ({db_expense.total_amount})." ) return splits async def _create_percentage_splits( db: AsyncSession, db_expense: ExpenseModel, expense_in: ExpenseCreate, round_money: Callable[[Decimal], Decimal] ) -> PyList[ExpenseSplitModel]: """Creates splits based on percentages.""" if not expense_in.splits_in: raise InvalidOperationError("Splits data is required for PERCENTAGE split type.") # Validate all users in splits exist await _validate_users_in_splits(db, expense_in.splits_in) total_percentage = Decimal("0.00") current_total = Decimal("0.00") splits = [] for split_in in expense_in.splits_in: if not (split_in.share_percentage and Decimal("0") < split_in.share_percentage <= Decimal("100")): raise InvalidOperationError( f"Invalid percentage {split_in.share_percentage} for user {split_in.user_id}." ) total_percentage += split_in.share_percentage owed_amount = round_money(db_expense.total_amount * (split_in.share_percentage / Decimal("100"))) current_total += owed_amount splits.append(ExpenseSplitModel( user_id=split_in.user_id, owed_amount=owed_amount, share_percentage=split_in.share_percentage )) if round_money(total_percentage) != Decimal("100.00"): raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.") # Adjust for rounding differences if current_total != db_expense.total_amount and splits: diff = db_expense.total_amount - current_total splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff) return splits async def _create_shares_splits( db: AsyncSession, db_expense: ExpenseModel, expense_in: ExpenseCreate, round_money: Callable[[Decimal], Decimal] ) -> PyList[ExpenseSplitModel]: """Creates splits based on shares.""" if not expense_in.splits_in: raise InvalidOperationError("Splits data is required for SHARES split type.") # Validate all users in splits exist await _validate_users_in_splits(db, expense_in.splits_in) # Calculate total shares total_shares = sum(s.share_units for s in expense_in.splits_in if s.share_units and s.share_units > 0) if total_shares == 0: raise InvalidOperationError("Total shares cannot be zero for SHARES split.") splits = [] current_total = Decimal("0.00") for split_in in expense_in.splits_in: if not (split_in.share_units and split_in.share_units > 0): raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.") share_ratio = Decimal(split_in.share_units) / Decimal(total_shares) owed_amount = round_money(db_expense.total_amount * share_ratio) current_total += owed_amount splits.append(ExpenseSplitModel( user_id=split_in.user_id, owed_amount=owed_amount, share_units=split_in.share_units )) # Adjust for rounding differences if current_total != db_expense.total_amount and splits: diff = db_expense.total_amount - current_total splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff) return splits async def _create_item_based_splits( db: AsyncSession, db_expense: ExpenseModel, expense_in: ExpenseCreate, round_money: Callable[[Decimal], Decimal] ) -> PyList[ExpenseSplitModel]: """Creates splits based on items in a shopping list.""" if not expense_in.list_id: raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.") if expense_in.splits_in: logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.") # Build query to fetch relevant items items_query = select(ItemModel).where(ItemModel.list_id == expense_in.list_id) if expense_in.item_id: items_query = items_query.where(ItemModel.id == expense_in.item_id) else: items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0"))) # Load items with their adders items_result = await db.execute(items_query.options(selectinload(ItemModel.added_by_user))) relevant_items = items_result.scalars().all() if not relevant_items: error_msg = ( f"Specified item ID {expense_in.item_id} not found in list {expense_in.list_id}." if expense_in.item_id else f"List {expense_in.list_id} has no priced items to base the expense on." ) raise InvalidOperationError(error_msg) # Aggregate owed amounts by user calculated_total = Decimal("0.00") user_owed_amounts = defaultdict(Decimal) processed_items = 0 for item in relevant_items: if item.price is None or item.price <= Decimal("0"): if expense_in.item_id: raise InvalidOperationError( f"Item ID {expense_in.item_id} must have a positive price for ITEM_BASED expense." ) continue if not item.added_by_user: logger.error(f"Item ID {item.id} is missing added_by_user relationship.") raise InvalidOperationError(f"Data integrity issue: Item {item.id} is missing adder information.") calculated_total += item.price user_owed_amounts[item.added_by_user.id] += item.price processed_items += 1 if processed_items == 0: raise InvalidOperationError( f"No items with positive prices found in list {expense_in.list_id} to create ITEM_BASED expense." ) # Validate total matches calculated total if round_money(calculated_total) != db_expense.total_amount: raise InvalidOperationError( f"Expense total amount ({db_expense.total_amount}) does not match the " f"calculated total from item prices ({calculated_total})." ) # Create splits based on aggregated amounts splits = [] for user_id, owed_amount in user_owed_amounts.items(): splits.append(ExpenseSplitModel( user_id=user_id, owed_amount=round_money(owed_amount) )) return splits async def _validate_users_in_splits(db: AsyncSession, splits_in: PyList[ExpenseSplitCreate]) -> None: """Validates that all users in the splits exist.""" user_ids_in_split = [s.user_id for s in splits_in] user_results = await db.execute(select(UserModel.id).where(UserModel.id.in_(user_ids_in_split))) found_user_ids = {row[0] for row in user_results} if len(found_user_ids) != len(user_ids_in_split): missing_user_ids = set(user_ids_in_split) - found_user_ids raise UserNotFoundError(identifier=f"users in split data: {list(missing_user_ids)}") async def get_expense_by_id(db: AsyncSession, expense_id: int) -> Optional[ExpenseModel]: result = await db.execute( select(ExpenseModel) .options( selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)), selectinload(ExpenseModel.paid_by_user), selectinload(ExpenseModel.list), selectinload(ExpenseModel.group), selectinload(ExpenseModel.item) ) .where(ExpenseModel.id == expense_id) ) return result.scalars().first() async def get_expenses_for_list(db: AsyncSession, list_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]: result = await db.execute( select(ExpenseModel) .where(ExpenseModel.list_id == list_id) .order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc()) .offset(skip).limit(limit) .options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user))) # Also load user for each split ) return result.scalars().all() async def get_expenses_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]: result = await db.execute( select(ExpenseModel) .where(ExpenseModel.group_id == group_id) .order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc()) .offset(skip).limit(limit) .options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user))) ) return result.scalars().all() async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in: ExpenseUpdate) -> ExpenseModel: """ Updates an existing expense. Only allows updates to description, currency, and expense_date to avoid split complexities. Requires version matching for optimistic locking. """ if expense_db.version != expense_in.version: raise InvalidOperationError( f"Expense '{expense_db.description}' (ID: {expense_db.id}) has been modified. " f"Your version is {expense_in.version}, current version is {expense_db.version}. Please refresh.", # status_code=status.HTTP_409_CONFLICT # This would be for the API layer to set ) update_data = expense_in.model_dump(exclude_unset=True, exclude={"version"}) # Exclude version itself from data # Fields that are safe to update without affecting splits or core logic allowed_to_update = {"description", "currency", "expense_date"} updated_something = False for field, value in update_data.items(): if field in allowed_to_update: setattr(expense_db, field, value) updated_something = True else: # If any other field is present in the update payload, it's an invalid operation for this simple update raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed.") if not updated_something and not expense_in.model_fields_set.intersection(allowed_to_update): # No actual updatable fields were provided in the payload, even if others (like version) were. # This could be a non-issue, or an indication of a misuse of the endpoint. # For now, if only version was sent, we still increment if it matched. pass # Or raise InvalidOperationError("No updatable fields provided.") expense_db.version += 1 expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp try: await db.commit() await db.refresh(expense_db) except Exception as e: await db.rollback() # Consider specific DB error types if needed raise InvalidOperationError(f"Failed to update expense: {str(e)}") return expense_db async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None: """ Deletes an expense. Requires version matching if expected_version is provided. Associated ExpenseSplits are cascade deleted by the database foreign key constraint. """ if expected_version is not None and expense_db.version != expected_version: raise InvalidOperationError( f"Expense '{expense_db.description}' (ID: {expense_db.id}) cannot be deleted. " f"Your expected version {expected_version} does not match current version {expense_db.version}. Please refresh.", # status_code=status.HTTP_409_CONFLICT ) await db.delete(expense_db) try: await db.commit() except Exception as e: await db.rollback() raise InvalidOperationError(f"Failed to delete expense: {str(e)}") return None # Note: The InvalidOperationError is a simple ValueError placeholder. # For API endpoints, these should be translated to appropriate HTTPExceptions. # Ensure app.core.exceptions has proper HTTP error classes if needed. # app/crud/invite.py import secrets from datetime import datetime, timedelta, timezone from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy import delete # Import delete statement from typing import Optional from app.models import Invite as InviteModel # Invite codes should be reasonably unique, but handle potential collision MAX_CODE_GENERATION_ATTEMPTS = 5 async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 7) -> Optional[InviteModel]: """Creates a new invite code for a group.""" expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days) code = None attempts = 0 # Generate a unique code, retrying if a collision occurs (highly unlikely but safe) while attempts < MAX_CODE_GENERATION_ATTEMPTS: attempts += 1 potential_code = secrets.token_urlsafe(16) # Check if an *active* invite with this code already exists existing = await db.execute( select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1) ) if existing.scalar_one_or_none() is None: code = potential_code break if code is None: # Failed to generate a unique code after several attempts return None db_invite = InviteModel( code=code, group_id=group_id, created_by_id=creator_id, expires_at=expires_at, is_active=True ) db.add(db_invite) await db.commit() await db.refresh(db_invite) return db_invite async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]: """Gets an active and non-expired invite by its code.""" now = datetime.now(timezone.utc) result = await db.execute( select(InviteModel).where( InviteModel.code == code, InviteModel.is_active == True, InviteModel.expires_at > now ) ) return result.scalars().first() async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteModel: """Marks an invite as inactive (used).""" invite.is_active = False db.add(invite) # Add to session to track change await db.commit() await db.refresh(invite) return invite # Optional: Function to periodically delete old, inactive invites # async def cleanup_old_invites(db: AsyncSession, older_than_days: int = 30): ... # app/crud/settlement.py from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload, joinedload from sqlalchemy import or_ from decimal import Decimal from typing import List as PyList, Optional, Sequence from datetime import datetime, timezone from app.models import ( Settlement as SettlementModel, User as UserModel, Group as GroupModel ) from app.schemas.expense import SettlementCreate, SettlementUpdate # SettlementUpdate not used yet from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel: """Creates a new settlement record.""" # Validate Payer, Payee, and Group exist payer = await db.get(UserModel, settlement_in.paid_by_user_id) if not payer: raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer") payee = await db.get(UserModel, settlement_in.paid_to_user_id) if not payee: raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee") if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id: raise InvalidOperationError("Payer and Payee cannot be the same user.") group = await db.get(GroupModel, settlement_in.group_id) if not group: raise GroupNotFoundError(settlement_in.group_id) # Optional: Check if current_user_id is part of the group or is one of the parties involved # This is more of an API-level permission check but could be added here if strict. # For example: if current_user_id not in [settlement_in.paid_by_user_id, settlement_in.paid_to_user_id]: # is_in_group = await db.execute(select(UserGroupModel).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id)) # if not is_in_group.first(): # raise InvalidOperationError("You can only record settlements you are part of or for groups you belong to.") db_settlement = SettlementModel( group_id=settlement_in.group_id, paid_by_user_id=settlement_in.paid_by_user_id, paid_to_user_id=settlement_in.paid_to_user_id, amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc), description=settlement_in.description # created_by_user_id = current_user_id # Optional: Who recorded this settlement ) db.add(db_settlement) try: await db.commit() await db.refresh(db_settlement, attribute_names=["payer", "payee", "group"]) except Exception as e: await db.rollback() raise InvalidOperationError(f"Failed to save settlement: {str(e)}") return db_settlement async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional[SettlementModel]: result = await db.execute( select(SettlementModel) .options( selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), selectinload(SettlementModel.group) ) .where(SettlementModel.id == settlement_id) ) return result.scalars().first() async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]: result = await db.execute( select(SettlementModel) .where(SettlementModel.group_id == group_id) .order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc()) .offset(skip).limit(limit) .options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee)) ) return result.scalars().all() async def get_settlements_involving_user( db: AsyncSession, user_id: int, group_id: Optional[int] = None, skip: int = 0, limit: int = 100 ) -> Sequence[SettlementModel]: query = ( select(SettlementModel) .where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id)) .order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc()) .offset(skip).limit(limit) .options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), selectinload(SettlementModel.group)) ) if group_id: query = query.where(SettlementModel.group_id == group_id) result = await db.execute(query) return result.scalars().all() async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, settlement_in: SettlementUpdate) -> SettlementModel: """ Updates an existing settlement. Only allows updates to description and settlement_date. Requires version matching for optimistic locking. Assumes SettlementUpdate schema includes a version field. """ # Check if SettlementUpdate schema has 'version'. If not, this check needs to be adapted or version passed differently. if not hasattr(settlement_in, 'version') or settlement_db.version != settlement_in.version: raise InvalidOperationError( f"Settlement (ID: {settlement_db.id}) has been modified. " f"Your version does not match current version {settlement_db.version}. Please refresh.", # status_code=status.HTTP_409_CONFLICT ) update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"}) allowed_to_update = {"description", "settlement_date"} updated_something = False for field, value in update_data.items(): if field in allowed_to_update: setattr(settlement_db, field, value) updated_something = True else: raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed for settlements.") if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update): pass # No actual updatable fields provided, but version matched. settlement_db.version += 1 # Assuming SettlementModel has a version field, add if missing settlement_db.updated_at = datetime.now(timezone.utc) try: await db.commit() await db.refresh(settlement_db) except Exception as e: await db.rollback() raise InvalidOperationError(f"Failed to update settlement: {str(e)}") return settlement_db async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, expected_version: Optional[int] = None) -> None: """ Deletes a settlement. Requires version matching if expected_version is provided. Assumes SettlementModel has a version field. """ if expected_version is not None: if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version: raise InvalidOperationError( f"Settlement (ID: {settlement_db.id}) cannot be deleted. " f"Expected version {expected_version} does not match current version. Please refresh.", # status_code=status.HTTP_409_CONFLICT ) await db.delete(settlement_db) try: await db.commit() except Exception as e: await db.rollback() raise InvalidOperationError(f"Failed to delete settlement: {str(e)}") return None # TODO: Implement update_settlement (consider restrictions, versioning) # TODO: Implement delete_settlement (consider implications on balances) # app/database.py from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker, declarative_base from app.config import settings # Ensure DATABASE_URL is set before proceeding if not settings.DATABASE_URL: raise ValueError("DATABASE_URL is not configured in settings.") # Create the SQLAlchemy async engine # pool_recycle=3600 helps prevent stale connections on some DBs engine = create_async_engine( settings.DATABASE_URL, echo=True, # Log SQL queries (useful for debugging) future=True, # Use SQLAlchemy 2.0 style features pool_recycle=3600 # Optional: recycle connections after 1 hour ) # Create a configured "Session" class # expire_on_commit=False prevents attributes from expiring after commit AsyncSessionLocal = sessionmaker( bind=engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) # Base class for our ORM models Base = declarative_base() # Dependency to get DB session in path operations async def get_db() -> AsyncSession: # type: ignore """ Dependency function that yields an AsyncSession. Ensures the session is closed after the request. """ async with AsyncSessionLocal() as session: try: yield session # Optionally commit if your endpoints modify data directly # await session.commit() # Usually commit happens within endpoint logic except Exception: await session.rollback() raise finally: await session.close() # Not strictly necessary with async context manager, but explicit # app/schemas/expense.py from pydantic import BaseModel, ConfigDict, validator from typing import List, Optional from decimal import Decimal from datetime import datetime # Assuming SplitTypeEnum is accessible here, e.g., from app.models or app.core.enums # For now, let's redefine it or import it if models.py is parsable by Pydantic directly # If it's from app.models, you might need to make app.models.SplitTypeEnum Pydantic-compatible or map it. # For simplicity during schema definition, I'll redefine a string enum here. # In a real setup, ensure this aligns with the SQLAlchemy enum in models.py. from app.models import SplitTypeEnum # Try importing directly # --- ExpenseSplit Schemas --- class ExpenseSplitBase(BaseModel): user_id: int owed_amount: Decimal share_percentage: Optional[Decimal] = None share_units: Optional[int] = None class ExpenseSplitCreate(ExpenseSplitBase): pass # All fields from base are needed for creation class ExpenseSplitPublic(ExpenseSplitBase): id: int expense_id: int # user: Optional[UserPublic] # If we want to nest user details created_at: datetime updated_at: datetime model_config = ConfigDict(from_attributes=True) # --- Expense Schemas --- class ExpenseBase(BaseModel): description: str total_amount: Decimal currency: Optional[str] = "USD" expense_date: Optional[datetime] = None split_type: SplitTypeEnum list_id: Optional[int] = None group_id: Optional[int] = None # Should be present if list_id is not, and vice-versa item_id: Optional[int] = None paid_by_user_id: int class ExpenseCreate(ExpenseBase): # For EQUAL split, splits are generated. For others, they might be provided. # This logic will be in the CRUD: if split_type is EXACT_AMOUNTS, PERCENTAGE, SHARES, # then 'splits_in' should be provided. splits_in: Optional[List[ExpenseSplitCreate]] = None @validator('total_amount') def total_amount_must_be_positive(cls, v): if v <= Decimal('0'): raise ValueError('Total amount must be positive') return v # Basic validation: if list_id is None, group_id must be provided. # More complex cross-field validation might be needed. @validator('group_id', always=True) def check_list_or_group_id(cls, v, values): if values.get('list_id') is None and v is None: raise ValueError('Either list_id or group_id must be provided for an expense') return v class ExpenseUpdate(BaseModel): description: Optional[str] = None total_amount: Optional[Decimal] = None currency: Optional[str] = None expense_date: Optional[datetime] = None split_type: Optional[SplitTypeEnum] = None list_id: Optional[int] = None group_id: Optional[int] = None item_id: Optional[int] = None # paid_by_user_id is usually not updatable directly to maintain integrity. # Updating splits would be a more complex operation, potentially a separate endpoint or careful logic. version: int # For optimistic locking class ExpensePublic(ExpenseBase): id: int created_at: datetime updated_at: datetime version: int splits: List[ExpenseSplitPublic] = [] # paid_by_user: Optional[UserPublic] # If nesting user details # list: Optional[ListPublic] # If nesting list details # group: Optional[GroupPublic] # If nesting group details # item: Optional[ItemPublic] # If nesting item details model_config = ConfigDict(from_attributes=True) # --- Settlement Schemas --- class SettlementBase(BaseModel): group_id: int paid_by_user_id: int paid_to_user_id: int amount: Decimal settlement_date: Optional[datetime] = None description: Optional[str] = None class SettlementCreate(SettlementBase): @validator('amount') def amount_must_be_positive(cls, v): if v <= Decimal('0'): raise ValueError('Settlement amount must be positive') return v @validator('paid_to_user_id') def payer_and_payee_must_be_different(cls, v, values): if 'paid_by_user_id' in values and v == values['paid_by_user_id']: raise ValueError('Payer and payee cannot be the same user') return v class SettlementUpdate(BaseModel): amount: Optional[Decimal] = None settlement_date: Optional[datetime] = None description: Optional[str] = None # group_id, paid_by_user_id, paid_to_user_id are typically not updatable. version: int # For optimistic locking class SettlementPublic(SettlementBase): id: int created_at: datetime updated_at: datetime # payer: Optional[UserPublic] # payee: Optional[UserPublic] # group: Optional[GroupPublic] model_config = ConfigDict(from_attributes=True) # Placeholder for nested schemas (e.g., UserPublic) if needed # from app.schemas.user import UserPublic # from app.schemas.list import ListPublic # from app.schemas.group import GroupPublic # from app.schemas.item import ItemPublic # app/schemas/group.py from pydantic import BaseModel, ConfigDict from datetime import datetime from typing import Optional, List from .user import UserPublic # Import UserPublic to represent members # Properties to receive via API on creation class GroupCreate(BaseModel): name: str # Properties to return to client class GroupPublic(BaseModel): id: int name: str created_by_id: int created_at: datetime members: Optional[List[UserPublic]] = None # Include members only in detailed view model_config = ConfigDict(from_attributes=True) # Properties stored in DB (if needed, often GroupPublic is sufficient) # class GroupInDB(GroupPublic): # pass # app/schemas/invite.py from pydantic import BaseModel from datetime import datetime # Properties to receive when accepting an invite class InviteAccept(BaseModel): code: str # Properties to return when an invite is created class InviteCodePublic(BaseModel): code: str expires_at: datetime group_id: int # Properties for internal use/DB (optional) # class Invite(InviteCodePublic): # id: int # created_by_id: int # is_active: bool = True # model_config = ConfigDict(from_attributes=True) # app/schemas/message.py from pydantic import BaseModel class Message(BaseModel): detail: str # app/schemas/ocr.py from pydantic import BaseModel from typing import List class OcrExtractResponse(BaseModel): extracted_items: List[str] # A list of potential item names import pytest from fastapi import status from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession from typing import Callable, Dict, Any from app.models import User as UserModel, Group as GroupModel, List as ListModel from app.schemas.expense import ExpenseCreate from app.core.config import settings # Helper to create a URL for an endpoint API_V1_STR = settings.API_V1_STR def expense_url(endpoint: str = "") -> str: return f"{API_V1_STR}/financials/expenses{endpoint}" def settlement_url(endpoint: str = "") -> str: return f"{API_V1_STR}/financials/settlements{endpoint}" @pytest.mark.asyncio async def test_create_new_expense_success_list_context( client: AsyncClient, db_session: AsyncSession, # Assuming a fixture for db session normal_user_token_headers: Dict[str, str], # Assuming a fixture for user auth test_user: UserModel, # Assuming a fixture for a test user test_list_user_is_member: ListModel, # Assuming a fixture for a list user is member of ) -> None: """ Test successful creation of a new expense linked to a list. """ expense_data = ExpenseCreate( description="Test Expense for List", amount=100.00, currency="USD", paid_by_user_id=test_user.id, list_id=test_list_user_is_member.id, group_id=None, # group_id should be derived from list if list is in a group # category_id: Optional[int] = None # Assuming category is optional # expense_date: Optional[date] = None # Assuming date is optional # splits: Optional[List[SplitCreate]] = [] # Assuming splits are optional for now ) response = await client.post( expense_url(), headers=normal_user_token_headers, json=expense_data.model_dump(exclude_unset=True) ) assert response.status_code == status.HTTP_201_CREATED content = response.json() assert content["description"] == expense_data.description assert content["amount"] == expense_data.amount assert content["currency"] == expense_data.currency assert content["paid_by_user_id"] == test_user.id assert content["list_id"] == test_list_user_is_member.id # If test_list_user_is_member has a group_id, it should be set in the response if test_list_user_is_member.group_id: assert content["group_id"] == test_list_user_is_member.group_id else: assert content["group_id"] is None assert "id" in content assert "created_at" in content assert "updated_at" in content assert "version" in content assert content["version"] == 1 @pytest.mark.asyncio async def test_create_new_expense_success_group_context( client: AsyncClient, normal_user_token_headers: Dict[str, str], test_user: UserModel, test_group_user_is_member: GroupModel, # Assuming a fixture for a group user is member of ) -> None: """ Test successful creation of a new expense linked directly to a group. """ expense_data = ExpenseCreate( description="Test Expense for Group", amount=50.00, currency="EUR", paid_by_user_id=test_user.id, group_id=test_group_user_is_member.id, list_id=None, ) response = await client.post( expense_url(), headers=normal_user_token_headers, json=expense_data.model_dump(exclude_unset=True) ) assert response.status_code == status.HTTP_201_CREATED content = response.json() assert content["description"] == expense_data.description assert content["paid_by_user_id"] == test_user.id assert content["group_id"] == test_group_user_is_member.id assert content["list_id"] is None assert content["version"] == 1 @pytest.mark.asyncio async def test_create_new_expense_fail_no_list_or_group( client: AsyncClient, normal_user_token_headers: Dict[str, str], test_user: UserModel, ) -> None: """ Test expense creation fails if neither list_id nor group_id is provided. """ expense_data = ExpenseCreate( description="Test Invalid Expense", amount=10.00, currency="USD", paid_by_user_id=test_user.id, list_id=None, group_id=None, ) response = await client.post( expense_url(), headers=normal_user_token_headers, json=expense_data.model_dump(exclude_unset=True) ) assert response.status_code == status.HTTP_400_BAD_REQUEST content = response.json() assert "Expense must be linked to a list_id or group_id" in content["detail"] @pytest.mark.asyncio async def test_create_new_expense_fail_paid_by_other_not_owner( client: AsyncClient, normal_user_token_headers: Dict[str, str], # User is member, not owner test_user: UserModel, # This is the current_user (member) test_group_user_is_member: GroupModel, # Group the current_user is a member of another_user_in_group: UserModel, # Another user in the same group # Ensure test_user is NOT an owner of test_group_user_is_member for this test ) -> None: """ Test creation fails if paid_by_user_id is another user, and current_user is not a group owner. Assumes normal_user_token_headers belongs to a user who is a member but not an owner of test_group_user_is_member. """ expense_data = ExpenseCreate( description="Expense paid by other", amount=75.00, currency="GBP", paid_by_user_id=another_user_in_group.id, # Paid by someone else group_id=test_group_user_is_member.id, list_id=None, ) response = await client.post( expense_url(), headers=normal_user_token_headers, # Current user is a member, not owner json=expense_data.model_dump(exclude_unset=True) ) assert response.status_code == status.HTTP_403_FORBIDDEN content = response.json() assert "Only group owners can create expenses paid by others" in content["detail"] # --- Add tests for other endpoints below --- # GET /expenses/{expense_id} @pytest.mark.asyncio async def test_get_expense_success( client: AsyncClient, normal_user_token_headers: Dict[str, str], test_user: UserModel, # Assume an existing expense created by test_user or in a group/list they have access to # This would typically be created by another test or a fixture created_expense: ExpensePublic, # Assuming a fixture that provides a created expense ) -> None: """ Test successfully retrieving an existing expense. User has access either by being the payer, or via list/group membership. """ response = await client.get( expense_url(f"/{created_expense.id}"), headers=normal_user_token_headers ) assert response.status_code == status.HTTP_200_OK content = response.json() assert content["id"] == created_expense.id assert content["description"] == created_expense.description assert content["amount"] == created_expense.amount assert content["paid_by_user_id"] == created_expense.paid_by_user_id if created_expense.list_id: assert content["list_id"] == created_expense.list_id if created_expense.group_id: assert content["group_id"] == created_expense.group_id # TODO: Add more tests for get_expense: # - expense not found -> 404 # - user has no access (not payer, not in list, not in group if applicable) -> 403 # - expense in list, user has list access # - expense in group, user has group access # - expense personal (no list, no group), user is payer # - expense personal (no list, no group), user is NOT payer -> 403 @pytest.mark.asyncio async def test_get_expense_not_found( client: AsyncClient, normal_user_token_headers: Dict[str, str], ) -> None: """ Test retrieving a non-existent expense results in 404. """ non_existent_expense_id = 9999999 response = await client.get( expense_url(f"/{non_existent_expense_id}"), headers=normal_user_token_headers ) assert response.status_code == status.HTTP_404_NOT_FOUND content = response.json() assert "not found" in content["detail"].lower() @pytest.mark.asyncio async def test_get_expense_forbidden_personal_expense_other_user( client: AsyncClient, normal_user_token_headers: Dict[str, str], # Belongs to test_user # Fixture for an expense paid by another_user, not linked to any list/group test_user has access to personal_expense_of_another_user: ExpensePublic ) -> None: """ Test retrieving a personal expense of another user (no shared list/group) results in 403. """ response = await client.get( expense_url(f"/{personal_expense_of_another_user.id}"), headers=normal_user_token_headers # Current user querying ) assert response.status_code == status.HTTP_403_FORBIDDEN content = response.json() assert "Not authorized to view this expense" in content["detail"] # GET /lists/{list_id}/expenses @pytest.mark.asyncio async def test_list_list_expenses_success( client: AsyncClient, normal_user_token_headers: Dict[str, str], test_user: UserModel, test_list_user_is_member: ListModel, # List the user is a member of # Assume some expenses have been created for this list by a fixture or previous tests ) -> None: """ Test successfully listing expenses for a list the user has access to. """ response = await client.get( f"{API_V1_STR}/financials/lists/{test_list_user_is_member.id}/expenses", headers=normal_user_token_headers ) assert response.status_code == status.HTTP_200_OK content = response.json() assert isinstance(content, list) for expense_item in content: # Renamed from expense to avoid conflict if a fixture is named expense assert expense_item["list_id"] == test_list_user_is_member.id # TODO: Add more tests for list_list_expenses: # - list not found -> 404 (ListNotFoundError from check_list_access_for_financials) # - user has no access to list -> 403 (ListPermissionError from check_list_access_for_financials) # - list exists but has no expenses -> empty list, 200 OK # - test pagination (skip, limit) @pytest.mark.asyncio async def test_list_list_expenses_list_not_found( client: AsyncClient, normal_user_token_headers: Dict[str, str], ) -> None: """ Test listing expenses for a non-existent list results in 404 (or appropriate error from permission check). The check_list_access_for_financials raises ListNotFoundError, which might be caught and raised as 404. The endpoint itself also has a get for ListModel, which would 404 first if permission check passed (not possible here). Based on financials.py, ListNotFoundError is raised by check_list_access_for_financials. This should translate to a 404 or a 403 if ListPermissionError wraps it with an action. The current ListPermissionError in check_list_access_for_financials re-raises ListNotFoundError if that's the cause. ListNotFoundError is a custom exception often mapped to 404. Let's assume ListNotFoundError results in a 404 response from an exception handler. """ non_existent_list_id = 99999 response = await client.get( f"{API_V1_STR}/financials/lists/{non_existent_list_id}/expenses", headers=normal_user_token_headers ) # The ListNotFoundError is raised by the check_list_access_for_financials helper, # which is then re-raised. FastAPI default exception handlers or custom ones # would convert this to an HTTP response. Typically NotFoundError -> 404. # If ListPermissionError catches it and re-raises it specifically, it might be 403. # From the code: `except ListNotFoundError: raise` means it propagates. # Let's assume a global handler for NotFoundError derived exceptions leads to 404. assert response.status_code == status.HTTP_404_NOT_FOUND # The actual detail might vary based on how ListNotFoundError is handled by FastAPI # For now, we check the status code. If financials.py maps it differently, this will need adjustment. # Based on `raise ListNotFoundError(expense_in.list_id)` in create_new_expense, and if that leads to 400, # this might be inconsistent. However, `check_list_access_for_financials` just re-raises ListNotFoundError. # Let's stick to expecting 404 for a direct not found error from a path parameter. content = response.json() assert "list not found" in content["detail"].lower() # Common detail for not found errors @pytest.mark.asyncio async def test_list_list_expenses_no_access( client: AsyncClient, normal_user_token_headers: Dict[str, str], # User who will attempt access test_list_user_not_member: ListModel, # A list current user is NOT a member of ) -> None: """ Test listing expenses for a list the user does not have access to (403 Forbidden). """ response = await client.get( f"{API_V1_STR}/financials/lists/{test_list_user_not_member.id}/expenses", headers=normal_user_token_headers ) assert response.status_code == status.HTTP_403_FORBIDDEN content = response.json() assert f"User does not have permission to access financial data for list {test_list_user_not_member.id}" in content["detail"] @pytest.mark.asyncio async def test_list_list_expenses_empty( client: AsyncClient, normal_user_token_headers: Dict[str, str], test_list_user_is_member_no_expenses: ListModel, # List user is member of, but has no expenses ) -> None: """ Test listing expenses for an accessible list that has no expenses (empty list, 200 OK). """ response = await client.get( f"{API_V1_STR}/financials/lists/{test_list_user_is_member_no_expenses.id}/expenses", headers=normal_user_token_headers ) assert response.status_code == status.HTTP_200_OK content = response.json() assert isinstance(content, list) assert len(content) == 0 # GET /groups/{group_id}/expenses @pytest.mark.asyncio async def test_list_group_expenses_success( client: AsyncClient, normal_user_token_headers: Dict[str, str], test_user: UserModel, test_group_user_is_member: GroupModel, # Group the user is a member of # Assume some expenses have been created for this group by a fixture or previous tests ) -> None: """ Test successfully listing expenses for a group the user has access to. """ response = await client.get( f"{API_V1_STR}/financials/groups/{test_group_user_is_member.id}/expenses", headers=normal_user_token_headers ) assert response.status_code == status.HTTP_200_OK content = response.json() assert isinstance(content, list) # Further assertions can be made here, e.g., checking if all expenses belong to the group for expense_item in content: assert expense_item["group_id"] == test_group_user_is_member.id # Expenses in a group might also have a list_id if they were added via a list belonging to that group # TODO: Add more tests for list_group_expenses: # - group not found -> 404 (GroupNotFoundError from check_group_membership) # - user has no access to group (not a member) -> 403 (GroupMembershipError from check_group_membership) # - group exists but has no expenses -> empty list, 200 OK # - test pagination (skip, limit) # PUT /expenses/{expense_id} # DELETE /expenses/{expense_id} # GET /settlements/{settlement_id} # POST /settlements # GET /groups/{group_id}/settlements # PUT /settlements/{settlement_id} # DELETE /settlements/{settlement_id} pytest.skip("Still implementing other tests", allow_module_level=True) import pytest from fastapi import status from app.core.exceptions import ( ListNotFoundError, ListPermissionError, ListCreatorRequiredError, GroupNotFoundError, GroupPermissionError, GroupMembershipError, GroupOperationError, GroupValidationError, ItemNotFoundError, UserNotFoundError, InvalidOperationError, DatabaseConnectionError, DatabaseIntegrityError, DatabaseTransactionError, DatabaseQueryError, OCRServiceUnavailableError, OCRServiceConfigError, OCRUnexpectedError, OCRQuotaExceededError, InvalidFileTypeError, FileTooLargeError, OCRProcessingError, EmailAlreadyRegisteredError, UserCreationError, InviteNotFoundError, InviteExpiredError, InviteAlreadyUsedError, InviteCreationError, ListStatusNotFoundError, ConflictError, InvalidCredentialsError, NotAuthenticatedError, JWTError, JWTUnexpectedError ) # TODO: It seems like settings are used in some exceptions. # You will need to mock app.config.settings for these tests to pass. # Consider using pytest-mock or unittest.mock.patch. # Example: from app.config import settings def test_list_not_found_error(): list_id = 123 with pytest.raises(ListNotFoundError) as excinfo: raise ListNotFoundError(list_id=list_id) assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND assert excinfo.value.detail == f"List {list_id} not found" def test_list_permission_error(): list_id = 456 action = "delete" with pytest.raises(ListPermissionError) as excinfo: raise ListPermissionError(list_id=list_id, action=action) assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN assert excinfo.value.detail == f"You do not have permission to {action} list {list_id}" def test_list_permission_error_default_action(): list_id = 789 with pytest.raises(ListPermissionError) as excinfo: raise ListPermissionError(list_id=list_id) assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN assert excinfo.value.detail == f"You do not have permission to access list {list_id}" def test_list_creator_required_error(): list_id = 101 action = "update" with pytest.raises(ListCreatorRequiredError) as excinfo: raise ListCreatorRequiredError(list_id=list_id, action=action) assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN assert excinfo.value.detail == f"Only the list creator can {action} list {list_id}" def test_group_not_found_error(): group_id = 202 with pytest.raises(GroupNotFoundError) as excinfo: raise GroupNotFoundError(group_id=group_id) assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND assert excinfo.value.detail == f"Group {group_id} not found" def test_group_permission_error(): group_id = 303 action = "invite" with pytest.raises(GroupPermissionError) as excinfo: raise GroupPermissionError(group_id=group_id, action=action) assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN assert excinfo.value.detail == f"You do not have permission to {action} in group {group_id}" def test_group_membership_error(): group_id = 404 action = "post" with pytest.raises(GroupMembershipError) as excinfo: raise GroupMembershipError(group_id=group_id, action=action) assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN assert excinfo.value.detail == f"You must be a member of group {group_id} to {action}" def test_group_membership_error_default_action(): group_id = 505 with pytest.raises(GroupMembershipError) as excinfo: raise GroupMembershipError(group_id=group_id) assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN assert excinfo.value.detail == f"You must be a member of group {group_id} to access" def test_group_operation_error(): detail_msg = "Failed to perform group operation." with pytest.raises(GroupOperationError) as excinfo: raise GroupOperationError(detail=detail_msg) assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert excinfo.value.detail == detail_msg def test_group_validation_error(): detail_msg = "Invalid group data." with pytest.raises(GroupValidationError) as excinfo: raise GroupValidationError(detail=detail_msg) assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST assert excinfo.value.detail == detail_msg def test_item_not_found_error(): item_id = 606 with pytest.raises(ItemNotFoundError) as excinfo: raise ItemNotFoundError(item_id=item_id) assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND assert excinfo.value.detail == f"Item {item_id} not found" def test_user_not_found_error_no_identifier(): with pytest.raises(UserNotFoundError) as excinfo: raise UserNotFoundError() assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND assert excinfo.value.detail == "User not found." def test_user_not_found_error_with_id(): user_id = 707 with pytest.raises(UserNotFoundError) as excinfo: raise UserNotFoundError(user_id=user_id) assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND assert excinfo.value.detail == f"User with ID {user_id} not found." def test_user_not_found_error_with_identifier_string(): identifier = "test_user" with pytest.raises(UserNotFoundError) as excinfo: raise UserNotFoundError(identifier=identifier) assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND assert excinfo.value.detail == f"User with identifier '{identifier}' not found." def test_invalid_operation_error(): detail_msg = "This operation is not allowed." with pytest.raises(InvalidOperationError) as excinfo: raise InvalidOperationError(detail=detail_msg) assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST assert excinfo.value.detail == detail_msg def test_invalid_operation_error_custom_status(): detail_msg = "This operation is forbidden." custom_status = status.HTTP_403_FORBIDDEN with pytest.raises(InvalidOperationError) as excinfo: raise InvalidOperationError(detail=detail_msg, status_code=custom_status) assert excinfo.value.status_code == custom_status assert excinfo.value.detail == detail_msg # The following exceptions depend on `settings` # We need to mock `app.config.settings` for these tests. # For now, I will add placeholder tests that would fail without mocking. # Consider using pytest-mock or unittest.mock.patch for this. # def test_database_connection_error(mocker): # mocker.patch("app.core.exceptions.settings.DB_CONNECTION_ERROR", "Test DB connection error") # with pytest.raises(DatabaseConnectionError) as excinfo: # raise DatabaseConnectionError() # assert excinfo.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE # assert excinfo.value.detail == "Test DB connection error" # settings.DB_CONNECTION_ERROR # def test_database_integrity_error(mocker): # mocker.patch("app.core.exceptions.settings.DB_INTEGRITY_ERROR", "Test DB integrity error") # with pytest.raises(DatabaseIntegrityError) as excinfo: # raise DatabaseIntegrityError() # assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST # assert excinfo.value.detail == "Test DB integrity error" # settings.DB_INTEGRITY_ERROR # def test_database_transaction_error(mocker): # mocker.patch("app.core.exceptions.settings.DB_TRANSACTION_ERROR", "Test DB transaction error") # with pytest.raises(DatabaseTransactionError) as excinfo: # raise DatabaseTransactionError() # assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR # assert excinfo.value.detail == "Test DB transaction error" # settings.DB_TRANSACTION_ERROR # def test_database_query_error(mocker): # mocker.patch("app.core.exceptions.settings.DB_QUERY_ERROR", "Test DB query error") # with pytest.raises(DatabaseQueryError) as excinfo: # raise DatabaseQueryError() # assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR # assert excinfo.value.detail == "Test DB query error" # settings.DB_QUERY_ERROR # def test_ocr_service_unavailable_error(mocker): # mocker.patch("app.core.exceptions.settings.OCR_SERVICE_UNAVAILABLE", "Test OCR unavailable") # with pytest.raises(OCRServiceUnavailableError) as excinfo: # raise OCRServiceUnavailableError() # assert excinfo.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE # assert excinfo.value.detail == "Test OCR unavailable" # settings.OCR_SERVICE_UNAVAILABLE # def test_ocr_service_config_error(mocker): # mocker.patch("app.core.exceptions.settings.OCR_SERVICE_CONFIG_ERROR", "Test OCR config error") # with pytest.raises(OCRServiceConfigError) as excinfo: # raise OCRServiceConfigError() # assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR # assert excinfo.value.detail == "Test OCR config error" # settings.OCR_SERVICE_CONFIG_ERROR # def test_ocr_unexpected_error(mocker): # mocker.patch("app.core.exceptions.settings.OCR_UNEXPECTED_ERROR", "Test OCR unexpected error") # with pytest.raises(OCRUnexpectedError) as excinfo: # raise OCRUnexpectedError() # assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR # assert excinfo.value.detail == "Test OCR unexpected error" # settings.OCR_UNEXPECTED_ERROR # def test_ocr_quota_exceeded_error(mocker): # mocker.patch("app.core.exceptions.settings.OCR_QUOTA_EXCEEDED", "Test OCR quota exceeded") # with pytest.raises(OCRQuotaExceededError) as excinfo: # raise OCRQuotaExceededError() # assert excinfo.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS # assert excinfo.value.detail == "Test OCR quota exceeded" # settings.OCR_QUOTA_EXCEEDED # def test_invalid_file_type_error(mocker): # test_types = ["png", "jpg"] # mocker.patch("app.core.exceptions.settings.ALLOWED_IMAGE_TYPES", test_types) # mocker.patch("app.core.exceptions.settings.OCR_INVALID_FILE_TYPE", "Invalid type: {types}") # with pytest.raises(InvalidFileTypeError) as excinfo: # raise InvalidFileTypeError() # assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST # assert excinfo.value.detail == f"Invalid type: {', '.join(test_types)}" # settings.OCR_INVALID_FILE_TYPE.format(types=", ".join(settings.ALLOWED_IMAGE_TYPES)) # def test_file_too_large_error(mocker): # max_size = 10 # mocker.patch("app.core.exceptions.settings.MAX_FILE_SIZE_MB", max_size) # mocker.patch("app.core.exceptions.settings.OCR_FILE_TOO_LARGE", "File too large: {size}MB") # with pytest.raises(FileTooLargeError) as excinfo: # raise FileTooLargeError() # assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST # assert excinfo.value.detail == f"File too large: {max_size}MB" # settings.OCR_FILE_TOO_LARGE.format(size=settings.MAX_FILE_SIZE_MB) # def test_ocr_processing_error(mocker): # error_detail = "Specific OCR error" # mocker.patch("app.core.exceptions.settings.OCR_PROCESSING_ERROR", "OCR processing failed: {detail}") # with pytest.raises(OCRProcessingError) as excinfo: # raise OCRProcessingError(detail=error_detail) # assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST # assert excinfo.value.detail == f"OCR processing failed: {error_detail}" # settings.OCR_PROCESSING_ERROR.format(detail=detail) def test_email_already_registered_error(): with pytest.raises(EmailAlreadyRegisteredError) as excinfo: raise EmailAlreadyRegisteredError() assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST assert excinfo.value.detail == "Email already registered." def test_user_creation_error(): with pytest.raises(UserCreationError) as excinfo: raise UserCreationError() assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert excinfo.value.detail == "An error occurred during user creation." def test_invite_not_found_error(): invite_code = "TESTCODE123" with pytest.raises(InviteNotFoundError) as excinfo: raise InviteNotFoundError(invite_code=invite_code) assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND assert excinfo.value.detail == f"Invite code {invite_code} not found" def test_invite_expired_error(): invite_code = "EXPIREDCODE" with pytest.raises(InviteExpiredError) as excinfo: raise InviteExpiredError(invite_code=invite_code) assert excinfo.value.status_code == status.HTTP_410_GONE assert excinfo.value.detail == f"Invite code {invite_code} has expired" def test_invite_already_used_error(): invite_code = "USEDCODE" with pytest.raises(InviteAlreadyUsedError) as excinfo: raise InviteAlreadyUsedError(invite_code=invite_code) assert excinfo.value.status_code == status.HTTP_410_GONE assert excinfo.value.detail == f"Invite code {invite_code} has already been used" def test_invite_creation_error(): group_id = 909 with pytest.raises(InviteCreationError) as excinfo: raise InviteCreationError(group_id=group_id) assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert excinfo.value.detail == f"Failed to create invite for group {group_id}" def test_list_status_not_found_error(): list_id = 808 with pytest.raises(ListStatusNotFoundError) as excinfo: raise ListStatusNotFoundError(list_id=list_id) assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND assert excinfo.value.detail == f"Status for list {list_id} not found" def test_conflict_error(): detail_msg = "Resource version mismatch." with pytest.raises(ConflictError) as excinfo: raise ConflictError(detail=detail_msg) assert excinfo.value.status_code == status.HTTP_409_CONFLICT assert excinfo.value.detail == detail_msg # Tests for auth-related exceptions that likely require mocking app.config.settings # def test_invalid_credentials_error(mocker): # mocker.patch("app.core.exceptions.settings.AUTH_INVALID_CREDENTIALS", "Invalid test credentials") # mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth") # mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer") # with pytest.raises(InvalidCredentialsError) as excinfo: # raise InvalidCredentialsError() # assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED # assert excinfo.value.detail == "Invalid test credentials" # assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"invalid_credentials\""} # def test_not_authenticated_error(mocker): # mocker.patch("app.core.exceptions.settings.AUTH_NOT_AUTHENTICATED", "Not authenticated test") # mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth") # mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer") # with pytest.raises(NotAuthenticatedError) as excinfo: # raise NotAuthenticatedError() # assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED # assert excinfo.value.detail == "Not authenticated test" # assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"not_authenticated\""} # def test_jwt_error(mocker): # error_msg = "Test JWT issue" # mocker.patch("app.core.exceptions.settings.JWT_ERROR", "JWT error: {error}") # mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth") # mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer") # with pytest.raises(JWTError) as excinfo: # raise JWTError(error=error_msg) # assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED # assert excinfo.value.detail == f"JWT error: {error_msg}" # assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"invalid_token\""} # def test_jwt_unexpected_error(mocker): # error_msg = "Unexpected test JWT issue" # mocker.patch("app.core.exceptions.settings.JWT_UNEXPECTED_ERROR", "Unexpected JWT error: {error}") # mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth") # mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer") # with pytest.raises(JWTUnexpectedError) as excinfo: # raise JWTUnexpectedError(error=error_msg) # assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED # assert excinfo.value.detail == f"Unexpected JWT error: {error_msg}" # assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"invalid_token\""} import pytest import asyncio from unittest.mock import patch, MagicMock, AsyncMock import google.generativeai as genai from google.api_core import exceptions as google_exceptions # Modules to test from app.core import gemini from app.core.exceptions import ( OCRServiceUnavailableError, OCRServiceConfigError, OCRUnexpectedError, OCRQuotaExceededError ) # Default Mock Settings @pytest.fixture def mock_gemini_settings(): settings_mock = MagicMock() settings_mock.GEMINI_API_KEY = "test_api_key" settings_mock.GEMINI_MODEL_NAME = "gemini-pro-vision" settings_mock.GEMINI_SAFETY_SETTINGS = { "HARM_CATEGORY_HARASSMENT": "BLOCK_NONE", "HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE", } settings_mock.GEMINI_GENERATION_CONFIG = {"temperature": 0.7} settings_mock.OCR_ITEM_EXTRACTION_PROMPT = "Extract items:" return settings_mock @pytest.fixture def mock_generative_model_instance(): model_instance = MagicMock(spec=genai.GenerativeModel) model_instance.generate_content_async = AsyncMock() return model_instance @pytest.fixture @patch('google.generativeai.GenerativeModel') @patch('google.generativeai.configure') def patch_google_ai_client(mock_configure, mock_generative_model, mock_generative_model_instance): mock_generative_model.return_value = mock_generative_model_instance return mock_configure, mock_generative_model, mock_generative_model_instance # --- Test Gemini Client Initialization (Global Client) --- # Parametrize to test different scenarios for the global client init @pytest.mark.parametrize( "api_key_present, configure_raises, model_init_raises, expected_error_message_part", [ (True, None, None, None), # Success (False, None, None, "GEMINI_API_KEY not configured"), # API key missing (True, Exception("Config error"), None, "Failed to initialize Gemini AI client: Config error"), # genai.configure error (True, None, Exception("Model init error"), "Failed to initialize Gemini AI client: Model init error"), # GenerativeModel error ] ) @patch('app.core.gemini.genai') # Patch genai within the gemini module def test_global_gemini_client_initialization( mock_genai_module, mock_gemini_settings, api_key_present, configure_raises, model_init_raises, expected_error_message_part ): """Tests the global gemini_flash_client initialization logic in app.core.gemini.""" # We need to reload the module to re-trigger its top-level initialization code. # This is a bit tricky. A common pattern is to put init logic in a function. # For now, we'll try to simulate it by controlling mocks before module access. with patch('app.core.gemini.settings', mock_gemini_settings): if not api_key_present: mock_gemini_settings.GEMINI_API_KEY = None mock_genai_module.configure = MagicMock() mock_genai_module.GenerativeModel = MagicMock() mock_genai_module.types = genai.types # Keep original types mock_genai_module.HarmCategory = genai.HarmCategory mock_genai_module.HarmBlockThreshold = genai.HarmBlockThreshold if configure_raises: mock_genai_module.configure.side_effect = configure_raises if model_init_raises: mock_genai_module.GenerativeModel.side_effect = model_init_raises # Python modules are singletons. To re-run top-level code, we need to unload and reload. # This is generally discouraged. It's better to have an explicit init function. # For this test, we'll check the state variables set by the module's import-time code. import importlib importlib.reload(gemini) # This re-runs the try-except block at the top of gemini.py if expected_error_message_part: assert gemini.gemini_initialization_error is not None assert expected_error_message_part in gemini.gemini_initialization_error assert gemini.gemini_flash_client is None else: assert gemini.gemini_initialization_error is None assert gemini.gemini_flash_client is not None mock_genai_module.configure.assert_called_once_with(api_key="test_api_key") mock_genai_module.GenerativeModel.assert_called_once() # Could add more assertions about safety_settings and generation_config here # Clean up after reload for other tests importlib.reload(gemini) # --- Test get_gemini_client --- # Assuming the global client tests above set the stage for these @patch('app.core.gemini.gemini_flash_client', new_callable=MagicMock) @patch('app.core.gemini.gemini_initialization_error', None) def test_get_gemini_client_success(mock_client_var, mock_error_var): mock_client_var.return_value = MagicMock(spec=genai.GenerativeModel) # Simulate an initialized client gemini.gemini_flash_client = mock_client_var # Assign the mock gemini.gemini_initialization_error = None client = gemini.get_gemini_client() assert client is not None @patch('app.core.gemini.gemini_flash_client', None) @patch('app.core.gemini.gemini_initialization_error', "Test init error") def test_get_gemini_client_init_error(mock_client_var, mock_error_var): gemini.gemini_flash_client = None gemini.gemini_initialization_error = "Test init error" with pytest.raises(RuntimeError, match="Gemini client could not be initialized: Test init error"): gemini.get_gemini_client() @patch('app.core.gemini.gemini_flash_client', None) @patch('app.core.gemini.gemini_initialization_error', None) # No init error, but client is None def test_get_gemini_client_none_client_unknown_issue(mock_client_var, mock_error_var): gemini.gemini_flash_client = None gemini.gemini_initialization_error = None with pytest.raises(RuntimeError, match="Gemini client is not available \(unknown initialization issue\)."): gemini.get_gemini_client() # --- Tests for extract_items_from_image_gemini --- (Simplified for brevity, needs more cases) @pytest.mark.asyncio async def test_extract_items_from_image_gemini_success( mock_gemini_settings, mock_generative_model_instance, patch_google_ai_client # This fixture patches google.generativeai for the module ): """ Test successful item extraction """ # Ensure the global client is mocked to be the one we control with patch('app.core.gemini.settings', mock_gemini_settings), \ patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \ patch('app.core.gemini.gemini_initialization_error', None): mock_response = MagicMock() mock_response.text = "Item 1\nItem 2\n Item 3 \n\nAnother Item" # Simulate the structure for safety checks if needed mock_candidate = MagicMock() mock_candidate.content.parts = [MagicMock(text=mock_response.text)] mock_candidate.finish_reason = 'STOP' # Or whatever is appropriate for success mock_candidate.safety_ratings = [] mock_response.candidates = [mock_candidate] mock_generative_model_instance.generate_content_async.return_value = mock_response image_bytes = b"dummy_image_bytes" mime_type = "image/png" items = await gemini.extract_items_from_image_gemini(image_bytes, mime_type) mock_generative_model_instance.generate_content_async.assert_called_once_with([ mock_gemini_settings.OCR_ITEM_EXTRACTION_PROMPT, {"mime_type": mime_type, "data": image_bytes} ]) assert items == ["Item 1", "Item 2", "Item 3", "Another Item"] @pytest.mark.asyncio async def test_extract_items_from_image_gemini_client_not_init( mock_gemini_settings ): with patch('app.core.gemini.settings', mock_gemini_settings), \ patch('app.core.gemini.gemini_flash_client', None), \ patch('app.core.gemini.gemini_initialization_error', "Initialization failed explicitly"): image_bytes = b"dummy_image_bytes" with pytest.raises(RuntimeError, match="Gemini client could not be initialized: Initialization failed explicitly"): await gemini.extract_items_from_image_gemini(image_bytes) @pytest.mark.asyncio @patch('app.core.gemini.get_gemini_client') # Mock the getter to control the client directly async def test_extract_items_from_image_gemini_api_quota_error( mock_get_client, mock_gemini_settings, mock_generative_model_instance ): mock_get_client.return_value = mock_generative_model_instance mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ResourceExhausted("Quota exceeded") with patch('app.core.gemini.settings', mock_gemini_settings): image_bytes = b"dummy_image_bytes" with pytest.raises(google_exceptions.ResourceExhausted, match="Quota exceeded"): await gemini.extract_items_from_image_gemini(image_bytes) # --- Tests for GeminiOCRService --- (Example tests, more needed) @patch('app.core.gemini.genai.configure') @patch('app.core.gemini.genai.GenerativeModel') def test_gemini_ocr_service_init_success(MockGenerativeModel, MockConfigure, mock_gemini_settings, mock_generative_model_instance): MockGenerativeModel.return_value = mock_generative_model_instance with patch('app.core.gemini.settings', mock_gemini_settings): service = gemini.GeminiOCRService() MockConfigure.assert_called_once_with(api_key=mock_gemini_settings.GEMINI_API_KEY) MockGenerativeModel.assert_called_once_with(mock_gemini_settings.GEMINI_MODEL_NAME) assert service.model == mock_generative_model_instance # Could add assertions for safety_settings and generation_config if they are set directly on model @patch('app.core.gemini.genai.configure') @patch('app.core.gemini.genai.GenerativeModel', side_effect=Exception("Init model failed")) def test_gemini_ocr_service_init_failure(MockGenerativeModel, MockConfigure, mock_gemini_settings): with patch('app.core.gemini.settings', mock_gemini_settings): with pytest.raises(OCRServiceConfigError): gemini.GeminiOCRService() @pytest.mark.asyncio async def test_gemini_ocr_service_extract_items_success(mock_gemini_settings, mock_generative_model_instance): mock_response = MagicMock() mock_response.text = "Apple\nBanana\nOrange\nExample output should be ignored" mock_generative_model_instance.generate_content_async.return_value = mock_response with patch('app.core.gemini.settings', mock_gemini_settings): # Patch the model instance within the service for this test with patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance) as patched_model_class, patch.object(genai, 'configure') as patched_configure: service = gemini.GeminiOCRService() # Re-init to use the patched model items = await service.extract_items(b"dummy_image") expected_call_args = [ mock_gemini_settings.OCR_ITEM_EXTRACTION_PROMPT, {"mime_type": "image/jpeg", "data": b"dummy_image"} ] service.model.generate_content_async.assert_called_once_with(contents=expected_call_args) assert items == ["Apple", "Banana", "Orange"] @pytest.mark.asyncio async def test_gemini_ocr_service_extract_items_quota_error(mock_gemini_settings, mock_generative_model_instance): mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ResourceExhausted("Quota limits exceeded.") with patch('app.core.gemini.settings', mock_gemini_settings), \ patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance), \ patch.object(genai, 'configure'): service = gemini.GeminiOCRService() with pytest.raises(OCRQuotaExceededError): await service.extract_items(b"dummy_image") @pytest.mark.asyncio async def test_gemini_ocr_service_extract_items_api_unavailable(mock_gemini_settings, mock_generative_model_instance): # Simulate a generic API error that isn't quota related mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.InternalServerError("Service unavailable") with patch('app.core.gemini.settings', mock_gemini_settings), \ patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance), \ patch.object(genai, 'configure'): service = gemini.GeminiOCRService() with pytest.raises(OCRServiceUnavailableError): await service.extract_items(b"dummy_image") @pytest.mark.asyncio async def test_gemini_ocr_service_extract_items_no_text_response(mock_gemini_settings, mock_generative_model_instance): mock_response = MagicMock() mock_response.text = None # Simulate no text in response mock_generative_model_instance.generate_content_async.return_value = mock_response with patch('app.core.gemini.settings', mock_gemini_settings), \ patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance), \ patch.object(genai, 'configure'): service = gemini.GeminiOCRService() with pytest.raises(OCRUnexpectedError): await service.extract_items(b"dummy_image") import pytest from unittest.mock import patch, MagicMock from datetime import datetime, timedelta, timezone from jose import jwt, JWTError from passlib.context import CryptContext from app.core.security import ( verify_password, hash_password, create_access_token, create_refresh_token, verify_access_token, verify_refresh_token, pwd_context, # Import for direct testing if needed, or to check its config ) # Assuming app.config.settings will be mocked # from app.config import settings # --- Tests for Password Hashing --- def test_hash_password(): password = "securepassword123" hashed = hash_password(password) assert isinstance(hashed, str) assert hashed != password # Check that the default scheme (bcrypt) is used by verifying the hash prefix # bcrypt hashes typically start with $2b$ or $2a$ or $2y$ assert hashed.startswith("$2b$") or hashed.startswith("$2a$") or hashed.startswith("$2y$") def test_verify_password_correct(): password = "testpassword" hashed_password = pwd_context.hash(password) # Use the same context for consistency assert verify_password(password, hashed_password) is True def test_verify_password_incorrect(): password = "testpassword" wrong_password = "wrongpassword" hashed_password = pwd_context.hash(password) assert verify_password(wrong_password, hashed_password) is False def test_verify_password_invalid_hash_format(): password = "testpassword" invalid_hash = "notarealhash" assert verify_password(password, invalid_hash) is False # --- Tests for JWT Creation --- # Mock settings for JWT tests @pytest.fixture(scope="module") def mock_jwt_settings(): mock_settings = MagicMock() mock_settings.SECRET_KEY = "testsecretkey" mock_settings.ALGORITHM = "HS256" mock_settings.ACCESS_TOKEN_EXPIRE_MINUTES = 30 mock_settings.REFRESH_TOKEN_EXPIRE_MINUTES = 10080 # 7 days return mock_settings @patch('app.core.security.settings') def test_create_access_token_default_expiry(mock_settings_global, mock_jwt_settings): mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES subject = "user@example.com" token = create_access_token(subject) assert isinstance(token, str) decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM]) assert decoded_payload["sub"] == subject assert decoded_payload["type"] == "access" assert "exp" in decoded_payload # Check if expiry is roughly correct (within a small delta) expected_expiry = datetime.now(timezone.utc) + timedelta(minutes=mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES) assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5) @patch('app.core.security.settings') def test_create_access_token_custom_expiry(mock_settings_global, mock_jwt_settings): mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM # ACCESS_TOKEN_EXPIRE_MINUTES is not used here due to custom delta subject = 123 # Subject can be int custom_delta = timedelta(hours=1) token = create_access_token(subject, expires_delta=custom_delta) assert isinstance(token, str) decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM]) assert decoded_payload["sub"] == str(subject) assert decoded_payload["type"] == "access" expected_expiry = datetime.now(timezone.utc) + custom_delta assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5) @patch('app.core.security.settings') def test_create_refresh_token_default_expiry(mock_settings_global, mock_jwt_settings): mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES subject = "refresh_subject" token = create_refresh_token(subject) assert isinstance(token, str) decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM]) assert decoded_payload["sub"] == subject assert decoded_payload["type"] == "refresh" expected_expiry = datetime.now(timezone.utc) + timedelta(minutes=mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES) assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5) # --- Tests for JWT Verification --- (More tests to be added here) @patch('app.core.security.settings') def test_verify_access_token_valid(mock_settings_global, mock_jwt_settings): mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES subject = "test_user_valid_access" token = create_access_token(subject) payload = verify_access_token(token) assert payload is not None assert payload["sub"] == subject assert payload["type"] == "access" @patch('app.core.security.settings') def test_verify_access_token_invalid_signature(mock_settings_global, mock_jwt_settings): mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES subject = "test_user_invalid_sig" # Create token with correct key token = create_access_token(subject) # Try to verify with wrong key mock_settings_global.SECRET_KEY = "wrongsecretkey" payload = verify_access_token(token) assert payload is None @patch('app.core.security.settings') @patch('app.core.security.datetime') # Mock datetime to control token expiry def test_verify_access_token_expired(mock_datetime, mock_settings_global, mock_jwt_settings): mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = 1 # Expire in 1 minute # Set current time for token creation now = datetime.now(timezone.utc) mock_datetime.now.return_value = now mock_datetime.fromtimestamp = datetime.fromtimestamp # Ensure original fromtimestamp is used by jwt.decode mock_datetime.timedelta = timedelta # Ensure original timedelta is used subject = "test_user_expired" token = create_access_token(subject) # Advance time beyond expiry for verification mock_datetime.now.return_value = now + timedelta(minutes=5) payload = verify_access_token(token) assert payload is None @patch('app.core.security.settings') def test_verify_access_token_wrong_type(mock_settings_global, mock_jwt_settings): mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES # For refresh token creation subject = "test_user_wrong_type" # Create a refresh token refresh_token = create_refresh_token(subject) # Try to verify it as an access token payload = verify_access_token(refresh_token) assert payload is None @patch('app.core.security.settings') def test_verify_refresh_token_valid(mock_settings_global, mock_jwt_settings): mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES subject = "test_user_valid_refresh" token = create_refresh_token(subject) payload = verify_refresh_token(token) assert payload is not None assert payload["sub"] == subject assert payload["type"] == "refresh" @patch('app.core.security.settings') @patch('app.core.security.datetime') def test_verify_refresh_token_expired(mock_datetime, mock_settings_global, mock_jwt_settings): mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = 1 # Expire in 1 minute now = datetime.now(timezone.utc) mock_datetime.now.return_value = now mock_datetime.fromtimestamp = datetime.fromtimestamp mock_datetime.timedelta = timedelta subject = "test_user_expired_refresh" token = create_refresh_token(subject) mock_datetime.now.return_value = now + timedelta(minutes=5) payload = verify_refresh_token(token) assert payload is None @patch('app.core.security.settings') def test_verify_refresh_token_wrong_type(mock_settings_global, mock_jwt_settings): mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES subject = "test_user_wrong_type_refresh" access_token = create_access_token(subject) payload = verify_refresh_token(access_token) assert payload is None import pytest from unittest.mock import AsyncMock, MagicMock, patch from sqlalchemy.exc import IntegrityError, OperationalError from decimal import Decimal, ROUND_HALF_UP from datetime import datetime, timezone from typing import List as PyList, Optional from app.crud.expense import ( create_expense, get_expense_by_id, get_expenses_for_list, get_expenses_for_group, update_expense, # Assuming update_expense exists delete_expense, # Assuming delete_expense exists get_users_for_splitting # Helper, might test indirectly ) from app.schemas.expense import ExpenseCreate, ExpenseUpdate, ExpenseSplitCreate from app.models import ( Expense as ExpenseModel, ExpenseSplit as ExpenseSplitModel, User as UserModel, List as ListModel, Group as GroupModel, UserGroup as UserGroupModel, Item as ItemModel, SplitTypeEnum ) from app.core.exceptions import ( ListNotFoundError, GroupNotFoundError, UserNotFoundError, InvalidOperationError ) # General Fixtures @pytest.fixture def mock_db_session(): session = AsyncMock() session.commit = AsyncMock() session.rollback = AsyncMock() session.refresh = AsyncMock() session.add = MagicMock() session.delete = MagicMock() session.execute = AsyncMock() session.get = AsyncMock() session.flush = AsyncMock() # create_expense uses flush return session @pytest.fixture def basic_user_model(): return UserModel(id=1, name="Test User", email="test@example.com") @pytest.fixture def another_user_model(): return UserModel(id=2, name="Another User", email="another@example.com") @pytest.fixture def basic_group_model(): group = GroupModel(id=1, name="Test Group") # Simulate member_associations for get_users_for_splitting if needed directly # group.member_associations = [UserGroupModel(user_id=1, group_id=1, user=basic_user_model()), UserGroupModel(user_id=2, group_id=1, user=another_user_model())] return group @pytest.fixture def basic_list_model(basic_group_model, basic_user_model): return ListModel(id=1, name="Test List", group_id=basic_group_model.id, group=basic_group_model, creator_id=basic_user_model.id, creator=basic_user_model) @pytest.fixture def expense_create_data_equal_split_list_ctx(basic_list_model, basic_user_model): return ExpenseCreate( description="Grocery run", total_amount=Decimal("30.00"), currency="USD", expense_date=datetime.now(timezone.utc), split_type=SplitTypeEnum.EQUAL, list_id=basic_list_model.id, group_id=None, # Derived from list item_id=None, paid_by_user_id=basic_user_model.id, splits_in=None ) @pytest.fixture def expense_create_data_equal_split_group_ctx(basic_group_model, basic_user_model): return ExpenseCreate( description="Movies", total_amount=Decimal("50.00"), currency="USD", expense_date=datetime.now(timezone.utc), split_type=SplitTypeEnum.EQUAL, list_id=None, group_id=basic_group_model.id, item_id=None, paid_by_user_id=basic_user_model.id, splits_in=None ) @pytest.fixture def expense_create_data_exact_split(basic_group_model, basic_user_model, another_user_model): return ExpenseCreate( description="Dinner", total_amount=Decimal("100.00"), split_type=SplitTypeEnum.EXACT_AMOUNTS, group_id=basic_group_model.id, paid_by_user_id=basic_user_model.id, splits_in=[ ExpenseSplitCreate(user_id=basic_user_model.id, owed_amount=Decimal("60.00")), ExpenseSplitCreate(user_id=another_user_model.id, owed_amount=Decimal("40.00")), ] ) @pytest.fixture def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model): return ExpenseModel( id=1, description=expense_create_data_equal_split_group_ctx.description, total_amount=expense_create_data_equal_split_group_ctx.total_amount, currency=expense_create_data_equal_split_group_ctx.currency, expense_date=expense_create_data_equal_split_group_ctx.expense_date, split_type=expense_create_data_equal_split_group_ctx.split_type, list_id=expense_create_data_equal_split_group_ctx.list_id, group_id=expense_create_data_equal_split_group_ctx.group_id, item_id=expense_create_data_equal_split_group_ctx.item_id, paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id, paid_by=basic_user_model, # Assuming paid_by relation is loaded # splits would be populated after creation usually version=1 ) # Tests for get_users_for_splitting (indirectly tested via create_expense, but stubs for direct if needed) @pytest.mark.asyncio async def test_get_users_for_splitting_group_context(mock_db_session, basic_group_model, basic_user_model, another_user_model): # Setup group with members user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id) user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id) basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2] mock_execute = AsyncMock() mock_execute.scalars.return_value.first.return_value = basic_group_model mock_db_session.execute.return_value = mock_execute users = await get_users_for_splitting(mock_db_session, expense_group_id=1, expense_list_id=None, expense_paid_by_user_id=1) assert len(users) == 2 assert basic_user_model in users assert another_user_model in users # --- create_expense Tests --- @pytest.mark.asyncio async def test_create_expense_equal_split_group_success(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model, another_user_model): mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # Payer, Group # Mock get_users_for_splitting call within create_expense # This is a bit tricky as it's an internal call. Patching is an option. with patch('app.crud.expense.get_users_for_splitting', new_callable=AsyncMock) as mock_get_users: mock_get_users.return_value = [basic_user_model, another_user_model] created_expense = await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, current_user_id=1) mock_db_session.add.assert_called() mock_db_session.flush.assert_called_once() # mock_db_session.commit.assert_called_once() # create_expense does not commit itself # mock_db_session.refresh.assert_called_once() # create_expense does not refresh itself assert created_expense is not None assert created_expense.total_amount == expense_create_data_equal_split_group_ctx.total_amount assert created_expense.split_type == SplitTypeEnum.EQUAL assert len(created_expense.splits) == 2 # Expect splits to be added to the model instance # Check split amounts expected_amount_per_user = (expense_create_data_equal_split_group_ctx.total_amount / 2).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) for split in created_expense.splits: assert split.owed_amount == expected_amount_per_user @pytest.mark.asyncio async def test_create_expense_exact_split_success(mock_db_session, expense_create_data_exact_split, basic_user_model, basic_group_model, another_user_model): mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # Payer, Group # Mock the select for user validation in exact splits mock_user_select_result = AsyncMock() mock_user_select_result.all.return_value = [(basic_user_model.id,), (another_user_model.id,)] # Simulate (id,) tuples # To make it behave like scalars().all() that returns a list of IDs: # We need to mock the scalars().all() part, or the whole execute chain for user validation. # A simpler way for this specific case might be to mock the select for User.id mock_execute_user_ids = AsyncMock() # Construct a mock result that mimics what `await db.execute(select(UserModel.id)...)` would give then process # It's a bit involved, usually `found_user_ids = {row[0] for row in user_results}` # Let's assume the select returns a list of Row objects or tuples with one element mock_user_ids_result_proxy = MagicMock() mock_user_ids_result_proxy.__iter__.return_value = iter([(basic_user_model.id,), (another_user_model.id,)]) mock_db_session.execute.return_value = mock_user_ids_result_proxy created_expense = await create_expense(mock_db_session, expense_create_data_exact_split, current_user_id=1) mock_db_session.add.assert_called() mock_db_session.flush.assert_called_once() assert created_expense is not None assert created_expense.split_type == SplitTypeEnum.EXACT_AMOUNTS assert len(created_expense.splits) == 2 assert created_expense.splits[0].owed_amount == Decimal("60.00") assert created_expense.splits[1].owed_amount == Decimal("40.00") @pytest.mark.asyncio async def test_create_expense_payer_not_found(mock_db_session, expense_create_data_equal_split_group_ctx): mock_db_session.get.return_value = None # Payer not found with pytest.raises(UserNotFoundError): await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, 1) @pytest.mark.asyncio async def test_create_expense_no_list_or_group(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model): mock_db_session.get.return_value = basic_user_model # Payer found expense_data = expense_create_data_equal_split_group_ctx.model_copy() expense_data.list_id = None expense_data.group_id = None with pytest.raises(InvalidOperationError, match="Expense must be associated with a list or a group"): await create_expense(mock_db_session, expense_data, 1) # --- get_expense_by_id Tests --- @pytest.mark.asyncio async def test_get_expense_by_id_found(mock_db_session, db_expense_model): mock_result = AsyncMock() mock_result.scalars.return_value.first.return_value = db_expense_model mock_db_session.execute.return_value = mock_result expense = await get_expense_by_id(mock_db_session, 1) assert expense is not None assert expense.id == 1 mock_db_session.execute.assert_called_once() @pytest.mark.asyncio async def test_get_expense_by_id_not_found(mock_db_session): mock_result = AsyncMock() mock_result.scalars.return_value.first.return_value = None mock_db_session.execute.return_value = mock_result expense = await get_expense_by_id(mock_db_session, 999) assert expense is None # --- get_expenses_for_list Tests --- @pytest.mark.asyncio async def test_get_expenses_for_list_success(mock_db_session, db_expense_model): mock_result = AsyncMock() mock_result.scalars.return_value.all.return_value = [db_expense_model] mock_db_session.execute.return_value = mock_result expenses = await get_expenses_for_list(mock_db_session, list_id=1) assert len(expenses) == 1 assert expenses[0].id == db_expense_model.id mock_db_session.execute.assert_called_once() # --- get_expenses_for_group Tests --- @pytest.mark.asyncio async def test_get_expenses_for_group_success(mock_db_session, db_expense_model): mock_result = AsyncMock() mock_result.scalars.return_value.all.return_value = [db_expense_model] mock_db_session.execute.return_value = mock_result expenses = await get_expenses_for_group(mock_db_session, group_id=1) assert len(expenses) == 1 assert expenses[0].id == db_expense_model.id mock_db_session.execute.assert_called_once() # --- Stubs for update_expense and delete_expense --- # These will need more details once the actual implementation of update/delete is clear # For example, how splits are handled on update, versioning, etc. @pytest.mark.asyncio async def test_update_expense_stub(mock_db_session): # Placeholder: Test logic for update_expense will be more complex # Needs ExpenseUpdate schema, existing expense object, and mocking of commit/refresh # Also depends on what fields are updatable and how splits are managed. expense_to_update = MagicMock(spec=ExpenseModel) expense_to_update.version = 1 update_payload = ExpenseUpdate(description="New description", version=1) # Add other fields as per schema definition # Simulate the update_expense function behavior # For example, if it loads the expense, modifies, commits, refreshes: # mock_db_session.get.return_value = expense_to_update # updated_expense = await update_expense(mock_db_session, expense_to_update, update_payload) # assert updated_expense.description == "New description" # mock_db_session.commit.assert_called_once() # mock_db_session.refresh.assert_called_once() pass # Replace with actual test logic @pytest.mark.asyncio async def test_delete_expense_stub(mock_db_session): # Placeholder: Test logic for delete_expense # Needs an existing expense object and mocking of delete/commit # Also, consider implications (e.g., are splits deleted?) expense_to_delete = MagicMock(spec=ExpenseModel) expense_to_delete.id = 1 expense_to_delete.version = 1 # Simulate delete_expense behavior # mock_db_session.get.return_value = expense_to_delete # If it re-fetches # await delete_expense(mock_db_session, expense_to_delete, expected_version=1) # mock_db_session.delete.assert_called_once_with(expense_to_delete) # mock_db_session.commit.assert_called_once() pass # Replace with actual test logic # TODO: Add more tests for create_expense covering: # - List context success # - Percentage, Shares, Item-based splits # - Error cases for each split type (e.g., total mismatch, invalid inputs) # - Validation of list_id/group_id consistency # - User not found in splits_in # - Item not found for ITEM_BASED split # TODO: Flesh out update_expense tests: # - Success case # - Version mismatch # - Trying to update immutable fields # - How splits are handled (recalculated, deleted/recreated, or not changeable) # TODO: Flesh out delete_expense tests: # - Success case # - Version mismatch (if applicable) # - Ensure associated splits are also deleted (cascade behavior) import pytest from unittest.mock import AsyncMock, MagicMock, patch from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError from sqlalchemy.future import select from sqlalchemy import delete, func # For remove_user_from_group and get_group_member_count from app.crud.group import ( create_group, get_user_groups, get_group_by_id, is_user_member, get_user_role_in_group, add_user_to_group, remove_user_from_group, get_group_member_count, check_group_membership, check_user_role_in_group ) from app.schemas.group import GroupCreate from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel, UserRoleEnum from app.core.exceptions import ( GroupOperationError, GroupNotFoundError, DatabaseConnectionError, DatabaseIntegrityError, DatabaseQueryError, DatabaseTransactionError, GroupMembershipError, GroupPermissionError ) # Fixtures @pytest.fixture def mock_db_session(): session = AsyncMock() # Patch begin_nested for SQLAlchemy 1.4+ if used, or just begin() if that's the pattern # For simplicity, assuming `async with db.begin():` translates to db.begin() and db.commit()/rollback() session.begin = AsyncMock() # Mock the begin call used in async with db.begin() session.commit = AsyncMock() session.rollback = AsyncMock() session.refresh = AsyncMock() session.add = MagicMock() session.delete = MagicMock() # For remove_user_from_group (if it uses session.delete) session.execute = AsyncMock() session.get = AsyncMock() session.flush = AsyncMock() return session @pytest.fixture def group_create_data(): return GroupCreate(name="Test Group") @pytest.fixture def creator_user_model(): return UserModel(id=1, name="Creator User", email="creator@example.com") @pytest.fixture def member_user_model(): return UserModel(id=2, name="Member User", email="member@example.com") @pytest.fixture def db_group_model(creator_user_model): return GroupModel(id=1, name="Test Group", created_by_id=creator_user_model.id, creator=creator_user_model) @pytest.fixture def db_user_group_owner_assoc(db_group_model, creator_user_model): return UserGroupModel(user_id=creator_user_model.id, group_id=db_group_model.id, role=UserRoleEnum.owner, user=creator_user_model, group=db_group_model) @pytest.fixture def db_user_group_member_assoc(db_group_model, member_user_model): return UserGroupModel(user_id=member_user_model.id, group_id=db_group_model.id, role=UserRoleEnum.member, user=member_user_model, group=db_group_model) # --- create_group Tests --- @pytest.mark.asyncio async def test_create_group_success(mock_db_session, group_create_data, creator_user_model): async def mock_refresh(instance): instance.id = 1 # Simulate ID assignment by DB return None mock_db_session.refresh = AsyncMock(side_effect=mock_refresh) created_group = await create_group(mock_db_session, group_create_data, creator_user_model.id) assert mock_db_session.add.call_count == 2 # Group and UserGroup mock_db_session.flush.assert_called() # Called multiple times mock_db_session.refresh.assert_called_once_with(created_group) assert created_group is not None assert created_group.name == group_create_data.name assert created_group.created_by_id == creator_user_model.id # Further check if UserGroup was created correctly by inspecting mock_db_session.add calls or by fetching @pytest.mark.asyncio async def test_create_group_integrity_error(mock_db_session, group_create_data, creator_user_model): mock_db_session.flush.side_effect = IntegrityError("mock integrity error", "params", "orig") with pytest.raises(DatabaseIntegrityError): await create_group(mock_db_session, group_create_data, creator_user_model.id) mock_db_session.rollback.assert_called_once() # Assuming rollback within the except block of create_group # --- get_user_groups Tests --- @pytest.mark.asyncio async def test_get_user_groups_success(mock_db_session, db_group_model, creator_user_model): mock_result = AsyncMock() mock_result.scalars.return_value.all.return_value = [db_group_model] mock_db_session.execute.return_value = mock_result groups = await get_user_groups(mock_db_session, creator_user_model.id) assert len(groups) == 1 assert groups[0].name == db_group_model.name mock_db_session.execute.assert_called_once() # --- get_group_by_id Tests --- @pytest.mark.asyncio async def test_get_group_by_id_found(mock_db_session, db_group_model): mock_result = AsyncMock() mock_result.scalars.return_value.first.return_value = db_group_model mock_db_session.execute.return_value = mock_result group = await get_group_by_id(mock_db_session, db_group_model.id) assert group is not None assert group.id == db_group_model.id # Add assertions for eager loaded members if applicable and mocked @pytest.mark.asyncio async def test_get_group_by_id_not_found(mock_db_session): mock_result = AsyncMock() mock_result.scalars.return_value.first.return_value = None mock_db_session.execute.return_value = mock_result group = await get_group_by_id(mock_db_session, 999) assert group is None # --- is_user_member Tests --- @pytest.mark.asyncio async def test_is_user_member_true(mock_db_session, db_group_model, creator_user_model): mock_result = AsyncMock() mock_result.scalar_one_or_none.return_value = 1 # Simulate UserGroup.id found mock_db_session.execute.return_value = mock_result is_member = await is_user_member(mock_db_session, db_group_model.id, creator_user_model.id) assert is_member is True @pytest.mark.asyncio async def test_is_user_member_false(mock_db_session, db_group_model, member_user_model): mock_result = AsyncMock() mock_result.scalar_one_or_none.return_value = None # Simulate no UserGroup.id found mock_db_session.execute.return_value = mock_result is_member = await is_user_member(mock_db_session, db_group_model.id, member_user_model.id + 1) # Non-member assert is_member is False # --- get_user_role_in_group Tests --- @pytest.mark.asyncio async def test_get_user_role_in_group_owner(mock_db_session, db_group_model, creator_user_model): mock_result = AsyncMock() mock_result.scalar_one_or_none.return_value = UserRoleEnum.owner mock_db_session.execute.return_value = mock_result role = await get_user_role_in_group(mock_db_session, db_group_model.id, creator_user_model.id) assert role == UserRoleEnum.owner # --- add_user_to_group Tests --- @pytest.mark.asyncio async def test_add_user_to_group_new_member(mock_db_session, db_group_model, member_user_model): # First execute call for checking existing membership returns None mock_existing_check_result = AsyncMock() mock_existing_check_result.scalar_one_or_none.return_value = None mock_db_session.execute.return_value = mock_existing_check_result async def mock_refresh_user_group(instance): instance.id = 100 # Simulate ID for UserGroupModel return None mock_db_session.refresh = AsyncMock(side_effect=mock_refresh_user_group) user_group_assoc = await add_user_to_group(mock_db_session, db_group_model.id, member_user_model.id, UserRoleEnum.member) mock_db_session.add.assert_called_once() mock_db_session.flush.assert_called_once() mock_db_session.refresh.assert_called_once() assert user_group_assoc is not None assert user_group_assoc.user_id == member_user_model.id assert user_group_assoc.group_id == db_group_model.id assert user_group_assoc.role == UserRoleEnum.member @pytest.mark.asyncio async def test_add_user_to_group_already_member(mock_db_session, db_group_model, creator_user_model, db_user_group_owner_assoc): mock_existing_check_result = AsyncMock() mock_existing_check_result.scalar_one_or_none.return_value = db_user_group_owner_assoc # User is already a member mock_db_session.execute.return_value = mock_existing_check_result user_group_assoc = await add_user_to_group(mock_db_session, db_group_model.id, creator_user_model.id) assert user_group_assoc is None mock_db_session.add.assert_not_called() # --- remove_user_from_group Tests --- @pytest.mark.asyncio async def test_remove_user_from_group_success(mock_db_session, db_group_model, member_user_model): mock_delete_result = AsyncMock() mock_delete_result.scalar_one_or_none.return_value = 1 # Simulate a row was deleted (returning ID) mock_db_session.execute.return_value = mock_delete_result removed = await remove_user_from_group(mock_db_session, db_group_model.id, member_user_model.id) assert removed is True # Assert that db.execute was called with a delete statement # This requires inspecting the call args of mock_db_session.execute # For simplicity, we check it was called. A deeper check would validate the SQL query itself. mock_db_session.execute.assert_called_once() # --- get_group_member_count Tests --- @pytest.mark.asyncio async def test_get_group_member_count_success(mock_db_session, db_group_model): mock_count_result = AsyncMock() mock_count_result.scalar_one.return_value = 5 mock_db_session.execute.return_value = mock_count_result count = await get_group_member_count(mock_db_session, db_group_model.id) assert count == 5 # --- check_group_membership Tests --- @pytest.mark.asyncio async def test_check_group_membership_is_member(mock_db_session, db_group_model, creator_user_model): mock_db_session.get.return_value = db_group_model # Group exists mock_membership_result = AsyncMock() mock_membership_result.scalar_one_or_none.return_value = 1 # User is a member mock_db_session.execute.return_value = mock_membership_result await check_group_membership(mock_db_session, db_group_model.id, creator_user_model.id) # No exception means success @pytest.mark.asyncio async def test_check_group_membership_group_not_found(mock_db_session, creator_user_model): mock_db_session.get.return_value = None # Group does not exist with pytest.raises(GroupNotFoundError): await check_group_membership(mock_db_session, 999, creator_user_model.id) @pytest.mark.asyncio async def test_check_group_membership_not_member(mock_db_session, db_group_model, member_user_model): mock_db_session.get.return_value = db_group_model # Group exists mock_membership_result = AsyncMock() mock_membership_result.scalar_one_or_none.return_value = None # User is not a member mock_db_session.execute.return_value = mock_membership_result with pytest.raises(GroupMembershipError): await check_group_membership(mock_db_session, db_group_model.id, member_user_model.id) # --- check_user_role_in_group Tests --- @pytest.mark.asyncio async def test_check_user_role_in_group_sufficient_role(mock_db_session, db_group_model, creator_user_model): # Mock check_group_membership (implicitly called) mock_db_session.get.return_value = db_group_model mock_membership_check = AsyncMock() mock_membership_check.scalar_one_or_none.return_value = 1 # User is member # Mock get_user_role_in_group mock_role_check = AsyncMock() mock_role_check.scalar_one_or_none.return_value = UserRoleEnum.owner mock_db_session.execute.side_effect = [mock_membership_check, mock_role_check] await check_user_role_in_group(mock_db_session, db_group_model.id, creator_user_model.id, UserRoleEnum.member) # No exception means success @pytest.mark.asyncio async def test_check_user_role_in_group_insufficient_role(mock_db_session, db_group_model, member_user_model): mock_db_session.get.return_value = db_group_model # Group exists mock_membership_check = AsyncMock() mock_membership_check.scalar_one_or_none.return_value = 1 # User is member (for check_group_membership call) mock_role_check = AsyncMock() mock_role_check.scalar_one_or_none.return_value = UserRoleEnum.member # User's actual role mock_db_session.execute.side_effect = [mock_membership_check, mock_role_check] with pytest.raises(GroupPermissionError): await check_user_role_in_group(mock_db_session, db_group_model.id, member_user_model.id, UserRoleEnum.owner) # TODO: Add tests for DB operational/SQLAlchemy errors for each function similar to create_group_integrity_error # TODO: Test edge cases like trying to add user to non-existent group (should be caught by FK constraints or prior checks) import pytest from unittest.mock import AsyncMock, MagicMock, patch from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError # Assuming these might be raised from datetime import datetime, timedelta, timezone import secrets from app.crud.invite import ( create_invite, get_active_invite_by_code, deactivate_invite, MAX_CODE_GENERATION_ATTEMPTS ) from app.models import Invite as InviteModel, User as UserModel, Group as GroupModel # For context # No specific schemas for invite CRUD usually, but models are used. # Fixtures @pytest.fixture def mock_db_session(): session = AsyncMock() session.commit = AsyncMock() session.rollback = AsyncMock() session.refresh = AsyncMock() session.add = MagicMock() session.execute = AsyncMock() return session @pytest.fixture def group_model(): return GroupModel(id=1, name="Test Group") @pytest.fixture def user_model(): # Creator return UserModel(id=1, name="Creator User") @pytest.fixture def db_invite_model(group_model, user_model): return InviteModel( id=1, code="test_invite_code_123", group_id=group_model.id, created_by_id=user_model.id, expires_at=datetime.now(timezone.utc) + timedelta(days=7), is_active=True ) # --- create_invite Tests --- @pytest.mark.asyncio @patch('app.crud.invite.secrets.token_urlsafe') # Patch secrets.token_urlsafe async def test_create_invite_success_first_attempt(mock_token_urlsafe, mock_db_session, group_model, user_model): generated_code = "unique_code_123" mock_token_urlsafe.return_value = generated_code # Mock DB execute for checking existing code (first attempt, no existing code) mock_existing_check_result = AsyncMock() mock_existing_check_result.scalar_one_or_none.return_value = None mock_db_session.execute.return_value = mock_existing_check_result invite = await create_invite(mock_db_session, group_model.id, user_model.id, expires_in_days=5) mock_token_urlsafe.assert_called_once_with(16) mock_db_session.execute.assert_called_once() # For the uniqueness check mock_db_session.add.assert_called_once() mock_db_session.commit.assert_called_once() mock_db_session.refresh.assert_called_once_with(invite) assert invite is not None assert invite.code == generated_code assert invite.group_id == group_model.id assert invite.created_by_id == user_model.id assert invite.is_active is True assert invite.expires_at > datetime.now(timezone.utc) + timedelta(days=4) # Check expiry is roughly correct @pytest.mark.asyncio @patch('app.crud.invite.secrets.token_urlsafe') async def test_create_invite_success_after_collision(mock_token_urlsafe, mock_db_session, group_model, user_model): colliding_code = "colliding_code" unique_code = "finally_unique_code" mock_token_urlsafe.side_effect = [colliding_code, unique_code] # First call collides, second is unique # Mock DB execute for checking existing code mock_collision_check_result = AsyncMock() mock_collision_check_result.scalar_one_or_none.return_value = 1 # Simulate collision (ID found) mock_no_collision_check_result = AsyncMock() mock_no_collision_check_result.scalar_one_or_none.return_value = None # No collision mock_db_session.execute.side_effect = [mock_collision_check_result, mock_no_collision_check_result] invite = await create_invite(mock_db_session, group_model.id, user_model.id) assert mock_token_urlsafe.call_count == 2 assert mock_db_session.execute.call_count == 2 assert invite is not None assert invite.code == unique_code @pytest.mark.asyncio @patch('app.crud.invite.secrets.token_urlsafe') async def test_create_invite_fails_after_max_attempts(mock_token_urlsafe, mock_db_session, group_model, user_model): mock_token_urlsafe.return_value = "always_colliding_code" mock_collision_check_result = AsyncMock() mock_collision_check_result.scalar_one_or_none.return_value = 1 # Always collide mock_db_session.execute.return_value = mock_collision_check_result invite = await create_invite(mock_db_session, group_model.id, user_model.id) assert invite is None assert mock_token_urlsafe.call_count == MAX_CODE_GENERATION_ATTEMPTS assert mock_db_session.execute.call_count == MAX_CODE_GENERATION_ATTEMPTS mock_db_session.add.assert_not_called() # --- get_active_invite_by_code Tests --- @pytest.mark.asyncio async def test_get_active_invite_by_code_found_active(mock_db_session, db_invite_model): db_invite_model.is_active = True db_invite_model.expires_at = datetime.now(timezone.utc) + timedelta(days=1) mock_result = AsyncMock() mock_result.scalars.return_value.first.return_value = db_invite_model mock_db_session.execute.return_value = mock_result invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code) assert invite is not None assert invite.code == db_invite_model.code mock_db_session.execute.assert_called_once() @pytest.mark.asyncio async def test_get_active_invite_by_code_not_found(mock_db_session): mock_result = AsyncMock() mock_result.scalars.return_value.first.return_value = None mock_db_session.execute.return_value = mock_result invite = await get_active_invite_by_code(mock_db_session, "non_existent_code") assert invite is None @pytest.mark.asyncio async def test_get_active_invite_by_code_inactive(mock_db_session, db_invite_model): db_invite_model.is_active = False # Inactive db_invite_model.expires_at = datetime.now(timezone.utc) + timedelta(days=1) mock_result = AsyncMock() mock_result.scalars.return_value.first.return_value = None # Should not be found by query mock_db_session.execute.return_value = mock_result invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code) assert invite is None @pytest.mark.asyncio async def test_get_active_invite_by_code_expired(mock_db_session, db_invite_model): db_invite_model.is_active = True db_invite_model.expires_at = datetime.now(timezone.utc) - timedelta(days=1) # Expired mock_result = AsyncMock() mock_result.scalars.return_value.first.return_value = None # Should not be found by query mock_db_session.execute.return_value = mock_result invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code) assert invite is None # --- deactivate_invite Tests --- @pytest.mark.asyncio async def test_deactivate_invite_success(mock_db_session, db_invite_model): db_invite_model.is_active = True # Ensure it starts active deactivated_invite = await deactivate_invite(mock_db_session, db_invite_model) mock_db_session.add.assert_called_once_with(db_invite_model) mock_db_session.commit.assert_called_once() mock_db_session.refresh.assert_called_once_with(db_invite_model) assert deactivated_invite.is_active is False # It might be useful to test DB error cases (OperationalError, etc.) for each function # if they have specific try-except blocks, but invite.py seems to rely on caller/framework for some of that. # create_invite has its own DB interaction within the loop, so that's covered. # get_active_invite_by_code and deactivate_invite are simpler DB ops. import pytest from unittest.mock import AsyncMock, MagicMock, patch from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError from datetime import datetime, timezone from app.crud.item import ( create_item, get_items_by_list_id, get_item_by_id, update_item, delete_item ) from app.schemas.item import ItemCreate, ItemUpdate from app.models import Item as ItemModel, User as UserModel, List as ListModel from app.core.exceptions import ( ItemNotFoundError, # Not directly raised by CRUD but good for API layer tests DatabaseConnectionError, DatabaseIntegrityError, DatabaseQueryError, DatabaseTransactionError, ConflictError ) # Fixtures @pytest.fixture def mock_db_session(): session = AsyncMock() session.begin = AsyncMock() session.commit = AsyncMock() session.rollback = AsyncMock() session.refresh = AsyncMock() session.add = MagicMock() session.delete = MagicMock() session.execute = AsyncMock() session.get = AsyncMock() # Though not directly used in item.py, good for consistency session.flush = AsyncMock() return session @pytest.fixture def item_create_data(): return ItemCreate(name="Test Item", quantity="1 pack") @pytest.fixture def item_update_data(): return ItemUpdate(name="Updated Test Item", quantity="2 packs", version=1, is_complete=False) @pytest.fixture def user_model(): return UserModel(id=1, name="Test User", email="test@example.com") @pytest.fixture def list_model(): return ListModel(id=1, name="Test List") @pytest.fixture def db_item_model(list_model, user_model): return ItemModel( id=1, name="Existing Item", quantity="1 unit", list_id=list_model.id, added_by_id=user_model.id, is_complete=False, version=1, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc) ) # --- create_item Tests --- @pytest.mark.asyncio async def test_create_item_success(mock_db_session, item_create_data, list_model, user_model): async def mock_refresh(instance): instance.id = 10 # Simulate ID assignment instance.version = 1 # Simulate version init instance.created_at = datetime.now(timezone.utc) instance.updated_at = datetime.now(timezone.utc) return None mock_db_session.refresh = AsyncMock(side_effect=mock_refresh) created_item = await create_item(mock_db_session, item_create_data, list_model.id, user_model.id) mock_db_session.add.assert_called_once() mock_db_session.flush.assert_called_once() mock_db_session.refresh.assert_called_once_with(created_item) assert created_item is not None assert created_item.name == item_create_data.name assert created_item.list_id == list_model.id assert created_item.added_by_id == user_model.id assert created_item.is_complete is False assert created_item.version == 1 @pytest.mark.asyncio async def test_create_item_integrity_error(mock_db_session, item_create_data, list_model, user_model): mock_db_session.flush.side_effect = IntegrityError("mock integrity error", "params", "orig") with pytest.raises(DatabaseIntegrityError): await create_item(mock_db_session, item_create_data, list_model.id, user_model.id) mock_db_session.rollback.assert_called_once() # --- get_items_by_list_id Tests --- @pytest.mark.asyncio async def test_get_items_by_list_id_success(mock_db_session, db_item_model, list_model): mock_result = AsyncMock() mock_result.scalars.return_value.all.return_value = [db_item_model] mock_db_session.execute.return_value = mock_result items = await get_items_by_list_id(mock_db_session, list_model.id) assert len(items) == 1 assert items[0].id == db_item_model.id mock_db_session.execute.assert_called_once() # --- get_item_by_id Tests --- @pytest.mark.asyncio async def test_get_item_by_id_found(mock_db_session, db_item_model): mock_result = AsyncMock() mock_result.scalars.return_value.first.return_value = db_item_model mock_db_session.execute.return_value = mock_result item = await get_item_by_id(mock_db_session, db_item_model.id) assert item is not None assert item.id == db_item_model.id @pytest.mark.asyncio async def test_get_item_by_id_not_found(mock_db_session): mock_result = AsyncMock() mock_result.scalars.return_value.first.return_value = None mock_db_session.execute.return_value = mock_result item = await get_item_by_id(mock_db_session, 999) assert item is None # --- update_item Tests --- @pytest.mark.asyncio async def test_update_item_success(mock_db_session, db_item_model, item_update_data, user_model): item_update_data.version = db_item_model.version # Match versions for successful update item_update_data.name = "Newly Updated Name" item_update_data.is_complete = True # Test completion logic updated_item = await update_item(mock_db_session, db_item_model, item_update_data, user_model.id) mock_db_session.add.assert_called_once_with(db_item_model) # add is used for existing objects too mock_db_session.flush.assert_called_once() mock_db_session.refresh.assert_called_once_with(db_item_model) assert updated_item.name == "Newly Updated Name" assert updated_item.version == db_item_model.version # Check version increment logic in test assert updated_item.is_complete is True assert updated_item.completed_by_id == user_model.id @pytest.mark.asyncio async def test_update_item_version_conflict(mock_db_session, db_item_model, item_update_data, user_model): item_update_data.version = db_item_model.version + 1 # Create a version mismatch with pytest.raises(ConflictError): await update_item(mock_db_session, db_item_model, item_update_data, user_model.id) mock_db_session.rollback.assert_called_once() @pytest.mark.asyncio async def test_update_item_set_incomplete(mock_db_session, db_item_model, item_update_data, user_model): db_item_model.is_complete = True # Start as complete db_item_model.completed_by_id = user_model.id db_item_model.version = 1 item_update_data.version = 1 item_update_data.is_complete = False item_update_data.name = db_item_model.name # No name change for this test item_update_data.quantity = db_item_model.quantity updated_item = await update_item(mock_db_session, db_item_model, item_update_data, user_model.id) assert updated_item.is_complete is False assert updated_item.completed_by_id is None assert updated_item.version == 2 # --- delete_item Tests --- @pytest.mark.asyncio async def test_delete_item_success(mock_db_session, db_item_model): result = await delete_item(mock_db_session, db_item_model) assert result is None mock_db_session.delete.assert_called_once_with(db_item_model) mock_db_session.commit.assert_called_once() # Commit happens in the `async with db.begin()` context manager @pytest.mark.asyncio async def test_delete_item_db_error(mock_db_session, db_item_model): mock_db_session.delete.side_effect = OperationalError("mock op error", "params", "orig") with pytest.raises(DatabaseConnectionError): await delete_item(mock_db_session, db_item_model) mock_db_session.rollback.assert_called_once() # TODO: Add more specific DB error tests (Operational, SQLAlchemyError) for each function. import pytest from unittest.mock import AsyncMock, MagicMock, patch from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError from sqlalchemy.future import select from sqlalchemy import func as sql_func # For get_list_status from datetime import datetime, timezone from app.crud.list import ( create_list, get_lists_for_user, get_list_by_id, update_list, delete_list, check_list_permission, get_list_status ) from app.schemas.list import ListCreate, ListUpdate, ListStatus from app.models import List as ListModel, User as UserModel, Group as GroupModel, UserGroup as UserGroupModel, Item as ItemModel from app.core.exceptions import ( ListNotFoundError, ListPermissionError, ListCreatorRequiredError, DatabaseConnectionError, DatabaseIntegrityError, DatabaseQueryError, DatabaseTransactionError, ConflictError ) # Fixtures @pytest.fixture def mock_db_session(): session = AsyncMock() session.begin = AsyncMock() session.commit = AsyncMock() session.rollback = AsyncMock() session.refresh = AsyncMock() session.add = MagicMock() session.delete = MagicMock() session.execute = AsyncMock() session.get = AsyncMock() # Used by check_list_permission via get_list_by_id session.flush = AsyncMock() return session @pytest.fixture def list_create_data(): return ListCreate(name="New Shopping List", description="Groceries for the week") @pytest.fixture def list_update_data(): return ListUpdate(name="Updated Shopping List", description="Weekend Groceries", version=1) @pytest.fixture def user_model(): return UserModel(id=1, name="Test User", email="test@example.com") @pytest.fixture def another_user_model(): return UserModel(id=2, name="Another User", email="another@example.com") @pytest.fixture def group_model(): return GroupModel(id=1, name="Test Group") @pytest.fixture def db_list_personal_model(user_model): return ListModel( id=1, name="Personal List", created_by_id=user_model.id, creator=user_model, version=1, updated_at=datetime.now(timezone.utc), items=[] ) @pytest.fixture def db_list_group_model(user_model, group_model): return ListModel( id=2, name="Group List", created_by_id=user_model.id, creator=user_model, group_id=group_model.id, group=group_model, version=1, updated_at=datetime.now(timezone.utc), items=[] ) # --- create_list Tests --- @pytest.mark.asyncio async def test_create_list_success(mock_db_session, list_create_data, user_model): async def mock_refresh(instance): instance.id = 100 instance.version = 1 instance.updated_at = datetime.now(timezone.utc) return None mock_db_session.refresh.return_value = None mock_db_session.refresh.side_effect = mock_refresh created_list = await create_list(mock_db_session, list_create_data, user_model.id) mock_db_session.add.assert_called_once() mock_db_session.flush.assert_called_once() mock_db_session.refresh.assert_called_once() assert created_list.name == list_create_data.name assert created_list.created_by_id == user_model.id # --- get_lists_for_user Tests --- @pytest.mark.asyncio async def test_get_lists_for_user_mix(mock_db_session, user_model, db_list_personal_model, db_list_group_model): # Simulate user is part of group for db_list_group_model mock_group_ids_result = AsyncMock() mock_group_ids_result.scalars.return_value.all.return_value = [db_list_group_model.group_id] mock_lists_result = AsyncMock() # Order should be personal list (created by user_id) then group list mock_lists_result.scalars.return_value.all.return_value = [db_list_personal_model, db_list_group_model] mock_db_session.execute.side_effect = [mock_group_ids_result, mock_lists_result] lists = await get_lists_for_user(mock_db_session, user_model.id) assert len(lists) == 2 assert db_list_personal_model in lists assert db_list_group_model in lists assert mock_db_session.execute.call_count == 2 # --- get_list_by_id Tests --- @pytest.mark.asyncio async def test_get_list_by_id_found_no_items(mock_db_session, db_list_personal_model): mock_result = AsyncMock() mock_result.scalars.return_value.first.return_value = db_list_personal_model mock_db_session.execute.return_value = mock_result found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=False) assert found_list is not None assert found_list.id == db_list_personal_model.id # query options should not include selectinload for items # (difficult to assert directly without inspecting query object in detail) @pytest.mark.asyncio async def test_get_list_by_id_found_with_items(mock_db_session, db_list_personal_model): # Simulate items loaded for the list db_list_personal_model.items = [ItemModel(id=1, name="Test Item")] mock_result = AsyncMock() mock_result.scalars.return_value.first.return_value = db_list_personal_model mock_db_session.execute.return_value = mock_result found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=True) assert found_list is not None assert len(found_list.items) == 1 # query options should include selectinload for items # --- update_list Tests --- @pytest.mark.asyncio async def test_update_list_success(mock_db_session, db_list_personal_model, list_update_data): list_update_data.version = db_list_personal_model.version # Match version updated_list = await update_list(mock_db_session, db_list_personal_model, list_update_data) assert updated_list.name == list_update_data.name assert updated_list.version == db_list_personal_model.version # version incremented in db_list_personal_model mock_db_session.add.assert_called_once_with(db_list_personal_model) mock_db_session.flush.assert_called_once() mock_db_session.refresh.assert_called_once_with(db_list_personal_model) @pytest.mark.asyncio async def test_update_list_conflict(mock_db_session, db_list_personal_model, list_update_data): list_update_data.version = db_list_personal_model.version + 1 # Version mismatch with pytest.raises(ConflictError): await update_list(mock_db_session, db_list_personal_model, list_update_data) mock_db_session.rollback.assert_called_once() # --- delete_list Tests --- @pytest.mark.asyncio async def test_delete_list_success(mock_db_session, db_list_personal_model): await delete_list(mock_db_session, db_list_personal_model) mock_db_session.delete.assert_called_once_with(db_list_personal_model) mock_db_session.commit.assert_called_once() # from async with db.begin() # --- check_list_permission Tests --- @pytest.mark.asyncio async def test_check_list_permission_creator_access_personal_list(mock_db_session, db_list_personal_model, user_model): # get_list_by_id (called by check_list_permission) will mock execute mock_list_fetch_result = AsyncMock() mock_list_fetch_result.scalars.return_value.first.return_value = db_list_personal_model mock_db_session.execute.return_value = mock_list_fetch_result ret_list = await check_list_permission(mock_db_session, db_list_personal_model.id, user_model.id) assert ret_list.id == db_list_personal_model.id @pytest.mark.asyncio async def test_check_list_permission_group_member_access_group_list(mock_db_session, db_list_group_model, another_user_model, group_model): # User `another_user_model` is not creator but member of the group db_list_group_model.creator_id = user_model.id # Original creator is user_model db_list_group_model.creator = user_model # Mock get_list_by_id internal call mock_list_fetch_result = AsyncMock() mock_list_fetch_result.scalars.return_value.first.return_value = db_list_group_model # Mock is_user_member call with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member: mock_is_member.return_value = True # another_user_model is a member mock_db_session.execute.return_value = mock_list_fetch_result ret_list = await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id) assert ret_list.id == db_list_group_model.id mock_is_member.assert_called_once_with(mock_db_session, group_id=group_model.id, user_id=another_user_model.id) @pytest.mark.asyncio async def test_check_list_permission_non_member_no_access_group_list(mock_db_session, db_list_group_model, another_user_model): db_list_group_model.creator_id = user_model.id # Creator is not another_user_model mock_list_fetch_result = AsyncMock() mock_list_fetch_result.scalars.return_value.first.return_value = db_list_group_model with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member: mock_is_member.return_value = False # another_user_model is NOT a member mock_db_session.execute.return_value = mock_list_fetch_result with pytest.raises(ListPermissionError): await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id) @pytest.mark.asyncio async def test_check_list_permission_list_not_found(mock_db_session, user_model): mock_list_fetch_result = AsyncMock() mock_list_fetch_result.scalars.return_value.first.return_value = None # List not found mock_db_session.execute.return_value = mock_list_fetch_result with pytest.raises(ListNotFoundError): await check_list_permission(mock_db_session, 999, user_model.id) # --- get_list_status Tests --- @pytest.mark.asyncio async def test_get_list_status_success(mock_db_session, db_list_personal_model): list_updated_at = datetime.now(timezone.utc) - timezone.timedelta(hours=1) item_updated_at = datetime.now(timezone.utc) item_count = 5 db_list_personal_model.updated_at = list_updated_at # Mock for ListModel.updated_at query mock_list_updated_result = AsyncMock() mock_list_updated_result.scalar_one_or_none.return_value = list_updated_at # Mock for ItemModel status query mock_item_status_result = AsyncMock() # SQLAlchemy query for func.max and func.count returns a Row-like object or None mock_item_status_row = MagicMock() mock_item_status_row.latest_item_updated_at = item_updated_at mock_item_status_row.item_count = item_count mock_item_status_result.first.return_value = mock_item_status_row mock_db_session.execute.side_effect = [mock_list_updated_result, mock_item_status_result] status = await get_list_status(mock_db_session, db_list_personal_model.id) assert status.list_updated_at == list_updated_at assert status.latest_item_updated_at == item_updated_at assert status.item_count == item_count assert mock_db_session.execute.call_count == 2 @pytest.mark.asyncio async def test_get_list_status_list_not_found(mock_db_session): mock_list_updated_result = AsyncMock() mock_list_updated_result.scalar_one_or_none.return_value = None # List not found mock_db_session.execute.return_value = mock_list_updated_result with pytest.raises(ListNotFoundError): await get_list_status(mock_db_session, 999) # TODO: Add more specific DB error tests (Operational, SQLAlchemyError, IntegrityError) for each function. # TODO: Test check_list_permission with require_creator=True cases. import pytest from unittest.mock import AsyncMock, MagicMock, patch from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.future import select from decimal import Decimal, ROUND_HALF_UP from datetime import datetime, timezone from typing import List as PyList from app.crud.settlement import ( create_settlement, get_settlement_by_id, get_settlements_for_group, get_settlements_involving_user, update_settlement, delete_settlement ) from app.schemas.expense import SettlementCreate, SettlementUpdate from app.models import Settlement as SettlementModel, User as UserModel, Group as GroupModel from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError # Fixtures @pytest.fixture def mock_db_session(): session = AsyncMock() session.commit = AsyncMock() session.rollback = AsyncMock() session.refresh = AsyncMock() session.add = MagicMock() session.delete = MagicMock() session.execute = AsyncMock() session.get = AsyncMock() return session @pytest.fixture def settlement_create_data(): return SettlementCreate( group_id=1, paid_by_user_id=1, paid_to_user_id=2, amount=Decimal("10.50"), settlement_date=datetime.now(timezone.utc), description="Test settlement" ) @pytest.fixture def settlement_update_data(): return SettlementUpdate( description="Updated settlement description", settlement_date=datetime.now(timezone.utc), version=1 # Assuming version is required for update ) @pytest.fixture def db_settlement_model(): return SettlementModel( id=1, group_id=1, paid_by_user_id=1, paid_to_user_id=2, amount=Decimal("10.50"), settlement_date=datetime.now(timezone.utc), description="Original settlement", version=1, # Initial version created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), payer=UserModel(id=1, name="Payer User"), payee=UserModel(id=2, name="Payee User"), group=GroupModel(id=1, name="Test Group") ) @pytest.fixture def payer_user_model(): return UserModel(id=1, name="Payer User", email="payer@example.com") @pytest.fixture def payee_user_model(): return UserModel(id=2, name="Payee User", email="payee@example.com") @pytest.fixture def group_model(): return GroupModel(id=1, name="Test Group") # Tests for create_settlement @pytest.mark.asyncio async def test_create_settlement_success(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model): mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model] # Order of gets created_settlement = await create_settlement(mock_db_session, settlement_create_data, current_user_id=1) mock_db_session.add.assert_called_once() mock_db_session.commit.assert_called_once() mock_db_session.refresh.assert_called_once() assert created_settlement is not None assert created_settlement.amount == settlement_create_data.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) assert created_settlement.paid_by_user_id == settlement_create_data.paid_by_user_id assert created_settlement.paid_to_user_id == settlement_create_data.paid_to_user_id @pytest.mark.asyncio async def test_create_settlement_payer_not_found(mock_db_session, settlement_create_data): mock_db_session.get.side_effect = [None, payee_user_model, group_model] with pytest.raises(UserNotFoundError) as excinfo: await create_settlement(mock_db_session, settlement_create_data, 1) assert "Payer" in str(excinfo.value) @pytest.mark.asyncio async def test_create_settlement_payee_not_found(mock_db_session, settlement_create_data, payer_user_model): mock_db_session.get.side_effect = [payer_user_model, None, group_model] with pytest.raises(UserNotFoundError) as excinfo: await create_settlement(mock_db_session, settlement_create_data, 1) assert "Payee" in str(excinfo.value) @pytest.mark.asyncio async def test_create_settlement_group_not_found(mock_db_session, settlement_create_data, payer_user_model, payee_user_model): mock_db_session.get.side_effect = [payer_user_model, payee_user_model, None] with pytest.raises(GroupNotFoundError): await create_settlement(mock_db_session, settlement_create_data, 1) @pytest.mark.asyncio async def test_create_settlement_payer_equals_payee(mock_db_session, settlement_create_data, payer_user_model, group_model): settlement_create_data.paid_to_user_id = settlement_create_data.paid_by_user_id mock_db_session.get.side_effect = [payer_user_model, payer_user_model, group_model] with pytest.raises(InvalidOperationError) as excinfo: await create_settlement(mock_db_session, settlement_create_data, 1) assert "Payer and Payee cannot be the same user" in str(excinfo.value) @pytest.mark.asyncio async def test_create_settlement_commit_failure(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model): mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model] mock_db_session.commit.side_effect = Exception("DB commit failed") with pytest.raises(InvalidOperationError) as excinfo: await create_settlement(mock_db_session, settlement_create_data, 1) assert "Failed to save settlement" in str(excinfo.value) mock_db_session.rollback.assert_called_once() # Tests for get_settlement_by_id @pytest.mark.asyncio async def test_get_settlement_by_id_found(mock_db_session, db_settlement_model): mock_db_session.execute.return_value.scalars.return_value.first.return_value = db_settlement_model settlement = await get_settlement_by_id(mock_db_session, 1) assert settlement is not None assert settlement.id == 1 mock_db_session.execute.assert_called_once() @pytest.mark.asyncio async def test_get_settlement_by_id_not_found(mock_db_session): mock_db_session.execute.return_value.scalars.return_value.first.return_value = None settlement = await get_settlement_by_id(mock_db_session, 999) assert settlement is None # Tests for get_settlements_for_group @pytest.mark.asyncio async def test_get_settlements_for_group_success(mock_db_session, db_settlement_model): mock_db_session.execute.return_value.scalars.return_value.all.return_value = [db_settlement_model] settlements = await get_settlements_for_group(mock_db_session, group_id=1) assert len(settlements) == 1 assert settlements[0].group_id == 1 mock_db_session.execute.assert_called_once() # Tests for get_settlements_involving_user @pytest.mark.asyncio async def test_get_settlements_involving_user_success(mock_db_session, db_settlement_model): mock_db_session.execute.return_value.scalars.return_value.all.return_value = [db_settlement_model] settlements = await get_settlements_involving_user(mock_db_session, user_id=1) assert len(settlements) == 1 assert settlements[0].paid_by_user_id == 1 or settlements[0].paid_to_user_id == 1 mock_db_session.execute.assert_called_once() @pytest.mark.asyncio async def test_get_settlements_involving_user_with_group_filter(mock_db_session, db_settlement_model): mock_db_session.execute.return_value.scalars.return_value.all.return_value = [db_settlement_model] settlements = await get_settlements_involving_user(mock_db_session, user_id=1, group_id=1) assert len(settlements) == 1 # More specific assertions about the query would require deeper mocking of SQLAlchemy query construction mock_db_session.execute.assert_called_once() # Tests for update_settlement @pytest.mark.asyncio async def test_update_settlement_success(mock_db_session, db_settlement_model, settlement_update_data): # Ensure settlement_update_data.version matches db_settlement_model.version settlement_update_data.version = db_settlement_model.version # Mock datetime.now() fixed_datetime_now = datetime.now(timezone.utc) with patch('app.crud.settlement.datetime', wraps=datetime) as mock_datetime: mock_datetime.now.return_value = fixed_datetime_now updated_settlement = await update_settlement(mock_db_session, db_settlement_model, settlement_update_data) mock_db_session.commit.assert_called_once() mock_db_session.refresh.assert_called_once() assert updated_settlement.description == settlement_update_data.description assert updated_settlement.settlement_date == settlement_update_data.settlement_date assert updated_settlement.version == db_settlement_model.version + 1 # Version incremented assert updated_settlement.updated_at == fixed_datetime_now @pytest.mark.asyncio async def test_update_settlement_version_mismatch(mock_db_session, db_settlement_model, settlement_update_data): settlement_update_data.version = db_settlement_model.version + 1 # Mismatched version with pytest.raises(InvalidOperationError) as excinfo: await update_settlement(mock_db_session, db_settlement_model, settlement_update_data) assert "version does not match" in str(excinfo.value) @pytest.mark.asyncio async def test_update_settlement_disallowed_field(mock_db_session, db_settlement_model): # Create an update payload with a field not allowed to be updated, e.g., 'amount' invalid_update_data = SettlementUpdate(amount=Decimal("100.00"), version=db_settlement_model.version) with pytest.raises(InvalidOperationError) as excinfo: await update_settlement(mock_db_session, db_settlement_model, invalid_update_data) assert "Field 'amount' cannot be updated" in str(excinfo.value) @pytest.mark.asyncio async def test_update_settlement_commit_failure(mock_db_session, db_settlement_model, settlement_update_data): settlement_update_data.version = db_settlement_model.version mock_db_session.commit.side_effect = Exception("DB commit failed") with pytest.raises(InvalidOperationError) as excinfo: await update_settlement(mock_db_session, db_settlement_model, settlement_update_data) assert "Failed to update settlement" in str(excinfo.value) mock_db_session.rollback.assert_called_once() # Tests for delete_settlement @pytest.mark.asyncio async def test_delete_settlement_success(mock_db_session, db_settlement_model): await delete_settlement(mock_db_session, db_settlement_model) mock_db_session.delete.assert_called_once_with(db_settlement_model) mock_db_session.commit.assert_called_once() @pytest.mark.asyncio async def test_delete_settlement_success_with_version_check(mock_db_session, db_settlement_model): await delete_settlement(mock_db_session, db_settlement_model, expected_version=db_settlement_model.version) mock_db_session.delete.assert_called_once_with(db_settlement_model) mock_db_session.commit.assert_called_once() @pytest.mark.asyncio async def test_delete_settlement_version_mismatch(mock_db_session, db_settlement_model): with pytest.raises(InvalidOperationError) as excinfo: await delete_settlement(mock_db_session, db_settlement_model, expected_version=db_settlement_model.version + 1) assert "Expected version" in str(excinfo.value) assert "does not match current version" in str(excinfo.value) mock_db_session.delete.assert_not_called() @pytest.mark.asyncio async def test_delete_settlement_commit_failure(mock_db_session, db_settlement_model): mock_db_session.commit.side_effect = Exception("DB commit failed") with pytest.raises(InvalidOperationError) as excinfo: await delete_settlement(mock_db_session, db_settlement_model) assert "Failed to delete settlement" in str(excinfo.value) mock_db_session.rollback.assert_called_once() import pytest from unittest.mock import AsyncMock, MagicMock from sqlalchemy.exc import IntegrityError, OperationalError from app.crud.user import get_user_by_email, create_user from app.schemas.user import UserCreate from app.models import User as UserModel from app.core.exceptions import ( UserCreationError, EmailAlreadyRegisteredError, DatabaseConnectionError, DatabaseIntegrityError, DatabaseQueryError, DatabaseTransactionError ) # Fixtures @pytest.fixture def mock_db_session(): return AsyncMock() @pytest.fixture def user_create_data(): return UserCreate(email="test@example.com", password="password123", name="Test User") @pytest.fixture def existing_user_data(): return UserModel(id=1, email="exists@example.com", password_hash="hashed_password", name="Existing User") # Tests for get_user_by_email @pytest.mark.asyncio async def test_get_user_by_email_found(mock_db_session, existing_user_data): mock_db_session.execute.return_value.scalars.return_value.first.return_value = existing_user_data user = await get_user_by_email(mock_db_session, "exists@example.com") assert user is not None assert user.email == "exists@example.com" mock_db_session.execute.assert_called_once() @pytest.mark.asyncio async def test_get_user_by_email_not_found(mock_db_session): mock_db_session.execute.return_value.scalars.return_value.first.return_value = None user = await get_user_by_email(mock_db_session, "nonexistent@example.com") assert user is None mock_db_session.execute.assert_called_once() @pytest.mark.asyncio async def test_get_user_by_email_db_connection_error(mock_db_session): mock_db_session.execute.side_effect = OperationalError("mock_op_error", "params", "orig") with pytest.raises(DatabaseConnectionError): await get_user_by_email(mock_db_session, "test@example.com") @pytest.mark.asyncio async def test_get_user_by_email_db_query_error(mock_db_session): # Simulate a generic SQLAlchemyError that is not OperationalError mock_db_session.execute.side_effect = IntegrityError("mock_sql_error", "params", "orig") # Using IntegrityError as an example of SQLAlchemyError with pytest.raises(DatabaseQueryError): await get_user_by_email(mock_db_session, "test@example.com") # Tests for create_user @pytest.mark.asyncio async def test_create_user_success(mock_db_session, user_create_data): # The actual user object returned would be created by SQLAlchemy based on db_user # We mock the process: db.add is called, then db.flush, then db.refresh updates db_user async def mock_refresh(user_model_instance): user_model_instance.id = 1 # Simulate DB assigning an ID # Simulate other db-generated fields if necessary return None mock_db_session.refresh = AsyncMock(side_effect=mock_refresh) mock_db_session.flush = AsyncMock() mock_db_session.add = MagicMock() created_user = await create_user(mock_db_session, user_create_data) mock_db_session.add.assert_called_once() mock_db_session.flush.assert_called_once() mock_db_session.refresh.assert_called_once() assert created_user is not None assert created_user.email == user_create_data.email assert created_user.name == user_create_data.name assert hasattr(created_user, 'id') # Check if ID was assigned (simulated by mock_refresh) # Password hash check would be more involved, ensure hash_password was called correctly # For now, we assume hash_password works as intended and is tested elsewhere. @pytest.mark.asyncio async def test_create_user_email_already_registered(mock_db_session, user_create_data): mock_db_session.flush.side_effect = IntegrityError("mock error (unique constraint)", "params", "orig") with pytest.raises(EmailAlreadyRegisteredError): await create_user(mock_db_session, user_create_data) @pytest.mark.asyncio async def test_create_user_db_integrity_error_not_unique(mock_db_session, user_create_data): # Simulate an IntegrityError that is not related to a unique constraint mock_db_session.flush.side_effect = IntegrityError("mock error (not unique constraint)", "params", "orig") with pytest.raises(DatabaseIntegrityError): await create_user(mock_db_session, user_create_data) @pytest.mark.asyncio async def test_create_user_db_connection_error(mock_db_session, user_create_data): mock_db_session.begin.side_effect = OperationalError("mock_op_error", "params", "orig") with pytest.raises(DatabaseConnectionError): await create_user(mock_db_session, user_create_data) # also test OperationalError on flush mock_db_session.begin.side_effect = None # reset side effect mock_db_session.flush.side_effect = OperationalError("mock_op_error", "params", "orig") with pytest.raises(DatabaseConnectionError): await create_user(mock_db_session, user_create_data) @pytest.mark.asyncio async def test_create_user_db_transaction_error(mock_db_session, user_create_data): # Simulate a generic SQLAlchemyError on flush that is not IntegrityError or OperationalError mock_db_session.flush.side_effect = UserCreationError("Simulated non-specific SQLAlchemyError") # Or any other SQLAlchemyError with pytest.raises(DatabaseTransactionError): await create_user(mock_db_session, user_create_data) ## Polished PWA Plan: Shared Lists & Household Management ## 1. Product Overview **Concept:** Develop a Progressive Web App (PWA) focused on simplifying household coordination. Users can: - Create, manage, and **share** shopping lists within defined groups (e.g., households, trip members). - Capture images of receipts or shopping lists via the browser and extract items using **Google Cloud Vision API** for OCR. - Track item costs on shared lists and easily split expenses among group participants. - (Future) Manage and assign household chores. **Target Audience:** Households, roommates, families, groups organizing shared purchases. **UX Philosophy:** - **User-Centered & Collaborative:** Design intuitive flows for both individual use and group collaboration with minimal friction. - **Native-like PWA Experience:** Leverage service workers, caching, and manifest files for reliable offline use, installability, and smooth performance. - **Clarity & Accessibility:** Prioritize high contrast, legible typography, straightforward navigation, and adherence to accessibility standards (WCAG). - **Informative Feedback:** Provide clear visual feedback for actions (animations, loading states), OCR processing status, and data synchronization, including handling potential offline conflicts gracefully. --- ## 2. MVP Scope (Refined & Focused) The MVP will focus on delivering a robust, shareable shopping list experience with integrated OCR and cost splitting, built as a high-quality PWA. **Chore management is deferred post-MVP** to ensure a polished core experience at launch. 1. **Shared Shopping List Management:** * **Core Features:** Create, update, delete lists and items. Mark items as complete. Basic item sorting/reordering (e.g., manual drag-and-drop). * **Collaboration:** Share lists within user-defined groups. Real-time (or near real-time) updates visible to group members (via polling or simple WebSocket for MVP). * **PWA/UX:** Responsive design, offline access to cached lists, basic conflict indication if offline edits clash (e.g., "Item updated by another user, refresh needed"). 2. **OCR Integration (Google Cloud Vision):** * **Core Features:** Capture images via browser (`` or `getUserMedia`). Upload images to the FastAPI backend. Backend securely calls **Google Cloud Vision API (Text Detection / Document Text Detection)**. Process results, suggest items to add to the list. * **PWA/UX:** Clear instructions for image capture. Progress indicators during upload/processing. Display editable OCR results for user review and confirmation before adding to the list. Handle potential API errors or low-confidence results gracefully. 3. **Cost Splitting (Integrated with Lists):** * **Core Features:** Assign prices to items *on the shopping list* as they are purchased. Add participants (from the shared group) to a list's expense split. Calculate totals per list and simple equal splits per participant. * **PWA/UX:** Clear display of totals and individual shares. Easy interface for marking items as bought and adding their price. 4. **User Authentication & Group Management:** * **Core Features:** Secure email/password signup & login (JWT-based). Ability to create simple groups (e.g., "Household"). Mechanism to invite/add users to a group (e.g., unique invite code/link). Basic role distinction (e.g., group owner/admin, member) if necessary for managing participants. * **PWA/UX:** Minimalist forms, clear inline validation, smooth onboarding explaining the group concept. 5. **Core PWA Functionality:** * **Core Features:** Installable via `manifest.json`. Offline access via service worker caching (app shell, static assets, user data). Basic background sync strategy for offline actions (e.g., "last write wins" for simple edits, potentially queueing adds/deletes). --- ## 3. Feature Breakdown & UX Enhancements (MVP Focus) ### A. Shared Shopping Lists - **Screens:** Dashboard (list overview), List Detail (items), Group Management. - **Flows:** Create list -> (Optional) Share with group -> Add/edit/check items -> See updates from others -> Mark list complete. - **UX Focus:** Smooth transitions, clear indication of shared status, offline caching, simple conflict notification (not full resolution in MVP). ### B. OCR with Google Cloud Vision - **Flow:** Tap "Add via OCR" -> Capture/Select Image -> Upload -> Show Progress -> Display Review Screen (editable text boxes for potential items) -> User confirms/edits -> Items added to list. - **UX Focus:** Clear instructions, robust error handling (API errors, poor image quality feedback if possible), easy correction interface, manage user expectations regarding OCR accuracy. Monitor API costs/quotas. ### C. Integrated Cost Splitting - **Flow:** Open shared list -> Mark item "bought" -> Input price -> View updated list total -> Go to "Split Costs" view for the list -> Confirm participants (group members) -> See calculated equal split. - **UX Focus:** Seamless transition from shopping to cost entry. Clear, real-time calculation display. Simple participant management within the list context. ### D. User Auth & Groups - **Flow:** Sign up/Login -> Create a group -> Invite members (e.g., share code) -> Member joins group -> Access shared lists. - **UX Focus:** Secure and straightforward auth. Simple group creation and joining process. Clear visibility of group members. ### E. PWA Essentials - **Manifest:** Define app name, icons, theme, display mode. - **Service Worker:** Cache app shell, assets, API responses (user data). Implement basic offline sync queue for actions performed offline (e.g., adding/checking items). Define a clear sync conflict strategy (e.g., last-write-wins, notify user on conflict). --- ## 4. Architecture & Technology Stack ### Frontend: Svelte PWA - **Framework:** Svelte/SvelteKit (Excellent for performant, component-based PWAs). - **State Management:** Svelte Stores for managing UI state and cached data. - **PWA Tools:** Workbox.js (via SvelteKit integration or standalone) for robust service worker generation and caching strategies. - **Styling:** Tailwind CSS or standard CSS with scoped styles. - **UX:** Design system (e.g., using Figma), Storybook for component development. ### Backend: FastAPI & PostgreSQL - **Framework:** FastAPI (High performance, async support, auto-docs, Pydantic validation). - **Database:** PostgreSQL (Reliable, supports JSONB for flexibility if needed). Schema designed to handle users, groups, lists, items, costs, and relationships. Basic indexing on foreign keys and frequently queried fields (user IDs, group IDs, list IDs). - **ORM:** SQLAlchemy (async support with v2.0+) or Tortoise ORM (async-native). Alembic for migrations. - **OCR Integration:** Use the official **Google Cloud Client Libraries for Python** to interact with the Vision API. Implement robust error handling, retries, and potentially rate limiting/cost control logic. Ensure API calls are `async` to avoid blocking. - **Authentication:** JWT tokens for stateless session management. - **Deployment:** Containerize using Docker/Docker Compose for development and deployment consistency. Deploy on a scalable cloud platform (e.g., Google Cloud Run, AWS Fargate, DigitalOcean App Platform). - **Monitoring:** Logging (standard Python logging), Error Tracking (Sentry), Performance Monitoring (Prometheus/Grafana if needed later). --- # Finalized User Stories, Flow Mapping, Sharing Model & Sync, Tech Stack & Initial Architecture Diagram ## 1. User Stories ### Authentication & User Management - As a new user, I want to sign up with my email so I can create and manage shopping lists - As a returning user, I want to log in securely to access my lists and groups - As a user, I want to reset my password if I forget it - As a user, I want to edit my profile information (name, avatar) ### Group Management - As a user, I want to create a new group (e.g., "Household", "Roommates") to organize shared lists - As a group creator, I want to invite others to join my group via a shareable link/code - As an invitee, I want to easily join a group by clicking a link or entering a code - As a group owner, I want to remove members if needed - As a user, I want to leave a group I no longer wish to be part of - As a user, I want to see all groups I belong to and switch between them ### List Management - As a user, I want to create a personal shopping list with a title and optional description - As a user, I want to share a list with a specific group so members can collaborate - As a user, I want to view all my lists (personal and shared) from a central dashboard - As a user, I want to archive or delete lists I no longer need - As a user, I want to mark a list as "shopping complete" when finished - As a user, I want to see which group a list is shared with ### Item Management - As a user, I want to add items to a list with names and optional quantities - As a user, I want to mark items as purchased when shopping - As a user, I want to edit item details (name, quantity, notes) - As a user, I want to delete items from a list - As a user, I want to reorder items on my list for shopping efficiency - As a user, I want to see who added or marked items as purchased in shared lists ### OCR Integration - As a user, I want to capture a photo of a physical shopping list or receipt - As a user, I want the app to extract text and convert it into list items - As a user, I want to review and edit OCR results before adding to my list - As a user, I want clear feedback on OCR processing status - As a user, I want to retry OCR if the results aren't satisfactory ### Cost Splitting - As a user, I want to add prices to items as I purchase them - As a user, I want to see the total cost of all purchased items in a list - As a user, I want to split costs equally among group members - As a user, I want to see who owes what amount based on the split - As a user, I want to mark expenses as settled ### PWA & Offline Experience - As a user, I want to install the app on my home screen for quick access - As a user, I want to view and edit my lists even when offline - As a user, I want my changes to sync automatically when I'm back online - As a user, I want to be notified if my offline changes conflict with others' changes # docker-compose.yml (in project root) version: '3.8' services: db: image: postgres:15 # Use a specific PostgreSQL version container_name: postgres_db environment: POSTGRES_USER: dev_user # Define DB user POSTGRES_PASSWORD: dev_password # Define DB password POSTGRES_DB: dev_db # Define Database name volumes: - postgres_data:/var/lib/postgresql/data # Persist data using a named volume ports: - "5432:5432" # Expose PostgreSQL port to host (optional, for direct access) healthcheck: test: ["CMD-SHELL", "pg_isready -U $${POSTGRES_USER} -d $${POSTGRES_DB}"] interval: 10s timeout: 5s retries: 5 start_period: 10s restart: unless-stopped backend: container_name: fastapi_backend build: context: ./be # Path to the directory containing the Dockerfile dockerfile: Dockerfile volumes: # Mount local code into the container for development hot-reloading # The code inside the container at /app will mirror your local ./be directory - ./be:/app ports: - "8000:8000" # Map container port 8000 to host port 8000 environment: # Pass the database URL to the backend container # Uses the service name 'db' as the host, and credentials defined above # IMPORTANT: Use the correct async driver prefix if your app needs it! - DATABASE_URL=postgresql+asyncpg://dev_user:dev_password@db:5432/dev_db # Add other environment variables needed by the backend here # - SOME_OTHER_VAR=some_value depends_on: db: # Wait for the db service to be healthy before starting backend condition: service_healthy command: ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] # Override CMD for development reload restart: unless-stopped pgadmin: # Optional service for database administration image: dpage/pgadmin4:latest container_name: pgadmin4_server environment: PGADMIN_DEFAULT_EMAIL: admin@example.com # Change as needed PGADMIN_DEFAULT_PASSWORD: admin_password # Change to a secure password PGADMIN_CONFIG_SERVER_MODE: 'False' # Run in Desktop mode for easier local dev server setup volumes: - pgadmin_data:/var/lib/pgadmin # Persist pgAdmin configuration ports: - "5050:80" # Map container port 80 to host port 5050 depends_on: - db # Depends on the database service restart: unless-stopped volumes: # Define named volumes for data persistence postgres_data: pgadmin_data: [*.{js,jsx,mjs,cjs,ts,tsx,mts,cts,vue}] charset = utf-8 indent_size = 2 indent_style = space end_of_line = lf insert_final_newline = true trim_trailing_whitespace = true { "$schema": "https://json.schemastore.org/prettierrc", "singleQuote": true, "printWidth": 100 } { "recommendations": [ "dbaeumer.vscode-eslint", "esbenp.prettier-vscode", "editorconfig.editorconfig", "vue.volar", "wayou.vscode-todo-highlight" ], "unwantedRecommendations": [ "octref.vetur", "hookyqr.beautify", "dbaeumer.jshint", "ms-vscode.vscode-typescript-tslint-plugin" ] } { "editor.bracketPairColorization.enabled": true, "editor.guides.bracketPairs": true, "editor.formatOnSave": true, "editor.defaultFormatter": "esbenp.prettier-vscode", "editor.codeActionsOnSave": [ "source.fixAll.eslint" ], "eslint.validate": [ "javascript", "javascriptreact", "typescript", "vue" ], "typescript.tsdk": "node_modules/typescript/lib" } import js from '@eslint/js' import globals from 'globals' import pluginVue from 'eslint-plugin-vue' import pluginQuasar from '@quasar/app-vite/eslint' import { defineConfigWithVueTs, vueTsConfigs } from '@vue/eslint-config-typescript' import prettierSkipFormatting from '@vue/eslint-config-prettier/skip-formatting' export default defineConfigWithVueTs( { /** * Ignore the following files. * Please note that pluginQuasar.configs.recommended() already ignores * the "node_modules" folder for you (and all other Quasar project * relevant folders and files). * * ESLint requires "ignores" key to be the only one in this object */ // ignores: [] }, pluginQuasar.configs.recommended(), js.configs.recommended, /** * https://eslint.vuejs.org * * pluginVue.configs.base * -> Settings and rules to enable correct ESLint parsing. * pluginVue.configs[ 'flat/essential'] * -> base, plus rules to prevent errors or unintended behavior. * pluginVue.configs["flat/strongly-recommended"] * -> Above, plus rules to considerably improve code readability and/or dev experience. * pluginVue.configs["flat/recommended"] * -> Above, plus rules to enforce subjective community defaults to ensure consistency. */ pluginVue.configs[ 'flat/essential' ], { files: ['**/*.ts', '**/*.vue'], rules: { '@typescript-eslint/consistent-type-imports': [ 'error', { prefer: 'type-imports' } ], } }, // https://github.com/vuejs/eslint-config-typescript vueTsConfigs.recommendedTypeChecked, { languageOptions: { ecmaVersion: 'latest', sourceType: 'module', globals: { ...globals.browser, ...globals.node, // SSR, Electron, config files process: 'readonly', // process.env.* ga: 'readonly', // Google Analytics cordova: 'readonly', Capacitor: 'readonly', chrome: 'readonly', // BEX related browser: 'readonly' // BEX related } }, // add your custom rules here rules: { 'prefer-promise-reject-errors': 'off', // allow debugger during development only 'no-debugger': process.env.NODE_ENV === 'production' ? 'error' : 'off' } }, { files: [ 'src-pwa/custom-service-worker.ts' ], languageOptions: { globals: { ...globals.serviceworker } } }, prettierSkipFormatting ) <%= productName %> // https://github.com/michael-ciniawsky/postcss-load-config import autoprefixer from 'autoprefixer' // import rtlcss from 'postcss-rtlcss' export default { plugins: [ // https://github.com/postcss/autoprefixer autoprefixer({ overrideBrowserslist: [ 'last 4 Chrome versions', 'last 4 Firefox versions', 'last 4 Edge versions', 'last 4 Safari versions', 'last 4 Android versions', 'last 4 ChromeAndroid versions', 'last 4 FirefoxAndroid versions', 'last 4 iOS versions' ] }), // https://github.com/elchininet/postcss-rtlcss // If you want to support RTL css, then // 1. yarn/pnpm/bun/npm install postcss-rtlcss // 2. optionally set quasar.config.js > framework > lang to an RTL language // 3. uncomment the following line (and its import statement above): // rtlcss() ] } { "orientation": "portrait", "background_color": "#ffffff", "theme_color": "#027be3", "icons": [ { "src": "icons/icon-128x128.png", "sizes": "128x128", "type": "image/png" }, { "src": "icons/icon-192x192.png", "sizes": "192x192", "type": "image/png" }, { "src": "icons/icon-256x256.png", "sizes": "256x256", "type": "image/png" }, { "src": "icons/icon-384x384.png", "sizes": "384x384", "type": "image/png" }, { "src": "icons/icon-512x512.png", "sizes": "512x512", "type": "image/png" } ] } declare namespace NodeJS { interface ProcessEnv { SERVICE_WORKER_FILE: string; PWA_FALLBACK_HTML: string; PWA_SERVICE_WORKER_REGEX: string; } } import { register } from 'register-service-worker'; // The ready(), registered(), cached(), updatefound() and updated() // events passes a ServiceWorkerRegistration instance in their arguments. // ServiceWorkerRegistration: https://developer.mozilla.org/en-US/docs/Web/API/ServiceWorkerRegistration register(process.env.SERVICE_WORKER_FILE, { // The registrationOptions object will be passed as the second argument // to ServiceWorkerContainer.register() // https://developer.mozilla.org/en-US/docs/Web/API/ServiceWorkerContainer/register#Parameter // registrationOptions: { scope: './' }, ready (/* registration */) { // console.log('Service worker is active.') }, registered (/* registration */) { // console.log('Service worker has been registered.') }, cached (/* registration */) { // console.log('Content has been cached for offline use.') }, updatefound (/* registration */) { // console.log('New content is downloading.') }, updated (/* registration */) { // console.log('New content is available; please refresh.') }, offline () { // console.log('No internet connection found. App is running in offline mode.') }, error (/* err */) { // console.error('Error during service worker registration:', err) }, }); { "extends": "../tsconfig.json", "compilerOptions": { "lib": ["WebWorker", "ESNext"] }, "include": ["*.ts", "*.d.ts"] } import { defineBoot } from '#q-app/wrappers'; import { createI18n } from 'vue-i18n'; import messages from 'src/i18n'; export type MessageLanguages = keyof typeof messages; // Type-define 'en-US' as the master schema for the resource export type MessageSchema = typeof messages['en-US']; // See https://vue-i18n.intlify.dev/guide/advanced/typescript.html#global-resource-schema-type-definition /* eslint-disable @typescript-eslint/no-empty-object-type */ declare module 'vue-i18n' { // define the locale messages schema export interface DefineLocaleMessage extends MessageSchema {} // define the datetime format schema export interface DefineDateTimeFormat {} // define the number format schema export interface DefineNumberFormat {} } /* eslint-enable @typescript-eslint/no-empty-object-type */ export default defineBoot(({ app }) => { const i18n = createI18n<{ message: MessageSchema }, MessageLanguages>({ locale: 'en-US', legacy: false, messages, }); // Set i18n instance on app app.use(i18n); }); export interface Todo { id: number; content: string; } export interface Meta { totalCount: number; } import { api } from 'boot/axios'; import { API_BASE_URL, API_VERSION, API_ENDPOINTS } from './api-config'; // Helper function to get full API URL export const getApiUrl = (endpoint: string): string => { return `${API_BASE_URL}/api/${API_VERSION}${endpoint}`; }; // Helper function to make API calls export const apiClient = { get: (endpoint: string, config = {}) => api.get(getApiUrl(endpoint), config), post: (endpoint: string, data = {}, config = {}) => api.post(getApiUrl(endpoint), data, config), put: (endpoint: string, data = {}, config = {}) => api.put(getApiUrl(endpoint), data, config), patch: (endpoint: string, data = {}, config = {}) => api.patch(getApiUrl(endpoint), data, config), delete: (endpoint: string, config = {}) => api.delete(getApiUrl(endpoint), config), }; export { API_ENDPOINTS }; // app global css in SCSS form // Quasar SCSS (& Sass) Variables // -------------------------------------------------- // To customize the look and feel of this app, you can override // the Sass/SCSS variables found in Quasar's source Sass/SCSS files. // Check documentation for full list of Quasar variables // Your own variables (that are declared here) and Quasar's own // ones will be available out of the box in your .vue/.scss/.sass files // It's highly recommended to change the default colors // to match your app's branding. // Tip: Use the "Theme Builder" on Quasar's documentation website. $primary : #1976D2; $secondary : #26A69A; $accent : #9C27B0; $dark : #1D1D1D; $dark-page : #121212; $positive : #21BA45; $negative : #C10015; $info : #31CCEC; $warning : #F2C037; declare namespace NodeJS { interface ProcessEnv { NODE_ENV: string; VUE_ROUTER_MODE: 'hash' | 'history' | 'abstract' | undefined; VUE_ROUTER_BASE: string | undefined; } } // This is just an example, // so you can safely delete all default props below export default { failed: 'Action failed', success: 'Action was successful' }; import enUS from './en-US'; export default { 'en-US': enUS }; import { defineRouter } from '#q-app/wrappers'; import { createMemoryHistory, createRouter, createWebHashHistory, createWebHistory, } from 'vue-router'; import routes from './routes'; import { useAuthStore } from 'stores/auth'; /* * If not building with SSR mode, you can * directly export the Router instantiation; * * The function below can be async too; either use * async/await or return a Promise which resolves * with the Router instance. */ export default defineRouter(function (/* { store, ssrContext } */) { const createHistory = process.env.SERVER ? createMemoryHistory : process.env.VUE_ROUTER_MODE === 'history' ? createWebHistory : createWebHashHistory; const Router = createRouter({ scrollBehavior: () => ({ left: 0, top: 0 }), routes, // Leave this as is and make changes in quasar.conf.js instead! // quasar.conf.js -> build -> vueRouterMode // quasar.conf.js -> build -> publicPath history: createHistory(process.env.VUE_ROUTER_BASE), }); // Navigation guard to check authentication Router.beforeEach((to, from, next) => { const authStore = useAuthStore(); const isAuthenticated = authStore.isAuthenticated; // Define public routes that don't require authentication const publicRoutes = ['/login', '/signup']; // Check if the route requires authentication const requiresAuth = !publicRoutes.includes(to.path); if (requiresAuth && !isAuthenticated) { // Redirect to login if trying to access protected route without authentication next({ path: '/login', query: { redirect: to.fullPath } }); } else if (!requiresAuth && isAuthenticated) { // Redirect to home if trying to access login/signup while authenticated next({ path: '/' }); } else { // Proceed with navigation next(); } }); return Router; }); import { defineStore } from '#q-app/wrappers' import { createPinia } from 'pinia' /* * When adding new properties to stores, you should also * extend the `PiniaCustomProperties` interface. * @see https://pinia.vuejs.org/core-concepts/plugins.html#typing-new-store-properties */ declare module 'pinia' { // eslint-disable-next-line @typescript-eslint/no-empty-object-type export interface PiniaCustomProperties { // add your custom properties here, if any } } /* * If not building with SSR mode, you can * directly export the Store instantiation; * * The function below can be async too; either use * async/await or return a Promise which resolves * with the Store instance. */ export default defineStore((/* { ssrContext } */) => { const pinia = createPinia() // You can add Pinia plugins here // pinia.use(SomePiniaPlugin) return pinia }) # app/api/dependencies.py import logging from typing import Optional from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from sqlalchemy.ext.asyncio import AsyncSession from jose import JWTError from app.database import get_db from app.core.security import verify_access_token from app.crud import user as crud_user from app.models import User as UserModel # Import the SQLAlchemy model from app.config import settings logger = logging.getLogger(__name__) # Define the OAuth2 scheme # tokenUrl should point to your login endpoint relative to the base path # It's used by Swagger UI for the "Authorize" button flow. oauth2_scheme = OAuth2PasswordBearer(tokenUrl=settings.OAUTH2_TOKEN_URL) async def get_current_user( token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db) ) -> UserModel: """ Dependency to get the current user based on the JWT token. - Extracts token using OAuth2PasswordBearer. - Verifies the token (signature, expiry). - Fetches the user from the database based on the token's subject (email). - Raises HTTPException 401 if any step fails. Returns: The authenticated user's database model instance. """ credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=settings.AUTH_CREDENTIALS_ERROR, headers={settings.AUTH_HEADER_NAME: settings.AUTH_HEADER_PREFIX}, ) payload = verify_access_token(token) if payload is None: logger.warning("Token verification failed (invalid, expired, or malformed).") raise credentials_exception email: Optional[str] = payload.get("sub") if email is None: logger.error("Token payload missing 'sub' (subject/email).") raise credentials_exception # Token is malformed # Fetch user from database user = await crud_user.get_user_by_email(db, email=email) if user is None: logger.warning(f"User corresponding to token subject not found: {email}") # Could happen if user deleted after token issuance raise credentials_exception # Treat as invalid credentials logger.debug(f"Authenticated user retrieved: {user.email} (ID: {user.id})") return user # Optional: Dependency for getting the *active* current user # You might add an `is_active` flag to your User model later # async def get_current_active_user( # current_user: UserModel = Depends(get_current_user) # ) -> UserModel: # if not current_user.is_active: # Assuming an is_active attribute # logger.warning(f"Authentication attempt by inactive user: {current_user.email}") # raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user") # return current_user # app/api/v1/endpoints/groups.py import logging from typing import List from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db from app.api.dependencies import get_current_user from app.models import User as UserModel, UserRoleEnum # Import model and enum from app.schemas.group import GroupCreate, GroupPublic from app.schemas.invite import InviteCodePublic from app.schemas.message import Message # For simple responses from app.crud import group as crud_group from app.crud import invite as crud_invite from app.core.exceptions import ( GroupNotFoundError, GroupPermissionError, GroupMembershipError, GroupOperationError, GroupValidationError, InviteCreationError ) logger = logging.getLogger(__name__) router = APIRouter() @router.post( "", # Route relative to prefix "/groups" response_model=GroupPublic, status_code=status.HTTP_201_CREATED, summary="Create New Group", tags=["Groups"] ) async def create_group( group_in: GroupCreate, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """Creates a new group, adding the creator as the owner.""" logger.info(f"User {current_user.email} creating group: {group_in.name}") created_group = await crud_group.create_group(db=db, group_in=group_in, creator_id=current_user.id) # Load members explicitly if needed for the response (optional here) # created_group = await crud_group.get_group_by_id(db, created_group.id) return created_group @router.get( "", # Route relative to prefix "/groups" response_model=List[GroupPublic], summary="List User's Groups", tags=["Groups"] ) async def read_user_groups( db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """Retrieves all groups the current user is a member of.""" logger.info(f"Fetching groups for user: {current_user.email}") groups = await crud_group.get_user_groups(db=db, user_id=current_user.id) return groups @router.get( "/{group_id}", response_model=GroupPublic, summary="Get Group Details", tags=["Groups"] ) async def read_group( group_id: int, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """Retrieves details for a specific group, including members, if the user is part of it.""" logger.info(f"User {current_user.email} requesting details for group ID: {group_id}") # Check if user is a member first is_member = await crud_group.is_user_member(db=db, group_id=group_id, user_id=current_user.id) if not is_member: logger.warning(f"Access denied: User {current_user.email} not member of group {group_id}") raise GroupMembershipError(group_id, "view group details") group = await crud_group.get_group_by_id(db=db, group_id=group_id) if not group: logger.error(f"Group {group_id} requested by member {current_user.email} not found (data inconsistency?)") raise GroupNotFoundError(group_id) return group @router.post( "/{group_id}/invites", response_model=InviteCodePublic, summary="Create Group Invite", tags=["Groups", "Invites"] ) async def create_group_invite( group_id: int, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """Generates a new invite code for the group. Requires owner/admin role (MVP: owner only).""" logger.info(f"User {current_user.email} attempting to create invite for group {group_id}") user_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=current_user.id) # --- Permission Check (MVP: Owner only) --- if user_role != UserRoleEnum.owner: logger.warning(f"Permission denied: User {current_user.email} (role: {user_role}) cannot create invite for group {group_id}") raise GroupPermissionError(group_id, "create invites") # Check if group exists (implicitly done by role check, but good practice) group = await crud_group.get_group_by_id(db, group_id) if not group: raise GroupNotFoundError(group_id) invite = await crud_invite.create_invite(db=db, group_id=group_id, creator_id=current_user.id) if not invite: logger.error(f"Failed to generate unique invite code for group {group_id}") raise InviteCreationError(group_id) logger.info(f"User {current_user.email} created invite code for group {group_id}") return invite @router.delete( "/{group_id}/leave", response_model=Message, summary="Leave Group", tags=["Groups"] ) async def leave_group( group_id: int, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """Removes the current user from the specified group.""" logger.info(f"User {current_user.email} attempting to leave group {group_id}") user_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=current_user.id) if user_role is None: raise GroupMembershipError(group_id, "leave (you are not a member)") # --- MVP: Prevent owner leaving if they are the last member/owner --- if user_role == UserRoleEnum.owner: member_count = await crud_group.get_group_member_count(db, group_id) # More robust check: count owners. For now, just check member count. if member_count <= 1: logger.warning(f"Owner {current_user.email} attempted to leave group {group_id} as last member.") raise GroupValidationError("Owner cannot leave the group as the last member. Delete the group or transfer ownership.") # Proceed with removal deleted = await crud_group.remove_user_from_group(db, group_id=group_id, user_id=current_user.id) if not deleted: # Should not happen if role check passed, but handle defensively logger.error(f"Failed to remove user {current_user.email} from group {group_id} despite being a member.") raise GroupOperationError("Failed to leave group") logger.info(f"User {current_user.email} successfully left group {group_id}") return Message(detail="Successfully left the group") # --- Optional: Remove Member Endpoint --- @router.delete( "/{group_id}/members/{user_id_to_remove}", response_model=Message, summary="Remove Member From Group (Owner Only)", tags=["Groups"] ) async def remove_group_member( group_id: int, user_id_to_remove: int, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """Removes a specified user from the group. Requires current user to be owner.""" logger.info(f"Owner {current_user.email} attempting to remove user {user_id_to_remove} from group {group_id}") owner_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=current_user.id) # --- Permission Check --- if owner_role != UserRoleEnum.owner: logger.warning(f"Permission denied: User {current_user.email} (role: {owner_role}) cannot remove members from group {group_id}") raise GroupPermissionError(group_id, "remove members") # Prevent owner removing themselves via this endpoint if current_user.id == user_id_to_remove: raise GroupValidationError("Owner cannot remove themselves using this endpoint. Use 'Leave Group' instead.") # Check if target user is actually in the group target_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=user_id_to_remove) if target_role is None: raise GroupMembershipError(group_id, "remove this user (they are not a member)") # Proceed with removal deleted = await crud_group.remove_user_from_group(db, group_id=group_id, user_id=user_id_to_remove) if not deleted: logger.error(f"Owner {current_user.email} failed to remove user {user_id_to_remove} from group {group_id}.") raise GroupOperationError("Failed to remove member") logger.info(f"Owner {current_user.email} successfully removed user {user_id_to_remove} from group {group_id}") return Message(detail="Successfully removed member from the group") # app/api/v1/endpoints/health.py import logging from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.sql import text from app.database import get_db from app.schemas.health import HealthStatus from app.core.exceptions import DatabaseConnectionError logger = logging.getLogger(__name__) router = APIRouter() @router.get( "/health", response_model=HealthStatus, summary="Perform a Health Check", description="Checks the operational status of the API and its connection to the database.", tags=["Health"] ) async def check_health(db: AsyncSession = Depends(get_db)): """ Health check endpoint. Verifies API reachability and database connection. """ try: # Try executing a simple query to check DB connection result = await db.execute(text("SELECT 1")) if result.scalar_one() == 1: logger.info("Health check successful: Database connection verified.") return HealthStatus(status="ok", database="connected") else: # This case should ideally not happen with 'SELECT 1' logger.error("Health check failed: Database connection check returned unexpected result.") raise DatabaseConnectionError("Unexpected result from database connection check") except Exception as e: logger.error(f"Health check failed: Database connection error - {e}", exc_info=True) raise DatabaseConnectionError(str(e)) # app/api/v1/endpoints/invites.py import logging from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db from app.api.dependencies import get_current_user from app.models import User as UserModel, UserRoleEnum from app.schemas.invite import InviteAccept from app.schemas.message import Message from app.crud import invite as crud_invite from app.crud import group as crud_group from app.core.exceptions import ( InviteNotFoundError, InviteExpiredError, InviteAlreadyUsedError, InviteCreationError, GroupNotFoundError, GroupMembershipError ) logger = logging.getLogger(__name__) router = APIRouter() @router.post( "/accept", # Route relative to prefix "/invites" response_model=Message, summary="Accept Group Invite", tags=["Invites"] ) async def accept_invite( invite_in: InviteAccept, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """Accepts a group invite using the provided invite code.""" logger.info(f"User {current_user.email} attempting to accept invite code: {invite_in.invite_code}") # Get the invite invite = await crud_invite.get_invite_by_code(db, invite_code=invite_in.invite_code) if not invite: logger.warning(f"Invalid invite code attempted by user {current_user.email}: {invite_in.invite_code}") raise InviteNotFoundError(invite_in.invite_code) # Check if invite is expired if invite.is_expired(): logger.warning(f"Expired invite code attempted by user {current_user.email}: {invite_in.invite_code}") raise InviteExpiredError(invite_in.invite_code) # Check if invite has already been used if invite.used_at: logger.warning(f"Already used invite code attempted by user {current_user.email}: {invite_in.invite_code}") raise InviteAlreadyUsedError(invite_in.invite_code) # Check if group still exists group = await crud_group.get_group_by_id(db, group_id=invite.group_id) if not group: logger.error(f"Group {invite.group_id} not found for invite {invite_in.invite_code}") raise GroupNotFoundError(invite.group_id) # Check if user is already a member is_member = await crud_group.is_user_member(db, group_id=invite.group_id, user_id=current_user.id) if is_member: logger.warning(f"User {current_user.email} already a member of group {invite.group_id}") raise GroupMembershipError(invite.group_id, "join (already a member)") # Add user to group and mark invite as used success = await crud_invite.accept_invite(db, invite=invite, user_id=current_user.id) if not success: logger.error(f"Failed to accept invite {invite_in.invite_code} for user {current_user.email}") raise InviteCreationError(invite.group_id) logger.info(f"User {current_user.email} successfully joined group {invite.group_id} via invite {invite_in.invite_code}") return Message(detail="Successfully joined the group") # app/core/security.py from datetime import datetime, timedelta, timezone from typing import Any, Union, Optional from jose import JWTError, jwt from passlib.context import CryptContext from app.config import settings # Import settings from config # --- Password Hashing --- # Configure passlib context # Using bcrypt as the default hashing scheme # 'deprecated="auto"' will automatically upgrade hashes if needed on verification pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") def verify_password(plain_password: str, hashed_password: str) -> bool: """ Verifies a plain text password against a hashed password. Args: plain_password: The password attempt. hashed_password: The stored hash from the database. Returns: True if the password matches the hash, False otherwise. """ try: return pwd_context.verify(plain_password, hashed_password) except Exception: # Handle potential errors during verification (e.g., invalid hash format) return False def hash_password(password: str) -> str: """ Hashes a plain text password using the configured context (bcrypt). Args: password: The plain text password to hash. Returns: The resulting hash string. """ return pwd_context.hash(password) # --- JSON Web Tokens (JWT) --- def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str: """ Creates a JWT access token. Args: subject: The subject of the token (e.g., user ID or email). expires_delta: Optional timedelta object for token expiry. If None, uses ACCESS_TOKEN_EXPIRE_MINUTES from settings. Returns: The encoded JWT access token string. """ if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta( minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES ) # Data to encode in the token payload to_encode = {"exp": expire, "sub": str(subject), "type": "access"} encoded_jwt = jwt.encode( to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM ) return encoded_jwt def create_refresh_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str: """ Creates a JWT refresh token. Args: subject: The subject of the token (e.g., user ID or email). expires_delta: Optional timedelta object for token expiry. If None, uses REFRESH_TOKEN_EXPIRE_MINUTES from settings. Returns: The encoded JWT refresh token string. """ if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta( minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES ) # Data to encode in the token payload to_encode = {"exp": expire, "sub": str(subject), "type": "refresh"} encoded_jwt = jwt.encode( to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM ) return encoded_jwt def verify_access_token(token: str) -> Optional[dict]: """ Verifies a JWT access token and returns its payload if valid. Args: token: The JWT token string to verify. Returns: The decoded token payload (dict) if the token is valid and not expired, otherwise None. """ try: # Decode the token. This also automatically verifies: # - Signature (using SECRET_KEY and ALGORITHM) # - Expiration ('exp' claim) payload = jwt.decode( token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] ) if payload.get("type") != "access": raise JWTError("Invalid token type") return payload except JWTError as e: # Handles InvalidSignatureError, ExpiredSignatureError, etc. print(f"JWT Error: {e}") # Log the error for debugging return None except Exception as e: # Handle other potential unexpected errors during decoding print(f"Unexpected error decoding JWT: {e}") return None def verify_refresh_token(token: str) -> Optional[dict]: """ Verifies a JWT refresh token and returns its payload if valid. Args: token: The JWT token string to verify. Returns: The decoded token payload (dict) if the token is valid, not expired, and is a refresh token, otherwise None. """ try: payload = jwt.decode( token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] ) if payload.get("type") != "refresh": raise JWTError("Invalid token type") return payload except JWTError as e: print(f"JWT Error: {e}") # Log the error for debugging return None except Exception as e: print(f"Unexpected error decoding JWT: {e}") return None # You might add a function here later to extract the 'sub' (subject/user id) # specifically, often used in dependency injection for authentication. # def get_subject_from_token(token: str) -> Optional[str]: # payload = verify_access_token(token) # if payload: # return payload.get("sub") # return None # app/crud/user.py from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError from typing import Optional from app.models import User as UserModel # Alias to avoid name clash from app.schemas.user import UserCreate from app.core.security import hash_password from app.core.exceptions import ( UserCreationError, EmailAlreadyRegisteredError, DatabaseConnectionError, DatabaseIntegrityError, DatabaseQueryError, DatabaseTransactionError ) async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]: """Fetches a user from the database by email.""" try: async with db.begin(): result = await db.execute(select(UserModel).filter(UserModel.email == email)) return result.scalars().first() except OperationalError as e: raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") except SQLAlchemyError as e: raise DatabaseQueryError(f"Failed to query user: {str(e)}") async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel: """Creates a new user record in the database.""" try: async with db.begin(): _hashed_password = hash_password(user_in.password) db_user = UserModel( email=user_in.email, password_hash=_hashed_password, name=user_in.name ) db.add(db_user) await db.flush() # Flush to get DB-generated values await db.refresh(db_user) return db_user except IntegrityError as e: if "unique constraint" in str(e).lower(): raise EmailAlreadyRegisteredError() raise DatabaseIntegrityError(f"Failed to create user: {str(e)}") except OperationalError as e: raise DatabaseConnectionError(f"Database connection error: {str(e)}") except SQLAlchemyError as e: raise DatabaseTransactionError(f"Failed to create user: {str(e)}") # app/schemas/auth.py from pydantic import BaseModel, EmailStr from app.config import settings class Token(BaseModel): access_token: str refresh_token: str # Added refresh token token_type: str = settings.TOKEN_TYPE # Use configured token type # Optional: If you preferred not to use OAuth2PasswordRequestForm # class UserLogin(BaseModel): # email: EmailStr # password: str from pydantic import BaseModel, ConfigDict from typing import List, Optional from decimal import Decimal class UserCostShare(BaseModel): user_id: int user_identifier: str # Name or email items_added_value: Decimal = Decimal("0.00") # Total value of items this user added amount_due: Decimal # The user's share of the total cost (for equal split, this is total_cost / num_users) balance: Decimal # items_added_value - amount_due model_config = ConfigDict(from_attributes=True) class ListCostSummary(BaseModel): list_id: int list_name: str total_list_cost: Decimal num_participating_users: int equal_share_per_user: Decimal user_balances: List[UserCostShare] model_config = ConfigDict(from_attributes=True) class UserBalanceDetail(BaseModel): user_id: int user_identifier: str # Name or email total_paid_for_expenses: Decimal = Decimal("0.00") total_share_of_expenses: Decimal = Decimal("0.00") total_settlements_paid: Decimal = Decimal("0.00") total_settlements_received: Decimal = Decimal("0.00") net_balance: Decimal = Decimal("0.00") # (paid_for_expenses + settlements_received) - (share_of_expenses + settlements_paid) model_config = ConfigDict(from_attributes=True) class SuggestedSettlement(BaseModel): from_user_id: int from_user_identifier: str # Name or email of payer to_user_id: int to_user_identifier: str # Name or email of payee amount: Decimal model_config = ConfigDict(from_attributes=True) class GroupBalanceSummary(BaseModel): group_id: int group_name: str overall_total_expenses: Decimal = Decimal("0.00") overall_total_settlements: Decimal = Decimal("0.00") user_balances: List[UserBalanceDetail] # Optional: Could add a list of suggested settlements to zero out balances suggested_settlements: Optional[List[SuggestedSettlement]] = None model_config = ConfigDict(from_attributes=True) # class SuggestedSettlement(BaseModel): # from_user_id: int # to_user_id: int # amount: Decimal # app/schemas/health.py from pydantic import BaseModel from app.config import settings class HealthStatus(BaseModel): """ Response model for the health check endpoint. """ status: str = settings.HEALTH_STATUS_OK # Use configured default value database: str # app/schemas/item.py from pydantic import BaseModel, ConfigDict from datetime import datetime from typing import Optional from decimal import Decimal # Properties to return to client class ItemPublic(BaseModel): id: int list_id: int name: str quantity: Optional[str] = None is_complete: bool price: Optional[Decimal] = None added_by_id: int completed_by_id: Optional[int] = None created_at: datetime updated_at: datetime version: int model_config = ConfigDict(from_attributes=True) # Properties to receive via API on creation class ItemCreate(BaseModel): name: str quantity: Optional[str] = None # list_id will be from path param # added_by_id will be from current_user # Properties to receive via API on update class ItemUpdate(BaseModel): name: Optional[str] = None quantity: Optional[str] = None is_complete: Optional[bool] = None price: Optional[Decimal] = None # Price added here for update version: int # completed_by_id will be set internally if is_complete is true # app/schemas/list.py from pydantic import BaseModel, ConfigDict from datetime import datetime from typing import Optional, List from .item import ItemPublic # Import item schema for nesting # Properties to receive via API on creation class ListCreate(BaseModel): name: str description: Optional[str] = None group_id: Optional[int] = None # Optional for sharing # Properties to receive via API on update class ListUpdate(BaseModel): name: Optional[str] = None description: Optional[str] = None is_complete: Optional[bool] = None version: int # Client must provide the version for updates # Potentially add group_id update later if needed # Base properties returned by API (common fields) class ListBase(BaseModel): id: int name: str description: Optional[str] = None created_by_id: int group_id: Optional[int] = None is_complete: bool created_at: datetime updated_at: datetime version: int # Include version in responses model_config = ConfigDict(from_attributes=True) # Properties returned when listing lists (no items) class ListPublic(ListBase): pass # Inherits all from ListBase # Properties returned for a single list detail (includes items) class ListDetail(ListBase): items: List[ItemPublic] = [] # Include list of items class ListStatus(BaseModel): list_updated_at: datetime latest_item_updated_at: Optional[datetime] = None # Can be null if list has no items item_count: int # app/schemas/user.py from pydantic import BaseModel, EmailStr, ConfigDict from datetime import datetime from typing import Optional # Shared properties class UserBase(BaseModel): email: EmailStr name: Optional[str] = None # Properties to receive via API on creation class UserCreate(UserBase): password: str # Properties to receive via API on update (optional, add later if needed) # class UserUpdate(UserBase): # password: Optional[str] = None # Properties stored in DB class UserInDBBase(UserBase): id: int password_hash: str created_at: datetime model_config = ConfigDict(from_attributes=True) # Use orm_mode in Pydantic v1 # Additional properties to return via API (excluding password) class UserPublic(UserBase): id: int created_at: datetime model_config = ConfigDict(from_attributes=True) # Full user model including hashed password (for internal use/reading from DB) class User(UserInDBBase): pass .DS_Store .thumbs.db node_modules # Quasar core related directories .quasar /dist /quasar.config.*.temporary.compiled* # Cordova related directories and files /src-cordova/node_modules /src-cordova/platforms /src-cordova/plugins /src-cordova/www # Capacitor related directories and files /src-capacitor/www /src-capacitor/node_modules # Log files npm-debug.log* yarn-debug.log* yarn-error.log* # Editor directories and files .idea *.suo *.ntvs* *.njsproj *.sln # local .env files .env.local* # pnpm-related options shamefully-hoist=true strict-peer-dependencies=false # to get the latest compatible packages when creating the project https://github.com/pnpm/pnpm/issues/6463 resolution-mode=highest // Configuration for your app // https://v2.quasar.dev/quasar-cli-vite/quasar-config-file import { defineConfig } from '#q-app/wrappers'; import { fileURLToPath } from 'node:url'; export default defineConfig((ctx) => { return { // https://v2.quasar.dev/quasar-cli-vite/prefetch-feature // preFetch: true, // app boot file (/src/boot) // --> boot files are part of "main.js" // https://v2.quasar.dev/quasar-cli-vite/boot-files boot: ['i18n', 'axios'], // https://v2.quasar.dev/quasar-cli-vite/quasar-config-file#css css: ['app.scss'], // https://github.com/quasarframework/quasar/tree/dev/extras extras: [ // 'ionicons-v4', // 'mdi-v7', // 'fontawesome-v6', // 'eva-icons', // 'themify', // 'line-awesome', // 'roboto-font-latin-ext', // this or either 'roboto-font', NEVER both! 'roboto-font', // optional, you are not bound to it 'material-icons', // optional, you are not bound to it ], // Full list of options: https://v2.quasar.dev/quasar-cli-vite/quasar-config-file#build build: { target: { browser: ['es2022', 'firefox115', 'chrome115', 'safari14'], node: 'node20', }, typescript: { strict: true, vueShim: true, // extendTsConfig (tsConfig) {} }, vueRouterMode: 'hash', // available values: 'hash', 'history' // vueRouterBase, // vueDevtools, // vueOptionsAPI: false, // rebuildCache: true, // rebuilds Vite/linter/etc cache on startup // publicPath: '/', // analyze: true, // env: {}, // rawDefine: {} // ignorePublicFolder: true, // minify: false, // polyfillModulePreload: true, // distDir // extendViteConf (viteConf) {}, // viteVuePluginOptions: {}, vitePlugins: [ [ '@intlify/unplugin-vue-i18n/vite', { // if you want to use Vue I18n Legacy API, you need to set `compositionOnly: false` // compositionOnly: false, // if you want to use named tokens in your Vue I18n messages, such as 'Hello {name}', // you need to set `runtimeOnly: false` // runtimeOnly: false, ssr: ctx.modeName === 'ssr', // you need to set i18n resource including paths ! include: [fileURLToPath(new URL('./src/i18n', import.meta.url))], }, ], [ 'vite-plugin-checker', { vueTsc: true, eslint: { lintCommand: 'eslint -c ./eslint.config.js "./src*/**/*.{ts,js,mjs,cjs,vue}"', useFlatConfig: true, }, }, { server: false }, ], ], }, // Full list of options: https://v2.quasar.dev/quasar-cli-vite/quasar-config-file#devserver devServer: { // https: true, open: true, // opens browser window automatically }, // https://v2.quasar.dev/quasar-cli-vite/quasar-config-file#framework framework: { config: {}, // iconSet: 'material-icons', // Quasar icon set // lang: 'en-US', // Quasar language pack // For special cases outside of where the auto-import strategy can have an impact // (like functional components as one of the examples), // you can manually specify Quasar components/directives to be available everywhere: // // components: [], // directives: [], // Quasar plugins plugins: ['Notify'], }, // animations: 'all', // --- includes all animations // https://v2.quasar.dev/options/animations animations: [], // https://v2.quasar.dev/quasar-cli-vite/quasar-config-file#sourcefiles // sourceFiles: { // rootComponent: 'src/App.vue', // router: 'src/router/index', // store: 'src/store/index', // pwaRegisterServiceWorker: 'src-pwa/register-service-worker', // pwaServiceWorker: 'src-pwa/custom-service-worker', // pwaManifestFile: 'src-pwa/manifest.json', // electronMain: 'src-electron/electron-main', // electronPreload: 'src-electron/electron-preload' // bexManifestFile: 'src-bex/manifest.json // }, // https://v2.quasar.dev/quasar-cli-vite/developing-ssr/configuring-ssr ssr: { prodPort: 3000, // The default port that the production server should use // (gets superseded if process.env.PORT is specified at runtime) middlewares: [ 'render', // keep this as last one ], // extendPackageJson (json) {}, // extendSSRWebserverConf (esbuildConf) {}, // manualStoreSerialization: true, // manualStoreSsrContextInjection: true, // manualStoreHydration: true, // manualPostHydrationTrigger: true, pwa: false, // pwaOfflineHtmlFilename: 'offline.html', // do NOT use index.html as name! // pwaExtendGenerateSWOptions (cfg) {}, // pwaExtendInjectManifestOptions (cfg) {} }, // https://v2.quasar.dev/quasar-cli-vite/developing-pwa/configuring-pwa pwa: { workboxMode: 'InjectManifest', // Changed from 'GenerateSW' to 'InjectManifest' swFilename: 'sw.js', manifestFilename: 'manifest.json', injectPwaMetaTags: true, // extendManifestJson (json) {}, // useCredentialsForManifestTag: true, // extendPWACustomSWConf (esbuildConf) {}, // extendGenerateSWOptions (cfg) {}, // extendInjectManifestOptions (cfg) {} }, // Full list of options: https://v2.quasar.dev/quasar-cli-vite/developing-cordova-apps/configuring-cordova cordova: { // noIosLegacyBuildFlag: true, // uncomment only if you know what you are doing }, // Full list of options: https://v2.quasar.dev/quasar-cli-vite/developing-capacitor-apps/configuring-capacitor capacitor: { hideSplashscreen: true, }, // Full list of options: https://v2.quasar.dev/quasar-cli-vite/developing-electron-apps/configuring-electron electron: { // extendElectronMainConf (esbuildConf) {}, // extendElectronPreloadConf (esbuildConf) {}, // extendPackageJson (json) {}, // Electron preload scripts (if any) from /src-electron, WITHOUT file extension preloadScripts: ['electron-preload'], // specify the debugging port to use for the Electron app when running in development mode inspectPort: 5858, bundler: 'packager', // 'packager' or 'builder' packager: { // https://github.com/electron-userland/electron-packager/blob/master/docs/api.md#options // OS X / Mac App Store // appBundleId: '', // appCategoryType: '', // osxSign: '', // protocol: 'myapp://path', // Windows only // win32metadata: { ... } }, builder: { // https://www.electron.build/configuration/configuration appId: 'mitlist', }, }, // Full list of options: https://v2.quasar.dev/quasar-cli-vite/developing-browser-extensions/configuring-bex bex: { // extendBexScriptsConf (esbuildConf) {}, // extendBexManifestJson (json) {}, /** * The list of extra scripts (js/ts) not in your bex manifest that you want to * compile and use in your browser extension. Maybe dynamic use them? * * Each entry in the list should be a relative filename to /src-bex/ * * @example [ 'my-script.ts', 'sub-folder/my-other-script.js' ] */ extraScripts: [], }, }; }); # mitlist (mitlist) mitlist pwa ## Install the dependencies ```bash yarn # or npm install ``` ### Start the app in development mode (hot-code reloading, error reporting, etc.) ```bash quasar dev ``` ### Lint the files ```bash yarn lint # or npm run lint ``` ### Format the files ```bash yarn format # or npm run format ``` ### Build the app for production ```bash quasar build ``` ### Customize the configuration See [Configuring quasar.config.js](https://v2.quasar.dev/quasar-cli-vite/quasar-config-js). /* * This file (which will be your service worker) * is picked up by the build system ONLY if * quasar.config file > pwa > workboxMode is set to "InjectManifest" */ declare const self: ServiceWorkerGlobalScope & typeof globalThis & { skipWaiting: () => Promise }; import { clientsClaim } from 'workbox-core'; import { precacheAndRoute, cleanupOutdatedCaches, createHandlerBoundToURL, } from 'workbox-precaching'; import { registerRoute, NavigationRoute } from 'workbox-routing'; import { CacheFirst, NetworkFirst } from 'workbox-strategies'; import { ExpirationPlugin } from 'workbox-expiration'; import { CacheableResponsePlugin } from 'workbox-cacheable-response'; import type { WorkboxPlugin } from 'workbox-core/types'; self.skipWaiting().catch((error) => { console.error('Error during service worker activation:', error); }); clientsClaim(); // Use with precache injection precacheAndRoute(self.__WB_MANIFEST); cleanupOutdatedCaches(); // Cache app shell and static assets with Cache First strategy registerRoute( // Match static assets ({ request }) => request.destination === 'style' || request.destination === 'script' || request.destination === 'image' || request.destination === 'font', new CacheFirst({ cacheName: 'static-assets', plugins: [ new CacheableResponsePlugin({ statuses: [0, 200], }) as WorkboxPlugin, new ExpirationPlugin({ maxEntries: 60, maxAgeSeconds: 30 * 24 * 60 * 60, // 30 days }) as WorkboxPlugin, ], }) ); // Cache API calls with Network First strategy registerRoute( // Match API calls ({ url }) => url.pathname.startsWith('/api/'), new NetworkFirst({ cacheName: 'api-cache', plugins: [ new CacheableResponsePlugin({ statuses: [0, 200], }) as WorkboxPlugin, new ExpirationPlugin({ maxEntries: 50, maxAgeSeconds: 24 * 60 * 60, // 24 hours }) as WorkboxPlugin, ], }) ); // Non-SSR fallbacks to index.html // Production SSR fallbacks to offline.html (except for dev) if (process.env.MODE !== 'ssr' || process.env.PROD) { registerRoute( new NavigationRoute(createHandlerBoundToURL(process.env.PWA_FALLBACK_HTML), { denylist: [new RegExp(process.env.PWA_SERVICE_WORKER_REGEX), /workbox-(.)*\.js$/], }), ); } import type { RouteRecordRaw } from 'vue-router'; const routes: RouteRecordRaw[] = [ { path: '/', component: () => import('layouts/MainLayout.vue'), children: [ { path: '', redirect: '/lists' }, { path: 'lists', name: 'PersonalLists', component: () => import('pages/ListsPage.vue') }, { path: 'lists/:id', name: 'ListDetail', component: () => import('pages/ListDetailPage.vue'), props: true, }, { path: 'groups', name: 'GroupsList', component: () => import('pages/GroupsPage.vue') }, { path: 'groups/:id', name: 'GroupDetail', component: () => import('pages/GroupDetailPage.vue'), props: true, }, { path: 'groups/:groupId/lists', name: 'GroupLists', component: () => import('pages/ListsPage.vue'), props: true, }, { path: 'account', name: 'Account', component: () => import('pages/AccountPage.vue') }, ], }, { path: '/', component: () => import('layouts/AuthLayout.vue'), children: [ { path: 'login', component: () => import('pages/LoginPage.vue') }, { path: 'signup', component: () => import('pages/SignupPage.vue') }, ], }, // Always leave this as last one, // but you can also remove it { path: '/:catchAll(.*)*', component: () => import('pages/ErrorNotFound.vue'), }, ]; export default routes; import { defineStore } from 'pinia'; import { ref, computed } from 'vue'; import { apiClient, API_ENDPOINTS } from 'src/config/api'; interface AuthState { accessToken: string | null; refreshToken: string | null; user: { email: string; name: string; } | null; } export const useAuthStore = defineStore('auth', () => { // State const accessToken = ref(localStorage.getItem('token')); const refreshToken = ref(localStorage.getItem('refresh_token')); const user = ref(null); // Getters const isAuthenticated = computed(() => !!accessToken.value); const getUser = computed(() => user.value); // Actions const setTokens = (tokens: { access_token: string; refresh_token: string }) => { accessToken.value = tokens.access_token; refreshToken.value = tokens.refresh_token; localStorage.setItem('token', tokens.access_token); localStorage.setItem('refresh_token', tokens.refresh_token); }; const clearTokens = () => { accessToken.value = null; refreshToken.value = null; user.value = null; localStorage.removeItem('token'); localStorage.removeItem('refresh_token'); }; const setUser = (userData: AuthState['user']) => { user.value = userData; }; const login = async (email: string, password: string) => { const formData = new FormData(); formData.append('username', email); formData.append('password', password); const response = await apiClient.post(API_ENDPOINTS.AUTH.LOGIN, formData, { headers: { 'Content-Type': 'application/x-www-form-urlencoded', }, }); const { access_token, refresh_token } = response.data; setTokens({ access_token, refresh_token }); return response.data; }; const signup = async (userData: { name: string; email: string; password: string }) => { const response = await apiClient.post(API_ENDPOINTS.AUTH.SIGNUP, userData); return response.data; }; const refreshAccessToken = async () => { if (!refreshToken.value) { throw new Error('No refresh token available'); } try { const response = await apiClient.post(API_ENDPOINTS.AUTH.REFRESH_TOKEN, { refresh_token: refreshToken.value, }); const { access_token, refresh_token } = response.data; setTokens({ access_token, refresh_token }); return response.data; } catch (error) { clearTokens(); throw error; } }; const logout = () => { clearTokens(); }; return { // State accessToken, refreshToken, user, // Getters isAuthenticated, getUser, // Actions setTokens, clearTokens, setUser, login, signup, refreshAccessToken, logout, }; }); import { defineStore } from 'pinia'; import { ref, computed } from 'vue'; import { useQuasar } from 'quasar'; import { LocalStorage } from 'quasar'; export interface OfflineAction { id: string; type: 'add' | 'complete' | 'update' | 'delete'; itemId?: string; data: unknown; timestamp: number; version?: number; } export interface ConflictResolution { version: 'local' | 'server' | 'merge'; action: OfflineAction; } export interface ConflictData { localVersion: { data: Record; timestamp: number; }; serverVersion: { data: Record; timestamp: number; }; action: OfflineAction; } export const useOfflineStore = defineStore('offline', () => { const $q = useQuasar(); const isOnline = ref(navigator.onLine); const pendingActions = ref([]); const isProcessingQueue = ref(false); const showConflictDialog = ref(false); const currentConflict = ref(null); // Initialize from IndexedDB const init = () => { try { const stored = LocalStorage.getItem('offline-actions'); if (stored) { pendingActions.value = JSON.parse(stored as string); } } catch (error) { console.error('Failed to load offline actions:', error); } }; // Save to IndexedDB const saveToStorage = () => { try { LocalStorage.set('offline-actions', JSON.stringify(pendingActions.value)); } catch (error) { console.error('Failed to save offline actions:', error); } }; // Add a new offline action const addAction = (action: Omit) => { const newAction: OfflineAction = { ...action, id: crypto.randomUUID(), timestamp: Date.now(), }; pendingActions.value.push(newAction); saveToStorage(); }; // Process the queue when online const processQueue = async () => { if (isProcessingQueue.value || !isOnline.value) return; isProcessingQueue.value = true; const actions = [...pendingActions.value]; for (const action of actions) { try { await processAction(action); pendingActions.value = pendingActions.value.filter(a => a.id !== action.id); saveToStorage(); } catch (error) { if (error instanceof Error && error.message.includes('409')) { $q.notify({ type: 'warning', message: 'Item was modified by someone else while you were offline. Please review.', actions: [ { label: 'Review', color: 'white', handler: () => { // TODO: Implement conflict resolution UI } } ] }); } else { console.error('Failed to process offline action:', error); } } } isProcessingQueue.value = false; }; // Process a single action const processAction = async (action: OfflineAction) => { TODO: Implement actual API calls switch (action.type) { case 'add': // await api.addItem(action.data); break; case 'complete': // await api.completeItem(action.itemId, action.data); break; case 'update': // await api.updateItem(action.itemId, action.data); break; case 'delete': // await api.deleteItem(action.itemId); break; } }; // Listen for online/offline status changes const setupNetworkListeners = () => { window.addEventListener('online', () => { (async () => { isOnline.value = true; await processQueue(); })().catch(error => { console.error('Error processing queue:', error); }); }); window.addEventListener('offline', () => { isOnline.value = false; }); }; // Computed properties const hasPendingActions = computed(() => pendingActions.value.length > 0); const pendingActionCount = computed(() => pendingActions.value.length); // Initialize init(); setupNetworkListeners(); const handleConflictResolution = (resolution: ConflictResolution) => { // Implement the logic to handle the conflict resolution console.log('Conflict resolution:', resolution); }; return { isOnline, pendingActions, hasPendingActions, pendingActionCount, showConflictDialog, currentConflict, addAction, processQueue, handleConflictResolution, }; }); { "extends": "./.quasar/tsconfig.json" } # app/api/v1/endpoints/auth.py import logging from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db from app.schemas.user import UserCreate, UserPublic from app.schemas.auth import Token from app.crud import user as crud_user from app.core.security import ( verify_password, create_access_token, create_refresh_token, verify_refresh_token ) from app.core.exceptions import ( EmailAlreadyRegisteredError, InvalidCredentialsError, UserCreationError ) from app.config import settings logger = logging.getLogger(__name__) router = APIRouter() @router.post( "/signup", response_model=UserPublic, status_code=201, summary="Register New User", description="Creates a new user account.", tags=["Authentication"] ) async def signup( user_in: UserCreate, db: AsyncSession = Depends(get_db) ): """ Handles user registration. - Validates input data. - Checks if email already exists. - Hashes the password. - Stores the new user in the database. """ logger.info(f"Signup attempt for email: {user_in.email}") existing_user = await crud_user.get_user_by_email(db, email=user_in.email) if existing_user: logger.warning(f"Signup failed: Email already registered - {user_in.email}") raise EmailAlreadyRegisteredError() try: created_user = await crud_user.create_user(db=db, user_in=user_in) logger.info(f"User created successfully: {created_user.email} (ID: {created_user.id})") return created_user except Exception as e: logger.error(f"Error during user creation for {user_in.email}: {e}", exc_info=True) raise UserCreationError() @router.post( "/login", response_model=Token, summary="User Login", description="Authenticates a user and returns an access and refresh token.", tags=["Authentication"] ) async def login( form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: AsyncSession = Depends(get_db) ): """ Handles user login. - Finds user by email (provided in 'username' field of form). - Verifies the provided password against the stored hash. - Generates and returns JWT access and refresh tokens upon successful authentication. """ logger.info(f"Login attempt for user: {form_data.username}") user = await crud_user.get_user_by_email(db, email=form_data.username) if not user or not verify_password(form_data.password, user.password_hash): logger.warning(f"Login failed: Invalid credentials for user {form_data.username}") raise InvalidCredentialsError() access_token = create_access_token(subject=user.email) refresh_token = create_refresh_token(subject=user.email) logger.info(f"Login successful, tokens generated for user: {user.email}") return Token( access_token=access_token, refresh_token=refresh_token, token_type=settings.TOKEN_TYPE ) @router.post( "/refresh", response_model=Token, summary="Refresh Access Token", description="Refreshes an access token using a refresh token.", tags=["Authentication"] ) async def refresh_token( refresh_token_str: str, db: AsyncSession = Depends(get_db) ): """ Handles access token refresh. - Verifies the provided refresh token. - If valid, generates and returns a new JWT access token and the same refresh token. """ logger.info("Access token refresh attempt") payload = verify_refresh_token(refresh_token_str) if not payload: logger.warning("Refresh token invalid or expired") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired refresh token", headers={"WWW-Authenticate": "Bearer"}, ) user_email = payload.get("sub") if not user_email: logger.error("User email not found in refresh token payload") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token payload", headers={"WWW-Authenticate": "Bearer"}, ) new_access_token = create_access_token(subject=user_email) logger.info(f"Access token refreshed for user: {user_email}") return Token( access_token=new_access_token, refresh_token=refresh_token_str, token_type=settings.TOKEN_TYPE ) # app/api/v1/endpoints/lists.py import logging from typing import List as PyList, Optional # Alias for Python List type hint from fastapi import APIRouter, Depends, HTTPException, status, Response, Query # Added Query from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db from app.api.dependencies import get_current_user from app.models import User as UserModel from app.schemas.list import ListCreate, ListUpdate, ListPublic, ListDetail from app.schemas.message import Message # For simple responses from app.crud import list as crud_list from app.crud import group as crud_group # Need for group membership check from app.schemas.list import ListStatus from app.core.exceptions import ( GroupMembershipError, ListNotFoundError, ListPermissionError, ListStatusNotFoundError, ConflictError # Added ConflictError ) logger = logging.getLogger(__name__) router = APIRouter() @router.post( "", # Route relative to prefix "/lists" response_model=ListPublic, # Return basic list info on creation status_code=status.HTTP_201_CREATED, summary="Create New List", tags=["Lists"] ) async def create_list( list_in: ListCreate, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """ Creates a new shopping list. - If `group_id` is provided, the user must be a member of that group. - If `group_id` is null, it's a personal list. """ logger.info(f"User {current_user.email} creating list: {list_in.name}") group_id = list_in.group_id # Permission Check: If sharing with a group, verify membership if group_id: is_member = await crud_group.is_user_member(db, group_id=group_id, user_id=current_user.id) if not is_member: logger.warning(f"User {current_user.email} attempted to create list in group {group_id} but is not a member.") raise GroupMembershipError(group_id, "create lists") created_list = await crud_list.create_list(db=db, list_in=list_in, creator_id=current_user.id) logger.info(f"List '{created_list.name}' (ID: {created_list.id}) created successfully for user {current_user.email}.") return created_list @router.get( "", # Route relative to prefix "/lists" response_model=PyList[ListPublic], # Return a list of basic list info summary="List Accessible Lists", tags=["Lists"] ) async def read_lists( db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), # Add pagination parameters later if needed: skip: int = 0, limit: int = 100 ): """ Retrieves lists accessible to the current user: - Personal lists created by the user. - Lists belonging to groups the user is a member of. """ logger.info(f"Fetching lists accessible to user: {current_user.email}") lists = await crud_list.get_lists_for_user(db=db, user_id=current_user.id) return lists @router.get( "/{list_id}", response_model=ListDetail, # Return detailed list info including items summary="Get List Details", tags=["Lists"] ) async def read_list( list_id: int, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """ Retrieves details for a specific list, including its items, if the user has permission (creator or group member). """ logger.info(f"User {current_user.email} requesting details for list ID: {list_id}") # The check_list_permission function will raise appropriate exceptions list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) return list_db @router.put( "/{list_id}", response_model=ListPublic, # Return updated basic info summary="Update List", tags=["Lists"], responses={ # Add 409 to responses status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified by someone else"} } ) async def update_list( list_id: int, list_in: ListUpdate, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """ Updates a list's details (name, description, is_complete). Requires user to be the creator or a member of the list's group. The client MUST provide the current `version` of the list in the `list_in` payload. If the version does not match, a 409 Conflict is returned. """ logger.info(f"User {current_user.email} attempting to update list ID: {list_id} with version {list_in.version}") list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) try: updated_list = await crud_list.update_list(db=db, list_db=list_db, list_in=list_in) logger.info(f"List {list_id} updated successfully by user {current_user.email} to version {updated_list.version}.") return updated_list except ConflictError as e: # Catch and re-raise as HTTPException for proper FastAPI response logger.warning(f"Conflict updating list {list_id} for user {current_user.email}: {str(e)}") raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) except Exception as e: # Catch other potential errors from crud operation logger.error(f"Error updating list {list_id} for user {current_user.email}: {str(e)}") # Consider a more generic error, but for now, let's keep it specific if possible # Re-raising might be better if crud layer already raises appropriate HTTPExceptions raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the list.") @router.delete( "/{list_id}", status_code=status.HTTP_204_NO_CONTENT, # Standard for successful DELETE with no body summary="Delete List", tags=["Lists"], responses={ # Add 409 to responses status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified, cannot delete specified version"} } ) async def delete_list( list_id: int, expected_version: Optional[int] = Query(None, description="The expected version of the list to delete for optimistic locking."), db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """ Deletes a list. Requires user to be the creator of the list. If `expected_version` is provided and does not match the list's current version, a 409 Conflict is returned. """ logger.info(f"User {current_user.email} attempting to delete list ID: {list_id}, expected version: {expected_version}") # Use the helper, requiring creator permission list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id, require_creator=True) if expected_version is not None and list_db.version != expected_version: logger.warning( f"Conflict deleting list {list_id} for user {current_user.email}. " f"Expected version {expected_version}, actual version {list_db.version}." ) raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail=f"List has been modified. Expected version {expected_version}, but current version is {list_db.version}. Please refresh." ) await crud_list.delete_list(db=db, list_db=list_db) logger.info(f"List {list_id} (version: {list_db.version}) deleted successfully by user {current_user.email}.") return Response(status_code=status.HTTP_204_NO_CONTENT) @router.get( "/{list_id}/status", response_model=ListStatus, summary="Get List Status", tags=["Lists"] ) async def read_list_status( list_id: int, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """ Retrieves the last update time for the list and its items, plus item count. Used for polling to check if a full refresh is needed. Requires user to have permission to view the list. """ # Verify user has access to the list first list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) if not list_db: # Check if list exists at all for correct error code exists = await crud_list.get_list_by_id(db, list_id) if not exists: raise ListNotFoundError(list_id) raise ListPermissionError(list_id, "access this list's status") # Fetch the status details list_status = await crud_list.get_list_status(db=db, list_id=list_id) if not list_status: # Should not happen if check_list_permission passed, but handle defensively logger.error(f"Could not retrieve status for list {list_id} even though permission check passed.") raise ListStatusNotFoundError(list_id) return list_status import logging from typing import List from fastapi import APIRouter, Depends, UploadFile, File, HTTPException, status from google.api_core import exceptions as google_exceptions from app.api.dependencies import get_current_user from app.models import User as UserModel from app.schemas.ocr import OcrExtractResponse from app.core.gemini import extract_items_from_image_gemini, gemini_initialization_error, GeminiOCRService from app.core.exceptions import ( OCRServiceUnavailableError, OCRServiceConfigError, OCRUnexpectedError, OCRQuotaExceededError, InvalidFileTypeError, FileTooLargeError, OCRProcessingError ) from app.config import settings logger = logging.getLogger(__name__) router = APIRouter() ocr_service = GeminiOCRService() @router.post( "/extract-items", response_model=OcrExtractResponse, summary="Extract List Items via OCR (Gemini)", tags=["OCR"] ) async def ocr_extract_items( current_user: UserModel = Depends(get_current_user), image_file: UploadFile = File(..., description="Image file (JPEG, PNG, WEBP) of the shopping list or receipt."), ): """ Accepts an image upload, sends it to Gemini Flash with a prompt to extract shopping list items, and returns the parsed items. """ # Check if Gemini client initialized correctly if gemini_initialization_error: logger.error("OCR endpoint called but Gemini client failed to initialize.") raise OCRServiceUnavailableError(gemini_initialization_error) logger.info(f"User {current_user.email} uploading image '{image_file.filename}' for OCR extraction.") # --- File Validation --- if image_file.content_type not in settings.ALLOWED_IMAGE_TYPES: logger.warning(f"Invalid file type uploaded by {current_user.email}: {image_file.content_type}") raise InvalidFileTypeError() # Simple size check contents = await image_file.read() if len(contents) > settings.MAX_FILE_SIZE_MB * 1024 * 1024: logger.warning(f"File too large uploaded by {current_user.email}: {len(contents)} bytes") raise FileTooLargeError() try: # Call the Gemini helper function extracted_items = await extract_items_from_image_gemini( image_bytes=contents, mime_type=image_file.content_type ) logger.info(f"Successfully extracted {len(extracted_items)} items for user {current_user.email}.") return OcrExtractResponse(extracted_items=extracted_items) except OCRServiceUnavailableError: raise OCRServiceUnavailableError() except OCRServiceConfigError: raise OCRServiceConfigError() except OCRQuotaExceededError: raise OCRQuotaExceededError() except Exception as e: raise OCRProcessingError(str(e)) finally: # Ensure file handle is closed await image_file.close() # app/core/gemini.py import logging from typing import List import google.generativeai as genai from google.generativeai.types import HarmCategory, HarmBlockThreshold # For safety settings from google.api_core import exceptions as google_exceptions from app.config import settings from app.core.exceptions import ( OCRServiceUnavailableError, OCRServiceConfigError, OCRUnexpectedError, OCRQuotaExceededError ) logger = logging.getLogger(__name__) # --- Global variable to hold the initialized model client --- gemini_flash_client = None gemini_initialization_error = None # Store potential init error # --- Configure and Initialize --- try: if settings.GEMINI_API_KEY: genai.configure(api_key=settings.GEMINI_API_KEY) # Initialize the specific model we want to use gemini_flash_client = genai.GenerativeModel( model_name=settings.GEMINI_MODEL_NAME, # Safety settings from config safety_settings={ getattr(HarmCategory, category): getattr(HarmBlockThreshold, threshold) for category, threshold in settings.GEMINI_SAFETY_SETTINGS.items() }, # Generation config from settings generation_config=genai.types.GenerationConfig( **settings.GEMINI_GENERATION_CONFIG ) ) logger.info(f"Gemini AI client initialized successfully for model '{settings.GEMINI_MODEL_NAME}'.") else: # Store error if API key is missing gemini_initialization_error = "GEMINI_API_KEY not configured. Gemini client not initialized." logger.error(gemini_initialization_error) except Exception as e: # Catch any other unexpected errors during initialization gemini_initialization_error = f"Failed to initialize Gemini AI client: {e}" logger.exception(gemini_initialization_error) # Log full traceback gemini_flash_client = None # Ensure client is None on error # --- Function to get the client (optional, allows checking error) --- def get_gemini_client(): """ Returns the initialized Gemini client instance. Raises an exception if initialization failed. """ if gemini_initialization_error: raise RuntimeError(f"Gemini client could not be initialized: {gemini_initialization_error}") if gemini_flash_client is None: # This case should ideally be covered by the check above, but as a safeguard: raise RuntimeError("Gemini client is not available (unknown initialization issue).") return gemini_flash_client # Define the prompt as a constant OCR_ITEM_EXTRACTION_PROMPT = """ Extract the shopping list items from this image. List each distinct item on a new line. Ignore prices, quantities, store names, discounts, taxes, totals, and other non-item text. Focus only on the names of the products or items to be purchased. If the image does not appear to be a shopping list or receipt, state that clearly. Example output for a grocery list: Milk Eggs Bread Apples Organic Bananas """ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "image/jpeg") -> List[str]: """ Uses Gemini Flash to extract shopping list items from image bytes. Args: image_bytes: The image content as bytes. mime_type: The MIME type of the image (e.g., "image/jpeg", "image/png", "image/webp"). Returns: A list of extracted item strings. Raises: RuntimeError: If the Gemini client is not initialized. google_exceptions.GoogleAPIError: For API call errors (quota, invalid key etc.). ValueError: If the response is blocked or contains no usable text. """ client = get_gemini_client() # Raises RuntimeError if not initialized # Prepare image part for multimodal input image_part = { "mime_type": mime_type, "data": image_bytes } # Prepare the full prompt content prompt_parts = [ settings.OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first image_part # Then the image ] logger.info("Sending image to Gemini for item extraction...") try: # Make the API call # Use generate_content_async for async FastAPI response = await client.generate_content_async(prompt_parts) # --- Process the response --- # Check for safety blocks or lack of content if not response.candidates or not response.candidates[0].content.parts: logger.warning("Gemini response blocked or empty.", extra={"response": response}) # Check finish_reason if available finish_reason = response.candidates[0].finish_reason if response.candidates else 'UNKNOWN' safety_ratings = response.candidates[0].safety_ratings if response.candidates else 'N/A' if finish_reason == 'SAFETY': raise ValueError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}") else: raise ValueError(f"Gemini response was empty or incomplete. Finish Reason: {finish_reason}") # Extract text - assumes the first part of the first candidate is the text response raw_text = response.text # response.text is a shortcut for response.candidates[0].content.parts[0].text logger.info("Received raw text from Gemini.") # logger.debug(f"Gemini Raw Text:\n{raw_text}") # Optional: Log full response text # Parse the text response items = [] for line in raw_text.splitlines(): # Split by newline cleaned_line = line.strip() # Remove leading/trailing whitespace # Basic filtering: ignore empty lines and potential non-item lines if cleaned_line and len(cleaned_line) > 1: # Ignore very short lines too? # Add more sophisticated filtering if needed (e.g., regex, keyword check) items.append(cleaned_line) logger.info(f"Extracted {len(items)} potential items.") return items except google_exceptions.GoogleAPIError as e: logger.error(f"Gemini API Error: {e}", exc_info=True) # Re-raise specific Google API errors for endpoint to handle (e.g., quota) raise e except Exception as e: # Catch other unexpected errors during generation or processing logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True) # Wrap in a generic ValueError or re-raise raise ValueError(f"Failed to process image with Gemini: {e}") from e class GeminiOCRService: def __init__(self): try: genai.configure(api_key=settings.GEMINI_API_KEY) self.model = genai.GenerativeModel(settings.GEMINI_MODEL_NAME) self.model.safety_settings = settings.GEMINI_SAFETY_SETTINGS self.model.generation_config = settings.GEMINI_GENERATION_CONFIG except Exception as e: logger.error(f"Failed to initialize Gemini client: {e}") raise OCRServiceConfigError() async def extract_items(self, image_data: bytes) -> List[str]: """ Extract shopping list items from an image using Gemini Vision. """ try: # Create image part image_parts = [{"mime_type": "image/jpeg", "data": image_data}] # Generate content response = await self.model.generate_content_async( contents=[settings.OCR_ITEM_EXTRACTION_PROMPT, *image_parts] ) # Process response if not response.text: raise OCRUnexpectedError() # Split response into lines and clean up items = [ item.strip() for item in response.text.split("\n") if item.strip() and not item.strip().startswith("Example") ] return items except Exception as e: logger.error(f"Error during OCR extraction: {e}") if "quota" in str(e).lower(): raise OCRQuotaExceededError() raise OCRServiceUnavailableError() # be/Dockerfile # Choose a suitable Python base image FROM python:3.11-slim # Set environment variables ENV PYTHONDONTWRITEBYTECODE 1 # Prevent python from writing pyc files ENV PYTHONUNBUFFERED 1 # Keep stdout/stderr unbuffered # Set the working directory in the container WORKDIR /app # Install system dependencies if needed (e.g., for psycopg2 build) # RUN apt-get update && apt-get install -y --no-install-recommends gcc build-essential libpq-dev && rm -rf /var/lib/apt/lists/* # Install Python dependencies # Upgrade pip first RUN pip install --no-cache-dir --upgrade pip # Copy only requirements first to leverage Docker cache COPY requirements.txt requirements.txt # Install dependencies RUN pip install --no-cache-dir -r requirements.txt # Copy the rest of the application code into the working directory COPY . . # This includes your 'app/' directory, alembic.ini, etc. # Expose the port the app runs on EXPOSE 8000 # Command to run the application using uvicorn # The default command for production (can be overridden in docker-compose for development) # Note: Make sure 'app.main:app' correctly points to your FastAPI app instance # relative to the WORKDIR (/app). If your main.py is directly in /app, this is correct. CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] fastapi>=0.95.0 uvicorn[standard]>=0.20.0 sqlalchemy[asyncio]>=2.0.0 # Core ORM + Async support asyncpg>=0.27.0 # Async PostgreSQL driver psycopg2-binary>=2.9.0 # Often needed by Alembic even if app uses asyncpg alembic>=1.9.0 # Database migrations pydantic-settings>=2.0.0 # For loading settings from .env python-dotenv>=1.0.0 # To load .env file for scripts/alembic passlib[bcrypt]>=1.7.4 python-jose[cryptography]>=3.3.0 pydantic[email] google-generativeai>=0.5.0 { "name": "mitlist", "version": "0.0.1", "description": "mitlist pwa", "productName": "mitlist", "author": "Mohamad ", "type": "module", "private": true, "scripts": { "lint": "eslint -c ./eslint.config.js \"./src*/**/*.{ts,js,cjs,mjs,vue}\"", "format": "prettier --write \"**/*.{js,ts,vue,scss,html,md,json}\" --ignore-path .gitignore", "test": "echo \"No test specified\" && exit 0", "dev": "quasar dev", "build": "quasar build", "postinstall": "quasar prepare" }, "dependencies": { "@quasar/extras": "^1.16.4", "axios": "^1.2.1", "pinia": "^3.0.1", "quasar": "^2.16.0", "register-service-worker": "^1.7.2", "vue": "^3.4.18", "vue-i18n": "^11.0.0", "vue-router": "^4.0.12" }, "devDependencies": { "@eslint/js": "^9.14.0", "@intlify/unplugin-vue-i18n": "^4.0.0", "@quasar/app-vite": "^2.1.0", "@types/node": "^20.5.9", "@vue/eslint-config-prettier": "^10.1.0", "@vue/eslint-config-typescript": "^14.4.0", "autoprefixer": "^10.4.2", "eslint": "^9.14.0", "eslint-plugin-vue": "^9.30.0", "globals": "^15.12.0", "prettier": "^3.3.3", "typescript": "~5.5.3", "vite-plugin-checker": "^0.9.0", "vue-tsc": "^2.0.29", "workbox-build": "^7.3.0", "workbox-cacheable-response": "^7.3.0", "workbox-core": "^7.3.0", "workbox-expiration": "^7.3.0", "workbox-precaching": "^7.3.0", "workbox-routing": "^7.3.0", "workbox-strategies": "^7.3.0" }, "engines": { "node": "^28 || ^26 || ^24 || ^22 || ^20 || ^18", "npm": ">= 6.13.4", "yarn": ">= 1.21.1" } } import { boot } from 'quasar/wrappers'; import axios from 'axios'; import { API_BASE_URL } from 'src/config/api-config'; // Create axios instance const api = axios.create({ baseURL: API_BASE_URL, headers: { 'Content-Type': 'application/json', }, }); // Request interceptor api.interceptors.request.use( (config) => { const token = localStorage.getItem('token'); if (token) { config.headers.Authorization = `Bearer ${token}`; } return config; }, (error) => { return Promise.reject(new Error(String(error))); } ); // Response interceptor api.interceptors.response.use( (response) => response, async (error) => { const originalRequest = error.config; // If error is 401 and we haven't tried to refresh token yet if (error.response?.status === 401 && !originalRequest._retry) { originalRequest._retry = true; try { const refreshToken = localStorage.getItem('refreshToken'); if (!refreshToken) { throw new Error('No refresh token available'); } // Call refresh token endpoint const response = await api.post('/api/v1/auth/refresh-token', { refresh_token: refreshToken, }); const { access_token } = response.data; localStorage.setItem('token', access_token); // Retry the original request with new token originalRequest.headers.Authorization = `Bearer ${access_token}`; return api(originalRequest); } catch (refreshError) { // If refresh token fails, clear storage and redirect to login localStorage.removeItem('token'); localStorage.removeItem('refreshToken'); window.location.href = '/login'; return Promise.reject(new Error(String(refreshError))); } } return Promise.reject(new Error(String(error))); } ); export default boot(({ app }) => { app.config.globalProperties.$axios = axios; app.config.globalProperties.$api = api; }); export { api }; # app/api/v1/endpoints/costs.py import logging from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session, selectinload from decimal import Decimal, ROUND_HALF_UP from app.database import get_db from app.api.dependencies import get_current_user from app.models import ( User as UserModel, Group as GroupModel, List as ListModel, Expense as ExpenseModel, Item as ItemModel, UserGroup as UserGroupModel, SplitTypeEnum, ExpenseSplit as ExpenseSplitModel, Settlement as SettlementModel ) from app.schemas.cost import ListCostSummary, GroupBalanceSummary from app.schemas.expense import ExpenseCreate from app.crud import list as crud_list from app.crud import expense as crud_expense from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError, GroupNotFoundError logger = logging.getLogger(__name__) router = APIRouter() @router.get( "/lists/{list_id}/cost-summary", response_model=ListCostSummary, summary="Get Cost Summary for a List", tags=["Costs"], responses={ status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this list"}, status.HTTP_404_NOT_FOUND: {"description": "List or associated user not found"} } ) async def get_list_cost_summary( list_id: int, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """ Retrieves a calculated cost summary for a specific list, detailing total costs, equal shares per user, and individual user balances based on their contributions. The user must have access to the list to view its cost summary. Costs are split among group members if the list belongs to a group, or just for the creator if it's a personal list. All users who added items with prices are included in the calculation. """ logger.info(f"User {current_user.email} requesting cost summary for list {list_id}") # 1. Verify user has access to the target list try: await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) except ListPermissionError as e: logger.warning(f"Permission denied for user {current_user.email} on list {list_id}: {str(e)}") raise except ListNotFoundError as e: logger.warning(f"List {list_id} not found when checking permissions for cost summary: {str(e)}") raise # 2. Get the list with its items and users list_result = await db.execute( select(ListModel) .options( selectinload(ListModel.items).options(selectinload(ItemModel.added_by_user)), selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))), selectinload(ListModel.creator) ) .where(ListModel.id == list_id) ) db_list = list_result.scalars().first() if not db_list: raise ListNotFoundError(list_id) # 3. Get or create an expense for this list expense_result = await db.execute( select(ExpenseModel) .where(ExpenseModel.list_id == list_id) .options(selectinload(ExpenseModel.splits)) ) db_expense = expense_result.scalars().first() if not db_expense: # Create a new expense for this list total_amount = sum(item.price for item in db_list.items if item.price is not None and item.price > Decimal("0")) if total_amount == Decimal("0"): return ListCostSummary( list_id=db_list.id, list_name=db_list.name, total_list_cost=Decimal("0.00"), num_participating_users=0, equal_share_per_user=Decimal("0.00"), user_balances=[] ) # Create expense with ITEM_BASED split type expense_in = ExpenseCreate( description=f"Cost summary for list {db_list.name}", total_amount=total_amount, list_id=list_id, split_type=SplitTypeEnum.ITEM_BASED, paid_by_user_id=current_user.id # Use current user as payer for now ) db_expense = await crud_expense.create_expense(db=db, expense_in=expense_in) # 4. Calculate cost summary from expense splits participating_users = set() user_items_added_value = {} total_list_cost = Decimal("0.00") # Get all users who added items for item in db_list.items: if item.price is not None and item.price > Decimal("0") and item.added_by_user: participating_users.add(item.added_by_user) user_items_added_value[item.added_by_user.id] = user_items_added_value.get(item.added_by_user.id, Decimal("0.00")) + item.price total_list_cost += item.price # Get all users from expense splits for split in db_expense.splits: if split.user: participating_users.add(split.user) num_participating_users = len(participating_users) if num_participating_users == 0: return ListCostSummary( list_id=db_list.id, list_name=db_list.name, total_list_cost=Decimal("0.00"), num_participating_users=0, equal_share_per_user=Decimal("0.00"), user_balances=[] ) equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) remainder = total_list_cost - (equal_share_per_user * num_participating_users) user_balances = [] first_user_processed = False for user in participating_users: items_added = user_items_added_value.get(user.id, Decimal("0.00")) current_user_share = equal_share_per_user if not first_user_processed and remainder != Decimal("0"): current_user_share += remainder first_user_processed = True balance = items_added - current_user_share user_identifier = user.name if user.name else user.email user_balances.append( UserCostShare( user_id=user.id, user_identifier=user_identifier, items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) ) ) user_balances.sort(key=lambda x: x.user_identifier) return ListCostSummary( list_id=db_list.id, list_name=db_list.name, total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), num_participating_users=num_participating_users, equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), user_balances=user_balances ) @router.get( "/groups/{group_id}/balance-summary", response_model=GroupBalanceSummary, summary="Get Detailed Balance Summary for a Group", tags=["Costs", "Groups"], responses={ status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this group"}, status.HTTP_404_NOT_FOUND: {"description": "Group not found"} } ) async def get_group_balance_summary( group_id: int, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """ Retrieves a detailed financial balance summary for all users within a specific group. It considers all expenses, their splits, and all settlements recorded for the group. The user must be a member of the group to view its balance summary. """ logger.info(f"User {current_user.email} requesting balance summary for group {group_id}") # 1. Verify user is a member of the target group group_check = await db.execute( select(GroupModel) .options(selectinload(GroupModel.member_associations)) .where(GroupModel.id == group_id) ) db_group_for_check = group_check.scalars().first() if not db_group_for_check: raise GroupNotFoundError(group_id) user_is_member = any(assoc.user_id == current_user.id for assoc in db_group_for_check.member_associations) if not user_is_member: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"User not a member of group {group_id}") # 2. Get all expenses and settlements for the group expenses_result = await db.execute( select(ExpenseModel) .where(ExpenseModel.group_id == group_id) .options(selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user)) ) expenses = expenses_result.scalars().all() settlements_result = await db.execute( select(SettlementModel) .where(SettlementModel.group_id == group_id) .options( selectinload(SettlementModel.paid_by_user), selectinload(SettlementModel.paid_to_user) ) ) settlements = settlements_result.scalars().all() # 3. Calculate user balances user_balances_data = {} for assoc in db_group_for_check.member_associations: if assoc.user: user_balances_data[assoc.user.id] = UserBalanceDetail( user_id=assoc.user.id, user_identifier=assoc.user.name if assoc.user.name else assoc.user.email ) # Process expenses for expense in expenses: if expense.paid_by_user_id in user_balances_data: user_balances_data[expense.paid_by_user_id].total_paid_for_expenses += expense.total_amount for split in expense.splits: if split.user_id in user_balances_data: user_balances_data[split.user_id].total_share_of_expenses += split.owed_amount # Process settlements for settlement in settlements: if settlement.paid_by_user_id in user_balances_data: user_balances_data[settlement.paid_by_user_id].total_settlements_paid += settlement.amount if settlement.paid_to_user_id in user_balances_data: user_balances_data[settlement.paid_to_user_id].total_settlements_received += settlement.amount # Calculate net balances final_user_balances = [] for user_id, data in user_balances_data.items(): data.net_balance = ( data.total_paid_for_expenses + data.total_settlements_received ) - (data.total_share_of_expenses + data.total_settlements_paid) data.total_paid_for_expenses = data.total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) data.total_share_of_expenses = data.total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) data.total_settlements_paid = data.total_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) data.total_settlements_received = data.total_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) data.net_balance = data.net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) final_user_balances.append(data) # Sort by user identifier final_user_balances.sort(key=lambda x: x.user_identifier) # Calculate suggested settlements suggested_settlements = calculate_suggested_settlements(final_user_balances) return GroupBalanceSummary( group_id=db_group_for_check.id, group_name=db_group_for_check.name, user_balances=final_user_balances, suggested_settlements=suggested_settlements ) # app/api/v1/endpoints/items.py import logging from typing import List as PyList, Optional from fastapi import APIRouter, Depends, HTTPException, status, Response, Query from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db from app.api.dependencies import get_current_user # --- Import Models Correctly --- from app.models import User as UserModel from app.models import Item as ItemModel # <-- IMPORT Item and alias it # --- End Import Models --- from app.schemas.item import ItemCreate, ItemUpdate, ItemPublic from app.crud import item as crud_item from app.crud import list as crud_list from app.core.exceptions import ItemNotFoundError, ListPermissionError, ConflictError logger = logging.getLogger(__name__) router = APIRouter() # --- Helper Dependency for Item Permissions --- # Now ItemModel is defined before being used as a type hint async def get_item_and_verify_access( item_id: int, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user) ) -> ItemModel: """Dependency to get an item and verify the user has access to its list.""" item_db = await crud_item.get_item_by_id(db, item_id=item_id) if not item_db: raise ItemNotFoundError(item_id) # Check permission on the parent list try: await crud_list.check_list_permission(db=db, list_id=item_db.list_id, user_id=current_user.id) except ListPermissionError as e: # Re-raise with a more specific message raise ListPermissionError(item_db.list_id, "access this item's list") return item_db # --- Endpoints --- @router.post( "/lists/{list_id}/items", # Nested under lists response_model=ItemPublic, status_code=status.HTTP_201_CREATED, summary="Add Item to List", tags=["Items"] ) async def create_list_item( list_id: int, item_in: ItemCreate, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), ): """Adds a new item to a specific list. User must have access to the list.""" user_email = current_user.email # Access email attribute before async operations logger.info(f"User {user_email} adding item to list {list_id}: {item_in.name}") # Verify user has access to the target list try: await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) except ListPermissionError as e: # Re-raise with a more specific message raise ListPermissionError(list_id, "add items to this list") created_item = await crud_item.create_item( db=db, item_in=item_in, list_id=list_id, user_id=current_user.id ) logger.info(f"Item '{created_item.name}' (ID: {created_item.id}) added to list {list_id} by user {user_email}.") return created_item @router.get( "/lists/{list_id}/items", # Nested under lists response_model=PyList[ItemPublic], summary="List Items in List", tags=["Items"] ) async def read_list_items( list_id: int, db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), # Add sorting/filtering params later if needed: sort_by: str = 'created_at', order: str = 'asc' ): """Retrieves all items for a specific list if the user has access.""" user_email = current_user.email # Access email attribute before async operations logger.info(f"User {user_email} listing items for list {list_id}") # Verify user has access to the list try: await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) except ListPermissionError as e: # Re-raise with a more specific message raise ListPermissionError(list_id, "view items in this list") items = await crud_item.get_items_by_list_id(db=db, list_id=list_id) return items @router.put( "/items/{item_id}", # Operate directly on item ID response_model=ItemPublic, summary="Update Item", tags=["Items"], responses={ status.HTTP_409_CONFLICT: {"description": "Conflict: Item has been modified by someone else"} } ) async def update_item( item_id: int, # Item ID from path item_in: ItemUpdate, item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), # Need user ID for completed_by ): """ Updates an item's details (name, quantity, is_complete, price). User must have access to the list the item belongs to. The client MUST provide the current `version` of the item in the `item_in` payload. If the version does not match, a 409 Conflict is returned. Sets/unsets `completed_by_id` based on `is_complete` flag. """ user_email = current_user.email # Access email attribute before async operations logger.info(f"User {user_email} attempting to update item ID: {item_id} with version {item_in.version}") # Permission check is handled by get_item_and_verify_access dependency try: updated_item = await crud_item.update_item( db=db, item_db=item_db, item_in=item_in, user_id=current_user.id ) logger.info(f"Item {item_id} updated successfully by user {user_email} to version {updated_item.version}.") return updated_item except ConflictError as e: logger.warning(f"Conflict updating item {item_id} for user {user_email}: {str(e)}") raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) except Exception as e: logger.error(f"Error updating item {item_id} for user {user_email}: {str(e)}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the item.") @router.delete( "/items/{item_id}", # Operate directly on item ID status_code=status.HTTP_204_NO_CONTENT, summary="Delete Item", tags=["Items"], responses={ status.HTTP_409_CONFLICT: {"description": "Conflict: Item has been modified, cannot delete specified version"} } ) async def delete_item( item_id: int, # Item ID from path expected_version: Optional[int] = Query(None, description="The expected version of the item to delete for optimistic locking."), item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access db: AsyncSession = Depends(get_db), current_user: UserModel = Depends(get_current_user), # Log who deleted it ): """ Deletes an item. User must have access to the list the item belongs to. If `expected_version` is provided and does not match the item's current version, a 409 Conflict is returned. """ user_email = current_user.email # Access email attribute before async operations logger.info(f"User {user_email} attempting to delete item ID: {item_id}, expected version: {expected_version}") # Permission check is handled by get_item_and_verify_access dependency if expected_version is not None and item_db.version != expected_version: logger.warning( f"Conflict deleting item {item_id} for user {user_email}. " f"Expected version {expected_version}, actual version {item_db.version}." ) raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail=f"Item has been modified. Expected version {expected_version}, but current version is {item_db.version}. Please refresh." ) await crud_item.delete_item(db=db, item_db=item_db) logger.info(f"Item {item_id} (version {item_db.version}) deleted successfully by user {user_email}.") return Response(status_code=status.HTTP_204_NO_CONTENT) # app/crud/group.py from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload # For eager loading members from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError from typing import Optional, List from sqlalchemy import func from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel from app.schemas.group import GroupCreate from app.models import UserRoleEnum # Import enum from app.core.exceptions import ( GroupOperationError, GroupNotFoundError, DatabaseConnectionError, DatabaseIntegrityError, DatabaseQueryError, DatabaseTransactionError, GroupMembershipError, GroupPermissionError # Import GroupPermissionError ) # --- Group CRUD --- async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) -> GroupModel: """Creates a group and adds the creator as the owner.""" try: async with db.begin(): db_group = GroupModel(name=group_in.name, created_by_id=creator_id) db.add(db_group) await db.flush() db_user_group = UserGroupModel( user_id=creator_id, group_id=db_group.id, role=UserRoleEnum.owner ) db.add(db_user_group) await db.flush() await db.refresh(db_group) return db_group except IntegrityError as e: raise DatabaseIntegrityError(f"Failed to create group: {str(e)}") except OperationalError as e: raise DatabaseConnectionError(f"Database connection error: {str(e)}") except SQLAlchemyError as e: raise DatabaseTransactionError(f"Failed to create group: {str(e)}") async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]: """Gets all groups a user is a member of.""" try: result = await db.execute( select(GroupModel) .join(UserGroupModel) .where(UserGroupModel.user_id == user_id) .options(selectinload(GroupModel.member_associations)) ) return result.scalars().all() except OperationalError as e: raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") except SQLAlchemyError as e: raise DatabaseQueryError(f"Failed to query user groups: {str(e)}") async def get_group_by_id(db: AsyncSession, group_id: int) -> Optional[GroupModel]: """Gets a single group by its ID, optionally loading members.""" try: result = await db.execute( select(GroupModel) .where(GroupModel.id == group_id) .options( selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user) ) ) return result.scalars().first() except OperationalError as e: raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") except SQLAlchemyError as e: raise DatabaseQueryError(f"Failed to query group: {str(e)}") async def is_user_member(db: AsyncSession, group_id: int, user_id: int) -> bool: """Checks if a user is a member of a specific group.""" try: result = await db.execute( select(UserGroupModel.id) .where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id) .limit(1) ) return result.scalar_one_or_none() is not None except OperationalError as e: raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") except SQLAlchemyError as e: raise DatabaseQueryError(f"Failed to check group membership: {str(e)}") async def get_user_role_in_group(db: AsyncSession, group_id: int, user_id: int) -> Optional[UserRoleEnum]: """Gets the role of a user in a specific group.""" try: result = await db.execute( select(UserGroupModel.role) .where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id) ) return result.scalar_one_or_none() except OperationalError as e: raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") except SQLAlchemyError as e: raise DatabaseQueryError(f"Failed to query user role: {str(e)}") async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role: UserRoleEnum = UserRoleEnum.member) -> Optional[UserGroupModel]: """Adds a user to a group if they aren't already a member.""" try: async with db.begin(): existing = await db.execute( select(UserGroupModel).where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id) ) if existing.scalar_one_or_none(): return None db_user_group = UserGroupModel(user_id=user_id, group_id=group_id, role=role) db.add(db_user_group) await db.flush() await db.refresh(db_user_group) return db_user_group except IntegrityError as e: raise DatabaseIntegrityError(f"Failed to add user to group: {str(e)}") except OperationalError as e: raise DatabaseConnectionError(f"Database connection error: {str(e)}") except SQLAlchemyError as e: raise DatabaseTransactionError(f"Failed to add user to group: {str(e)}") async def remove_user_from_group(db: AsyncSession, group_id: int, user_id: int) -> bool: """Removes a user from a group.""" try: async with db.begin(): result = await db.execute( delete(UserGroupModel) .where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id) .returning(UserGroupModel.id) ) return result.scalar_one_or_none() is not None except OperationalError as e: raise DatabaseConnectionError(f"Database connection error: {str(e)}") except SQLAlchemyError as e: raise DatabaseTransactionError(f"Failed to remove user from group: {str(e)}") async def get_group_member_count(db: AsyncSession, group_id: int) -> int: """Counts the number of members in a group.""" try: result = await db.execute( select(func.count(UserGroupModel.id)).where(UserGroupModel.group_id == group_id) ) return result.scalar_one() except OperationalError as e: raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") except SQLAlchemyError as e: raise DatabaseQueryError(f"Failed to count group members: {str(e)}") async def check_group_membership( db: AsyncSession, group_id: int, user_id: int, action: str = "access this group" ) -> None: """ Checks if a user is a member of a group. Raises exceptions if not found or not a member. Raises: GroupNotFoundError: If the group_id does not exist. GroupMembershipError: If the user_id is not a member of the group. """ try: # Check group existence first group_exists = await db.get(GroupModel, group_id) if not group_exists: raise GroupNotFoundError(group_id) # Check membership membership = await db.execute( select(UserGroupModel.id) .where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id) .limit(1) ) if membership.scalar_one_or_none() is None: raise GroupMembershipError(group_id, action=action) # If we reach here, the user is a member return None except GroupNotFoundError: # Re-raise specific errors raise except GroupMembershipError: raise except OperationalError as e: raise DatabaseConnectionError(f"Failed to connect to database while checking membership: {str(e)}") except SQLAlchemyError as e: raise DatabaseQueryError(f"Failed to check group membership: {str(e)}") async def check_user_role_in_group( db: AsyncSession, group_id: int, user_id: int, required_role: UserRoleEnum, action: str = "perform this action" ) -> None: """ Checks if a user is a member of a group and has the required role (or higher). Raises: GroupNotFoundError: If the group_id does not exist. GroupMembershipError: If the user_id is not a member of the group. GroupPermissionError: If the user does not have the required role. """ # First, ensure user is a member (this also checks group existence) await check_group_membership(db, group_id, user_id, action=f"be checked for permissions to {action}") # Get the user's actual role actual_role = await get_user_role_in_group(db, group_id, user_id) # Define role hierarchy (assuming owner > member) role_hierarchy = {UserRoleEnum.owner: 2, UserRoleEnum.member: 1} if not actual_role or role_hierarchy.get(actual_role, 0) < role_hierarchy.get(required_role, 0): raise GroupPermissionError( group_id=group_id, action=f"{action} (requires at least '{required_role.value}' role)" ) # If role is sufficient, return None return None # app/main.py import logging import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from app.api.api_router import api_router from app.config import settings from app.core.api_config import API_METADATA, API_TAGS # Import database and models if needed for startup/shutdown events later # from . import database, models # --- Logging Setup --- logging.basicConfig( level=getattr(logging, settings.LOG_LEVEL), format=settings.LOG_FORMAT ) logger = logging.getLogger(__name__) # --- FastAPI App Instance --- app = FastAPI( **API_METADATA, openapi_tags=API_TAGS ) # --- CORS Middleware --- # Define allowed origins. Be specific in production! # Use ["*"] for wide open access during early development if needed, # but restrict it as soon as possible. # SvelteKit default dev port is 5173 origins = [ "http://localhost:5174", "http://localhost:8000", # Allow requests from the API itself (e.g., Swagger UI) # Add your deployed frontend URL here later # "https://your-frontend-domain.com", ] app.add_middleware( CORSMiddleware, allow_origins=settings.CORS_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- End CORS Middleware --- # --- Include API Routers --- # All API endpoints will be prefixed with /api app.include_router(api_router, prefix=settings.API_PREFIX) # --- End Include API Routers --- # --- Root Endpoint (Optional - outside the main API structure) --- @app.get("/", tags=["Root"]) async def read_root(): """ Provides a simple welcome message at the root path. Useful for basic reachability checks. """ logger.info("Root endpoint '/' accessed.") return {"message": settings.ROOT_MESSAGE} # --- End Root Endpoint --- # --- Application Startup/Shutdown Events (Optional) --- # @app.on_event("startup") # async def startup_event(): # logger.info("Application startup: Connecting to database...") # # You might perform initial checks or warm-up here # # await database.engine.connect() # Example check (get_db handles sessions per request) # logger.info("Application startup complete.") # @app.on_event("shutdown") # async def shutdown_event(): # logger.info("Application shutdown: Disconnecting from database...") # # await database.engine.dispose() # Close connection pool # logger.info("Application shutdown complete.") # --- End Events --- # --- Direct Run (for simple local testing if needed) --- # It's better to use `uvicorn app.main:app --reload` from the terminal # if __name__ == "__main__": # logger.info("Starting Uvicorn server directly from main.py") # uvicorn.run(app, host="0.0.0.0", port=8000) # ------------------------------------------------------ // API Version export const API_VERSION = 'v1'; // API Base URL export const API_BASE_URL = import.meta.env.VITE_API_URL || 'http://localhost:8000'; // API Endpoints export const API_ENDPOINTS = { // Auth AUTH: { LOGIN: '/auth/login', SIGNUP: '/auth/signup', REFRESH_TOKEN: '/auth/refresh-token', LOGOUT: '/auth/logout', VERIFY_EMAIL: '/auth/verify-email', RESET_PASSWORD: '/auth/reset-password', FORGOT_PASSWORD: '/auth/forgot-password', }, // Users USERS: { PROFILE: '/users/me', UPDATE_PROFILE: '/users/me', PASSWORD: '/users/password', AVATAR: '/users/avatar', SETTINGS: '/users/settings', NOTIFICATIONS: '/users/notifications', PREFERENCES: '/users/preferences', }, // Lists LISTS: { BASE: '/lists', BY_ID: (id: string) => `/lists/${id}`, ITEMS: (listId: string) => `/lists/${listId}/items`, ITEM: (listId: string, itemId: string) => `/lists/${listId}/items/${itemId}`, SHARE: (listId: string) => `/lists/${listId}/share`, UNSHARE: (listId: string) => `/lists/${listId}/unshare`, COMPLETE: (listId: string) => `/lists/${listId}/complete`, REOPEN: (listId: string) => `/lists/${listId}/reopen`, ARCHIVE: (listId: string) => `/lists/${listId}/archive`, RESTORE: (listId: string) => `/lists/${listId}/restore`, DUPLICATE: (listId: string) => `/lists/${listId}/duplicate`, EXPORT: (listId: string) => `/lists/${listId}/export`, IMPORT: '/lists/import', }, // Groups GROUPS: { BASE: '/groups', BY_ID: (id: string) => `/groups/${id}`, LISTS: (groupId: string) => `/groups/${groupId}/lists`, MEMBERS: (groupId: string) => `/groups/${groupId}/members`, MEMBER: (groupId: string, userId: string) => `/groups/${groupId}/members/${userId}`, LEAVE: (groupId: string) => `/groups/${groupId}/leave`, DELETE: (groupId: string) => `/groups/${groupId}`, SETTINGS: (groupId: string) => `/groups/${groupId}/settings`, ROLES: (groupId: string) => `/groups/${groupId}/roles`, ROLE: (groupId: string, roleId: string) => `/groups/${groupId}/roles/${roleId}`, }, // Invites INVITES: { BASE: '/invites', BY_ID: (id: string) => `/invites/${id}`, ACCEPT: (id: string) => `/invites/${id}/accept`, DECLINE: (id: string) => `/invites/${id}/decline`, REVOKE: (id: string) => `/invites/${id}/revoke`, LIST: '/invites', PENDING: '/invites/pending', SENT: '/invites/sent', }, // Items (for direct operations like update, get by ID) ITEMS: { BY_ID: (itemId: string) => `/items/${itemId}`, }, // OCR OCR: { PROCESS: '/ocr/extract-items', STATUS: (jobId: string) => `/ocr/status/${jobId}`, RESULT: (jobId: string) => `/ocr/result/${jobId}`, BATCH: '/ocr/batch', CANCEL: (jobId: string) => `/ocr/cancel/${jobId}`, HISTORY: '/ocr/history', }, // Costs COSTS: { BASE: '/costs', LIST_SUMMARY: (listId: string | number) => `/costs/lists/${listId}/cost-summary`, GROUP_BALANCE_SUMMARY: (groupId: string | number) => `/costs/groups/${groupId}/balance-summary`, }, // Financials FINANCIALS: { EXPENSES: '/financials/expenses', EXPENSE: (id: string) => `/financials/expenses/${id}`, SETTLEMENTS: '/financials/settlements', SETTLEMENT: (id: string) => `/financials/settlements/${id}`, BALANCES: '/financials/balances', BALANCE: (userId: string) => `/financials/balances/${userId}`, REPORTS: '/financials/reports', REPORT: (id: string) => `/financials/reports/${id}`, CATEGORIES: '/financials/categories', CATEGORY: (id: string) => `/financials/categories/${id}`, }, // Health HEALTH: { CHECK: '/health', VERSION: '/health/version', STATUS: '/health/status', METRICS: '/health/metrics', LOGS: '/health/logs', }, }; from fastapi import HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from app.config import settings from typing import Optional class ListNotFoundError(HTTPException): """Raised when a list is not found.""" def __init__(self, list_id: int): super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=f"List {list_id} not found" ) class ListPermissionError(HTTPException): """Raised when a user doesn't have permission to access a list.""" def __init__(self, list_id: int, action: str = "access"): super().__init__( status_code=status.HTTP_403_FORBIDDEN, detail=f"You do not have permission to {action} list {list_id}" ) class ListCreatorRequiredError(HTTPException): """Raised when an action requires the list creator but the user is not the creator.""" def __init__(self, list_id: int, action: str): super().__init__( status_code=status.HTTP_403_FORBIDDEN, detail=f"Only the list creator can {action} list {list_id}" ) class GroupNotFoundError(HTTPException): """Raised when a group is not found.""" def __init__(self, group_id: int): super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=f"Group {group_id} not found" ) class GroupPermissionError(HTTPException): """Raised when a user doesn't have permission to perform an action in a group.""" def __init__(self, group_id: int, action: str): super().__init__( status_code=status.HTTP_403_FORBIDDEN, detail=f"You do not have permission to {action} in group {group_id}" ) class GroupMembershipError(HTTPException): """Raised when a user attempts to perform an action that requires group membership.""" def __init__(self, group_id: int, action: str = "access"): super().__init__( status_code=status.HTTP_403_FORBIDDEN, detail=f"You must be a member of group {group_id} to {action}" ) class GroupOperationError(HTTPException): """Raised when a group operation fails.""" def __init__(self, detail: str): super().__init__( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=detail ) class GroupValidationError(HTTPException): """Raised when a group operation is invalid.""" def __init__(self, detail: str): super().__init__( status_code=status.HTTP_400_BAD_REQUEST, detail=detail ) class ItemNotFoundError(HTTPException): """Raised when an item is not found.""" def __init__(self, item_id: int): super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=f"Item {item_id} not found" ) class UserNotFoundError(HTTPException): """Raised when a user is not found.""" def __init__(self, user_id: Optional[int] = None, identifier: Optional[str] = None): detail_msg = "User not found." if user_id: detail_msg = f"User with ID {user_id} not found." elif identifier: detail_msg = f"User with identifier '{identifier}' not found." super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=detail_msg ) class InvalidOperationError(HTTPException): """Raised when an operation is invalid or disallowed by business logic.""" def __init__(self, detail: str, status_code: int = status.HTTP_400_BAD_REQUEST): super().__init__( status_code=status_code, detail=detail ) class DatabaseConnectionError(HTTPException): """Raised when there is an error connecting to the database.""" def __init__(self): super().__init__( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=settings.DB_CONNECTION_ERROR ) class DatabaseIntegrityError(HTTPException): """Raised when a database integrity constraint is violated.""" def __init__(self): super().__init__( status_code=status.HTTP_400_BAD_REQUEST, detail=settings.DB_INTEGRITY_ERROR ) class DatabaseTransactionError(HTTPException): """Raised when a database transaction fails.""" def __init__(self, detail: str = settings.DB_TRANSACTION_ERROR): super().__init__( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=detail ) class DatabaseQueryError(HTTPException): """Raised when a database query fails.""" def __init__(self, detail: str = settings.DB_QUERY_ERROR): super().__init__( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=detail ) class OCRServiceUnavailableError(HTTPException): """Raised when the OCR service is unavailable.""" def __init__(self): super().__init__( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=settings.OCR_SERVICE_UNAVAILABLE ) class OCRServiceConfigError(HTTPException): """Raised when there is an error in the OCR service configuration.""" def __init__(self): super().__init__( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=settings.OCR_SERVICE_CONFIG_ERROR ) class OCRUnexpectedError(HTTPException): """Raised when there is an unexpected error in the OCR service.""" def __init__(self): super().__init__( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=settings.OCR_UNEXPECTED_ERROR ) class OCRQuotaExceededError(HTTPException): """Raised when the OCR service quota is exceeded.""" def __init__(self): super().__init__( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=settings.OCR_QUOTA_EXCEEDED ) class InvalidFileTypeError(HTTPException): """Raised when an invalid file type is uploaded for OCR.""" def __init__(self): super().__init__( status_code=status.HTTP_400_BAD_REQUEST, detail=settings.OCR_INVALID_FILE_TYPE.format(types=", ".join(settings.ALLOWED_IMAGE_TYPES)) ) class FileTooLargeError(HTTPException): """Raised when an uploaded file exceeds the size limit.""" def __init__(self): super().__init__( status_code=status.HTTP_400_BAD_REQUEST, detail=settings.OCR_FILE_TOO_LARGE.format(size=settings.MAX_FILE_SIZE_MB) ) class OCRProcessingError(HTTPException): """Raised when there is an error processing the image with OCR.""" def __init__(self, detail: str): super().__init__( status_code=status.HTTP_400_BAD_REQUEST, detail=settings.OCR_PROCESSING_ERROR.format(detail=detail) ) class EmailAlreadyRegisteredError(HTTPException): """Raised when attempting to register with an email that is already in use.""" def __init__(self): super().__init__( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered." ) class UserCreationError(HTTPException): """Raised when there is an error creating a new user.""" def __init__(self): super().__init__( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An error occurred during user creation." ) class InviteNotFoundError(HTTPException): """Raised when an invite is not found.""" def __init__(self, invite_code: str): super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=f"Invite code {invite_code} not found" ) class InviteExpiredError(HTTPException): """Raised when an invite has expired.""" def __init__(self, invite_code: str): super().__init__( status_code=status.HTTP_410_GONE, detail=f"Invite code {invite_code} has expired" ) class InviteAlreadyUsedError(HTTPException): """Raised when an invite has already been used.""" def __init__(self, invite_code: str): super().__init__( status_code=status.HTTP_410_GONE, detail=f"Invite code {invite_code} has already been used" ) class InviteCreationError(HTTPException): """Raised when an invite cannot be created.""" def __init__(self, group_id: int): super().__init__( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create invite for group {group_id}" ) class ListStatusNotFoundError(HTTPException): """Raised when a list's status cannot be retrieved.""" def __init__(self, list_id: int): super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=f"Status for list {list_id} not found" ) class ConflictError(HTTPException): """Raised when an optimistic lock version conflict occurs.""" def __init__(self, detail: str): super().__init__( status_code=status.HTTP_409_CONFLICT, detail=detail ) class InvalidCredentialsError(HTTPException): """Raised when login credentials are invalid.""" def __init__(self): super().__init__( status_code=status.HTTP_401_UNAUTHORIZED, detail=settings.AUTH_INVALID_CREDENTIALS, headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_credentials\""} ) class NotAuthenticatedError(HTTPException): """Raised when the user is not authenticated.""" def __init__(self): super().__init__( status_code=status.HTTP_401_UNAUTHORIZED, detail=settings.AUTH_NOT_AUTHENTICATED, headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"not_authenticated\""} ) class JWTError(HTTPException): """Raised when there is an error with the JWT token.""" def __init__(self, error: str): super().__init__( status_code=status.HTTP_401_UNAUTHORIZED, detail=settings.JWT_ERROR.format(error=error), headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""} ) class JWTUnexpectedError(HTTPException): """Raised when there is an unexpected error with the JWT token.""" def __init__(self, error: str): super().__init__( status_code=status.HTTP_401_UNAUTHORIZED, detail=settings.JWT_UNEXPECTED_ERROR.format(error=error), headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""} ) # app/crud/list.py from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload, joinedload from sqlalchemy import or_, and_, delete as sql_delete, func as sql_func, desc from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError from typing import Optional, List as PyList from app.schemas.list import ListStatus from app.models import List as ListModel, UserGroup as UserGroupModel, Item as ItemModel from app.schemas.list import ListCreate, ListUpdate from app.core.exceptions import ( ListNotFoundError, ListPermissionError, ListCreatorRequiredError, DatabaseConnectionError, DatabaseIntegrityError, DatabaseQueryError, DatabaseTransactionError, ConflictError ) async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel: """Creates a new list record.""" try: async with db.begin(): db_list = ListModel( name=list_in.name, description=list_in.description, group_id=list_in.group_id, created_by_id=creator_id, is_complete=False ) db.add(db_list) await db.flush() await db.refresh(db_list) return db_list except IntegrityError as e: raise DatabaseIntegrityError(f"Failed to create list: {str(e)}") except OperationalError as e: raise DatabaseConnectionError(f"Database connection error: {str(e)}") except SQLAlchemyError as e: raise DatabaseTransactionError(f"Failed to create list: {str(e)}") async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel]: """Gets all lists accessible by a user.""" try: group_ids_result = await db.execute( select(UserGroupModel.group_id).where(UserGroupModel.user_id == user_id) ) user_group_ids = group_ids_result.scalars().all() # Build conditions for the OR clause dynamically conditions = [ and_(ListModel.created_by_id == user_id, ListModel.group_id.is_(None)) ] if user_group_ids: # Only add the IN clause if there are group IDs conditions.append(ListModel.group_id.in_(user_group_ids)) query = select(ListModel).where(or_(*conditions)).order_by(ListModel.updated_at.desc()) result = await db.execute(query) return result.scalars().all() except OperationalError as e: raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") except SQLAlchemyError as e: raise DatabaseQueryError(f"Failed to query user lists: {str(e)}") async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = False) -> Optional[ListModel]: """Gets a single list by ID, optionally loading its items.""" try: query = select(ListModel).where(ListModel.id == list_id) if load_items: query = query.options( selectinload(ListModel.items) .options( joinedload(ItemModel.added_by_user), joinedload(ItemModel.completed_by_user) ) ) result = await db.execute(query) return result.scalars().first() except OperationalError as e: raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") except SQLAlchemyError as e: raise DatabaseQueryError(f"Failed to query list: {str(e)}") async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel: """Updates an existing list record, checking for version conflicts.""" try: async with db.begin(): if list_db.version != list_in.version: raise ConflictError( f"List '{list_db.name}' (ID: {list_db.id}) has been modified. " f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh." ) update_data = list_in.model_dump(exclude_unset=True, exclude={'version'}) for key, value in update_data.items(): setattr(list_db, key, value) list_db.version += 1 db.add(list_db) await db.flush() await db.refresh(list_db) return list_db except IntegrityError as e: await db.rollback() raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}") except OperationalError as e: await db.rollback() raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}") except ConflictError: await db.rollback() raise except SQLAlchemyError as e: await db.rollback() raise DatabaseTransactionError(f"Failed to update list: {str(e)}") async def delete_list(db: AsyncSession, list_db: ListModel) -> None: """Deletes a list record. Version check should be done by the caller (API endpoint).""" try: async with db.begin(): await db.delete(list_db) return None except OperationalError as e: await db.rollback() raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}") except SQLAlchemyError as e: await db.rollback() raise DatabaseTransactionError(f"Failed to delete list: {str(e)}") async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel: """Fetches a list and verifies user permission.""" try: list_db = await get_list_by_id(db, list_id=list_id, load_items=True) if not list_db: raise ListNotFoundError(list_id) is_creator = list_db.created_by_id == user_id if require_creator: if not is_creator: raise ListCreatorRequiredError(list_id, "access") return list_db if is_creator: return list_db if list_db.group_id: from app.crud.group import is_user_member is_member = await is_user_member(db, group_id=list_db.group_id, user_id=user_id) if not is_member: raise ListPermissionError(list_id) return list_db else: raise ListPermissionError(list_id) except OperationalError as e: raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") except SQLAlchemyError as e: raise DatabaseQueryError(f"Failed to check list permissions: {str(e)}") async def get_list_status(db: AsyncSession, list_id: int) -> ListStatus: """Gets the update timestamps and item count for a list.""" try: list_query = select(ListModel.updated_at).where(ListModel.id == list_id) list_result = await db.execute(list_query) list_updated_at = list_result.scalar_one_or_none() if list_updated_at is None: raise ListNotFoundError(list_id) item_status_query = ( select( sql_func.max(ItemModel.updated_at).label("latest_item_updated_at"), sql_func.count(ItemModel.id).label("item_count") ) .where(ItemModel.list_id == list_id) ) item_result = await db.execute(item_status_query) item_status = item_result.first() return ListStatus( list_updated_at=list_updated_at, latest_item_updated_at=item_status.latest_item_updated_at if item_status else None, item_count=item_status.item_count if item_status else 0 ) except OperationalError as e: raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") except SQLAlchemyError as e: raise DatabaseQueryError(f"Failed to get list status: {str(e)}") # app/models.py import enum import secrets from datetime import datetime, timedelta, timezone from sqlalchemy import ( Column, Integer, String, DateTime, ForeignKey, Boolean, Enum as SAEnum, UniqueConstraint, Index, DDL, event, delete, func, text as sa_text, Text, # <-- Add Text for description Numeric # <-- Add Numeric for price ) from sqlalchemy.orm import relationship, backref from .database import Base # --- Enums --- class UserRoleEnum(enum.Enum): owner = "owner" member = "member" class SplitTypeEnum(enum.Enum): EQUAL = "EQUAL" # Split equally among all involved users EXACT_AMOUNTS = "EXACT_AMOUNTS" # Specific amounts for each user (defined in ExpenseSplit) PERCENTAGE = "PERCENTAGE" # Percentage for each user (defined in ExpenseSplit) SHARES = "SHARES" # Proportional to shares/units (defined in ExpenseSplit) ITEM_BASED = "ITEM_BASED" # If an expense is derived directly from item prices and who added them # Add more types as needed, e.g., UNPAID (for tracking debts not part of a formal expense) # --- User Model --- class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True, index=True) email = Column(String, unique=True, index=True, nullable=False) password_hash = Column(String, nullable=False) name = Column(String, index=True, nullable=True) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # --- Relationships --- created_groups = relationship("Group", back_populates="creator") group_associations = relationship("UserGroup", back_populates="user", cascade="all, delete-orphan") created_invites = relationship("Invite", back_populates="creator") # --- NEW Relationships for Lists/Items --- created_lists = relationship("List", foreign_keys="List.created_by_id", back_populates="creator") # Link List.created_by_id -> User added_items = relationship("Item", foreign_keys="Item.added_by_id", back_populates="added_by_user") # Link Item.added_by_id -> User completed_items = relationship("Item", foreign_keys="Item.completed_by_id", back_populates="completed_by_user") # Link Item.completed_by_id -> User # --- End NEW Relationships --- # --- Relationships for Cost Splitting --- expenses_paid = relationship("Expense", foreign_keys="Expense.paid_by_user_id", back_populates="paid_by_user", cascade="all, delete-orphan") expense_splits = relationship("ExpenseSplit", foreign_keys="ExpenseSplit.user_id", back_populates="user", cascade="all, delete-orphan") settlements_made = relationship("Settlement", foreign_keys="Settlement.paid_by_user_id", back_populates="payer", cascade="all, delete-orphan") settlements_received = relationship("Settlement", foreign_keys="Settlement.paid_to_user_id", back_populates="payee", cascade="all, delete-orphan") # --- End Relationships for Cost Splitting --- # --- Group Model --- class Group(Base): __tablename__ = "groups" id = Column(Integer, primary_key=True, index=True) name = Column(String, index=True, nullable=False) created_by_id = Column(Integer, ForeignKey("users.id"), nullable=False) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # --- Relationships --- creator = relationship("User", back_populates="created_groups") member_associations = relationship("UserGroup", back_populates="group", cascade="all, delete-orphan") invites = relationship("Invite", back_populates="group", cascade="all, delete-orphan") # --- NEW Relationship for Lists --- lists = relationship("List", back_populates="group", cascade="all, delete-orphan") # Link List.group_id -> Group # --- End NEW Relationship --- # --- Relationships for Cost Splitting --- expenses = relationship("Expense", foreign_keys="Expense.group_id", back_populates="group", cascade="all, delete-orphan") settlements = relationship("Settlement", foreign_keys="Settlement.group_id", back_populates="group", cascade="all, delete-orphan") # --- End Relationships for Cost Splitting --- # --- UserGroup Association Model --- class UserGroup(Base): __tablename__ = "user_groups" __table_args__ = (UniqueConstraint('user_id', 'group_id', name='uq_user_group'),) id = Column(Integer, primary_key=True, index=True) user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) group_id = Column(Integer, ForeignKey("groups.id", ondelete="CASCADE"), nullable=False) role = Column(SAEnum(UserRoleEnum, name="userroleenum", create_type=True), nullable=False, default=UserRoleEnum.member) joined_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) user = relationship("User", back_populates="group_associations") group = relationship("Group", back_populates="member_associations") # --- Invite Model --- class Invite(Base): __tablename__ = "invites" __table_args__ = ( Index('ix_invites_active_code', 'code', unique=True, postgresql_where=sa_text('is_active = true')), ) id = Column(Integer, primary_key=True, index=True) code = Column(String, unique=False, index=True, nullable=False, default=lambda: secrets.token_urlsafe(16)) group_id = Column(Integer, ForeignKey("groups.id", ondelete="CASCADE"), nullable=False) created_by_id = Column(Integer, ForeignKey("users.id"), nullable=False) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) expires_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc) + timedelta(days=7)) is_active = Column(Boolean, default=True, nullable=False) group = relationship("Group", back_populates="invites") creator = relationship("User", back_populates="created_invites") # === NEW: List Model === class List(Base): __tablename__ = "lists" id = Column(Integer, primary_key=True, index=True) name = Column(String, index=True, nullable=False) description = Column(Text, nullable=True) created_by_id = Column(Integer, ForeignKey("users.id"), nullable=False) # Who created this list group_id = Column(Integer, ForeignKey("groups.id"), nullable=True) # Which group it belongs to (NULL if personal) is_complete = Column(Boolean, default=False, nullable=False) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) version = Column(Integer, nullable=False, default=1, server_default='1') # --- Relationships --- creator = relationship("User", back_populates="created_lists") # Link to User.created_lists group = relationship("Group", back_populates="lists") # Link to Group.lists items = relationship("Item", back_populates="list", cascade="all, delete-orphan", order_by="Item.created_at") # Link to Item.list, cascade deletes # --- Relationships for Cost Splitting --- expenses = relationship("Expense", foreign_keys="Expense.list_id", back_populates="list", cascade="all, delete-orphan") # --- End Relationships for Cost Splitting --- # === NEW: Item Model === class Item(Base): __tablename__ = "items" id = Column(Integer, primary_key=True, index=True) list_id = Column(Integer, ForeignKey("lists.id", ondelete="CASCADE"), nullable=False) # Belongs to which list name = Column(String, index=True, nullable=False) quantity = Column(String, nullable=True) # Flexible quantity (e.g., "1", "2 lbs", "a bunch") is_complete = Column(Boolean, default=False, nullable=False) price = Column(Numeric(10, 2), nullable=True) # For cost splitting later (e.g., 12345678.99) added_by_id = Column(Integer, ForeignKey("users.id"), nullable=False) # Who added this item completed_by_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Who marked it complete created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) version = Column(Integer, nullable=False, default=1, server_default='1') # --- Relationships --- list = relationship("List", back_populates="items") # Link to List.items added_by_user = relationship("User", foreign_keys=[added_by_id], back_populates="added_items") # Link to User.added_items completed_by_user = relationship("User", foreign_keys=[completed_by_id], back_populates="completed_items") # Link to User.completed_items # --- Relationships for Cost Splitting --- # If an item directly results in an expense, or an expense can be tied to an item. expenses = relationship("Expense", back_populates="item") # An item might have multiple associated expenses # --- End Relationships for Cost Splitting --- # === NEW Models for Advanced Cost Splitting === class Expense(Base): __tablename__ = "expenses" id = Column(Integer, primary_key=True, index=True) description = Column(String, nullable=False) total_amount = Column(Numeric(10, 2), nullable=False) currency = Column(String, nullable=False, default="USD") # Consider making this an Enum too if few currencies expense_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) split_type = Column(SAEnum(SplitTypeEnum, name="splittypeenum", create_type=True), nullable=False) # Foreign Keys list_id = Column(Integer, ForeignKey("lists.id"), nullable=True) group_id = Column(Integer, ForeignKey("groups.id"), nullable=True) # If not list-specific but group-specific item_id = Column(Integer, ForeignKey("items.id"), nullable=True) # If the expense is for a specific item paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) version = Column(Integer, nullable=False, default=1, server_default='1') # Relationships paid_by_user = relationship("User", foreign_keys=[paid_by_user_id], back_populates="expenses_paid") list = relationship("List", foreign_keys=[list_id], back_populates="expenses") group = relationship("Group", foreign_keys=[group_id], back_populates="expenses") item = relationship("Item", foreign_keys=[item_id], back_populates="expenses") splits = relationship("ExpenseSplit", back_populates="expense", cascade="all, delete-orphan") __table_args__ = ( # Example: Ensure either list_id or group_id is present if item_id is null # CheckConstraint('(item_id IS NOT NULL) OR (list_id IS NOT NULL) OR (group_id IS NOT NULL)', name='chk_expense_context'), ) class ExpenseSplit(Base): __tablename__ = "expense_splits" __table_args__ = (UniqueConstraint('expense_id', 'user_id', name='uq_expense_user_split'),) id = Column(Integer, primary_key=True, index=True) expense_id = Column(Integer, ForeignKey("expenses.id", ondelete="CASCADE"), nullable=False) user_id = Column(Integer, ForeignKey("users.id"), nullable=False) owed_amount = Column(Numeric(10, 2), nullable=False) # For EQUAL or EXACT_AMOUNTS # For PERCENTAGE split (value from 0.00 to 100.00) share_percentage = Column(Numeric(5, 2), nullable=True) # For SHARES split (e.g., user A has 2 shares, user B has 3 shares) share_units = Column(Integer, nullable=True) # is_settled might be better tracked via actual Settlement records or a reconciliation process # is_settled = Column(Boolean, default=False, nullable=False) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) # Relationships expense = relationship("Expense", back_populates="splits") user = relationship("User", foreign_keys=[user_id], back_populates="expense_splits") class Settlement(Base): __tablename__ = "settlements" id = Column(Integer, primary_key=True, index=True) group_id = Column(Integer, ForeignKey("groups.id"), nullable=False) # Settlements usually within a group paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False) paid_to_user_id = Column(Integer, ForeignKey("users.id"), nullable=False) amount = Column(Numeric(10, 2), nullable=False) settlement_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) description = Column(Text, nullable=True) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) version = Column(Integer, nullable=False, default=1, server_default='1') # Relationships group = relationship("Group", foreign_keys=[group_id], back_populates="settlements") payer = relationship("User", foreign_keys=[paid_by_user_id], back_populates="settlements_made") payee = relationship("User", foreign_keys=[paid_to_user_id], back_populates="settlements_received") __table_args__ = ( # Ensure payer and payee are different users # CheckConstraint('paid_by_user_id <> paid_to_user_id', name='chk_settlement_payer_ne_payee'), ) # Potential future: PaymentMethod model, etc. # app/crud/item.py from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy import delete as sql_delete, update as sql_update # Use aliases from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError from typing import Optional, List as PyList from datetime import datetime, timezone from app.models import Item as ItemModel from app.schemas.item import ItemCreate, ItemUpdate from app.core.exceptions import ( ItemNotFoundError, DatabaseConnectionError, DatabaseIntegrityError, DatabaseQueryError, DatabaseTransactionError, ConflictError ) async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel: """Creates a new item record for a specific list.""" try: db_item = ItemModel( name=item_in.name, quantity=item_in.quantity, list_id=list_id, added_by_id=user_id, is_complete=False # Default on creation # version is implicitly set to 1 by model default ) db.add(db_item) await db.flush() await db.refresh(db_item) await db.commit() # Explicitly commit here return db_item except IntegrityError as e: await db.rollback() # Rollback on integrity error raise DatabaseIntegrityError(f"Failed to create item: {str(e)}") except OperationalError as e: await db.rollback() # Rollback on operational error raise DatabaseConnectionError(f"Database connection error: {str(e)}") except SQLAlchemyError as e: await db.rollback() # Rollback on other SQLAlchemy errors raise DatabaseTransactionError(f"Failed to create item: {str(e)}") except Exception as e: # Catch any other exception and attempt rollback await db.rollback() raise # Re-raise the original exception async def get_items_by_list_id(db: AsyncSession, list_id: int) -> PyList[ItemModel]: """Gets all items belonging to a specific list, ordered by creation time.""" try: result = await db.execute( select(ItemModel) .where(ItemModel.list_id == list_id) .order_by(ItemModel.created_at.asc()) # Or desc() if preferred ) return result.scalars().all() except OperationalError as e: raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") except SQLAlchemyError as e: raise DatabaseQueryError(f"Failed to query items: {str(e)}") async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]: """Gets a single item by its ID.""" try: result = await db.execute(select(ItemModel).where(ItemModel.id == item_id)) return result.scalars().first() except OperationalError as e: raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") except SQLAlchemyError as e: raise DatabaseQueryError(f"Failed to query item: {str(e)}") async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel: """Updates an existing item record, checking for version conflicts.""" try: # Check version conflict if item_db.version != item_in.version: raise ConflictError( f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. " f"Your version is {item_in.version}, current version is {item_db.version}. Please refresh." ) update_data = item_in.model_dump(exclude_unset=True, exclude={'version'}) # Exclude version # Special handling for is_complete if 'is_complete' in update_data: if update_data['is_complete'] is True: if item_db.completed_by_id is None: # Only set if not already completed by someone update_data['completed_by_id'] = user_id else: update_data['completed_by_id'] = None # Clear if marked incomplete # Apply updates for key, value in update_data.items(): setattr(item_db, key, value) item_db.version += 1 # Increment version db.add(item_db) await db.flush() await db.refresh(item_db) # Commit the transaction if not part of a larger transaction await db.commit() return item_db except IntegrityError as e: await db.rollback() raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}") except OperationalError as e: await db.rollback() raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}") except ConflictError: # Re-raise ConflictError await db.rollback() raise except SQLAlchemyError as e: await db.rollback() raise DatabaseTransactionError(f"Failed to update item: {str(e)}") async def delete_item(db: AsyncSession, item_db: ItemModel) -> None: """Deletes an item record. Version check should be done by the caller (API endpoint).""" try: await db.delete(item_db) await db.commit() return None except OperationalError as e: await db.rollback() raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}") except SQLAlchemyError as e: await db.rollback() raise DatabaseTransactionError(f"Failed to delete item: {str(e)}") from fastapi import APIRouter from app.api.v1.endpoints import health from app.api.v1.endpoints import auth from app.api.v1.endpoints import users from app.api.v1.endpoints import groups from app.api.v1.endpoints import invites from app.api.v1.endpoints import lists from app.api.v1.endpoints import items from app.api.v1.endpoints import ocr from app.api.v1.endpoints import costs from app.api.v1.endpoints import financials api_router_v1 = APIRouter() api_router_v1.include_router(health.router) api_router_v1.include_router(auth.router, prefix="/auth", tags=["Authentication"]) api_router_v1.include_router(users.router, prefix="/users", tags=["Users"]) api_router_v1.include_router(groups.router, prefix="/groups", tags=["Groups"]) api_router_v1.include_router(invites.router, prefix="/invites", tags=["Invites"]) api_router_v1.include_router(lists.router, prefix="/lists", tags=["Lists"]) api_router_v1.include_router(items.router, tags=["Items"]) api_router_v1.include_router(ocr.router, prefix="/ocr", tags=["OCR"]) api_router_v1.include_router(costs.router, prefix="/costs", tags=["Costs"]) api_router_v1.include_router(financials.router) # Add other v1 endpoint routers here later # e.g., api_router_v1.include_router(users.router, prefix="/users", tags=["Users"]) # app/config.py import os from pydantic_settings import BaseSettings from dotenv import load_dotenv import logging import secrets load_dotenv() logger = logging.getLogger(__name__) class Settings(BaseSettings): DATABASE_URL: str | None = None GEMINI_API_KEY: str | None = None # --- JWT Settings --- SECRET_KEY: str # Must be set via environment variable ALGORITHM: str = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # Default token lifetime: 30 minutes REFRESH_TOKEN_EXPIRE_MINUTES: int = 10080 # Default refresh token lifetime: 7 days # --- OCR Settings --- MAX_FILE_SIZE_MB: int = 10 # Maximum allowed file size for OCR processing ALLOWED_IMAGE_TYPES: list[str] = ["image/jpeg", "image/png", "image/webp"] # Supported image formats OCR_ITEM_EXTRACTION_PROMPT: str = """ Extract the shopping list items from this image. List each distinct item on a new line. Ignore prices, quantities, store names, discounts, taxes, totals, and other non-item text. Focus only on the names of the products or items to be purchased. Add 2 underscores before and after the item name, if it is struck through. If the image does not appear to be a shopping list or receipt, state that clearly. Example output for a grocery list: Milk Eggs Bread __Apples__ Organic Bananas """ # --- Gemini AI Settings --- GEMINI_MODEL_NAME: str = "gemini-2.0-flash" # The model to use for OCR GEMINI_SAFETY_SETTINGS: dict = { "HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE", "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE", "HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE", "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE", } GEMINI_GENERATION_CONFIG: dict = { "candidate_count": 1, "max_output_tokens": 2048, "temperature": 0.9, "top_p": 1, "top_k": 1 } # --- API Settings --- API_PREFIX: str = "/api" # Base path for all API endpoints API_OPENAPI_URL: str = "/api/openapi.json" API_DOCS_URL: str = "/api/docs" API_REDOC_URL: str = "/api/redoc" CORS_ORIGINS: list[str] = [ "http://localhost:5174", "http://localhost:8000", "http://localhost:9000", # Add your deployed frontend URL here later # "https://your-frontend-domain.com", ] # --- API Metadata --- API_TITLE: str = "Shared Lists API" API_DESCRIPTION: str = "API for managing shared shopping lists, OCR, and cost splitting." API_VERSION: str = "0.1.0" ROOT_MESSAGE: str = "Welcome to the Shared Lists API! Docs available at /api/docs" # --- Logging Settings --- LOG_LEVEL: str = "INFO" LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" # --- Health Check Settings --- HEALTH_STATUS_OK: str = "ok" HEALTH_STATUS_ERROR: str = "error" # --- Auth Settings --- OAUTH2_TOKEN_URL: str = "/api/v1/auth/login" # Path to login endpoint TOKEN_TYPE: str = "bearer" # Default token type for OAuth2 AUTH_HEADER_PREFIX: str = "Bearer" # Prefix for Authorization header AUTH_HEADER_NAME: str = "WWW-Authenticate" # Name of auth header AUTH_CREDENTIALS_ERROR: str = "Could not validate credentials" AUTH_INVALID_CREDENTIALS: str = "Incorrect email or password" AUTH_NOT_AUTHENTICATED: str = "Not authenticated" # --- HTTP Status Messages --- HTTP_400_DETAIL: str = "Bad Request" HTTP_401_DETAIL: str = "Unauthorized" HTTP_403_DETAIL: str = "Forbidden" HTTP_404_DETAIL: str = "Not Found" HTTP_422_DETAIL: str = "Unprocessable Entity" HTTP_429_DETAIL: str = "Too Many Requests" HTTP_500_DETAIL: str = "Internal Server Error" HTTP_503_DETAIL: str = "Service Unavailable" # --- Database Error Messages --- DB_CONNECTION_ERROR: str = "Database connection error" DB_INTEGRITY_ERROR: str = "Database integrity error" DB_TRANSACTION_ERROR: str = "Database transaction error" DB_QUERY_ERROR: str = "Database query error" class Config: env_file = ".env" env_file_encoding = 'utf-8' extra = "ignore" settings = Settings() # Validation for critical settings if settings.DATABASE_URL is None: raise ValueError("DATABASE_URL environment variable must be set.") # Enforce secure secret key if not settings.SECRET_KEY: raise ValueError("SECRET_KEY environment variable must be set. Generate a secure key using: openssl rand -hex 32") # Validate secret key strength if len(settings.SECRET_KEY) < 32: raise ValueError("SECRET_KEY must be at least 32 characters long for security") if settings.GEMINI_API_KEY is None: logger.error("CRITICAL: GEMINI_API_KEY environment variable not set. Gemini features will be unavailable.") else: # Optional: Log partial key for confirmation (avoid logging full key) logger.info(f"GEMINI_API_KEY loaded (starts with: {settings.GEMINI_API_KEY[:4]}...).")