diff --git a/fe/package-lock.json b/fe/package-lock.json index 230d70b..0b09d9d 100644 --- a/fe/package-lock.json +++ b/fe/package-lock.json @@ -8,6 +8,8 @@ "name": "fe", "version": "0.0.0", "dependencies": { + "@supabase/auth-js": "^2.69.1", + "@supabase/supabase-js": "^2.49.4", "@vueuse/core": "^13.1.0", "axios": "^1.9.0", "pinia": "^3.0.2", @@ -3836,6 +3838,102 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/@supabase/auth-js": { + "version": "2.69.1", + "resolved": "https://registry.npmjs.org/@supabase/auth-js/-/auth-js-2.69.1.tgz", + "integrity": "sha512-FILtt5WjCNzmReeRLq5wRs3iShwmnWgBvxHfqapC/VoljJl+W8hDAyFmf1NVw3zH+ZjZ05AKxiKxVeb0HNWRMQ==", + "license": "MIT", + "dependencies": { + "@supabase/node-fetch": "^2.6.14" + } + }, + "node_modules/@supabase/functions-js": { + "version": "2.4.4", + "resolved": "https://registry.npmjs.org/@supabase/functions-js/-/functions-js-2.4.4.tgz", + "integrity": "sha512-WL2p6r4AXNGwop7iwvul2BvOtuJ1YQy8EbOd0dhG1oN1q8el/BIRSFCFnWAMM/vJJlHWLi4ad22sKbKr9mvjoA==", + "license": "MIT", + "dependencies": { + "@supabase/node-fetch": "^2.6.14" + } + }, + "node_modules/@supabase/node-fetch": { + "version": "2.6.15", + "resolved": "https://registry.npmjs.org/@supabase/node-fetch/-/node-fetch-2.6.15.tgz", + "integrity": "sha512-1ibVeYUacxWYi9i0cf5efil6adJ9WRyZBLivgjs+AUpewx1F3xPi7gLgaASI2SmIQxPoCEjAsLAzKPgMJVgOUQ==", + "license": "MIT", + "dependencies": { + "whatwg-url": "^5.0.0" + }, + "engines": { + "node": "4.x || >=6.0.0" + } + }, + "node_modules/@supabase/node-fetch/node_modules/tr46": { + "version": "0.0.3", + "resolved": "https://registry.npmjs.org/tr46/-/tr46-0.0.3.tgz", + "integrity": "sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==", + "license": "MIT" + }, + "node_modules/@supabase/node-fetch/node_modules/webidl-conversions": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/webidl-conversions/-/webidl-conversions-3.0.1.tgz", + "integrity": "sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ==", + "license": "BSD-2-Clause" + }, + "node_modules/@supabase/node-fetch/node_modules/whatwg-url": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/whatwg-url/-/whatwg-url-5.0.0.tgz", + "integrity": "sha512-saE57nupxk6v3HY35+jzBwYa0rKSy0XR8JSxZPwgLr7ys0IBzhGviA1/TUGJLmSVqs8pb9AnvICXEuOHLprYTw==", + "license": "MIT", + "dependencies": { + "tr46": "~0.0.3", + "webidl-conversions": "^3.0.0" + } + }, + "node_modules/@supabase/postgrest-js": { + "version": "1.19.4", + "resolved": "https://registry.npmjs.org/@supabase/postgrest-js/-/postgrest-js-1.19.4.tgz", + "integrity": "sha512-O4soKqKtZIW3olqmbXXbKugUtByD2jPa8kL2m2c1oozAO11uCcGrRhkZL0kVxjBLrXHE0mdSkFsMj7jDSfyNpw==", + "license": "MIT", + "dependencies": { + "@supabase/node-fetch": "^2.6.14" + } + }, + "node_modules/@supabase/realtime-js": { + "version": "2.11.2", + "resolved": "https://registry.npmjs.org/@supabase/realtime-js/-/realtime-js-2.11.2.tgz", + "integrity": "sha512-u/XeuL2Y0QEhXSoIPZZwR6wMXgB+RQbJzG9VErA3VghVt7uRfSVsjeqd7m5GhX3JR6dM/WRmLbVR8URpDWG4+w==", + "license": "MIT", + "dependencies": { + "@supabase/node-fetch": "^2.6.14", + "@types/phoenix": "^1.5.4", + "@types/ws": "^8.5.10", + "ws": "^8.18.0" + } + }, + "node_modules/@supabase/storage-js": { + "version": "2.7.1", + "resolved": "https://registry.npmjs.org/@supabase/storage-js/-/storage-js-2.7.1.tgz", + "integrity": "sha512-asYHcyDR1fKqrMpytAS1zjyEfvxuOIp1CIXX7ji4lHHcJKqyk+sLl/Vxgm4sN6u8zvuUtae9e4kDxQP2qrwWBA==", + "license": "MIT", + "dependencies": { + "@supabase/node-fetch": "^2.6.14" + } + }, + "node_modules/@supabase/supabase-js": { + "version": "2.49.4", + "resolved": "https://registry.npmjs.org/@supabase/supabase-js/-/supabase-js-2.49.4.tgz", + "integrity": "sha512-jUF0uRUmS8BKt37t01qaZ88H9yV1mbGYnqLeuFWLcdV+x1P4fl0yP9DGtaEhFPZcwSom7u16GkLEH9QJZOqOkw==", + "license": "MIT", + "dependencies": { + "@supabase/auth-js": "2.69.1", + "@supabase/functions-js": "2.4.4", + "@supabase/node-fetch": "2.6.15", + "@supabase/postgrest-js": "1.19.4", + "@supabase/realtime-js": "2.11.2", + "@supabase/storage-js": "2.7.1" + } + }, "node_modules/@surma/rollup-plugin-off-main-thread": { "version": "2.2.3", "resolved": "https://registry.npmjs.org/@surma/rollup-plugin-off-main-thread/-/rollup-plugin-off-main-thread-2.2.3.tgz", @@ -3896,12 +3994,17 @@ "version": "22.15.17", "resolved": "https://registry.npmjs.org/@types/node/-/node-22.15.17.tgz", "integrity": "sha512-wIX2aSZL5FE+MR0JlvF87BNVrtFWf6AE6rxSE9X7OwnVvoyCQjpzSRJ+M87se/4QCkCiebQAqrJ0y6fwIyi7nw==", - "dev": true, "license": "MIT", "dependencies": { "undici-types": "~6.21.0" } }, + "node_modules/@types/phoenix": { + "version": "1.6.6", + "resolved": "https://registry.npmjs.org/@types/phoenix/-/phoenix-1.6.6.tgz", + "integrity": "sha512-PIzZZlEppgrpoT2QgbnDU+MMzuR6BbCjllj0bM70lWoejMeNJAxCchxnv7J3XFkI8MpygtRpzXrIlmWUBclP5A==", + "license": "MIT" + }, "node_modules/@types/resolve": { "version": "1.20.2", "resolved": "https://registry.npmjs.org/@types/resolve/-/resolve-1.20.2.tgz", @@ -3929,6 +4032,15 @@ "integrity": "sha512-oIQLCGWtcFZy2JW77j9k8nHzAOpqMHLQejDA48XXMWH6tjCQHz5RCFz1bzsmROyL6PUm+LLnUiI4BCn221inxA==", "license": "MIT" }, + "node_modules/@types/ws": { + "version": "8.18.1", + "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz", + "integrity": "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==", + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, "node_modules/@typescript-eslint/eslint-plugin": { "version": "8.32.1", "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-8.32.1.tgz", @@ -10916,7 +11028,6 @@ "version": "6.21.0", "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.21.0.tgz", "integrity": "sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ==", - "dev": true, "license": "MIT" }, "node_modules/unicode-canonical-property-names-ecmascript": { @@ -12338,7 +12449,6 @@ "version": "8.18.2", "resolved": "https://registry.npmjs.org/ws/-/ws-8.18.2.tgz", "integrity": "sha512-DMricUmwGZUVr++AEAe2uiVM7UoO9MAVZMDu05UQOaUII0lp+zOzLLU4Xqh/JvTqklB1T4uELaaPBKyjE1r4fQ==", - "dev": true, "license": "MIT", "engines": { "node": ">=10.0.0" diff --git a/fe/package.json b/fe/package.json index 63d22bb..7b1374f 100644 --- a/fe/package.json +++ b/fe/package.json @@ -17,6 +17,8 @@ "format": "prettier --write src/" }, "dependencies": { + "@supabase/auth-js": "^2.69.1", + "@supabase/supabase-js": "^2.49.4", "@vueuse/core": "^13.1.0", "axios": "^1.9.0", "pinia": "^3.0.2", diff --git a/repomix-output.xml b/repomix-output.xml deleted file mode 100644 index c55ed7b..0000000 --- a/repomix-output.xml +++ /dev/null @@ -1,13219 +0,0 @@ -This file is a merged representation of the entire codebase, combined into a single document by Repomix. - - -This section contains a summary of this file. - - -This file contains a packed representation of the entire repository's contents. -It is designed to be easily consumable by AI systems for analysis, code review, -or other automated processes. - - - -The content is organized as follows: -1. This summary section -2. Repository information -3. Directory structure -4. Repository files, each consisting of: - - File path as an attribute - - Full contents of the file - - - -- This file should be treated as read-only. Any changes should be made to the - original repository files, not this packed version. -- When processing this file, use the file path to distinguish - between different files in the repository. -- Be aware that this file may contain sensitive information. Handle it with - the same level of security as you would the original repository. - - - -- Some files may have been excluded based on .gitignore rules and Repomix's configuration -- Binary files are not included in this packed representation. Please refer to the Repository Structure section for a complete list of file paths, including binary files -- Files matching patterns in .gitignore are excluded -- Files matching default ignore patterns are excluded -- Files are sorted by Git change count (files with more changes are at the bottom) - - - - - - - - - -.gitea/workflows/ci.yml -be/.dockerignore -be/.gitignore -be/alembic.ini -be/alembic/env.py -be/alembic/README -be/alembic/script.py.mako -be/alembic/versions/bc37e9c7ae19_fresh_start.py -be/app/api/api_router.py -be/app/api/dependencies.py -be/app/api/v1/api.py -be/app/api/v1/endpoints/auth.py -be/app/api/v1/endpoints/costs.py -be/app/api/v1/endpoints/financials.py -be/app/api/v1/endpoints/groups.py -be/app/api/v1/endpoints/health.py -be/app/api/v1/endpoints/invites.py -be/app/api/v1/endpoints/items.py -be/app/api/v1/endpoints/lists.py -be/app/api/v1/endpoints/ocr.py -be/app/api/v1/endpoints/users.py -be/app/api/v1/test_auth.py -be/app/api/v1/test_users.py -be/app/config.py -be/app/core/api_config.py -be/app/core/exceptions.py -be/app/core/gemini.py -be/app/core/security.py -be/app/crud/expense.py -be/app/crud/group.py -be/app/crud/invite.py -be/app/crud/item.py -be/app/crud/list.py -be/app/crud/settlement.py -be/app/crud/user.py -be/app/database.py -be/app/main.py -be/app/models.py -be/app/schemas/auth.py -be/app/schemas/cost.py -be/app/schemas/expense.py -be/app/schemas/group.py -be/app/schemas/health.py -be/app/schemas/invite.py -be/app/schemas/item.py -be/app/schemas/list.py -be/app/schemas/message.py -be/app/schemas/ocr.py -be/app/schemas/user.py -be/Dockerfile -be/requirements.txt -be/tests/api/v1/endpoints/test_financials.py -be/tests/core/__init__.py -be/tests/core/test_exceptions.py -be/tests/core/test_gemini.py -be/tests/core/test_security.py -be/tests/crud/__init__.py -be/tests/crud/test_expense.py -be/tests/crud/test_group.py -be/tests/crud/test_invite.py -be/tests/crud/test_item.py -be/tests/crud/test_list.py -be/tests/crud/test_settlement.py -be/tests/crud/test_user.py -be/Untitled-1.md -docker-compose.yml -fe/.editorconfig -fe/.gitignore -fe/.npmrc -fe/.prettierrc.json -fe/.vscode/extensions.json -fe/.vscode/settings.json -fe/eslint.config.js -fe/index.html -fe/package.json -fe/postcss.config.js -fe/public/icons/safari-pinned-tab.svg -fe/quasar.config.ts -fe/README.md -fe/src-pwa/custom-service-worker.ts -fe/src-pwa/manifest.json -fe/src-pwa/pwa-env.d.ts -fe/src-pwa/register-service-worker.ts -fe/src-pwa/tsconfig.json -fe/src/App.vue -fe/src/assets/quasar-logo-vertical.svg -fe/src/boot/axios.ts -fe/src/boot/i18n.ts -fe/src/components/ConflictResolutionDialog.vue -fe/src/components/CreateListModal.vue -fe/src/components/EssentialLink.vue -fe/src/components/models.ts -fe/src/components/OfflineIndicator.vue -fe/src/config/api-config.ts -fe/src/config/api.ts -fe/src/css/app.scss -fe/src/css/quasar.variables.scss -fe/src/env.d.ts -fe/src/i18n/en-US/index.ts -fe/src/i18n/index.ts -fe/src/layouts/AuthLayout.vue -fe/src/layouts/MainLayout.vue -fe/src/pages/AccountPage.vue -fe/src/pages/ErrorNotFound.vue -fe/src/pages/GroupDetailPage.vue -fe/src/pages/GroupsPage.vue -fe/src/pages/IndexPage.vue -fe/src/pages/ListDetailPage.vue -fe/src/pages/ListsPage.vue -fe/src/pages/LoginPage.vue -fe/src/pages/SignupPage.vue -fe/src/router/index.ts -fe/src/router/routes.ts -fe/src/stores/auth.ts -fe/src/stores/index.ts -fe/src/stores/offline.ts -fe/tsconfig.json - - - -This section contains the contents of the repository's files. - - -# When you push to the develop branch or open/update a pull request targeting main, Gitea will: -# Trigger the "CI Checks" workflow. -# Execute the checks job on a runner. -# Run each step sequentially. -# If any of the linter/formatter check commands (black --check, ruff check, npm run lint) exit with a non-zero status code (indicating an error or check failure), the step and the entire job will fail. -# You will see the status (success/failure) associated with your commit or pull request in the Gitea interface. - -name: CI Checks - -# Define triggers for the workflow -on: - push: - branches: - - develop # Run on pushes to the develop branch - pull_request: - branches: - - main # Run on pull requests targeting the main branch - -jobs: - checks: - name: Linters and Formatters - runs-on: ubuntu-latest # Use a standard Linux runner environment - - steps: - - name: Checkout code - uses: actions/checkout@v4 # Fetches the repository code - - # --- Backend Checks (Python/FastAPI) --- - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.11' # Match your project/Dockerfile version - cache: 'pip' # Cache pip dependencies based on requirements.txt - cache-dependency-path: 'be/requirements.txt' # Specify path for caching - - - name: Install Backend Dependencies and Tools - working-directory: ./be # Run command within the 'be' directory - run: | - pip install --upgrade pip - pip install -r requirements.txt - pip install black ruff # Install formatters/linters for CI check - - - name: Run Black Formatter Check (Backend) - working-directory: ./be - run: black --check --diff . - - - name: Run Ruff Linter (Backend) - working-directory: ./be - run: ruff check . - - # --- Frontend Checks (SvelteKit/Node.js) --- - - - name: Set up Node.js - uses: actions/setup-node@v4 - with: - node-version: '20.x' # Or specify your required Node.js version (e.g., 'lts/*') - cache: 'npm' # Or 'pnpm' / 'yarn' depending on your package manager - cache-dependency-path: 'fe/package-lock.json' # Adjust lockfile name if needed - - - name: Install Frontend Dependencies - working-directory: ./fe # Run command within the 'fe' directory - run: npm install # Or 'pnpm install' / 'yarn install' - - - name: Run ESLint and Prettier Check (Frontend) - working-directory: ./fe - # Assuming you have a 'lint' script in fe/package.json that runs both - # Example package.json script: "lint": "prettier --check . && eslint ." - run: npm run lint - # If no combined script, run separately: - # run: | - # npm run format -- --check # Or 'npx prettier --check .' - # npm run lint # Or 'npx eslint .' - - # - name: Run Frontend Type Check (Optional but recommended) - # working-directory: ./fe - # # Assuming you have a 'check' script: "check": "svelte-kit sync && svelte-check ..." - # run: npm run check - - # - name: Run Placeholder Tests (Optional) - # run: | - # # Add commands to run backend tests if available - # # Add commands to run frontend tests (e.g., npm test in ./fe) if available - # echo "No tests configured yet." - - - -# Git files -.git -.gitignore - -# Virtual environment -.venv -venv/ -env/ -ENV/ -*.env # Ignore local .env files within the backend directory if any - -# Python cache -__pycache__/ -*.py[cod] -*$py.class - -# IDE files -.idea/ -.vscode/ - -# Test artifacts -.pytest_cache/ -htmlcov/ -.coverage* - -# Other build/temp files -*.egg-info/ -dist/ -build/ -*.db # e.g., sqlite temp dbs - - - -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -pip-wheel-metadata/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -.python-version - -# PEP 582; used by PDM, Flit and potentially other tools. -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv/ -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static analysis results -.pytype/ - -# alembic default temp file -*.db # If using sqlite for alembic versions locally for instance - -# If you use alembic autogenerate, it might create temporary files -# Depending on your DB, adjust if necessary -# *.sql.tmp - -# IDE files -.idea/ -.vscode/ - -# OS generated files -.DS_Store -Thumbs.db - - - -# A generic, single database configuration. - -[alembic] -# path to migration scripts -# Use forward slashes (/) also on windows to provide an os agnostic path -script_location = alembic - -# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s -# Uncomment the line below if you want the files to be prepended with date and time -# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file -# for all available tokens -# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s - -# sys.path path, will be prepended to sys.path if present. -# defaults to the current working directory. -prepend_sys_path = . - -# timezone to use when rendering the date within the migration file -# as well as the filename. -# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. -# Any required deps can installed by adding `alembic[tz]` to the pip requirements -# string value is passed to ZoneInfo() -# leave blank for localtime -# timezone = - -# max length of characters to apply to the "slug" field -# truncate_slug_length = 40 - -# set to 'true' to run the environment during -# the 'revision' command, regardless of autogenerate -# revision_environment = false - -# set to 'true' to allow .pyc and .pyo files without -# a source .py file to be detected as revisions in the -# versions/ directory -# sourceless = false - -# version location specification; This defaults -# to alembic/versions. When using multiple version -# directories, initial revisions must be specified with --version-path. -# The path separator used here should be the separator specified by "version_path_separator" below. -# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions - -# version path separator; As mentioned above, this is the character used to split -# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. -# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. -# Valid values for version_path_separator are: -# -# version_path_separator = : -# version_path_separator = ; -# version_path_separator = space -# version_path_separator = newline -# -# Use os.pathsep. Default configuration used for new projects. -version_path_separator = os - -# set to 'true' to search source files recursively -# in each "version_locations" directory -# new in Alembic version 1.10 -# recursive_version_locations = false - -# the output encoding used when revision files -# are written from script.py.mako -# output_encoding = utf-8 - -; sqlalchemy.url = driver://user:pass@localhost/dbname - - -[post_write_hooks] -# post_write_hooks defines scripts or Python functions that are run -# on newly generated revision scripts. See the documentation for further -# detail and examples - -# format using "black" - use the console_scripts runner, against the "black" entrypoint -# hooks = black -# black.type = console_scripts -# black.entrypoint = black -# black.options = -l 79 REVISION_SCRIPT_FILENAME - -# lint with attempts to fix using "ruff" - use the exec runner, execute a binary -# hooks = ruff -# ruff.type = exec -# ruff.executable = %(here)s/.venv/bin/ruff -# ruff.options = check --fix REVISION_SCRIPT_FILENAME - -# Logging configuration -[loggers] -keys = root,sqlalchemy,alembic - -[handlers] -keys = console - -[formatters] -keys = generic - -[logger_root] -level = WARNING -handlers = console -qualname = - -[logger_sqlalchemy] -level = WARNING -handlers = -qualname = sqlalchemy.engine - -[logger_alembic] -level = INFO -handlers = -qualname = alembic - -[handler_console] -class = StreamHandler -args = (sys.stderr,) -level = NOTSET -formatter = generic - -[formatter_generic] -format = %(levelname)-5.5s [%(name)s] %(message)s -datefmt = %H:%M:%S - - - -from logging.config import fileConfig -import os -import sys - -from sqlalchemy import engine_from_config -from sqlalchemy import pool - -from alembic import context - - -# Ensure the 'app' directory is in the Python path -# Adjust the path if your project structure is different -sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), '..'))) - -# Import your app's Base and settings -from app.models import Base # Import Base from your models module -from app.config import settings # Import settings to get DATABASE_URL - -# this is the Alembic Config object, which provides -# access to the values within the .ini file in use. -config = context.config - -# Set the sqlalchemy.url from your application settings -# Use a synchronous version of the URL for Alembic's operations -sync_db_url = settings.DATABASE_URL.replace("+asyncpg", "") if settings.DATABASE_URL else None -if not sync_db_url: - raise ValueError("DATABASE_URL not found in settings for Alembic.") -config.set_main_option('sqlalchemy.url', sync_db_url) - -# Interpret the config file for Python logging. -# This line sets up loggers basically. -if config.config_file_name is not None: - fileConfig(config.config_file_name) - -# add your model's MetaData object here -# for 'autogenerate' support -# from myapp import mymodel -# target_metadata = mymodel.Base.metadata -target_metadata = Base.metadata - -# other values from the config, defined by the needs of env.py, -# can be acquired: -# my_important_option = config.get_main_option("my_important_option") -# ... etc. - - -def run_migrations_offline() -> None: - """Run migrations in 'offline' mode. - - This configures the context with just a URL - and not an Engine, though an Engine is acceptable - here as well. By skipping the Engine creation - we don't even need a DBAPI to be available. - - Calls to context.execute() here emit the given string to the - script output. - - """ - url = config.get_main_option("sqlalchemy.url") - context.configure( - url=url, - target_metadata=target_metadata, - literal_binds=True, - dialect_opts={"paramstyle": "named"}, - ) - - with context.begin_transaction(): - context.run_migrations() - - -def run_migrations_online() -> None: - """Run migrations in 'online' mode. - - In this scenario we need to create an Engine - and associate a connection with the context. - - """ - connectable = engine_from_config( - config.get_section(config.config_ini_section, {}), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) - - with connectable.connect() as connection: - context.configure( - connection=connection, target_metadata=target_metadata - ) - - with context.begin_transaction(): - context.run_migrations() - - -if context.is_offline_mode(): - run_migrations_offline() -else: - run_migrations_online() - - - -Generic single-database configuration. - - - -"""${message} - -Revision ID: ${up_revision} -Revises: ${down_revision | comma,n} -Create Date: ${create_date} - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa -${imports if imports else ""} - -# revision identifiers, used by Alembic. -revision: str = ${repr(up_revision)} -down_revision: Union[str, None] = ${repr(down_revision)} -branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} -depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} - - -def upgrade() -> None: - """Upgrade schema.""" - ${upgrades if upgrades else "pass"} - - -def downgrade() -> None: - """Downgrade schema.""" - ${downgrades if downgrades else "pass"} - - - -"""fresh start - -Revision ID: bc37e9c7ae19 -Revises: -Create Date: 2025-05-08 16:06:51.208542 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = 'bc37e9c7ae19' -down_revision: Union[str, None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - """Upgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - op.create_table('users', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('email', sa.String(), nullable=False), - sa.Column('password_hash', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=True), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.PrimaryKeyConstraint('id') - ) - op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True) - op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False) - op.create_index(op.f('ix_users_name'), 'users', ['name'], unique=False) - op.create_table('groups', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', sa.String(), nullable=False), - sa.Column('created_by_id', sa.Integer(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') - ) - op.create_index(op.f('ix_groups_id'), 'groups', ['id'], unique=False) - op.create_index(op.f('ix_groups_name'), 'groups', ['name'], unique=False) - op.create_table('invites', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('code', sa.String(), nullable=False), - sa.Column('group_id', sa.Integer(), nullable=False), - sa.Column('created_by_id', sa.Integer(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('is_active', sa.Boolean(), nullable=False), - sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ), - sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') - ) - op.create_index('ix_invites_active_code', 'invites', ['code'], unique=True, postgresql_where=sa.text('is_active = true')) - op.create_index(op.f('ix_invites_code'), 'invites', ['code'], unique=False) - op.create_index(op.f('ix_invites_id'), 'invites', ['id'], unique=False) - op.create_table('lists', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', sa.String(), nullable=False), - sa.Column('description', sa.Text(), nullable=True), - sa.Column('created_by_id', sa.Integer(), nullable=False), - sa.Column('group_id', sa.Integer(), nullable=True), - sa.Column('is_complete', sa.Boolean(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('version', sa.Integer(), server_default='1', nullable=False), - sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ), - sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ), - sa.PrimaryKeyConstraint('id') - ) - op.create_index(op.f('ix_lists_id'), 'lists', ['id'], unique=False) - op.create_index(op.f('ix_lists_name'), 'lists', ['name'], unique=False) - op.create_table('settlements', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('group_id', sa.Integer(), nullable=False), - sa.Column('paid_by_user_id', sa.Integer(), nullable=False), - sa.Column('paid_to_user_id', sa.Integer(), nullable=False), - sa.Column('amount', sa.Numeric(precision=10, scale=2), nullable=False), - sa.Column('settlement_date', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('description', sa.Text(), nullable=True), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('version', sa.Integer(), server_default='1', nullable=False), - sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ), - sa.ForeignKeyConstraint(['paid_by_user_id'], ['users.id'], ), - sa.ForeignKeyConstraint(['paid_to_user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') - ) - op.create_index(op.f('ix_settlements_id'), 'settlements', ['id'], unique=False) - op.create_table('user_groups', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('group_id', sa.Integer(), nullable=False), - sa.Column('role', sa.Enum('owner', 'member', name='userroleenum'), nullable=False), - sa.Column('joined_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('user_id', 'group_id', name='uq_user_group') - ) - op.create_index(op.f('ix_user_groups_id'), 'user_groups', ['id'], unique=False) - op.create_table('items', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('list_id', sa.Integer(), nullable=False), - sa.Column('name', sa.String(), nullable=False), - sa.Column('quantity', sa.String(), nullable=True), - sa.Column('is_complete', sa.Boolean(), nullable=False), - sa.Column('price', sa.Numeric(precision=10, scale=2), nullable=True), - sa.Column('added_by_id', sa.Integer(), nullable=False), - sa.Column('completed_by_id', sa.Integer(), nullable=True), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('version', sa.Integer(), server_default='1', nullable=False), - sa.ForeignKeyConstraint(['added_by_id'], ['users.id'], ), - sa.ForeignKeyConstraint(['completed_by_id'], ['users.id'], ), - sa.ForeignKeyConstraint(['list_id'], ['lists.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') - ) - op.create_index(op.f('ix_items_id'), 'items', ['id'], unique=False) - op.create_index(op.f('ix_items_name'), 'items', ['name'], unique=False) - op.create_table('expenses', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('description', sa.String(), nullable=False), - sa.Column('total_amount', sa.Numeric(precision=10, scale=2), nullable=False), - sa.Column('currency', sa.String(), nullable=False), - sa.Column('expense_date', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('split_type', sa.Enum('EQUAL', 'EXACT_AMOUNTS', 'PERCENTAGE', 'SHARES', 'ITEM_BASED', name='splittypeenum'), nullable=False), - sa.Column('list_id', sa.Integer(), nullable=True), - sa.Column('group_id', sa.Integer(), nullable=True), - sa.Column('item_id', sa.Integer(), nullable=True), - sa.Column('paid_by_user_id', sa.Integer(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('version', sa.Integer(), server_default='1', nullable=False), - sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ), - sa.ForeignKeyConstraint(['item_id'], ['items.id'], ), - sa.ForeignKeyConstraint(['list_id'], ['lists.id'], ), - sa.ForeignKeyConstraint(['paid_by_user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') - ) - op.create_index(op.f('ix_expenses_id'), 'expenses', ['id'], unique=False) - op.create_table('expense_splits', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('expense_id', sa.Integer(), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('owed_amount', sa.Numeric(precision=10, scale=2), nullable=False), - sa.Column('share_percentage', sa.Numeric(precision=5, scale=2), nullable=True), - sa.Column('share_units', sa.Integer(), nullable=True), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.ForeignKeyConstraint(['expense_id'], ['expenses.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('expense_id', 'user_id', name='uq_expense_user_split') - ) - op.create_index(op.f('ix_expense_splits_id'), 'expense_splits', ['id'], unique=False) - # ### end Alembic commands ### - - -def downgrade() -> None: - """Downgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - op.drop_index(op.f('ix_expense_splits_id'), table_name='expense_splits') - op.drop_table('expense_splits') - op.drop_index(op.f('ix_expenses_id'), table_name='expenses') - op.drop_table('expenses') - op.drop_index(op.f('ix_items_name'), table_name='items') - op.drop_index(op.f('ix_items_id'), table_name='items') - op.drop_table('items') - op.drop_index(op.f('ix_user_groups_id'), table_name='user_groups') - op.drop_table('user_groups') - op.drop_index(op.f('ix_settlements_id'), table_name='settlements') - op.drop_table('settlements') - op.drop_index(op.f('ix_lists_name'), table_name='lists') - op.drop_index(op.f('ix_lists_id'), table_name='lists') - op.drop_table('lists') - op.drop_index(op.f('ix_invites_id'), table_name='invites') - op.drop_index(op.f('ix_invites_code'), table_name='invites') - op.drop_index('ix_invites_active_code', table_name='invites', postgresql_where=sa.text('is_active = true')) - op.drop_table('invites') - op.drop_index(op.f('ix_groups_name'), table_name='groups') - op.drop_index(op.f('ix_groups_id'), table_name='groups') - op.drop_table('groups') - op.drop_index(op.f('ix_users_name'), table_name='users') - op.drop_index(op.f('ix_users_id'), table_name='users') - op.drop_index(op.f('ix_users_email'), table_name='users') - op.drop_table('users') - # ### end Alembic commands ### - - - -# app/api/api_router.py -from fastapi import APIRouter - -from app.api.v1.api import api_router_v1 # Import the v1 router - -api_router = APIRouter() - -# Include versioned routers here, adding the /api prefix -api_router.include_router(api_router_v1, prefix="/v1") # Mounts v1 endpoints under /api/v1/... - -# Add other API versions later -# e.g., api_router.include_router(api_router_v2, prefix="/v2") - - - -# app/api/v1/endpoints/financials.py -import logging -from fastapi import APIRouter, Depends, HTTPException, status, Query, Response -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select -from typing import List as PyList, Optional, Sequence - -from app.database import get_db -from app.api.dependencies import get_current_user -from app.models import User as UserModel, Group as GroupModel, List as ListModel, UserGroup as UserGroupModel, UserRoleEnum -from app.schemas.expense import ( - ExpenseCreate, ExpensePublic, - SettlementCreate, SettlementPublic, - ExpenseUpdate, SettlementUpdate -) -from app.crud import expense as crud_expense -from app.crud import settlement as crud_settlement -from app.crud import group as crud_group -from app.crud import list as crud_list -from app.core.exceptions import ( - ListNotFoundError, GroupNotFoundError, UserNotFoundError, - InvalidOperationError, GroupPermissionError, ListPermissionError, - ItemNotFoundError, GroupMembershipError -) - -logger = logging.getLogger(__name__) -router = APIRouter() - -# --- Helper for permissions --- -async def check_list_access_for_financials(db: AsyncSession, list_id: int, user_id: int, action: str = "access financial data for"): - try: - await crud_list.check_list_permission(db=db, list_id=list_id, user_id=user_id, require_member=True) - except ListPermissionError as e: - logger.warning(f"ListPermissionError in check_list_access_for_financials for list {list_id}, user {user_id}, action '{action}': {e.detail}") - raise ListPermissionError(list_id, action=action) - except ListNotFoundError: - raise - -# --- Expense Endpoints --- -@router.post( - "/expenses", - response_model=ExpensePublic, - status_code=status.HTTP_201_CREATED, - summary="Create New Expense", - tags=["Expenses"] -) -async def create_new_expense( - expense_in: ExpenseCreate, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - logger.info(f"User {current_user.email} creating expense: {expense_in.description}") - effective_group_id = expense_in.group_id - is_group_context = False - - if expense_in.list_id: - # Check basic access to list (implies membership if list is in group) - await check_list_access_for_financials(db, expense_in.list_id, current_user.id, action="create expenses for") - list_obj = await db.get(ListModel, expense_in.list_id) - if not list_obj: - raise ListNotFoundError(expense_in.list_id) - if list_obj.group_id: - if expense_in.group_id and list_obj.group_id != expense_in.group_id: - raise InvalidOperationError(f"List {list_obj.id} belongs to group {list_obj.group_id}, not group {expense_in.group_id} specified in expense.") - effective_group_id = list_obj.group_id - is_group_context = True # Expense is tied to a group via the list - elif expense_in.group_id: - raise InvalidOperationError(f"Personal list {list_obj.id} cannot have expense associated with group {expense_in.group_id}.") - # If list is personal, no group check needed yet, handled by payer check below. - - elif effective_group_id: # Only group_id provided for expense - is_group_context = True - # Ensure user is at least a member to create expense in group context - await crud_group.check_group_membership(db, group_id=effective_group_id, user_id=current_user.id, action="create expenses for") - else: - # This case should ideally be caught by earlier checks if list_id was present but list was personal. - # If somehow reached, it means no list_id and no group_id. - raise InvalidOperationError("Expense must be linked to a list_id or group_id.") - - # Finalize expense payload with correct group_id if derived - expense_in_final = expense_in.model_copy(update={"group_id": effective_group_id}) - - # --- Granular Permission Check for Payer --- - if expense_in_final.paid_by_user_id != current_user.id: - logger.warning(f"User {current_user.email} attempting to create expense paid by other user {expense_in_final.paid_by_user_id}") - # If creating expense paid by someone else, user MUST be owner IF in group context - if is_group_context and effective_group_id: - try: - await crud_group.check_user_role_in_group(db, group_id=effective_group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="create expense paid by another user") - except GroupPermissionError as e: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Only group owners can create expenses paid by others. {str(e)}") - else: - # Cannot create expense paid by someone else for a personal list (no group context) - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Cannot create expense paid by another user for a personal list.") - # If paying for self, basic list/group membership check above is sufficient. - - try: - created_expense = await crud_expense.create_expense(db=db, expense_in=expense_in_final, current_user_id=current_user.id) - logger.info(f"Expense '{created_expense.description}' (ID: {created_expense.id}) created successfully.") - return created_expense - except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError, GroupMembershipError) as e: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) - except NotImplementedError as e: - raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=str(e)) - except Exception as e: - logger.error(f"Unexpected error creating expense: {str(e)}", exc_info=True) - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.") - -@router.get("/expenses/{expense_id}", response_model=ExpensePublic, summary="Get Expense by ID", tags=["Expenses"]) -async def get_expense( - expense_id: int, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - logger.info(f"User {current_user.email} requesting expense ID {expense_id}") - expense = await crud_expense.get_expense_by_id(db, expense_id=expense_id) - if not expense: - raise ItemNotFoundError(item_id=expense_id) - - if expense.list_id: - await check_list_access_for_financials(db, expense.list_id, current_user.id) - elif expense.group_id: - await crud_group.check_group_membership(db, group_id=expense.group_id, user_id=current_user.id) - elif expense.paid_by_user_id != current_user.id: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to view this expense") - return expense - -@router.get("/lists/{list_id}/expenses", response_model=PyList[ExpensePublic], summary="List Expenses for a List", tags=["Expenses", "Lists"]) -async def list_list_expenses( - list_id: int, - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=200), - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - logger.info(f"User {current_user.email} listing expenses for list ID {list_id}") - await check_list_access_for_financials(db, list_id, current_user.id) - expenses = await crud_expense.get_expenses_for_list(db, list_id=list_id, skip=skip, limit=limit) - return expenses - -@router.get("/groups/{group_id}/expenses", response_model=PyList[ExpensePublic], summary="List Expenses for a Group", tags=["Expenses", "Groups"]) -async def list_group_expenses( - group_id: int, - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=200), - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - logger.info(f"User {current_user.email} listing expenses for group ID {group_id}") - await crud_group.check_group_membership(db, group_id=group_id, user_id=current_user.id, action="list expenses for") - expenses = await crud_expense.get_expenses_for_group(db, group_id=group_id, skip=skip, limit=limit) - return expenses - -@router.put("/expenses/{expense_id}", response_model=ExpensePublic, summary="Update Expense", tags=["Expenses"]) -async def update_expense_details( - expense_id: int, - expense_in: ExpenseUpdate, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """ - Updates an existing expense (description, currency, expense_date only). - Requires the current version number for optimistic locking. - User must have permission to modify the expense (e.g., be the payer or group admin). - """ - logger.info(f"User {current_user.email} attempting to update expense ID {expense_id} (version {expense_in.version})") - expense_db = await crud_expense.get_expense_by_id(db, expense_id=expense_id) - if not expense_db: - raise ItemNotFoundError(item_id=expense_id) - - # --- Granular Permission Check --- - can_modify = False - # 1. User paid for the expense - if expense_db.paid_by_user_id == current_user.id: - can_modify = True - # 2. OR User is owner of the group the expense belongs to - elif expense_db.group_id: - try: - await crud_group.check_user_role_in_group(db, group_id=expense_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="modify group expenses") - can_modify = True - logger.info(f"Allowing update for expense {expense_id} by group owner {current_user.email}") - except GroupMembershipError: # User not even a member - pass # Keep can_modify as False - except GroupPermissionError: # User is member but not owner - pass # Keep can_modify as False - except GroupNotFoundError: # Group doesn't exist (data integrity issue) - logger.error(f"Group {expense_db.group_id} not found for expense {expense_id} during update check.") - pass # Keep can_modify as False - # Note: If expense is only linked to a personal list (no group), only payer can modify. - - if not can_modify: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot modify this expense (must be payer or group owner)") - - try: - updated_expense = await crud_expense.update_expense(db=db, expense_db=expense_db, expense_in=expense_in) - logger.info(f"Expense ID {expense_id} updated successfully to version {updated_expense.version}.") - return updated_expense - except InvalidOperationError as e: - # Check if it's a version conflict (409) or other validation error (400) - status_code = status.HTTP_400_BAD_REQUEST - if "version" in str(e).lower(): - status_code = status.HTTP_409_CONFLICT - raise HTTPException(status_code=status_code, detail=str(e)) - except Exception as e: - logger.error(f"Unexpected error updating expense {expense_id}: {str(e)}", exc_info=True) - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.") - -@router.delete("/expenses/{expense_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete Expense", tags=["Expenses"]) -async def delete_expense_record( - expense_id: int, - expected_version: Optional[int] = Query(None, description="Expected version for optimistic locking"), - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """ - Deletes an expense and its associated splits. - Requires expected_version query parameter for optimistic locking. - User must have permission to delete the expense (e.g., be the payer or group admin). - """ - logger.info(f"User {current_user.email} attempting to delete expense ID {expense_id} (expected version {expected_version})") - expense_db = await crud_expense.get_expense_by_id(db, expense_id=expense_id) - if not expense_db: - # Return 204 even if not found, as the end state is achieved (item is gone) - logger.warning(f"Attempt to delete non-existent expense ID {expense_id}") - return Response(status_code=status.HTTP_204_NO_CONTENT) - # Alternatively, raise NotFoundError(detail=f"Expense {expense_id} not found") -> 404 - - # --- Granular Permission Check --- - can_delete = False - # 1. User paid for the expense - if expense_db.paid_by_user_id == current_user.id: - can_delete = True - # 2. OR User is owner of the group the expense belongs to - elif expense_db.group_id: - try: - await crud_group.check_user_role_in_group(db, group_id=expense_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="delete group expenses") - can_delete = True - logger.info(f"Allowing delete for expense {expense_id} by group owner {current_user.email}") - except GroupMembershipError: - pass - except GroupPermissionError: - pass - except GroupNotFoundError: - logger.error(f"Group {expense_db.group_id} not found for expense {expense_id} during delete check.") - pass - - if not can_delete: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot delete this expense (must be payer or group owner)") - - try: - await crud_expense.delete_expense(db=db, expense_db=expense_db, expected_version=expected_version) - logger.info(f"Expense ID {expense_id} deleted successfully.") - # No need to return content on 204 - except InvalidOperationError as e: - # Check if it's a version conflict (409) or other validation error (400) - status_code = status.HTTP_400_BAD_REQUEST - if "version" in str(e).lower(): - status_code = status.HTTP_409_CONFLICT - raise HTTPException(status_code=status_code, detail=str(e)) - except Exception as e: - logger.error(f"Unexpected error deleting expense {expense_id}: {str(e)}", exc_info=True) - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.") - - return Response(status_code=status.HTTP_204_NO_CONTENT) - -# --- Settlement Endpoints --- -@router.post( - "/settlements", - response_model=SettlementPublic, - status_code=status.HTTP_201_CREATED, - summary="Record New Settlement", - tags=["Settlements"] -) -async def create_new_settlement( - settlement_in: SettlementCreate, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - logger.info(f"User {current_user.email} recording settlement in group {settlement_in.group_id}") - await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=current_user.id, action="record settlements in") - try: - await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=settlement_in.paid_by_user_id, action="be a payer in this group's settlement") - await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=settlement_in.paid_to_user_id, action="be a payee in this group's settlement") - except GroupMembershipError as e: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Payer or payee issue: {str(e)}") - except GroupNotFoundError as e: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) - - try: - created_settlement = await crud_settlement.create_settlement(db=db, settlement_in=settlement_in, current_user_id=current_user.id) - logger.info(f"Settlement ID {created_settlement.id} recorded successfully in group {settlement_in.group_id}.") - return created_settlement - except (UserNotFoundError, GroupNotFoundError, InvalidOperationError, GroupMembershipError) as e: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) - except Exception as e: - logger.error(f"Unexpected error recording settlement: {str(e)}", exc_info=True) - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.") - -@router.get("/settlements/{settlement_id}", response_model=SettlementPublic, summary="Get Settlement by ID", tags=["Settlements"]) -async def get_settlement( - settlement_id: int, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - logger.info(f"User {current_user.email} requesting settlement ID {settlement_id}") - settlement = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id) - if not settlement: - raise ItemNotFoundError(item_id=settlement_id) - - is_party_to_settlement = current_user.id in [settlement.paid_by_user_id, settlement.paid_to_user_id] - try: - await crud_group.check_group_membership(db, group_id=settlement.group_id, user_id=current_user.id) - except GroupMembershipError: - if not is_party_to_settlement: - raise GroupMembershipError(settlement.group_id, action="view this settlement's details") - logger.info(f"User {current_user.email} (party to settlement) viewing settlement {settlement_id} for group {settlement.group_id}.") - return settlement - -@router.get("/groups/{group_id}/settlements", response_model=PyList[SettlementPublic], summary="List Settlements for a Group", tags=["Settlements", "Groups"]) -async def list_group_settlements( - group_id: int, - skip: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=200), - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - logger.info(f"User {current_user.email} listing settlements for group ID {group_id}") - await crud_group.check_group_membership(db, group_id=group_id, user_id=current_user.id, action="list settlements for this group") - settlements = await crud_settlement.get_settlements_for_group(db, group_id=group_id, skip=skip, limit=limit) - return settlements - -@router.put("/settlements/{settlement_id}", response_model=SettlementPublic, summary="Update Settlement", tags=["Settlements"]) -async def update_settlement_details( - settlement_id: int, - settlement_in: SettlementUpdate, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """ - Updates an existing settlement (description, settlement_date only). - Requires the current version number for optimistic locking. - User must have permission (e.g., be involved party or group admin). - """ - logger.info(f"User {current_user.email} attempting to update settlement ID {settlement_id} (version {settlement_in.version})") - settlement_db = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id) - if not settlement_db: - raise ItemNotFoundError(item_id=settlement_id) - - # --- Granular Permission Check --- - can_modify = False - # 1. User is involved party (payer or payee) - is_party = current_user.id in [settlement_db.paid_by_user_id, settlement_db.paid_to_user_id] - if is_party: - can_modify = True - # 2. OR User is owner of the group the settlement belongs to - # Note: Settlements always have a group_id based on current model - elif settlement_db.group_id: - try: - await crud_group.check_user_role_in_group(db, group_id=settlement_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="modify group settlements") - can_modify = True - logger.info(f"Allowing update for settlement {settlement_id} by group owner {current_user.email}") - except GroupMembershipError: - pass - except GroupPermissionError: - pass - except GroupNotFoundError: - logger.error(f"Group {settlement_db.group_id} not found for settlement {settlement_id} during update check.") - pass - - if not can_modify: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot modify this settlement (must be involved party or group owner)") - - try: - updated_settlement = await crud_settlement.update_settlement(db=db, settlement_db=settlement_db, settlement_in=settlement_in) - logger.info(f"Settlement ID {settlement_id} updated successfully to version {updated_settlement.version}.") - return updated_settlement - except InvalidOperationError as e: - status_code = status.HTTP_400_BAD_REQUEST - if "version" in str(e).lower(): - status_code = status.HTTP_409_CONFLICT - raise HTTPException(status_code=status_code, detail=str(e)) - except Exception as e: - logger.error(f"Unexpected error updating settlement {settlement_id}: {str(e)}", exc_info=True) - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.") - -@router.delete("/settlements/{settlement_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete Settlement", tags=["Settlements"]) -async def delete_settlement_record( - settlement_id: int, - expected_version: Optional[int] = Query(None, description="Expected version for optimistic locking"), - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """ - Deletes a settlement. - Requires expected_version query parameter for optimistic locking. - User must have permission (e.g., be involved party or group admin). - """ - logger.info(f"User {current_user.email} attempting to delete settlement ID {settlement_id} (expected version {expected_version})") - settlement_db = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id) - if not settlement_db: - logger.warning(f"Attempt to delete non-existent settlement ID {settlement_id}") - return Response(status_code=status.HTTP_204_NO_CONTENT) - - # --- Granular Permission Check --- - can_delete = False - # 1. User is involved party (payer or payee) - is_party = current_user.id in [settlement_db.paid_by_user_id, settlement_db.paid_to_user_id] - if is_party: - can_delete = True - # 2. OR User is owner of the group the settlement belongs to - elif settlement_db.group_id: - try: - await crud_group.check_user_role_in_group(db, group_id=settlement_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="delete group settlements") - can_delete = True - logger.info(f"Allowing delete for settlement {settlement_id} by group owner {current_user.email}") - except GroupMembershipError: - pass - except GroupPermissionError: - pass - except GroupNotFoundError: - logger.error(f"Group {settlement_db.group_id} not found for settlement {settlement_id} during delete check.") - pass - - if not can_delete: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot delete this settlement (must be involved party or group owner)") - - try: - await crud_settlement.delete_settlement(db=db, settlement_db=settlement_db, expected_version=expected_version) - logger.info(f"Settlement ID {settlement_id} deleted successfully.") - except InvalidOperationError as e: - status_code = status.HTTP_400_BAD_REQUEST - if "version" in str(e).lower(): - status_code = status.HTTP_409_CONFLICT - raise HTTPException(status_code=status_code, detail=str(e)) - except Exception as e: - logger.error(f"Unexpected error deleting settlement {settlement_id}: {str(e)}", exc_info=True) - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.") - - return Response(status_code=status.HTTP_204_NO_CONTENT) - -# TODO (remaining from original list): -# (None - GET/POST/PUT/DELETE implemented for Expense/Settlement) - - - -# app/api/v1/endpoints/users.py -import logging -from fastapi import APIRouter, Depends, HTTPException - -from app.api.dependencies import get_current_user # Import the dependency -from app.schemas.user import UserPublic # Import the response schema -from app.models import User as UserModel # Import the DB model for type hinting - -logger = logging.getLogger(__name__) -router = APIRouter() - -@router.get( - "/me", - response_model=UserPublic, # Use the public schema to avoid exposing hash - summary="Get Current User", - description="Retrieves the details of the currently authenticated user.", - tags=["Users"] -) -async def read_users_me( - current_user: UserModel = Depends(get_current_user) # Apply the dependency -): - """ - Returns the data for the user associated with the current valid access token. - """ - logger.info(f"Fetching details for current user: {current_user.email}") - # The 'current_user' object is the SQLAlchemy model instance returned by the dependency. - # Pydantic's response_model will automatically convert it using UserPublic schema. - return current_user - -# Add other user-related endpoints here later (e.g., update user, list users (admin)) - - - -# Example: be/tests/api/v1/test_auth.py -import pytest -from httpx import AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession - -from app.core.security import verify_password -from app.crud.user import get_user_by_email -from app.schemas.user import UserPublic # Import for response validation -from app.schemas.auth import Token # Import for response validation - -pytestmark = pytest.mark.asyncio - -async def test_signup_success(client: AsyncClient, db: AsyncSession): - email = "testsignup@example.com" - password = "testpassword123" - response = await client.post( - "/api/v1/auth/signup", - json={"email": email, "password": password, "name": "Test Signup"}, - ) - assert response.status_code == 201 - data = response.json() - assert data["email"] == email - assert data["name"] == "Test Signup" - assert "id" in data - assert "created_at" in data - # Verify password hash is NOT returned - assert "password" not in data - assert "hashed_password" not in data - - # Verify user exists in DB - user_db = await get_user_by_email(db, email=email) - assert user_db is not None - assert user_db.email == email - assert verify_password(password, user_db.hashed_password) - - -async def test_signup_email_exists(client: AsyncClient, db: AsyncSession): - # Create user first - email = "testexists@example.com" - password = "testpassword123" - await client.post( - "/api/v1/auth/signup", - json={"email": email, "password": password}, - ) - - # Attempt signup again with same email - response = await client.post( - "/api/v1/auth/signup", - json={"email": email, "password": "anotherpassword"}, - ) - assert response.status_code == 400 - assert "Email already registered" in response.json()["detail"] - - -async def test_login_success(client: AsyncClient, db: AsyncSession): - email = "testlogin@example.com" - password = "testpassword123" - # Create user first via signup - signup_res = await client.post( - "/api/v1/auth/signup", json={"email": email, "password": password} - ) - assert signup_res.status_code == 201 - - # Attempt login - login_payload = {"username": email, "password": password} - response = await client.post("/api/v1/auth/login", data=login_payload) # Use data for form encoding - assert response.status_code == 200 - data = response.json() - assert "access_token" in data - assert data["token_type"] == "bearer" - # Optionally verify the token itself here using verify_access_token - - -async def test_login_wrong_password(client: AsyncClient, db: AsyncSession): - email = "testloginwrong@example.com" - password = "testpassword123" - await client.post( - "/api/v1/auth/signup", json={"email": email, "password": password} - ) - - login_payload = {"username": email, "password": "wrongpassword"} - response = await client.post("/api/v1/auth/login", data=login_payload) - assert response.status_code == 401 - assert "Incorrect email or password" in response.json()["detail"] - assert "WWW-Authenticate" in response.headers - assert response.headers["WWW-Authenticate"] == "Bearer" - -async def test_login_user_not_found(client: AsyncClient, db: AsyncSession): - login_payload = {"username": "nosuchuser@example.com", "password": "anypassword"} - response = await client.post("/api/v1/auth/login", data=login_payload) - assert response.status_code == 401 - assert "Incorrect email or password" in response.json()["detail"] - - - -# Example: be/tests/api/v1/test_users.py -import pytest -from httpx import AsyncClient - -from app.schemas.user import UserPublic # For response validation -from app.core.security import create_access_token - -pytestmark = pytest.mark.asyncio - -# Helper function to get a valid token -async def get_auth_headers(client: AsyncClient, email: str, password: str) -> dict: - """Logs in a user and returns authorization headers.""" - login_payload = {"username": email, "password": password} - response = await client.post("/api/v1/auth/login", data=login_payload) - response.raise_for_status() # Raise exception for non-2xx status - token_data = response.json() - return {"Authorization": f"Bearer {token_data['access_token']}"} - -async def test_read_users_me_success(client: AsyncClient): - # 1. Create user - email = "testme@example.com" - password = "password123" - signup_res = await client.post( - "/api/v1/auth/signup", json={"email": email, "password": password, "name": "Test Me"} - ) - assert signup_res.status_code == 201 - user_data = UserPublic(**signup_res.json()) # Validate signup response - - # 2. Get token - headers = await get_auth_headers(client, email, password) - - # 3. Request /users/me - response = await client.get("/api/v1/users/me", headers=headers) - assert response.status_code == 200 - me_data = response.json() - assert me_data["email"] == email - assert me_data["name"] == "Test Me" - assert me_data["id"] == user_data.id # Check ID matches signup - assert "password" not in me_data - assert "hashed_password" not in me_data - - -async def test_read_users_me_no_token(client: AsyncClient): - response = await client.get("/api/v1/users/me") # No headers - assert response.status_code == 401 # Handled by OAuth2PasswordBearer - assert response.json()["detail"] == "Not authenticated" # Default detail from OAuth2PasswordBearer - -async def test_read_users_me_invalid_token(client: AsyncClient): - headers = {"Authorization": "Bearer invalid-token-string"} - response = await client.get("/api/v1/users/me", headers=headers) - assert response.status_code == 401 - assert response.json()["detail"] == "Could not validate credentials" # Detail from our dependency - -async def test_read_users_me_expired_token(client: AsyncClient): - # Create a short-lived token manually (or adjust settings temporarily) - email = "testexpired@example.com" - # Assume create_access_token allows timedelta override - expired_token = create_access_token(subject=email, expires_delta=timedelta(seconds=-10)) - headers = {"Authorization": f"Bearer {expired_token}"} - - response = await client.get("/api/v1/users/me", headers=headers) - assert response.status_code == 401 - assert response.json()["detail"] == "Could not validate credentials" - -# Add test case for valid token but user deleted from DB if needed - - - -from typing import Dict, Any -from app.config import settings - -# API Version -API_VERSION = "v1" - -# API Prefix -API_PREFIX = f"/api/{API_VERSION}" - -# API Endpoints -class APIEndpoints: - # Auth - AUTH = { - "LOGIN": "/auth/login", - "SIGNUP": "/auth/signup", - "REFRESH_TOKEN": "/auth/refresh-token", - } - - # Users - USERS = { - "PROFILE": "/users/profile", - "UPDATE_PROFILE": "/users/profile", - } - - # Lists - LISTS = { - "BASE": "/lists", - "BY_ID": "/lists/{id}", - "ITEMS": "/lists/{list_id}/items", - "ITEM": "/lists/{list_id}/items/{item_id}", - } - - # Groups - GROUPS = { - "BASE": "/groups", - "BY_ID": "/groups/{id}", - "LISTS": "/groups/{group_id}/lists", - "MEMBERS": "/groups/{group_id}/members", - } - - # Invites - INVITES = { - "BASE": "/invites", - "BY_ID": "/invites/{id}", - "ACCEPT": "/invites/{id}/accept", - "DECLINE": "/invites/{id}/decline", - } - - # OCR - OCR = { - "PROCESS": "/ocr/process", - } - - # Financials - FINANCIALS = { - "EXPENSES": "/financials/expenses", - "EXPENSE": "/financials/expenses/{id}", - "SETTLEMENTS": "/financials/settlements", - "SETTLEMENT": "/financials/settlements/{id}", - } - - # Health - HEALTH = { - "CHECK": "/health", - } - -# API Metadata -API_METADATA = { - "title": settings.API_TITLE, - "description": settings.API_DESCRIPTION, - "version": settings.API_VERSION, - "openapi_url": settings.API_OPENAPI_URL, - "docs_url": settings.API_DOCS_URL, - "redoc_url": settings.API_REDOC_URL, -} - -# API Tags -API_TAGS = [ - {"name": "Authentication", "description": "Authentication and authorization endpoints"}, - {"name": "Users", "description": "User management endpoints"}, - {"name": "Lists", "description": "Shopping list management endpoints"}, - {"name": "Groups", "description": "Group management endpoints"}, - {"name": "Invites", "description": "Group invitation management endpoints"}, - {"name": "OCR", "description": "Optical Character Recognition endpoints"}, - {"name": "Financials", "description": "Financial management endpoints"}, - {"name": "Health", "description": "Health check endpoints"}, -] - -# Helper function to get full API URL -def get_api_url(endpoint: str, **kwargs) -> str: - """ - Get the full API URL for an endpoint. - - Args: - endpoint: The endpoint path - **kwargs: Path parameters to format the endpoint - - Returns: - str: The full API URL - """ - formatted_endpoint = endpoint.format(**kwargs) - return f"{API_PREFIX}{formatted_endpoint}" - - - -# app/crud/expense.py -import logging # Add logging import -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select -from sqlalchemy.orm import selectinload, joinedload -from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation -from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict -from datetime import datetime, timezone # Added timezone - -from app.models import ( - Expense as ExpenseModel, - ExpenseSplit as ExpenseSplitModel, - User as UserModel, - List as ListModel, - Group as GroupModel, - UserGroup as UserGroupModel, - SplitTypeEnum, - Item as ItemModel -) -from app.schemas.expense import ExpenseCreate, ExpenseSplitCreate, ExpenseUpdate # Removed unused ExpenseUpdate -from app.core.exceptions import ( - # Using existing specific exceptions where possible - ListNotFoundError, - GroupNotFoundError, - UserNotFoundError, - InvalidOperationError # Import the new exception -) - -# Placeholder for InvalidOperationError if not defined in app.core.exceptions -# This should be a proper HTTPException subclass if used in API layer -# class CrudInvalidOperationError(ValueError): # For internal CRUD validation logic -# pass - -logger = logging.getLogger(__name__) # Initialize logger - -async def get_users_for_splitting(db: AsyncSession, expense_group_id: Optional[int], expense_list_id: Optional[int], expense_paid_by_user_id: int) -> PyList[UserModel]: - """ - Determines the list of users an expense should be split amongst. - Priority: Group members (if group_id), then List's group members or creator (if list_id). - Fallback to only the payer if no other context yields users. - """ - users_to_split_with: PyList[UserModel] = [] - processed_user_ids = set() - - async def _add_user(user: Optional[UserModel]): - if user and user.id not in processed_user_ids: - users_to_split_with.append(user) - processed_user_ids.add(user.id) - - if expense_group_id: - group_result = await db.execute( - select(GroupModel).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))) - .where(GroupModel.id == expense_group_id) - ) - group = group_result.scalars().first() - if not group: - raise GroupNotFoundError(expense_group_id) - for assoc in group.member_associations: - await _add_user(assoc.user) - - elif expense_list_id: # Only if group_id was not primary context - list_result = await db.execute( - select(ListModel) - .options( - selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))), - selectinload(ListModel.creator) - ) - .where(ListModel.id == expense_list_id) - ) - db_list = list_result.scalars().first() - if not db_list: - raise ListNotFoundError(expense_list_id) - - if db_list.group: - for assoc in db_list.group.member_associations: - await _add_user(assoc.user) - elif db_list.creator: - await _add_user(db_list.creator) - - if not users_to_split_with: - payer_user = await db.get(UserModel, expense_paid_by_user_id) - if not payer_user: - # This should have been caught earlier if paid_by_user_id was validated before calling this helper - raise UserNotFoundError(user_id=expense_paid_by_user_id) - await _add_user(payer_user) - - if not users_to_split_with: - # This should ideally not be reached if payer is always a fallback - raise InvalidOperationError("Could not determine any users for splitting the expense.") - - return users_to_split_with - - -async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_user_id: int) -> ExpenseModel: - """Creates a new expense and its associated splits. - - Args: - db: Database session - expense_in: Expense creation data - current_user_id: ID of the user creating the expense - - Returns: - The created expense with splits - - Raises: - UserNotFoundError: If payer or split users don't exist - ListNotFoundError: If specified list doesn't exist - GroupNotFoundError: If specified group doesn't exist - InvalidOperationError: For various validation failures - """ - # Helper function to round decimals consistently - def round_money(amount: Decimal) -> Decimal: - return amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - - # 1. Context Validation - # Validate basic context requirements first - if not expense_in.list_id and not expense_in.group_id: - raise InvalidOperationError("Expense must be associated with a list or a group.") - - # 2. User Validation - payer = await db.get(UserModel, expense_in.paid_by_user_id) - if not payer: - raise UserNotFoundError(user_id=expense_in.paid_by_user_id) - - # 3. List/Group Context Resolution - final_group_id = await _resolve_expense_context(db, expense_in) - - # 4. Create the expense object - db_expense = ExpenseModel( - description=expense_in.description, - total_amount=round_money(expense_in.total_amount), - currency=expense_in.currency or "USD", - expense_date=expense_in.expense_date or datetime.now(timezone.utc), - split_type=expense_in.split_type, - list_id=expense_in.list_id, - group_id=final_group_id, - item_id=expense_in.item_id, - paid_by_user_id=expense_in.paid_by_user_id, - created_by_user_id=current_user_id # Track who created this expense - ) - - # 5. Generate splits based on split type - splits_to_create = await _generate_expense_splits(db, db_expense, expense_in, round_money) - - # 6. Single transaction for expense and all splits - try: - db.add(db_expense) - await db.flush() # Get expense ID without committing - - # Update all splits with the expense ID - for split in splits_to_create: - split.expense_id = db_expense.id - - db.add_all(splits_to_create) - await db.commit() - - except Exception as e: - await db.rollback() - logger.error(f"Failed to save expense: {str(e)}", exc_info=True) - raise InvalidOperationError(f"Failed to save expense: {str(e)}") - - # Refresh to get the splits relationship populated - await db.refresh(db_expense, attribute_names=["splits"]) - return db_expense - - -async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]: - """Resolves and validates the expense's context (list and group). - - Returns the final group_id for the expense after validation. - """ - final_group_id = expense_in.group_id - - # If list_id is provided, validate it and potentially derive group_id - if expense_in.list_id: - list_obj = await db.get(ListModel, expense_in.list_id) - if not list_obj: - raise ListNotFoundError(expense_in.list_id) - - # If list belongs to a group, verify consistency or inherit group_id - if list_obj.group_id: - if expense_in.group_id and list_obj.group_id != expense_in.group_id: - raise InvalidOperationError( - f"List {expense_in.list_id} belongs to group {list_obj.group_id}, " - f"but expense was specified for group {expense_in.group_id}." - ) - final_group_id = list_obj.group_id # Prioritize list's group - - # If only group_id is provided (no list_id), validate group_id - elif final_group_id: - group_obj = await db.get(GroupModel, final_group_id) - if not group_obj: - raise GroupNotFoundError(final_group_id) - - return final_group_id - - -async def _generate_expense_splits( - db: AsyncSession, - db_expense: ExpenseModel, - expense_in: ExpenseCreate, - round_money: Callable[[Decimal], Decimal] -) -> PyList[ExpenseSplitModel]: - """Generates appropriate expense splits based on split type.""" - - splits_to_create: PyList[ExpenseSplitModel] = [] - - # Create splits based on the split type - if expense_in.split_type == SplitTypeEnum.EQUAL: - splits_to_create = await _create_equal_splits( - db, db_expense, expense_in, round_money - ) - - elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS: - splits_to_create = await _create_exact_amount_splits( - db, db_expense, expense_in, round_money - ) - - elif expense_in.split_type == SplitTypeEnum.PERCENTAGE: - splits_to_create = await _create_percentage_splits( - db, db_expense, expense_in, round_money - ) - - elif expense_in.split_type == SplitTypeEnum.SHARES: - splits_to_create = await _create_shares_splits( - db, db_expense, expense_in, round_money - ) - - elif expense_in.split_type == SplitTypeEnum.ITEM_BASED: - splits_to_create = await _create_item_based_splits( - db, db_expense, expense_in, round_money - ) - - else: - raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}") - - if not splits_to_create: - raise InvalidOperationError("No expense splits were generated.") - - return splits_to_create - - -async def _create_equal_splits( - db: AsyncSession, - db_expense: ExpenseModel, - expense_in: ExpenseCreate, - round_money: Callable[[Decimal], Decimal] -) -> PyList[ExpenseSplitModel]: - """Creates equal splits among users.""" - - users_for_splitting = await get_users_for_splitting( - db, db_expense.group_id, expense_in.list_id, expense_in.paid_by_user_id - ) - if not users_for_splitting: - raise InvalidOperationError("No users found for EQUAL split.") - - num_users = len(users_for_splitting) - amount_per_user = round_money(db_expense.total_amount / Decimal(num_users)) - remainder = db_expense.total_amount - (amount_per_user * num_users) - - splits = [] - for i, user in enumerate(users_for_splitting): - split_amount = amount_per_user - if i == 0 and remainder != Decimal('0'): - split_amount = round_money(amount_per_user + remainder) - - splits.append(ExpenseSplitModel( - user_id=user.id, - owed_amount=split_amount - )) - - return splits - - -async def _create_exact_amount_splits( - db: AsyncSession, - db_expense: ExpenseModel, - expense_in: ExpenseCreate, - round_money: Callable[[Decimal], Decimal] -) -> PyList[ExpenseSplitModel]: - """Creates splits with exact amounts.""" - - if not expense_in.splits_in: - raise InvalidOperationError("Splits data is required for EXACT_AMOUNTS split type.") - - # Validate all users in splits exist - await _validate_users_in_splits(db, expense_in.splits_in) - - current_total = Decimal("0.00") - splits = [] - - for split_in in expense_in.splits_in: - if split_in.owed_amount <= Decimal('0'): - raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.") - - rounded_amount = round_money(split_in.owed_amount) - current_total += rounded_amount - - splits.append(ExpenseSplitModel( - user_id=split_in.user_id, - owed_amount=rounded_amount - )) - - if round_money(current_total) != db_expense.total_amount: - raise InvalidOperationError( - f"Sum of exact split amounts ({current_total}) != expense total ({db_expense.total_amount})." - ) - - return splits - - -async def _create_percentage_splits( - db: AsyncSession, - db_expense: ExpenseModel, - expense_in: ExpenseCreate, - round_money: Callable[[Decimal], Decimal] -) -> PyList[ExpenseSplitModel]: - """Creates splits based on percentages.""" - - if not expense_in.splits_in: - raise InvalidOperationError("Splits data is required for PERCENTAGE split type.") - - # Validate all users in splits exist - await _validate_users_in_splits(db, expense_in.splits_in) - - total_percentage = Decimal("0.00") - current_total = Decimal("0.00") - splits = [] - - for split_in in expense_in.splits_in: - if not (split_in.share_percentage and Decimal("0") < split_in.share_percentage <= Decimal("100")): - raise InvalidOperationError( - f"Invalid percentage {split_in.share_percentage} for user {split_in.user_id}." - ) - - total_percentage += split_in.share_percentage - owed_amount = round_money(db_expense.total_amount * (split_in.share_percentage / Decimal("100"))) - current_total += owed_amount - - splits.append(ExpenseSplitModel( - user_id=split_in.user_id, - owed_amount=owed_amount, - share_percentage=split_in.share_percentage - )) - - if round_money(total_percentage) != Decimal("100.00"): - raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.") - - # Adjust for rounding differences - if current_total != db_expense.total_amount and splits: - diff = db_expense.total_amount - current_total - splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff) - - return splits - - -async def _create_shares_splits( - db: AsyncSession, - db_expense: ExpenseModel, - expense_in: ExpenseCreate, - round_money: Callable[[Decimal], Decimal] -) -> PyList[ExpenseSplitModel]: - """Creates splits based on shares.""" - - if not expense_in.splits_in: - raise InvalidOperationError("Splits data is required for SHARES split type.") - - # Validate all users in splits exist - await _validate_users_in_splits(db, expense_in.splits_in) - - # Calculate total shares - total_shares = sum(s.share_units for s in expense_in.splits_in if s.share_units and s.share_units > 0) - if total_shares == 0: - raise InvalidOperationError("Total shares cannot be zero for SHARES split.") - - splits = [] - current_total = Decimal("0.00") - - for split_in in expense_in.splits_in: - if not (split_in.share_units and split_in.share_units > 0): - raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.") - - share_ratio = Decimal(split_in.share_units) / Decimal(total_shares) - owed_amount = round_money(db_expense.total_amount * share_ratio) - current_total += owed_amount - - splits.append(ExpenseSplitModel( - user_id=split_in.user_id, - owed_amount=owed_amount, - share_units=split_in.share_units - )) - - # Adjust for rounding differences - if current_total != db_expense.total_amount and splits: - diff = db_expense.total_amount - current_total - splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff) - - return splits - - -async def _create_item_based_splits( - db: AsyncSession, - db_expense: ExpenseModel, - expense_in: ExpenseCreate, - round_money: Callable[[Decimal], Decimal] -) -> PyList[ExpenseSplitModel]: - """Creates splits based on items in a shopping list.""" - - if not expense_in.list_id: - raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.") - - if expense_in.splits_in: - logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.") - - # Build query to fetch relevant items - items_query = select(ItemModel).where(ItemModel.list_id == expense_in.list_id) - if expense_in.item_id: - items_query = items_query.where(ItemModel.id == expense_in.item_id) - else: - items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0"))) - - # Load items with their adders - items_result = await db.execute(items_query.options(selectinload(ItemModel.added_by_user))) - relevant_items = items_result.scalars().all() - - if not relevant_items: - error_msg = ( - f"Specified item ID {expense_in.item_id} not found in list {expense_in.list_id}." - if expense_in.item_id else - f"List {expense_in.list_id} has no priced items to base the expense on." - ) - raise InvalidOperationError(error_msg) - - # Aggregate owed amounts by user - calculated_total = Decimal("0.00") - user_owed_amounts = defaultdict(Decimal) - processed_items = 0 - - for item in relevant_items: - if item.price is None or item.price <= Decimal("0"): - if expense_in.item_id: - raise InvalidOperationError( - f"Item ID {expense_in.item_id} must have a positive price for ITEM_BASED expense." - ) - continue - - if not item.added_by_user: - logger.error(f"Item ID {item.id} is missing added_by_user relationship.") - raise InvalidOperationError(f"Data integrity issue: Item {item.id} is missing adder information.") - - calculated_total += item.price - user_owed_amounts[item.added_by_user.id] += item.price - processed_items += 1 - - if processed_items == 0: - raise InvalidOperationError( - f"No items with positive prices found in list {expense_in.list_id} to create ITEM_BASED expense." - ) - - # Validate total matches calculated total - if round_money(calculated_total) != db_expense.total_amount: - raise InvalidOperationError( - f"Expense total amount ({db_expense.total_amount}) does not match the " - f"calculated total from item prices ({calculated_total})." - ) - - # Create splits based on aggregated amounts - splits = [] - for user_id, owed_amount in user_owed_amounts.items(): - splits.append(ExpenseSplitModel( - user_id=user_id, - owed_amount=round_money(owed_amount) - )) - - return splits - - -async def _validate_users_in_splits(db: AsyncSession, splits_in: PyList[ExpenseSplitCreate]) -> None: - """Validates that all users in the splits exist.""" - - user_ids_in_split = [s.user_id for s in splits_in] - user_results = await db.execute(select(UserModel.id).where(UserModel.id.in_(user_ids_in_split))) - found_user_ids = {row[0] for row in user_results} - - if len(found_user_ids) != len(user_ids_in_split): - missing_user_ids = set(user_ids_in_split) - found_user_ids - raise UserNotFoundError(identifier=f"users in split data: {list(missing_user_ids)}") - - -async def get_expense_by_id(db: AsyncSession, expense_id: int) -> Optional[ExpenseModel]: - result = await db.execute( - select(ExpenseModel) - .options( - selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)), - selectinload(ExpenseModel.paid_by_user), - selectinload(ExpenseModel.list), - selectinload(ExpenseModel.group), - selectinload(ExpenseModel.item) - ) - .where(ExpenseModel.id == expense_id) - ) - return result.scalars().first() - -async def get_expenses_for_list(db: AsyncSession, list_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]: - result = await db.execute( - select(ExpenseModel) - .where(ExpenseModel.list_id == list_id) - .order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc()) - .offset(skip).limit(limit) - .options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user))) # Also load user for each split - ) - return result.scalars().all() - -async def get_expenses_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]: - result = await db.execute( - select(ExpenseModel) - .where(ExpenseModel.group_id == group_id) - .order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc()) - .offset(skip).limit(limit) - .options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user))) - ) - return result.scalars().all() - -async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in: ExpenseUpdate) -> ExpenseModel: - """ - Updates an existing expense. - Only allows updates to description, currency, and expense_date to avoid split complexities. - Requires version matching for optimistic locking. - """ - if expense_db.version != expense_in.version: - raise InvalidOperationError( - f"Expense '{expense_db.description}' (ID: {expense_db.id}) has been modified. " - f"Your version is {expense_in.version}, current version is {expense_db.version}. Please refresh.", - # status_code=status.HTTP_409_CONFLICT # This would be for the API layer to set - ) - - update_data = expense_in.model_dump(exclude_unset=True, exclude={"version"}) # Exclude version itself from data - - # Fields that are safe to update without affecting splits or core logic - allowed_to_update = {"description", "currency", "expense_date"} - - updated_something = False - for field, value in update_data.items(): - if field in allowed_to_update: - setattr(expense_db, field, value) - updated_something = True - else: - # If any other field is present in the update payload, it's an invalid operation for this simple update - raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed.") - - if not updated_something and not expense_in.model_fields_set.intersection(allowed_to_update): - # No actual updatable fields were provided in the payload, even if others (like version) were. - # This could be a non-issue, or an indication of a misuse of the endpoint. - # For now, if only version was sent, we still increment if it matched. - pass # Or raise InvalidOperationError("No updatable fields provided.") - - expense_db.version += 1 - expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp - - try: - await db.commit() - await db.refresh(expense_db) - except Exception as e: - await db.rollback() - # Consider specific DB error types if needed - raise InvalidOperationError(f"Failed to update expense: {str(e)}") - - return expense_db - -async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None: - """ - Deletes an expense. Requires version matching if expected_version is provided. - Associated ExpenseSplits are cascade deleted by the database foreign key constraint. - """ - if expected_version is not None and expense_db.version != expected_version: - raise InvalidOperationError( - f"Expense '{expense_db.description}' (ID: {expense_db.id}) cannot be deleted. " - f"Your expected version {expected_version} does not match current version {expense_db.version}. Please refresh.", - # status_code=status.HTTP_409_CONFLICT - ) - - await db.delete(expense_db) - try: - await db.commit() - except Exception as e: - await db.rollback() - raise InvalidOperationError(f"Failed to delete expense: {str(e)}") - return None - -# Note: The InvalidOperationError is a simple ValueError placeholder. -# For API endpoints, these should be translated to appropriate HTTPExceptions. -# Ensure app.core.exceptions has proper HTTP error classes if needed. - - - -# app/crud/invite.py -import secrets -from datetime import datetime, timedelta, timezone -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select -from sqlalchemy import delete # Import delete statement -from typing import Optional - -from app.models import Invite as InviteModel - -# Invite codes should be reasonably unique, but handle potential collision -MAX_CODE_GENERATION_ATTEMPTS = 5 - -async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 7) -> Optional[InviteModel]: - """Creates a new invite code for a group.""" - expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days) - code = None - attempts = 0 - - # Generate a unique code, retrying if a collision occurs (highly unlikely but safe) - while attempts < MAX_CODE_GENERATION_ATTEMPTS: - attempts += 1 - potential_code = secrets.token_urlsafe(16) - # Check if an *active* invite with this code already exists - existing = await db.execute( - select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1) - ) - if existing.scalar_one_or_none() is None: - code = potential_code - break - - if code is None: - # Failed to generate a unique code after several attempts - return None - - db_invite = InviteModel( - code=code, - group_id=group_id, - created_by_id=creator_id, - expires_at=expires_at, - is_active=True - ) - db.add(db_invite) - await db.commit() - await db.refresh(db_invite) - return db_invite - -async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]: - """Gets an active and non-expired invite by its code.""" - now = datetime.now(timezone.utc) - result = await db.execute( - select(InviteModel).where( - InviteModel.code == code, - InviteModel.is_active == True, - InviteModel.expires_at > now - ) - ) - return result.scalars().first() - -async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteModel: - """Marks an invite as inactive (used).""" - invite.is_active = False - db.add(invite) # Add to session to track change - await db.commit() - await db.refresh(invite) - return invite - -# Optional: Function to periodically delete old, inactive invites -# async def cleanup_old_invites(db: AsyncSession, older_than_days: int = 30): ... - - - -# app/crud/settlement.py -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select -from sqlalchemy.orm import selectinload, joinedload -from sqlalchemy import or_ -from decimal import Decimal -from typing import List as PyList, Optional, Sequence -from datetime import datetime, timezone - -from app.models import ( - Settlement as SettlementModel, - User as UserModel, - Group as GroupModel -) -from app.schemas.expense import SettlementCreate, SettlementUpdate # SettlementUpdate not used yet -from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError - -async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel: - """Creates a new settlement record.""" - - # Validate Payer, Payee, and Group exist - payer = await db.get(UserModel, settlement_in.paid_by_user_id) - if not payer: - raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer") - - payee = await db.get(UserModel, settlement_in.paid_to_user_id) - if not payee: - raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee") - - if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id: - raise InvalidOperationError("Payer and Payee cannot be the same user.") - - group = await db.get(GroupModel, settlement_in.group_id) - if not group: - raise GroupNotFoundError(settlement_in.group_id) - - # Optional: Check if current_user_id is part of the group or is one of the parties involved - # This is more of an API-level permission check but could be added here if strict. - # For example: if current_user_id not in [settlement_in.paid_by_user_id, settlement_in.paid_to_user_id]: - # is_in_group = await db.execute(select(UserGroupModel).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id)) - # if not is_in_group.first(): - # raise InvalidOperationError("You can only record settlements you are part of or for groups you belong to.") - - db_settlement = SettlementModel( - group_id=settlement_in.group_id, - paid_by_user_id=settlement_in.paid_by_user_id, - paid_to_user_id=settlement_in.paid_to_user_id, - amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), - settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc), - description=settlement_in.description - # created_by_user_id = current_user_id # Optional: Who recorded this settlement - ) - db.add(db_settlement) - try: - await db.commit() - await db.refresh(db_settlement, attribute_names=["payer", "payee", "group"]) - except Exception as e: - await db.rollback() - raise InvalidOperationError(f"Failed to save settlement: {str(e)}") - - return db_settlement - -async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional[SettlementModel]: - result = await db.execute( - select(SettlementModel) - .options( - selectinload(SettlementModel.payer), - selectinload(SettlementModel.payee), - selectinload(SettlementModel.group) - ) - .where(SettlementModel.id == settlement_id) - ) - return result.scalars().first() - -async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]: - result = await db.execute( - select(SettlementModel) - .where(SettlementModel.group_id == group_id) - .order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc()) - .offset(skip).limit(limit) - .options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee)) - ) - return result.scalars().all() - -async def get_settlements_involving_user( - db: AsyncSession, - user_id: int, - group_id: Optional[int] = None, - skip: int = 0, - limit: int = 100 -) -> Sequence[SettlementModel]: - query = ( - select(SettlementModel) - .where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id)) - .order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc()) - .offset(skip).limit(limit) - .options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), selectinload(SettlementModel.group)) - ) - if group_id: - query = query.where(SettlementModel.group_id == group_id) - - result = await db.execute(query) - return result.scalars().all() - -async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, settlement_in: SettlementUpdate) -> SettlementModel: - """ - Updates an existing settlement. - Only allows updates to description and settlement_date. - Requires version matching for optimistic locking. - Assumes SettlementUpdate schema includes a version field. - """ - # Check if SettlementUpdate schema has 'version'. If not, this check needs to be adapted or version passed differently. - if not hasattr(settlement_in, 'version') or settlement_db.version != settlement_in.version: - raise InvalidOperationError( - f"Settlement (ID: {settlement_db.id}) has been modified. " - f"Your version does not match current version {settlement_db.version}. Please refresh.", - # status_code=status.HTTP_409_CONFLICT - ) - - update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"}) - allowed_to_update = {"description", "settlement_date"} - updated_something = False - - for field, value in update_data.items(): - if field in allowed_to_update: - setattr(settlement_db, field, value) - updated_something = True - else: - raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed for settlements.") - - if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update): - pass # No actual updatable fields provided, but version matched. - - settlement_db.version += 1 # Assuming SettlementModel has a version field, add if missing - settlement_db.updated_at = datetime.now(timezone.utc) - - try: - await db.commit() - await db.refresh(settlement_db) - except Exception as e: - await db.rollback() - raise InvalidOperationError(f"Failed to update settlement: {str(e)}") - - return settlement_db - -async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, expected_version: Optional[int] = None) -> None: - """ - Deletes a settlement. Requires version matching if expected_version is provided. - Assumes SettlementModel has a version field. - """ - if expected_version is not None: - if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version: - raise InvalidOperationError( - f"Settlement (ID: {settlement_db.id}) cannot be deleted. " - f"Expected version {expected_version} does not match current version. Please refresh.", - # status_code=status.HTTP_409_CONFLICT - ) - - await db.delete(settlement_db) - try: - await db.commit() - except Exception as e: - await db.rollback() - raise InvalidOperationError(f"Failed to delete settlement: {str(e)}") - return None - -# TODO: Implement update_settlement (consider restrictions, versioning) -# TODO: Implement delete_settlement (consider implications on balances) - - - -# app/database.py -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession -from sqlalchemy.orm import sessionmaker, declarative_base -from app.config import settings - -# Ensure DATABASE_URL is set before proceeding -if not settings.DATABASE_URL: - raise ValueError("DATABASE_URL is not configured in settings.") - -# Create the SQLAlchemy async engine -# pool_recycle=3600 helps prevent stale connections on some DBs -engine = create_async_engine( - settings.DATABASE_URL, - echo=True, # Log SQL queries (useful for debugging) - future=True, # Use SQLAlchemy 2.0 style features - pool_recycle=3600 # Optional: recycle connections after 1 hour -) - -# Create a configured "Session" class -# expire_on_commit=False prevents attributes from expiring after commit -AsyncSessionLocal = sessionmaker( - bind=engine, - class_=AsyncSession, - expire_on_commit=False, - autoflush=False, - autocommit=False, -) - -# Base class for our ORM models -Base = declarative_base() - -# Dependency to get DB session in path operations -async def get_db() -> AsyncSession: # type: ignore - """ - Dependency function that yields an AsyncSession. - Ensures the session is closed after the request. - """ - async with AsyncSessionLocal() as session: - try: - yield session - # Optionally commit if your endpoints modify data directly - # await session.commit() # Usually commit happens within endpoint logic - except Exception: - await session.rollback() - raise - finally: - await session.close() # Not strictly necessary with async context manager, but explicit - - - -# app/schemas/expense.py -from pydantic import BaseModel, ConfigDict, validator -from typing import List, Optional -from decimal import Decimal -from datetime import datetime - -# Assuming SplitTypeEnum is accessible here, e.g., from app.models or app.core.enums -# For now, let's redefine it or import it if models.py is parsable by Pydantic directly -# If it's from app.models, you might need to make app.models.SplitTypeEnum Pydantic-compatible or map it. -# For simplicity during schema definition, I'll redefine a string enum here. -# In a real setup, ensure this aligns with the SQLAlchemy enum in models.py. -from app.models import SplitTypeEnum # Try importing directly - -# --- ExpenseSplit Schemas --- -class ExpenseSplitBase(BaseModel): - user_id: int - owed_amount: Decimal - share_percentage: Optional[Decimal] = None - share_units: Optional[int] = None - -class ExpenseSplitCreate(ExpenseSplitBase): - pass # All fields from base are needed for creation - -class ExpenseSplitPublic(ExpenseSplitBase): - id: int - expense_id: int - # user: Optional[UserPublic] # If we want to nest user details - created_at: datetime - updated_at: datetime - model_config = ConfigDict(from_attributes=True) - -# --- Expense Schemas --- -class ExpenseBase(BaseModel): - description: str - total_amount: Decimal - currency: Optional[str] = "USD" - expense_date: Optional[datetime] = None - split_type: SplitTypeEnum - list_id: Optional[int] = None - group_id: Optional[int] = None # Should be present if list_id is not, and vice-versa - item_id: Optional[int] = None - paid_by_user_id: int - -class ExpenseCreate(ExpenseBase): - # For EQUAL split, splits are generated. For others, they might be provided. - # This logic will be in the CRUD: if split_type is EXACT_AMOUNTS, PERCENTAGE, SHARES, - # then 'splits_in' should be provided. - splits_in: Optional[List[ExpenseSplitCreate]] = None - - @validator('total_amount') - def total_amount_must_be_positive(cls, v): - if v <= Decimal('0'): - raise ValueError('Total amount must be positive') - return v - - # Basic validation: if list_id is None, group_id must be provided. - # More complex cross-field validation might be needed. - @validator('group_id', always=True) - def check_list_or_group_id(cls, v, values): - if values.get('list_id') is None and v is None: - raise ValueError('Either list_id or group_id must be provided for an expense') - return v - -class ExpenseUpdate(BaseModel): - description: Optional[str] = None - total_amount: Optional[Decimal] = None - currency: Optional[str] = None - expense_date: Optional[datetime] = None - split_type: Optional[SplitTypeEnum] = None - list_id: Optional[int] = None - group_id: Optional[int] = None - item_id: Optional[int] = None - # paid_by_user_id is usually not updatable directly to maintain integrity. - # Updating splits would be a more complex operation, potentially a separate endpoint or careful logic. - version: int # For optimistic locking - -class ExpensePublic(ExpenseBase): - id: int - created_at: datetime - updated_at: datetime - version: int - splits: List[ExpenseSplitPublic] = [] - # paid_by_user: Optional[UserPublic] # If nesting user details - # list: Optional[ListPublic] # If nesting list details - # group: Optional[GroupPublic] # If nesting group details - # item: Optional[ItemPublic] # If nesting item details - model_config = ConfigDict(from_attributes=True) - -# --- Settlement Schemas --- -class SettlementBase(BaseModel): - group_id: int - paid_by_user_id: int - paid_to_user_id: int - amount: Decimal - settlement_date: Optional[datetime] = None - description: Optional[str] = None - -class SettlementCreate(SettlementBase): - @validator('amount') - def amount_must_be_positive(cls, v): - if v <= Decimal('0'): - raise ValueError('Settlement amount must be positive') - return v - - @validator('paid_to_user_id') - def payer_and_payee_must_be_different(cls, v, values): - if 'paid_by_user_id' in values and v == values['paid_by_user_id']: - raise ValueError('Payer and payee cannot be the same user') - return v - -class SettlementUpdate(BaseModel): - amount: Optional[Decimal] = None - settlement_date: Optional[datetime] = None - description: Optional[str] = None - # group_id, paid_by_user_id, paid_to_user_id are typically not updatable. - version: int # For optimistic locking - -class SettlementPublic(SettlementBase): - id: int - created_at: datetime - updated_at: datetime - # payer: Optional[UserPublic] - # payee: Optional[UserPublic] - # group: Optional[GroupPublic] - model_config = ConfigDict(from_attributes=True) - -# Placeholder for nested schemas (e.g., UserPublic) if needed -# from app.schemas.user import UserPublic -# from app.schemas.list import ListPublic -# from app.schemas.group import GroupPublic -# from app.schemas.item import ItemPublic - - - -# app/schemas/group.py -from pydantic import BaseModel, ConfigDict -from datetime import datetime -from typing import Optional, List - -from .user import UserPublic # Import UserPublic to represent members - -# Properties to receive via API on creation -class GroupCreate(BaseModel): - name: str - -# Properties to return to client -class GroupPublic(BaseModel): - id: int - name: str - created_by_id: int - created_at: datetime - members: Optional[List[UserPublic]] = None # Include members only in detailed view - - model_config = ConfigDict(from_attributes=True) - -# Properties stored in DB (if needed, often GroupPublic is sufficient) -# class GroupInDB(GroupPublic): -# pass - - - -# app/schemas/invite.py -from pydantic import BaseModel -from datetime import datetime - -# Properties to receive when accepting an invite -class InviteAccept(BaseModel): - code: str - -# Properties to return when an invite is created -class InviteCodePublic(BaseModel): - code: str - expires_at: datetime - group_id: int - -# Properties for internal use/DB (optional) -# class Invite(InviteCodePublic): -# id: int -# created_by_id: int -# is_active: bool = True -# model_config = ConfigDict(from_attributes=True) - - - -# app/schemas/message.py -from pydantic import BaseModel - -class Message(BaseModel): - detail: str - - - -# app/schemas/ocr.py -from pydantic import BaseModel -from typing import List - -class OcrExtractResponse(BaseModel): - extracted_items: List[str] # A list of potential item names - - - -import pytest -from fastapi import status -from httpx import AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession -from typing import Callable, Dict, Any - -from app.models import User as UserModel, Group as GroupModel, List as ListModel -from app.schemas.expense import ExpenseCreate -from app.core.config import settings - -# Helper to create a URL for an endpoint -API_V1_STR = settings.API_V1_STR - -def expense_url(endpoint: str = "") -> str: - return f"{API_V1_STR}/financials/expenses{endpoint}" - -def settlement_url(endpoint: str = "") -> str: - return f"{API_V1_STR}/financials/settlements{endpoint}" - -@pytest.mark.asyncio -async def test_create_new_expense_success_list_context( - client: AsyncClient, - db_session: AsyncSession, # Assuming a fixture for db session - normal_user_token_headers: Dict[str, str], # Assuming a fixture for user auth - test_user: UserModel, # Assuming a fixture for a test user - test_list_user_is_member: ListModel, # Assuming a fixture for a list user is member of -) -> None: - """ - Test successful creation of a new expense linked to a list. - """ - expense_data = ExpenseCreate( - description="Test Expense for List", - amount=100.00, - currency="USD", - paid_by_user_id=test_user.id, - list_id=test_list_user_is_member.id, - group_id=None, # group_id should be derived from list if list is in a group - # category_id: Optional[int] = None # Assuming category is optional - # expense_date: Optional[date] = None # Assuming date is optional - # splits: Optional[List[SplitCreate]] = [] # Assuming splits are optional for now - ) - - response = await client.post( - expense_url(), - headers=normal_user_token_headers, - json=expense_data.model_dump(exclude_unset=True) - ) - - assert response.status_code == status.HTTP_201_CREATED - content = response.json() - assert content["description"] == expense_data.description - assert content["amount"] == expense_data.amount - assert content["currency"] == expense_data.currency - assert content["paid_by_user_id"] == test_user.id - assert content["list_id"] == test_list_user_is_member.id - # If test_list_user_is_member has a group_id, it should be set in the response - if test_list_user_is_member.group_id: - assert content["group_id"] == test_list_user_is_member.group_id - else: - assert content["group_id"] is None - assert "id" in content - assert "created_at" in content - assert "updated_at" in content - assert "version" in content - assert content["version"] == 1 - -@pytest.mark.asyncio -async def test_create_new_expense_success_group_context( - client: AsyncClient, - normal_user_token_headers: Dict[str, str], - test_user: UserModel, - test_group_user_is_member: GroupModel, # Assuming a fixture for a group user is member of -) -> None: - """ - Test successful creation of a new expense linked directly to a group. - """ - expense_data = ExpenseCreate( - description="Test Expense for Group", - amount=50.00, - currency="EUR", - paid_by_user_id=test_user.id, - group_id=test_group_user_is_member.id, - list_id=None, - ) - - response = await client.post( - expense_url(), - headers=normal_user_token_headers, - json=expense_data.model_dump(exclude_unset=True) - ) - - assert response.status_code == status.HTTP_201_CREATED - content = response.json() - assert content["description"] == expense_data.description - assert content["paid_by_user_id"] == test_user.id - assert content["group_id"] == test_group_user_is_member.id - assert content["list_id"] is None - assert content["version"] == 1 - -@pytest.mark.asyncio -async def test_create_new_expense_fail_no_list_or_group( - client: AsyncClient, - normal_user_token_headers: Dict[str, str], - test_user: UserModel, -) -> None: - """ - Test expense creation fails if neither list_id nor group_id is provided. - """ - expense_data = ExpenseCreate( - description="Test Invalid Expense", - amount=10.00, - currency="USD", - paid_by_user_id=test_user.id, - list_id=None, - group_id=None, - ) - - response = await client.post( - expense_url(), - headers=normal_user_token_headers, - json=expense_data.model_dump(exclude_unset=True) - ) - - assert response.status_code == status.HTTP_400_BAD_REQUEST - content = response.json() - assert "Expense must be linked to a list_id or group_id" in content["detail"] - -@pytest.mark.asyncio -async def test_create_new_expense_fail_paid_by_other_not_owner( - client: AsyncClient, - normal_user_token_headers: Dict[str, str], # User is member, not owner - test_user: UserModel, # This is the current_user (member) - test_group_user_is_member: GroupModel, # Group the current_user is a member of - another_user_in_group: UserModel, # Another user in the same group - # Ensure test_user is NOT an owner of test_group_user_is_member for this test -) -> None: - """ - Test creation fails if paid_by_user_id is another user, and current_user is not a group owner. - Assumes normal_user_token_headers belongs to a user who is a member but not an owner of test_group_user_is_member. - """ - expense_data = ExpenseCreate( - description="Expense paid by other", - amount=75.00, - currency="GBP", - paid_by_user_id=another_user_in_group.id, # Paid by someone else - group_id=test_group_user_is_member.id, - list_id=None, - ) - - response = await client.post( - expense_url(), - headers=normal_user_token_headers, # Current user is a member, not owner - json=expense_data.model_dump(exclude_unset=True) - ) - - assert response.status_code == status.HTTP_403_FORBIDDEN - content = response.json() - assert "Only group owners can create expenses paid by others" in content["detail"] - -# --- Add tests for other endpoints below --- -# GET /expenses/{expense_id} - -@pytest.mark.asyncio -async def test_get_expense_success( - client: AsyncClient, - normal_user_token_headers: Dict[str, str], - test_user: UserModel, - # Assume an existing expense created by test_user or in a group/list they have access to - # This would typically be created by another test or a fixture - created_expense: ExpensePublic, # Assuming a fixture that provides a created expense -) -> None: - """ - Test successfully retrieving an existing expense. - User has access either by being the payer, or via list/group membership. - """ - response = await client.get( - expense_url(f"/{created_expense.id}"), - headers=normal_user_token_headers - ) - assert response.status_code == status.HTTP_200_OK - content = response.json() - assert content["id"] == created_expense.id - assert content["description"] == created_expense.description - assert content["amount"] == created_expense.amount - assert content["paid_by_user_id"] == created_expense.paid_by_user_id - if created_expense.list_id: - assert content["list_id"] == created_expense.list_id - if created_expense.group_id: - assert content["group_id"] == created_expense.group_id - -# TODO: Add more tests for get_expense: -# - expense not found -> 404 -# - user has no access (not payer, not in list, not in group if applicable) -> 403 -# - expense in list, user has list access -# - expense in group, user has group access -# - expense personal (no list, no group), user is payer -# - expense personal (no list, no group), user is NOT payer -> 403 - -@pytest.mark.asyncio -async def test_get_expense_not_found( - client: AsyncClient, - normal_user_token_headers: Dict[str, str], -) -> None: - """ - Test retrieving a non-existent expense results in 404. - """ - non_existent_expense_id = 9999999 - response = await client.get( - expense_url(f"/{non_existent_expense_id}"), - headers=normal_user_token_headers - ) - assert response.status_code == status.HTTP_404_NOT_FOUND - content = response.json() - assert "not found" in content["detail"].lower() - -@pytest.mark.asyncio -async def test_get_expense_forbidden_personal_expense_other_user( - client: AsyncClient, - normal_user_token_headers: Dict[str, str], # Belongs to test_user - # Fixture for an expense paid by another_user, not linked to any list/group test_user has access to - personal_expense_of_another_user: ExpensePublic -) -> None: - """ - Test retrieving a personal expense of another user (no shared list/group) results in 403. - """ - response = await client.get( - expense_url(f"/{personal_expense_of_another_user.id}"), - headers=normal_user_token_headers # Current user querying - ) - assert response.status_code == status.HTTP_403_FORBIDDEN - content = response.json() - assert "Not authorized to view this expense" in content["detail"] - -# GET /lists/{list_id}/expenses - -@pytest.mark.asyncio -async def test_list_list_expenses_success( - client: AsyncClient, - normal_user_token_headers: Dict[str, str], - test_user: UserModel, - test_list_user_is_member: ListModel, # List the user is a member of - # Assume some expenses have been created for this list by a fixture or previous tests -) -> None: - """ - Test successfully listing expenses for a list the user has access to. - """ - response = await client.get( - f"{API_V1_STR}/financials/lists/{test_list_user_is_member.id}/expenses", - headers=normal_user_token_headers - ) - assert response.status_code == status.HTTP_200_OK - content = response.json() - assert isinstance(content, list) - for expense_item in content: # Renamed from expense to avoid conflict if a fixture is named expense - assert expense_item["list_id"] == test_list_user_is_member.id - -# TODO: Add more tests for list_list_expenses: -# - list not found -> 404 (ListNotFoundError from check_list_access_for_financials) -# - user has no access to list -> 403 (ListPermissionError from check_list_access_for_financials) -# - list exists but has no expenses -> empty list, 200 OK -# - test pagination (skip, limit) - -@pytest.mark.asyncio -async def test_list_list_expenses_list_not_found( - client: AsyncClient, - normal_user_token_headers: Dict[str, str], -) -> None: - """ - Test listing expenses for a non-existent list results in 404 (or appropriate error from permission check). - The check_list_access_for_financials raises ListNotFoundError, which might be caught and raised as 404. - The endpoint itself also has a get for ListModel, which would 404 first if permission check passed (not possible here). - Based on financials.py, ListNotFoundError is raised by check_list_access_for_financials. - This should translate to a 404 or a 403 if ListPermissionError wraps it with an action. - The current ListPermissionError in check_list_access_for_financials re-raises ListNotFoundError if that's the cause. - ListNotFoundError is a custom exception often mapped to 404. - Let's assume ListNotFoundError results in a 404 response from an exception handler. - """ - non_existent_list_id = 99999 - response = await client.get( - f"{API_V1_STR}/financials/lists/{non_existent_list_id}/expenses", - headers=normal_user_token_headers - ) - # The ListNotFoundError is raised by the check_list_access_for_financials helper, - # which is then re-raised. FastAPI default exception handlers or custom ones - # would convert this to an HTTP response. Typically NotFoundError -> 404. - # If ListPermissionError catches it and re-raises it specifically, it might be 403. - # From the code: `except ListNotFoundError: raise` means it propagates. - # Let's assume a global handler for NotFoundError derived exceptions leads to 404. - assert response.status_code == status.HTTP_404_NOT_FOUND - # The actual detail might vary based on how ListNotFoundError is handled by FastAPI - # For now, we check the status code. If financials.py maps it differently, this will need adjustment. - # Based on `raise ListNotFoundError(expense_in.list_id)` in create_new_expense, and if that leads to 400, - # this might be inconsistent. However, `check_list_access_for_financials` just re-raises ListNotFoundError. - # Let's stick to expecting 404 for a direct not found error from a path parameter. - content = response.json() - assert "list not found" in content["detail"].lower() # Common detail for not found errors - -@pytest.mark.asyncio -async def test_list_list_expenses_no_access( - client: AsyncClient, - normal_user_token_headers: Dict[str, str], # User who will attempt access - test_list_user_not_member: ListModel, # A list current user is NOT a member of -) -> None: - """ - Test listing expenses for a list the user does not have access to (403 Forbidden). - """ - response = await client.get( - f"{API_V1_STR}/financials/lists/{test_list_user_not_member.id}/expenses", - headers=normal_user_token_headers - ) - assert response.status_code == status.HTTP_403_FORBIDDEN - content = response.json() - assert f"User does not have permission to access financial data for list {test_list_user_not_member.id}" in content["detail"] - -@pytest.mark.asyncio -async def test_list_list_expenses_empty( - client: AsyncClient, - normal_user_token_headers: Dict[str, str], - test_list_user_is_member_no_expenses: ListModel, # List user is member of, but has no expenses -) -> None: - """ - Test listing expenses for an accessible list that has no expenses (empty list, 200 OK). - """ - response = await client.get( - f"{API_V1_STR}/financials/lists/{test_list_user_is_member_no_expenses.id}/expenses", - headers=normal_user_token_headers - ) - assert response.status_code == status.HTTP_200_OK - content = response.json() - assert isinstance(content, list) - assert len(content) == 0 - -# GET /groups/{group_id}/expenses - -@pytest.mark.asyncio -async def test_list_group_expenses_success( - client: AsyncClient, - normal_user_token_headers: Dict[str, str], - test_user: UserModel, - test_group_user_is_member: GroupModel, # Group the user is a member of - # Assume some expenses have been created for this group by a fixture or previous tests -) -> None: - """ - Test successfully listing expenses for a group the user has access to. - """ - response = await client.get( - f"{API_V1_STR}/financials/groups/{test_group_user_is_member.id}/expenses", - headers=normal_user_token_headers - ) - assert response.status_code == status.HTTP_200_OK - content = response.json() - assert isinstance(content, list) - # Further assertions can be made here, e.g., checking if all expenses belong to the group - for expense_item in content: - assert expense_item["group_id"] == test_group_user_is_member.id - # Expenses in a group might also have a list_id if they were added via a list belonging to that group - -# TODO: Add more tests for list_group_expenses: -# - group not found -> 404 (GroupNotFoundError from check_group_membership) -# - user has no access to group (not a member) -> 403 (GroupMembershipError from check_group_membership) -# - group exists but has no expenses -> empty list, 200 OK -# - test pagination (skip, limit) - -# PUT /expenses/{expense_id} -# DELETE /expenses/{expense_id} - -# GET /settlements/{settlement_id} -# POST /settlements -# GET /groups/{group_id}/settlements -# PUT /settlements/{settlement_id} -# DELETE /settlements/{settlement_id} - -pytest.skip("Still implementing other tests", allow_module_level=True) - - - - - - - -import pytest -from fastapi import status -from app.core.exceptions import ( - ListNotFoundError, - ListPermissionError, - ListCreatorRequiredError, - GroupNotFoundError, - GroupPermissionError, - GroupMembershipError, - GroupOperationError, - GroupValidationError, - ItemNotFoundError, - UserNotFoundError, - InvalidOperationError, - DatabaseConnectionError, - DatabaseIntegrityError, - DatabaseTransactionError, - DatabaseQueryError, - OCRServiceUnavailableError, - OCRServiceConfigError, - OCRUnexpectedError, - OCRQuotaExceededError, - InvalidFileTypeError, - FileTooLargeError, - OCRProcessingError, - EmailAlreadyRegisteredError, - UserCreationError, - InviteNotFoundError, - InviteExpiredError, - InviteAlreadyUsedError, - InviteCreationError, - ListStatusNotFoundError, - ConflictError, - InvalidCredentialsError, - NotAuthenticatedError, - JWTError, - JWTUnexpectedError -) -# TODO: It seems like settings are used in some exceptions. -# You will need to mock app.config.settings for these tests to pass. -# Consider using pytest-mock or unittest.mock.patch. -# Example: from app.config import settings - - -def test_list_not_found_error(): - list_id = 123 - with pytest.raises(ListNotFoundError) as excinfo: - raise ListNotFoundError(list_id=list_id) - assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND - assert excinfo.value.detail == f"List {list_id} not found" - -def test_list_permission_error(): - list_id = 456 - action = "delete" - with pytest.raises(ListPermissionError) as excinfo: - raise ListPermissionError(list_id=list_id, action=action) - assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN - assert excinfo.value.detail == f"You do not have permission to {action} list {list_id}" - -def test_list_permission_error_default_action(): - list_id = 789 - with pytest.raises(ListPermissionError) as excinfo: - raise ListPermissionError(list_id=list_id) - assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN - assert excinfo.value.detail == f"You do not have permission to access list {list_id}" - -def test_list_creator_required_error(): - list_id = 101 - action = "update" - with pytest.raises(ListCreatorRequiredError) as excinfo: - raise ListCreatorRequiredError(list_id=list_id, action=action) - assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN - assert excinfo.value.detail == f"Only the list creator can {action} list {list_id}" - -def test_group_not_found_error(): - group_id = 202 - with pytest.raises(GroupNotFoundError) as excinfo: - raise GroupNotFoundError(group_id=group_id) - assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND - assert excinfo.value.detail == f"Group {group_id} not found" - -def test_group_permission_error(): - group_id = 303 - action = "invite" - with pytest.raises(GroupPermissionError) as excinfo: - raise GroupPermissionError(group_id=group_id, action=action) - assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN - assert excinfo.value.detail == f"You do not have permission to {action} in group {group_id}" - -def test_group_membership_error(): - group_id = 404 - action = "post" - with pytest.raises(GroupMembershipError) as excinfo: - raise GroupMembershipError(group_id=group_id, action=action) - assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN - assert excinfo.value.detail == f"You must be a member of group {group_id} to {action}" - -def test_group_membership_error_default_action(): - group_id = 505 - with pytest.raises(GroupMembershipError) as excinfo: - raise GroupMembershipError(group_id=group_id) - assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN - assert excinfo.value.detail == f"You must be a member of group {group_id} to access" - -def test_group_operation_error(): - detail_msg = "Failed to perform group operation." - with pytest.raises(GroupOperationError) as excinfo: - raise GroupOperationError(detail=detail_msg) - assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR - assert excinfo.value.detail == detail_msg - -def test_group_validation_error(): - detail_msg = "Invalid group data." - with pytest.raises(GroupValidationError) as excinfo: - raise GroupValidationError(detail=detail_msg) - assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST - assert excinfo.value.detail == detail_msg - -def test_item_not_found_error(): - item_id = 606 - with pytest.raises(ItemNotFoundError) as excinfo: - raise ItemNotFoundError(item_id=item_id) - assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND - assert excinfo.value.detail == f"Item {item_id} not found" - -def test_user_not_found_error_no_identifier(): - with pytest.raises(UserNotFoundError) as excinfo: - raise UserNotFoundError() - assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND - assert excinfo.value.detail == "User not found." - -def test_user_not_found_error_with_id(): - user_id = 707 - with pytest.raises(UserNotFoundError) as excinfo: - raise UserNotFoundError(user_id=user_id) - assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND - assert excinfo.value.detail == f"User with ID {user_id} not found." - -def test_user_not_found_error_with_identifier_string(): - identifier = "test_user" - with pytest.raises(UserNotFoundError) as excinfo: - raise UserNotFoundError(identifier=identifier) - assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND - assert excinfo.value.detail == f"User with identifier '{identifier}' not found." - -def test_invalid_operation_error(): - detail_msg = "This operation is not allowed." - with pytest.raises(InvalidOperationError) as excinfo: - raise InvalidOperationError(detail=detail_msg) - assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST - assert excinfo.value.detail == detail_msg - -def test_invalid_operation_error_custom_status(): - detail_msg = "This operation is forbidden." - custom_status = status.HTTP_403_FORBIDDEN - with pytest.raises(InvalidOperationError) as excinfo: - raise InvalidOperationError(detail=detail_msg, status_code=custom_status) - assert excinfo.value.status_code == custom_status - assert excinfo.value.detail == detail_msg - -# The following exceptions depend on `settings` -# We need to mock `app.config.settings` for these tests. -# For now, I will add placeholder tests that would fail without mocking. -# Consider using pytest-mock or unittest.mock.patch for this. - -# def test_database_connection_error(mocker): -# mocker.patch("app.core.exceptions.settings.DB_CONNECTION_ERROR", "Test DB connection error") -# with pytest.raises(DatabaseConnectionError) as excinfo: -# raise DatabaseConnectionError() -# assert excinfo.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE -# assert excinfo.value.detail == "Test DB connection error" # settings.DB_CONNECTION_ERROR - -# def test_database_integrity_error(mocker): -# mocker.patch("app.core.exceptions.settings.DB_INTEGRITY_ERROR", "Test DB integrity error") -# with pytest.raises(DatabaseIntegrityError) as excinfo: -# raise DatabaseIntegrityError() -# assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST -# assert excinfo.value.detail == "Test DB integrity error" # settings.DB_INTEGRITY_ERROR - -# def test_database_transaction_error(mocker): -# mocker.patch("app.core.exceptions.settings.DB_TRANSACTION_ERROR", "Test DB transaction error") -# with pytest.raises(DatabaseTransactionError) as excinfo: -# raise DatabaseTransactionError() -# assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR -# assert excinfo.value.detail == "Test DB transaction error" # settings.DB_TRANSACTION_ERROR - -# def test_database_query_error(mocker): -# mocker.patch("app.core.exceptions.settings.DB_QUERY_ERROR", "Test DB query error") -# with pytest.raises(DatabaseQueryError) as excinfo: -# raise DatabaseQueryError() -# assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR -# assert excinfo.value.detail == "Test DB query error" # settings.DB_QUERY_ERROR - -# def test_ocr_service_unavailable_error(mocker): -# mocker.patch("app.core.exceptions.settings.OCR_SERVICE_UNAVAILABLE", "Test OCR unavailable") -# with pytest.raises(OCRServiceUnavailableError) as excinfo: -# raise OCRServiceUnavailableError() -# assert excinfo.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE -# assert excinfo.value.detail == "Test OCR unavailable" # settings.OCR_SERVICE_UNAVAILABLE - -# def test_ocr_service_config_error(mocker): -# mocker.patch("app.core.exceptions.settings.OCR_SERVICE_CONFIG_ERROR", "Test OCR config error") -# with pytest.raises(OCRServiceConfigError) as excinfo: -# raise OCRServiceConfigError() -# assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR -# assert excinfo.value.detail == "Test OCR config error" # settings.OCR_SERVICE_CONFIG_ERROR - -# def test_ocr_unexpected_error(mocker): -# mocker.patch("app.core.exceptions.settings.OCR_UNEXPECTED_ERROR", "Test OCR unexpected error") -# with pytest.raises(OCRUnexpectedError) as excinfo: -# raise OCRUnexpectedError() -# assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR -# assert excinfo.value.detail == "Test OCR unexpected error" # settings.OCR_UNEXPECTED_ERROR - -# def test_ocr_quota_exceeded_error(mocker): -# mocker.patch("app.core.exceptions.settings.OCR_QUOTA_EXCEEDED", "Test OCR quota exceeded") -# with pytest.raises(OCRQuotaExceededError) as excinfo: -# raise OCRQuotaExceededError() -# assert excinfo.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS -# assert excinfo.value.detail == "Test OCR quota exceeded" # settings.OCR_QUOTA_EXCEEDED - -# def test_invalid_file_type_error(mocker): -# test_types = ["png", "jpg"] -# mocker.patch("app.core.exceptions.settings.ALLOWED_IMAGE_TYPES", test_types) -# mocker.patch("app.core.exceptions.settings.OCR_INVALID_FILE_TYPE", "Invalid type: {types}") -# with pytest.raises(InvalidFileTypeError) as excinfo: -# raise InvalidFileTypeError() -# assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST -# assert excinfo.value.detail == f"Invalid type: {', '.join(test_types)}" # settings.OCR_INVALID_FILE_TYPE.format(types=", ".join(settings.ALLOWED_IMAGE_TYPES)) - -# def test_file_too_large_error(mocker): -# max_size = 10 -# mocker.patch("app.core.exceptions.settings.MAX_FILE_SIZE_MB", max_size) -# mocker.patch("app.core.exceptions.settings.OCR_FILE_TOO_LARGE", "File too large: {size}MB") -# with pytest.raises(FileTooLargeError) as excinfo: -# raise FileTooLargeError() -# assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST -# assert excinfo.value.detail == f"File too large: {max_size}MB" # settings.OCR_FILE_TOO_LARGE.format(size=settings.MAX_FILE_SIZE_MB) - -# def test_ocr_processing_error(mocker): -# error_detail = "Specific OCR error" -# mocker.patch("app.core.exceptions.settings.OCR_PROCESSING_ERROR", "OCR processing failed: {detail}") -# with pytest.raises(OCRProcessingError) as excinfo: -# raise OCRProcessingError(detail=error_detail) -# assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST -# assert excinfo.value.detail == f"OCR processing failed: {error_detail}" # settings.OCR_PROCESSING_ERROR.format(detail=detail) - - -def test_email_already_registered_error(): - with pytest.raises(EmailAlreadyRegisteredError) as excinfo: - raise EmailAlreadyRegisteredError() - assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST - assert excinfo.value.detail == "Email already registered." - -def test_user_creation_error(): - with pytest.raises(UserCreationError) as excinfo: - raise UserCreationError() - assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR - assert excinfo.value.detail == "An error occurred during user creation." - -def test_invite_not_found_error(): - invite_code = "TESTCODE123" - with pytest.raises(InviteNotFoundError) as excinfo: - raise InviteNotFoundError(invite_code=invite_code) - assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND - assert excinfo.value.detail == f"Invite code {invite_code} not found" - -def test_invite_expired_error(): - invite_code = "EXPIREDCODE" - with pytest.raises(InviteExpiredError) as excinfo: - raise InviteExpiredError(invite_code=invite_code) - assert excinfo.value.status_code == status.HTTP_410_GONE - assert excinfo.value.detail == f"Invite code {invite_code} has expired" - -def test_invite_already_used_error(): - invite_code = "USEDCODE" - with pytest.raises(InviteAlreadyUsedError) as excinfo: - raise InviteAlreadyUsedError(invite_code=invite_code) - assert excinfo.value.status_code == status.HTTP_410_GONE - assert excinfo.value.detail == f"Invite code {invite_code} has already been used" - -def test_invite_creation_error(): - group_id = 909 - with pytest.raises(InviteCreationError) as excinfo: - raise InviteCreationError(group_id=group_id) - assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR - assert excinfo.value.detail == f"Failed to create invite for group {group_id}" - -def test_list_status_not_found_error(): - list_id = 808 - with pytest.raises(ListStatusNotFoundError) as excinfo: - raise ListStatusNotFoundError(list_id=list_id) - assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND - assert excinfo.value.detail == f"Status for list {list_id} not found" - -def test_conflict_error(): - detail_msg = "Resource version mismatch." - with pytest.raises(ConflictError) as excinfo: - raise ConflictError(detail=detail_msg) - assert excinfo.value.status_code == status.HTTP_409_CONFLICT - assert excinfo.value.detail == detail_msg - -# Tests for auth-related exceptions that likely require mocking app.config.settings - -# def test_invalid_credentials_error(mocker): -# mocker.patch("app.core.exceptions.settings.AUTH_INVALID_CREDENTIALS", "Invalid test credentials") -# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth") -# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer") -# with pytest.raises(InvalidCredentialsError) as excinfo: -# raise InvalidCredentialsError() -# assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED -# assert excinfo.value.detail == "Invalid test credentials" -# assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"invalid_credentials\""} - -# def test_not_authenticated_error(mocker): -# mocker.patch("app.core.exceptions.settings.AUTH_NOT_AUTHENTICATED", "Not authenticated test") -# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth") -# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer") -# with pytest.raises(NotAuthenticatedError) as excinfo: -# raise NotAuthenticatedError() -# assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED -# assert excinfo.value.detail == "Not authenticated test" -# assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"not_authenticated\""} - -# def test_jwt_error(mocker): -# error_msg = "Test JWT issue" -# mocker.patch("app.core.exceptions.settings.JWT_ERROR", "JWT error: {error}") -# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth") -# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer") -# with pytest.raises(JWTError) as excinfo: -# raise JWTError(error=error_msg) -# assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED -# assert excinfo.value.detail == f"JWT error: {error_msg}" -# assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"invalid_token\""} - -# def test_jwt_unexpected_error(mocker): -# error_msg = "Unexpected test JWT issue" -# mocker.patch("app.core.exceptions.settings.JWT_UNEXPECTED_ERROR", "Unexpected JWT error: {error}") -# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth") -# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer") -# with pytest.raises(JWTUnexpectedError) as excinfo: -# raise JWTUnexpectedError(error=error_msg) -# assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED -# assert excinfo.value.detail == f"Unexpected JWT error: {error_msg}" -# assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"invalid_token\""} - - - -import pytest -import asyncio -from unittest.mock import patch, MagicMock, AsyncMock - -import google.generativeai as genai -from google.api_core import exceptions as google_exceptions - -# Modules to test -from app.core import gemini -from app.core.exceptions import ( - OCRServiceUnavailableError, - OCRServiceConfigError, - OCRUnexpectedError, - OCRQuotaExceededError -) - -# Default Mock Settings -@pytest.fixture -def mock_gemini_settings(): - settings_mock = MagicMock() - settings_mock.GEMINI_API_KEY = "test_api_key" - settings_mock.GEMINI_MODEL_NAME = "gemini-pro-vision" - settings_mock.GEMINI_SAFETY_SETTINGS = { - "HARM_CATEGORY_HARASSMENT": "BLOCK_NONE", - "HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE", - } - settings_mock.GEMINI_GENERATION_CONFIG = {"temperature": 0.7} - settings_mock.OCR_ITEM_EXTRACTION_PROMPT = "Extract items:" - return settings_mock - -@pytest.fixture -def mock_generative_model_instance(): - model_instance = MagicMock(spec=genai.GenerativeModel) - model_instance.generate_content_async = AsyncMock() - return model_instance - -@pytest.fixture -@patch('google.generativeai.GenerativeModel') -@patch('google.generativeai.configure') -def patch_google_ai_client(mock_configure, mock_generative_model, mock_generative_model_instance): - mock_generative_model.return_value = mock_generative_model_instance - return mock_configure, mock_generative_model, mock_generative_model_instance - - -# --- Test Gemini Client Initialization (Global Client) --- - -# Parametrize to test different scenarios for the global client init -@pytest.mark.parametrize( - "api_key_present, configure_raises, model_init_raises, expected_error_message_part", - [ - (True, None, None, None), # Success - (False, None, None, "GEMINI_API_KEY not configured"), # API key missing - (True, Exception("Config error"), None, "Failed to initialize Gemini AI client: Config error"), # genai.configure error - (True, None, Exception("Model init error"), "Failed to initialize Gemini AI client: Model init error"), # GenerativeModel error - ] -) -@patch('app.core.gemini.genai') # Patch genai within the gemini module -def test_global_gemini_client_initialization( - mock_genai_module, - mock_gemini_settings, - api_key_present, - configure_raises, - model_init_raises, - expected_error_message_part -): - """Tests the global gemini_flash_client initialization logic in app.core.gemini.""" - # We need to reload the module to re-trigger its top-level initialization code. - # This is a bit tricky. A common pattern is to put init logic in a function. - # For now, we'll try to simulate it by controlling mocks before module access. - - with patch('app.core.gemini.settings', mock_gemini_settings): - if not api_key_present: - mock_gemini_settings.GEMINI_API_KEY = None - - mock_genai_module.configure = MagicMock() - mock_genai_module.GenerativeModel = MagicMock() - mock_genai_module.types = genai.types # Keep original types - mock_genai_module.HarmCategory = genai.HarmCategory - mock_genai_module.HarmBlockThreshold = genai.HarmBlockThreshold - - if configure_raises: - mock_genai_module.configure.side_effect = configure_raises - if model_init_raises: - mock_genai_module.GenerativeModel.side_effect = model_init_raises - - # Python modules are singletons. To re-run top-level code, we need to unload and reload. - # This is generally discouraged. It's better to have an explicit init function. - # For this test, we'll check the state variables set by the module's import-time code. - import importlib - importlib.reload(gemini) # This re-runs the try-except block at the top of gemini.py - - if expected_error_message_part: - assert gemini.gemini_initialization_error is not None - assert expected_error_message_part in gemini.gemini_initialization_error - assert gemini.gemini_flash_client is None - else: - assert gemini.gemini_initialization_error is None - assert gemini.gemini_flash_client is not None - mock_genai_module.configure.assert_called_once_with(api_key="test_api_key") - mock_genai_module.GenerativeModel.assert_called_once() - # Could add more assertions about safety_settings and generation_config here - - # Clean up after reload for other tests - importlib.reload(gemini) - -# --- Test get_gemini_client --- -# Assuming the global client tests above set the stage for these - -@patch('app.core.gemini.gemini_flash_client', new_callable=MagicMock) -@patch('app.core.gemini.gemini_initialization_error', None) -def test_get_gemini_client_success(mock_client_var, mock_error_var): - mock_client_var.return_value = MagicMock(spec=genai.GenerativeModel) # Simulate an initialized client - gemini.gemini_flash_client = mock_client_var # Assign the mock - gemini.gemini_initialization_error = None - client = gemini.get_gemini_client() - assert client is not None - -@patch('app.core.gemini.gemini_flash_client', None) -@patch('app.core.gemini.gemini_initialization_error', "Test init error") -def test_get_gemini_client_init_error(mock_client_var, mock_error_var): - gemini.gemini_flash_client = None - gemini.gemini_initialization_error = "Test init error" - with pytest.raises(RuntimeError, match="Gemini client could not be initialized: Test init error"): - gemini.get_gemini_client() - -@patch('app.core.gemini.gemini_flash_client', None) -@patch('app.core.gemini.gemini_initialization_error', None) # No init error, but client is None -def test_get_gemini_client_none_client_unknown_issue(mock_client_var, mock_error_var): - gemini.gemini_flash_client = None - gemini.gemini_initialization_error = None - with pytest.raises(RuntimeError, match="Gemini client is not available \(unknown initialization issue\)."): - gemini.get_gemini_client() - - -# --- Tests for extract_items_from_image_gemini --- (Simplified for brevity, needs more cases) -@pytest.mark.asyncio -async def test_extract_items_from_image_gemini_success( - mock_gemini_settings, - mock_generative_model_instance, - patch_google_ai_client # This fixture patches google.generativeai for the module -): - """ Test successful item extraction """ - # Ensure the global client is mocked to be the one we control - with patch('app.core.gemini.settings', mock_gemini_settings), \ - patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \ - patch('app.core.gemini.gemini_initialization_error', None): - - mock_response = MagicMock() - mock_response.text = "Item 1\nItem 2\n Item 3 \n\nAnother Item" - # Simulate the structure for safety checks if needed - mock_candidate = MagicMock() - mock_candidate.content.parts = [MagicMock(text=mock_response.text)] - mock_candidate.finish_reason = 'STOP' # Or whatever is appropriate for success - mock_candidate.safety_ratings = [] - mock_response.candidates = [mock_candidate] - - mock_generative_model_instance.generate_content_async.return_value = mock_response - - image_bytes = b"dummy_image_bytes" - mime_type = "image/png" - - items = await gemini.extract_items_from_image_gemini(image_bytes, mime_type) - - mock_generative_model_instance.generate_content_async.assert_called_once_with([ - mock_gemini_settings.OCR_ITEM_EXTRACTION_PROMPT, - {"mime_type": mime_type, "data": image_bytes} - ]) - assert items == ["Item 1", "Item 2", "Item 3", "Another Item"] - -@pytest.mark.asyncio -async def test_extract_items_from_image_gemini_client_not_init( - mock_gemini_settings -): - with patch('app.core.gemini.settings', mock_gemini_settings), \ - patch('app.core.gemini.gemini_flash_client', None), \ - patch('app.core.gemini.gemini_initialization_error', "Initialization failed explicitly"): - - image_bytes = b"dummy_image_bytes" - with pytest.raises(RuntimeError, match="Gemini client could not be initialized: Initialization failed explicitly"): - await gemini.extract_items_from_image_gemini(image_bytes) - -@pytest.mark.asyncio -@patch('app.core.gemini.get_gemini_client') # Mock the getter to control the client directly -async def test_extract_items_from_image_gemini_api_quota_error( - mock_get_client, - mock_gemini_settings, - mock_generative_model_instance -): - mock_get_client.return_value = mock_generative_model_instance - mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ResourceExhausted("Quota exceeded") - - with patch('app.core.gemini.settings', mock_gemini_settings): - image_bytes = b"dummy_image_bytes" - with pytest.raises(google_exceptions.ResourceExhausted, match="Quota exceeded"): - await gemini.extract_items_from_image_gemini(image_bytes) - - -# --- Tests for GeminiOCRService --- (Example tests, more needed) - -@patch('app.core.gemini.genai.configure') -@patch('app.core.gemini.genai.GenerativeModel') -def test_gemini_ocr_service_init_success(MockGenerativeModel, MockConfigure, mock_gemini_settings, mock_generative_model_instance): - MockGenerativeModel.return_value = mock_generative_model_instance - with patch('app.core.gemini.settings', mock_gemini_settings): - service = gemini.GeminiOCRService() - MockConfigure.assert_called_once_with(api_key=mock_gemini_settings.GEMINI_API_KEY) - MockGenerativeModel.assert_called_once_with(mock_gemini_settings.GEMINI_MODEL_NAME) - assert service.model == mock_generative_model_instance - # Could add assertions for safety_settings and generation_config if they are set directly on model - -@patch('app.core.gemini.genai.configure') -@patch('app.core.gemini.genai.GenerativeModel', side_effect=Exception("Init model failed")) -def test_gemini_ocr_service_init_failure(MockGenerativeModel, MockConfigure, mock_gemini_settings): - with patch('app.core.gemini.settings', mock_gemini_settings): - with pytest.raises(OCRServiceConfigError): - gemini.GeminiOCRService() - -@pytest.mark.asyncio -async def test_gemini_ocr_service_extract_items_success(mock_gemini_settings, mock_generative_model_instance): - mock_response = MagicMock() - mock_response.text = "Apple\nBanana\nOrange\nExample output should be ignored" - mock_generative_model_instance.generate_content_async.return_value = mock_response - - with patch('app.core.gemini.settings', mock_gemini_settings): - # Patch the model instance within the service for this test - with patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance) as patched_model_class, - patch.object(genai, 'configure') as patched_configure: - - service = gemini.GeminiOCRService() # Re-init to use the patched model - items = await service.extract_items(b"dummy_image") - - expected_call_args = [ - mock_gemini_settings.OCR_ITEM_EXTRACTION_PROMPT, - {"mime_type": "image/jpeg", "data": b"dummy_image"} - ] - service.model.generate_content_async.assert_called_once_with(contents=expected_call_args) - assert items == ["Apple", "Banana", "Orange"] - -@pytest.mark.asyncio -async def test_gemini_ocr_service_extract_items_quota_error(mock_gemini_settings, mock_generative_model_instance): - mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ResourceExhausted("Quota limits exceeded.") - - with patch('app.core.gemini.settings', mock_gemini_settings), \ - patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance), \ - patch.object(genai, 'configure'): - - service = gemini.GeminiOCRService() - with pytest.raises(OCRQuotaExceededError): - await service.extract_items(b"dummy_image") - -@pytest.mark.asyncio -async def test_gemini_ocr_service_extract_items_api_unavailable(mock_gemini_settings, mock_generative_model_instance): - # Simulate a generic API error that isn't quota related - mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.InternalServerError("Service unavailable") - - with patch('app.core.gemini.settings', mock_gemini_settings), \ - patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance), \ - patch.object(genai, 'configure'): - - service = gemini.GeminiOCRService() - with pytest.raises(OCRServiceUnavailableError): - await service.extract_items(b"dummy_image") - -@pytest.mark.asyncio -async def test_gemini_ocr_service_extract_items_no_text_response(mock_gemini_settings, mock_generative_model_instance): - mock_response = MagicMock() - mock_response.text = None # Simulate no text in response - mock_generative_model_instance.generate_content_async.return_value = mock_response - - with patch('app.core.gemini.settings', mock_gemini_settings), \ - patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance), \ - patch.object(genai, 'configure'): - - service = gemini.GeminiOCRService() - with pytest.raises(OCRUnexpectedError): - await service.extract_items(b"dummy_image") - - - -import pytest -from unittest.mock import patch, MagicMock -from datetime import datetime, timedelta, timezone - -from jose import jwt, JWTError -from passlib.context import CryptContext - -from app.core.security import ( - verify_password, - hash_password, - create_access_token, - create_refresh_token, - verify_access_token, - verify_refresh_token, - pwd_context, # Import for direct testing if needed, or to check its config -) -# Assuming app.config.settings will be mocked -# from app.config import settings - -# --- Tests for Password Hashing --- - -def test_hash_password(): - password = "securepassword123" - hashed = hash_password(password) - assert isinstance(hashed, str) - assert hashed != password - # Check that the default scheme (bcrypt) is used by verifying the hash prefix - # bcrypt hashes typically start with $2b$ or $2a$ or $2y$ - assert hashed.startswith("$2b$") or hashed.startswith("$2a$") or hashed.startswith("$2y$") - -def test_verify_password_correct(): - password = "testpassword" - hashed_password = pwd_context.hash(password) # Use the same context for consistency - assert verify_password(password, hashed_password) is True - -def test_verify_password_incorrect(): - password = "testpassword" - wrong_password = "wrongpassword" - hashed_password = pwd_context.hash(password) - assert verify_password(wrong_password, hashed_password) is False - -def test_verify_password_invalid_hash_format(): - password = "testpassword" - invalid_hash = "notarealhash" - assert verify_password(password, invalid_hash) is False - -# --- Tests for JWT Creation --- -# Mock settings for JWT tests -@pytest.fixture(scope="module") -def mock_jwt_settings(): - mock_settings = MagicMock() - mock_settings.SECRET_KEY = "testsecretkey" - mock_settings.ALGORITHM = "HS256" - mock_settings.ACCESS_TOKEN_EXPIRE_MINUTES = 30 - mock_settings.REFRESH_TOKEN_EXPIRE_MINUTES = 10080 # 7 days - return mock_settings - -@patch('app.core.security.settings') -def test_create_access_token_default_expiry(mock_settings_global, mock_jwt_settings): - mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY - mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM - mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES - - subject = "user@example.com" - token = create_access_token(subject) - assert isinstance(token, str) - - decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM]) - assert decoded_payload["sub"] == subject - assert decoded_payload["type"] == "access" - assert "exp" in decoded_payload - # Check if expiry is roughly correct (within a small delta) - expected_expiry = datetime.now(timezone.utc) + timedelta(minutes=mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES) - assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5) - -@patch('app.core.security.settings') -def test_create_access_token_custom_expiry(mock_settings_global, mock_jwt_settings): - mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY - mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM - # ACCESS_TOKEN_EXPIRE_MINUTES is not used here due to custom delta - - subject = 123 # Subject can be int - custom_delta = timedelta(hours=1) - token = create_access_token(subject, expires_delta=custom_delta) - assert isinstance(token, str) - - decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM]) - assert decoded_payload["sub"] == str(subject) - assert decoded_payload["type"] == "access" - expected_expiry = datetime.now(timezone.utc) + custom_delta - assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5) - -@patch('app.core.security.settings') -def test_create_refresh_token_default_expiry(mock_settings_global, mock_jwt_settings): - mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY - mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM - mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES - - subject = "refresh_subject" - token = create_refresh_token(subject) - assert isinstance(token, str) - - decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM]) - assert decoded_payload["sub"] == subject - assert decoded_payload["type"] == "refresh" - expected_expiry = datetime.now(timezone.utc) + timedelta(minutes=mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES) - assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5) - -# --- Tests for JWT Verification --- (More tests to be added here) - -@patch('app.core.security.settings') -def test_verify_access_token_valid(mock_settings_global, mock_jwt_settings): - mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY - mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM - mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES - - subject = "test_user_valid_access" - token = create_access_token(subject) - payload = verify_access_token(token) - assert payload is not None - assert payload["sub"] == subject - assert payload["type"] == "access" - -@patch('app.core.security.settings') -def test_verify_access_token_invalid_signature(mock_settings_global, mock_jwt_settings): - mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY - mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM - mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES - - subject = "test_user_invalid_sig" - # Create token with correct key - token = create_access_token(subject) - - # Try to verify with wrong key - mock_settings_global.SECRET_KEY = "wrongsecretkey" - payload = verify_access_token(token) - assert payload is None - -@patch('app.core.security.settings') -@patch('app.core.security.datetime') # Mock datetime to control token expiry -def test_verify_access_token_expired(mock_datetime, mock_settings_global, mock_jwt_settings): - mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY - mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM - mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = 1 # Expire in 1 minute - - # Set current time for token creation - now = datetime.now(timezone.utc) - mock_datetime.now.return_value = now - mock_datetime.fromtimestamp = datetime.fromtimestamp # Ensure original fromtimestamp is used by jwt.decode - mock_datetime.timedelta = timedelta # Ensure original timedelta is used - - subject = "test_user_expired" - token = create_access_token(subject) - - # Advance time beyond expiry for verification - mock_datetime.now.return_value = now + timedelta(minutes=5) - payload = verify_access_token(token) - assert payload is None - -@patch('app.core.security.settings') -def test_verify_access_token_wrong_type(mock_settings_global, mock_jwt_settings): - mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY - mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM - mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES # For refresh token creation - - subject = "test_user_wrong_type" - # Create a refresh token - refresh_token = create_refresh_token(subject) - - # Try to verify it as an access token - payload = verify_access_token(refresh_token) - assert payload is None - -@patch('app.core.security.settings') -def test_verify_refresh_token_valid(mock_settings_global, mock_jwt_settings): - mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY - mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM - mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES - - subject = "test_user_valid_refresh" - token = create_refresh_token(subject) - payload = verify_refresh_token(token) - assert payload is not None - assert payload["sub"] == subject - assert payload["type"] == "refresh" - -@patch('app.core.security.settings') -@patch('app.core.security.datetime') -def test_verify_refresh_token_expired(mock_datetime, mock_settings_global, mock_jwt_settings): - mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY - mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM - mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = 1 # Expire in 1 minute - - now = datetime.now(timezone.utc) - mock_datetime.now.return_value = now - mock_datetime.fromtimestamp = datetime.fromtimestamp - mock_datetime.timedelta = timedelta - - subject = "test_user_expired_refresh" - token = create_refresh_token(subject) - - mock_datetime.now.return_value = now + timedelta(minutes=5) - payload = verify_refresh_token(token) - assert payload is None - -@patch('app.core.security.settings') -def test_verify_refresh_token_wrong_type(mock_settings_global, mock_jwt_settings): - mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY - mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM - mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES - - subject = "test_user_wrong_type_refresh" - access_token = create_access_token(subject) - - payload = verify_refresh_token(access_token) - assert payload is None - - - - - - - -import pytest -from unittest.mock import AsyncMock, MagicMock, patch -from sqlalchemy.exc import IntegrityError, OperationalError -from decimal import Decimal, ROUND_HALF_UP -from datetime import datetime, timezone -from typing import List as PyList, Optional - -from app.crud.expense import ( - create_expense, - get_expense_by_id, - get_expenses_for_list, - get_expenses_for_group, - update_expense, # Assuming update_expense exists - delete_expense, # Assuming delete_expense exists - get_users_for_splitting # Helper, might test indirectly -) -from app.schemas.expense import ExpenseCreate, ExpenseUpdate, ExpenseSplitCreate -from app.models import ( - Expense as ExpenseModel, - ExpenseSplit as ExpenseSplitModel, - User as UserModel, - List as ListModel, - Group as GroupModel, - UserGroup as UserGroupModel, - Item as ItemModel, - SplitTypeEnum -) -from app.core.exceptions import ( - ListNotFoundError, - GroupNotFoundError, - UserNotFoundError, - InvalidOperationError -) - -# General Fixtures -@pytest.fixture -def mock_db_session(): - session = AsyncMock() - session.commit = AsyncMock() - session.rollback = AsyncMock() - session.refresh = AsyncMock() - session.add = MagicMock() - session.delete = MagicMock() - session.execute = AsyncMock() - session.get = AsyncMock() - session.flush = AsyncMock() # create_expense uses flush - return session - -@pytest.fixture -def basic_user_model(): - return UserModel(id=1, name="Test User", email="test@example.com") - -@pytest.fixture -def another_user_model(): - return UserModel(id=2, name="Another User", email="another@example.com") - -@pytest.fixture -def basic_group_model(): - group = GroupModel(id=1, name="Test Group") - # Simulate member_associations for get_users_for_splitting if needed directly - # group.member_associations = [UserGroupModel(user_id=1, group_id=1, user=basic_user_model()), UserGroupModel(user_id=2, group_id=1, user=another_user_model())] - return group - -@pytest.fixture -def basic_list_model(basic_group_model, basic_user_model): - return ListModel(id=1, name="Test List", group_id=basic_group_model.id, group=basic_group_model, creator_id=basic_user_model.id, creator=basic_user_model) - -@pytest.fixture -def expense_create_data_equal_split_list_ctx(basic_list_model, basic_user_model): - return ExpenseCreate( - description="Grocery run", - total_amount=Decimal("30.00"), - currency="USD", - expense_date=datetime.now(timezone.utc), - split_type=SplitTypeEnum.EQUAL, - list_id=basic_list_model.id, - group_id=None, # Derived from list - item_id=None, - paid_by_user_id=basic_user_model.id, - splits_in=None - ) - -@pytest.fixture -def expense_create_data_equal_split_group_ctx(basic_group_model, basic_user_model): - return ExpenseCreate( - description="Movies", - total_amount=Decimal("50.00"), - currency="USD", - expense_date=datetime.now(timezone.utc), - split_type=SplitTypeEnum.EQUAL, - list_id=None, - group_id=basic_group_model.id, - item_id=None, - paid_by_user_id=basic_user_model.id, - splits_in=None - ) - -@pytest.fixture -def expense_create_data_exact_split(basic_group_model, basic_user_model, another_user_model): - return ExpenseCreate( - description="Dinner", - total_amount=Decimal("100.00"), - split_type=SplitTypeEnum.EXACT_AMOUNTS, - group_id=basic_group_model.id, - paid_by_user_id=basic_user_model.id, - splits_in=[ - ExpenseSplitCreate(user_id=basic_user_model.id, owed_amount=Decimal("60.00")), - ExpenseSplitCreate(user_id=another_user_model.id, owed_amount=Decimal("40.00")), - ] - ) - -@pytest.fixture -def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model): - return ExpenseModel( - id=1, - description=expense_create_data_equal_split_group_ctx.description, - total_amount=expense_create_data_equal_split_group_ctx.total_amount, - currency=expense_create_data_equal_split_group_ctx.currency, - expense_date=expense_create_data_equal_split_group_ctx.expense_date, - split_type=expense_create_data_equal_split_group_ctx.split_type, - list_id=expense_create_data_equal_split_group_ctx.list_id, - group_id=expense_create_data_equal_split_group_ctx.group_id, - item_id=expense_create_data_equal_split_group_ctx.item_id, - paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id, - paid_by=basic_user_model, # Assuming paid_by relation is loaded - # splits would be populated after creation usually - version=1 - ) - -# Tests for get_users_for_splitting (indirectly tested via create_expense, but stubs for direct if needed) -@pytest.mark.asyncio -async def test_get_users_for_splitting_group_context(mock_db_session, basic_group_model, basic_user_model, another_user_model): - # Setup group with members - user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id) - user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id) - basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2] - - mock_execute = AsyncMock() - mock_execute.scalars.return_value.first.return_value = basic_group_model - mock_db_session.execute.return_value = mock_execute - - users = await get_users_for_splitting(mock_db_session, expense_group_id=1, expense_list_id=None, expense_paid_by_user_id=1) - assert len(users) == 2 - assert basic_user_model in users - assert another_user_model in users - -# --- create_expense Tests --- -@pytest.mark.asyncio -async def test_create_expense_equal_split_group_success(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model, another_user_model): - mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # Payer, Group - - # Mock get_users_for_splitting call within create_expense - # This is a bit tricky as it's an internal call. Patching is an option. - with patch('app.crud.expense.get_users_for_splitting', new_callable=AsyncMock) as mock_get_users: - mock_get_users.return_value = [basic_user_model, another_user_model] - - created_expense = await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, current_user_id=1) - - mock_db_session.add.assert_called() - mock_db_session.flush.assert_called_once() - # mock_db_session.commit.assert_called_once() # create_expense does not commit itself - # mock_db_session.refresh.assert_called_once() # create_expense does not refresh itself - - assert created_expense is not None - assert created_expense.total_amount == expense_create_data_equal_split_group_ctx.total_amount - assert created_expense.split_type == SplitTypeEnum.EQUAL - assert len(created_expense.splits) == 2 # Expect splits to be added to the model instance - - # Check split amounts - expected_amount_per_user = (expense_create_data_equal_split_group_ctx.total_amount / 2).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - for split in created_expense.splits: - assert split.owed_amount == expected_amount_per_user - -@pytest.mark.asyncio -async def test_create_expense_exact_split_success(mock_db_session, expense_create_data_exact_split, basic_user_model, basic_group_model, another_user_model): - mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # Payer, Group - - # Mock the select for user validation in exact splits - mock_user_select_result = AsyncMock() - mock_user_select_result.all.return_value = [(basic_user_model.id,), (another_user_model.id,)] # Simulate (id,) tuples - # To make it behave like scalars().all() that returns a list of IDs: - # We need to mock the scalars().all() part, or the whole execute chain for user validation. - # A simpler way for this specific case might be to mock the select for User.id - mock_execute_user_ids = AsyncMock() - # Construct a mock result that mimics what `await db.execute(select(UserModel.id)...)` would give then process - # It's a bit involved, usually `found_user_ids = {row[0] for row in user_results}` - # Let's assume the select returns a list of Row objects or tuples with one element - mock_user_ids_result_proxy = MagicMock() - mock_user_ids_result_proxy.__iter__.return_value = iter([(basic_user_model.id,), (another_user_model.id,)]) - mock_db_session.execute.return_value = mock_user_ids_result_proxy - - created_expense = await create_expense(mock_db_session, expense_create_data_exact_split, current_user_id=1) - - mock_db_session.add.assert_called() - mock_db_session.flush.assert_called_once() - assert created_expense is not None - assert created_expense.split_type == SplitTypeEnum.EXACT_AMOUNTS - assert len(created_expense.splits) == 2 - assert created_expense.splits[0].owed_amount == Decimal("60.00") - assert created_expense.splits[1].owed_amount == Decimal("40.00") - -@pytest.mark.asyncio -async def test_create_expense_payer_not_found(mock_db_session, expense_create_data_equal_split_group_ctx): - mock_db_session.get.return_value = None # Payer not found - with pytest.raises(UserNotFoundError): - await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, 1) - -@pytest.mark.asyncio -async def test_create_expense_no_list_or_group(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model): - mock_db_session.get.return_value = basic_user_model # Payer found - expense_data = expense_create_data_equal_split_group_ctx.model_copy() - expense_data.list_id = None - expense_data.group_id = None - with pytest.raises(InvalidOperationError, match="Expense must be associated with a list or a group"): - await create_expense(mock_db_session, expense_data, 1) - -# --- get_expense_by_id Tests --- -@pytest.mark.asyncio -async def test_get_expense_by_id_found(mock_db_session, db_expense_model): - mock_result = AsyncMock() - mock_result.scalars.return_value.first.return_value = db_expense_model - mock_db_session.execute.return_value = mock_result - - expense = await get_expense_by_id(mock_db_session, 1) - assert expense is not None - assert expense.id == 1 - mock_db_session.execute.assert_called_once() - -@pytest.mark.asyncio -async def test_get_expense_by_id_not_found(mock_db_session): - mock_result = AsyncMock() - mock_result.scalars.return_value.first.return_value = None - mock_db_session.execute.return_value = mock_result - - expense = await get_expense_by_id(mock_db_session, 999) - assert expense is None - -# --- get_expenses_for_list Tests --- -@pytest.mark.asyncio -async def test_get_expenses_for_list_success(mock_db_session, db_expense_model): - mock_result = AsyncMock() - mock_result.scalars.return_value.all.return_value = [db_expense_model] - mock_db_session.execute.return_value = mock_result - - expenses = await get_expenses_for_list(mock_db_session, list_id=1) - assert len(expenses) == 1 - assert expenses[0].id == db_expense_model.id - mock_db_session.execute.assert_called_once() - -# --- get_expenses_for_group Tests --- -@pytest.mark.asyncio -async def test_get_expenses_for_group_success(mock_db_session, db_expense_model): - mock_result = AsyncMock() - mock_result.scalars.return_value.all.return_value = [db_expense_model] - mock_db_session.execute.return_value = mock_result - - expenses = await get_expenses_for_group(mock_db_session, group_id=1) - assert len(expenses) == 1 - assert expenses[0].id == db_expense_model.id - mock_db_session.execute.assert_called_once() - -# --- Stubs for update_expense and delete_expense --- -# These will need more details once the actual implementation of update/delete is clear -# For example, how splits are handled on update, versioning, etc. - -@pytest.mark.asyncio -async def test_update_expense_stub(mock_db_session): - # Placeholder: Test logic for update_expense will be more complex - # Needs ExpenseUpdate schema, existing expense object, and mocking of commit/refresh - # Also depends on what fields are updatable and how splits are managed. - expense_to_update = MagicMock(spec=ExpenseModel) - expense_to_update.version = 1 - update_payload = ExpenseUpdate(description="New description", version=1) # Add other fields as per schema definition - - # Simulate the update_expense function behavior - # For example, if it loads the expense, modifies, commits, refreshes: - # mock_db_session.get.return_value = expense_to_update - # updated_expense = await update_expense(mock_db_session, expense_to_update, update_payload) - # assert updated_expense.description == "New description" - # mock_db_session.commit.assert_called_once() - # mock_db_session.refresh.assert_called_once() - pass # Replace with actual test logic - -@pytest.mark.asyncio -async def test_delete_expense_stub(mock_db_session): - # Placeholder: Test logic for delete_expense - # Needs an existing expense object and mocking of delete/commit - # Also, consider implications (e.g., are splits deleted?) - expense_to_delete = MagicMock(spec=ExpenseModel) - expense_to_delete.id = 1 - expense_to_delete.version = 1 - - # Simulate delete_expense behavior - # mock_db_session.get.return_value = expense_to_delete # If it re-fetches - # await delete_expense(mock_db_session, expense_to_delete, expected_version=1) - # mock_db_session.delete.assert_called_once_with(expense_to_delete) - # mock_db_session.commit.assert_called_once() - pass # Replace with actual test logic - -# TODO: Add more tests for create_expense covering: -# - List context success -# - Percentage, Shares, Item-based splits -# - Error cases for each split type (e.g., total mismatch, invalid inputs) -# - Validation of list_id/group_id consistency -# - User not found in splits_in -# - Item not found for ITEM_BASED split - -# TODO: Flesh out update_expense tests: -# - Success case -# - Version mismatch -# - Trying to update immutable fields -# - How splits are handled (recalculated, deleted/recreated, or not changeable) - -# TODO: Flesh out delete_expense tests: -# - Success case -# - Version mismatch (if applicable) -# - Ensure associated splits are also deleted (cascade behavior) - - - -import pytest -from unittest.mock import AsyncMock, MagicMock, patch -from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError -from sqlalchemy.future import select -from sqlalchemy import delete, func # For remove_user_from_group and get_group_member_count - -from app.crud.group import ( - create_group, - get_user_groups, - get_group_by_id, - is_user_member, - get_user_role_in_group, - add_user_to_group, - remove_user_from_group, - get_group_member_count, - check_group_membership, - check_user_role_in_group -) -from app.schemas.group import GroupCreate -from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel, UserRoleEnum -from app.core.exceptions import ( - GroupOperationError, - GroupNotFoundError, - DatabaseConnectionError, - DatabaseIntegrityError, - DatabaseQueryError, - DatabaseTransactionError, - GroupMembershipError, - GroupPermissionError -) - -# Fixtures -@pytest.fixture -def mock_db_session(): - session = AsyncMock() - # Patch begin_nested for SQLAlchemy 1.4+ if used, or just begin() if that's the pattern - # For simplicity, assuming `async with db.begin():` translates to db.begin() and db.commit()/rollback() - session.begin = AsyncMock() # Mock the begin call used in async with db.begin() - session.commit = AsyncMock() - session.rollback = AsyncMock() - session.refresh = AsyncMock() - session.add = MagicMock() - session.delete = MagicMock() # For remove_user_from_group (if it uses session.delete) - session.execute = AsyncMock() - session.get = AsyncMock() - session.flush = AsyncMock() - return session - -@pytest.fixture -def group_create_data(): - return GroupCreate(name="Test Group") - -@pytest.fixture -def creator_user_model(): - return UserModel(id=1, name="Creator User", email="creator@example.com") - -@pytest.fixture -def member_user_model(): - return UserModel(id=2, name="Member User", email="member@example.com") - -@pytest.fixture -def db_group_model(creator_user_model): - return GroupModel(id=1, name="Test Group", created_by_id=creator_user_model.id, creator=creator_user_model) - -@pytest.fixture -def db_user_group_owner_assoc(db_group_model, creator_user_model): - return UserGroupModel(user_id=creator_user_model.id, group_id=db_group_model.id, role=UserRoleEnum.owner, user=creator_user_model, group=db_group_model) - -@pytest.fixture -def db_user_group_member_assoc(db_group_model, member_user_model): - return UserGroupModel(user_id=member_user_model.id, group_id=db_group_model.id, role=UserRoleEnum.member, user=member_user_model, group=db_group_model) - -# --- create_group Tests --- -@pytest.mark.asyncio -async def test_create_group_success(mock_db_session, group_create_data, creator_user_model): - async def mock_refresh(instance): - instance.id = 1 # Simulate ID assignment by DB - return None - mock_db_session.refresh = AsyncMock(side_effect=mock_refresh) - - created_group = await create_group(mock_db_session, group_create_data, creator_user_model.id) - - assert mock_db_session.add.call_count == 2 # Group and UserGroup - mock_db_session.flush.assert_called() # Called multiple times - mock_db_session.refresh.assert_called_once_with(created_group) - assert created_group is not None - assert created_group.name == group_create_data.name - assert created_group.created_by_id == creator_user_model.id - # Further check if UserGroup was created correctly by inspecting mock_db_session.add calls or by fetching - -@pytest.mark.asyncio -async def test_create_group_integrity_error(mock_db_session, group_create_data, creator_user_model): - mock_db_session.flush.side_effect = IntegrityError("mock integrity error", "params", "orig") - with pytest.raises(DatabaseIntegrityError): - await create_group(mock_db_session, group_create_data, creator_user_model.id) - mock_db_session.rollback.assert_called_once() # Assuming rollback within the except block of create_group - -# --- get_user_groups Tests --- -@pytest.mark.asyncio -async def test_get_user_groups_success(mock_db_session, db_group_model, creator_user_model): - mock_result = AsyncMock() - mock_result.scalars.return_value.all.return_value = [db_group_model] - mock_db_session.execute.return_value = mock_result - - groups = await get_user_groups(mock_db_session, creator_user_model.id) - assert len(groups) == 1 - assert groups[0].name == db_group_model.name - mock_db_session.execute.assert_called_once() - -# --- get_group_by_id Tests --- -@pytest.mark.asyncio -async def test_get_group_by_id_found(mock_db_session, db_group_model): - mock_result = AsyncMock() - mock_result.scalars.return_value.first.return_value = db_group_model - mock_db_session.execute.return_value = mock_result - - group = await get_group_by_id(mock_db_session, db_group_model.id) - assert group is not None - assert group.id == db_group_model.id - # Add assertions for eager loaded members if applicable and mocked - -@pytest.mark.asyncio -async def test_get_group_by_id_not_found(mock_db_session): - mock_result = AsyncMock() - mock_result.scalars.return_value.first.return_value = None - mock_db_session.execute.return_value = mock_result - group = await get_group_by_id(mock_db_session, 999) - assert group is None - -# --- is_user_member Tests --- -@pytest.mark.asyncio -async def test_is_user_member_true(mock_db_session, db_group_model, creator_user_model): - mock_result = AsyncMock() - mock_result.scalar_one_or_none.return_value = 1 # Simulate UserGroup.id found - mock_db_session.execute.return_value = mock_result - is_member = await is_user_member(mock_db_session, db_group_model.id, creator_user_model.id) - assert is_member is True - -@pytest.mark.asyncio -async def test_is_user_member_false(mock_db_session, db_group_model, member_user_model): - mock_result = AsyncMock() - mock_result.scalar_one_or_none.return_value = None # Simulate no UserGroup.id found - mock_db_session.execute.return_value = mock_result - is_member = await is_user_member(mock_db_session, db_group_model.id, member_user_model.id + 1) # Non-member - assert is_member is False - -# --- get_user_role_in_group Tests --- -@pytest.mark.asyncio -async def test_get_user_role_in_group_owner(mock_db_session, db_group_model, creator_user_model): - mock_result = AsyncMock() - mock_result.scalar_one_or_none.return_value = UserRoleEnum.owner - mock_db_session.execute.return_value = mock_result - role = await get_user_role_in_group(mock_db_session, db_group_model.id, creator_user_model.id) - assert role == UserRoleEnum.owner - -# --- add_user_to_group Tests --- -@pytest.mark.asyncio -async def test_add_user_to_group_new_member(mock_db_session, db_group_model, member_user_model): - # First execute call for checking existing membership returns None - mock_existing_check_result = AsyncMock() - mock_existing_check_result.scalar_one_or_none.return_value = None - mock_db_session.execute.return_value = mock_existing_check_result - - async def mock_refresh_user_group(instance): - instance.id = 100 # Simulate ID for UserGroupModel - return None - mock_db_session.refresh = AsyncMock(side_effect=mock_refresh_user_group) - - user_group_assoc = await add_user_to_group(mock_db_session, db_group_model.id, member_user_model.id, UserRoleEnum.member) - - mock_db_session.add.assert_called_once() - mock_db_session.flush.assert_called_once() - mock_db_session.refresh.assert_called_once() - assert user_group_assoc is not None - assert user_group_assoc.user_id == member_user_model.id - assert user_group_assoc.group_id == db_group_model.id - assert user_group_assoc.role == UserRoleEnum.member - -@pytest.mark.asyncio -async def test_add_user_to_group_already_member(mock_db_session, db_group_model, creator_user_model, db_user_group_owner_assoc): - mock_existing_check_result = AsyncMock() - mock_existing_check_result.scalar_one_or_none.return_value = db_user_group_owner_assoc # User is already a member - mock_db_session.execute.return_value = mock_existing_check_result - - user_group_assoc = await add_user_to_group(mock_db_session, db_group_model.id, creator_user_model.id) - assert user_group_assoc is None - mock_db_session.add.assert_not_called() - -# --- remove_user_from_group Tests --- -@pytest.mark.asyncio -async def test_remove_user_from_group_success(mock_db_session, db_group_model, member_user_model): - mock_delete_result = AsyncMock() - mock_delete_result.scalar_one_or_none.return_value = 1 # Simulate a row was deleted (returning ID) - mock_db_session.execute.return_value = mock_delete_result - - removed = await remove_user_from_group(mock_db_session, db_group_model.id, member_user_model.id) - assert removed is True - # Assert that db.execute was called with a delete statement - # This requires inspecting the call args of mock_db_session.execute - # For simplicity, we check it was called. A deeper check would validate the SQL query itself. - mock_db_session.execute.assert_called_once() - -# --- get_group_member_count Tests --- -@pytest.mark.asyncio -async def test_get_group_member_count_success(mock_db_session, db_group_model): - mock_count_result = AsyncMock() - mock_count_result.scalar_one.return_value = 5 - mock_db_session.execute.return_value = mock_count_result - count = await get_group_member_count(mock_db_session, db_group_model.id) - assert count == 5 - -# --- check_group_membership Tests --- -@pytest.mark.asyncio -async def test_check_group_membership_is_member(mock_db_session, db_group_model, creator_user_model): - mock_db_session.get.return_value = db_group_model # Group exists - mock_membership_result = AsyncMock() - mock_membership_result.scalar_one_or_none.return_value = 1 # User is a member - mock_db_session.execute.return_value = mock_membership_result - - await check_group_membership(mock_db_session, db_group_model.id, creator_user_model.id) - # No exception means success - -@pytest.mark.asyncio -async def test_check_group_membership_group_not_found(mock_db_session, creator_user_model): - mock_db_session.get.return_value = None # Group does not exist - with pytest.raises(GroupNotFoundError): - await check_group_membership(mock_db_session, 999, creator_user_model.id) - -@pytest.mark.asyncio -async def test_check_group_membership_not_member(mock_db_session, db_group_model, member_user_model): - mock_db_session.get.return_value = db_group_model # Group exists - mock_membership_result = AsyncMock() - mock_membership_result.scalar_one_or_none.return_value = None # User is not a member - mock_db_session.execute.return_value = mock_membership_result - with pytest.raises(GroupMembershipError): - await check_group_membership(mock_db_session, db_group_model.id, member_user_model.id) - -# --- check_user_role_in_group Tests --- -@pytest.mark.asyncio -async def test_check_user_role_in_group_sufficient_role(mock_db_session, db_group_model, creator_user_model): - # Mock check_group_membership (implicitly called) - mock_db_session.get.return_value = db_group_model - mock_membership_check = AsyncMock() - mock_membership_check.scalar_one_or_none.return_value = 1 # User is member - - # Mock get_user_role_in_group - mock_role_check = AsyncMock() - mock_role_check.scalar_one_or_none.return_value = UserRoleEnum.owner - - mock_db_session.execute.side_effect = [mock_membership_check, mock_role_check] - - await check_user_role_in_group(mock_db_session, db_group_model.id, creator_user_model.id, UserRoleEnum.member) - # No exception means success - -@pytest.mark.asyncio -async def test_check_user_role_in_group_insufficient_role(mock_db_session, db_group_model, member_user_model): - mock_db_session.get.return_value = db_group_model # Group exists - mock_membership_check = AsyncMock() - mock_membership_check.scalar_one_or_none.return_value = 1 # User is member (for check_group_membership call) - - mock_role_check = AsyncMock() - mock_role_check.scalar_one_or_none.return_value = UserRoleEnum.member # User's actual role - - mock_db_session.execute.side_effect = [mock_membership_check, mock_role_check] - - with pytest.raises(GroupPermissionError): - await check_user_role_in_group(mock_db_session, db_group_model.id, member_user_model.id, UserRoleEnum.owner) - -# TODO: Add tests for DB operational/SQLAlchemy errors for each function similar to create_group_integrity_error -# TODO: Test edge cases like trying to add user to non-existent group (should be caught by FK constraints or prior checks) - - - -import pytest -from unittest.mock import AsyncMock, MagicMock, patch -from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError # Assuming these might be raised -from datetime import datetime, timedelta, timezone -import secrets - -from app.crud.invite import ( - create_invite, - get_active_invite_by_code, - deactivate_invite, - MAX_CODE_GENERATION_ATTEMPTS -) -from app.models import Invite as InviteModel, User as UserModel, Group as GroupModel # For context -# No specific schemas for invite CRUD usually, but models are used. - -# Fixtures -@pytest.fixture -def mock_db_session(): - session = AsyncMock() - session.commit = AsyncMock() - session.rollback = AsyncMock() - session.refresh = AsyncMock() - session.add = MagicMock() - session.execute = AsyncMock() - return session - -@pytest.fixture -def group_model(): - return GroupModel(id=1, name="Test Group") - -@pytest.fixture -def user_model(): # Creator - return UserModel(id=1, name="Creator User") - -@pytest.fixture -def db_invite_model(group_model, user_model): - return InviteModel( - id=1, - code="test_invite_code_123", - group_id=group_model.id, - created_by_id=user_model.id, - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - is_active=True - ) - -# --- create_invite Tests --- -@pytest.mark.asyncio -@patch('app.crud.invite.secrets.token_urlsafe') # Patch secrets.token_urlsafe -async def test_create_invite_success_first_attempt(mock_token_urlsafe, mock_db_session, group_model, user_model): - generated_code = "unique_code_123" - mock_token_urlsafe.return_value = generated_code - - # Mock DB execute for checking existing code (first attempt, no existing code) - mock_existing_check_result = AsyncMock() - mock_existing_check_result.scalar_one_or_none.return_value = None - mock_db_session.execute.return_value = mock_existing_check_result - - invite = await create_invite(mock_db_session, group_model.id, user_model.id, expires_in_days=5) - - mock_token_urlsafe.assert_called_once_with(16) - mock_db_session.execute.assert_called_once() # For the uniqueness check - mock_db_session.add.assert_called_once() - mock_db_session.commit.assert_called_once() - mock_db_session.refresh.assert_called_once_with(invite) - - assert invite is not None - assert invite.code == generated_code - assert invite.group_id == group_model.id - assert invite.created_by_id == user_model.id - assert invite.is_active is True - assert invite.expires_at > datetime.now(timezone.utc) + timedelta(days=4) # Check expiry is roughly correct - -@pytest.mark.asyncio -@patch('app.crud.invite.secrets.token_urlsafe') -async def test_create_invite_success_after_collision(mock_token_urlsafe, mock_db_session, group_model, user_model): - colliding_code = "colliding_code" - unique_code = "finally_unique_code" - mock_token_urlsafe.side_effect = [colliding_code, unique_code] # First call collides, second is unique - - # Mock DB execute for checking existing code - mock_collision_check_result = AsyncMock() - mock_collision_check_result.scalar_one_or_none.return_value = 1 # Simulate collision (ID found) - - mock_no_collision_check_result = AsyncMock() - mock_no_collision_check_result.scalar_one_or_none.return_value = None # No collision - - mock_db_session.execute.side_effect = [mock_collision_check_result, mock_no_collision_check_result] - - invite = await create_invite(mock_db_session, group_model.id, user_model.id) - - assert mock_token_urlsafe.call_count == 2 - assert mock_db_session.execute.call_count == 2 - assert invite is not None - assert invite.code == unique_code - -@pytest.mark.asyncio -@patch('app.crud.invite.secrets.token_urlsafe') -async def test_create_invite_fails_after_max_attempts(mock_token_urlsafe, mock_db_session, group_model, user_model): - mock_token_urlsafe.return_value = "always_colliding_code" - - mock_collision_check_result = AsyncMock() - mock_collision_check_result.scalar_one_or_none.return_value = 1 # Always collide - mock_db_session.execute.return_value = mock_collision_check_result - - invite = await create_invite(mock_db_session, group_model.id, user_model.id) - - assert invite is None - assert mock_token_urlsafe.call_count == MAX_CODE_GENERATION_ATTEMPTS - assert mock_db_session.execute.call_count == MAX_CODE_GENERATION_ATTEMPTS - mock_db_session.add.assert_not_called() - -# --- get_active_invite_by_code Tests --- -@pytest.mark.asyncio -async def test_get_active_invite_by_code_found_active(mock_db_session, db_invite_model): - db_invite_model.is_active = True - db_invite_model.expires_at = datetime.now(timezone.utc) + timedelta(days=1) - - mock_result = AsyncMock() - mock_result.scalars.return_value.first.return_value = db_invite_model - mock_db_session.execute.return_value = mock_result - - invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code) - assert invite is not None - assert invite.code == db_invite_model.code - mock_db_session.execute.assert_called_once() - -@pytest.mark.asyncio -async def test_get_active_invite_by_code_not_found(mock_db_session): - mock_result = AsyncMock() - mock_result.scalars.return_value.first.return_value = None - mock_db_session.execute.return_value = mock_result - invite = await get_active_invite_by_code(mock_db_session, "non_existent_code") - assert invite is None - -@pytest.mark.asyncio -async def test_get_active_invite_by_code_inactive(mock_db_session, db_invite_model): - db_invite_model.is_active = False # Inactive - db_invite_model.expires_at = datetime.now(timezone.utc) + timedelta(days=1) - - mock_result = AsyncMock() - mock_result.scalars.return_value.first.return_value = None # Should not be found by query - mock_db_session.execute.return_value = mock_result - - invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code) - assert invite is None - -@pytest.mark.asyncio -async def test_get_active_invite_by_code_expired(mock_db_session, db_invite_model): - db_invite_model.is_active = True - db_invite_model.expires_at = datetime.now(timezone.utc) - timedelta(days=1) # Expired - - mock_result = AsyncMock() - mock_result.scalars.return_value.first.return_value = None # Should not be found by query - mock_db_session.execute.return_value = mock_result - - invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code) - assert invite is None - -# --- deactivate_invite Tests --- -@pytest.mark.asyncio -async def test_deactivate_invite_success(mock_db_session, db_invite_model): - db_invite_model.is_active = True # Ensure it starts active - - deactivated_invite = await deactivate_invite(mock_db_session, db_invite_model) - - mock_db_session.add.assert_called_once_with(db_invite_model) - mock_db_session.commit.assert_called_once() - mock_db_session.refresh.assert_called_once_with(db_invite_model) - assert deactivated_invite.is_active is False - -# It might be useful to test DB error cases (OperationalError, etc.) for each function -# if they have specific try-except blocks, but invite.py seems to rely on caller/framework for some of that. -# create_invite has its own DB interaction within the loop, so that's covered. -# get_active_invite_by_code and deactivate_invite are simpler DB ops. - - - -import pytest -from unittest.mock import AsyncMock, MagicMock, patch -from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError -from datetime import datetime, timezone - -from app.crud.item import ( - create_item, - get_items_by_list_id, - get_item_by_id, - update_item, - delete_item -) -from app.schemas.item import ItemCreate, ItemUpdate -from app.models import Item as ItemModel, User as UserModel, List as ListModel -from app.core.exceptions import ( - ItemNotFoundError, # Not directly raised by CRUD but good for API layer tests - DatabaseConnectionError, - DatabaseIntegrityError, - DatabaseQueryError, - DatabaseTransactionError, - ConflictError -) - -# Fixtures -@pytest.fixture -def mock_db_session(): - session = AsyncMock() - session.begin = AsyncMock() - session.commit = AsyncMock() - session.rollback = AsyncMock() - session.refresh = AsyncMock() - session.add = MagicMock() - session.delete = MagicMock() - session.execute = AsyncMock() - session.get = AsyncMock() # Though not directly used in item.py, good for consistency - session.flush = AsyncMock() - return session - -@pytest.fixture -def item_create_data(): - return ItemCreate(name="Test Item", quantity="1 pack") - -@pytest.fixture -def item_update_data(): - return ItemUpdate(name="Updated Test Item", quantity="2 packs", version=1, is_complete=False) - -@pytest.fixture -def user_model(): - return UserModel(id=1, name="Test User", email="test@example.com") - -@pytest.fixture -def list_model(): - return ListModel(id=1, name="Test List") - -@pytest.fixture -def db_item_model(list_model, user_model): - return ItemModel( - id=1, - name="Existing Item", - quantity="1 unit", - list_id=list_model.id, - added_by_id=user_model.id, - is_complete=False, - version=1, - created_at=datetime.now(timezone.utc), - updated_at=datetime.now(timezone.utc) - ) - -# --- create_item Tests --- -@pytest.mark.asyncio -async def test_create_item_success(mock_db_session, item_create_data, list_model, user_model): - async def mock_refresh(instance): - instance.id = 10 # Simulate ID assignment - instance.version = 1 # Simulate version init - instance.created_at = datetime.now(timezone.utc) - instance.updated_at = datetime.now(timezone.utc) - return None - mock_db_session.refresh = AsyncMock(side_effect=mock_refresh) - - created_item = await create_item(mock_db_session, item_create_data, list_model.id, user_model.id) - - mock_db_session.add.assert_called_once() - mock_db_session.flush.assert_called_once() - mock_db_session.refresh.assert_called_once_with(created_item) - assert created_item is not None - assert created_item.name == item_create_data.name - assert created_item.list_id == list_model.id - assert created_item.added_by_id == user_model.id - assert created_item.is_complete is False - assert created_item.version == 1 - -@pytest.mark.asyncio -async def test_create_item_integrity_error(mock_db_session, item_create_data, list_model, user_model): - mock_db_session.flush.side_effect = IntegrityError("mock integrity error", "params", "orig") - with pytest.raises(DatabaseIntegrityError): - await create_item(mock_db_session, item_create_data, list_model.id, user_model.id) - mock_db_session.rollback.assert_called_once() - -# --- get_items_by_list_id Tests --- -@pytest.mark.asyncio -async def test_get_items_by_list_id_success(mock_db_session, db_item_model, list_model): - mock_result = AsyncMock() - mock_result.scalars.return_value.all.return_value = [db_item_model] - mock_db_session.execute.return_value = mock_result - - items = await get_items_by_list_id(mock_db_session, list_model.id) - assert len(items) == 1 - assert items[0].id == db_item_model.id - mock_db_session.execute.assert_called_once() - -# --- get_item_by_id Tests --- -@pytest.mark.asyncio -async def test_get_item_by_id_found(mock_db_session, db_item_model): - mock_result = AsyncMock() - mock_result.scalars.return_value.first.return_value = db_item_model - mock_db_session.execute.return_value = mock_result - item = await get_item_by_id(mock_db_session, db_item_model.id) - assert item is not None - assert item.id == db_item_model.id - -@pytest.mark.asyncio -async def test_get_item_by_id_not_found(mock_db_session): - mock_result = AsyncMock() - mock_result.scalars.return_value.first.return_value = None - mock_db_session.execute.return_value = mock_result - item = await get_item_by_id(mock_db_session, 999) - assert item is None - -# --- update_item Tests --- -@pytest.mark.asyncio -async def test_update_item_success(mock_db_session, db_item_model, item_update_data, user_model): - item_update_data.version = db_item_model.version # Match versions for successful update - item_update_data.name = "Newly Updated Name" - item_update_data.is_complete = True # Test completion logic - - updated_item = await update_item(mock_db_session, db_item_model, item_update_data, user_model.id) - - mock_db_session.add.assert_called_once_with(db_item_model) # add is used for existing objects too - mock_db_session.flush.assert_called_once() - mock_db_session.refresh.assert_called_once_with(db_item_model) - assert updated_item.name == "Newly Updated Name" - assert updated_item.version == db_item_model.version # Check version increment logic in test - assert updated_item.is_complete is True - assert updated_item.completed_by_id == user_model.id - -@pytest.mark.asyncio -async def test_update_item_version_conflict(mock_db_session, db_item_model, item_update_data, user_model): - item_update_data.version = db_item_model.version + 1 # Create a version mismatch - with pytest.raises(ConflictError): - await update_item(mock_db_session, db_item_model, item_update_data, user_model.id) - mock_db_session.rollback.assert_called_once() - -@pytest.mark.asyncio -async def test_update_item_set_incomplete(mock_db_session, db_item_model, item_update_data, user_model): - db_item_model.is_complete = True # Start as complete - db_item_model.completed_by_id = user_model.id - db_item_model.version = 1 - - item_update_data.version = 1 - item_update_data.is_complete = False - item_update_data.name = db_item_model.name # No name change for this test - item_update_data.quantity = db_item_model.quantity - - updated_item = await update_item(mock_db_session, db_item_model, item_update_data, user_model.id) - assert updated_item.is_complete is False - assert updated_item.completed_by_id is None - assert updated_item.version == 2 - -# --- delete_item Tests --- -@pytest.mark.asyncio -async def test_delete_item_success(mock_db_session, db_item_model): - result = await delete_item(mock_db_session, db_item_model) - assert result is None - mock_db_session.delete.assert_called_once_with(db_item_model) - mock_db_session.commit.assert_called_once() # Commit happens in the `async with db.begin()` context manager - -@pytest.mark.asyncio -async def test_delete_item_db_error(mock_db_session, db_item_model): - mock_db_session.delete.side_effect = OperationalError("mock op error", "params", "orig") - with pytest.raises(DatabaseConnectionError): - await delete_item(mock_db_session, db_item_model) - mock_db_session.rollback.assert_called_once() - -# TODO: Add more specific DB error tests (Operational, SQLAlchemyError) for each function. - - - -import pytest -from unittest.mock import AsyncMock, MagicMock, patch -from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError -from sqlalchemy.future import select -from sqlalchemy import func as sql_func # For get_list_status -from datetime import datetime, timezone - -from app.crud.list import ( - create_list, - get_lists_for_user, - get_list_by_id, - update_list, - delete_list, - check_list_permission, - get_list_status -) -from app.schemas.list import ListCreate, ListUpdate, ListStatus -from app.models import List as ListModel, User as UserModel, Group as GroupModel, UserGroup as UserGroupModel, Item as ItemModel -from app.core.exceptions import ( - ListNotFoundError, - ListPermissionError, - ListCreatorRequiredError, - DatabaseConnectionError, - DatabaseIntegrityError, - DatabaseQueryError, - DatabaseTransactionError, - ConflictError -) - -# Fixtures -@pytest.fixture -def mock_db_session(): - session = AsyncMock() - session.begin = AsyncMock() - session.commit = AsyncMock() - session.rollback = AsyncMock() - session.refresh = AsyncMock() - session.add = MagicMock() - session.delete = MagicMock() - session.execute = AsyncMock() - session.get = AsyncMock() # Used by check_list_permission via get_list_by_id - session.flush = AsyncMock() - return session - -@pytest.fixture -def list_create_data(): - return ListCreate(name="New Shopping List", description="Groceries for the week") - -@pytest.fixture -def list_update_data(): - return ListUpdate(name="Updated Shopping List", description="Weekend Groceries", version=1) - -@pytest.fixture -def user_model(): - return UserModel(id=1, name="Test User", email="test@example.com") - -@pytest.fixture -def another_user_model(): - return UserModel(id=2, name="Another User", email="another@example.com") - -@pytest.fixture -def group_model(): - return GroupModel(id=1, name="Test Group") - -@pytest.fixture -def db_list_personal_model(user_model): - return ListModel( - id=1, name="Personal List", created_by_id=user_model.id, creator=user_model, - version=1, updated_at=datetime.now(timezone.utc), items=[] - ) - -@pytest.fixture -def db_list_group_model(user_model, group_model): - return ListModel( - id=2, name="Group List", created_by_id=user_model.id, creator=user_model, - group_id=group_model.id, group=group_model, version=1, updated_at=datetime.now(timezone.utc), items=[] - ) - -# --- create_list Tests --- -@pytest.mark.asyncio -async def test_create_list_success(mock_db_session, list_create_data, user_model): - async def mock_refresh(instance): - instance.id = 100 - instance.version = 1 - instance.updated_at = datetime.now(timezone.utc) - return None - mock_db_session.refresh.return_value = None - mock_db_session.refresh.side_effect = mock_refresh - - created_list = await create_list(mock_db_session, list_create_data, user_model.id) - mock_db_session.add.assert_called_once() - mock_db_session.flush.assert_called_once() - mock_db_session.refresh.assert_called_once() - assert created_list.name == list_create_data.name - assert created_list.created_by_id == user_model.id - -# --- get_lists_for_user Tests --- -@pytest.mark.asyncio -async def test_get_lists_for_user_mix(mock_db_session, user_model, db_list_personal_model, db_list_group_model): - # Simulate user is part of group for db_list_group_model - mock_group_ids_result = AsyncMock() - mock_group_ids_result.scalars.return_value.all.return_value = [db_list_group_model.group_id] - - mock_lists_result = AsyncMock() - # Order should be personal list (created by user_id) then group list - mock_lists_result.scalars.return_value.all.return_value = [db_list_personal_model, db_list_group_model] - - mock_db_session.execute.side_effect = [mock_group_ids_result, mock_lists_result] - - lists = await get_lists_for_user(mock_db_session, user_model.id) - assert len(lists) == 2 - assert db_list_personal_model in lists - assert db_list_group_model in lists - assert mock_db_session.execute.call_count == 2 - -# --- get_list_by_id Tests --- -@pytest.mark.asyncio -async def test_get_list_by_id_found_no_items(mock_db_session, db_list_personal_model): - mock_result = AsyncMock() - mock_result.scalars.return_value.first.return_value = db_list_personal_model - mock_db_session.execute.return_value = mock_result - - found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=False) - assert found_list is not None - assert found_list.id == db_list_personal_model.id - # query options should not include selectinload for items - # (difficult to assert directly without inspecting query object in detail) - -@pytest.mark.asyncio -async def test_get_list_by_id_found_with_items(mock_db_session, db_list_personal_model): - # Simulate items loaded for the list - db_list_personal_model.items = [ItemModel(id=1, name="Test Item")] - mock_result = AsyncMock() - mock_result.scalars.return_value.first.return_value = db_list_personal_model - mock_db_session.execute.return_value = mock_result - - found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=True) - assert found_list is not None - assert len(found_list.items) == 1 - # query options should include selectinload for items - -# --- update_list Tests --- -@pytest.mark.asyncio -async def test_update_list_success(mock_db_session, db_list_personal_model, list_update_data): - list_update_data.version = db_list_personal_model.version # Match version - - updated_list = await update_list(mock_db_session, db_list_personal_model, list_update_data) - assert updated_list.name == list_update_data.name - assert updated_list.version == db_list_personal_model.version # version incremented in db_list_personal_model - mock_db_session.add.assert_called_once_with(db_list_personal_model) - mock_db_session.flush.assert_called_once() - mock_db_session.refresh.assert_called_once_with(db_list_personal_model) - -@pytest.mark.asyncio -async def test_update_list_conflict(mock_db_session, db_list_personal_model, list_update_data): - list_update_data.version = db_list_personal_model.version + 1 # Version mismatch - with pytest.raises(ConflictError): - await update_list(mock_db_session, db_list_personal_model, list_update_data) - mock_db_session.rollback.assert_called_once() - -# --- delete_list Tests --- -@pytest.mark.asyncio -async def test_delete_list_success(mock_db_session, db_list_personal_model): - await delete_list(mock_db_session, db_list_personal_model) - mock_db_session.delete.assert_called_once_with(db_list_personal_model) - mock_db_session.commit.assert_called_once() # from async with db.begin() - -# --- check_list_permission Tests --- -@pytest.mark.asyncio -async def test_check_list_permission_creator_access_personal_list(mock_db_session, db_list_personal_model, user_model): - # get_list_by_id (called by check_list_permission) will mock execute - mock_list_fetch_result = AsyncMock() - mock_list_fetch_result.scalars.return_value.first.return_value = db_list_personal_model - mock_db_session.execute.return_value = mock_list_fetch_result - - ret_list = await check_list_permission(mock_db_session, db_list_personal_model.id, user_model.id) - assert ret_list.id == db_list_personal_model.id - -@pytest.mark.asyncio -async def test_check_list_permission_group_member_access_group_list(mock_db_session, db_list_group_model, another_user_model, group_model): - # User `another_user_model` is not creator but member of the group - db_list_group_model.creator_id = user_model.id # Original creator is user_model - db_list_group_model.creator = user_model - - # Mock get_list_by_id internal call - mock_list_fetch_result = AsyncMock() - mock_list_fetch_result.scalars.return_value.first.return_value = db_list_group_model - - # Mock is_user_member call - with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member: - mock_is_member.return_value = True # another_user_model is a member - mock_db_session.execute.return_value = mock_list_fetch_result - - ret_list = await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id) - assert ret_list.id == db_list_group_model.id - mock_is_member.assert_called_once_with(mock_db_session, group_id=group_model.id, user_id=another_user_model.id) - -@pytest.mark.asyncio -async def test_check_list_permission_non_member_no_access_group_list(mock_db_session, db_list_group_model, another_user_model): - db_list_group_model.creator_id = user_model.id # Creator is not another_user_model - - mock_list_fetch_result = AsyncMock() - mock_list_fetch_result.scalars.return_value.first.return_value = db_list_group_model - - with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member: - mock_is_member.return_value = False # another_user_model is NOT a member - mock_db_session.execute.return_value = mock_list_fetch_result - - with pytest.raises(ListPermissionError): - await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id) - -@pytest.mark.asyncio -async def test_check_list_permission_list_not_found(mock_db_session, user_model): - mock_list_fetch_result = AsyncMock() - mock_list_fetch_result.scalars.return_value.first.return_value = None # List not found - mock_db_session.execute.return_value = mock_list_fetch_result - - with pytest.raises(ListNotFoundError): - await check_list_permission(mock_db_session, 999, user_model.id) - -# --- get_list_status Tests --- -@pytest.mark.asyncio -async def test_get_list_status_success(mock_db_session, db_list_personal_model): - list_updated_at = datetime.now(timezone.utc) - timezone.timedelta(hours=1) - item_updated_at = datetime.now(timezone.utc) - item_count = 5 - - db_list_personal_model.updated_at = list_updated_at - - # Mock for ListModel.updated_at query - mock_list_updated_result = AsyncMock() - mock_list_updated_result.scalar_one_or_none.return_value = list_updated_at - - # Mock for ItemModel status query - mock_item_status_result = AsyncMock() - # SQLAlchemy query for func.max and func.count returns a Row-like object or None - mock_item_status_row = MagicMock() - mock_item_status_row.latest_item_updated_at = item_updated_at - mock_item_status_row.item_count = item_count - mock_item_status_result.first.return_value = mock_item_status_row - - mock_db_session.execute.side_effect = [mock_list_updated_result, mock_item_status_result] - - status = await get_list_status(mock_db_session, db_list_personal_model.id) - assert status.list_updated_at == list_updated_at - assert status.latest_item_updated_at == item_updated_at - assert status.item_count == item_count - assert mock_db_session.execute.call_count == 2 - -@pytest.mark.asyncio -async def test_get_list_status_list_not_found(mock_db_session): - mock_list_updated_result = AsyncMock() - mock_list_updated_result.scalar_one_or_none.return_value = None # List not found - mock_db_session.execute.return_value = mock_list_updated_result - with pytest.raises(ListNotFoundError): - await get_list_status(mock_db_session, 999) - -# TODO: Add more specific DB error tests (Operational, SQLAlchemyError, IntegrityError) for each function. -# TODO: Test check_list_permission with require_creator=True cases. - - - -import pytest -from unittest.mock import AsyncMock, MagicMock, patch -from sqlalchemy.exc import IntegrityError, OperationalError -from sqlalchemy.future import select -from decimal import Decimal, ROUND_HALF_UP -from datetime import datetime, timezone -from typing import List as PyList - -from app.crud.settlement import ( - create_settlement, - get_settlement_by_id, - get_settlements_for_group, - get_settlements_involving_user, - update_settlement, - delete_settlement -) -from app.schemas.expense import SettlementCreate, SettlementUpdate -from app.models import Settlement as SettlementModel, User as UserModel, Group as GroupModel -from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError - -# Fixtures -@pytest.fixture -def mock_db_session(): - session = AsyncMock() - session.commit = AsyncMock() - session.rollback = AsyncMock() - session.refresh = AsyncMock() - session.add = MagicMock() - session.delete = MagicMock() - session.execute = AsyncMock() - session.get = AsyncMock() - return session - -@pytest.fixture -def settlement_create_data(): - return SettlementCreate( - group_id=1, - paid_by_user_id=1, - paid_to_user_id=2, - amount=Decimal("10.50"), - settlement_date=datetime.now(timezone.utc), - description="Test settlement" - ) - -@pytest.fixture -def settlement_update_data(): - return SettlementUpdate( - description="Updated settlement description", - settlement_date=datetime.now(timezone.utc), - version=1 # Assuming version is required for update - ) - -@pytest.fixture -def db_settlement_model(): - return SettlementModel( - id=1, - group_id=1, - paid_by_user_id=1, - paid_to_user_id=2, - amount=Decimal("10.50"), - settlement_date=datetime.now(timezone.utc), - description="Original settlement", - version=1, # Initial version - created_at=datetime.now(timezone.utc), - updated_at=datetime.now(timezone.utc), - payer=UserModel(id=1, name="Payer User"), - payee=UserModel(id=2, name="Payee User"), - group=GroupModel(id=1, name="Test Group") - ) - -@pytest.fixture -def payer_user_model(): - return UserModel(id=1, name="Payer User", email="payer@example.com") - -@pytest.fixture -def payee_user_model(): - return UserModel(id=2, name="Payee User", email="payee@example.com") - -@pytest.fixture -def group_model(): - return GroupModel(id=1, name="Test Group") - -# Tests for create_settlement -@pytest.mark.asyncio -async def test_create_settlement_success(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model): - mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model] # Order of gets - - created_settlement = await create_settlement(mock_db_session, settlement_create_data, current_user_id=1) - - mock_db_session.add.assert_called_once() - mock_db_session.commit.assert_called_once() - mock_db_session.refresh.assert_called_once() - assert created_settlement is not None - assert created_settlement.amount == settlement_create_data.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - assert created_settlement.paid_by_user_id == settlement_create_data.paid_by_user_id - assert created_settlement.paid_to_user_id == settlement_create_data.paid_to_user_id - - -@pytest.mark.asyncio -async def test_create_settlement_payer_not_found(mock_db_session, settlement_create_data): - mock_db_session.get.side_effect = [None, payee_user_model, group_model] - with pytest.raises(UserNotFoundError) as excinfo: - await create_settlement(mock_db_session, settlement_create_data, 1) - assert "Payer" in str(excinfo.value) - -@pytest.mark.asyncio -async def test_create_settlement_payee_not_found(mock_db_session, settlement_create_data, payer_user_model): - mock_db_session.get.side_effect = [payer_user_model, None, group_model] - with pytest.raises(UserNotFoundError) as excinfo: - await create_settlement(mock_db_session, settlement_create_data, 1) - assert "Payee" in str(excinfo.value) - -@pytest.mark.asyncio -async def test_create_settlement_group_not_found(mock_db_session, settlement_create_data, payer_user_model, payee_user_model): - mock_db_session.get.side_effect = [payer_user_model, payee_user_model, None] - with pytest.raises(GroupNotFoundError): - await create_settlement(mock_db_session, settlement_create_data, 1) - -@pytest.mark.asyncio -async def test_create_settlement_payer_equals_payee(mock_db_session, settlement_create_data, payer_user_model, group_model): - settlement_create_data.paid_to_user_id = settlement_create_data.paid_by_user_id - mock_db_session.get.side_effect = [payer_user_model, payer_user_model, group_model] - with pytest.raises(InvalidOperationError) as excinfo: - await create_settlement(mock_db_session, settlement_create_data, 1) - assert "Payer and Payee cannot be the same user" in str(excinfo.value) - -@pytest.mark.asyncio -async def test_create_settlement_commit_failure(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model): - mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model] - mock_db_session.commit.side_effect = Exception("DB commit failed") - with pytest.raises(InvalidOperationError) as excinfo: - await create_settlement(mock_db_session, settlement_create_data, 1) - assert "Failed to save settlement" in str(excinfo.value) - mock_db_session.rollback.assert_called_once() - - -# Tests for get_settlement_by_id -@pytest.mark.asyncio -async def test_get_settlement_by_id_found(mock_db_session, db_settlement_model): - mock_db_session.execute.return_value.scalars.return_value.first.return_value = db_settlement_model - settlement = await get_settlement_by_id(mock_db_session, 1) - assert settlement is not None - assert settlement.id == 1 - mock_db_session.execute.assert_called_once() - -@pytest.mark.asyncio -async def test_get_settlement_by_id_not_found(mock_db_session): - mock_db_session.execute.return_value.scalars.return_value.first.return_value = None - settlement = await get_settlement_by_id(mock_db_session, 999) - assert settlement is None - -# Tests for get_settlements_for_group -@pytest.mark.asyncio -async def test_get_settlements_for_group_success(mock_db_session, db_settlement_model): - mock_db_session.execute.return_value.scalars.return_value.all.return_value = [db_settlement_model] - settlements = await get_settlements_for_group(mock_db_session, group_id=1) - assert len(settlements) == 1 - assert settlements[0].group_id == 1 - mock_db_session.execute.assert_called_once() - -# Tests for get_settlements_involving_user -@pytest.mark.asyncio -async def test_get_settlements_involving_user_success(mock_db_session, db_settlement_model): - mock_db_session.execute.return_value.scalars.return_value.all.return_value = [db_settlement_model] - settlements = await get_settlements_involving_user(mock_db_session, user_id=1) - assert len(settlements) == 1 - assert settlements[0].paid_by_user_id == 1 or settlements[0].paid_to_user_id == 1 - mock_db_session.execute.assert_called_once() - -@pytest.mark.asyncio -async def test_get_settlements_involving_user_with_group_filter(mock_db_session, db_settlement_model): - mock_db_session.execute.return_value.scalars.return_value.all.return_value = [db_settlement_model] - settlements = await get_settlements_involving_user(mock_db_session, user_id=1, group_id=1) - assert len(settlements) == 1 - # More specific assertions about the query would require deeper mocking of SQLAlchemy query construction - mock_db_session.execute.assert_called_once() - - -# Tests for update_settlement -@pytest.mark.asyncio -async def test_update_settlement_success(mock_db_session, db_settlement_model, settlement_update_data): - # Ensure settlement_update_data.version matches db_settlement_model.version - settlement_update_data.version = db_settlement_model.version - - # Mock datetime.now() - fixed_datetime_now = datetime.now(timezone.utc) - with patch('app.crud.settlement.datetime', wraps=datetime) as mock_datetime: - mock_datetime.now.return_value = fixed_datetime_now - - updated_settlement = await update_settlement(mock_db_session, db_settlement_model, settlement_update_data) - - mock_db_session.commit.assert_called_once() - mock_db_session.refresh.assert_called_once() - assert updated_settlement.description == settlement_update_data.description - assert updated_settlement.settlement_date == settlement_update_data.settlement_date - assert updated_settlement.version == db_settlement_model.version + 1 # Version incremented - assert updated_settlement.updated_at == fixed_datetime_now - -@pytest.mark.asyncio -async def test_update_settlement_version_mismatch(mock_db_session, db_settlement_model, settlement_update_data): - settlement_update_data.version = db_settlement_model.version + 1 # Mismatched version - with pytest.raises(InvalidOperationError) as excinfo: - await update_settlement(mock_db_session, db_settlement_model, settlement_update_data) - assert "version does not match" in str(excinfo.value) - -@pytest.mark.asyncio -async def test_update_settlement_disallowed_field(mock_db_session, db_settlement_model): - # Create an update payload with a field not allowed to be updated, e.g., 'amount' - invalid_update_data = SettlementUpdate(amount=Decimal("100.00"), version=db_settlement_model.version) - with pytest.raises(InvalidOperationError) as excinfo: - await update_settlement(mock_db_session, db_settlement_model, invalid_update_data) - assert "Field 'amount' cannot be updated" in str(excinfo.value) - -@pytest.mark.asyncio -async def test_update_settlement_commit_failure(mock_db_session, db_settlement_model, settlement_update_data): - settlement_update_data.version = db_settlement_model.version - mock_db_session.commit.side_effect = Exception("DB commit failed") - with pytest.raises(InvalidOperationError) as excinfo: - await update_settlement(mock_db_session, db_settlement_model, settlement_update_data) - assert "Failed to update settlement" in str(excinfo.value) - mock_db_session.rollback.assert_called_once() - -# Tests for delete_settlement -@pytest.mark.asyncio -async def test_delete_settlement_success(mock_db_session, db_settlement_model): - await delete_settlement(mock_db_session, db_settlement_model) - mock_db_session.delete.assert_called_once_with(db_settlement_model) - mock_db_session.commit.assert_called_once() - -@pytest.mark.asyncio -async def test_delete_settlement_success_with_version_check(mock_db_session, db_settlement_model): - await delete_settlement(mock_db_session, db_settlement_model, expected_version=db_settlement_model.version) - mock_db_session.delete.assert_called_once_with(db_settlement_model) - mock_db_session.commit.assert_called_once() - -@pytest.mark.asyncio -async def test_delete_settlement_version_mismatch(mock_db_session, db_settlement_model): - with pytest.raises(InvalidOperationError) as excinfo: - await delete_settlement(mock_db_session, db_settlement_model, expected_version=db_settlement_model.version + 1) - assert "Expected version" in str(excinfo.value) - assert "does not match current version" in str(excinfo.value) - mock_db_session.delete.assert_not_called() - -@pytest.mark.asyncio -async def test_delete_settlement_commit_failure(mock_db_session, db_settlement_model): - mock_db_session.commit.side_effect = Exception("DB commit failed") - with pytest.raises(InvalidOperationError) as excinfo: - await delete_settlement(mock_db_session, db_settlement_model) - assert "Failed to delete settlement" in str(excinfo.value) - mock_db_session.rollback.assert_called_once() - - - -import pytest -from unittest.mock import AsyncMock, MagicMock -from sqlalchemy.exc import IntegrityError, OperationalError - -from app.crud.user import get_user_by_email, create_user -from app.schemas.user import UserCreate -from app.models import User as UserModel -from app.core.exceptions import ( - UserCreationError, - EmailAlreadyRegisteredError, - DatabaseConnectionError, - DatabaseIntegrityError, - DatabaseQueryError, - DatabaseTransactionError -) - -# Fixtures -@pytest.fixture -def mock_db_session(): - return AsyncMock() - -@pytest.fixture -def user_create_data(): - return UserCreate(email="test@example.com", password="password123", name="Test User") - -@pytest.fixture -def existing_user_data(): - return UserModel(id=1, email="exists@example.com", password_hash="hashed_password", name="Existing User") - -# Tests for get_user_by_email -@pytest.mark.asyncio -async def test_get_user_by_email_found(mock_db_session, existing_user_data): - mock_db_session.execute.return_value.scalars.return_value.first.return_value = existing_user_data - user = await get_user_by_email(mock_db_session, "exists@example.com") - assert user is not None - assert user.email == "exists@example.com" - mock_db_session.execute.assert_called_once() - -@pytest.mark.asyncio -async def test_get_user_by_email_not_found(mock_db_session): - mock_db_session.execute.return_value.scalars.return_value.first.return_value = None - user = await get_user_by_email(mock_db_session, "nonexistent@example.com") - assert user is None - mock_db_session.execute.assert_called_once() - -@pytest.mark.asyncio -async def test_get_user_by_email_db_connection_error(mock_db_session): - mock_db_session.execute.side_effect = OperationalError("mock_op_error", "params", "orig") - with pytest.raises(DatabaseConnectionError): - await get_user_by_email(mock_db_session, "test@example.com") - -@pytest.mark.asyncio -async def test_get_user_by_email_db_query_error(mock_db_session): - # Simulate a generic SQLAlchemyError that is not OperationalError - mock_db_session.execute.side_effect = IntegrityError("mock_sql_error", "params", "orig") # Using IntegrityError as an example of SQLAlchemyError - with pytest.raises(DatabaseQueryError): - await get_user_by_email(mock_db_session, "test@example.com") - - -# Tests for create_user -@pytest.mark.asyncio -async def test_create_user_success(mock_db_session, user_create_data): - # The actual user object returned would be created by SQLAlchemy based on db_user - # We mock the process: db.add is called, then db.flush, then db.refresh updates db_user - async def mock_refresh(user_model_instance): - user_model_instance.id = 1 # Simulate DB assigning an ID - # Simulate other db-generated fields if necessary - return None - - mock_db_session.refresh = AsyncMock(side_effect=mock_refresh) - mock_db_session.flush = AsyncMock() - mock_db_session.add = MagicMock() - - created_user = await create_user(mock_db_session, user_create_data) - - mock_db_session.add.assert_called_once() - mock_db_session.flush.assert_called_once() - mock_db_session.refresh.assert_called_once() - - assert created_user is not None - assert created_user.email == user_create_data.email - assert created_user.name == user_create_data.name - assert hasattr(created_user, 'id') # Check if ID was assigned (simulated by mock_refresh) - # Password hash check would be more involved, ensure hash_password was called correctly - # For now, we assume hash_password works as intended and is tested elsewhere. - -@pytest.mark.asyncio -async def test_create_user_email_already_registered(mock_db_session, user_create_data): - mock_db_session.flush.side_effect = IntegrityError("mock error (unique constraint)", "params", "orig") - with pytest.raises(EmailAlreadyRegisteredError): - await create_user(mock_db_session, user_create_data) - -@pytest.mark.asyncio -async def test_create_user_db_integrity_error_not_unique(mock_db_session, user_create_data): - # Simulate an IntegrityError that is not related to a unique constraint - mock_db_session.flush.side_effect = IntegrityError("mock error (not unique constraint)", "params", "orig") - with pytest.raises(DatabaseIntegrityError): - await create_user(mock_db_session, user_create_data) - -@pytest.mark.asyncio -async def test_create_user_db_connection_error(mock_db_session, user_create_data): - mock_db_session.begin.side_effect = OperationalError("mock_op_error", "params", "orig") - with pytest.raises(DatabaseConnectionError): - await create_user(mock_db_session, user_create_data) - # also test OperationalError on flush - mock_db_session.begin.side_effect = None # reset side effect - mock_db_session.flush.side_effect = OperationalError("mock_op_error", "params", "orig") - with pytest.raises(DatabaseConnectionError): - await create_user(mock_db_session, user_create_data) - - -@pytest.mark.asyncio -async def test_create_user_db_transaction_error(mock_db_session, user_create_data): - # Simulate a generic SQLAlchemyError on flush that is not IntegrityError or OperationalError - mock_db_session.flush.side_effect = UserCreationError("Simulated non-specific SQLAlchemyError") # Or any other SQLAlchemyError - with pytest.raises(DatabaseTransactionError): - await create_user(mock_db_session, user_create_data) - - - -## Polished PWA Plan: Shared Lists & Household Management - -## 1. Product Overview - -**Concept:** -Develop a Progressive Web App (PWA) focused on simplifying household coordination. Users can: -- Create, manage, and **share** shopping lists within defined groups (e.g., households, trip members). -- Capture images of receipts or shopping lists via the browser and extract items using **Google Cloud Vision API** for OCR. -- Track item costs on shared lists and easily split expenses among group participants. -- (Future) Manage and assign household chores. - -**Target Audience:** Households, roommates, families, groups organizing shared purchases. - -**UX Philosophy:** -- **User-Centered & Collaborative:** Design intuitive flows for both individual use and group collaboration with minimal friction. -- **Native-like PWA Experience:** Leverage service workers, caching, and manifest files for reliable offline use, installability, and smooth performance. -- **Clarity & Accessibility:** Prioritize high contrast, legible typography, straightforward navigation, and adherence to accessibility standards (WCAG). -- **Informative Feedback:** Provide clear visual feedback for actions (animations, loading states), OCR processing status, and data synchronization, including handling potential offline conflicts gracefully. - ---- - -## 2. MVP Scope (Refined & Focused) - -The MVP will focus on delivering a robust, shareable shopping list experience with integrated OCR and cost splitting, built as a high-quality PWA. **Chore management is deferred post-MVP** to ensure a polished core experience at launch. - -1. **Shared Shopping List Management:** - * **Core Features:** Create, update, delete lists and items. Mark items as complete. Basic item sorting/reordering (e.g., manual drag-and-drop). - * **Collaboration:** Share lists within user-defined groups. Real-time (or near real-time) updates visible to group members (via polling or simple WebSocket for MVP). - * **PWA/UX:** Responsive design, offline access to cached lists, basic conflict indication if offline edits clash (e.g., "Item updated by another user, refresh needed"). - -2. **OCR Integration (Google Cloud Vision):** - * **Core Features:** Capture images via browser (`` or `getUserMedia`). Upload images to the FastAPI backend. Backend securely calls **Google Cloud Vision API (Text Detection / Document Text Detection)**. Process results, suggest items to add to the list. - * **PWA/UX:** Clear instructions for image capture. Progress indicators during upload/processing. Display editable OCR results for user review and confirmation before adding to the list. Handle potential API errors or low-confidence results gracefully. - -3. **Cost Splitting (Integrated with Lists):** - * **Core Features:** Assign prices to items *on the shopping list* as they are purchased. Add participants (from the shared group) to a list's expense split. Calculate totals per list and simple equal splits per participant. - * **PWA/UX:** Clear display of totals and individual shares. Easy interface for marking items as bought and adding their price. - -4. **User Authentication & Group Management:** - * **Core Features:** Secure email/password signup & login (JWT-based). Ability to create simple groups (e.g., "Household"). Mechanism to invite/add users to a group (e.g., unique invite code/link). Basic role distinction (e.g., group owner/admin, member) if necessary for managing participants. - * **PWA/UX:** Minimalist forms, clear inline validation, smooth onboarding explaining the group concept. - -5. **Core PWA Functionality:** - * **Core Features:** Installable via `manifest.json`. Offline access via service worker caching (app shell, static assets, user data). Basic background sync strategy for offline actions (e.g., "last write wins" for simple edits, potentially queueing adds/deletes). - ---- - -## 3. Feature Breakdown & UX Enhancements (MVP Focus) - -### A. Shared Shopping Lists -- **Screens:** Dashboard (list overview), List Detail (items), Group Management. -- **Flows:** Create list -> (Optional) Share with group -> Add/edit/check items -> See updates from others -> Mark list complete. -- **UX Focus:** Smooth transitions, clear indication of shared status, offline caching, simple conflict notification (not full resolution in MVP). - -### B. OCR with Google Cloud Vision -- **Flow:** Tap "Add via OCR" -> Capture/Select Image -> Upload -> Show Progress -> Display Review Screen (editable text boxes for potential items) -> User confirms/edits -> Items added to list. -- **UX Focus:** Clear instructions, robust error handling (API errors, poor image quality feedback if possible), easy correction interface, manage user expectations regarding OCR accuracy. Monitor API costs/quotas. - -### C. Integrated Cost Splitting -- **Flow:** Open shared list -> Mark item "bought" -> Input price -> View updated list total -> Go to "Split Costs" view for the list -> Confirm participants (group members) -> See calculated equal split. -- **UX Focus:** Seamless transition from shopping to cost entry. Clear, real-time calculation display. Simple participant management within the list context. - -### D. User Auth & Groups -- **Flow:** Sign up/Login -> Create a group -> Invite members (e.g., share code) -> Member joins group -> Access shared lists. -- **UX Focus:** Secure and straightforward auth. Simple group creation and joining process. Clear visibility of group members. - -### E. PWA Essentials -- **Manifest:** Define app name, icons, theme, display mode. -- **Service Worker:** Cache app shell, assets, API responses (user data). Implement basic offline sync queue for actions performed offline (e.g., adding/checking items). Define a clear sync conflict strategy (e.g., last-write-wins, notify user on conflict). - ---- - -## 4. Architecture & Technology Stack - -### Frontend: Svelte PWA -- **Framework:** Svelte/SvelteKit (Excellent for performant, component-based PWAs). -- **State Management:** Svelte Stores for managing UI state and cached data. -- **PWA Tools:** Workbox.js (via SvelteKit integration or standalone) for robust service worker generation and caching strategies. -- **Styling:** Tailwind CSS or standard CSS with scoped styles. -- **UX:** Design system (e.g., using Figma), Storybook for component development. - -### Backend: FastAPI & PostgreSQL -- **Framework:** FastAPI (High performance, async support, auto-docs, Pydantic validation). -- **Database:** PostgreSQL (Reliable, supports JSONB for flexibility if needed). Schema designed to handle users, groups, lists, items, costs, and relationships. Basic indexing on foreign keys and frequently queried fields (user IDs, group IDs, list IDs). -- **ORM:** SQLAlchemy (async support with v2.0+) or Tortoise ORM (async-native). Alembic for migrations. -- **OCR Integration:** Use the official **Google Cloud Client Libraries for Python** to interact with the Vision API. Implement robust error handling, retries, and potentially rate limiting/cost control logic. Ensure API calls are `async` to avoid blocking. -- **Authentication:** JWT tokens for stateless session management. -- **Deployment:** Containerize using Docker/Docker Compose for development and deployment consistency. Deploy on a scalable cloud platform (e.g., Google Cloud Run, AWS Fargate, DigitalOcean App Platform). -- **Monitoring:** Logging (standard Python logging), Error Tracking (Sentry), Performance Monitoring (Prometheus/Grafana if needed later). ---- -# Finalized User Stories, Flow Mapping, Sharing Model & Sync, Tech Stack & Initial Architecture Diagram - -## 1. User Stories - -### Authentication & User Management -- As a new user, I want to sign up with my email so I can create and manage shopping lists -- As a returning user, I want to log in securely to access my lists and groups -- As a user, I want to reset my password if I forget it -- As a user, I want to edit my profile information (name, avatar) - -### Group Management -- As a user, I want to create a new group (e.g., "Household", "Roommates") to organize shared lists -- As a group creator, I want to invite others to join my group via a shareable link/code -- As an invitee, I want to easily join a group by clicking a link or entering a code -- As a group owner, I want to remove members if needed -- As a user, I want to leave a group I no longer wish to be part of -- As a user, I want to see all groups I belong to and switch between them - -### List Management -- As a user, I want to create a personal shopping list with a title and optional description -- As a user, I want to share a list with a specific group so members can collaborate -- As a user, I want to view all my lists (personal and shared) from a central dashboard -- As a user, I want to archive or delete lists I no longer need -- As a user, I want to mark a list as "shopping complete" when finished -- As a user, I want to see which group a list is shared with - -### Item Management -- As a user, I want to add items to a list with names and optional quantities -- As a user, I want to mark items as purchased when shopping -- As a user, I want to edit item details (name, quantity, notes) -- As a user, I want to delete items from a list -- As a user, I want to reorder items on my list for shopping efficiency -- As a user, I want to see who added or marked items as purchased in shared lists - -### OCR Integration -- As a user, I want to capture a photo of a physical shopping list or receipt -- As a user, I want the app to extract text and convert it into list items -- As a user, I want to review and edit OCR results before adding to my list -- As a user, I want clear feedback on OCR processing status -- As a user, I want to retry OCR if the results aren't satisfactory - -### Cost Splitting -- As a user, I want to add prices to items as I purchase them -- As a user, I want to see the total cost of all purchased items in a list -- As a user, I want to split costs equally among group members -- As a user, I want to see who owes what amount based on the split -- As a user, I want to mark expenses as settled - -### PWA & Offline Experience -- As a user, I want to install the app on my home screen for quick access -- As a user, I want to view and edit my lists even when offline -- As a user, I want my changes to sync automatically when I'm back online -- As a user, I want to be notified if my offline changes conflict with others' changes - - - -# docker-compose.yml (in project root) -version: '3.8' - -services: - db: - image: postgres:15 # Use a specific PostgreSQL version - container_name: postgres_db - environment: - POSTGRES_USER: dev_user # Define DB user - POSTGRES_PASSWORD: dev_password # Define DB password - POSTGRES_DB: dev_db # Define Database name - volumes: - - postgres_data:/var/lib/postgresql/data # Persist data using a named volume - ports: - - "5432:5432" # Expose PostgreSQL port to host (optional, for direct access) - healthcheck: - test: ["CMD-SHELL", "pg_isready -U $${POSTGRES_USER} -d $${POSTGRES_DB}"] - interval: 10s - timeout: 5s - retries: 5 - start_period: 10s - restart: unless-stopped - - backend: - container_name: fastapi_backend - build: - context: ./be # Path to the directory containing the Dockerfile - dockerfile: Dockerfile - volumes: - # Mount local code into the container for development hot-reloading - # The code inside the container at /app will mirror your local ./be directory - - ./be:/app - ports: - - "8000:8000" # Map container port 8000 to host port 8000 - environment: - # Pass the database URL to the backend container - # Uses the service name 'db' as the host, and credentials defined above - # IMPORTANT: Use the correct async driver prefix if your app needs it! - - DATABASE_URL=postgresql+asyncpg://dev_user:dev_password@db:5432/dev_db - # Add other environment variables needed by the backend here - # - SOME_OTHER_VAR=some_value - depends_on: - db: # Wait for the db service to be healthy before starting backend - condition: service_healthy - command: ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] # Override CMD for development reload - restart: unless-stopped - - pgadmin: # Optional service for database administration - image: dpage/pgadmin4:latest - container_name: pgadmin4_server - environment: - PGADMIN_DEFAULT_EMAIL: admin@example.com # Change as needed - PGADMIN_DEFAULT_PASSWORD: admin_password # Change to a secure password - PGADMIN_CONFIG_SERVER_MODE: 'False' # Run in Desktop mode for easier local dev server setup - volumes: - - pgadmin_data:/var/lib/pgadmin # Persist pgAdmin configuration - ports: - - "5050:80" # Map container port 80 to host port 5050 - depends_on: - - db # Depends on the database service - restart: unless-stopped - -volumes: # Define named volumes for data persistence - postgres_data: - pgadmin_data: - - - -[*.{js,jsx,mjs,cjs,ts,tsx,mts,cts,vue}] -charset = utf-8 -indent_size = 2 -indent_style = space -end_of_line = lf -insert_final_newline = true -trim_trailing_whitespace = true - - - -{ - "$schema": "https://json.schemastore.org/prettierrc", - "singleQuote": true, - "printWidth": 100 -} - - - -{ - "recommendations": [ - "dbaeumer.vscode-eslint", - "esbenp.prettier-vscode", - "editorconfig.editorconfig", - "vue.volar", - "wayou.vscode-todo-highlight" - ], - "unwantedRecommendations": [ - "octref.vetur", - "hookyqr.beautify", - "dbaeumer.jshint", - "ms-vscode.vscode-typescript-tslint-plugin" - ] -} - - - -{ - "editor.bracketPairColorization.enabled": true, - "editor.guides.bracketPairs": true, - "editor.formatOnSave": true, - "editor.defaultFormatter": "esbenp.prettier-vscode", - "editor.codeActionsOnSave": [ - "source.fixAll.eslint" - ], - "eslint.validate": [ - "javascript", - "javascriptreact", - "typescript", - "vue" - ], - "typescript.tsdk": "node_modules/typescript/lib" -} - - - -import js from '@eslint/js' -import globals from 'globals' -import pluginVue from 'eslint-plugin-vue' -import pluginQuasar from '@quasar/app-vite/eslint' -import { defineConfigWithVueTs, vueTsConfigs } from '@vue/eslint-config-typescript' -import prettierSkipFormatting from '@vue/eslint-config-prettier/skip-formatting' - -export default defineConfigWithVueTs( - { - /** - * Ignore the following files. - * Please note that pluginQuasar.configs.recommended() already ignores - * the "node_modules" folder for you (and all other Quasar project - * relevant folders and files). - * - * ESLint requires "ignores" key to be the only one in this object - */ - // ignores: [] - }, - - pluginQuasar.configs.recommended(), - js.configs.recommended, - - /** - * https://eslint.vuejs.org - * - * pluginVue.configs.base - * -> Settings and rules to enable correct ESLint parsing. - * pluginVue.configs[ 'flat/essential'] - * -> base, plus rules to prevent errors or unintended behavior. - * pluginVue.configs["flat/strongly-recommended"] - * -> Above, plus rules to considerably improve code readability and/or dev experience. - * pluginVue.configs["flat/recommended"] - * -> Above, plus rules to enforce subjective community defaults to ensure consistency. - */ - pluginVue.configs[ 'flat/essential' ], - - { - files: ['**/*.ts', '**/*.vue'], - rules: { - '@typescript-eslint/consistent-type-imports': [ - 'error', - { prefer: 'type-imports' } - ], - } - }, - // https://github.com/vuejs/eslint-config-typescript - vueTsConfigs.recommendedTypeChecked, - - { - languageOptions: { - ecmaVersion: 'latest', - sourceType: 'module', - - globals: { - ...globals.browser, - ...globals.node, // SSR, Electron, config files - process: 'readonly', // process.env.* - ga: 'readonly', // Google Analytics - cordova: 'readonly', - Capacitor: 'readonly', - chrome: 'readonly', // BEX related - browser: 'readonly' // BEX related - } - }, - - // add your custom rules here - rules: { - 'prefer-promise-reject-errors': 'off', - - // allow debugger during development only - 'no-debugger': process.env.NODE_ENV === 'production' ? 'error' : 'off' - } - }, - - { - files: [ 'src-pwa/custom-service-worker.ts' ], - languageOptions: { - globals: { - ...globals.serviceworker - } - } - }, - - prettierSkipFormatting -) - - - - - - - <%= productName %> - - - - - - - - - - - - - - - - - - - - -// https://github.com/michael-ciniawsky/postcss-load-config - -import autoprefixer from 'autoprefixer' -// import rtlcss from 'postcss-rtlcss' - -export default { - plugins: [ - // https://github.com/postcss/autoprefixer - autoprefixer({ - overrideBrowserslist: [ - 'last 4 Chrome versions', - 'last 4 Firefox versions', - 'last 4 Edge versions', - 'last 4 Safari versions', - 'last 4 Android versions', - 'last 4 ChromeAndroid versions', - 'last 4 FirefoxAndroid versions', - 'last 4 iOS versions' - ] - }), - - // https://github.com/elchininet/postcss-rtlcss - // If you want to support RTL css, then - // 1. yarn/pnpm/bun/npm install postcss-rtlcss - // 2. optionally set quasar.config.js > framework > lang to an RTL language - // 3. uncomment the following line (and its import statement above): - // rtlcss() - ] -} - - - - - - - -{ - "orientation": "portrait", - "background_color": "#ffffff", - "theme_color": "#027be3", - "icons": [ - { - "src": "icons/icon-128x128.png", - "sizes": "128x128", - "type": "image/png" - }, - { - "src": "icons/icon-192x192.png", - "sizes": "192x192", - "type": "image/png" - }, - { - "src": "icons/icon-256x256.png", - "sizes": "256x256", - "type": "image/png" - }, - { - "src": "icons/icon-384x384.png", - "sizes": "384x384", - "type": "image/png" - }, - { - "src": "icons/icon-512x512.png", - "sizes": "512x512", - "type": "image/png" - } - ] -} - - - -declare namespace NodeJS { - interface ProcessEnv { - SERVICE_WORKER_FILE: string; - PWA_FALLBACK_HTML: string; - PWA_SERVICE_WORKER_REGEX: string; - } -} - - - -import { register } from 'register-service-worker'; - -// The ready(), registered(), cached(), updatefound() and updated() -// events passes a ServiceWorkerRegistration instance in their arguments. -// ServiceWorkerRegistration: https://developer.mozilla.org/en-US/docs/Web/API/ServiceWorkerRegistration - -register(process.env.SERVICE_WORKER_FILE, { - // The registrationOptions object will be passed as the second argument - // to ServiceWorkerContainer.register() - // https://developer.mozilla.org/en-US/docs/Web/API/ServiceWorkerContainer/register#Parameter - - // registrationOptions: { scope: './' }, - - ready (/* registration */) { - // console.log('Service worker is active.') - }, - - registered (/* registration */) { - // console.log('Service worker has been registered.') - }, - - cached (/* registration */) { - // console.log('Content has been cached for offline use.') - }, - - updatefound (/* registration */) { - // console.log('New content is downloading.') - }, - - updated (/* registration */) { - // console.log('New content is available; please refresh.') - }, - - offline () { - // console.log('No internet connection found. App is running in offline mode.') - }, - - error (/* err */) { - // console.error('Error during service worker registration:', err) - }, -}); - - - -{ - "extends": "../tsconfig.json", - "compilerOptions": { - "lib": ["WebWorker", "ESNext"] - }, - "include": ["*.ts", "*.d.ts"] -} - - - - - - - - - - - - - - - - - - - - - -import { defineBoot } from '#q-app/wrappers'; -import { createI18n } from 'vue-i18n'; - -import messages from 'src/i18n'; - -export type MessageLanguages = keyof typeof messages; -// Type-define 'en-US' as the master schema for the resource -export type MessageSchema = typeof messages['en-US']; - -// See https://vue-i18n.intlify.dev/guide/advanced/typescript.html#global-resource-schema-type-definition -/* eslint-disable @typescript-eslint/no-empty-object-type */ -declare module 'vue-i18n' { - // define the locale messages schema - export interface DefineLocaleMessage extends MessageSchema {} - - // define the datetime format schema - export interface DefineDateTimeFormat {} - - // define the number format schema - export interface DefineNumberFormat {} -} -/* eslint-enable @typescript-eslint/no-empty-object-type */ - -export default defineBoot(({ app }) => { - const i18n = createI18n<{ message: MessageSchema }, MessageLanguages>({ - locale: 'en-US', - legacy: false, - messages, - }); - - // Set i18n instance on app - app.use(i18n); -}); - - - - - - - - - -export interface Todo { - id: number; - content: string; -} - -export interface Meta { - totalCount: number; -} - - - -import { api } from 'boot/axios'; -import { API_BASE_URL, API_VERSION, API_ENDPOINTS } from './api-config'; - -// Helper function to get full API URL -export const getApiUrl = (endpoint: string): string => { - return `${API_BASE_URL}/api/${API_VERSION}${endpoint}`; -}; - -// Helper function to make API calls -export const apiClient = { - get: (endpoint: string, config = {}) => api.get(getApiUrl(endpoint), config), - post: (endpoint: string, data = {}, config = {}) => api.post(getApiUrl(endpoint), data, config), - put: (endpoint: string, data = {}, config = {}) => api.put(getApiUrl(endpoint), data, config), - patch: (endpoint: string, data = {}, config = {}) => api.patch(getApiUrl(endpoint), data, config), - delete: (endpoint: string, config = {}) => api.delete(getApiUrl(endpoint), config), -}; - -export { API_ENDPOINTS }; - - - -// app global css in SCSS form - - - -// Quasar SCSS (& Sass) Variables -// -------------------------------------------------- -// To customize the look and feel of this app, you can override -// the Sass/SCSS variables found in Quasar's source Sass/SCSS files. - -// Check documentation for full list of Quasar variables - -// Your own variables (that are declared here) and Quasar's own -// ones will be available out of the box in your .vue/.scss/.sass files - -// It's highly recommended to change the default colors -// to match your app's branding. -// Tip: Use the "Theme Builder" on Quasar's documentation website. - -$primary : #1976D2; -$secondary : #26A69A; -$accent : #9C27B0; - -$dark : #1D1D1D; -$dark-page : #121212; - -$positive : #21BA45; -$negative : #C10015; -$info : #31CCEC; -$warning : #F2C037; - - - -declare namespace NodeJS { - interface ProcessEnv { - NODE_ENV: string; - VUE_ROUTER_MODE: 'hash' | 'history' | 'abstract' | undefined; - VUE_ROUTER_BASE: string | undefined; - } -} - - - -// This is just an example, -// so you can safely delete all default props below - -export default { - failed: 'Action failed', - success: 'Action was successful' -}; - - - -import enUS from './en-US'; - -export default { - 'en-US': enUS -}; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -import { defineRouter } from '#q-app/wrappers'; -import { - createMemoryHistory, - createRouter, - createWebHashHistory, - createWebHistory, -} from 'vue-router'; -import routes from './routes'; -import { useAuthStore } from 'stores/auth'; - -/* - * If not building with SSR mode, you can - * directly export the Router instantiation; - * - * The function below can be async too; either use - * async/await or return a Promise which resolves - * with the Router instance. - */ - -export default defineRouter(function (/* { store, ssrContext } */) { - const createHistory = process.env.SERVER - ? createMemoryHistory - : process.env.VUE_ROUTER_MODE === 'history' - ? createWebHistory - : createWebHashHistory; - - const Router = createRouter({ - scrollBehavior: () => ({ left: 0, top: 0 }), - routes, - - // Leave this as is and make changes in quasar.conf.js instead! - // quasar.conf.js -> build -> vueRouterMode - // quasar.conf.js -> build -> publicPath - history: createHistory(process.env.VUE_ROUTER_BASE), - }); - - // Navigation guard to check authentication - Router.beforeEach((to, from, next) => { - const authStore = useAuthStore(); - const isAuthenticated = authStore.isAuthenticated; - - // Define public routes that don't require authentication - const publicRoutes = ['/login', '/signup']; - - // Check if the route requires authentication - const requiresAuth = !publicRoutes.includes(to.path); - - if (requiresAuth && !isAuthenticated) { - // Redirect to login if trying to access protected route without authentication - next({ path: '/login', query: { redirect: to.fullPath } }); - } else if (!requiresAuth && isAuthenticated) { - // Redirect to home if trying to access login/signup while authenticated - next({ path: '/' }); - } else { - // Proceed with navigation - next(); - } - }); - - return Router; -}); - - - -import { defineStore } from '#q-app/wrappers' -import { createPinia } from 'pinia' - -/* - * When adding new properties to stores, you should also - * extend the `PiniaCustomProperties` interface. - * @see https://pinia.vuejs.org/core-concepts/plugins.html#typing-new-store-properties - */ -declare module 'pinia' { - // eslint-disable-next-line @typescript-eslint/no-empty-object-type - export interface PiniaCustomProperties { - // add your custom properties here, if any - } -} - -/* - * If not building with SSR mode, you can - * directly export the Store instantiation; - * - * The function below can be async too; either use - * async/await or return a Promise which resolves - * with the Store instance. - */ - -export default defineStore((/* { ssrContext } */) => { - const pinia = createPinia() - - // You can add Pinia plugins here - // pinia.use(SomePiniaPlugin) - - return pinia -}) - - - -# app/api/dependencies.py -import logging -from typing import Optional - -from fastapi import Depends, HTTPException, status -from fastapi.security import OAuth2PasswordBearer -from sqlalchemy.ext.asyncio import AsyncSession -from jose import JWTError - -from app.database import get_db -from app.core.security import verify_access_token -from app.crud import user as crud_user -from app.models import User as UserModel # Import the SQLAlchemy model -from app.config import settings - -logger = logging.getLogger(__name__) - -# Define the OAuth2 scheme -# tokenUrl should point to your login endpoint relative to the base path -# It's used by Swagger UI for the "Authorize" button flow. -oauth2_scheme = OAuth2PasswordBearer(tokenUrl=settings.OAUTH2_TOKEN_URL) - -async def get_current_user( - token: str = Depends(oauth2_scheme), - db: AsyncSession = Depends(get_db) -) -> UserModel: - """ - Dependency to get the current user based on the JWT token. - - - Extracts token using OAuth2PasswordBearer. - - Verifies the token (signature, expiry). - - Fetches the user from the database based on the token's subject (email). - - Raises HTTPException 401 if any step fails. - - Returns: - The authenticated user's database model instance. - """ - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=settings.AUTH_CREDENTIALS_ERROR, - headers={settings.AUTH_HEADER_NAME: settings.AUTH_HEADER_PREFIX}, - ) - - payload = verify_access_token(token) - if payload is None: - logger.warning("Token verification failed (invalid, expired, or malformed).") - raise credentials_exception - - email: Optional[str] = payload.get("sub") - if email is None: - logger.error("Token payload missing 'sub' (subject/email).") - raise credentials_exception # Token is malformed - - # Fetch user from database - user = await crud_user.get_user_by_email(db, email=email) - if user is None: - logger.warning(f"User corresponding to token subject not found: {email}") - # Could happen if user deleted after token issuance - raise credentials_exception # Treat as invalid credentials - - logger.debug(f"Authenticated user retrieved: {user.email} (ID: {user.id})") - return user - -# Optional: Dependency for getting the *active* current user -# You might add an `is_active` flag to your User model later -# async def get_current_active_user( -# current_user: UserModel = Depends(get_current_user) -# ) -> UserModel: -# if not current_user.is_active: # Assuming an is_active attribute -# logger.warning(f"Authentication attempt by inactive user: {current_user.email}") -# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user") -# return current_user - - - -# app/api/v1/endpoints/groups.py -import logging -from typing import List - -from fastapi import APIRouter, Depends, HTTPException, status -from sqlalchemy.ext.asyncio import AsyncSession - -from app.database import get_db -from app.api.dependencies import get_current_user -from app.models import User as UserModel, UserRoleEnum # Import model and enum -from app.schemas.group import GroupCreate, GroupPublic -from app.schemas.invite import InviteCodePublic -from app.schemas.message import Message # For simple responses -from app.crud import group as crud_group -from app.crud import invite as crud_invite -from app.core.exceptions import ( - GroupNotFoundError, - GroupPermissionError, - GroupMembershipError, - GroupOperationError, - GroupValidationError, - InviteCreationError -) - -logger = logging.getLogger(__name__) -router = APIRouter() - -@router.post( - "", # Route relative to prefix "/groups" - response_model=GroupPublic, - status_code=status.HTTP_201_CREATED, - summary="Create New Group", - tags=["Groups"] -) -async def create_group( - group_in: GroupCreate, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """Creates a new group, adding the creator as the owner.""" - logger.info(f"User {current_user.email} creating group: {group_in.name}") - created_group = await crud_group.create_group(db=db, group_in=group_in, creator_id=current_user.id) - # Load members explicitly if needed for the response (optional here) - # created_group = await crud_group.get_group_by_id(db, created_group.id) - return created_group - - -@router.get( - "", # Route relative to prefix "/groups" - response_model=List[GroupPublic], - summary="List User's Groups", - tags=["Groups"] -) -async def read_user_groups( - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """Retrieves all groups the current user is a member of.""" - logger.info(f"Fetching groups for user: {current_user.email}") - groups = await crud_group.get_user_groups(db=db, user_id=current_user.id) - return groups - - -@router.get( - "/{group_id}", - response_model=GroupPublic, - summary="Get Group Details", - tags=["Groups"] -) -async def read_group( - group_id: int, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """Retrieves details for a specific group, including members, if the user is part of it.""" - logger.info(f"User {current_user.email} requesting details for group ID: {group_id}") - # Check if user is a member first - is_member = await crud_group.is_user_member(db=db, group_id=group_id, user_id=current_user.id) - if not is_member: - logger.warning(f"Access denied: User {current_user.email} not member of group {group_id}") - raise GroupMembershipError(group_id, "view group details") - - group = await crud_group.get_group_by_id(db=db, group_id=group_id) - if not group: - logger.error(f"Group {group_id} requested by member {current_user.email} not found (data inconsistency?)") - raise GroupNotFoundError(group_id) - - return group - - -@router.post( - "/{group_id}/invites", - response_model=InviteCodePublic, - summary="Create Group Invite", - tags=["Groups", "Invites"] -) -async def create_group_invite( - group_id: int, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """Generates a new invite code for the group. Requires owner/admin role (MVP: owner only).""" - logger.info(f"User {current_user.email} attempting to create invite for group {group_id}") - user_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=current_user.id) - - # --- Permission Check (MVP: Owner only) --- - if user_role != UserRoleEnum.owner: - logger.warning(f"Permission denied: User {current_user.email} (role: {user_role}) cannot create invite for group {group_id}") - raise GroupPermissionError(group_id, "create invites") - - # Check if group exists (implicitly done by role check, but good practice) - group = await crud_group.get_group_by_id(db, group_id) - if not group: - raise GroupNotFoundError(group_id) - - invite = await crud_invite.create_invite(db=db, group_id=group_id, creator_id=current_user.id) - if not invite: - logger.error(f"Failed to generate unique invite code for group {group_id}") - raise InviteCreationError(group_id) - - logger.info(f"User {current_user.email} created invite code for group {group_id}") - return invite - -@router.delete( - "/{group_id}/leave", - response_model=Message, - summary="Leave Group", - tags=["Groups"] -) -async def leave_group( - group_id: int, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """Removes the current user from the specified group.""" - logger.info(f"User {current_user.email} attempting to leave group {group_id}") - user_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=current_user.id) - - if user_role is None: - raise GroupMembershipError(group_id, "leave (you are not a member)") - - # --- MVP: Prevent owner leaving if they are the last member/owner --- - if user_role == UserRoleEnum.owner: - member_count = await crud_group.get_group_member_count(db, group_id) - # More robust check: count owners. For now, just check member count. - if member_count <= 1: - logger.warning(f"Owner {current_user.email} attempted to leave group {group_id} as last member.") - raise GroupValidationError("Owner cannot leave the group as the last member. Delete the group or transfer ownership.") - - # Proceed with removal - deleted = await crud_group.remove_user_from_group(db, group_id=group_id, user_id=current_user.id) - - if not deleted: - # Should not happen if role check passed, but handle defensively - logger.error(f"Failed to remove user {current_user.email} from group {group_id} despite being a member.") - raise GroupOperationError("Failed to leave group") - - logger.info(f"User {current_user.email} successfully left group {group_id}") - return Message(detail="Successfully left the group") - -# --- Optional: Remove Member Endpoint --- -@router.delete( - "/{group_id}/members/{user_id_to_remove}", - response_model=Message, - summary="Remove Member From Group (Owner Only)", - tags=["Groups"] -) -async def remove_group_member( - group_id: int, - user_id_to_remove: int, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """Removes a specified user from the group. Requires current user to be owner.""" - logger.info(f"Owner {current_user.email} attempting to remove user {user_id_to_remove} from group {group_id}") - owner_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=current_user.id) - - # --- Permission Check --- - if owner_role != UserRoleEnum.owner: - logger.warning(f"Permission denied: User {current_user.email} (role: {owner_role}) cannot remove members from group {group_id}") - raise GroupPermissionError(group_id, "remove members") - - # Prevent owner removing themselves via this endpoint - if current_user.id == user_id_to_remove: - raise GroupValidationError("Owner cannot remove themselves using this endpoint. Use 'Leave Group' instead.") - - # Check if target user is actually in the group - target_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=user_id_to_remove) - if target_role is None: - raise GroupMembershipError(group_id, "remove this user (they are not a member)") - - # Proceed with removal - deleted = await crud_group.remove_user_from_group(db, group_id=group_id, user_id=user_id_to_remove) - - if not deleted: - logger.error(f"Owner {current_user.email} failed to remove user {user_id_to_remove} from group {group_id}.") - raise GroupOperationError("Failed to remove member") - - logger.info(f"Owner {current_user.email} successfully removed user {user_id_to_remove} from group {group_id}") - return Message(detail="Successfully removed member from the group") - - - -# app/api/v1/endpoints/health.py -import logging -from fastapi import APIRouter, Depends -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.sql import text - -from app.database import get_db -from app.schemas.health import HealthStatus -from app.core.exceptions import DatabaseConnectionError - -logger = logging.getLogger(__name__) -router = APIRouter() - -@router.get( - "/health", - response_model=HealthStatus, - summary="Perform a Health Check", - description="Checks the operational status of the API and its connection to the database.", - tags=["Health"] -) -async def check_health(db: AsyncSession = Depends(get_db)): - """ - Health check endpoint. Verifies API reachability and database connection. - """ - try: - # Try executing a simple query to check DB connection - result = await db.execute(text("SELECT 1")) - if result.scalar_one() == 1: - logger.info("Health check successful: Database connection verified.") - return HealthStatus(status="ok", database="connected") - else: - # This case should ideally not happen with 'SELECT 1' - logger.error("Health check failed: Database connection check returned unexpected result.") - raise DatabaseConnectionError("Unexpected result from database connection check") - - except Exception as e: - logger.error(f"Health check failed: Database connection error - {e}", exc_info=True) - raise DatabaseConnectionError(str(e)) - - - -# app/api/v1/endpoints/invites.py -import logging -from fastapi import APIRouter, Depends, HTTPException, status -from sqlalchemy.ext.asyncio import AsyncSession - -from app.database import get_db -from app.api.dependencies import get_current_user -from app.models import User as UserModel, UserRoleEnum -from app.schemas.invite import InviteAccept -from app.schemas.message import Message -from app.crud import invite as crud_invite -from app.crud import group as crud_group -from app.core.exceptions import ( - InviteNotFoundError, - InviteExpiredError, - InviteAlreadyUsedError, - InviteCreationError, - GroupNotFoundError, - GroupMembershipError -) - -logger = logging.getLogger(__name__) -router = APIRouter() - -@router.post( - "/accept", # Route relative to prefix "/invites" - response_model=Message, - summary="Accept Group Invite", - tags=["Invites"] -) -async def accept_invite( - invite_in: InviteAccept, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """Accepts a group invite using the provided invite code.""" - logger.info(f"User {current_user.email} attempting to accept invite code: {invite_in.invite_code}") - - # Get the invite - invite = await crud_invite.get_invite_by_code(db, invite_code=invite_in.invite_code) - if not invite: - logger.warning(f"Invalid invite code attempted by user {current_user.email}: {invite_in.invite_code}") - raise InviteNotFoundError(invite_in.invite_code) - - # Check if invite is expired - if invite.is_expired(): - logger.warning(f"Expired invite code attempted by user {current_user.email}: {invite_in.invite_code}") - raise InviteExpiredError(invite_in.invite_code) - - # Check if invite has already been used - if invite.used_at: - logger.warning(f"Already used invite code attempted by user {current_user.email}: {invite_in.invite_code}") - raise InviteAlreadyUsedError(invite_in.invite_code) - - # Check if group still exists - group = await crud_group.get_group_by_id(db, group_id=invite.group_id) - if not group: - logger.error(f"Group {invite.group_id} not found for invite {invite_in.invite_code}") - raise GroupNotFoundError(invite.group_id) - - # Check if user is already a member - is_member = await crud_group.is_user_member(db, group_id=invite.group_id, user_id=current_user.id) - if is_member: - logger.warning(f"User {current_user.email} already a member of group {invite.group_id}") - raise GroupMembershipError(invite.group_id, "join (already a member)") - - # Add user to group and mark invite as used - success = await crud_invite.accept_invite(db, invite=invite, user_id=current_user.id) - if not success: - logger.error(f"Failed to accept invite {invite_in.invite_code} for user {current_user.email}") - raise InviteCreationError(invite.group_id) - - logger.info(f"User {current_user.email} successfully joined group {invite.group_id} via invite {invite_in.invite_code}") - return Message(detail="Successfully joined the group") - - - -# app/core/security.py -from datetime import datetime, timedelta, timezone -from typing import Any, Union, Optional - -from jose import JWTError, jwt -from passlib.context import CryptContext - -from app.config import settings # Import settings from config - -# --- Password Hashing --- - -# Configure passlib context -# Using bcrypt as the default hashing scheme -# 'deprecated="auto"' will automatically upgrade hashes if needed on verification -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - -def verify_password(plain_password: str, hashed_password: str) -> bool: - """ - Verifies a plain text password against a hashed password. - - Args: - plain_password: The password attempt. - hashed_password: The stored hash from the database. - - Returns: - True if the password matches the hash, False otherwise. - """ - try: - return pwd_context.verify(plain_password, hashed_password) - except Exception: - # Handle potential errors during verification (e.g., invalid hash format) - return False - -def hash_password(password: str) -> str: - """ - Hashes a plain text password using the configured context (bcrypt). - - Args: - password: The plain text password to hash. - - Returns: - The resulting hash string. - """ - return pwd_context.hash(password) - - -# --- JSON Web Tokens (JWT) --- - -def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str: - """ - Creates a JWT access token. - - Args: - subject: The subject of the token (e.g., user ID or email). - expires_delta: Optional timedelta object for token expiry. If None, - uses ACCESS_TOKEN_EXPIRE_MINUTES from settings. - - Returns: - The encoded JWT access token string. - """ - if expires_delta: - expire = datetime.now(timezone.utc) + expires_delta - else: - expire = datetime.now(timezone.utc) + timedelta( - minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES - ) - - # Data to encode in the token payload - to_encode = {"exp": expire, "sub": str(subject), "type": "access"} - - encoded_jwt = jwt.encode( - to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM - ) - return encoded_jwt - -def create_refresh_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str: - """ - Creates a JWT refresh token. - - Args: - subject: The subject of the token (e.g., user ID or email). - expires_delta: Optional timedelta object for token expiry. If None, - uses REFRESH_TOKEN_EXPIRE_MINUTES from settings. - - Returns: - The encoded JWT refresh token string. - """ - if expires_delta: - expire = datetime.now(timezone.utc) + expires_delta - else: - expire = datetime.now(timezone.utc) + timedelta( - minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES - ) - - # Data to encode in the token payload - to_encode = {"exp": expire, "sub": str(subject), "type": "refresh"} - - encoded_jwt = jwt.encode( - to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM - ) - return encoded_jwt - -def verify_access_token(token: str) -> Optional[dict]: - """ - Verifies a JWT access token and returns its payload if valid. - - Args: - token: The JWT token string to verify. - - Returns: - The decoded token payload (dict) if the token is valid and not expired, - otherwise None. - """ - try: - # Decode the token. This also automatically verifies: - # - Signature (using SECRET_KEY and ALGORITHM) - # - Expiration ('exp' claim) - payload = jwt.decode( - token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] - ) - if payload.get("type") != "access": - raise JWTError("Invalid token type") - return payload - except JWTError as e: - # Handles InvalidSignatureError, ExpiredSignatureError, etc. - print(f"JWT Error: {e}") # Log the error for debugging - return None - except Exception as e: - # Handle other potential unexpected errors during decoding - print(f"Unexpected error decoding JWT: {e}") - return None - -def verify_refresh_token(token: str) -> Optional[dict]: - """ - Verifies a JWT refresh token and returns its payload if valid. - - Args: - token: The JWT token string to verify. - - Returns: - The decoded token payload (dict) if the token is valid, not expired, - and is a refresh token, otherwise None. - """ - try: - payload = jwt.decode( - token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] - ) - if payload.get("type") != "refresh": - raise JWTError("Invalid token type") - return payload - except JWTError as e: - print(f"JWT Error: {e}") # Log the error for debugging - return None - except Exception as e: - print(f"Unexpected error decoding JWT: {e}") - return None - -# You might add a function here later to extract the 'sub' (subject/user id) -# specifically, often used in dependency injection for authentication. -# def get_subject_from_token(token: str) -> Optional[str]: -# payload = verify_access_token(token) -# if payload: -# return payload.get("sub") -# return None - - - -# app/crud/user.py -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select -from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError -from typing import Optional - -from app.models import User as UserModel # Alias to avoid name clash -from app.schemas.user import UserCreate -from app.core.security import hash_password -from app.core.exceptions import ( - UserCreationError, - EmailAlreadyRegisteredError, - DatabaseConnectionError, - DatabaseIntegrityError, - DatabaseQueryError, - DatabaseTransactionError -) - -async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]: - """Fetches a user from the database by email.""" - try: - async with db.begin(): - result = await db.execute(select(UserModel).filter(UserModel.email == email)) - return result.scalars().first() - except OperationalError as e: - raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") - except SQLAlchemyError as e: - raise DatabaseQueryError(f"Failed to query user: {str(e)}") - -async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel: - """Creates a new user record in the database.""" - try: - async with db.begin(): - _hashed_password = hash_password(user_in.password) - db_user = UserModel( - email=user_in.email, - password_hash=_hashed_password, - name=user_in.name - ) - db.add(db_user) - await db.flush() # Flush to get DB-generated values - await db.refresh(db_user) - return db_user - except IntegrityError as e: - if "unique constraint" in str(e).lower(): - raise EmailAlreadyRegisteredError() - raise DatabaseIntegrityError(f"Failed to create user: {str(e)}") - except OperationalError as e: - raise DatabaseConnectionError(f"Database connection error: {str(e)}") - except SQLAlchemyError as e: - raise DatabaseTransactionError(f"Failed to create user: {str(e)}") - - - -# app/schemas/auth.py -from pydantic import BaseModel, EmailStr -from app.config import settings - -class Token(BaseModel): - access_token: str - refresh_token: str # Added refresh token - token_type: str = settings.TOKEN_TYPE # Use configured token type - -# Optional: If you preferred not to use OAuth2PasswordRequestForm -# class UserLogin(BaseModel): -# email: EmailStr -# password: str - - - -from pydantic import BaseModel, ConfigDict -from typing import List, Optional -from decimal import Decimal - -class UserCostShare(BaseModel): - user_id: int - user_identifier: str # Name or email - items_added_value: Decimal = Decimal("0.00") # Total value of items this user added - amount_due: Decimal # The user's share of the total cost (for equal split, this is total_cost / num_users) - balance: Decimal # items_added_value - amount_due - - model_config = ConfigDict(from_attributes=True) - -class ListCostSummary(BaseModel): - list_id: int - list_name: str - total_list_cost: Decimal - num_participating_users: int - equal_share_per_user: Decimal - user_balances: List[UserCostShare] - - model_config = ConfigDict(from_attributes=True) - -class UserBalanceDetail(BaseModel): - user_id: int - user_identifier: str # Name or email - total_paid_for_expenses: Decimal = Decimal("0.00") - total_share_of_expenses: Decimal = Decimal("0.00") - total_settlements_paid: Decimal = Decimal("0.00") - total_settlements_received: Decimal = Decimal("0.00") - net_balance: Decimal = Decimal("0.00") # (paid_for_expenses + settlements_received) - (share_of_expenses + settlements_paid) - model_config = ConfigDict(from_attributes=True) - -class SuggestedSettlement(BaseModel): - from_user_id: int - from_user_identifier: str # Name or email of payer - to_user_id: int - to_user_identifier: str # Name or email of payee - amount: Decimal - model_config = ConfigDict(from_attributes=True) - -class GroupBalanceSummary(BaseModel): - group_id: int - group_name: str - overall_total_expenses: Decimal = Decimal("0.00") - overall_total_settlements: Decimal = Decimal("0.00") - user_balances: List[UserBalanceDetail] - # Optional: Could add a list of suggested settlements to zero out balances - suggested_settlements: Optional[List[SuggestedSettlement]] = None - model_config = ConfigDict(from_attributes=True) - -# class SuggestedSettlement(BaseModel): -# from_user_id: int -# to_user_id: int -# amount: Decimal - - - -# app/schemas/health.py -from pydantic import BaseModel -from app.config import settings - -class HealthStatus(BaseModel): - """ - Response model for the health check endpoint. - """ - status: str = settings.HEALTH_STATUS_OK # Use configured default value - database: str - - - -# app/schemas/item.py -from pydantic import BaseModel, ConfigDict -from datetime import datetime -from typing import Optional -from decimal import Decimal - -# Properties to return to client -class ItemPublic(BaseModel): - id: int - list_id: int - name: str - quantity: Optional[str] = None - is_complete: bool - price: Optional[Decimal] = None - added_by_id: int - completed_by_id: Optional[int] = None - created_at: datetime - updated_at: datetime - version: int - model_config = ConfigDict(from_attributes=True) - -# Properties to receive via API on creation -class ItemCreate(BaseModel): - name: str - quantity: Optional[str] = None - # list_id will be from path param - # added_by_id will be from current_user - -# Properties to receive via API on update -class ItemUpdate(BaseModel): - name: Optional[str] = None - quantity: Optional[str] = None - is_complete: Optional[bool] = None - price: Optional[Decimal] = None # Price added here for update - version: int - # completed_by_id will be set internally if is_complete is true - - - -# app/schemas/list.py -from pydantic import BaseModel, ConfigDict -from datetime import datetime -from typing import Optional, List - -from .item import ItemPublic # Import item schema for nesting - -# Properties to receive via API on creation -class ListCreate(BaseModel): - name: str - description: Optional[str] = None - group_id: Optional[int] = None # Optional for sharing - -# Properties to receive via API on update -class ListUpdate(BaseModel): - name: Optional[str] = None - description: Optional[str] = None - is_complete: Optional[bool] = None - version: int # Client must provide the version for updates - # Potentially add group_id update later if needed - -# Base properties returned by API (common fields) -class ListBase(BaseModel): - id: int - name: str - description: Optional[str] = None - created_by_id: int - group_id: Optional[int] = None - is_complete: bool - created_at: datetime - updated_at: datetime - version: int # Include version in responses - - model_config = ConfigDict(from_attributes=True) - -# Properties returned when listing lists (no items) -class ListPublic(ListBase): - pass # Inherits all from ListBase - -# Properties returned for a single list detail (includes items) -class ListDetail(ListBase): - items: List[ItemPublic] = [] # Include list of items - -class ListStatus(BaseModel): - list_updated_at: datetime - latest_item_updated_at: Optional[datetime] = None # Can be null if list has no items - item_count: int - - - -# app/schemas/user.py -from pydantic import BaseModel, EmailStr, ConfigDict -from datetime import datetime -from typing import Optional - -# Shared properties -class UserBase(BaseModel): - email: EmailStr - name: Optional[str] = None - -# Properties to receive via API on creation -class UserCreate(UserBase): - password: str - -# Properties to receive via API on update (optional, add later if needed) -# class UserUpdate(UserBase): -# password: Optional[str] = None - -# Properties stored in DB -class UserInDBBase(UserBase): - id: int - password_hash: str - created_at: datetime - model_config = ConfigDict(from_attributes=True) # Use orm_mode in Pydantic v1 - -# Additional properties to return via API (excluding password) -class UserPublic(UserBase): - id: int - created_at: datetime - model_config = ConfigDict(from_attributes=True) - -# Full user model including hashed password (for internal use/reading from DB) -class User(UserInDBBase): - pass - - - -.DS_Store -.thumbs.db -node_modules - -# Quasar core related directories -.quasar -/dist -/quasar.config.*.temporary.compiled* - -# Cordova related directories and files -/src-cordova/node_modules -/src-cordova/platforms -/src-cordova/plugins -/src-cordova/www - -# Capacitor related directories and files -/src-capacitor/www -/src-capacitor/node_modules - -# Log files -npm-debug.log* -yarn-debug.log* -yarn-error.log* - -# Editor directories and files -.idea -*.suo -*.ntvs* -*.njsproj -*.sln - -# local .env files -.env.local* - - - -# pnpm-related options -shamefully-hoist=true -strict-peer-dependencies=false -# to get the latest compatible packages when creating the project https://github.com/pnpm/pnpm/issues/6463 -resolution-mode=highest - - - -// Configuration for your app -// https://v2.quasar.dev/quasar-cli-vite/quasar-config-file - -import { defineConfig } from '#q-app/wrappers'; -import { fileURLToPath } from 'node:url'; - -export default defineConfig((ctx) => { - return { - // https://v2.quasar.dev/quasar-cli-vite/prefetch-feature - // preFetch: true, - - // app boot file (/src/boot) - // --> boot files are part of "main.js" - // https://v2.quasar.dev/quasar-cli-vite/boot-files - boot: ['i18n', 'axios'], - - // https://v2.quasar.dev/quasar-cli-vite/quasar-config-file#css - css: ['app.scss'], - - // https://github.com/quasarframework/quasar/tree/dev/extras - extras: [ - // 'ionicons-v4', - // 'mdi-v7', - // 'fontawesome-v6', - // 'eva-icons', - // 'themify', - // 'line-awesome', - // 'roboto-font-latin-ext', // this or either 'roboto-font', NEVER both! - - 'roboto-font', // optional, you are not bound to it - 'material-icons', // optional, you are not bound to it - ], - - // Full list of options: https://v2.quasar.dev/quasar-cli-vite/quasar-config-file#build - build: { - target: { - browser: ['es2022', 'firefox115', 'chrome115', 'safari14'], - node: 'node20', - }, - - typescript: { - strict: true, - vueShim: true, - // extendTsConfig (tsConfig) {} - }, - - vueRouterMode: 'hash', // available values: 'hash', 'history' - // vueRouterBase, - // vueDevtools, - // vueOptionsAPI: false, - - // rebuildCache: true, // rebuilds Vite/linter/etc cache on startup - - // publicPath: '/', - // analyze: true, - // env: {}, - // rawDefine: {} - // ignorePublicFolder: true, - // minify: false, - // polyfillModulePreload: true, - // distDir - - // extendViteConf (viteConf) {}, - // viteVuePluginOptions: {}, - - vitePlugins: [ - [ - '@intlify/unplugin-vue-i18n/vite', - { - // if you want to use Vue I18n Legacy API, you need to set `compositionOnly: false` - // compositionOnly: false, - - // if you want to use named tokens in your Vue I18n messages, such as 'Hello {name}', - // you need to set `runtimeOnly: false` - // runtimeOnly: false, - - ssr: ctx.modeName === 'ssr', - - // you need to set i18n resource including paths ! - include: [fileURLToPath(new URL('./src/i18n', import.meta.url))], - }, - ], - - [ - 'vite-plugin-checker', - { - vueTsc: true, - eslint: { - lintCommand: 'eslint -c ./eslint.config.js "./src*/**/*.{ts,js,mjs,cjs,vue}"', - useFlatConfig: true, - }, - }, - { server: false }, - ], - ], - }, - - // Full list of options: https://v2.quasar.dev/quasar-cli-vite/quasar-config-file#devserver - devServer: { - // https: true, - open: true, // opens browser window automatically - }, - - // https://v2.quasar.dev/quasar-cli-vite/quasar-config-file#framework - framework: { - config: {}, - - // iconSet: 'material-icons', // Quasar icon set - // lang: 'en-US', // Quasar language pack - - // For special cases outside of where the auto-import strategy can have an impact - // (like functional components as one of the examples), - // you can manually specify Quasar components/directives to be available everywhere: - // - // components: [], - // directives: [], - - // Quasar plugins - plugins: ['Notify'], - }, - - // animations: 'all', // --- includes all animations - // https://v2.quasar.dev/options/animations - animations: [], - - // https://v2.quasar.dev/quasar-cli-vite/quasar-config-file#sourcefiles - // sourceFiles: { - // rootComponent: 'src/App.vue', - // router: 'src/router/index', - // store: 'src/store/index', - // pwaRegisterServiceWorker: 'src-pwa/register-service-worker', - // pwaServiceWorker: 'src-pwa/custom-service-worker', - // pwaManifestFile: 'src-pwa/manifest.json', - // electronMain: 'src-electron/electron-main', - // electronPreload: 'src-electron/electron-preload' - // bexManifestFile: 'src-bex/manifest.json - // }, - - // https://v2.quasar.dev/quasar-cli-vite/developing-ssr/configuring-ssr - ssr: { - prodPort: 3000, // The default port that the production server should use - // (gets superseded if process.env.PORT is specified at runtime) - - middlewares: [ - 'render', // keep this as last one - ], - - // extendPackageJson (json) {}, - // extendSSRWebserverConf (esbuildConf) {}, - - // manualStoreSerialization: true, - // manualStoreSsrContextInjection: true, - // manualStoreHydration: true, - // manualPostHydrationTrigger: true, - - pwa: false, - // pwaOfflineHtmlFilename: 'offline.html', // do NOT use index.html as name! - - // pwaExtendGenerateSWOptions (cfg) {}, - // pwaExtendInjectManifestOptions (cfg) {} - }, - - // https://v2.quasar.dev/quasar-cli-vite/developing-pwa/configuring-pwa - pwa: { - workboxMode: 'InjectManifest', // Changed from 'GenerateSW' to 'InjectManifest' - swFilename: 'sw.js', - manifestFilename: 'manifest.json', - injectPwaMetaTags: true, - // extendManifestJson (json) {}, - // useCredentialsForManifestTag: true, - // extendPWACustomSWConf (esbuildConf) {}, - // extendGenerateSWOptions (cfg) {}, - // extendInjectManifestOptions (cfg) {} - }, - - // Full list of options: https://v2.quasar.dev/quasar-cli-vite/developing-cordova-apps/configuring-cordova - cordova: { - // noIosLegacyBuildFlag: true, // uncomment only if you know what you are doing - }, - - // Full list of options: https://v2.quasar.dev/quasar-cli-vite/developing-capacitor-apps/configuring-capacitor - capacitor: { - hideSplashscreen: true, - }, - - // Full list of options: https://v2.quasar.dev/quasar-cli-vite/developing-electron-apps/configuring-electron - electron: { - // extendElectronMainConf (esbuildConf) {}, - // extendElectronPreloadConf (esbuildConf) {}, - - // extendPackageJson (json) {}, - - // Electron preload scripts (if any) from /src-electron, WITHOUT file extension - preloadScripts: ['electron-preload'], - - // specify the debugging port to use for the Electron app when running in development mode - inspectPort: 5858, - - bundler: 'packager', // 'packager' or 'builder' - - packager: { - // https://github.com/electron-userland/electron-packager/blob/master/docs/api.md#options - // OS X / Mac App Store - // appBundleId: '', - // appCategoryType: '', - // osxSign: '', - // protocol: 'myapp://path', - // Windows only - // win32metadata: { ... } - }, - - builder: { - // https://www.electron.build/configuration/configuration - - appId: 'mitlist', - }, - }, - - // Full list of options: https://v2.quasar.dev/quasar-cli-vite/developing-browser-extensions/configuring-bex - bex: { - // extendBexScriptsConf (esbuildConf) {}, - // extendBexManifestJson (json) {}, - - /** - * The list of extra scripts (js/ts) not in your bex manifest that you want to - * compile and use in your browser extension. Maybe dynamic use them? - * - * Each entry in the list should be a relative filename to /src-bex/ - * - * @example [ 'my-script.ts', 'sub-folder/my-other-script.js' ] - */ - extraScripts: [], - }, - }; -}); - - - -# mitlist (mitlist) - -mitlist pwa - -## Install the dependencies -```bash -yarn -# or -npm install -``` - -### Start the app in development mode (hot-code reloading, error reporting, etc.) -```bash -quasar dev -``` - - -### Lint the files -```bash -yarn lint -# or -npm run lint -``` - - -### Format the files -```bash -yarn format -# or -npm run format -``` - - -### Build the app for production -```bash -quasar build -``` - -### Customize the configuration -See [Configuring quasar.config.js](https://v2.quasar.dev/quasar-cli-vite/quasar-config-js). - - - -/* - * This file (which will be your service worker) - * is picked up by the build system ONLY if - * quasar.config file > pwa > workboxMode is set to "InjectManifest" - */ - -declare const self: ServiceWorkerGlobalScope & - typeof globalThis & { skipWaiting: () => Promise }; - -import { clientsClaim } from 'workbox-core'; -import { - precacheAndRoute, - cleanupOutdatedCaches, - createHandlerBoundToURL, -} from 'workbox-precaching'; -import { registerRoute, NavigationRoute } from 'workbox-routing'; -import { CacheFirst, NetworkFirst } from 'workbox-strategies'; -import { ExpirationPlugin } from 'workbox-expiration'; -import { CacheableResponsePlugin } from 'workbox-cacheable-response'; -import type { WorkboxPlugin } from 'workbox-core/types'; - -self.skipWaiting().catch((error) => { - console.error('Error during service worker activation:', error); -}); -clientsClaim(); - -// Use with precache injection -precacheAndRoute(self.__WB_MANIFEST); - -cleanupOutdatedCaches(); - -// Cache app shell and static assets with Cache First strategy -registerRoute( - // Match static assets - ({ request }) => - request.destination === 'style' || - request.destination === 'script' || - request.destination === 'image' || - request.destination === 'font', - new CacheFirst({ - cacheName: 'static-assets', - plugins: [ - new CacheableResponsePlugin({ - statuses: [0, 200], - }) as WorkboxPlugin, - new ExpirationPlugin({ - maxEntries: 60, - maxAgeSeconds: 30 * 24 * 60 * 60, // 30 days - }) as WorkboxPlugin, - ], - }) -); - -// Cache API calls with Network First strategy -registerRoute( - // Match API calls - ({ url }) => url.pathname.startsWith('/api/'), - new NetworkFirst({ - cacheName: 'api-cache', - plugins: [ - new CacheableResponsePlugin({ - statuses: [0, 200], - }) as WorkboxPlugin, - new ExpirationPlugin({ - maxEntries: 50, - maxAgeSeconds: 24 * 60 * 60, // 24 hours - }) as WorkboxPlugin, - ], - }) -); - -// Non-SSR fallbacks to index.html -// Production SSR fallbacks to offline.html (except for dev) -if (process.env.MODE !== 'ssr' || process.env.PROD) { - registerRoute( - new NavigationRoute(createHandlerBoundToURL(process.env.PWA_FALLBACK_HTML), { - denylist: [new RegExp(process.env.PWA_SERVICE_WORKER_REGEX), /workbox-(.)*\.js$/], - }), - ); -} - - - - - - - - - - - - - - - - - - - - - - - - - - - -import type { RouteRecordRaw } from 'vue-router'; - -const routes: RouteRecordRaw[] = [ - { - path: '/', - component: () => import('layouts/MainLayout.vue'), - children: [ - { path: '', redirect: '/lists' }, - { path: 'lists', name: 'PersonalLists', component: () => import('pages/ListsPage.vue') }, - { - path: 'lists/:id', - name: 'ListDetail', - component: () => import('pages/ListDetailPage.vue'), - props: true, - }, - { path: 'groups', name: 'GroupsList', component: () => import('pages/GroupsPage.vue') }, - { - path: 'groups/:id', - name: 'GroupDetail', - component: () => import('pages/GroupDetailPage.vue'), - props: true, - }, - { - path: 'groups/:groupId/lists', - name: 'GroupLists', - component: () => import('pages/ListsPage.vue'), - props: true, - }, - { path: 'account', name: 'Account', component: () => import('pages/AccountPage.vue') }, - ], - }, - - { - path: '/', - component: () => import('layouts/AuthLayout.vue'), - children: [ - { path: 'login', component: () => import('pages/LoginPage.vue') }, - { path: 'signup', component: () => import('pages/SignupPage.vue') }, - ], - }, - - // Always leave this as last one, - // but you can also remove it - { - path: '/:catchAll(.*)*', - component: () => import('pages/ErrorNotFound.vue'), - }, -]; - -export default routes; - - - -import { defineStore } from 'pinia'; -import { ref, computed } from 'vue'; -import { apiClient, API_ENDPOINTS } from 'src/config/api'; - -interface AuthState { - accessToken: string | null; - refreshToken: string | null; - user: { - email: string; - name: string; - } | null; -} - -export const useAuthStore = defineStore('auth', () => { - // State - const accessToken = ref(localStorage.getItem('token')); - const refreshToken = ref(localStorage.getItem('refresh_token')); - const user = ref(null); - - // Getters - const isAuthenticated = computed(() => !!accessToken.value); - const getUser = computed(() => user.value); - - // Actions - const setTokens = (tokens: { access_token: string; refresh_token: string }) => { - accessToken.value = tokens.access_token; - refreshToken.value = tokens.refresh_token; - localStorage.setItem('token', tokens.access_token); - localStorage.setItem('refresh_token', tokens.refresh_token); - }; - - const clearTokens = () => { - accessToken.value = null; - refreshToken.value = null; - user.value = null; - localStorage.removeItem('token'); - localStorage.removeItem('refresh_token'); - }; - - const setUser = (userData: AuthState['user']) => { - user.value = userData; - }; - - const login = async (email: string, password: string) => { - const formData = new FormData(); - formData.append('username', email); - formData.append('password', password); - - const response = await apiClient.post(API_ENDPOINTS.AUTH.LOGIN, formData, { - headers: { - 'Content-Type': 'application/x-www-form-urlencoded', - }, - }); - - const { access_token, refresh_token } = response.data; - setTokens({ access_token, refresh_token }); - return response.data; - }; - - const signup = async (userData: { name: string; email: string; password: string }) => { - const response = await apiClient.post(API_ENDPOINTS.AUTH.SIGNUP, userData); - return response.data; - }; - - const refreshAccessToken = async () => { - if (!refreshToken.value) { - throw new Error('No refresh token available'); - } - - try { - const response = await apiClient.post(API_ENDPOINTS.AUTH.REFRESH_TOKEN, { - refresh_token: refreshToken.value, - }); - - const { access_token, refresh_token } = response.data; - setTokens({ access_token, refresh_token }); - return response.data; - } catch (error) { - clearTokens(); - throw error; - } - }; - - const logout = () => { - clearTokens(); - }; - - return { - // State - accessToken, - refreshToken, - user, - // Getters - isAuthenticated, - getUser, - // Actions - setTokens, - clearTokens, - setUser, - login, - signup, - refreshAccessToken, - logout, - }; -}); - - - -import { defineStore } from 'pinia'; -import { ref, computed } from 'vue'; -import { useQuasar } from 'quasar'; -import { LocalStorage } from 'quasar'; - -export interface OfflineAction { - id: string; - type: 'add' | 'complete' | 'update' | 'delete'; - itemId?: string; - data: unknown; - timestamp: number; - version?: number; -} - -export interface ConflictResolution { - version: 'local' | 'server' | 'merge'; - action: OfflineAction; -} - -export interface ConflictData { - localVersion: { - data: Record; - timestamp: number; - }; - serverVersion: { - data: Record; - timestamp: number; - }; - action: OfflineAction; -} - -export const useOfflineStore = defineStore('offline', () => { - const $q = useQuasar(); - const isOnline = ref(navigator.onLine); - const pendingActions = ref([]); - const isProcessingQueue = ref(false); - const showConflictDialog = ref(false); - const currentConflict = ref(null); - - // Initialize from IndexedDB - const init = () => { - try { - const stored = LocalStorage.getItem('offline-actions'); - if (stored) { - pendingActions.value = JSON.parse(stored as string); - } - } catch (error) { - console.error('Failed to load offline actions:', error); - } - }; - - // Save to IndexedDB - const saveToStorage = () => { - try { - LocalStorage.set('offline-actions', JSON.stringify(pendingActions.value)); - } catch (error) { - console.error('Failed to save offline actions:', error); - } - }; - - // Add a new offline action - const addAction = (action: Omit) => { - const newAction: OfflineAction = { - ...action, - id: crypto.randomUUID(), - timestamp: Date.now(), - }; - pendingActions.value.push(newAction); - saveToStorage(); - }; - - // Process the queue when online - const processQueue = async () => { - if (isProcessingQueue.value || !isOnline.value) return; - - isProcessingQueue.value = true; - const actions = [...pendingActions.value]; - - for (const action of actions) { - try { - await processAction(action); - pendingActions.value = pendingActions.value.filter(a => a.id !== action.id); - saveToStorage(); - } catch (error) { - if (error instanceof Error && error.message.includes('409')) { - $q.notify({ - type: 'warning', - message: 'Item was modified by someone else while you were offline. Please review.', - actions: [ - { - label: 'Review', - color: 'white', - handler: () => { - // TODO: Implement conflict resolution UI - } - } - ] - }); - } else { - console.error('Failed to process offline action:', error); - } - } - } - - isProcessingQueue.value = false; - }; - - // Process a single action - const processAction = async (action: OfflineAction) => { - TODO: Implement actual API calls - switch (action.type) { - case 'add': - // await api.addItem(action.data); - break; - case 'complete': - // await api.completeItem(action.itemId, action.data); - break; - case 'update': - // await api.updateItem(action.itemId, action.data); - break; - case 'delete': - // await api.deleteItem(action.itemId); - break; - } - }; - - // Listen for online/offline status changes - const setupNetworkListeners = () => { - window.addEventListener('online', () => { - (async () => { - isOnline.value = true; - await processQueue(); - })().catch(error => { - console.error('Error processing queue:', error); - }); - }); - - window.addEventListener('offline', () => { - isOnline.value = false; - }); - }; - - // Computed properties - const hasPendingActions = computed(() => pendingActions.value.length > 0); - const pendingActionCount = computed(() => pendingActions.value.length); - - // Initialize - init(); - setupNetworkListeners(); - - const handleConflictResolution = (resolution: ConflictResolution) => { - // Implement the logic to handle the conflict resolution - console.log('Conflict resolution:', resolution); - }; - - return { - isOnline, - pendingActions, - hasPendingActions, - pendingActionCount, - showConflictDialog, - currentConflict, - addAction, - processQueue, - handleConflictResolution, - }; -}); - - - -{ - "extends": "./.quasar/tsconfig.json" -} - - - -# app/api/v1/endpoints/auth.py -import logging -from typing import Annotated -from fastapi import APIRouter, Depends, HTTPException, status -from fastapi.security import OAuth2PasswordRequestForm -from sqlalchemy.ext.asyncio import AsyncSession - -from app.database import get_db -from app.schemas.user import UserCreate, UserPublic -from app.schemas.auth import Token -from app.crud import user as crud_user -from app.core.security import ( - verify_password, - create_access_token, - create_refresh_token, - verify_refresh_token -) -from app.core.exceptions import ( - EmailAlreadyRegisteredError, - InvalidCredentialsError, - UserCreationError -) -from app.config import settings - -logger = logging.getLogger(__name__) -router = APIRouter() - -@router.post( - "/signup", - response_model=UserPublic, - status_code=201, - summary="Register New User", - description="Creates a new user account.", - tags=["Authentication"] -) -async def signup( - user_in: UserCreate, - db: AsyncSession = Depends(get_db) -): - """ - Handles user registration. - - Validates input data. - - Checks if email already exists. - - Hashes the password. - - Stores the new user in the database. - """ - logger.info(f"Signup attempt for email: {user_in.email}") - existing_user = await crud_user.get_user_by_email(db, email=user_in.email) - if existing_user: - logger.warning(f"Signup failed: Email already registered - {user_in.email}") - raise EmailAlreadyRegisteredError() - - try: - created_user = await crud_user.create_user(db=db, user_in=user_in) - logger.info(f"User created successfully: {created_user.email} (ID: {created_user.id})") - return created_user - except Exception as e: - logger.error(f"Error during user creation for {user_in.email}: {e}", exc_info=True) - raise UserCreationError() - -@router.post( - "/login", - response_model=Token, - summary="User Login", - description="Authenticates a user and returns an access and refresh token.", - tags=["Authentication"] -) -async def login( - form_data: Annotated[OAuth2PasswordRequestForm, Depends()], - db: AsyncSession = Depends(get_db) -): - """ - Handles user login. - - Finds user by email (provided in 'username' field of form). - - Verifies the provided password against the stored hash. - - Generates and returns JWT access and refresh tokens upon successful authentication. - """ - logger.info(f"Login attempt for user: {form_data.username}") - user = await crud_user.get_user_by_email(db, email=form_data.username) - - if not user or not verify_password(form_data.password, user.password_hash): - logger.warning(f"Login failed: Invalid credentials for user {form_data.username}") - raise InvalidCredentialsError() - - access_token = create_access_token(subject=user.email) - refresh_token = create_refresh_token(subject=user.email) - logger.info(f"Login successful, tokens generated for user: {user.email}") - return Token( - access_token=access_token, - refresh_token=refresh_token, - token_type=settings.TOKEN_TYPE - ) - -@router.post( - "/refresh", - response_model=Token, - summary="Refresh Access Token", - description="Refreshes an access token using a refresh token.", - tags=["Authentication"] -) -async def refresh_token( - refresh_token_str: str, - db: AsyncSession = Depends(get_db) -): - """ - Handles access token refresh. - - Verifies the provided refresh token. - - If valid, generates and returns a new JWT access token and the same refresh token. - """ - logger.info("Access token refresh attempt") - payload = verify_refresh_token(refresh_token_str) - if not payload: - logger.warning("Refresh token invalid or expired") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid or expired refresh token", - headers={"WWW-Authenticate": "Bearer"}, - ) - - user_email = payload.get("sub") - if not user_email: - logger.error("User email not found in refresh token payload") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid refresh token payload", - headers={"WWW-Authenticate": "Bearer"}, - ) - - new_access_token = create_access_token(subject=user_email) - logger.info(f"Access token refreshed for user: {user_email}") - return Token( - access_token=new_access_token, - refresh_token=refresh_token_str, - token_type=settings.TOKEN_TYPE - ) - - - -# app/api/v1/endpoints/lists.py -import logging -from typing import List as PyList, Optional # Alias for Python List type hint - -from fastapi import APIRouter, Depends, HTTPException, status, Response, Query # Added Query -from sqlalchemy.ext.asyncio import AsyncSession - -from app.database import get_db -from app.api.dependencies import get_current_user -from app.models import User as UserModel -from app.schemas.list import ListCreate, ListUpdate, ListPublic, ListDetail -from app.schemas.message import Message # For simple responses -from app.crud import list as crud_list -from app.crud import group as crud_group # Need for group membership check -from app.schemas.list import ListStatus -from app.core.exceptions import ( - GroupMembershipError, - ListNotFoundError, - ListPermissionError, - ListStatusNotFoundError, - ConflictError # Added ConflictError -) - -logger = logging.getLogger(__name__) -router = APIRouter() - -@router.post( - "", # Route relative to prefix "/lists" - response_model=ListPublic, # Return basic list info on creation - status_code=status.HTTP_201_CREATED, - summary="Create New List", - tags=["Lists"] -) -async def create_list( - list_in: ListCreate, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """ - Creates a new shopping list. - - If `group_id` is provided, the user must be a member of that group. - - If `group_id` is null, it's a personal list. - """ - logger.info(f"User {current_user.email} creating list: {list_in.name}") - group_id = list_in.group_id - - # Permission Check: If sharing with a group, verify membership - if group_id: - is_member = await crud_group.is_user_member(db, group_id=group_id, user_id=current_user.id) - if not is_member: - logger.warning(f"User {current_user.email} attempted to create list in group {group_id} but is not a member.") - raise GroupMembershipError(group_id, "create lists") - - created_list = await crud_list.create_list(db=db, list_in=list_in, creator_id=current_user.id) - logger.info(f"List '{created_list.name}' (ID: {created_list.id}) created successfully for user {current_user.email}.") - return created_list - - -@router.get( - "", # Route relative to prefix "/lists" - response_model=PyList[ListPublic], # Return a list of basic list info - summary="List Accessible Lists", - tags=["Lists"] -) -async def read_lists( - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), - # Add pagination parameters later if needed: skip: int = 0, limit: int = 100 -): - """ - Retrieves lists accessible to the current user: - - Personal lists created by the user. - - Lists belonging to groups the user is a member of. - """ - logger.info(f"Fetching lists accessible to user: {current_user.email}") - lists = await crud_list.get_lists_for_user(db=db, user_id=current_user.id) - return lists - - -@router.get( - "/{list_id}", - response_model=ListDetail, # Return detailed list info including items - summary="Get List Details", - tags=["Lists"] -) -async def read_list( - list_id: int, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """ - Retrieves details for a specific list, including its items, - if the user has permission (creator or group member). - """ - logger.info(f"User {current_user.email} requesting details for list ID: {list_id}") - # The check_list_permission function will raise appropriate exceptions - list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) - return list_db - - -@router.put( - "/{list_id}", - response_model=ListPublic, # Return updated basic info - summary="Update List", - tags=["Lists"], - responses={ # Add 409 to responses - status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified by someone else"} - } -) -async def update_list( - list_id: int, - list_in: ListUpdate, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """ - Updates a list's details (name, description, is_complete). - Requires user to be the creator or a member of the list's group. - The client MUST provide the current `version` of the list in the `list_in` payload. - If the version does not match, a 409 Conflict is returned. - """ - logger.info(f"User {current_user.email} attempting to update list ID: {list_id} with version {list_in.version}") - list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) - - try: - updated_list = await crud_list.update_list(db=db, list_db=list_db, list_in=list_in) - logger.info(f"List {list_id} updated successfully by user {current_user.email} to version {updated_list.version}.") - return updated_list - except ConflictError as e: # Catch and re-raise as HTTPException for proper FastAPI response - logger.warning(f"Conflict updating list {list_id} for user {current_user.email}: {str(e)}") - raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) - except Exception as e: # Catch other potential errors from crud operation - logger.error(f"Error updating list {list_id} for user {current_user.email}: {str(e)}") - # Consider a more generic error, but for now, let's keep it specific if possible - # Re-raising might be better if crud layer already raises appropriate HTTPExceptions - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the list.") - - -@router.delete( - "/{list_id}", - status_code=status.HTTP_204_NO_CONTENT, # Standard for successful DELETE with no body - summary="Delete List", - tags=["Lists"], - responses={ # Add 409 to responses - status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified, cannot delete specified version"} - } -) -async def delete_list( - list_id: int, - expected_version: Optional[int] = Query(None, description="The expected version of the list to delete for optimistic locking."), - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """ - Deletes a list. Requires user to be the creator of the list. - If `expected_version` is provided and does not match the list's current version, - a 409 Conflict is returned. - """ - logger.info(f"User {current_user.email} attempting to delete list ID: {list_id}, expected version: {expected_version}") - # Use the helper, requiring creator permission - list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id, require_creator=True) - - if expected_version is not None and list_db.version != expected_version: - logger.warning( - f"Conflict deleting list {list_id} for user {current_user.email}. " - f"Expected version {expected_version}, actual version {list_db.version}." - ) - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail=f"List has been modified. Expected version {expected_version}, but current version is {list_db.version}. Please refresh." - ) - - await crud_list.delete_list(db=db, list_db=list_db) - logger.info(f"List {list_id} (version: {list_db.version}) deleted successfully by user {current_user.email}.") - return Response(status_code=status.HTTP_204_NO_CONTENT) - - -@router.get( - "/{list_id}/status", - response_model=ListStatus, - summary="Get List Status", - tags=["Lists"] -) -async def read_list_status( - list_id: int, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """ - Retrieves the last update time for the list and its items, plus item count. - Used for polling to check if a full refresh is needed. - Requires user to have permission to view the list. - """ - # Verify user has access to the list first - list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) - if not list_db: - # Check if list exists at all for correct error code - exists = await crud_list.get_list_by_id(db, list_id) - if not exists: - raise ListNotFoundError(list_id) - raise ListPermissionError(list_id, "access this list's status") - - # Fetch the status details - list_status = await crud_list.get_list_status(db=db, list_id=list_id) - if not list_status: - # Should not happen if check_list_permission passed, but handle defensively - logger.error(f"Could not retrieve status for list {list_id} even though permission check passed.") - raise ListStatusNotFoundError(list_id) - - return list_status - - - -import logging -from typing import List - -from fastapi import APIRouter, Depends, UploadFile, File, HTTPException, status -from google.api_core import exceptions as google_exceptions - -from app.api.dependencies import get_current_user -from app.models import User as UserModel -from app.schemas.ocr import OcrExtractResponse -from app.core.gemini import extract_items_from_image_gemini, gemini_initialization_error, GeminiOCRService -from app.core.exceptions import ( - OCRServiceUnavailableError, - OCRServiceConfigError, - OCRUnexpectedError, - OCRQuotaExceededError, - InvalidFileTypeError, - FileTooLargeError, - OCRProcessingError -) -from app.config import settings - -logger = logging.getLogger(__name__) -router = APIRouter() -ocr_service = GeminiOCRService() - -@router.post( - "/extract-items", - response_model=OcrExtractResponse, - summary="Extract List Items via OCR (Gemini)", - tags=["OCR"] -) -async def ocr_extract_items( - current_user: UserModel = Depends(get_current_user), - image_file: UploadFile = File(..., description="Image file (JPEG, PNG, WEBP) of the shopping list or receipt."), -): - """ - Accepts an image upload, sends it to Gemini Flash with a prompt - to extract shopping list items, and returns the parsed items. - """ - # Check if Gemini client initialized correctly - if gemini_initialization_error: - logger.error("OCR endpoint called but Gemini client failed to initialize.") - raise OCRServiceUnavailableError(gemini_initialization_error) - - logger.info(f"User {current_user.email} uploading image '{image_file.filename}' for OCR extraction.") - - # --- File Validation --- - if image_file.content_type not in settings.ALLOWED_IMAGE_TYPES: - logger.warning(f"Invalid file type uploaded by {current_user.email}: {image_file.content_type}") - raise InvalidFileTypeError() - - # Simple size check - contents = await image_file.read() - if len(contents) > settings.MAX_FILE_SIZE_MB * 1024 * 1024: - logger.warning(f"File too large uploaded by {current_user.email}: {len(contents)} bytes") - raise FileTooLargeError() - - try: - # Call the Gemini helper function - extracted_items = await extract_items_from_image_gemini( - image_bytes=contents, - mime_type=image_file.content_type - ) - - logger.info(f"Successfully extracted {len(extracted_items)} items for user {current_user.email}.") - return OcrExtractResponse(extracted_items=extracted_items) - - except OCRServiceUnavailableError: - raise OCRServiceUnavailableError() - except OCRServiceConfigError: - raise OCRServiceConfigError() - except OCRQuotaExceededError: - raise OCRQuotaExceededError() - except Exception as e: - raise OCRProcessingError(str(e)) - - finally: - # Ensure file handle is closed - await image_file.close() - - - -# app/core/gemini.py -import logging -from typing import List -import google.generativeai as genai -from google.generativeai.types import HarmCategory, HarmBlockThreshold # For safety settings -from google.api_core import exceptions as google_exceptions -from app.config import settings -from app.core.exceptions import ( - OCRServiceUnavailableError, - OCRServiceConfigError, - OCRUnexpectedError, - OCRQuotaExceededError -) - -logger = logging.getLogger(__name__) - -# --- Global variable to hold the initialized model client --- -gemini_flash_client = None -gemini_initialization_error = None # Store potential init error - -# --- Configure and Initialize --- -try: - if settings.GEMINI_API_KEY: - genai.configure(api_key=settings.GEMINI_API_KEY) - # Initialize the specific model we want to use - gemini_flash_client = genai.GenerativeModel( - model_name=settings.GEMINI_MODEL_NAME, - # Safety settings from config - safety_settings={ - getattr(HarmCategory, category): getattr(HarmBlockThreshold, threshold) - for category, threshold in settings.GEMINI_SAFETY_SETTINGS.items() - }, - # Generation config from settings - generation_config=genai.types.GenerationConfig( - **settings.GEMINI_GENERATION_CONFIG - ) - ) - logger.info(f"Gemini AI client initialized successfully for model '{settings.GEMINI_MODEL_NAME}'.") - else: - # Store error if API key is missing - gemini_initialization_error = "GEMINI_API_KEY not configured. Gemini client not initialized." - logger.error(gemini_initialization_error) - -except Exception as e: - # Catch any other unexpected errors during initialization - gemini_initialization_error = f"Failed to initialize Gemini AI client: {e}" - logger.exception(gemini_initialization_error) # Log full traceback - gemini_flash_client = None # Ensure client is None on error - - -# --- Function to get the client (optional, allows checking error) --- -def get_gemini_client(): - """ - Returns the initialized Gemini client instance. - Raises an exception if initialization failed. - """ - if gemini_initialization_error: - raise RuntimeError(f"Gemini client could not be initialized: {gemini_initialization_error}") - if gemini_flash_client is None: - # This case should ideally be covered by the check above, but as a safeguard: - raise RuntimeError("Gemini client is not available (unknown initialization issue).") - return gemini_flash_client - -# Define the prompt as a constant -OCR_ITEM_EXTRACTION_PROMPT = """ -Extract the shopping list items from this image. -List each distinct item on a new line. -Ignore prices, quantities, store names, discounts, taxes, totals, and other non-item text. -Focus only on the names of the products or items to be purchased. -If the image does not appear to be a shopping list or receipt, state that clearly. -Example output for a grocery list: -Milk -Eggs -Bread -Apples -Organic Bananas -""" - -async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "image/jpeg") -> List[str]: - """ - Uses Gemini Flash to extract shopping list items from image bytes. - - Args: - image_bytes: The image content as bytes. - mime_type: The MIME type of the image (e.g., "image/jpeg", "image/png", "image/webp"). - - Returns: - A list of extracted item strings. - - Raises: - RuntimeError: If the Gemini client is not initialized. - google_exceptions.GoogleAPIError: For API call errors (quota, invalid key etc.). - ValueError: If the response is blocked or contains no usable text. - """ - client = get_gemini_client() # Raises RuntimeError if not initialized - - # Prepare image part for multimodal input - image_part = { - "mime_type": mime_type, - "data": image_bytes - } - - # Prepare the full prompt content - prompt_parts = [ - settings.OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first - image_part # Then the image - ] - - logger.info("Sending image to Gemini for item extraction...") - try: - # Make the API call - # Use generate_content_async for async FastAPI - response = await client.generate_content_async(prompt_parts) - - # --- Process the response --- - # Check for safety blocks or lack of content - if not response.candidates or not response.candidates[0].content.parts: - logger.warning("Gemini response blocked or empty.", extra={"response": response}) - # Check finish_reason if available - finish_reason = response.candidates[0].finish_reason if response.candidates else 'UNKNOWN' - safety_ratings = response.candidates[0].safety_ratings if response.candidates else 'N/A' - if finish_reason == 'SAFETY': - raise ValueError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}") - else: - raise ValueError(f"Gemini response was empty or incomplete. Finish Reason: {finish_reason}") - - # Extract text - assumes the first part of the first candidate is the text response - raw_text = response.text # response.text is a shortcut for response.candidates[0].content.parts[0].text - logger.info("Received raw text from Gemini.") - # logger.debug(f"Gemini Raw Text:\n{raw_text}") # Optional: Log full response text - - # Parse the text response - items = [] - for line in raw_text.splitlines(): # Split by newline - cleaned_line = line.strip() # Remove leading/trailing whitespace - # Basic filtering: ignore empty lines and potential non-item lines - if cleaned_line and len(cleaned_line) > 1: # Ignore very short lines too? - # Add more sophisticated filtering if needed (e.g., regex, keyword check) - items.append(cleaned_line) - - logger.info(f"Extracted {len(items)} potential items.") - return items - - except google_exceptions.GoogleAPIError as e: - logger.error(f"Gemini API Error: {e}", exc_info=True) - # Re-raise specific Google API errors for endpoint to handle (e.g., quota) - raise e - except Exception as e: - # Catch other unexpected errors during generation or processing - logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True) - # Wrap in a generic ValueError or re-raise - raise ValueError(f"Failed to process image with Gemini: {e}") from e - -class GeminiOCRService: - def __init__(self): - try: - genai.configure(api_key=settings.GEMINI_API_KEY) - self.model = genai.GenerativeModel(settings.GEMINI_MODEL_NAME) - self.model.safety_settings = settings.GEMINI_SAFETY_SETTINGS - self.model.generation_config = settings.GEMINI_GENERATION_CONFIG - except Exception as e: - logger.error(f"Failed to initialize Gemini client: {e}") - raise OCRServiceConfigError() - - async def extract_items(self, image_data: bytes) -> List[str]: - """ - Extract shopping list items from an image using Gemini Vision. - """ - try: - # Create image part - image_parts = [{"mime_type": "image/jpeg", "data": image_data}] - - # Generate content - response = await self.model.generate_content_async( - contents=[settings.OCR_ITEM_EXTRACTION_PROMPT, *image_parts] - ) - - # Process response - if not response.text: - raise OCRUnexpectedError() - - # Split response into lines and clean up - items = [ - item.strip() - for item in response.text.split("\n") - if item.strip() and not item.strip().startswith("Example") - ] - - return items - - except Exception as e: - logger.error(f"Error during OCR extraction: {e}") - if "quota" in str(e).lower(): - raise OCRQuotaExceededError() - raise OCRServiceUnavailableError() - - - -# be/Dockerfile - -# Choose a suitable Python base image -FROM python:3.11-slim - -# Set environment variables -ENV PYTHONDONTWRITEBYTECODE 1 # Prevent python from writing pyc files -ENV PYTHONUNBUFFERED 1 # Keep stdout/stderr unbuffered - -# Set the working directory in the container -WORKDIR /app - -# Install system dependencies if needed (e.g., for psycopg2 build) -# RUN apt-get update && apt-get install -y --no-install-recommends gcc build-essential libpq-dev && rm -rf /var/lib/apt/lists/* - -# Install Python dependencies -# Upgrade pip first -RUN pip install --no-cache-dir --upgrade pip -# Copy only requirements first to leverage Docker cache -COPY requirements.txt requirements.txt -# Install dependencies -RUN pip install --no-cache-dir -r requirements.txt - -# Copy the rest of the application code into the working directory -COPY . . -# This includes your 'app/' directory, alembic.ini, etc. - -# Expose the port the app runs on -EXPOSE 8000 - -# Command to run the application using uvicorn -# The default command for production (can be overridden in docker-compose for development) -# Note: Make sure 'app.main:app' correctly points to your FastAPI app instance -# relative to the WORKDIR (/app). If your main.py is directly in /app, this is correct. -CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] - - - -fastapi>=0.95.0 -uvicorn[standard]>=0.20.0 -sqlalchemy[asyncio]>=2.0.0 # Core ORM + Async support -asyncpg>=0.27.0 # Async PostgreSQL driver -psycopg2-binary>=2.9.0 # Often needed by Alembic even if app uses asyncpg -alembic>=1.9.0 # Database migrations -pydantic-settings>=2.0.0 # For loading settings from .env -python-dotenv>=1.0.0 # To load .env file for scripts/alembic -passlib[bcrypt]>=1.7.4 -python-jose[cryptography]>=3.3.0 -pydantic[email] -google-generativeai>=0.5.0 - - - -{ - "name": "mitlist", - "version": "0.0.1", - "description": "mitlist pwa", - "productName": "mitlist", - "author": "Mohamad ", - "type": "module", - "private": true, - "scripts": { - "lint": "eslint -c ./eslint.config.js \"./src*/**/*.{ts,js,cjs,mjs,vue}\"", - "format": "prettier --write \"**/*.{js,ts,vue,scss,html,md,json}\" --ignore-path .gitignore", - "test": "echo \"No test specified\" && exit 0", - "dev": "quasar dev", - "build": "quasar build", - "postinstall": "quasar prepare" - }, - "dependencies": { - "@quasar/extras": "^1.16.4", - "axios": "^1.2.1", - "pinia": "^3.0.1", - "quasar": "^2.16.0", - "register-service-worker": "^1.7.2", - "vue": "^3.4.18", - "vue-i18n": "^11.0.0", - "vue-router": "^4.0.12" - }, - "devDependencies": { - "@eslint/js": "^9.14.0", - "@intlify/unplugin-vue-i18n": "^4.0.0", - "@quasar/app-vite": "^2.1.0", - "@types/node": "^20.5.9", - "@vue/eslint-config-prettier": "^10.1.0", - "@vue/eslint-config-typescript": "^14.4.0", - "autoprefixer": "^10.4.2", - "eslint": "^9.14.0", - "eslint-plugin-vue": "^9.30.0", - "globals": "^15.12.0", - "prettier": "^3.3.3", - "typescript": "~5.5.3", - "vite-plugin-checker": "^0.9.0", - "vue-tsc": "^2.0.29", - "workbox-build": "^7.3.0", - "workbox-cacheable-response": "^7.3.0", - "workbox-core": "^7.3.0", - "workbox-expiration": "^7.3.0", - "workbox-precaching": "^7.3.0", - "workbox-routing": "^7.3.0", - "workbox-strategies": "^7.3.0" - }, - "engines": { - "node": "^28 || ^26 || ^24 || ^22 || ^20 || ^18", - "npm": ">= 6.13.4", - "yarn": ">= 1.21.1" - } -} - - - -import { boot } from 'quasar/wrappers'; -import axios from 'axios'; -import { API_BASE_URL } from 'src/config/api-config'; - -// Create axios instance -const api = axios.create({ - baseURL: API_BASE_URL, - headers: { - 'Content-Type': 'application/json', - }, -}); - -// Request interceptor -api.interceptors.request.use( - (config) => { - const token = localStorage.getItem('token'); - if (token) { - config.headers.Authorization = `Bearer ${token}`; - } - return config; - }, - (error) => { - return Promise.reject(new Error(String(error))); - } -); - -// Response interceptor -api.interceptors.response.use( - (response) => response, - async (error) => { - const originalRequest = error.config; - - // If error is 401 and we haven't tried to refresh token yet - if (error.response?.status === 401 && !originalRequest._retry) { - originalRequest._retry = true; - - try { - const refreshToken = localStorage.getItem('refreshToken'); - if (!refreshToken) { - throw new Error('No refresh token available'); - } - - // Call refresh token endpoint - const response = await api.post('/api/v1/auth/refresh-token', { - refresh_token: refreshToken, - }); - - const { access_token } = response.data; - localStorage.setItem('token', access_token); - - // Retry the original request with new token - originalRequest.headers.Authorization = `Bearer ${access_token}`; - return api(originalRequest); - } catch (refreshError) { - // If refresh token fails, clear storage and redirect to login - localStorage.removeItem('token'); - localStorage.removeItem('refreshToken'); - window.location.href = '/login'; - return Promise.reject(new Error(String(refreshError))); - } - } - - return Promise.reject(new Error(String(error))); - } -); - -export default boot(({ app }) => { - app.config.globalProperties.$axios = axios; - app.config.globalProperties.$api = api; -}); - -export { api }; - - - - - - - - - -# app/api/v1/endpoints/costs.py -import logging -from fastapi import APIRouter, Depends, HTTPException, status -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Session, selectinload -from decimal import Decimal, ROUND_HALF_UP - -from app.database import get_db -from app.api.dependencies import get_current_user -from app.models import ( - User as UserModel, - Group as GroupModel, - List as ListModel, - Expense as ExpenseModel, - Item as ItemModel, - UserGroup as UserGroupModel, - SplitTypeEnum, - ExpenseSplit as ExpenseSplitModel, - Settlement as SettlementModel -) -from app.schemas.cost import ListCostSummary, GroupBalanceSummary -from app.schemas.expense import ExpenseCreate -from app.crud import list as crud_list -from app.crud import expense as crud_expense -from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError, GroupNotFoundError - -logger = logging.getLogger(__name__) -router = APIRouter() - -@router.get( - "/lists/{list_id}/cost-summary", - response_model=ListCostSummary, - summary="Get Cost Summary for a List", - tags=["Costs"], - responses={ - status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this list"}, - status.HTTP_404_NOT_FOUND: {"description": "List or associated user not found"} - } -) -async def get_list_cost_summary( - list_id: int, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """ - Retrieves a calculated cost summary for a specific list, detailing total costs, - equal shares per user, and individual user balances based on their contributions. - - The user must have access to the list to view its cost summary. - Costs are split among group members if the list belongs to a group, or just for - the creator if it's a personal list. All users who added items with prices are - included in the calculation. - """ - logger.info(f"User {current_user.email} requesting cost summary for list {list_id}") - - # 1. Verify user has access to the target list - try: - await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) - except ListPermissionError as e: - logger.warning(f"Permission denied for user {current_user.email} on list {list_id}: {str(e)}") - raise - except ListNotFoundError as e: - logger.warning(f"List {list_id} not found when checking permissions for cost summary: {str(e)}") - raise - - # 2. Get the list with its items and users - list_result = await db.execute( - select(ListModel) - .options( - selectinload(ListModel.items).options(selectinload(ItemModel.added_by_user)), - selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))), - selectinload(ListModel.creator) - ) - .where(ListModel.id == list_id) - ) - db_list = list_result.scalars().first() - if not db_list: - raise ListNotFoundError(list_id) - - # 3. Get or create an expense for this list - expense_result = await db.execute( - select(ExpenseModel) - .where(ExpenseModel.list_id == list_id) - .options(selectinload(ExpenseModel.splits)) - ) - db_expense = expense_result.scalars().first() - - if not db_expense: - # Create a new expense for this list - total_amount = sum(item.price for item in db_list.items if item.price is not None and item.price > Decimal("0")) - if total_amount == Decimal("0"): - return ListCostSummary( - list_id=db_list.id, - list_name=db_list.name, - total_list_cost=Decimal("0.00"), - num_participating_users=0, - equal_share_per_user=Decimal("0.00"), - user_balances=[] - ) - - # Create expense with ITEM_BASED split type - expense_in = ExpenseCreate( - description=f"Cost summary for list {db_list.name}", - total_amount=total_amount, - list_id=list_id, - split_type=SplitTypeEnum.ITEM_BASED, - paid_by_user_id=current_user.id # Use current user as payer for now - ) - db_expense = await crud_expense.create_expense(db=db, expense_in=expense_in) - - # 4. Calculate cost summary from expense splits - participating_users = set() - user_items_added_value = {} - total_list_cost = Decimal("0.00") - - # Get all users who added items - for item in db_list.items: - if item.price is not None and item.price > Decimal("0") and item.added_by_user: - participating_users.add(item.added_by_user) - user_items_added_value[item.added_by_user.id] = user_items_added_value.get(item.added_by_user.id, Decimal("0.00")) + item.price - total_list_cost += item.price - - # Get all users from expense splits - for split in db_expense.splits: - if split.user: - participating_users.add(split.user) - - num_participating_users = len(participating_users) - if num_participating_users == 0: - return ListCostSummary( - list_id=db_list.id, - list_name=db_list.name, - total_list_cost=Decimal("0.00"), - num_participating_users=0, - equal_share_per_user=Decimal("0.00"), - user_balances=[] - ) - - equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - remainder = total_list_cost - (equal_share_per_user * num_participating_users) - - user_balances = [] - first_user_processed = False - for user in participating_users: - items_added = user_items_added_value.get(user.id, Decimal("0.00")) - current_user_share = equal_share_per_user - if not first_user_processed and remainder != Decimal("0"): - current_user_share += remainder - first_user_processed = True - - balance = items_added - current_user_share - user_identifier = user.name if user.name else user.email - user_balances.append( - UserCostShare( - user_id=user.id, - user_identifier=user_identifier, - items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), - amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), - balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - ) - ) - - user_balances.sort(key=lambda x: x.user_identifier) - return ListCostSummary( - list_id=db_list.id, - list_name=db_list.name, - total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), - num_participating_users=num_participating_users, - equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), - user_balances=user_balances - ) - -@router.get( - "/groups/{group_id}/balance-summary", - response_model=GroupBalanceSummary, - summary="Get Detailed Balance Summary for a Group", - tags=["Costs", "Groups"], - responses={ - status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this group"}, - status.HTTP_404_NOT_FOUND: {"description": "Group not found"} - } -) -async def get_group_balance_summary( - group_id: int, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """ - Retrieves a detailed financial balance summary for all users within a specific group. - It considers all expenses, their splits, and all settlements recorded for the group. - The user must be a member of the group to view its balance summary. - """ - logger.info(f"User {current_user.email} requesting balance summary for group {group_id}") - - # 1. Verify user is a member of the target group - group_check = await db.execute( - select(GroupModel) - .options(selectinload(GroupModel.member_associations)) - .where(GroupModel.id == group_id) - ) - db_group_for_check = group_check.scalars().first() - - if not db_group_for_check: - raise GroupNotFoundError(group_id) - - user_is_member = any(assoc.user_id == current_user.id for assoc in db_group_for_check.member_associations) - if not user_is_member: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"User not a member of group {group_id}") - - # 2. Get all expenses and settlements for the group - expenses_result = await db.execute( - select(ExpenseModel) - .where(ExpenseModel.group_id == group_id) - .options(selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user)) - ) - expenses = expenses_result.scalars().all() - - settlements_result = await db.execute( - select(SettlementModel) - .where(SettlementModel.group_id == group_id) - .options( - selectinload(SettlementModel.paid_by_user), - selectinload(SettlementModel.paid_to_user) - ) - ) - settlements = settlements_result.scalars().all() - - # 3. Calculate user balances - user_balances_data = {} - for assoc in db_group_for_check.member_associations: - if assoc.user: - user_balances_data[assoc.user.id] = UserBalanceDetail( - user_id=assoc.user.id, - user_identifier=assoc.user.name if assoc.user.name else assoc.user.email - ) - - # Process expenses - for expense in expenses: - if expense.paid_by_user_id in user_balances_data: - user_balances_data[expense.paid_by_user_id].total_paid_for_expenses += expense.total_amount - - for split in expense.splits: - if split.user_id in user_balances_data: - user_balances_data[split.user_id].total_share_of_expenses += split.owed_amount - - # Process settlements - for settlement in settlements: - if settlement.paid_by_user_id in user_balances_data: - user_balances_data[settlement.paid_by_user_id].total_settlements_paid += settlement.amount - if settlement.paid_to_user_id in user_balances_data: - user_balances_data[settlement.paid_to_user_id].total_settlements_received += settlement.amount - - # Calculate net balances - final_user_balances = [] - for user_id, data in user_balances_data.items(): - data.net_balance = ( - data.total_paid_for_expenses + data.total_settlements_received - ) - (data.total_share_of_expenses + data.total_settlements_paid) - - data.total_paid_for_expenses = data.total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - data.total_share_of_expenses = data.total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - data.total_settlements_paid = data.total_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - data.total_settlements_received = data.total_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - data.net_balance = data.net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - - final_user_balances.append(data) - - # Sort by user identifier - final_user_balances.sort(key=lambda x: x.user_identifier) - - # Calculate suggested settlements - suggested_settlements = calculate_suggested_settlements(final_user_balances) - - return GroupBalanceSummary( - group_id=db_group_for_check.id, - group_name=db_group_for_check.name, - user_balances=final_user_balances, - suggested_settlements=suggested_settlements - ) - - - -# app/api/v1/endpoints/items.py -import logging -from typing import List as PyList, Optional - -from fastapi import APIRouter, Depends, HTTPException, status, Response, Query -from sqlalchemy.ext.asyncio import AsyncSession - -from app.database import get_db -from app.api.dependencies import get_current_user -# --- Import Models Correctly --- -from app.models import User as UserModel -from app.models import Item as ItemModel # <-- IMPORT Item and alias it -# --- End Import Models --- -from app.schemas.item import ItemCreate, ItemUpdate, ItemPublic -from app.crud import item as crud_item -from app.crud import list as crud_list -from app.core.exceptions import ItemNotFoundError, ListPermissionError, ConflictError - -logger = logging.getLogger(__name__) -router = APIRouter() - -# --- Helper Dependency for Item Permissions --- -# Now ItemModel is defined before being used as a type hint -async def get_item_and_verify_access( - item_id: int, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user) -) -> ItemModel: - """Dependency to get an item and verify the user has access to its list.""" - item_db = await crud_item.get_item_by_id(db, item_id=item_id) - if not item_db: - raise ItemNotFoundError(item_id) - - # Check permission on the parent list - try: - await crud_list.check_list_permission(db=db, list_id=item_db.list_id, user_id=current_user.id) - except ListPermissionError as e: - # Re-raise with a more specific message - raise ListPermissionError(item_db.list_id, "access this item's list") - return item_db - - -# --- Endpoints --- - -@router.post( - "/lists/{list_id}/items", # Nested under lists - response_model=ItemPublic, - status_code=status.HTTP_201_CREATED, - summary="Add Item to List", - tags=["Items"] -) -async def create_list_item( - list_id: int, - item_in: ItemCreate, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), -): - """Adds a new item to a specific list. User must have access to the list.""" - user_email = current_user.email # Access email attribute before async operations - logger.info(f"User {user_email} adding item to list {list_id}: {item_in.name}") - # Verify user has access to the target list - try: - await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) - except ListPermissionError as e: - # Re-raise with a more specific message - raise ListPermissionError(list_id, "add items to this list") - - created_item = await crud_item.create_item( - db=db, item_in=item_in, list_id=list_id, user_id=current_user.id - ) - logger.info(f"Item '{created_item.name}' (ID: {created_item.id}) added to list {list_id} by user {user_email}.") - return created_item - - -@router.get( - "/lists/{list_id}/items", # Nested under lists - response_model=PyList[ItemPublic], - summary="List Items in List", - tags=["Items"] -) -async def read_list_items( - list_id: int, - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), - # Add sorting/filtering params later if needed: sort_by: str = 'created_at', order: str = 'asc' -): - """Retrieves all items for a specific list if the user has access.""" - user_email = current_user.email # Access email attribute before async operations - logger.info(f"User {user_email} listing items for list {list_id}") - # Verify user has access to the list - try: - await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) - except ListPermissionError as e: - # Re-raise with a more specific message - raise ListPermissionError(list_id, "view items in this list") - - items = await crud_item.get_items_by_list_id(db=db, list_id=list_id) - return items - - -@router.put( - "/items/{item_id}", # Operate directly on item ID - response_model=ItemPublic, - summary="Update Item", - tags=["Items"], - responses={ - status.HTTP_409_CONFLICT: {"description": "Conflict: Item has been modified by someone else"} - } -) -async def update_item( - item_id: int, # Item ID from path - item_in: ItemUpdate, - item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), # Need user ID for completed_by -): - """ - Updates an item's details (name, quantity, is_complete, price). - User must have access to the list the item belongs to. - The client MUST provide the current `version` of the item in the `item_in` payload. - If the version does not match, a 409 Conflict is returned. - Sets/unsets `completed_by_id` based on `is_complete` flag. - """ - user_email = current_user.email # Access email attribute before async operations - logger.info(f"User {user_email} attempting to update item ID: {item_id} with version {item_in.version}") - # Permission check is handled by get_item_and_verify_access dependency - - try: - updated_item = await crud_item.update_item( - db=db, item_db=item_db, item_in=item_in, user_id=current_user.id - ) - logger.info(f"Item {item_id} updated successfully by user {user_email} to version {updated_item.version}.") - return updated_item - except ConflictError as e: - logger.warning(f"Conflict updating item {item_id} for user {user_email}: {str(e)}") - raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) - except Exception as e: - logger.error(f"Error updating item {item_id} for user {user_email}: {str(e)}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the item.") - - -@router.delete( - "/items/{item_id}", # Operate directly on item ID - status_code=status.HTTP_204_NO_CONTENT, - summary="Delete Item", - tags=["Items"], - responses={ - status.HTTP_409_CONFLICT: {"description": "Conflict: Item has been modified, cannot delete specified version"} - } -) -async def delete_item( - item_id: int, # Item ID from path - expected_version: Optional[int] = Query(None, description="The expected version of the item to delete for optimistic locking."), - item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access - db: AsyncSession = Depends(get_db), - current_user: UserModel = Depends(get_current_user), # Log who deleted it -): - """ - Deletes an item. User must have access to the list the item belongs to. - If `expected_version` is provided and does not match the item's current version, - a 409 Conflict is returned. - """ - user_email = current_user.email # Access email attribute before async operations - logger.info(f"User {user_email} attempting to delete item ID: {item_id}, expected version: {expected_version}") - # Permission check is handled by get_item_and_verify_access dependency - - if expected_version is not None and item_db.version != expected_version: - logger.warning( - f"Conflict deleting item {item_id} for user {user_email}. " - f"Expected version {expected_version}, actual version {item_db.version}." - ) - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail=f"Item has been modified. Expected version {expected_version}, but current version is {item_db.version}. Please refresh." - ) - - await crud_item.delete_item(db=db, item_db=item_db) - logger.info(f"Item {item_id} (version {item_db.version}) deleted successfully by user {user_email}.") - return Response(status_code=status.HTTP_204_NO_CONTENT) - - - -# app/crud/group.py -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select -from sqlalchemy.orm import selectinload # For eager loading members -from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError -from typing import Optional, List -from sqlalchemy import func - -from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel -from app.schemas.group import GroupCreate -from app.models import UserRoleEnum # Import enum -from app.core.exceptions import ( - GroupOperationError, - GroupNotFoundError, - DatabaseConnectionError, - DatabaseIntegrityError, - DatabaseQueryError, - DatabaseTransactionError, - GroupMembershipError, - GroupPermissionError # Import GroupPermissionError -) - -# --- Group CRUD --- -async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) -> GroupModel: - """Creates a group and adds the creator as the owner.""" - try: - async with db.begin(): - db_group = GroupModel(name=group_in.name, created_by_id=creator_id) - db.add(db_group) - await db.flush() - - db_user_group = UserGroupModel( - user_id=creator_id, - group_id=db_group.id, - role=UserRoleEnum.owner - ) - db.add(db_user_group) - await db.flush() - await db.refresh(db_group) - return db_group - except IntegrityError as e: - raise DatabaseIntegrityError(f"Failed to create group: {str(e)}") - except OperationalError as e: - raise DatabaseConnectionError(f"Database connection error: {str(e)}") - except SQLAlchemyError as e: - raise DatabaseTransactionError(f"Failed to create group: {str(e)}") - -async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]: - """Gets all groups a user is a member of.""" - try: - result = await db.execute( - select(GroupModel) - .join(UserGroupModel) - .where(UserGroupModel.user_id == user_id) - .options(selectinload(GroupModel.member_associations)) - ) - return result.scalars().all() - except OperationalError as e: - raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") - except SQLAlchemyError as e: - raise DatabaseQueryError(f"Failed to query user groups: {str(e)}") - -async def get_group_by_id(db: AsyncSession, group_id: int) -> Optional[GroupModel]: - """Gets a single group by its ID, optionally loading members.""" - try: - result = await db.execute( - select(GroupModel) - .where(GroupModel.id == group_id) - .options( - selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user) - ) - ) - return result.scalars().first() - except OperationalError as e: - raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") - except SQLAlchemyError as e: - raise DatabaseQueryError(f"Failed to query group: {str(e)}") - -async def is_user_member(db: AsyncSession, group_id: int, user_id: int) -> bool: - """Checks if a user is a member of a specific group.""" - try: - result = await db.execute( - select(UserGroupModel.id) - .where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id) - .limit(1) - ) - return result.scalar_one_or_none() is not None - except OperationalError as e: - raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") - except SQLAlchemyError as e: - raise DatabaseQueryError(f"Failed to check group membership: {str(e)}") - -async def get_user_role_in_group(db: AsyncSession, group_id: int, user_id: int) -> Optional[UserRoleEnum]: - """Gets the role of a user in a specific group.""" - try: - result = await db.execute( - select(UserGroupModel.role) - .where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id) - ) - return result.scalar_one_or_none() - except OperationalError as e: - raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") - except SQLAlchemyError as e: - raise DatabaseQueryError(f"Failed to query user role: {str(e)}") - -async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role: UserRoleEnum = UserRoleEnum.member) -> Optional[UserGroupModel]: - """Adds a user to a group if they aren't already a member.""" - try: - async with db.begin(): - existing = await db.execute( - select(UserGroupModel).where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id) - ) - if existing.scalar_one_or_none(): - return None - - db_user_group = UserGroupModel(user_id=user_id, group_id=group_id, role=role) - db.add(db_user_group) - await db.flush() - await db.refresh(db_user_group) - return db_user_group - except IntegrityError as e: - raise DatabaseIntegrityError(f"Failed to add user to group: {str(e)}") - except OperationalError as e: - raise DatabaseConnectionError(f"Database connection error: {str(e)}") - except SQLAlchemyError as e: - raise DatabaseTransactionError(f"Failed to add user to group: {str(e)}") - -async def remove_user_from_group(db: AsyncSession, group_id: int, user_id: int) -> bool: - """Removes a user from a group.""" - try: - async with db.begin(): - result = await db.execute( - delete(UserGroupModel) - .where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id) - .returning(UserGroupModel.id) - ) - return result.scalar_one_or_none() is not None - except OperationalError as e: - raise DatabaseConnectionError(f"Database connection error: {str(e)}") - except SQLAlchemyError as e: - raise DatabaseTransactionError(f"Failed to remove user from group: {str(e)}") - -async def get_group_member_count(db: AsyncSession, group_id: int) -> int: - """Counts the number of members in a group.""" - try: - result = await db.execute( - select(func.count(UserGroupModel.id)).where(UserGroupModel.group_id == group_id) - ) - return result.scalar_one() - except OperationalError as e: - raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") - except SQLAlchemyError as e: - raise DatabaseQueryError(f"Failed to count group members: {str(e)}") - -async def check_group_membership( - db: AsyncSession, - group_id: int, - user_id: int, - action: str = "access this group" -) -> None: - """ - Checks if a user is a member of a group. Raises exceptions if not found or not a member. - - Raises: - GroupNotFoundError: If the group_id does not exist. - GroupMembershipError: If the user_id is not a member of the group. - """ - try: - # Check group existence first - group_exists = await db.get(GroupModel, group_id) - if not group_exists: - raise GroupNotFoundError(group_id) - - # Check membership - membership = await db.execute( - select(UserGroupModel.id) - .where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id) - .limit(1) - ) - if membership.scalar_one_or_none() is None: - raise GroupMembershipError(group_id, action=action) - # If we reach here, the user is a member - return None - except GroupNotFoundError: # Re-raise specific errors - raise - except GroupMembershipError: - raise - except OperationalError as e: - raise DatabaseConnectionError(f"Failed to connect to database while checking membership: {str(e)}") - except SQLAlchemyError as e: - raise DatabaseQueryError(f"Failed to check group membership: {str(e)}") - -async def check_user_role_in_group( - db: AsyncSession, - group_id: int, - user_id: int, - required_role: UserRoleEnum, - action: str = "perform this action" -) -> None: - """ - Checks if a user is a member of a group and has the required role (or higher). - - Raises: - GroupNotFoundError: If the group_id does not exist. - GroupMembershipError: If the user_id is not a member of the group. - GroupPermissionError: If the user does not have the required role. - """ - # First, ensure user is a member (this also checks group existence) - await check_group_membership(db, group_id, user_id, action=f"be checked for permissions to {action}") - - # Get the user's actual role - actual_role = await get_user_role_in_group(db, group_id, user_id) - - # Define role hierarchy (assuming owner > member) - role_hierarchy = {UserRoleEnum.owner: 2, UserRoleEnum.member: 1} - - if not actual_role or role_hierarchy.get(actual_role, 0) < role_hierarchy.get(required_role, 0): - raise GroupPermissionError( - group_id=group_id, - action=f"{action} (requires at least '{required_role.value}' role)" - ) - # If role is sufficient, return None - return None - - - -# app/main.py -import logging -import uvicorn -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware - -from app.api.api_router import api_router -from app.config import settings -from app.core.api_config import API_METADATA, API_TAGS -# Import database and models if needed for startup/shutdown events later -# from . import database, models - -# --- Logging Setup --- -logging.basicConfig( - level=getattr(logging, settings.LOG_LEVEL), - format=settings.LOG_FORMAT -) -logger = logging.getLogger(__name__) - -# --- FastAPI App Instance --- -app = FastAPI( - **API_METADATA, - openapi_tags=API_TAGS -) - -# --- CORS Middleware --- -# Define allowed origins. Be specific in production! -# Use ["*"] for wide open access during early development if needed, -# but restrict it as soon as possible. -# SvelteKit default dev port is 5173 -origins = [ - "http://localhost:5174", - "http://localhost:8000", # Allow requests from the API itself (e.g., Swagger UI) - # Add your deployed frontend URL here later - # "https://your-frontend-domain.com", -] - -app.add_middleware( - CORSMiddleware, - allow_origins=settings.CORS_ORIGINS, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) -# --- End CORS Middleware --- - - -# --- Include API Routers --- -# All API endpoints will be prefixed with /api -app.include_router(api_router, prefix=settings.API_PREFIX) -# --- End Include API Routers --- - - -# --- Root Endpoint (Optional - outside the main API structure) --- -@app.get("/", tags=["Root"]) -async def read_root(): - """ - Provides a simple welcome message at the root path. - Useful for basic reachability checks. - """ - logger.info("Root endpoint '/' accessed.") - return {"message": settings.ROOT_MESSAGE} -# --- End Root Endpoint --- - - -# --- Application Startup/Shutdown Events (Optional) --- -# @app.on_event("startup") -# async def startup_event(): -# logger.info("Application startup: Connecting to database...") -# # You might perform initial checks or warm-up here -# # await database.engine.connect() # Example check (get_db handles sessions per request) -# logger.info("Application startup complete.") - -# @app.on_event("shutdown") -# async def shutdown_event(): -# logger.info("Application shutdown: Disconnecting from database...") -# # await database.engine.dispose() # Close connection pool -# logger.info("Application shutdown complete.") -# --- End Events --- - - -# --- Direct Run (for simple local testing if needed) --- -# It's better to use `uvicorn app.main:app --reload` from the terminal -# if __name__ == "__main__": -# logger.info("Starting Uvicorn server directly from main.py") -# uvicorn.run(app, host="0.0.0.0", port=8000) -# ------------------------------------------------------ - - - - - - - - - -// API Version -export const API_VERSION = 'v1'; - -// API Base URL -export const API_BASE_URL = import.meta.env.VITE_API_URL || 'http://localhost:8000'; - -// API Endpoints -export const API_ENDPOINTS = { - // Auth - AUTH: { - LOGIN: '/auth/login', - SIGNUP: '/auth/signup', - REFRESH_TOKEN: '/auth/refresh-token', - LOGOUT: '/auth/logout', - VERIFY_EMAIL: '/auth/verify-email', - RESET_PASSWORD: '/auth/reset-password', - FORGOT_PASSWORD: '/auth/forgot-password', - }, - - // Users - USERS: { - PROFILE: '/users/me', - UPDATE_PROFILE: '/users/me', - PASSWORD: '/users/password', - AVATAR: '/users/avatar', - SETTINGS: '/users/settings', - NOTIFICATIONS: '/users/notifications', - PREFERENCES: '/users/preferences', - }, - - // Lists - LISTS: { - BASE: '/lists', - BY_ID: (id: string) => `/lists/${id}`, - ITEMS: (listId: string) => `/lists/${listId}/items`, - ITEM: (listId: string, itemId: string) => `/lists/${listId}/items/${itemId}`, - SHARE: (listId: string) => `/lists/${listId}/share`, - UNSHARE: (listId: string) => `/lists/${listId}/unshare`, - COMPLETE: (listId: string) => `/lists/${listId}/complete`, - REOPEN: (listId: string) => `/lists/${listId}/reopen`, - ARCHIVE: (listId: string) => `/lists/${listId}/archive`, - RESTORE: (listId: string) => `/lists/${listId}/restore`, - DUPLICATE: (listId: string) => `/lists/${listId}/duplicate`, - EXPORT: (listId: string) => `/lists/${listId}/export`, - IMPORT: '/lists/import', - }, - - // Groups - GROUPS: { - BASE: '/groups', - BY_ID: (id: string) => `/groups/${id}`, - LISTS: (groupId: string) => `/groups/${groupId}/lists`, - MEMBERS: (groupId: string) => `/groups/${groupId}/members`, - MEMBER: (groupId: string, userId: string) => `/groups/${groupId}/members/${userId}`, - LEAVE: (groupId: string) => `/groups/${groupId}/leave`, - DELETE: (groupId: string) => `/groups/${groupId}`, - SETTINGS: (groupId: string) => `/groups/${groupId}/settings`, - ROLES: (groupId: string) => `/groups/${groupId}/roles`, - ROLE: (groupId: string, roleId: string) => `/groups/${groupId}/roles/${roleId}`, - }, - - // Invites - INVITES: { - BASE: '/invites', - BY_ID: (id: string) => `/invites/${id}`, - ACCEPT: (id: string) => `/invites/${id}/accept`, - DECLINE: (id: string) => `/invites/${id}/decline`, - REVOKE: (id: string) => `/invites/${id}/revoke`, - LIST: '/invites', - PENDING: '/invites/pending', - SENT: '/invites/sent', - }, - - // Items (for direct operations like update, get by ID) - ITEMS: { - BY_ID: (itemId: string) => `/items/${itemId}`, - }, - - // OCR - OCR: { - PROCESS: '/ocr/extract-items', - STATUS: (jobId: string) => `/ocr/status/${jobId}`, - RESULT: (jobId: string) => `/ocr/result/${jobId}`, - BATCH: '/ocr/batch', - CANCEL: (jobId: string) => `/ocr/cancel/${jobId}`, - HISTORY: '/ocr/history', - }, - - // Costs - COSTS: { - BASE: '/costs', - LIST_SUMMARY: (listId: string | number) => `/costs/lists/${listId}/cost-summary`, - GROUP_BALANCE_SUMMARY: (groupId: string | number) => `/costs/groups/${groupId}/balance-summary`, - }, - - // Financials - FINANCIALS: { - EXPENSES: '/financials/expenses', - EXPENSE: (id: string) => `/financials/expenses/${id}`, - SETTLEMENTS: '/financials/settlements', - SETTLEMENT: (id: string) => `/financials/settlements/${id}`, - BALANCES: '/financials/balances', - BALANCE: (userId: string) => `/financials/balances/${userId}`, - REPORTS: '/financials/reports', - REPORT: (id: string) => `/financials/reports/${id}`, - CATEGORIES: '/financials/categories', - CATEGORY: (id: string) => `/financials/categories/${id}`, - }, - - // Health - HEALTH: { - CHECK: '/health', - VERSION: '/health/version', - STATUS: '/health/status', - METRICS: '/health/metrics', - LOGS: '/health/logs', - }, -}; - - - -from fastapi import HTTPException, status -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials -from app.config import settings -from typing import Optional - -class ListNotFoundError(HTTPException): - """Raised when a list is not found.""" - def __init__(self, list_id: int): - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"List {list_id} not found" - ) - -class ListPermissionError(HTTPException): - """Raised when a user doesn't have permission to access a list.""" - def __init__(self, list_id: int, action: str = "access"): - super().__init__( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"You do not have permission to {action} list {list_id}" - ) - -class ListCreatorRequiredError(HTTPException): - """Raised when an action requires the list creator but the user is not the creator.""" - def __init__(self, list_id: int, action: str): - super().__init__( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"Only the list creator can {action} list {list_id}" - ) - -class GroupNotFoundError(HTTPException): - """Raised when a group is not found.""" - def __init__(self, group_id: int): - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Group {group_id} not found" - ) - -class GroupPermissionError(HTTPException): - """Raised when a user doesn't have permission to perform an action in a group.""" - def __init__(self, group_id: int, action: str): - super().__init__( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"You do not have permission to {action} in group {group_id}" - ) - -class GroupMembershipError(HTTPException): - """Raised when a user attempts to perform an action that requires group membership.""" - def __init__(self, group_id: int, action: str = "access"): - super().__init__( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"You must be a member of group {group_id} to {action}" - ) - -class GroupOperationError(HTTPException): - """Raised when a group operation fails.""" - def __init__(self, detail: str): - super().__init__( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=detail - ) - -class GroupValidationError(HTTPException): - """Raised when a group operation is invalid.""" - def __init__(self, detail: str): - super().__init__( - status_code=status.HTTP_400_BAD_REQUEST, - detail=detail - ) - -class ItemNotFoundError(HTTPException): - """Raised when an item is not found.""" - def __init__(self, item_id: int): - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Item {item_id} not found" - ) - -class UserNotFoundError(HTTPException): - """Raised when a user is not found.""" - def __init__(self, user_id: Optional[int] = None, identifier: Optional[str] = None): - detail_msg = "User not found." - if user_id: - detail_msg = f"User with ID {user_id} not found." - elif identifier: - detail_msg = f"User with identifier '{identifier}' not found." - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=detail_msg - ) - -class InvalidOperationError(HTTPException): - """Raised when an operation is invalid or disallowed by business logic.""" - def __init__(self, detail: str, status_code: int = status.HTTP_400_BAD_REQUEST): - super().__init__( - status_code=status_code, - detail=detail - ) - -class DatabaseConnectionError(HTTPException): - """Raised when there is an error connecting to the database.""" - def __init__(self): - super().__init__( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail=settings.DB_CONNECTION_ERROR - ) - -class DatabaseIntegrityError(HTTPException): - """Raised when a database integrity constraint is violated.""" - def __init__(self): - super().__init__( - status_code=status.HTTP_400_BAD_REQUEST, - detail=settings.DB_INTEGRITY_ERROR - ) - -class DatabaseTransactionError(HTTPException): - """Raised when a database transaction fails.""" - def __init__(self, detail: str = settings.DB_TRANSACTION_ERROR): - super().__init__( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=detail - ) - -class DatabaseQueryError(HTTPException): - """Raised when a database query fails.""" - def __init__(self, detail: str = settings.DB_QUERY_ERROR): - super().__init__( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=detail - ) - -class OCRServiceUnavailableError(HTTPException): - """Raised when the OCR service is unavailable.""" - def __init__(self): - super().__init__( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail=settings.OCR_SERVICE_UNAVAILABLE - ) - -class OCRServiceConfigError(HTTPException): - """Raised when there is an error in the OCR service configuration.""" - def __init__(self): - super().__init__( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=settings.OCR_SERVICE_CONFIG_ERROR - ) - -class OCRUnexpectedError(HTTPException): - """Raised when there is an unexpected error in the OCR service.""" - def __init__(self): - super().__init__( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=settings.OCR_UNEXPECTED_ERROR - ) - -class OCRQuotaExceededError(HTTPException): - """Raised when the OCR service quota is exceeded.""" - def __init__(self): - super().__init__( - status_code=status.HTTP_429_TOO_MANY_REQUESTS, - detail=settings.OCR_QUOTA_EXCEEDED - ) - -class InvalidFileTypeError(HTTPException): - """Raised when an invalid file type is uploaded for OCR.""" - def __init__(self): - super().__init__( - status_code=status.HTTP_400_BAD_REQUEST, - detail=settings.OCR_INVALID_FILE_TYPE.format(types=", ".join(settings.ALLOWED_IMAGE_TYPES)) - ) - -class FileTooLargeError(HTTPException): - """Raised when an uploaded file exceeds the size limit.""" - def __init__(self): - super().__init__( - status_code=status.HTTP_400_BAD_REQUEST, - detail=settings.OCR_FILE_TOO_LARGE.format(size=settings.MAX_FILE_SIZE_MB) - ) - -class OCRProcessingError(HTTPException): - """Raised when there is an error processing the image with OCR.""" - def __init__(self, detail: str): - super().__init__( - status_code=status.HTTP_400_BAD_REQUEST, - detail=settings.OCR_PROCESSING_ERROR.format(detail=detail) - ) - -class EmailAlreadyRegisteredError(HTTPException): - """Raised when attempting to register with an email that is already in use.""" - def __init__(self): - super().__init__( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Email already registered." - ) - -class UserCreationError(HTTPException): - """Raised when there is an error creating a new user.""" - def __init__(self): - super().__init__( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="An error occurred during user creation." - ) - -class InviteNotFoundError(HTTPException): - """Raised when an invite is not found.""" - def __init__(self, invite_code: str): - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Invite code {invite_code} not found" - ) - -class InviteExpiredError(HTTPException): - """Raised when an invite has expired.""" - def __init__(self, invite_code: str): - super().__init__( - status_code=status.HTTP_410_GONE, - detail=f"Invite code {invite_code} has expired" - ) - -class InviteAlreadyUsedError(HTTPException): - """Raised when an invite has already been used.""" - def __init__(self, invite_code: str): - super().__init__( - status_code=status.HTTP_410_GONE, - detail=f"Invite code {invite_code} has already been used" - ) - -class InviteCreationError(HTTPException): - """Raised when an invite cannot be created.""" - def __init__(self, group_id: int): - super().__init__( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to create invite for group {group_id}" - ) - -class ListStatusNotFoundError(HTTPException): - """Raised when a list's status cannot be retrieved.""" - def __init__(self, list_id: int): - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Status for list {list_id} not found" - ) - -class ConflictError(HTTPException): - """Raised when an optimistic lock version conflict occurs.""" - def __init__(self, detail: str): - super().__init__( - status_code=status.HTTP_409_CONFLICT, - detail=detail - ) - -class InvalidCredentialsError(HTTPException): - """Raised when login credentials are invalid.""" - def __init__(self): - super().__init__( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=settings.AUTH_INVALID_CREDENTIALS, - headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_credentials\""} - ) - -class NotAuthenticatedError(HTTPException): - """Raised when the user is not authenticated.""" - def __init__(self): - super().__init__( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=settings.AUTH_NOT_AUTHENTICATED, - headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"not_authenticated\""} - ) - -class JWTError(HTTPException): - """Raised when there is an error with the JWT token.""" - def __init__(self, error: str): - super().__init__( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=settings.JWT_ERROR.format(error=error), - headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""} - ) - -class JWTUnexpectedError(HTTPException): - """Raised when there is an unexpected error with the JWT token.""" - def __init__(self, error: str): - super().__init__( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=settings.JWT_UNEXPECTED_ERROR.format(error=error), - headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""} - ) - - - -# app/crud/list.py -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select -from sqlalchemy.orm import selectinload, joinedload -from sqlalchemy import or_, and_, delete as sql_delete, func as sql_func, desc -from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError -from typing import Optional, List as PyList - -from app.schemas.list import ListStatus -from app.models import List as ListModel, UserGroup as UserGroupModel, Item as ItemModel -from app.schemas.list import ListCreate, ListUpdate -from app.core.exceptions import ( - ListNotFoundError, - ListPermissionError, - ListCreatorRequiredError, - DatabaseConnectionError, - DatabaseIntegrityError, - DatabaseQueryError, - DatabaseTransactionError, - ConflictError -) - -async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel: - """Creates a new list record.""" - try: - async with db.begin(): - db_list = ListModel( - name=list_in.name, - description=list_in.description, - group_id=list_in.group_id, - created_by_id=creator_id, - is_complete=False - ) - db.add(db_list) - await db.flush() - await db.refresh(db_list) - return db_list - except IntegrityError as e: - raise DatabaseIntegrityError(f"Failed to create list: {str(e)}") - except OperationalError as e: - raise DatabaseConnectionError(f"Database connection error: {str(e)}") - except SQLAlchemyError as e: - raise DatabaseTransactionError(f"Failed to create list: {str(e)}") - -async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel]: - """Gets all lists accessible by a user.""" - try: - group_ids_result = await db.execute( - select(UserGroupModel.group_id).where(UserGroupModel.user_id == user_id) - ) - user_group_ids = group_ids_result.scalars().all() - - # Build conditions for the OR clause dynamically - conditions = [ - and_(ListModel.created_by_id == user_id, ListModel.group_id.is_(None)) - ] - if user_group_ids: # Only add the IN clause if there are group IDs - conditions.append(ListModel.group_id.in_(user_group_ids)) - - query = select(ListModel).where(or_(*conditions)).order_by(ListModel.updated_at.desc()) - - result = await db.execute(query) - return result.scalars().all() - except OperationalError as e: - raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") - except SQLAlchemyError as e: - raise DatabaseQueryError(f"Failed to query user lists: {str(e)}") - -async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = False) -> Optional[ListModel]: - """Gets a single list by ID, optionally loading its items.""" - try: - query = select(ListModel).where(ListModel.id == list_id) - if load_items: - query = query.options( - selectinload(ListModel.items) - .options( - joinedload(ItemModel.added_by_user), - joinedload(ItemModel.completed_by_user) - ) - ) - result = await db.execute(query) - return result.scalars().first() - except OperationalError as e: - raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") - except SQLAlchemyError as e: - raise DatabaseQueryError(f"Failed to query list: {str(e)}") - -async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel: - """Updates an existing list record, checking for version conflicts.""" - try: - async with db.begin(): - if list_db.version != list_in.version: - raise ConflictError( - f"List '{list_db.name}' (ID: {list_db.id}) has been modified. " - f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh." - ) - - update_data = list_in.model_dump(exclude_unset=True, exclude={'version'}) - - for key, value in update_data.items(): - setattr(list_db, key, value) - - list_db.version += 1 - - db.add(list_db) - await db.flush() - await db.refresh(list_db) - return list_db - except IntegrityError as e: - await db.rollback() - raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}") - except OperationalError as e: - await db.rollback() - raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}") - except ConflictError: - await db.rollback() - raise - except SQLAlchemyError as e: - await db.rollback() - raise DatabaseTransactionError(f"Failed to update list: {str(e)}") - -async def delete_list(db: AsyncSession, list_db: ListModel) -> None: - """Deletes a list record. Version check should be done by the caller (API endpoint).""" - try: - async with db.begin(): - await db.delete(list_db) - return None - except OperationalError as e: - await db.rollback() - raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}") - except SQLAlchemyError as e: - await db.rollback() - raise DatabaseTransactionError(f"Failed to delete list: {str(e)}") - -async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel: - """Fetches a list and verifies user permission.""" - try: - list_db = await get_list_by_id(db, list_id=list_id, load_items=True) - if not list_db: - raise ListNotFoundError(list_id) - - is_creator = list_db.created_by_id == user_id - - if require_creator: - if not is_creator: - raise ListCreatorRequiredError(list_id, "access") - return list_db - - if is_creator: - return list_db - - if list_db.group_id: - from app.crud.group import is_user_member - is_member = await is_user_member(db, group_id=list_db.group_id, user_id=user_id) - if not is_member: - raise ListPermissionError(list_id) - return list_db - else: - raise ListPermissionError(list_id) - except OperationalError as e: - raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") - except SQLAlchemyError as e: - raise DatabaseQueryError(f"Failed to check list permissions: {str(e)}") - -async def get_list_status(db: AsyncSession, list_id: int) -> ListStatus: - """Gets the update timestamps and item count for a list.""" - try: - list_query = select(ListModel.updated_at).where(ListModel.id == list_id) - list_result = await db.execute(list_query) - list_updated_at = list_result.scalar_one_or_none() - - if list_updated_at is None: - raise ListNotFoundError(list_id) - - item_status_query = ( - select( - sql_func.max(ItemModel.updated_at).label("latest_item_updated_at"), - sql_func.count(ItemModel.id).label("item_count") - ) - .where(ItemModel.list_id == list_id) - ) - item_result = await db.execute(item_status_query) - item_status = item_result.first() - - return ListStatus( - list_updated_at=list_updated_at, - latest_item_updated_at=item_status.latest_item_updated_at if item_status else None, - item_count=item_status.item_count if item_status else 0 - ) - except OperationalError as e: - raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") - except SQLAlchemyError as e: - raise DatabaseQueryError(f"Failed to get list status: {str(e)}") - - - -# app/models.py -import enum -import secrets -from datetime import datetime, timedelta, timezone - -from sqlalchemy import ( - Column, - Integer, - String, - DateTime, - ForeignKey, - Boolean, - Enum as SAEnum, - UniqueConstraint, - Index, - DDL, - event, - delete, - func, - text as sa_text, - Text, # <-- Add Text for description - Numeric # <-- Add Numeric for price -) -from sqlalchemy.orm import relationship, backref - -from .database import Base - -# --- Enums --- -class UserRoleEnum(enum.Enum): - owner = "owner" - member = "member" - -class SplitTypeEnum(enum.Enum): - EQUAL = "EQUAL" # Split equally among all involved users - EXACT_AMOUNTS = "EXACT_AMOUNTS" # Specific amounts for each user (defined in ExpenseSplit) - PERCENTAGE = "PERCENTAGE" # Percentage for each user (defined in ExpenseSplit) - SHARES = "SHARES" # Proportional to shares/units (defined in ExpenseSplit) - ITEM_BASED = "ITEM_BASED" # If an expense is derived directly from item prices and who added them - # Add more types as needed, e.g., UNPAID (for tracking debts not part of a formal expense) - -# --- User Model --- -class User(Base): - __tablename__ = "users" - - id = Column(Integer, primary_key=True, index=True) - email = Column(String, unique=True, index=True, nullable=False) - password_hash = Column(String, nullable=False) - name = Column(String, index=True, nullable=True) - created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - - # --- Relationships --- - created_groups = relationship("Group", back_populates="creator") - group_associations = relationship("UserGroup", back_populates="user", cascade="all, delete-orphan") - created_invites = relationship("Invite", back_populates="creator") - - # --- NEW Relationships for Lists/Items --- - created_lists = relationship("List", foreign_keys="List.created_by_id", back_populates="creator") # Link List.created_by_id -> User - added_items = relationship("Item", foreign_keys="Item.added_by_id", back_populates="added_by_user") # Link Item.added_by_id -> User - completed_items = relationship("Item", foreign_keys="Item.completed_by_id", back_populates="completed_by_user") # Link Item.completed_by_id -> User - # --- End NEW Relationships --- - - # --- Relationships for Cost Splitting --- - expenses_paid = relationship("Expense", foreign_keys="Expense.paid_by_user_id", back_populates="paid_by_user", cascade="all, delete-orphan") - expense_splits = relationship("ExpenseSplit", foreign_keys="ExpenseSplit.user_id", back_populates="user", cascade="all, delete-orphan") - settlements_made = relationship("Settlement", foreign_keys="Settlement.paid_by_user_id", back_populates="payer", cascade="all, delete-orphan") - settlements_received = relationship("Settlement", foreign_keys="Settlement.paid_to_user_id", back_populates="payee", cascade="all, delete-orphan") - # --- End Relationships for Cost Splitting --- - - -# --- Group Model --- -class Group(Base): - __tablename__ = "groups" - - id = Column(Integer, primary_key=True, index=True) - name = Column(String, index=True, nullable=False) - created_by_id = Column(Integer, ForeignKey("users.id"), nullable=False) - created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - - # --- Relationships --- - creator = relationship("User", back_populates="created_groups") - member_associations = relationship("UserGroup", back_populates="group", cascade="all, delete-orphan") - invites = relationship("Invite", back_populates="group", cascade="all, delete-orphan") - - # --- NEW Relationship for Lists --- - lists = relationship("List", back_populates="group", cascade="all, delete-orphan") # Link List.group_id -> Group - # --- End NEW Relationship --- - - # --- Relationships for Cost Splitting --- - expenses = relationship("Expense", foreign_keys="Expense.group_id", back_populates="group", cascade="all, delete-orphan") - settlements = relationship("Settlement", foreign_keys="Settlement.group_id", back_populates="group", cascade="all, delete-orphan") - # --- End Relationships for Cost Splitting --- - - -# --- UserGroup Association Model --- -class UserGroup(Base): - __tablename__ = "user_groups" - __table_args__ = (UniqueConstraint('user_id', 'group_id', name='uq_user_group'),) - - id = Column(Integer, primary_key=True, index=True) - user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) - group_id = Column(Integer, ForeignKey("groups.id", ondelete="CASCADE"), nullable=False) - role = Column(SAEnum(UserRoleEnum, name="userroleenum", create_type=True), nullable=False, default=UserRoleEnum.member) - joined_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - - user = relationship("User", back_populates="group_associations") - group = relationship("Group", back_populates="member_associations") - - -# --- Invite Model --- -class Invite(Base): - __tablename__ = "invites" - __table_args__ = ( - Index('ix_invites_active_code', 'code', unique=True, postgresql_where=sa_text('is_active = true')), - ) - - id = Column(Integer, primary_key=True, index=True) - code = Column(String, unique=False, index=True, nullable=False, default=lambda: secrets.token_urlsafe(16)) - group_id = Column(Integer, ForeignKey("groups.id", ondelete="CASCADE"), nullable=False) - created_by_id = Column(Integer, ForeignKey("users.id"), nullable=False) - created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - expires_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc) + timedelta(days=7)) - is_active = Column(Boolean, default=True, nullable=False) - - group = relationship("Group", back_populates="invites") - creator = relationship("User", back_populates="created_invites") - - -# === NEW: List Model === -class List(Base): - __tablename__ = "lists" - - id = Column(Integer, primary_key=True, index=True) - name = Column(String, index=True, nullable=False) - description = Column(Text, nullable=True) - created_by_id = Column(Integer, ForeignKey("users.id"), nullable=False) # Who created this list - group_id = Column(Integer, ForeignKey("groups.id"), nullable=True) # Which group it belongs to (NULL if personal) - is_complete = Column(Boolean, default=False, nullable=False) - created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) - version = Column(Integer, nullable=False, default=1, server_default='1') - - # --- Relationships --- - creator = relationship("User", back_populates="created_lists") # Link to User.created_lists - group = relationship("Group", back_populates="lists") # Link to Group.lists - items = relationship("Item", back_populates="list", cascade="all, delete-orphan", order_by="Item.created_at") # Link to Item.list, cascade deletes - - # --- Relationships for Cost Splitting --- - expenses = relationship("Expense", foreign_keys="Expense.list_id", back_populates="list", cascade="all, delete-orphan") - # --- End Relationships for Cost Splitting --- - - -# === NEW: Item Model === -class Item(Base): - __tablename__ = "items" - - id = Column(Integer, primary_key=True, index=True) - list_id = Column(Integer, ForeignKey("lists.id", ondelete="CASCADE"), nullable=False) # Belongs to which list - name = Column(String, index=True, nullable=False) - quantity = Column(String, nullable=True) # Flexible quantity (e.g., "1", "2 lbs", "a bunch") - is_complete = Column(Boolean, default=False, nullable=False) - price = Column(Numeric(10, 2), nullable=True) # For cost splitting later (e.g., 12345678.99) - added_by_id = Column(Integer, ForeignKey("users.id"), nullable=False) # Who added this item - completed_by_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Who marked it complete - created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) - version = Column(Integer, nullable=False, default=1, server_default='1') - - # --- Relationships --- - list = relationship("List", back_populates="items") # Link to List.items - added_by_user = relationship("User", foreign_keys=[added_by_id], back_populates="added_items") # Link to User.added_items - completed_by_user = relationship("User", foreign_keys=[completed_by_id], back_populates="completed_items") # Link to User.completed_items - - # --- Relationships for Cost Splitting --- - # If an item directly results in an expense, or an expense can be tied to an item. - expenses = relationship("Expense", back_populates="item") # An item might have multiple associated expenses - # --- End Relationships for Cost Splitting --- - - -# === NEW Models for Advanced Cost Splitting === - -class Expense(Base): - __tablename__ = "expenses" - - id = Column(Integer, primary_key=True, index=True) - description = Column(String, nullable=False) - total_amount = Column(Numeric(10, 2), nullable=False) - currency = Column(String, nullable=False, default="USD") # Consider making this an Enum too if few currencies - expense_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - split_type = Column(SAEnum(SplitTypeEnum, name="splittypeenum", create_type=True), nullable=False) - - # Foreign Keys - list_id = Column(Integer, ForeignKey("lists.id"), nullable=True) - group_id = Column(Integer, ForeignKey("groups.id"), nullable=True) # If not list-specific but group-specific - item_id = Column(Integer, ForeignKey("items.id"), nullable=True) # If the expense is for a specific item - paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False) - - created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) - version = Column(Integer, nullable=False, default=1, server_default='1') - - # Relationships - paid_by_user = relationship("User", foreign_keys=[paid_by_user_id], back_populates="expenses_paid") - list = relationship("List", foreign_keys=[list_id], back_populates="expenses") - group = relationship("Group", foreign_keys=[group_id], back_populates="expenses") - item = relationship("Item", foreign_keys=[item_id], back_populates="expenses") - splits = relationship("ExpenseSplit", back_populates="expense", cascade="all, delete-orphan") - - __table_args__ = ( - # Example: Ensure either list_id or group_id is present if item_id is null - # CheckConstraint('(item_id IS NOT NULL) OR (list_id IS NOT NULL) OR (group_id IS NOT NULL)', name='chk_expense_context'), - ) - -class ExpenseSplit(Base): - __tablename__ = "expense_splits" - __table_args__ = (UniqueConstraint('expense_id', 'user_id', name='uq_expense_user_split'),) - - id = Column(Integer, primary_key=True, index=True) - expense_id = Column(Integer, ForeignKey("expenses.id", ondelete="CASCADE"), nullable=False) - user_id = Column(Integer, ForeignKey("users.id"), nullable=False) - - owed_amount = Column(Numeric(10, 2), nullable=False) # For EQUAL or EXACT_AMOUNTS - # For PERCENTAGE split (value from 0.00 to 100.00) - share_percentage = Column(Numeric(5, 2), nullable=True) - # For SHARES split (e.g., user A has 2 shares, user B has 3 shares) - share_units = Column(Integer, nullable=True) - - # is_settled might be better tracked via actual Settlement records or a reconciliation process - # is_settled = Column(Boolean, default=False, nullable=False) - - created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) - - # Relationships - expense = relationship("Expense", back_populates="splits") - user = relationship("User", foreign_keys=[user_id], back_populates="expense_splits") - - -class Settlement(Base): - __tablename__ = "settlements" - - id = Column(Integer, primary_key=True, index=True) - group_id = Column(Integer, ForeignKey("groups.id"), nullable=False) # Settlements usually within a group - paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False) - paid_to_user_id = Column(Integer, ForeignKey("users.id"), nullable=False) - amount = Column(Numeric(10, 2), nullable=False) - settlement_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - description = Column(Text, nullable=True) - - created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) - version = Column(Integer, nullable=False, default=1, server_default='1') - - # Relationships - group = relationship("Group", foreign_keys=[group_id], back_populates="settlements") - payer = relationship("User", foreign_keys=[paid_by_user_id], back_populates="settlements_made") - payee = relationship("User", foreign_keys=[paid_to_user_id], back_populates="settlements_received") - - __table_args__ = ( - # Ensure payer and payee are different users - # CheckConstraint('paid_by_user_id <> paid_to_user_id', name='chk_settlement_payer_ne_payee'), - ) - -# Potential future: PaymentMethod model, etc. - - - - - - - - - - - - - - - - - - - -# app/crud/item.py -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select -from sqlalchemy import delete as sql_delete, update as sql_update # Use aliases -from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError -from typing import Optional, List as PyList -from datetime import datetime, timezone - -from app.models import Item as ItemModel -from app.schemas.item import ItemCreate, ItemUpdate -from app.core.exceptions import ( - ItemNotFoundError, - DatabaseConnectionError, - DatabaseIntegrityError, - DatabaseQueryError, - DatabaseTransactionError, - ConflictError -) - -async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel: - """Creates a new item record for a specific list.""" - try: - db_item = ItemModel( - name=item_in.name, - quantity=item_in.quantity, - list_id=list_id, - added_by_id=user_id, - is_complete=False # Default on creation - # version is implicitly set to 1 by model default - ) - db.add(db_item) - await db.flush() - await db.refresh(db_item) - await db.commit() # Explicitly commit here - return db_item - except IntegrityError as e: - await db.rollback() # Rollback on integrity error - raise DatabaseIntegrityError(f"Failed to create item: {str(e)}") - except OperationalError as e: - await db.rollback() # Rollback on operational error - raise DatabaseConnectionError(f"Database connection error: {str(e)}") - except SQLAlchemyError as e: - await db.rollback() # Rollback on other SQLAlchemy errors - raise DatabaseTransactionError(f"Failed to create item: {str(e)}") - except Exception as e: # Catch any other exception and attempt rollback - await db.rollback() - raise # Re-raise the original exception - -async def get_items_by_list_id(db: AsyncSession, list_id: int) -> PyList[ItemModel]: - """Gets all items belonging to a specific list, ordered by creation time.""" - try: - result = await db.execute( - select(ItemModel) - .where(ItemModel.list_id == list_id) - .order_by(ItemModel.created_at.asc()) # Or desc() if preferred - ) - return result.scalars().all() - except OperationalError as e: - raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") - except SQLAlchemyError as e: - raise DatabaseQueryError(f"Failed to query items: {str(e)}") - -async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]: - """Gets a single item by its ID.""" - try: - result = await db.execute(select(ItemModel).where(ItemModel.id == item_id)) - return result.scalars().first() - except OperationalError as e: - raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") - except SQLAlchemyError as e: - raise DatabaseQueryError(f"Failed to query item: {str(e)}") - -async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel: - """Updates an existing item record, checking for version conflicts.""" - try: - # Check version conflict - if item_db.version != item_in.version: - raise ConflictError( - f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. " - f"Your version is {item_in.version}, current version is {item_db.version}. Please refresh." - ) - - update_data = item_in.model_dump(exclude_unset=True, exclude={'version'}) # Exclude version - - # Special handling for is_complete - if 'is_complete' in update_data: - if update_data['is_complete'] is True: - if item_db.completed_by_id is None: # Only set if not already completed by someone - update_data['completed_by_id'] = user_id - else: - update_data['completed_by_id'] = None # Clear if marked incomplete - - # Apply updates - for key, value in update_data.items(): - setattr(item_db, key, value) - - item_db.version += 1 # Increment version - - db.add(item_db) - await db.flush() - await db.refresh(item_db) - - # Commit the transaction if not part of a larger transaction - await db.commit() - - return item_db - except IntegrityError as e: - await db.rollback() - raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}") - except OperationalError as e: - await db.rollback() - raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}") - except ConflictError: # Re-raise ConflictError - await db.rollback() - raise - except SQLAlchemyError as e: - await db.rollback() - raise DatabaseTransactionError(f"Failed to update item: {str(e)}") - -async def delete_item(db: AsyncSession, item_db: ItemModel) -> None: - """Deletes an item record. Version check should be done by the caller (API endpoint).""" - try: - await db.delete(item_db) - await db.commit() - return None - except OperationalError as e: - await db.rollback() - raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}") - except SQLAlchemyError as e: - await db.rollback() - raise DatabaseTransactionError(f"Failed to delete item: {str(e)}") - - - - - - - - - - - -from fastapi import APIRouter - -from app.api.v1.endpoints import health -from app.api.v1.endpoints import auth -from app.api.v1.endpoints import users -from app.api.v1.endpoints import groups -from app.api.v1.endpoints import invites -from app.api.v1.endpoints import lists -from app.api.v1.endpoints import items -from app.api.v1.endpoints import ocr -from app.api.v1.endpoints import costs -from app.api.v1.endpoints import financials - -api_router_v1 = APIRouter() - -api_router_v1.include_router(health.router) -api_router_v1.include_router(auth.router, prefix="/auth", tags=["Authentication"]) -api_router_v1.include_router(users.router, prefix="/users", tags=["Users"]) -api_router_v1.include_router(groups.router, prefix="/groups", tags=["Groups"]) -api_router_v1.include_router(invites.router, prefix="/invites", tags=["Invites"]) -api_router_v1.include_router(lists.router, prefix="/lists", tags=["Lists"]) -api_router_v1.include_router(items.router, tags=["Items"]) -api_router_v1.include_router(ocr.router, prefix="/ocr", tags=["OCR"]) -api_router_v1.include_router(costs.router, prefix="/costs", tags=["Costs"]) -api_router_v1.include_router(financials.router) -# Add other v1 endpoint routers here later -# e.g., api_router_v1.include_router(users.router, prefix="/users", tags=["Users"]) - - - -# app/config.py -import os -from pydantic_settings import BaseSettings -from dotenv import load_dotenv -import logging -import secrets - -load_dotenv() -logger = logging.getLogger(__name__) - -class Settings(BaseSettings): - DATABASE_URL: str | None = None - GEMINI_API_KEY: str | None = None - - # --- JWT Settings --- - SECRET_KEY: str # Must be set via environment variable - ALGORITHM: str = "HS256" - ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # Default token lifetime: 30 minutes - REFRESH_TOKEN_EXPIRE_MINUTES: int = 10080 # Default refresh token lifetime: 7 days - - # --- OCR Settings --- - MAX_FILE_SIZE_MB: int = 10 # Maximum allowed file size for OCR processing - ALLOWED_IMAGE_TYPES: list[str] = ["image/jpeg", "image/png", "image/webp"] # Supported image formats - OCR_ITEM_EXTRACTION_PROMPT: str = """ -Extract the shopping list items from this image. -List each distinct item on a new line. -Ignore prices, quantities, store names, discounts, taxes, totals, and other non-item text. -Focus only on the names of the products or items to be purchased. -Add 2 underscores before and after the item name, if it is struck through. -If the image does not appear to be a shopping list or receipt, state that clearly. -Example output for a grocery list: -Milk -Eggs -Bread -__Apples__ -Organic Bananas -""" - - # --- Gemini AI Settings --- - GEMINI_MODEL_NAME: str = "gemini-2.0-flash" # The model to use for OCR - GEMINI_SAFETY_SETTINGS: dict = { - "HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE", - "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE", - "HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE", - "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE", - } - GEMINI_GENERATION_CONFIG: dict = { - "candidate_count": 1, - "max_output_tokens": 2048, - "temperature": 0.9, - "top_p": 1, - "top_k": 1 - } - - # --- API Settings --- - API_PREFIX: str = "/api" # Base path for all API endpoints - API_OPENAPI_URL: str = "/api/openapi.json" - API_DOCS_URL: str = "/api/docs" - API_REDOC_URL: str = "/api/redoc" - CORS_ORIGINS: list[str] = [ - "http://localhost:5174", - "http://localhost:8000", - "http://localhost:9000", - # Add your deployed frontend URL here later - # "https://your-frontend-domain.com", - ] - - # --- API Metadata --- - API_TITLE: str = "Shared Lists API" - API_DESCRIPTION: str = "API for managing shared shopping lists, OCR, and cost splitting." - API_VERSION: str = "0.1.0" - ROOT_MESSAGE: str = "Welcome to the Shared Lists API! Docs available at /api/docs" - - # --- Logging Settings --- - LOG_LEVEL: str = "INFO" - LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - - # --- Health Check Settings --- - HEALTH_STATUS_OK: str = "ok" - HEALTH_STATUS_ERROR: str = "error" - - # --- Auth Settings --- - OAUTH2_TOKEN_URL: str = "/api/v1/auth/login" # Path to login endpoint - TOKEN_TYPE: str = "bearer" # Default token type for OAuth2 - AUTH_HEADER_PREFIX: str = "Bearer" # Prefix for Authorization header - AUTH_HEADER_NAME: str = "WWW-Authenticate" # Name of auth header - AUTH_CREDENTIALS_ERROR: str = "Could not validate credentials" - AUTH_INVALID_CREDENTIALS: str = "Incorrect email or password" - AUTH_NOT_AUTHENTICATED: str = "Not authenticated" - - # --- HTTP Status Messages --- - HTTP_400_DETAIL: str = "Bad Request" - HTTP_401_DETAIL: str = "Unauthorized" - HTTP_403_DETAIL: str = "Forbidden" - HTTP_404_DETAIL: str = "Not Found" - HTTP_422_DETAIL: str = "Unprocessable Entity" - HTTP_429_DETAIL: str = "Too Many Requests" - HTTP_500_DETAIL: str = "Internal Server Error" - HTTP_503_DETAIL: str = "Service Unavailable" - - # --- Database Error Messages --- - DB_CONNECTION_ERROR: str = "Database connection error" - DB_INTEGRITY_ERROR: str = "Database integrity error" - DB_TRANSACTION_ERROR: str = "Database transaction error" - DB_QUERY_ERROR: str = "Database query error" - - class Config: - env_file = ".env" - env_file_encoding = 'utf-8' - extra = "ignore" - -settings = Settings() - -# Validation for critical settings -if settings.DATABASE_URL is None: - raise ValueError("DATABASE_URL environment variable must be set.") - -# Enforce secure secret key -if not settings.SECRET_KEY: - raise ValueError("SECRET_KEY environment variable must be set. Generate a secure key using: openssl rand -hex 32") - -# Validate secret key strength -if len(settings.SECRET_KEY) < 32: - raise ValueError("SECRET_KEY must be at least 32 characters long for security") - -if settings.GEMINI_API_KEY is None: - logger.error("CRITICAL: GEMINI_API_KEY environment variable not set. Gemini features will be unavailable.") -else: - # Optional: Log partial key for confirmation (avoid logging full key) - logger.info(f"GEMINI_API_KEY loaded (starts with: {settings.GEMINI_API_KEY[:4]}...).") - - - - - - - - - - -