Compare commits

..

188 Commits
main ... ph4

Author SHA1 Message Date
mohamad
f8788ee42d feat: Add Dutch translations and update de, es, fr 2025-06-08 00:03:38 +02:00
whtvrboo
b0ec84b8ca
Merge pull request #16 from whtvrboo/i18n-pages-partial
feat: Add missing i18n translations for page components (partial)
2025-06-07 22:41:18 +02:00
google-labs-jules[bot]
198222c3ff feat: Add missing i18n translations for page components (partial)
This commit introduces internationalization for several page components by identifying hardcoded strings, adding them to translation files, and updating the components to use translation keys.

Processed pages:
- fe/src/pages/AuthCallbackPage.vue: I internationalized an error message.
- fe/src/pages/ChoresPage.vue: I internationalized console error messages and an input placeholder.
- fe/src/pages/ErrorNotFound.vue: I found no missing translations.
- fe/src/pages/GroupDetailPage.vue: I internationalized various UI elements (ARIA labels, button text, fallback user display names) and console/error messages.
- fe/src/pages/GroupsPage.vue: I internationalized error messages and console logs.
- fe/src/pages/IndexPage.vue: I found no missing user-facing translations.
- fe/src/pages/ListDetailPage.vue: My analysis is complete, and I identified a console message and a fallback string for translation (implementation of changes for this page is pending).

For each processed page where changes were needed:
- I added new keys to `fe/src/i18n/en.json`.
- I added corresponding placeholder keys `"[TRANSLATE] Original Text"` to `fe/src/i18n/de.json`, `fe/src/i18n/es.json`, and `fe/src/i18n/fr.json`.
- I updated the Vue component to use the `t()` function with the new keys.

Further pages in `fe/src/pages/` are pending analysis and internationalization as per our original plan.
2025-06-07 20:40:49 +00:00
whtvrboo
7ef225daec
Merge pull request #15 from whtvrboo/fix/offline-logout-on-startup
Fix: Prevent automatic logout when starting app offline
2025-06-07 22:31:23 +02:00
google-labs-jules[bot]
6e56e164df Fix: Prevent automatic logout when starting app offline
Problem:
The application would inadvertently log you out if it was started while offline.
This occurred because the `fetchCurrentUser` action in the `authStore` would attempt to fetch your profile, and if this network request failed (as it does when offline), the catch block would unconditionally call `clearTokens()`. This removed the authentication token, effectively logging you out and preventing access to any cached data or offline functionality.

Solution:
I modified the `fetchCurrentUser` action in `fe/src/stores/auth.ts`:
- The `catch` block now inspects the error.
- `clearTokens()` is only called if the error is a specific HTTP authentication error from the server (401 Unauthorized or 403 Forbidden) when online.
- For network errors (indicating offline status) or other non-auth HTTP errors, tokens are preserved. The user object (`user.value`) might remain null if no cached profile is available, but the authentication token itself is kept.

This change allows the application to remain in a logged-in state when started offline. The service worker can then serve cached API responses, and you can view previously accessed data. Navigation guards rely on `isAuthenticated` (which now remains true offline as long as a token exists), so you are not incorrectly redirected to the login page.
2025-06-07 20:30:52 +00:00
mohamad
550fac1c0c Update API base URL to production environment in api-config.ts
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m23s
2025-06-07 22:14:06 +02:00
mohamad
944976b1cc Update logging level to INFO, refine chore update logic, and enhance invite acceptance flow
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m47s
- Changed logging level from WARNING to INFO in config.py for better visibility during development.
- Adjusted chore update logic in chores.py to ensure correct payload structure.
- Improved invite acceptance process in invites.py by refining error handling and updating response models for better clarity.
- Updated API endpoint configurations in api-config.ts for consistency and added new endpoints for list statuses.
- Enhanced UI components in ChoresPage.vue and GroupsPage.vue for improved user experience and accessibility.
2025-06-07 22:07:35 +02:00
mohamad
92c919785a Update package dependencies to include 'qs'
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m22s
2025-06-07 18:56:29 +02:00
mohamad
a1acee6e59 Implement bulk status retrieval for lists and refine list status handling 2025-06-07 18:55:35 +02:00
mohamad
331eaf7c35 Refine layout and styling in GroupDetailPage for improved ux 2025-06-07 18:23:57 +02:00
mohamad
b9b2bfb469 Adjust footer button indentation and refine CSS positioning for improved layout in ListDetailPage
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m23s
2025-06-07 18:07:47 +02:00
mohamad
5f05cd9377 Fix indentation in ListDetailPage footer buttons for improved readability 2025-06-07 18:05:26 +02:00
mohamad
ddaa20af3c Remove deprecated task management files and enhance group management functionality
- Deleted obsolete task management files: `tasks.mdc` and `notes.md`.
- Introduced a new `groupStore` for managing group data, including fetching user groups and handling loading states.
- Updated `MainLayout.vue` to navigate to groups with improved loading checks.
- Enhanced `GroupsPage.vue` to support a tabbed interface for creating and joining groups, improving user experience.
- Refined `GroupDetailPage.vue` to display recent expenses with a more interactive layout and added functionality for settling shares.
2025-06-07 18:05:08 +02:00
mohamad
cef359238b Fix ChoresPage frequency option access and clean up auto-save comments 2025-06-07 17:02:40 +02:00
mohamad
5fffd4d2f5 Refactor MainLayout.vue to improve component rendering logic 2025-06-07 17:02:29 +02:00
mohamad
397cf28673 Refactor logging and clean up unused console statements across multiple files 2025-06-07 17:02:19 +02:00
mohamad
d6c7fde40c Refactor ChoresPage and GroupDetailPage for improved UI and functionality
- Enhanced the ChoresPage by refining button attributes for accessibility and improving layout consistency.
- Updated the GroupDetailPage to include a more interactive member avatar list and streamlined invite member functionality.
- Introduced new styles for better visual hierarchy and user experience across both pages.
- Implemented click-outside functionality for member menus and invite UI to enhance usability.
2025-06-07 16:50:39 +02:00
mohamad
77178cc67e Refactor VBadge and GroupDetailPage for enhanced badge variants and UI improvements
- Updated VBadge component to include additional badge variants: 'primary', 'success', 'danger', 'warning', 'info', and 'neutral'.
- Modified the GroupDetailPage to utilize the new badge variants for member roles and chore frequencies.
- Improved layout and styling of sections within GroupDetailPage for better user experience.
- Enhanced error handling and notification messages for invite code generation and clipboard actions.
2025-06-07 16:08:59 +02:00
mohamad
0aa88d0af7 Enhance ListDetailPage with collapsible expense items and improved UI 2025-06-07 15:26:46 +02:00
mohamad
fc09848a33 Add position attribute to Item model for reordering functionality
- Introduced a new 'position' column in the Item model to facilitate item ordering.
- Updated the List model's relationship to order items by position and creation date.
- Enhanced CRUD operations to handle item creation and updates with position management.
- Implemented drag-and-drop reordering in the frontend, ensuring proper position updates on item movement.
- Adjusted item update logic to accommodate reordering and version control.
2025-06-07 15:04:49 +02:00
mohamad
b9aace0c4e Update dependencies and refactor ListDetailPage for drag-and-drop functionality
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m23s
- Updated `vue-i18n` and related dependencies to version 9.14.4 for improved localization support.
- Added `vuedraggable` to enable drag-and-drop functionality for list items in `ListDetailPage.vue`.
- Refactored the item list structure to accommodate drag handles and improved item actions.
- Enhanced styling for drag-and-drop interactions and item actions for better user experience.
2025-06-05 01:04:34 +02:00
mohamad
d8db5721f4 Refactor GroupsPage and ListDetailPage for improved loading and error handling 2025-06-05 00:46:23 +02:00
Mohamad
6e79fbfa04 Update API base URL to production environment
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m23s
2025-06-04 17:55:33 +02:00
Mohamad
5c882996a9 Enhance financials API and list expense retrieval
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m23s
- Updated the `check_list_access_for_financials` function to allow access for list creators and members.
- Refactored the `list_expenses` endpoint to support filtering by `list_id`, `group_id`, and `isRecurring`, providing more flexible expense retrieval options.
- Introduced a new `read_list_expenses` endpoint to fetch expenses associated with a specific list, ensuring proper permission checks.
- Enhanced expense retrieval logic in the `get_expenses_for_list` and `get_user_accessible_expenses` functions to include settlement activities.
- Updated frontend API configuration to reflect new endpoint paths and ensure consistency across the application.
2025-06-04 17:50:19 +02:00
Mohamad
6306e70df7 Merge branch 'ph4' of https://github.com/whtvrboo/mitlist into ph4 2025-06-03 12:07:28 +02:00
whtvrboo
dbfbe7922e
Merge pull request #14 from whtvrboo/fix/expense-api-pathing
Fix: Correct API endpoint pathing for expenses to resolve 404 errors
2025-06-03 12:05:08 +02:00
google-labs-jules[bot]
57b913d135 Fix: Correct API endpoint pathing for expenses to resolve 404 errors
The expenses frontend was encountering 404 errors due to mismatched API paths
between the frontend calls and backend routing.

This commit addresses the issue by:

1. Modifying backend API routing in `be/app/api/v1/api.py`:
   - Added a `/financials` prefix to the `financials.router`. Expense endpoints are now served under `/api/v1/financials/expenses`.

2. Updating frontend API configuration in `fe/src/config/api-config.ts`:
   - Prepended `/api/v1` to all paths within the `API_ENDPOINTS.FINANCIALS` object to match the new backend structure (e.g., `API_ENDPOINTS.FINANCIALS.EXPENSES` is now `/api/v1/financials/expenses`).

3. Updating frontend expense service in `fe/src/services/expenseService.ts`:
   - Replaced hardcoded relative URLs with the updated constants from `API_ENDPOINTS.FINANCIALS`.
   - Ensured `API_ENDPOINTS` is correctly imported.

These changes align the frontend API calls with the backend endpoint definitions,
resolving the 404 errors.
2025-06-03 10:04:42 +00:00
mohamad
588abb1217 Refactor i18n message imports and update PWA configuration.
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m16s
2025-06-02 19:08:30 +02:00
mohamad
d150dd28c9 Update OAuth redirect URIs and API routing structure
- Changed the Google and Apple redirect URIs in the configuration to include the API version in the path.
- Reorganized the inclusion of OAuth routes in the main application to ensure they are properly prefixed and accessible.

These updates aim to enhance the API structure and ensure consistency in the authentication flow.
2025-06-02 19:08:30 +02:00
mohamad
6b54566cef Refactor API routing and update login URLs
- Updated the OAuth routes to be included under the main API prefix for better organization.
- Changed the Google login URL in the SocialLoginButtons component to reflect the new API structure.

These changes aim to improve the clarity and consistency of the API routing and enhance the login flow for users.
2025-06-02 19:08:12 +02:00
mohamad
d623c4b27c Enhance i18n support and PWA configuration
- Increased the maximum file size for caching in PWA settings from 5MB to 15MB in vite.config.ts.
- Updated import paths for i18n messages in main.ts for consistency.
- Simplified the i18n index by removing unnecessary keys and using shorthand for language imports.
- Added debug output in LoginPage.vue to log current locale and available messages for easier troubleshooting.
2025-06-02 18:07:41 +02:00
mohamad
fc49e830fc Update Dockerfile to use npm install and modify PWA theme and background colors in vite.config.ts 2025-06-02 18:07:41 +02:00
mohamad
af6324ddef Update vue-i18n dependency to version 9.9.1 in package.json 2025-06-02 18:07:41 +02:00
mohamad
6924a016c8 Add project documentation and production deployment guide
- Introduced comprehensive project documentation for the Shared Household Management PWA, detailing project overview, goals, features, user experience philosophy, technology stack, and development roadmap.
- Added a production deployment guide using Docker Compose and Gitea Actions, outlining setup, configuration, and deployment processes.
- Updated favicon and icon assets for improved branding and user experience across devices.
2025-06-02 18:07:41 +02:00
mohamad
0fcc94ae8d Update OAuth redirect URIs to production environment
- Changed the Google and Apple redirect URIs in the configuration to point to the production URLs.
- This update ensures that the application correctly redirects users to the appropriate authentication endpoints in the live environment.
2025-06-02 18:07:41 +02:00
mohamad
c0aa654e83 Update OAuth redirect URIs and API routing structure
- Changed the Google and Apple redirect URIs in the configuration to include the API version in the path.
- Reorganized the inclusion of OAuth routes in the main application to ensure they are properly prefixed and accessible.

These updates aim to enhance the API structure and ensure consistency in the authentication flow.
2025-06-02 18:07:41 +02:00
mohamad
ec361fe9ab Refactor API routing and update login URLs
- Updated the OAuth routes to be included under the main API prefix for better organization.
- Changed the Google login URL in the SocialLoginButtons component to reflect the new API structure.

These changes aim to improve the clarity and consistency of the API routing and enhance the login flow for users.
2025-06-02 18:07:41 +02:00
mohamad
9d404d04d5 Update OAuth redirect URIs and API routing structure
- Changed the Google and Apple redirect URIs in the configuration to include the API version in the path.
- Reorganized the inclusion of OAuth routes in the main application to ensure they are properly prefixed and accessible.

These updates aim to enhance the API structure and ensure consistency in the authentication flow.
2025-06-02 18:07:41 +02:00
mohamad
92c70813fb Refactor API routing and update login URLs
- Updated the OAuth routes to be included under the main API prefix for better organization.
- Changed the Google login URL in the SocialLoginButtons component to reflect the new API structure.

These changes aim to improve the clarity and consistency of the API routing and enhance the login flow for users.
2025-06-02 18:07:28 +02:00
mohamad
2d16116716 Enhance i18n support and PWA configuration
- Increased the maximum file size for caching in PWA settings from 5MB to 15MB in vite.config.ts.
- Updated import paths for i18n messages in main.ts for consistency.
- Simplified the i18n index by removing unnecessary keys and using shorthand for language imports.
- Added debug output in LoginPage.vue to log current locale and available messages for easier troubleshooting.
2025-06-02 18:06:21 +02:00
mohamad
3e328c2902 Update Dockerfile to use npm install and modify PWA theme and background colors in vite.config.ts
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m30s
2025-06-02 00:29:01 +02:00
mohamad
effaef7d08 Update vue-i18n dependency to version 9.9.1 in package.json
Some checks failed
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Failing after 1m1s
2025-06-02 00:25:21 +02:00
mohamad
12e2890a4a Add project documentation and production deployment guide
Some checks failed
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Failing after 1m6s
- Introduced comprehensive project documentation for the Shared Household Management PWA, detailing project overview, goals, features, user experience philosophy, technology stack, and development roadmap.
- Added a production deployment guide using Docker Compose and Gitea Actions, outlining setup, configuration, and deployment processes.
- Updated favicon and icon assets for improved branding and user experience across devices.
2025-06-02 00:19:54 +02:00
mohamad
f98bdb6b11 Update OAuth redirect URIs to production environment
- Changed the Google and Apple redirect URIs in the configuration to point to the production URLs.
- This update ensures that the application correctly redirects users to the appropriate authentication endpoints in the live environment.
2025-06-02 00:19:54 +02:00
mohamad
5d50606fc2 Update OAuth redirect URIs and API routing structure
- Changed the Google and Apple redirect URIs in the configuration to include the API version in the path.
- Reorganized the inclusion of OAuth routes in the main application to ensure they are properly prefixed and accessible.

These updates aim to enhance the API structure and ensure consistency in the authentication flow.
2025-06-02 00:19:54 +02:00
mohamad
30af7ab692 Refactor API routing and update login URLs
- Updated the OAuth routes to be included under the main API prefix for better organization.
- Changed the Google login URL in the SocialLoginButtons component to reflect the new API structure.

These changes aim to improve the clarity and consistency of the API routing and enhance the login flow for users.
2025-06-02 00:19:53 +02:00
google-labs-jules[bot]
4effbf5c03 feat(i18n): Internationalize remaining app pages
This commit completes the internationalization (i18n) for several key pages
within the frontend application.

The following pages have been updated to support multiple languages:
- AccountPage.vue
- SignupPage.vue
- ListDetailPage.vue (including items, OCR, expenses, and cost summary)
- MyChoresPage.vue
- PersonalChoresPage.vue
- IndexPage.vue

Key changes include:
- Extraction of all user-facing strings from these Vue components.
- Addition of new translation keys and their English values to `fe/src/i18n/en.json`.
- Modification of the Vue components to use the Vue I18n plugin's `$t()` (template)
  and `t()` (script) functions for displaying translated strings.
- Dynamic messages, notifications, and form validation messages are now also
  internationalized.
- The language files `de.json`, `es.json`, and `fr.json` have been updated
  with the new keys, using the English text as placeholders for future
  translation.

This effort significantly expands the i18n coverage of the application,
making it more accessible to a wider audience.
2025-06-02 00:19:26 +02:00
google-labs-jules[bot]
5c9ba3f38c feat: Internationalize AuthCallback, Chores, ErrorNotFound, GroupDetail pages
This commit introduces internationalization for several pages:
- AuthCallbackPage.vue
- ChoresPage.vue (a comprehensive page with many elements)
- ErrorNotFound.vue
- GroupDetailPage.vue (including sub-sections for members, invites, chores summary, and expenses summary)

Key changes:
- Integrated `useI18n` in each listed page to handle translatable strings.
- Replaced hardcoded text in templates and relevant script sections (notifications, dynamic messages, fallbacks, etc.) with `t('key')` calls.
- Added new translation keys, organized under page-specific namespaces (e.g., `authCallbackPage`, `choresPage`, `errorNotFoundPage`, `groupDetailPage`), to `fe/src/i18n/en.json`.
- Added corresponding keys with placeholder translations (prefixed with DE:, FR:, ES:) to `fe/src/i18n/de.json`, `fe/src/i18n/fr.json`, and `fe/src/i18n/es.json`.
- Reused existing translation keys (e.g., for chore frequency options) where applicable.
2025-06-02 00:19:26 +02:00
google-labs-jules[bot]
8034824c97 Fix: Resolve Google OAuth redirection issue
This commit addresses an issue where you, when clicking the "Continue with Google"
button, were redirected back to the login page instead of to Google's
authentication page.

The following changes were made:

1.  **Frontend Redirect:**
    *   Modified `fe/src/components/SocialLoginButtons.vue` to make the "Continue with Google" button redirect to the correct backend API endpoint (`/auth/google/login`) using the configured `API_BASE_URL`.

2.  **Backend Route Confirmation:**
    *   Verified that the backend OAuth routes in `be/app/api/auth/oauth.py` are correctly included in `be/app/main.py` under the `/auth` prefix, making them accessible.

3.  **OAuth Credentials Configuration:**
    *   Added `GOOGLE_CLIENT_ID` and `GOOGLE_CLIENT_SECRET` placeholders to `env.production.template` to guide you in setting up your OAuth credentials.
    *   Added instructional comments in `be/app/config.py` regarding the necessity of these environment variables and the correct configuration of `GOOGLE_REDIRECT_URI`.

With these changes, and assuming the necessary Google Cloud OAuth credentials
(Client ID, Client Secret) and redirect URIs are correctly configured in the
environment, the Google OAuth flow should now function as expected.
2025-06-02 00:19:26 +02:00
whtvrboo
82205f6158
Merge pull request #13 from whtvrboo/i18n-feature-pages
feat(i18n): Internationalize remaining app pages
2025-06-02 00:13:53 +02:00
google-labs-jules[bot]
2a2045c24a feat(i18n): Internationalize remaining app pages
This commit completes the internationalization (i18n) for several key pages
within the frontend application.

The following pages have been updated to support multiple languages:
- AccountPage.vue
- SignupPage.vue
- ListDetailPage.vue (including items, OCR, expenses, and cost summary)
- MyChoresPage.vue
- PersonalChoresPage.vue
- IndexPage.vue

Key changes include:
- Extraction of all user-facing strings from these Vue components.
- Addition of new translation keys and their English values to `fe/src/i18n/en.json`.
- Modification of the Vue components to use the Vue I18n plugin's `$t()` (template)
  and `t()` (script) functions for displaying translated strings.
- Dynamic messages, notifications, and form validation messages are now also
  internationalized.
- The language files `de.json`, `es.json`, and `fr.json` have been updated
  with the new keys, using the English text as placeholders for future
  translation.

This effort significantly expands the i18n coverage of the application,
making it more accessible to a wider audience.
2025-06-01 22:13:36 +00:00
whtvrboo
c1ebd16e5a
Merge pull request #12 from whtvrboo/feat/i18n-more-pages
feat: Internationalize AuthCallback, Chores, ErrorNotFound, GroupDeta…
2025-06-01 23:52:54 +02:00
google-labs-jules[bot]
554814ad63 feat: Internationalize AuthCallback, Chores, ErrorNotFound, GroupDetail pages
This commit introduces internationalization for several pages:
- AuthCallbackPage.vue
- ChoresPage.vue (a comprehensive page with many elements)
- ErrorNotFound.vue
- GroupDetailPage.vue (including sub-sections for members, invites, chores summary, and expenses summary)

Key changes:
- Integrated `useI18n` in each listed page to handle translatable strings.
- Replaced hardcoded text in templates and relevant script sections (notifications, dynamic messages, fallbacks, etc.) with `t('key')` calls.
- Added new translation keys, organized under page-specific namespaces (e.g., `authCallbackPage`, `choresPage`, `errorNotFoundPage`, `groupDetailPage`), to `fe/src/i18n/en.json`.
- Added corresponding keys with placeholder translations (prefixed with DE:, FR:, ES:) to `fe/src/i18n/de.json`, `fe/src/i18n/fr.json`, and `fe/src/i18n/es.json`.
- Reused existing translation keys (e.g., for chore frequency options) where applicable.
2025-06-01 21:51:01 +00:00
whtvrboo
f2609f53ec
Merge pull request #11 from whtvrboo/fix/google-oauth-redirect
Fix: Resolve Google OAuth redirection issue
2025-06-01 23:14:10 +02:00
google-labs-jules[bot]
4fef642970 Fix: Resolve Google OAuth redirection issue
This commit addresses an issue where you, when clicking the "Continue with Google"
button, were redirected back to the login page instead of to Google's
authentication page.

The following changes were made:

1.  **Frontend Redirect:**
    *   Modified `fe/src/components/SocialLoginButtons.vue` to make the "Continue with Google" button redirect to the correct backend API endpoint (`/auth/google/login`) using the configured `API_BASE_URL`.

2.  **Backend Route Confirmation:**
    *   Verified that the backend OAuth routes in `be/app/api/auth/oauth.py` are correctly included in `be/app/main.py` under the `/auth` prefix, making them accessible.

3.  **OAuth Credentials Configuration:**
    *   Added `GOOGLE_CLIENT_ID` and `GOOGLE_CLIENT_SECRET` placeholders to `env.production.template` to guide you in setting up your OAuth credentials.
    *   Added instructional comments in `be/app/config.py` regarding the necessity of these environment variables and the correct configuration of `GOOGLE_REDIRECT_URI`.

With these changes, and assuming the necessary Google Cloud OAuth credentials
(Client ID, Client Secret) and redirect URIs are correctly configured in the
environment, the Google OAuth flow should now function as expected.
2025-06-01 21:13:48 +00:00
mohamad
dda39532d6 Update OAuth redirect URIs and API routing structure
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m16s
- Changed the Google and Apple redirect URIs in the configuration to include the API version in the path.
- Reorganized the inclusion of OAuth routes in the main application to ensure they are properly prefixed and accessible.

These updates aim to enhance the API structure and ensure consistency in the authentication flow.
2025-06-01 22:43:02 +02:00
mohamad
6d5e950918 Refactor API routing and update login URLs
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m16s
- Updated the OAuth routes to be included under the main API prefix for better organization.
- Changed the Google login URL in the SocialLoginButtons component to reflect the new API structure.

These changes aim to improve the clarity and consistency of the API routing and enhance the login flow for users.
2025-06-01 22:37:44 +02:00
mohamad
e6c15210c1 Update API base URL in api-config.ts to point to the new production environment
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m16s
2025-06-01 22:16:36 +02:00
mohamad
b07ab09f88 Enhance styling and animations in ListsPage component
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m17s
- Updated transition effects for checkbox elements to use cubic-bezier for smoother animations.
- Added scale transformation on hover and checked states for improved visual feedback.
- Adjusted animation durations for checkbox interactions to enhance user experience.
- Refined text styling for checked checkboxes to include color change and line-through effect.

These changes aim to improve the overall user interface and interaction dynamics of the ListsPage component.
2025-06-01 22:03:08 +02:00
mohamad
5cb13862ef Enhance ChoresPage accessibility and functionality
- Added ARIA roles and attributes to buttons and modals for improved accessibility.
- Updated chore types and properties in the Chore interface to allow for null values.
- Refactored chore loading and filtering logic to handle edge cases and improve performance.
- Enhanced calendar and list views with better user feedback for empty states and loading indicators.
- Improved styling for mobile responsiveness and dark mode support.

These changes aim to enhance user experience, accessibility, and maintainability of the ChoresPage component.
2025-06-01 22:00:11 +02:00
mohamad
843b3411e4 Update package dependencies and refactor ListDetailPage and ListsPage components
- Added `motion` and `framer-motion` packages to `package.json` and `package-lock.json`.
- Updated API base URL in `api-config.ts` to point to the local development environment.
- Refactored `ListDetailPage.vue` to enhance item rendering and interaction, replacing `VListItem` with a custom list structure.
- Improved `ListsPage.vue` to handle loading states and item addition more effectively, including better handling of temporary item IDs.

These changes aim to improve the user experience and maintainability of the application.
2025-06-01 21:57:03 +02:00
mohamad
7da93d1fe9 Update API base URL in api-config.ts to point to the production environment
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m17s
2025-06-01 21:06:40 +02:00
mohamad
02238974aa Refactor: Update styling and functionality in various components
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m17s
This commit includes several enhancements across multiple files:
- In `valerie-ui.scss`, improved formatting of CSS properties and adjusted selectors for better readability and consistency.
- In `CreateListModal.vue`, introduced a sentinel value for group selection and refined the logic for handling group options.
- In `VModal.vue`, streamlined modal structure and removed unnecessary styles, centralizing modal styling in `valerie-ui.scss`.
- In `VTextarea.vue`, adjusted aria attributes for better accessibility and improved code clarity.
- Updated `api-config.ts` to switch the API base URL to a local development environment.

These changes aim to enhance maintainability, accessibility, and overall user experience.
2025-06-01 20:41:04 +02:00
mohamad
ca1ac94b57 Refactor GroupsPage: Replace VButton and VIcon components with standard HTML button and SVG for improved compatibility and maintainability. Added console logs for better debugging during the create list dialog flow. 2025-06-01 19:59:45 +02:00
mohamad
e52ab871bc Enhance group selection flow by ensuring latest groups data is fetched before opening the create list dialog. Additionally, refresh the groups list after a new list is created to reflect updates. This improves data consistency and user experience on the GroupsPage. 2025-06-01 19:56:11 +02:00
mohamad
c6c204f64a Refactor: Replace button elements with VButton and VIcon components in GroupsPage
Some checks failed
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Has been cancelled
This commit updates the GroupsPage.vue file by replacing standard button elements with custom VButton and VIcon components for improved consistency and styling. This change enhances the UI component structure and aligns with the project's design system.
2025-06-01 19:51:05 +02:00
mohamad
a059768d8a Refactor: Simplify docker-compose configuration by removing unused services and optimizing backend settings
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m20s
This commit cleans up the `docker-compose.prod.yml` file by removing the database, Redis, and frontend service configurations, which were deemed unnecessary. The backend service configuration has been streamlined, including the update of the `VITE_API_URL` environment variable. This refactor aims to enhance clarity and maintainability of the Docker setup.
2025-06-01 19:18:42 +02:00
whtvrboo
09c3160fbb
Merge pull request #10 from whtvrboo/fix/docker-migrations
Fix(alembic): Resolve TypeError in migration script and remove redund…
2025-06-01 19:17:07 +02:00
google-labs-jules[bot]
287155a783 Fix(alembic): Resolve TypeError in migration script and remove redundant migration call
This commit addresses two issues:
1. A `TypeError` during Alembic migrations (`upgrade() takes 0 positional
   arguments but 1 was given`). This was caused by the `upgrade` and
   `downgrade` functions in the initial migration script not accepting
   any arguments, while the custom migration runner in `migrations.py`
   was passing a context argument.
   - Modified `be/alembic/versions/0001_initial_schema.py` to ensure
     `upgrade` and `downgrade` functions accept a `context` argument.

2. Redundant execution of migrations. Migrations were being triggered
   both by the `entrypoint.sh` script and within the FastAPI application's
   startup event in `app/main.py`.
   - Commented out the `await run_migrations()` call in `app/main.py`
     to ensure migrations are only handled by the `entrypoint.sh` script.

These changes should ensure that database migrations run correctly and only
once when the backend container starts.
2025-06-01 17:16:41 +00:00
whtvrboo
c50395ae86
Merge pull request #9 from whtvrboo/fix/docker-migrations
Fix(docker): Run Alembic migrations on container startup
2025-06-01 19:10:53 +02:00
google-labs-jules[bot]
4540ad359e Fix(docker): Run Alembic migrations on container startup
This commit introduces changes to ensure that Alembic database migrations
are automatically applied when the backend Docker container starts.

Key changes:
- Added `be/entrypoint.sh`: This script first runs `alembic upgrade head`
  to apply any pending migrations and then executes the main container
  command (e.g., starting Uvicorn).
- Modified `be/Dockerfile`:
    - The `entrypoint.sh` script is copied into the image and made executable.
    - The Docker `ENTRYPOINT` is set to this script, ensuring migrations
      run before the application starts.
- Updated `docker-compose.yml`:
    - The `DATABASE_URL` for the `backend` service has been set to the
      Neon database URL you provided.
- Verified `be/alembic/env.py`: Confirmed that it correctly sources the
  `DATABASE_URL` from environment variables for Alembic to use.

These changes address the issue where migrations were not being run,
preventing the application from starting correctly.
2025-06-01 17:10:40 +00:00
mohamad
3738819065 refactor: Simplify upgrade function by directly creating enums and adding new tables for chores and chore assignments in the initial schema
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m31s
2025-06-01 18:20:31 +02:00
mohamad
c14b432082 refactor: Encapsulate enum creation logic within a dedicated function in the upgrade process for improved readability and maintainability
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m21s
2025-06-01 18:15:22 +02:00
mohamad
c204c25314 refactor: Modify upgrade function to accept context parameter for enhanced migration flexibility
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m21s
2025-06-01 18:12:58 +02:00
mohamad
02ab812ef0 refactor: Clarify access to revision strings in migration function by referencing Script object within RevisionStep
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m19s
2025-06-01 18:09:31 +02:00
mohamad
20daadc112 refactor: Update migration function to access revision strings from RevisionStep objects for improved clarity
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m16s
2025-06-01 17:50:02 +02:00
mohamad
5dcabd51f7 refactor: Introduce migration function to streamline upgrade steps in Alembic migrations
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m17s
2025-06-01 17:45:32 +02:00
mohamad
8f1da5d440 refactor: Improve Alembic migration functions by integrating configuration and script directory handling for enhanced migration context management
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m17s
2025-06-01 17:42:17 +02:00
mohamad
0f9d83a233 refactor: Update migration functions to accept connection parameter for improved flexibility and consistency
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m16s
2025-06-01 17:39:07 +02:00
mohamad
cb5bfcf7b5 refactor: Separate async migration logic into dedicated module and streamline migration functions for improved clarity and maintainability
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m17s
2025-06-01 17:33:04 +02:00
mohamad
e16c749019 refactor: Enhance Alembic migration functions to support direct execution and improve error handling for database URL configuration
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m18s
2025-06-01 17:29:48 +02:00
mohamad
7223606fdc refactor: Update Alembic migration functions to support asynchronous execution and streamline migration handling in application startup
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m20s
2025-06-01 17:20:28 +02:00
mohamad
f4eeb00acf fix: Add Alembic directory and configuration file to production Dockerfile for migration support
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m16s
2025-06-01 17:16:26 +02:00
mohamad
43e2d88ffe fix: Update Alembic configuration to use absolute paths for ini file and script location in migration process
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m15s
2025-06-01 17:13:09 +02:00
mohamad
32841ea727 fix: Enhance Alembic configuration by setting script location and database URL validation in migration process
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m17s
2025-06-01 17:09:49 +02:00
mohamad
26e06ddeaa refactor: Simplify Dockerfile by reorganizing Alembic file copying and enhance migration handling in application startup
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m21s
2025-06-01 17:03:13 +02:00
mohamad
f2df1c50dd fix: Update Alembic configuration in startup event to set script location and database URL
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m26s
2025-06-01 16:57:07 +02:00
mohamad
411c3c91b2 feat: Add Alembic configuration and migration command to application startup
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m16s
2025-06-01 16:54:16 +02:00
mohamad
5a2b311a4f fix ig
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m14s
2025-06-01 16:49:21 +02:00
mohamad
9b09b461bd refactor: Update production Dockerfile to use Node.js for serving built assets and enhance environment variable injection
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m24s
2025-06-01 16:46:00 +02:00
mohamad
9f8de46d06 refactor: Transition production Dockerfile to use Nginx for serving built assets and streamline environment variable handling
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m20s
2025-06-01 16:38:29 +02:00
mohamad
b1a74edb6a refactor: Update environment variable handling in Dockerfile for production
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m16s
2025-06-01 16:32:51 +02:00
mohamad
161292ff3b refactor: Optimize Dockerfiles and deployment workflow for improved performance and reliability
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 1m54s
- Updated Dockerfiles to use `python:3.11-slim` for reduced image size and enhanced build efficiency.
- Implemented multi-stage builds with selective file copying and non-root user creation for better security.
- Enhanced deployment workflow with retry logic for image pushes and added cleanup steps for Docker resources.
- Improved build commands with BuildKit optimizations for both backend and frontend images.
2025-06-01 16:26:49 +02:00
mohamad
55d08d36e0 refactor: Revise .dockerignore and Dockerfile for enhanced build efficiency and organization
Some checks failed
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Has been cancelled
- Updated .dockerignore to categorize ignored files, including logs and local development configurations.
- Implemented a multi-stage build in Dockerfile to optimize image size and dependency management.
- Added build dependencies and created a virtual environment for better isolation of Python packages.
2025-06-01 16:14:55 +02:00
mohamad
59f2f47949 refactor: Improve deployment workflow with retry logic for image pushes and optimized build process
Some checks failed
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Has been cancelled
2025-06-01 16:03:58 +02:00
mohamad
1e9957de91 refactor: Enhance deployment workflow for backend and frontend images
Some checks failed
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Failing after 24s
2025-06-01 16:00:55 +02:00
mohamad
6ed7e32922 refactor: Update .dockerignore for improved clarity and organization
Some checks failed
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Has been cancelled
- Consolidated and categorized ignored files for better readability.
- Added entries for logs, local development, and documentation.
- Removed redundant entries and ensured proper grouping of related files.
2025-06-01 15:56:49 +02:00
mohamad
cc1f910e4c refactor: Standardize user creation in Dockerfile and improve multi-stage build syntax
Some checks failed
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Failing after 4m40s
2025-06-01 15:47:42 +02:00
mohamad
cd98b7b854 refactor: Update backend Dockerfile to use Alpine package names
Some checks failed
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Failing after 20s
2025-06-01 15:46:08 +02:00
mohamad
392a2ae049 refactor: Switch backend Dockerfile to use Alpine package manager
Some checks failed
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Failing after 19s
2025-06-01 15:44:09 +02:00
mohamad
a51b18e8f5 refactor: Update Docker configurations for improved environment variable handling
Some checks failed
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Failing after 22s
- Changed the frontend Dockerfile to use process.env for environment variables instead of direct interpolation.
- Updated the production Docker Compose file to set environment variables directly instead of using build args.
- Switched the backend Dockerfile base image from `python:3.11-slim` to `python:alpine` for a smaller image size and increased worker count from 4 to 8 for better performance.
2025-06-01 15:41:42 +02:00
mohamad
99d6c5ffaa refactor: Improve environment variable injection in Dockerfile for production
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 30s
2025-06-01 15:34:59 +02:00
mohamad
dd29f27a5b fix: Update API base URL for development environment
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 27s
2025-06-01 15:15:56 +02:00
mohamad
d05200b623 refactor: Update frontend components and Dockerfile for production
All checks were successful
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Successful in 38s
- Changed the build command in Dockerfile from `npm run build` to `npm run build-only` for optimized production builds.
- Simplified service worker initialization by removing conditional precaching logic.
- Enhanced styling and structure in `VListItem.vue` and `ListDetailPage.vue` for better readability and consistency with Valerie UI components.
- Improved focus handling in item addition and editing processes for better user experience.
- Cleaned up unused CSS classes and ensured consistent usage of Valerie UI components across the application.
2025-06-01 14:59:30 +02:00
mohamad
ed76816a32 Enhance deployment workflow with context variable debugging and fallback logic
Some checks failed
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Failing after 1m26s
2025-06-01 14:50:52 +02:00
mohamad
8c5753ea77 fix: Update Docker image tags
Some checks failed
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Failing after 17s
2025-06-01 14:47:45 +02:00
mohamad
12f35b539a fix: Update Docker login commands in deployment workflow
Some checks failed
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Failing after 16s
2025-06-01 14:41:09 +02:00
mohamad
e104d26583 chore: Update deployment workflow to trigger on closed pull requests to prod
Some checks failed
Deploy to Production, build images and push to Gitea Registry / build_and_push (pull_request) Failing after 32s
2025-06-01 14:38:04 +02:00
mohamad
8ff31ecf91 refactor: Update deployment workflows and Dockerfiles for production
- Modified the GitHub Actions workflow to streamline the deployment process by installing Docker directly and using shell commands for building and pushing images.
- Changed the base image for the backend Dockerfile from `python:3.11-slim` to `python:alpine` for a smaller footprint.
- Updated the frontend Dockerfile to use `node:23-alpine` instead of `node:24-alpine`, and refactored the production stage to use `node:slim`. Added a script for runtime environment variable injection.
2025-06-01 14:37:09 +02:00
mohamad
1c87170955 refactor: use html for now
Some checks failed
Deploy to Production, build images and push to Gitea Registry / deploy (push) Failing after 2m14s
2025-06-01 14:27:46 +02:00
mohamad
74c73a9e8f refactor: Update GroupsPage to use standard HTML for now 2025-06-01 14:27:02 +02:00
mohamad
679169e4fb refactor: Simplify ChoresPage structure and enhance form functionality
- Removed redundant form elements and improved the layout for better readability.
- Streamlined the chore creation and editing process with enhanced validation and auto-save features.
- Updated keyboard shortcuts for improved accessibility and user experience.
- Enhanced modal interactions and improved loading states during data fetching.
- Cleaned up unused code and optimized the overall component structure.
2025-06-01 14:27:02 +02:00
google-labs-jules[bot]
a7fbc454a9 Refactor: Reset Alembic migrations and consolidate models.
This commit addresses issues with backend models, schemas, and migrations.

Key changes:
- Consolidated all SQLAlchemy model definitions into `be/app/models.py`.
- Emptied `be/app/models/expense.py` as its contents were duplicates.
- Verified and standardized Base class usage and SQLAlchemy imports in models.
- Confirmed the correctness of self-referential relationships in the `Expense` model.
- Added a clarifying comment to `SplitTypeEnum` regarding future extensibility.
- Corrected a typo in `Settlement.created_by_user_id`.

Migration Cleanup:
- Deleted all existing Alembic migration files from `be/alembic/versions/`.
- Created a new, single initial migration script (`0001_initial_schema.py`) that defines the entire database schema based on the current state of the SQLAlchemy models. This provides a clean slate for future migrations.

This reset was performed because the previous migration history was complex and contained a revision that was incompatible with the current model definitions. Starting fresh ensures consistency between the models and the database schema from the initial point.
2025-06-01 14:26:37 +02:00
google-labs-jules[bot]
813ed911f1 Okay, I've made some changes to integrate the Valerie UI components into the Account, Group Detail, and List Detail pages. This is part of the ongoing effort to standardize the UI and make the code easier to maintain.
Here's a breakdown of the changes:

1.  **`AccountPage.vue`**:
    *   I replaced the main heading with `VHeading`.
    *   I updated the loading spinner to `VSpinner`.
    *   I converted the error alert to `VAlert` with an action button.
    *   I refactored the Profile, Password, and Notifications sections to use `VCard` for their structure.
    *   The form elements within these cards (name, email, passwords) now use `VFormField` and `VInput`.
    *   Action buttons like "Save Changes" and "Change Password" are now `VButton` with an integrated `VSpinner` for loading states.
    *   The notification preferences list uses `VList` and `VListItem`, with each preference toggle converted to `VToggleSwitch`.

2.  **`GroupDetailPage.vue`**:
    *   I updated the page-level loading spinner, error alert, and main heading to `VSpinner`, `VAlert`, and `VHeading`.
    *   I refactored the "Group Members", "Invite Members", "Chores", and "Expenses" sections from custom "neo-card" styling to use `VCard`.
    *   Headers within these cards use `VHeading` and action buttons use `VButton` (I kept Material Icons where `VIcon` wasn't a direct replacement).
    *   Lists of members, chores, and expenses now use `VList` and `VListItem`.
    *   Buttons within list items (e.g., "Remove member") are `VButton` with `VSpinner`.
    *   Role indicators and frequency/split type "chips" are now `VBadge` components, and I updated the helper functions to return VBadge-compatible variants.
    *   The "Invite Members" form elements (input for code, copy button) use `VFormField`, `VInput`, and `VButton`.
    *   I simplified empty states within card bodies using `VIcon` and text.

3.  **`ListDetailPage.vue`**: This complex page required several steps to refactor:
    *   **Page-Level & Header:** I updated the loading state to `VSpinner`, the error alert to `VAlert`, and the main title to `VHeading`. Header action buttons are `VButton` with icons, and the list status is `VBadge`.
    *   **Modals:** I converted all five custom modals (OCR, Confirmation, Edit Item, Settle Share, Cost Summary shell) to use `VModal`. Internal forms and actions within these modals now use `VFormField`, `VInput`, `VButton`, `VSpinner`, `VList`, `VListItem`, and `VAlert` as appropriate. I removed the `onClickOutside` logic.
    *   **Main Items List:** The loading state uses `VCard` with `VSpinner`, and the empty state uses `VCard variant="empty-state"`. The list itself is now a `VCard` containing a `VList`. Each item is a `VListItem` with internal content refactored to use `VCheckbox`, `VInput` (for price), and `VButton` with `VIcon` for actions.
    *   **Add Item Form:** I re-structured this below the items list, using `VFormField`, `VInput`, and `VButton` with `VIcon`.
    *   **Expenses Section:** The main card uses `VCard` with `VHeading` and `VButton` in the header. Loading/error/empty states use `VSpinner`, `VAlert`, `VIcon`. The expenses list is `VList`, with each expense item as a `VListItem`. Statuses are `VBadge`.

This refactoring significantly increases the usage of the Valerie UI component library across these key application pages. This should help create a more consistent experience for you and make development smoother. Next, I'll focus on the Chores-related pages.
2025-06-01 14:26:37 +02:00
google-labs-jules[bot]
272e5abe41 refactor: Integrate Valerie UI components into Group and List pages
This commit refactors parts of `GroupsPage.vue`, `ListsPage.vue`, and the shared `CreateListModal.vue` to use the newly created Valerie UI components.

Key changes include:

1.  **Modals:**
    *   The "Create Group Dialog" in `GroupsPage.vue` now uses `VModal`, `VFormField`, `VInput`, `VButton`, and `VSpinner`.
    *   The `CreateListModal.vue` component (used by both pages) has been internally refactored to use `VModal`, `VFormField`, `VInput`, `VTextarea`, `VSelect`, `VButton`, and `VSpinner`.

2.  **Forms:**
    *   The "Join Group" form in `GroupsPage.vue` now uses `VFormField`, `VInput`, `VButton`, and `VSpinner`.

3.  **Alerts:**
    *   Error alerts in both `GroupsPage.vue` and `ListsPage.vue` now use the `VAlert` component, with retry buttons placed in the `actions` slot.

4.  **Empty States:**
    *   The empty state displays (e.g., "No Groups Yet", "No lists found") in both pages now use the `VCard` component with `variant="empty-state"`, mapping content to the relevant props and slots.

5.  **Buttons:**
    *   Various standalone buttons (e.g., "Create New Group", "Create New List", "List" button on group cards) have been updated to use the `VButton` component with appropriate props for variants, sizes, and icons.

**Scope of this Refactor:**
*   The focus was on replacing direct usages of custom-styled modal dialogs, form elements, alerts, and buttons with their Valerie UI component counterparts.
*   Highly custom card-like structures such as `neo-group-card` (in `GroupsPage.vue`) and `neo-list-card` (in `ListsPage.vue`), along with their specific "create" card variants, have been kept with their existing custom styling for this phase. This is due to their unique layouts and styling not directly mapping to the current generic `VCard` component without significant effort or potential introduction of overly specific props to `VCard`. Only buttons within these custom cards were refactored.
*   The internal item rendering within `neo-list-card` (custom checkboxes, add item input) also remains custom for now.

This refactoring improves consistency by leveraging the standardized Valerie UI components for common UI patterns like modals, forms, alerts, and buttons on these pages.
2025-06-01 14:26:37 +02:00
google-labs-jules[bot]
fc16f169b1 Jules was unable to complete the task in time. Please review the work done so far and provide feedback for Jules to continue. 2025-06-01 14:26:37 +02:00
google-labs-jules[bot]
3811dc7ee5 Refactor: Polish backend based on review
I reviewed the backend codebase covering schema, API endpoints, error handling, and tests.

Key changes I implemented:
- Updated `app/models.py`:
    - Added `parent_expense_id` and `last_occurrence` fields to the `Expense` model to align with the `add_recurring_expenses.py` migration.
    - Added `parent_expense` and `child_expenses` self-referential relationships to the `Expense` model.
- Updated `app/core/exceptions.py`:
    - Removed the unused and improperly defined `BalanceCalculationError` class.

I identified areas for future work:
- Create a new Alembic migration if necessary to ensure `parent_expense_id` and `last_occurrence` columns are correctly reflected in the database, or verify the existing `add_recurring_expenses.py` migration's status.
- Significantly improve API test coverage, particularly for:
    - Chores module (personal and group)
    - Groups, Invites, Lists, Items, OCR endpoints
    - Full CRUD operations for Expenses and Settlements
    - Recurring expense functionalities.
2025-06-01 14:26:36 +02:00
mohamad
136c4df7ac feat: Integrate Storybook for component development 2025-06-01 14:26:36 +02:00
mohamad
821a26e681 feat: Add Recurrence Pattern and Update Expense Schema
- Introduced a new `RecurrencePattern` model to manage recurrence details for expenses, allowing for daily, weekly, monthly, and yearly patterns.
- Updated the `Expense` model to include fields for recurrence management, such as `is_recurring`, `recurrence_pattern_id`, and `next_occurrence`.
- Modified the database schema to reflect these changes, including alterations to existing columns and the removal of obsolete fields.
- Enhanced the expense creation logic to accommodate recurring expenses and updated related CRUD operations accordingly.
- Implemented necessary migrations to ensure database integrity and support for the new features.
2025-06-01 14:23:05 +02:00
mohamad
ee6d96d9ec feat: Revamp ChoresPage with enhanced UI and functionality
- Redesigned the ChoresPage layout to improve user experience, introducing a new header with tabs for viewing overdue, today, upcoming, all pending, and completed chores.
- Implemented a calendar view for better visualization of chores, allowing users to navigate between months and add chores directly on specific dates.
- Added a loading state to enhance feedback during data fetching.
- Introduced a shortcuts modal to display keyboard shortcuts for improved accessibility.
- Enhanced chore card interactions with improved styling and responsiveness, including a transition effect for list items.
- Updated chore management functionalities, including create, edit, and delete operations, with better validation and auto-save features.
2025-06-01 13:19:28 +02:00
mohamad
8c52bbb307 feat: Integrate Storybook for component development 2025-05-31 14:43:59 +02:00
mohamad
ce67570cfb feat: Update deployment workflow and enhance ListDetailPage functionality
- Modified the production deployment workflow to trigger on pushes to the 'prod' branch and updated Docker registry login to use Gitea Container Registry.
- Enhanced ListDetailPage.vue to improve loading states and error handling, introducing a new loading mechanism for items and utilizing session storage for cached data.
- Implemented Intersection Observer for pre-fetching list details to optimize user experience during navigation.
- Improved touch feedback for list cards and optimized styles for mobile responsiveness.
2025-05-31 14:08:40 +02:00
mohamad
cb51186830 feat: Add production deployment configuration and environment setup
- Introduced `docker-compose.prod.yml` to define services for production deployment, including PostgreSQL, FastAPI backend, frontend, and Redis.
- Created `env.production.template` to outline necessary environment variables for production, ensuring sensitive data is not committed.
- Added `PRODUCTION.md` as a deployment guide detailing the setup process using Docker Compose and Gitea Actions for CI/CD.
- Implemented Gitea workflows for build, test, and deployment processes to streamline production updates.
- Updated backend and frontend Dockerfiles for optimized production builds and configurations.
- Enhanced application settings to support environment-specific configurations, including CORS and health checks.
2025-05-28 08:23:22 +02:00
mohamad
84b046508a feat: Implement refresh token functionality in authentication flow
- Added support for refresh tokens in the authentication backend, allowing users to obtain new access tokens using valid refresh tokens.
- Created a new `BearerResponseWithRefresh` model to structure responses containing both access and refresh tokens.
- Updated the `AuthenticationBackend` to handle login and logout processes with refresh token support.
- Introduced a new `/auth/jwt/refresh` endpoint to facilitate token refreshing, validating the refresh token and generating new tokens as needed.
- Modified OAuth callback logic to generate and return both access and refresh tokens upon successful authentication.
- Updated frontend API service to send the refresh token in the Authorization header for token refresh requests.
2025-05-25 12:51:02 +02:00
mohamad
a0d67f6c66 feat: Add comprehensive notes and tasks for project stabilization and enhancements
- Introduced a new `notes.md` file to document critical tasks and progress for stabilizing the core functionality of the MitList application.
- Documented the status and findings for key tasks, including backend financial logic fixes, frontend expense split settlement implementation, and core authentication flow reviews.
- Outlined remaining work for production deployment, including secret management, CI/CD pipeline setup, and performance optimizations.
- Updated the logging configuration to change the log level to WARNING for production readiness.
- Enhanced the database connection settings to disable SQL query logging in production.
- Added a new endpoint to list all chores for improved user experience and optimized database queries.
- Implemented various CRUD operations for chore assignments, including creation, retrieval, updating, and deletion.
- Updated frontend components and services to support new chore assignment features and improved error handling.
- Enhanced the expense management system with new fields and improved API interactions for better user experience.
2025-05-24 21:36:57 +02:00
mohamad
81577ac7e8 feat: Add Recurrence Pattern and Update Expense Schema
- Introduced a new `RecurrencePattern` model to manage recurrence details for expenses, allowing for daily, weekly, monthly, and yearly patterns.
- Updated the `Expense` model to include fields for recurrence management, such as `is_recurring`, `recurrence_pattern_id`, and `next_occurrence`.
- Modified the database schema to reflect these changes, including alterations to existing columns and the removal of obsolete fields.
- Enhanced the expense creation logic to accommodate recurring expenses and updated related CRUD operations accordingly.
- Implemented necessary migrations to ensure database integrity and support for the new features.
2025-05-23 21:01:49 +02:00
google-labs-jules[bot]
b0100a2e96 Fix: Ensure financial accuracy in cost splitting and balances
I've refactored the group balance summary logic to correctly account for
SettlementActivity. A SettlementActivity now reduces your
effective total_share_of_expenses, ensuring that net balances within
a group sum to zero. Previously, SettlementActivity amounts were
incorrectly added to total_settlements_paid, skewing balance
calculations.

I updated the existing `test_group_balance_summary_with_settlement_activity`
to assert the corrected balance outcomes.

I also added an extensive suite of API-level tests for:
- All expense splitting types (EQUAL, EXACT_AMOUNTS, PERCENTAGE, SHARES, ITEM_BASED),
  covering various scenarios and input validations.
- Group balance summary calculations, including multiple scenarios with
  SettlementActivity, partial payments, multiple expenses, and
  interactions with generic settlements. All balance tests verify that
  the sum of net balances is zero.

The CRUD operations for expenses and settlement activities were reviewed
and found to be sound, requiring no changes for this fix.

This resolves the flawed logic identified in
`be/tests/api/v1/test_costs.py` (test_group_balance_summary_with_settlement_activity)
and ensures that backend financial calculations are provably correct.
2025-05-22 17:04:46 +00:00
Mohamad.Elsena
5018ce02f7 feat: Implement recurring expenses feature with scheduling and management
- Added support for recurring expenses, allowing users to define recurrence patterns (daily, weekly, monthly, yearly) for expenses.
- Introduced `RecurrencePattern` model to manage recurrence details and linked it to the `Expense` model.
- Implemented background job scheduling using APScheduler to automatically generate new expenses based on defined patterns.
- Updated expense creation logic to handle recurring expenses, including validation and database interactions.
- Enhanced frontend components to allow users to create and manage recurring expenses through forms and lists.
- Updated documentation to reflect new features and usage guidelines for recurring expenses.
2025-05-22 16:37:14 +02:00
Mohamad.Elsena
52fc33b472 feat: Add CreateExpenseForm component and integrate into ListDetailPage
- Introduced CreateExpenseForm.vue for creating new expenses with fields for description, total amount, split type, and date.
- Integrated the CreateExpenseForm into ListDetailPage.vue, allowing users to add expenses directly from the list view.
- Enhanced UI with a modal for the expense creation form and added validation for required fields.
- Updated styles for consistency across the application.
- Implemented logic to refresh the expense list upon successful creation of a new expense.
2025-05-22 13:05:49 +02:00
whtvrboo
e7b072c2bd
Merge pull request #4 from whtvrboo/feat/traceable-expense-settlement
feat: Implement traceable expense splitting and settlement activities
2025-05-22 09:06:13 +02:00
google-labs-jules[bot]
f1152c5745 feat: Implement traceable expense splitting and settlement activities
Backend:
- Added `SettlementActivity` model to track payments against specific expense shares.
- Added `status` and `paid_at` to `ExpenseSplit` model.
- Added `overall_settlement_status` to `Expense` model.
- Implemented CRUD for `SettlementActivity`, including logic to update parent expense/split statuses.
- Updated `Expense` CRUD to initialize new status fields.
- Defined Pydantic schemas for `SettlementActivity` and updated `Expense/ExpenseSplit` schemas.
- Exposed API endpoints for creating/listing settlement activities and settling shares.
- Adjusted group balance summary logic to include settlement activities.
- Added comprehensive backend unit and API tests for new functionality.

Frontend (Foundation & TODOs due to my current capabilities):
- Created TypeScript interfaces for all new/updated models.
- Set up `listDetailStore.ts` with an action to handle `settleExpenseSplit` (API call is a placeholder) and refresh data.
- Created `SettleShareModal.vue` component for payment confirmation.
- Added unit tests for the new modal and store logic.
- Updated `ListDetailPage.vue` to display detailed expense/share statuses and settlement activities.
- `mitlist_doc.md` updated to reflect all backend changes and current frontend status.
- A `TODO.md` (implicitly within `mitlist_doc.md`'s new section) outlines necessary manual frontend integrations for `api.ts` and `ListDetailPage.vue` to complete the 'Settle Share' UI flow.

This set of changes provides the core backend infrastructure for precise expense share tracking and settlement, and lays the groundwork for full frontend integration.
2025-05-22 07:05:31 +00:00
whtvrboo
8bb960b605
Merge pull request #3 from whtvrboo/feat/frontend-tests
feat: Add comprehensive unit and E2E tests for Vue frontend
2025-05-22 08:43:11 +02:00
google-labs-jules[bot]
0bf7a7cb49 feat: Add comprehensive unit and E2E tests for Vue frontend
This commit introduces a suite of unit and E2E tests for the Vue.js
frontend, significantly improving code coverage and reliability.

Unit Test Summary:
- Setup: Configured Vitest and @vue/test-utils.
- Core UI Components: Added tests for EssentialLink, SocialLoginButtons,
  and NotificationDisplay.
- Pinia Stores: Implemented tests for auth, notifications, and offline
  stores, including detailed testing of actions, getters, and state
  management. Offline store tests were adapted to its event-driven design.
- Services:
  - api.ts: Tested Axios client config, interceptors (auth token refresh),
    and wrapper methods.
  - choreService.ts & groupService.ts: Tested all existing service
    functions for CRUD operations, mocking API interactions.
- Pages:
  - AccountPage.vue: Tested rendering, data fetching, form submissions
    (profile, password, preferences), and error handling.
  - ChoresPage.vue: Tested rendering, chore display (personal & grouped),
    CRUD modals, and state handling (loading, error, empty).
  - LoginPage.vue: Verified existing comprehensive tests.

E2E Test (Playwright) Summary:
- Auth (`auth.spec.ts`):
  - User signup, login, and logout flows.
  - Logout test updated with correct UI selectors.
- Group Management (`groups.spec.ts`):
  - User login handled via `beforeAll` and `storageState`.
  - Create group and view group details.
  - Update and Delete group tests are skipped as corresponding UI
    functionality is not present in GroupDetailPage.vue.
  - Selectors updated based on component code.
- List Management (`lists.spec.ts`):
  - User login handled similarly.
  - Create list (within a group), view list, add item to list,
    and mark item as complete.
  - Delete list test is skipped as corresponding UI functionality
    is not present.
  - Selectors based on component code.

This work establishes a strong testing foundation for the frontend.
Skipped E2E tests highlight areas where UI functionality for certain
CRUD operations (group update/delete, list delete) may need to be added
if desired.
2025-05-22 06:41:35 +00:00
whtvrboo
653788cfba
Merge pull request #2 from whtvrboo/feat/frontend-tests
feat: Add comprehensive unit tests for Vue frontend
2025-05-21 21:08:10 +02:00
google-labs-jules[bot]
c0dcccd970 feat: Add comprehensive unit tests for Vue frontend
This commit introduces a suite of unit tests for the Vue.js frontend,
significantly improving code coverage and reliability.

Key areas covered:
- **Setup**: Configured Vitest and @vue/test-utils.
- **Core UI Components**: Added tests for SocialLoginButtons and NotificationDisplay.
- **Pinia Stores**: Implemented tests for auth, notifications, and offline stores,
  including detailed testing of actions, getters, and state management.
  Offline store tests were adapted to its event-driven design.
- **Services**:
  - `api.ts`: Tested Axios client config, interceptors (auth token refresh),
    and wrapper methods.
  - `choreService.ts` & `groupService.ts`: Tested all existing service
    functions for CRUD operations, mocking API interactions.
- **Pages**:
  - `AccountPage.vue`: Tested rendering, data fetching, form submissions
    (profile, password, preferences), and error handling.
  - `ChoresPage.vue`: Tested rendering, chore display (personal & grouped),
    CRUD modals, and state handling (loading, error, empty).
  - `LoginPage.vue`: Verified existing comprehensive tests.

These tests provide a solid foundation for frontend testing. The next planned
step is to enhance E2E tests using Playwright.
2025-05-21 19:07:34 +00:00
mohamad
0204fb6f3a Add .env to .gitignore to prevent environment configuration files from being tracked 2025-05-21 20:23:57 +02:00
Mohamad.Elsena
29ccab2f7e feat: Implement chore management feature with personal and group chores
This commit introduces a comprehensive chore management system, allowing users to create, manage, and track both personal and group chores. Key changes include:
- Addition of new API endpoints for personal and group chores in `be/app/api/v1/endpoints/chores.py`.
- Implementation of chore models and schemas to support the new functionality in `be/app/models.py` and `be/app/schemas/chore.py`.
- Integration of chore services in the frontend to handle API interactions for chore management.
- Creation of new Vue components for displaying and managing chores, including `ChoresPage.vue` and `PersonalChoresPage.vue`.
- Updates to the router to include chore-related routes and navigation.

This feature enhances user collaboration and organization within shared living environments, aligning with the project's goal of streamlining household management.
2025-05-21 18:18:22 +02:00
Mohamad.Elsena
ed222c840a Remove obsolete Alembic migration files related to chore tables and assignments. This cleanup eliminates unused migration scripts that are no longer needed in the project. 2025-05-21 13:38:00 +02:00
whtvrboo
04b0ad7059
Merge pull request #1 from whtvrboo/feat/chore-management-backend-core
feat: Initial backend setup for Chore Management (Models, Migrations,…
2025-05-21 13:23:07 +02:00
google-labs-jules[bot]
16c9abb16a feat: Initial backend setup for Chore Management (Models, Migrations, Schemas, Chore CRUD)
I've implemented the foundational backend components for the chore management feature.

Key changes include:
- Definition of `Chore` and `ChoreAssignment` SQLAlchemy models in `be/app/models.py`.
- Addition of corresponding relationships to `User` and `Group` models.
- Creation of an Alembic migration script (`manual_0001_add_chore_tables.py`) for the new database tables. (Note: Migration not applied in sandbox).
- Implementation of a utility function `calculate_next_due_date` in `be/app/core/chore_utils.py` for determining chore due dates based on recurrence rules.
- Definition of Pydantic schemas (`ChoreCreate`, `ChorePublic`, `ChoreAssignmentCreate`, `ChoreAssignmentPublic`, etc.) in `be/app/schemas/chore.py` for API data validation.
- Implementation of CRUD operations (create, read, update, delete) for Chores in `be/app/crud/chore.py`.

This commit lays the groundwork for adding Chore Assignment CRUD operations and the API endpoints for both chores and their assignments.
2025-05-21 09:28:38 +00:00
Mohamad.Elsena
185e89351e Update expense creation to include current user ID for better tracking. Introduce a utility function to round monetary values to two decimal places. Enhance ListDetailPage styles by adding overflow handling for improved UI layout. 2025-05-21 09:34:51 +02:00
Mohamad
17bebbfab8 Refactor UI styles across multiple pages to enhance consistency and responsiveness. Update background colors to use CSS variables for improved theming in GroupDetailPage, GroupsPage, ListDetailPage, and ListsPage. Ensure all components align with the overall design system for a cohesive user experience. 2025-05-20 10:42:55 +02:00
Mohamad
fc355077ab Enhance database connection management by adding pool_pre_ping to ensure connections are live. Update connection pool settings for improved reliability. 2025-05-20 10:42:34 +02:00
mohamad
eb19230b22 Refactor frontend components and styles for improved UI consistency and responsiveness. Update HTML structure in index.html, enhance SCSS variables in valerie-ui.scss, and implement new layout styles across various pages. Adjust component props and event emissions for better data handling in CreateListModal and ConflictResolutionDialog. Add Material Icons for better visual representation in navigation. Ensure all changes align with the overall design system for a cohesive user experience. 2025-05-20 01:19:52 +02:00
mohamad
c8cdbd571e Add FastAPI database transaction management strategy and update requirements
Introduce a new technical specification for managing database transactions in FastAPI, ensuring ACID compliance through standardized practices. The specification outlines transaction handling for API endpoints, CRUD functions, and non-API operations, emphasizing the use of context managers and error handling.

Additionally, update the requirements file to include new testing dependencies for async operations, enhancing the testing framework for the application.
2025-05-20 01:19:37 +02:00
mohamad
d6d19397d3 Refactor authentication and user management to standardize session handling across OAuth flows. Update configuration to include default token type for JWT authentication. Enhance error handling with new exceptions for user operations, and clean up test cases for better clarity and reliability. 2025-05-20 01:19:21 +02:00
mohamad
323ce210ce Refactor database session management across multiple API endpoints to utilize a transactional session, enhancing consistency in transaction handling. Update dependencies in costs, financials, groups, health, invites, items, and lists modules for improved error handling and reliability. 2025-05-20 01:19:06 +02:00
mohamad
98b2f907de Refactor CRUD operations across multiple modules to standardize transaction handling using context managers, improving error logging and rollback mechanisms. Enhance error handling for database operations in expense, group, invite, item, list, settlement, and user modules, ensuring specific exceptions are raised for integrity and connection issues. 2025-05-20 01:18:49 +02:00
mohamad
e4175db4aa Implement test fixtures for async database sessions and enhance test coverage for CRUD operations. Introduce mock settings for financial endpoints and improve error handling in user and settlement tests. Refactor existing tests to utilize async mocks for better reliability and clarity. 2025-05-20 01:18:31 +02:00
mohamad
2b7816cf33 Update user model migration to include secure password hashing; set default hashed password for existing users. Refactor database session management for improved transaction handling and ensure session closure after use. 2025-05-20 01:17:47 +02:00
mohamad
5abe7839f1 Enhance configuration and error handling in the application; add new error messages for OCR and authentication processes. Refactor database session management to include transaction handling, and update models to track user creation for expenses and settlements. Update API endpoints to improve cost-sharing calculations and adjust invite management routes for clarity. 2025-05-17 13:56:17 +02:00
mohamad
c2aa62fa03 Update user model migration to set invalid password placeholder; enhance invite management with new endpoints for active invites and improved error handling in group invite creation. Refactor frontend to fetch and display active invite codes. 2025-05-16 22:31:44 +02:00
mohamad
f2ac73502c Enhance OAuth token handling in authentication flow; update frontend to support access and refresh tokens. Refactor auth store to manage refresh token state and improve token storage logic. 2025-05-16 22:08:56 +02:00
mohamad
9ff293b850 Ensure database transaction is committed after list creation in the API endpoint; improve reliability of list creation process. 2025-05-16 22:08:47 +02:00
mohamad
7a88ea258a Refactor database session management and exception handling across CRUD operations; streamline transaction handling in expense, group, invite, item, list, settlement, and user modules for improved reliability and clarity. Introduce specific operation errors for better error reporting. 2025-05-16 21:54:29 +02:00
mohamad
515534dcce Add conflict resolution for list creation and updates; implement offline action handling for list items. Enhance service worker with background sync capabilities and improve UI for offline states. 2025-05-16 02:07:41 +02:00
mohamad
3f0cfff9f1 Refactor authentication endpoints and user management; update CORS settings and JWT handling for improved security and compatibility with FastAPI-Users. Remove deprecated user-related endpoints and streamline API structure. 2025-05-14 01:04:09 +02:00
mohamad
72b988b79b Refactor authentication and database session handling; update user schemas for enhanced functionality and compatibility with FastAPI-Users. 2025-05-14 00:24:51 +02:00
mohamad
1c08e57afd fastapi-users, oauth, docker support, cleanup 2025-05-14 00:10:31 +02:00
mohamad
29682b7e9c Refactor input elements for consistency and readability; update styles for better alignment and spacing in SignupPage and ListDetailPage. 2025-05-13 22:46:40 +02:00
mohamad
18f759aa7c Add Sentry integration for error tracking; update requirements and configuration files. Introduce new Alembic migration for missing indexes and constraints in the database schema. 2025-05-13 21:45:45 +02:00
mohamad
9583aa4bab Remove repomix-output.xml file; update package.json and package-lock.json to include Supabase dependencies for enhanced authentication and real-time features. 2025-05-13 21:29:17 +02:00
mohamad
cacfb2a5e8 commit i guess 2025-05-13 20:33:02 +02:00
mohamad
227a3d6186 migrate to vue+vueuse+valerieui bc quasar customisation is sad 2025-05-13 19:23:15 +02:00
mohamad
9230d1f626 Remove ExampleComponent and example-store files 2025-05-09 00:11:02 +02:00
mohamad
5a910a29e2 Refactor data types in ConflictResolutionDialog and OfflineIndicator components for improved type safety; update OfflineAction interface to use 'unknown' instead of 'any' for data property, and enhance action label handling in OfflineIndicator for better clarity. 2025-05-09 00:10:37 +02:00
mohamad
db5f2d089e Implement offline functionality with conflict resolution; add OfflineIndicator component for user notifications, integrate offline action management in the store, and enhance service worker caching strategies for improved performance. 2025-05-08 23:52:11 +02:00
mohamad
7bbec7ad5f Add cost summary feature to ListDetailPage; implement API endpoints for costs and enhance UI with a dialog for displaying cost details, including user balances and total costs. 2025-05-08 23:38:07 +02:00
mohamad
f6a50e0d6a Update costs router in API to include prefix for improved endpoint organization and clarity. 2025-05-08 23:37:54 +02:00
mohamad
4283fe8a19 Refactor axios error handling to throw new Error instances for better stack trace clarity; update component lifecycle methods in AccountPage, ListsPage, and ListDetailPage to use void for asynchronous calls; adjust polling interval type in ListDetailPage for improved type safety. 2025-05-08 23:28:33 +02:00
mohamad
0dbee3bb4b Refactor logging in item API endpoints to use a local variable for user email; enhance clarity and maintainability of log messages. Update transaction management in item CRUD operations to ensure proper commit handling and version conflict checks. 2025-05-08 23:27:51 +02:00
mohamad
d99aef9d11 Add price input field for completed items in ListDetailPage; update item API endpoint to use new configuration; ensure price handling is consistent and type-safe during updates. 2025-05-08 23:21:12 +02:00
mohamad
8b6ddb91f8 Update ListDetailPage to change item quantity type from number to string; enhance logging for API calls and component lifecycle events for better debugging. 2025-05-08 22:53:38 +02:00
mohamad
e484c9e9a8 Enhance error handling and transaction management in item creation; explicitly commit changes and rollback on exceptions to ensure database integrity. 2025-05-08 22:53:26 +02:00
mohamad
f52b47f6df Refactor CRUD operations in group, item, and list modules to remove unnecessary transaction context; enhance error handling and improve code readability. Update API endpoint for OCR processing in configuration and add confirmation dialogs for item actions in ListDetailPage. 2025-05-08 22:34:07 +02:00
mohamad
262505c898 Refactor API integration across multiple components; introduce centralized API configuration, enhance error handling, and improve type safety in API calls for better maintainability and user experience. 2025-05-08 22:22:46 +02:00
mohamad
7836672f64 Refactor API calls in GroupDetailPage and GroupsPage components to remove versioning; enhance error handling for group creation and joining processes. 2025-05-08 21:59:13 +02:00
mohamad
fe252cfac8 Refactor CreateListModal and GroupDetailPage components; improve error handling and update API calls in ListsPage and ListDetailPage for better type safety and user feedback. 2025-05-08 21:35:02 +02:00
Mohamad
4f32670bda missing import 2025-05-08 19:12:23 +02:00
Mohamad
ff25af26f5 bugfix: hashed_password to password_hash 2025-05-08 19:12:00 +02:00
Mohamad
6198a29768 migrations fresh start 2025-05-08 19:11:21 +02:00
Mohamad.Elsena
c7fdb60130 Add database error messages and improve exception handling in CRUD operations 2025-05-08 16:00:12 +02:00
Mohamad.Elsena
5186892df6 Add CreateListModal and ListDetailPage components; enhance ListsPage with loading/error states and group filtering 2025-05-08 15:59:28 +02:00
Mohamad.Elsena
7b2c5c9ebd Svelte to Quasar 2025-05-08 15:02:09 +02:00
Mohamad
e3024ccd07 remove old cost splitter 2025-05-08 14:31:34 +02:00
mohamad
bbb3c3b7df 0705 2025-05-08 00:56:26 +02:00
mohamad
423d345fdf add_version_to_lists_and_items 2025-05-07 23:30:23 +02:00
mohamad
d2d484c327 001 2025-05-07 20:16:16 +02:00
374 changed files with 60609 additions and 8840 deletions

View File

@ -0,0 +1,57 @@
---
description: FastAPI Database Transactions
globs:
alwaysApply: false
---
## FastAPI Database Transaction Management: Technical Specification
**Objective:** Ensure atomic, consistent, isolated, and durable (ACID) database operations through a standardized transaction management strategy.
**1. API Endpoint Transaction Scope (Primary Strategy):**
* **Mechanism:** A FastAPI dependency `get_transactional_session` (from `app.database` or `app.core.dependencies`) wraps database-modifying API request handlers.
* **Behavior:**
* `async with AsyncSessionLocal() as session:` obtains a session.
* `async with session.begin():` starts a transaction.
* **Commit:** Automatic on successful completion of the `yield session` block (i.e., endpoint handler success).
* **Rollback:** Automatic on any exception raised from the `yield session` block.
* **Usage:** Endpoints performing CUD (Create, Update, Delete) operations **MUST** use `db: AsyncSession = Depends(get_transactional_session)`.
* **Read-Only Endpoints:** May use `get_async_session` (alias `get_db`) or `get_transactional_session` (results in an empty transaction).
**2. CRUD Layer Function Design:**
* **Transaction Participation:** CRUD functions (in `app/crud/`) operate on the session provided by the caller.
* **Composability Pattern:** Employ `async with db.begin_nested() if db.in_transaction() else db.begin():` to wrap database modification logic within the CRUD function.
* If an outer transaction exists (e.g., from `get_transactional_session`), `begin_nested()` creates a **savepoint**. The `async with` block commits/rolls back this savepoint.
* If no outer transaction exists (e.g., direct call from a script), `begin()` starts a **new transaction**. The `async with` block commits/rolls back this transaction.
* **NO Direct `db.commit()` / `db.rollback()`:** CRUD functions **MUST NOT** call these directly. The `async with begin_nested()/begin()` block and the outermost transaction manager are responsible.
* **`await db.flush()`:** Use only when necessary within the `async with` block to:
1. Obtain auto-generated IDs for subsequent operations in the *same* transaction.
2. Force database constraint checks mid-transaction.
* **Error Handling:** Raise specific custom exceptions (e.g., `ListNotFoundError`, `DatabaseIntegrityError`). These exceptions will trigger rollbacks in the managing transaction contexts.
**3. Non-API Operations (Background Tasks, Scripts):**
* **Explicit Management:** These contexts **MUST** manage their own session and transaction lifecycles.
* **Pattern:**
```python
async with AsyncSessionLocal() as session:
async with session.begin(): # Manages transaction for the task's scope
try:
# Call CRUD functions, which will participate via savepoints
await crud_operation_1(db=session, ...)
await crud_operation_2(db=session, ...)
# Commit is handled by session.begin() context manager on success
except Exception:
# Rollback is handled by session.begin() context manager on error
raise
```
**4. Key Principles Summary:**
* **API:** `get_transactional_session` for CUD.
* **CRUD:** Use `async with db.begin_nested() if db.in_transaction() else db.begin():`. No direct commit/rollback. Use `flush()` strategically.
* **Background Tasks:** Explicit `AsyncSessionLocal()` and `session.begin()` context managers.
This strategy ensures a clear separation of concerns, promotes composable CRUD operations, and centralizes final transaction control at the appropriate layer.

View File

@ -0,0 +1,71 @@
name: Build and Test
on:
push:
branches:
- main
- develop
pull_request:
branches:
- main
- develop
jobs:
build-and-test:
runs-on: ubuntu-latest
services:
postgres:
image: postgres:17-alpine
env:
POSTGRES_USER: testuser
POSTGRES_PASSWORD: testpassword
POSTGRES_DB: testdb
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.11'
- name: Set up Node.js
uses: actions/setup-node@v3
with:
node-version: '24'
- name: Install backend dependencies
working-directory: ./be
run: |
pip install --upgrade pip
pip install -r requirements.txt
- name: Install frontend dependencies
working-directory: ./fe
run: npm ci
- name: Run backend tests
working-directory: ./be
env:
DATABASE_URL: postgresql+asyncpg://testuser:testpassword@localhost:5432/testdb
SECRET_KEY: testsecretkey
GEMINI_API_KEY: testgeminikey # Mock or skip tests requiring this if not available
SESSION_SECRET_KEY: testsessionsecret
run: pytest
- name: Build frontend
working-directory: ./fe
run: npm run build
# Add frontend test command if you have one e.g. npm test
# - name: Run frontend tests
# working-directory: ./fe
# run: npm test

View File

@ -0,0 +1,214 @@
name: Deploy to Production, build images and push to Gitea Registry
on:
pull_request:
types: [closed]
branches:
- prod
jobs:
build_and_push:
if: github.event.pull_request.merged == true
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Debug context variables
run: |
echo "Actor: ${{ gitea.actor }}"
echo "Repository: ${{ gitea.repository_name }}"
echo "Repository owner: ${{ gitea.repository_owner }}"
echo "Event repository name: ${{ gitea.event.repository.name }}"
echo "Event repository full name: ${{ gitea.event.repository.full_name }}"
echo "Event repository owner login: ${{ gitea.event.repository.owner.login }}"
- name: Login to Gitea Registry
env:
GITEA_USERNAME: ${{ secrets.ME_USERNAME }}
GITEA_PASSWORD: ${{ secrets.ME_PASSWORD }}
run: |
echo $GITEA_PASSWORD | docker login git.vinylnostalgia.com -u $GITEA_USERNAME --password-stdin
- name: Set repository variables
id: vars
run: |
REPO_NAME="${{ gitea.repository_name }}"
ACTOR="${{ gitea.actor }}"
# Use fallback if variables are empty
if [ -z "$REPO_NAME" ]; then
REPO_NAME="${{ gitea.event.repository.name }}"
fi
if [ -z "$ACTOR" ]; then
ACTOR="${{ gitea.event.repository.owner.login }}"
fi
echo "actor=$ACTOR" >> $GITHUB_OUTPUT
echo "repo_name=$REPO_NAME" >> $GITHUB_OUTPUT
echo "Using ACTOR: $ACTOR"
echo "Using REPO_NAME: $REPO_NAME"
- name: Build backend image with optimizations
env:
GITEA_USERNAME: ${{ secrets.ME_USERNAME }}
GITEA_PASSWORD: ${{ secrets.ME_PASSWORD }}
run: |
REPO_NAME="${{ gitea.repository_name }}"
ACTOR="${{ gitea.actor }}"
# Use fallback if variables are empty
if [ -z "$REPO_NAME" ]; then
REPO_NAME="${{ gitea.event.repository.name }}"
fi
if [ -z "$ACTOR" ]; then
ACTOR="${{ gitea.event.repository.owner.login }}"
fi
echo "Building backend image..."
echo "Using ACTOR: $ACTOR"
echo "Using REPO_NAME: $REPO_NAME"
# Build with BuildKit optimizations
DOCKER_BUILDKIT=1 docker build \
--compress \
--no-cache \
--squash \
--build-arg BUILDKIT_INLINE_CACHE=1 \
-t git.vinylnostalgia.com/$ACTOR/$REPO_NAME-backend:latest \
./be -f ./be/Dockerfile.prod
echo "Backend image built successfully"
- name: Push backend image with retry logic
env:
GITEA_USERNAME: ${{ secrets.ME_USERNAME }}
GITEA_PASSWORD: ${{ secrets.ME_PASSWORD }}
run: |
REPO_NAME="${{ gitea.repository_name }}"
ACTOR="${{ gitea.actor }}"
# Use fallback if variables are empty
if [ -z "$REPO_NAME" ]; then
REPO_NAME="${{ gitea.event.repository.name }}"
fi
if [ -z "$ACTOR" ]; then
ACTOR="${{ gitea.event.repository.owner.login }}"
fi
# Push with retries and compression
max_retries=5
retry_count=0
base_wait=10
while [ $retry_count -lt $max_retries ]; do
echo "Pushing backend image (attempt $((retry_count + 1)) of $max_retries)..."
if docker push git.vinylnostalgia.com/$ACTOR/$REPO_NAME-backend:latest; then
echo "Backend image pushed successfully"
break
fi
retry_count=$((retry_count + 1))
if [ $retry_count -lt $max_retries ]; then
wait_time=$((base_wait * retry_count))
echo "Push failed, retrying in $wait_time seconds..."
sleep $wait_time
fi
done
if [ $retry_count -eq $max_retries ]; then
echo "Failed to push backend image after $max_retries attempts"
exit 1
fi
- name: Build frontend image with optimizations
env:
GITEA_USERNAME: ${{ secrets.ME_USERNAME }}
GITEA_PASSWORD: ${{ secrets.ME_PASSWORD }}
run: |
REPO_NAME="${{ gitea.repository_name }}"
ACTOR="${{ gitea.actor }}"
# Use fallback if variables are empty
if [ -z "$REPO_NAME" ]; then
REPO_NAME="${{ gitea.event.repository.name }}"
fi
if [ -z "$ACTOR" ]; then
ACTOR="${{ gitea.event.repository.owner.login }}"
fi
echo "Building frontend image..."
echo "Using ACTOR: $ACTOR"
echo "Using REPO_NAME: $REPO_NAME"
# Build with BuildKit optimizations
DOCKER_BUILDKIT=1 docker build \
--compress \
--no-cache \
--squash \
--build-arg BUILDKIT_INLINE_CACHE=1 \
-t git.vinylnostalgia.com/$ACTOR/$REPO_NAME-frontend:latest \
./fe -f ./fe/Dockerfile.prod
echo "Frontend image built successfully"
- name: Push frontend image with retry logic
env:
GITEA_USERNAME: ${{ secrets.ME_USERNAME }}
GITEA_PASSWORD: ${{ secrets.ME_PASSWORD }}
run: |
REPO_NAME="${{ gitea.repository_name }}"
ACTOR="${{ gitea.actor }}"
# Use fallback if variables are empty
if [ -z "$REPO_NAME" ]; then
REPO_NAME="${{ gitea.event.repository.name }}"
fi
if [ -z "$ACTOR" ]; then
ACTOR="${{ gitea.event.repository.owner.login }}"
fi
# Push with retries and exponential backoff
max_retries=5
retry_count=0
base_wait=10
while [ $retry_count -lt $max_retries ]; do
echo "Pushing frontend image (attempt $((retry_count + 1)) of $max_retries)..."
if docker push git.vinylnostalgia.com/$ACTOR/$REPO_NAME-frontend:latest; then
echo "Frontend image pushed successfully"
break
fi
retry_count=$((retry_count + 1))
if [ $retry_count -lt $max_retries ]; then
wait_time=$((base_wait * retry_count))
echo "Push failed, retrying in $wait_time seconds..."
sleep $wait_time
fi
done
if [ $retry_count -eq $max_retries ]; then
echo "Failed to push frontend image after $max_retries attempts"
exit 1
fi
- name: Cleanup Docker resources
if: always()
run: |
echo "Cleaning up Docker resources..."
docker system prune -af --volumes
docker logout git.vinylnostalgia.com
echo "Cleanup completed"
- name: Show final image sizes
if: always()
run: |
echo "Final image sizes:"
docker images --format "table {{.Repository}}\t{{.Tag}}\t{{.Size}}" | grep -E "(vinylnostalgia|REPOSITORY)"

View File

@ -1,30 +1,55 @@
# Git files
# Git
.git
.gitignore
# Virtual environment
.venv
venv/
env/
ENV/
*.env # Ignore local .env files within the backend directory if any
# Python cache
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# IDE files
# Virtual Environment
venv/
ENV/
# IDE
.idea/
.vscode/
*.swp
*.swo
# Test artifacts
# Logs
*.log
# Local development
.env
.env.local
.env.*.local
# Docker
Dockerfile*
docker-compose*
.dockerignore
# Tests
tests/
test/
.pytest_cache/
.coverage
htmlcov/
.coverage*
# Other build/temp files
*.egg-info/
dist/
build/
*.db # e.g., sqlite temp dbs

View File

@ -1,35 +1,75 @@
# be/Dockerfile
# Choose a suitable Python base image
FROM python:3.11-slim
# Multi-stage build for production - optimized for size
FROM python:3.11-slim AS builder
# Set environment variables
ENV PYTHONDONTWRITEBYTECODE 1 # Prevent python from writing pyc files
ENV PYTHONUNBUFFERED 1 # Keep stdout/stderr unbuffered
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PIP_NO_CACHE_DIR=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1
# Set the working directory in the container
WORKDIR /app
# Install system dependencies if needed (e.g., for psycopg2 build)
# RUN apt-get update && apt-get install -y --no-install-recommends gcc build-essential libpq-dev && rm -rf /var/lib/apt/lists/*
# Install build dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
g++ \
libpq-dev \
&& rm -rf /var/lib/apt/lists/*
# Install Python dependencies
# Upgrade pip first
RUN pip install --no-cache-dir --upgrade pip
# Copy only requirements first to leverage Docker cache
COPY requirements.txt requirements.txt
# Install dependencies
RUN pip install --no-cache-dir -r requirements.txt
COPY requirements.txt .
RUN pip install --user --no-cache-dir -r requirements.txt
# Copy the rest of the application code into the working directory
COPY . .
# This includes your 'app/' directory, alembic.ini, etc.
# Production stage - minimal image
FROM python:3.11-slim AS production
# Expose the port the app runs on
# Set environment variables
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PATH=/home/appuser/.local/bin:$PATH
# Install only runtime dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
libpq5 \
curl \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
# Create non-root user
RUN groupadd -g 1001 appuser && \
useradd -u 1001 -g appuser -m appuser
# Copy Python packages from builder stage
COPY --from=builder /root/.local /home/appuser/.local
# Set working directory
WORKDIR /app
# Copy only necessary application files (be selective)
COPY --chown=appuser:appuser app/ ./app/
COPY --chown=appuser:appuser alembic/ ./alembic/
COPY --chown=appuser:appuser alembic.ini ./
COPY --chown=appuser:appuser *.py ./
COPY --chown=appuser:appuser requirements.txt ./
COPY --chown=appuser:appuser entrypoint.sh /app/entrypoint.sh
RUN chmod +x /app/entrypoint.sh
# Create logs directory
RUN mkdir -p /app/logs && chown -R appuser:appuser /app
# Switch to non-root user
USER appuser
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Expose port
EXPOSE 8000
# Command to run the application using uvicorn
# The default command for production (can be overridden in docker-compose for development)
# Note: Make sure 'app.main:app' correctly points to your FastAPI app instance
# relative to the WORKDIR (/app). If your main.py is directly in /app, this is correct.
CMD ["uvicorn", "app.main:app", "--host", "localhost", "--port", "8000"]
# Production command
ENTRYPOINT ["/app/entrypoint.sh"]
CMD ["uvicorn", "app.main:app", \
"--host", "0.0.0.0", \
"--port", "8000", \
"--workers", "4", \
"--access-log", \
"--log-level", "info"]

72
be/Dockerfile.prod Normal file
View File

@ -0,0 +1,72 @@
# Multi-stage build for production - optimized for size
FROM python:3.11-slim AS builder
# Set environment variables
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PIP_NO_CACHE_DIR=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1
# Install build dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
g++ \
libpq-dev \
&& rm -rf /var/lib/apt/lists/*
# Install Python dependencies
COPY requirements.txt .
RUN pip install --user --no-cache-dir -r requirements.txt
# Production stage - minimal image
FROM python:3.11-slim AS production
# Set environment variables
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PATH=/home/appuser/.local/bin:$PATH
# Install only runtime dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
libpq5 \
curl \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
# Create non-root user
RUN groupadd -g 1001 appuser && \
useradd -u 1001 -g appuser -m appuser
# Copy Python packages from builder stage
COPY --from=builder /root/.local /home/appuser/.local
# Set working directory
WORKDIR /app
# Copy only necessary application files (be selective)
COPY --chown=appuser:appuser app/ ./app/
COPY --chown=appuser:appuser alembic/ ./alembic/
COPY --chown=appuser:appuser alembic.ini ./
COPY --chown=appuser:appuser *.py ./
COPY --chown=appuser:appuser requirements.txt ./
# Create logs directory
RUN mkdir -p /app/logs && chown -R appuser:appuser /app
# Switch to non-root user
USER appuser
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Expose port
EXPOSE 8000
# Production command
CMD ["uvicorn", "app.main:app", \
"--host", "0.0.0.0", \
"--port", "8000", \
"--workers", "4", \
"--access-log", \
"--log-level", "info"]

View File

@ -1,31 +1,28 @@
from logging.config import fileConfig
import os
import sys
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import create_async_engine
from alembic import context
# Ensure the 'app' directory is in the Python path
# Adjust the path if your project structure is different
sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), '..')))
# Import your app's Base and settings
from app.models import Base # Import Base from your models module
import app.models # Ensure all models are loaded and registered to app.database.Base
from app.database import Base as DatabaseBase # Explicitly get Base from database.py
from app.config import settings # Import settings to get DATABASE_URL
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
# Get alembic config
config = context.config
# Set the sqlalchemy.url from your application settings
# Use a synchronous version of the URL for Alembic's operations
sync_db_url = settings.DATABASE_URL.replace("+asyncpg", "") if settings.DATABASE_URL else None
if not sync_db_url:
# Ensure DATABASE_URL is available and use it directly
if not settings.DATABASE_URL:
raise ValueError("DATABASE_URL not found in settings for Alembic.")
config.set_main_option('sqlalchemy.url', sync_db_url)
config.set_main_option('sqlalchemy.url', settings.DATABASE_URL)
# Interpret the config file for Python logging.
# This line sets up loggers basically.
@ -36,26 +33,15 @@ if config.config_file_name is not None:
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata
target_metadata = DatabaseBase.metadata # Use metadata from app.database.Base
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
"""Run migrations in 'offline' mode."""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
@ -67,30 +53,32 @@ def run_migrations_offline() -> None:
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
async def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
connectable = create_async_engine(
settings.DATABASE_URL,
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
async with connectable.connect() as connection:
await connection.run_sync(_run_migrations)
await connectable.dispose()
def _run_migrations(connection):
context.configure(
connection=connection, target_metadata=target_metadata
connection=connection,
target_metadata=target_metadata,
compare_type=True,
compare_server_default=True
)
with context.begin_transaction():
context.run_migrations()
# This section only runs when executing alembic commands directly (not when imported)
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
import asyncio
asyncio.run(run_migrations_online())

74
be/alembic/migrations.py Normal file
View File

@ -0,0 +1,74 @@
"""
Async migrations handler for FastAPI application.
This file is separate from env.py to avoid Alembic context issues.
"""
import os
import sys
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import pool
from alembic.config import Config
from alembic.script import ScriptDirectory
from alembic.runtime.migration import MigrationContext
from alembic.operations import Operations
# Ensure the app directory is in the Python path
sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), '..')))
from app.database import Base as DatabaseBase
from app.config import settings
def _get_migration_fn(script_directory, current_rev):
"""Create a migration function that knows how to upgrade from current revision."""
def migration_fn(rev, context):
# Get all upgrade steps from current revision to head
revisions = script_directory._upgrade_revs("head", current_rev)
for revision_step in revisions:
# Access the revision string from the Script object, which is within the RevisionStep object
script = script_directory.get_revision(revision_step.revision.revision)
script.module.upgrade(context)
return migration_fn
async def run_migrations():
"""Run database migrations asynchronously."""
# Get alembic configuration and script directory
alembic_cfg = Config(os.path.join(os.path.dirname(__file__), '..', 'alembic.ini'))
script_directory = ScriptDirectory.from_config(alembic_cfg)
# Create async engine
engine = create_async_engine(
settings.DATABASE_URL,
poolclass=pool.NullPool,
)
async with engine.connect() as connection:
def get_current_rev(conn):
migration_context = MigrationContext.configure(
conn,
opts={
'target_metadata': DatabaseBase.metadata,
'script': script_directory
}
)
return migration_context.get_current_revision()
current_rev = await connection.run_sync(get_current_rev)
def upgrade_to_head(conn):
migration_context = MigrationContext.configure(
conn,
opts={
'target_metadata': DatabaseBase.metadata,
'script': script_directory,
'as_sql': False,
}
)
# Set the migration function
migration_context._migrations_fn = _get_migration_fn(script_directory, current_rev)
with migration_context.begin_transaction():
migration_context.run_migrations()
await connection.run_sync(upgrade_to_head)
await engine.dispose()

View File

@ -0,0 +1,305 @@
"""Initial schema setup
Revision ID: 0001_initial_schema
Revises:
Create Date: YYYY-MM-DD HH:MM:SS.ffffff
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '0001_initial_schema'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
user_role_enum = postgresql.ENUM('owner', 'member', name='userroleenum', create_type=False)
split_type_enum = postgresql.ENUM('EQUAL', 'EXACT_AMOUNTS', 'PERCENTAGE', 'SHARES', 'ITEM_BASED', name='splittypeenum', create_type=False)
expense_split_status_enum = postgresql.ENUM('unpaid', 'partially_paid', 'paid', name='expensesplitstatusenum', create_type=False)
expense_overall_status_enum = postgresql.ENUM('unpaid', 'partially_paid', 'paid', name='expenseoverallstatusenum', create_type=False)
recurrence_type_enum = postgresql.ENUM('DAILY', 'WEEKLY', 'MONTHLY', 'YEARLY', name='recurrencetypeenum', create_type=False)
chore_frequency_enum = postgresql.ENUM('one_time', 'daily', 'weekly', 'monthly', 'custom', name='chorefrequencyenum', create_type=False)
chore_type_enum = postgresql.ENUM('personal', 'group', name='choretypeenum', create_type=False)
def upgrade(context=None) -> None: # Add context=None for compatibility, real arg passed by Alembic
# Create enums
user_role_enum.create(op.get_bind(), checkfirst=True)
split_type_enum.create(op.get_bind(), checkfirst=True)
expense_split_status_enum.create(op.get_bind(), checkfirst=True)
expense_overall_status_enum.create(op.get_bind(), checkfirst=True)
recurrence_type_enum.create(op.get_bind(), checkfirst=True)
chore_frequency_enum.create(op.get_bind(), checkfirst=True)
chore_type_enum.create(op.get_bind(), checkfirst=True)
op.create_table('users',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('email', sa.String(), nullable=False),
sa.Column('hashed_password', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('is_active', sa.Boolean(), server_default=sa.text('true'), nullable=False),
sa.Column('is_superuser', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('is_verified', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False)
op.create_index(op.f('ix_users_name'), 'users', ['name'], unique=False)
op.create_table('groups',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('created_by_id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_groups_id'), 'groups', ['id'], unique=False)
op.create_index(op.f('ix_groups_name'), 'groups', ['name'], unique=False)
op.create_table('user_groups',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('group_id', sa.Integer(), nullable=False),
sa.Column('role', user_role_enum, nullable=False),
sa.Column('joined_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('user_id', 'group_id', name='uq_user_group')
)
op.create_index(op.f('ix_user_groups_id'), 'user_groups', ['id'], unique=False)
op.create_table('invites',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('code', sa.String(), nullable=False),
sa.Column('group_id', sa.Integer(), nullable=False),
sa.Column('created_by_id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('is_active', sa.Boolean(), server_default=sa.text('true'), nullable=False),
sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_invites_code'), 'invites', ['code'], unique=False)
op.create_index('ix_invites_active_code', 'invites', ['code'], unique=True, postgresql_where=sa.text('is_active = true'))
op.create_index(op.f('ix_invites_id'), 'invites', ['id'], unique=False)
op.create_table('lists',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('created_by_id', sa.Integer(), nullable=False),
sa.Column('group_id', sa.Integer(), nullable=True),
sa.Column('is_complete', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('version', sa.Integer(), server_default='1', nullable=False),
sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_lists_id'), 'lists', ['id'], unique=False)
op.create_index(op.f('ix_lists_name'), 'lists', ['name'], unique=False)
op.create_table('recurrence_patterns',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('type', recurrence_type_enum, nullable=False),
sa.Column('interval', sa.Integer(), server_default='1', nullable=False),
sa.Column('days_of_week', sa.String(), nullable=True),
sa.Column('end_date', sa.DateTime(timezone=True), nullable=True),
sa.Column('max_occurrences', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_recurrence_patterns_id'), 'recurrence_patterns', ['id'], unique=False)
op.create_table('items',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('list_id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('quantity', sa.String(), nullable=True),
sa.Column('is_complete', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('price', sa.Numeric(precision=10, scale=2), nullable=True),
sa.Column('added_by_id', sa.Integer(), nullable=False),
sa.Column('completed_by_id', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('version', sa.Integer(), server_default='1', nullable=False),
sa.ForeignKeyConstraint(['added_by_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['completed_by_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['list_id'], ['lists.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_items_id'), 'items', ['id'], unique=False)
op.create_index(op.f('ix_items_name'), 'items', ['name'], unique=False)
op.create_table('chores',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('type', chore_type_enum, nullable=False),
sa.Column('group_id', sa.Integer(), nullable=True),
sa.Column('name', sa.String(), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('created_by_id', sa.Integer(), nullable=False),
sa.Column('frequency', chore_frequency_enum, nullable=False),
sa.Column('custom_interval_days', sa.Integer(), nullable=True),
sa.Column('next_due_date', sa.Date(), nullable=False),
sa.Column('last_completed_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_chores_created_by_id'), 'chores', ['created_by_id'], unique=False)
op.create_index(op.f('ix_chores_group_id'), 'chores', ['group_id'], unique=False)
op.create_index(op.f('ix_chores_id'), 'chores', ['id'], unique=False)
op.create_index(op.f('ix_chores_name'), 'chores', ['name'], unique=False)
op.create_table('expenses',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('description', sa.String(), nullable=False),
sa.Column('total_amount', sa.Numeric(precision=10, scale=2), nullable=False),
sa.Column('currency', sa.String(), server_default='USD', nullable=False),
sa.Column('expense_date', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('split_type', split_type_enum, nullable=False),
sa.Column('list_id', sa.Integer(), nullable=True),
sa.Column('group_id', sa.Integer(), nullable=True),
sa.Column('item_id', sa.Integer(), nullable=True),
sa.Column('paid_by_user_id', sa.Integer(), nullable=False),
sa.Column('created_by_user_id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('version', sa.Integer(), server_default='1', nullable=False),
sa.Column('overall_settlement_status', expense_overall_status_enum, server_default='unpaid', nullable=False),
sa.Column('is_recurring', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('recurrence_pattern_id', sa.Integer(), nullable=True),
sa.Column('next_occurrence', sa.DateTime(timezone=True), nullable=True),
sa.Column('parent_expense_id', sa.Integer(), nullable=True),
sa.Column('last_occurrence', sa.DateTime(timezone=True), nullable=True),
sa.CheckConstraint('(item_id IS NOT NULL) OR (list_id IS NOT NULL) OR (group_id IS NOT NULL)', name='chk_expense_context'),
sa.ForeignKeyConstraint(['created_by_user_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ),
sa.ForeignKeyConstraint(['item_id'], ['items.id'], ),
sa.ForeignKeyConstraint(['list_id'], ['lists.id'], ),
sa.ForeignKeyConstraint(['paid_by_user_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['parent_expense_id'], ['expenses.id'], ),
sa.ForeignKeyConstraint(['recurrence_pattern_id'], ['recurrence_patterns.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_expenses_created_by_user_id'), 'expenses', ['created_by_user_id'], unique=False)
op.create_index(op.f('ix_expenses_group_id'), 'expenses', ['group_id'], unique=False)
op.create_index(op.f('ix_expenses_id'), 'expenses', ['id'], unique=False)
op.create_index(op.f('ix_expenses_list_id'), 'expenses', ['list_id'], unique=False)
op.create_index(op.f('ix_expenses_paid_by_user_id'), 'expenses', ['paid_by_user_id'], unique=False)
op.create_index('ix_expenses_recurring_next_occurrence', 'expenses', ['is_recurring', 'next_occurrence'], unique=False, postgresql_where=sa.text('is_recurring = true'))
op.create_table('chore_assignments',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('chore_id', sa.Integer(), nullable=False),
sa.Column('assigned_to_user_id', sa.Integer(), nullable=False),
sa.Column('due_date', sa.Date(), nullable=False),
sa.Column('is_complete', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['assigned_to_user_id'], ['users.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['chore_id'], ['chores.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_chore_assignments_assigned_to_user_id'), 'chore_assignments', ['assigned_to_user_id'], unique=False)
op.create_index(op.f('ix_chore_assignments_chore_id'), 'chore_assignments', ['chore_id'], unique=False)
op.create_index(op.f('ix_chore_assignments_id'), 'chore_assignments', ['id'], unique=False)
op.create_table('expense_splits',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('expense_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('owed_amount', sa.Numeric(precision=10, scale=2), nullable=False),
sa.Column('share_percentage', sa.Numeric(precision=5, scale=2), nullable=True),
sa.Column('share_units', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('status', expense_split_status_enum, server_default='unpaid', nullable=False),
sa.Column('paid_at', sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(['expense_id'], ['expenses.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('expense_id', 'user_id', name='uq_expense_user_split')
)
op.create_index(op.f('ix_expense_splits_id'), 'expense_splits', ['id'], unique=False)
op.create_index(op.f('ix_expense_splits_user_id'), 'expense_splits', ['user_id'], unique=False)
op.create_table('settlements',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('group_id', sa.Integer(), nullable=False),
sa.Column('paid_by_user_id', sa.Integer(), nullable=False),
sa.Column('paid_to_user_id', sa.Integer(), nullable=False),
sa.Column('amount', sa.Numeric(precision=10, scale=2), nullable=False),
sa.Column('settlement_date', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('created_by_user_id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('version', sa.Integer(), server_default='1', nullable=False),
sa.CheckConstraint('paid_by_user_id != paid_to_user_id', name='chk_settlement_different_users'),
sa.ForeignKeyConstraint(['created_by_user_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ),
sa.ForeignKeyConstraint(['paid_by_user_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['paid_to_user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_settlements_created_by_user_id'), 'settlements', ['created_by_user_id'], unique=False)
op.create_index(op.f('ix_settlements_group_id'), 'settlements', ['group_id'], unique=False)
op.create_index(op.f('ix_settlements_id'), 'settlements', ['id'], unique=False)
op.create_index(op.f('ix_settlements_paid_by_user_id'), 'settlements', ['paid_by_user_id'], unique=False)
op.create_index(op.f('ix_settlements_paid_to_user_id'), 'settlements', ['paid_to_user_id'], unique=False)
op.create_table('settlement_activities',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('expense_split_id', sa.Integer(), nullable=False),
sa.Column('paid_by_user_id', sa.Integer(), nullable=False),
sa.Column('paid_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('amount_paid', sa.Numeric(precision=10, scale=2), nullable=False),
sa.Column('created_by_user_id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['created_by_user_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['expense_split_id'], ['expense_splits.id'], ),
sa.ForeignKeyConstraint(['paid_by_user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_settlement_activities_created_by_user_id'), 'settlement_activities', ['created_by_user_id'], unique=False)
op.create_index(op.f('ix_settlement_activities_expense_split_id'), 'settlement_activities', ['expense_split_id'], unique=False)
op.create_index(op.f('ix_settlement_activities_paid_by_user_id'), 'settlement_activities', ['paid_by_user_id'], unique=False)
def downgrade(context=None) -> None: # Add context=None for compatibility, real arg passed by Alembic
op.drop_table('settlement_activities')
op.drop_table('settlements')
op.drop_table('expense_splits')
op.drop_table('chore_assignments')
op.drop_table('expenses')
op.drop_table('chores')
op.drop_table('items')
op.drop_table('recurrence_patterns')
op.drop_table('lists')
op.drop_table('invites')
op.drop_table('user_groups')
op.drop_table('groups')
op.drop_table('users')
# Drop enums in reverse order
chore_type_enum.drop(op.get_bind(), checkfirst=True)
chore_frequency_enum.drop(op.get_bind(), checkfirst=True)
recurrence_type_enum.drop(op.get_bind(), checkfirst=True)
expense_overall_status_enum.drop(op.get_bind(), checkfirst=True)
expense_split_status_enum.drop(op.get_bind(), checkfirst=True)
split_type_enum.drop(op.get_bind(), checkfirst=True)
user_role_enum.drop(op.get_bind(), checkfirst=True)

View File

@ -1,32 +0,0 @@
"""Add invite table and relationships
Revision ID: 563ee77c5214
Revises: 69b0c1432084
Create Date: 2025-03-30 18:51:19.926810
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '563ee77c5214'
down_revision: Union[str, None] = '69b0c1432084'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###

View File

@ -1,32 +0,0 @@
"""Initial database setup
Revision ID: 643956b3f4de
Revises:
Create Date: 2025-03-29 20:49:01.018626
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '643956b3f4de'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###

View File

@ -1,32 +0,0 @@
"""Add invite table and relationships
Revision ID: 69b0c1432084
Revises: 6f80b82dbdf8
Create Date: 2025-03-30 18:50:48.072504
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '69b0c1432084'
down_revision: Union[str, None] = '6f80b82dbdf8'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###

View File

@ -1,32 +0,0 @@
"""Add invite table and relationships
Revision ID: 6f80b82dbdf8
Revises: f42efe4f4bca
Create Date: 2025-03-30 18:49:26.968637
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '6f80b82dbdf8'
down_revision: Union[str, None] = 'f42efe4f4bca'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###

View File

@ -1,72 +0,0 @@
"""Add User, Group, UserGroup models
Revision ID: 85a3c075e73a
Revises: c6cbef99588b
Create Date: 2025-03-30 12:46:07.322285
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '85a3c075e73a'
down_revision: Union[str, None] = 'c6cbef99588b'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('users',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('email', sa.String(), nullable=False),
sa.Column('password_hash', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False)
op.create_index(op.f('ix_users_name'), 'users', ['name'], unique=False)
op.create_table('groups',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('created_by_id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_groups_id'), 'groups', ['id'], unique=False)
op.create_index(op.f('ix_groups_name'), 'groups', ['name'], unique=False)
op.create_table('user_groups',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('group_id', sa.Integer(), nullable=False),
sa.Column('role', sa.Enum('owner', 'member', name='userroleenum'), nullable=False),
sa.Column('joined_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('user_id', 'group_id', name='uq_user_group')
)
op.create_index(op.f('ix_user_groups_id'), 'user_groups', ['id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_user_groups_id'), table_name='user_groups')
op.drop_table('user_groups')
op.drop_index(op.f('ix_groups_name'), table_name='groups')
op.drop_index(op.f('ix_groups_id'), table_name='groups')
op.drop_table('groups')
op.drop_index(op.f('ix_users_name'), table_name='users')
op.drop_index(op.f('ix_users_id'), table_name='users')
op.drop_index(op.f('ix_users_email'), table_name='users')
op.drop_table('users')
# ### end Alembic commands ###

View File

@ -0,0 +1,133 @@
"""Add position to Item model for reordering
Revision ID: 91d00c100f5b
Revises: 0001_initial_schema
Create Date: 2025-06-07 14:59:48.761124
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '91d00c100f5b'
down_revision: Union[str, None] = '0001_initial_schema'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('ix_apscheduler_jobs_next_run_time', table_name='apscheduler_jobs')
op.drop_table('apscheduler_jobs')
op.alter_column('chore_assignments', 'is_complete',
existing_type=sa.BOOLEAN(),
server_default=None,
existing_nullable=False)
op.alter_column('expenses', 'currency',
existing_type=sa.VARCHAR(),
server_default=None,
existing_nullable=False)
op.alter_column('expenses', 'is_recurring',
existing_type=sa.BOOLEAN(),
server_default=None,
existing_nullable=False)
op.drop_index('ix_expenses_recurring_next_occurrence', table_name='expenses', postgresql_where='(is_recurring = true)')
op.alter_column('invites', 'is_active',
existing_type=sa.BOOLEAN(),
server_default=None,
existing_nullable=False)
op.add_column('items', sa.Column('position', sa.Integer(), server_default='0', nullable=False))
op.alter_column('items', 'is_complete',
existing_type=sa.BOOLEAN(),
server_default=None,
existing_nullable=False)
op.create_index('ix_items_list_id_position', 'items', ['list_id', 'position'], unique=False)
op.alter_column('lists', 'is_complete',
existing_type=sa.BOOLEAN(),
server_default=None,
existing_nullable=False)
op.alter_column('recurrence_patterns', 'interval',
existing_type=sa.INTEGER(),
server_default=None,
existing_nullable=False)
op.create_index(op.f('ix_settlement_activities_id'), 'settlement_activities', ['id'], unique=False)
op.create_index('ix_settlement_activity_created_by_user_id', 'settlement_activities', ['created_by_user_id'], unique=False)
op.create_index('ix_settlement_activity_expense_split_id', 'settlement_activities', ['expense_split_id'], unique=False)
op.create_index('ix_settlement_activity_paid_by_user_id', 'settlement_activities', ['paid_by_user_id'], unique=False)
op.alter_column('users', 'is_active',
existing_type=sa.BOOLEAN(),
server_default=None,
existing_nullable=False)
op.alter_column('users', 'is_superuser',
existing_type=sa.BOOLEAN(),
server_default=None,
existing_nullable=False)
op.alter_column('users', 'is_verified',
existing_type=sa.BOOLEAN(),
server_default=None,
existing_nullable=False)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('users', 'is_verified',
existing_type=sa.BOOLEAN(),
server_default=sa.text('false'),
existing_nullable=False)
op.alter_column('users', 'is_superuser',
existing_type=sa.BOOLEAN(),
server_default=sa.text('false'),
existing_nullable=False)
op.alter_column('users', 'is_active',
existing_type=sa.BOOLEAN(),
server_default=sa.text('true'),
existing_nullable=False)
op.drop_index('ix_settlement_activity_paid_by_user_id', table_name='settlement_activities')
op.drop_index('ix_settlement_activity_expense_split_id', table_name='settlement_activities')
op.drop_index('ix_settlement_activity_created_by_user_id', table_name='settlement_activities')
op.drop_index(op.f('ix_settlement_activities_id'), table_name='settlement_activities')
op.alter_column('recurrence_patterns', 'interval',
existing_type=sa.INTEGER(),
server_default=sa.text('1'),
existing_nullable=False)
op.alter_column('lists', 'is_complete',
existing_type=sa.BOOLEAN(),
server_default=sa.text('false'),
existing_nullable=False)
op.drop_index('ix_items_list_id_position', table_name='items')
op.alter_column('items', 'is_complete',
existing_type=sa.BOOLEAN(),
server_default=sa.text('false'),
existing_nullable=False)
op.drop_column('items', 'position')
op.alter_column('invites', 'is_active',
existing_type=sa.BOOLEAN(),
server_default=sa.text('true'),
existing_nullable=False)
op.create_index('ix_expenses_recurring_next_occurrence', 'expenses', ['is_recurring', 'next_occurrence'], unique=False, postgresql_where='(is_recurring = true)')
op.alter_column('expenses', 'is_recurring',
existing_type=sa.BOOLEAN(),
server_default=sa.text('false'),
existing_nullable=False)
op.alter_column('expenses', 'currency',
existing_type=sa.VARCHAR(),
server_default=sa.text("'USD'::character varying"),
existing_nullable=False)
op.alter_column('chore_assignments', 'is_complete',
existing_type=sa.BOOLEAN(),
server_default=sa.text('false'),
existing_nullable=False)
op.create_table('apscheduler_jobs',
sa.Column('id', sa.VARCHAR(length=191), autoincrement=False, nullable=False),
sa.Column('next_run_time', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=True),
sa.Column('job_state', postgresql.BYTEA(), autoincrement=False, nullable=False),
sa.PrimaryKeyConstraint('id', name='apscheduler_jobs_pkey')
)
op.create_index('ix_apscheduler_jobs_next_run_time', 'apscheduler_jobs', ['next_run_time'], unique=False)
# ### end Alembic commands ###

View File

@ -1,32 +0,0 @@
"""Initial database setup
Revision ID: c6cbef99588b
Revises: 643956b3f4de
Create Date: 2025-03-30 12:18:51.207858
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'c6cbef99588b'
down_revision: Union[str, None] = '643956b3f4de'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###

View File

@ -1,73 +0,0 @@
"""Add list and item tables
Revision ID: d25788f63e2c
Revises: d90ab7116920
Create Date: 2025-03-30 19:43:49.925240
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'd25788f63e2c'
down_revision: Union[str, None] = 'd90ab7116920'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('lists',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('created_by_id', sa.Integer(), nullable=False),
sa.Column('group_id', sa.Integer(), nullable=True),
sa.Column('is_complete', sa.Boolean(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_lists_id'), 'lists', ['id'], unique=False)
op.create_index(op.f('ix_lists_name'), 'lists', ['name'], unique=False)
op.create_table('items',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('list_id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('quantity', sa.String(), nullable=True),
sa.Column('is_complete', sa.Boolean(), nullable=False),
sa.Column('price', sa.Numeric(precision=10, scale=2), nullable=True),
sa.Column('added_by_id', sa.Integer(), nullable=False),
sa.Column('completed_by_id', sa.Integer(), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['added_by_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['completed_by_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['list_id'], ['lists.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_items_id'), 'items', ['id'], unique=False)
op.create_index(op.f('ix_items_name'), 'items', ['name'], unique=False)
op.drop_index('ix_invites_code', table_name='invites')
op.create_index(op.f('ix_invites_code'), 'invites', ['code'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_invites_code'), table_name='invites')
op.create_index('ix_invites_code', 'invites', ['code'], unique=True)
op.drop_index(op.f('ix_items_name'), table_name='items')
op.drop_index(op.f('ix_items_id'), table_name='items')
op.drop_table('items')
op.drop_index(op.f('ix_lists_name'), table_name='lists')
op.drop_index(op.f('ix_lists_id'), table_name='lists')
op.drop_table('lists')
# ### end Alembic commands ###

View File

@ -1,32 +0,0 @@
"""Add invite table and relationships
Revision ID: d90ab7116920
Revises: 563ee77c5214
Create Date: 2025-03-30 18:57:39.047729
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'd90ab7116920'
down_revision: Union[str, None] = '563ee77c5214'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###

View File

@ -1,89 +0,0 @@
"""Add expense tracking tables and item price columns
Revision ID: ebbe5cdba808
Revises: d25788f63e2c
Create Date: 2025-04-02 23:51:31.432547
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'ebbe5cdba808'
down_revision: Union[str, None] = 'd25788f63e2c'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('expense_records',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('list_id', sa.Integer(), nullable=False),
sa.Column('calculated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('calculated_by_id', sa.Integer(), nullable=False),
sa.Column('total_amount', sa.Numeric(precision=10, scale=2), nullable=False),
sa.Column('participants', sa.ARRAY(sa.Integer()), nullable=False),
sa.Column('split_type', sa.Enum('equal', name='splittypeenum'), nullable=False),
sa.Column('is_settled', sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(['calculated_by_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['list_id'], ['lists.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_expense_records_id'), 'expense_records', ['id'], unique=False)
op.create_index(op.f('ix_expense_records_list_id'), 'expense_records', ['list_id'], unique=False)
op.create_table('expense_shares',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('expense_record_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('amount_owed', sa.Numeric(precision=10, scale=2), nullable=False),
sa.Column('is_paid', sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(['expense_record_id'], ['expense_records.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('expense_record_id', 'user_id', name='uq_expense_share_user')
)
op.create_index(op.f('ix_expense_shares_expense_record_id'), 'expense_shares', ['expense_record_id'], unique=False)
op.create_index(op.f('ix_expense_shares_id'), 'expense_shares', ['id'], unique=False)
op.create_index(op.f('ix_expense_shares_user_id'), 'expense_shares', ['user_id'], unique=False)
op.create_table('settlement_activities',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('expense_record_id', sa.Integer(), nullable=False),
sa.Column('payer_user_id', sa.Integer(), nullable=False),
sa.Column('affected_user_id', sa.Integer(), nullable=False),
sa.Column('activity_type', sa.Enum('marked_paid', 'marked_unpaid', name='settlementactivitytypeenum'), nullable=False),
sa.Column('timestamp', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['affected_user_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['expense_record_id'], ['expense_records.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['payer_user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_settlement_activities_expense_record_id'), 'settlement_activities', ['expense_record_id'], unique=False)
op.create_index(op.f('ix_settlement_activities_id'), 'settlement_activities', ['id'], unique=False)
op.add_column('items', sa.Column('price_added_by_id', sa.Integer(), nullable=True))
op.add_column('items', sa.Column('price_added_at', sa.DateTime(timezone=True), nullable=True))
op.create_foreign_key(None, 'items', 'users', ['price_added_by_id'], ['id'])
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, 'items', type_='foreignkey')
op.drop_column('items', 'price_added_at')
op.drop_column('items', 'price_added_by_id')
op.drop_index(op.f('ix_settlement_activities_id'), table_name='settlement_activities')
op.drop_index(op.f('ix_settlement_activities_expense_record_id'), table_name='settlement_activities')
op.drop_table('settlement_activities')
op.drop_index(op.f('ix_expense_shares_user_id'), table_name='expense_shares')
op.drop_index(op.f('ix_expense_shares_id'), table_name='expense_shares')
op.drop_index(op.f('ix_expense_shares_expense_record_id'), table_name='expense_shares')
op.drop_table('expense_shares')
op.drop_index(op.f('ix_expense_records_list_id'), table_name='expense_records')
op.drop_index(op.f('ix_expense_records_id'), table_name='expense_records')
op.drop_table('expense_records')
# ### end Alembic commands ###

View File

@ -1,49 +0,0 @@
"""Add invite table and relationships
Revision ID: f42efe4f4bca
Revises: 85a3c075e73a
Create Date: 2025-03-30 18:41:50.854172
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'f42efe4f4bca'
down_revision: Union[str, None] = '85a3c075e73a'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('invites',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('code', sa.String(), nullable=False),
sa.Column('group_id', sa.Integer(), nullable=False),
sa.Column('created_by_id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ),
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index('ix_invites_active_code', 'invites', ['code'], unique=True, postgresql_where=sa.text('is_active = true'))
op.create_index(op.f('ix_invites_code'), 'invites', ['code'], unique=True)
op.create_index(op.f('ix_invites_id'), 'invites', ['id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_invites_id'), table_name='invites')
op.drop_index(op.f('ix_invites_code'), table_name='invites')
op.drop_index('ix_invites_active_code', table_name='invites', postgresql_where=sa.text('is_active = true'))
op.drop_table('invites')
# ### end Alembic commands ###

95
be/app/api/auth/oauth.py Normal file
View File

@ -0,0 +1,95 @@
from fastapi import APIRouter, Depends, Request
from fastapi.responses import RedirectResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.database import get_transactional_session
from app.models import User
from app.auth import oauth, fastapi_users, auth_backend, get_jwt_strategy, get_refresh_jwt_strategy
from app.config import settings
router = APIRouter()
@router.get('/google/login')
async def google_login(request: Request):
return await oauth.google.authorize_redirect(request, settings.GOOGLE_REDIRECT_URI)
@router.get('/google/callback')
async def google_callback(request: Request, db: AsyncSession = Depends(get_transactional_session)):
token_data = await oauth.google.authorize_access_token(request)
user_info = await oauth.google.parse_id_token(request, token_data)
# Check if user exists
existing_user = (await db.execute(select(User).where(User.email == user_info['email']))).scalar_one_or_none()
user_to_login = existing_user
if not existing_user:
# Create new user
new_user = User(
email=user_info['email'],
name=user_info.get('name', user_info.get('email')),
is_verified=True, # Email is verified by Google
is_active=True
)
db.add(new_user)
await db.flush() # Use flush instead of commit since we're in a transaction
user_to_login = new_user
# Generate JWT tokens using the new backend
access_strategy = get_jwt_strategy()
refresh_strategy = get_refresh_jwt_strategy()
access_token = await access_strategy.write_token(user_to_login)
refresh_token = await refresh_strategy.write_token(user_to_login)
# Redirect to frontend with tokens
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}&refresh_token={refresh_token}"
return RedirectResponse(url=redirect_url)
@router.get('/apple/login')
async def apple_login(request: Request):
return await oauth.apple.authorize_redirect(request, settings.APPLE_REDIRECT_URI)
@router.get('/apple/callback')
async def apple_callback(request: Request, db: AsyncSession = Depends(get_transactional_session)):
token_data = await oauth.apple.authorize_access_token(request)
user_info = token_data.get('user', await oauth.apple.userinfo(token=token_data) if hasattr(oauth.apple, 'userinfo') else {})
if 'email' not in user_info and 'sub' in token_data:
parsed_id_token = await oauth.apple.parse_id_token(request, token_data) if hasattr(oauth.apple, 'parse_id_token') else {}
user_info = {**parsed_id_token, **user_info}
if 'email' not in user_info:
return RedirectResponse(url=f"{settings.FRONTEND_URL}/auth/callback?error=apple_email_missing")
# Check if user exists
existing_user = (await db.execute(select(User).where(User.email == user_info['email']))).scalar_one_or_none()
user_to_login = existing_user
if not existing_user:
# Create new user
name_info = user_info.get('name', {})
first_name = name_info.get('firstName', '')
last_name = name_info.get('lastName', '')
full_name = f"{first_name} {last_name}".strip() if first_name or last_name else user_info.get('email')
new_user = User(
email=user_info['email'],
name=full_name,
is_verified=True, # Email is verified by Apple
is_active=True
)
db.add(new_user)
await db.flush() # Use flush instead of commit since we're in a transaction
user_to_login = new_user
# Generate JWT tokens using the new backend
access_strategy = get_jwt_strategy()
refresh_strategy = get_refresh_jwt_strategy()
access_token = await access_strategy.write_token(user_to_login)
refresh_token = await refresh_strategy.write_token(user_to_login)
# Redirect to frontend with tokens
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}&refresh_token={refresh_token}"
return RedirectResponse(url=redirect_url)

View File

@ -1,71 +0,0 @@
# app/api/dependencies.py
import logging
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
from jose import JWTError
from app.database import get_db
from app.core.security import verify_access_token
from app.crud import user as crud_user
from app.models import User as UserModel # Import the SQLAlchemy model
logger = logging.getLogger(__name__)
# Define the OAuth2 scheme
# tokenUrl should point to your login endpoint relative to the base path
# It's used by Swagger UI for the "Authorize" button flow.
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") # Corrected path
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_db)
) -> UserModel:
"""
Dependency to get the current user based on the JWT token.
- Extracts token using OAuth2PasswordBearer.
- Verifies the token (signature, expiry).
- Fetches the user from the database based on the token's subject (email).
- Raises HTTPException 401 if any step fails.
Returns:
The authenticated user's database model instance.
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
payload = verify_access_token(token)
if payload is None:
logger.warning("Token verification failed (invalid, expired, or malformed).")
raise credentials_exception
email: Optional[str] = payload.get("sub")
if email is None:
logger.error("Token payload missing 'sub' (subject/email).")
raise credentials_exception # Token is malformed
# Fetch user from database
user = await crud_user.get_user_by_email(db, email=email)
if user is None:
logger.warning(f"User corresponding to token subject not found: {email}")
# Could happen if user deleted after token issuance
raise credentials_exception # Treat as invalid credentials
logger.debug(f"Authenticated user retrieved: {user.email} (ID: {user.id})")
return user
# Optional: Dependency for getting the *active* current user
# You might add an `is_active` flag to your User model later
# async def get_current_active_user(
# current_user: UserModel = Depends(get_current_user)
# ) -> UserModel:
# if not current_user.is_active: # Assuming an is_active attribute
# logger.warning(f"Authentication attempt by inactive user: {current_user.email}")
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user")
# return current_user

View File

@ -1,26 +1,25 @@
# app/api/v1/api.py
from fastapi import APIRouter
from app.api.v1.endpoints import health
from app.api.v1.endpoints import auth
from app.api.v1.endpoints import users
from app.api.v1.endpoints import groups
from app.api.v1.endpoints import invites
from app.api.v1.endpoints import lists
from app.api.v1.endpoints import items
from app.api.v1.endpoints import ocr
from app.api.v1.endpoints import expenses
from app.api.v1.endpoints import costs
from app.api.v1.endpoints import financials
from app.api.v1.endpoints import chores
api_router_v1 = APIRouter()
api_router_v1.include_router(health.router) # Path /health defined inside
api_router_v1.include_router(auth.router, prefix="/auth", tags=["Authentication"])
api_router_v1.include_router(users.router, prefix="/users", tags=["Users"])
api_router_v1.include_router(health.router)
api_router_v1.include_router(groups.router, prefix="/groups", tags=["Groups"])
api_router_v1.include_router(invites.router, prefix="/invites", tags=["Invites"])
api_router_v1.include_router(lists.router, prefix="/lists", tags=["Lists"])
api_router_v1.include_router(items.router, tags=["Items"])
api_router_v1.include_router(ocr.router, prefix="/ocr", tags=["OCR"])
api_router_v1.include_router(expenses.router, tags=["Expenses"])
api_router_v1.include_router(costs.router, prefix="/costs", tags=["Costs"])
api_router_v1.include_router(financials.router, prefix="/financials", tags=["Financials"])
api_router_v1.include_router(chores.router, prefix="/chores", tags=["Chores"])
# Add other v1 endpoint routers here later
# e.g., api_router_v1.include_router(users.router, prefix="/users", tags=["Users"])

View File

@ -1,91 +0,0 @@
# app/api/v1/endpoints/auth.py
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.schemas.user import UserCreate, UserPublic
from app.schemas.auth import Token
from app.crud import user as crud_user
from app.core.security import verify_password, create_access_token
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post(
"/signup",
response_model=UserPublic, # Return public user info, not the password hash
status_code=status.HTTP_201_CREATED, # Indicate resource creation
summary="Register New User",
description="Creates a new user account.",
tags=["Authentication"]
)
async def signup(
user_in: UserCreate,
db: AsyncSession = Depends(get_db)
):
"""
Handles user registration.
- Validates input data.
- Checks if email already exists.
- Hashes the password.
- Stores the new user in the database.
"""
logger.info(f"Signup attempt for email: {user_in.email}")
existing_user = await crud_user.get_user_by_email(db, email=user_in.email)
if existing_user:
logger.warning(f"Signup failed: Email already registered - {user_in.email}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered.",
)
try:
created_user = await crud_user.create_user(db=db, user_in=user_in)
logger.info(f"User created successfully: {created_user.email} (ID: {created_user.id})")
# Note: UserPublic schema automatically excludes the hashed password
return created_user
except Exception as e:
logger.error(f"Error during user creation for {user_in.email}: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred during user creation.",
)
@router.post(
"/login",
response_model=Token,
summary="User Login",
description="Authenticates a user and returns an access token.",
tags=["Authentication"]
)
async def login(
form_data: OAuth2PasswordRequestForm = Depends(), # Use standard form for username/password
db: AsyncSession = Depends(get_db)
):
"""
Handles user login.
- Finds user by email (provided in 'username' field of form).
- Verifies the provided password against the stored hash.
- Generates and returns a JWT access token upon successful authentication.
"""
logger.info(f"Login attempt for user: {form_data.username}")
user = await crud_user.get_user_by_email(db, email=form_data.username)
# Check if user exists and password is correct
# Use the correct attribute name 'password_hash' from the User model
if not user or not verify_password(form_data.password, user.password_hash): # <-- CORRECTED LINE
logger.warning(f"Login failed: Invalid credentials for user {form_data.username}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"}, # Standard header for 401
)
# Generate JWT
access_token = create_access_token(subject=user.email) # Use email as subject
logger.info(f"Login successful, token generated for user: {user.email}")
return Token(access_token=access_token, token_type="bearer")

View File

@ -0,0 +1,453 @@
# app/api/v1/endpoints/chores.py
import logging
from typing import List as PyList, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Response
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_transactional_session, get_session
from app.auth import current_active_user
from app.models import User as UserModel, Chore as ChoreModel, ChoreTypeEnum
from app.schemas.chore import ChoreCreate, ChoreUpdate, ChorePublic, ChoreAssignmentCreate, ChoreAssignmentUpdate, ChoreAssignmentPublic
from app.crud import chore as crud_chore
from app.core.exceptions import ChoreNotFoundError, PermissionDeniedError, GroupNotFoundError, DatabaseIntegrityError
logger = logging.getLogger(__name__)
router = APIRouter()
# Add this new endpoint before the personal chores section
@router.get(
"/all",
response_model=PyList[ChorePublic],
summary="List All Chores",
tags=["Chores"]
)
async def list_all_chores(
db: AsyncSession = Depends(get_session), # Use read-only session for GET
current_user: UserModel = Depends(current_active_user),
):
"""Retrieves all chores (personal and group) for the current user in a single optimized request."""
logger.info(f"User {current_user.email} listing all their chores")
# Use the optimized function that reduces database queries
all_chores = await crud_chore.get_all_user_chores(db=db, user_id=current_user.id)
return all_chores
# --- Personal Chores Endpoints ---
@router.post(
"/personal",
response_model=ChorePublic,
status_code=status.HTTP_201_CREATED,
summary="Create Personal Chore",
tags=["Chores", "Personal Chores"]
)
async def create_personal_chore(
chore_in: ChoreCreate,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Creates a new personal chore for the current user."""
logger.info(f"User {current_user.email} creating personal chore: {chore_in.name}")
if chore_in.type != ChoreTypeEnum.personal:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Chore type must be personal.")
if chore_in.group_id is not None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="group_id must be null for personal chores.")
try:
return await crud_chore.create_chore(db=db, chore_in=chore_in, user_id=current_user.id)
except ValueError as e:
logger.warning(f"ValueError creating personal chore for user {current_user.email}: {str(e)}")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except DatabaseIntegrityError as e:
logger.error(f"DatabaseIntegrityError creating personal chore for {current_user.email}: {e.detail}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
@router.get(
"/personal",
response_model=PyList[ChorePublic],
summary="List Personal Chores",
tags=["Chores", "Personal Chores"]
)
async def list_personal_chores(
db: AsyncSession = Depends(get_session),
current_user: UserModel = Depends(current_active_user),
):
"""Retrieves all personal chores for the current user."""
logger.info(f"User {current_user.email} listing their personal chores")
return await crud_chore.get_personal_chores(db=db, user_id=current_user.id)
@router.put(
"/personal/{chore_id}",
response_model=ChorePublic,
summary="Update Personal Chore",
tags=["Chores", "Personal Chores"]
)
async def update_personal_chore(
chore_id: int,
chore_in: ChoreUpdate,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Updates a personal chore for the current user."""
logger.info(f"User {current_user.email} updating personal chore ID: {chore_id}")
if chore_in.type is not None and chore_in.type != ChoreTypeEnum.personal:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot change chore type to group via this endpoint.")
if chore_in.group_id is not None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="group_id must be null for personal chores.")
try:
updated_chore = await crud_chore.update_chore(db=db, chore_id=chore_id, chore_in=chore_in, user_id=current_user.id, group_id=None)
if not updated_chore:
raise ChoreNotFoundError(chore_id=chore_id)
if updated_chore.type != ChoreTypeEnum.personal or updated_chore.created_by_id != current_user.id:
# This should ideally be caught by the CRUD layer permission checks
raise PermissionDeniedError(detail="Chore is not a personal chore of the current user or does not exist.")
return updated_chore
except ChoreNotFoundError as e:
logger.warning(f"Personal chore {e.chore_id} not found for user {current_user.email} during update.")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
except PermissionDeniedError as e:
logger.warning(f"Permission denied for user {current_user.email} updating personal chore {chore_id}: {e.detail}")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
except ValueError as e:
logger.warning(f"ValueError updating personal chore {chore_id} for user {current_user.email}: {str(e)}")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except DatabaseIntegrityError as e:
logger.error(f"DatabaseIntegrityError updating personal chore {chore_id} for {current_user.email}: {e.detail}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
@router.delete(
"/personal/{chore_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete Personal Chore",
tags=["Chores", "Personal Chores"]
)
async def delete_personal_chore(
chore_id: int,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Deletes a personal chore for the current user."""
logger.info(f"User {current_user.email} deleting personal chore ID: {chore_id}")
try:
# First, verify it's a personal chore belonging to the user
chore_to_delete = await crud_chore.get_chore_by_id(db, chore_id)
if not chore_to_delete or chore_to_delete.type != ChoreTypeEnum.personal or chore_to_delete.created_by_id != current_user.id:
raise ChoreNotFoundError(chore_id=chore_id, detail="Personal chore not found or not owned by user.")
success = await crud_chore.delete_chore(db=db, chore_id=chore_id, user_id=current_user.id, group_id=None)
if not success:
# This case should be rare if the above check passes and DB is consistent
raise ChoreNotFoundError(chore_id=chore_id)
return Response(status_code=status.HTTP_204_NO_CONTENT)
except ChoreNotFoundError as e:
logger.warning(f"Personal chore {e.chore_id} not found for user {current_user.email} during delete.")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
except PermissionDeniedError as e: # Should be caught by the check above
logger.warning(f"Permission denied for user {current_user.email} deleting personal chore {chore_id}: {e.detail}")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
except DatabaseIntegrityError as e:
logger.error(f"DatabaseIntegrityError deleting personal chore {chore_id} for {current_user.email}: {e.detail}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
# --- Group Chores Endpoints ---
# (These would be similar to what you might have had before, but now explicitly part of this router)
@router.post(
"/groups/{group_id}/chores",
response_model=ChorePublic,
status_code=status.HTTP_201_CREATED,
summary="Create Group Chore",
tags=["Chores", "Group Chores"]
)
async def create_group_chore(
group_id: int,
chore_in: ChoreCreate,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Creates a new chore within a specific group."""
logger.info(f"User {current_user.email} creating chore in group {group_id}: {chore_in.name}")
if chore_in.type != ChoreTypeEnum.group:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Chore type must be group.")
if chore_in.group_id != group_id and chore_in.group_id is not None: # Make sure chore_in.group_id matches path if provided
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Chore's group_id ({chore_in.group_id}) must match path group_id ({group_id}) or be omitted.")
# Ensure chore_in has the correct group_id and type for the CRUD operation
chore_payload = chore_in.model_copy(update={"group_id": group_id, "type": ChoreTypeEnum.group})
try:
return await crud_chore.create_chore(db=db, chore_in=chore_payload, user_id=current_user.id, group_id=group_id)
except GroupNotFoundError as e:
logger.warning(f"Group {e.group_id} not found for chore creation by user {current_user.email}.")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
except PermissionDeniedError as e:
logger.warning(f"Permission denied for user {current_user.email} in group {group_id} for chore creation: {e.detail}")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
except ValueError as e:
logger.warning(f"ValueError creating group chore for user {current_user.email} in group {group_id}: {str(e)}")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except DatabaseIntegrityError as e:
logger.error(f"DatabaseIntegrityError creating group chore for {current_user.email} in group {group_id}: {e.detail}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
@router.get(
"/groups/{group_id}/chores",
response_model=PyList[ChorePublic],
summary="List Group Chores",
tags=["Chores", "Group Chores"]
)
async def list_group_chores(
group_id: int,
db: AsyncSession = Depends(get_session),
current_user: UserModel = Depends(current_active_user),
):
"""Retrieves all chores for a specific group, if the user is a member."""
logger.info(f"User {current_user.email} listing chores for group {group_id}")
try:
return await crud_chore.get_chores_by_group_id(db=db, group_id=group_id, user_id=current_user.id)
except PermissionDeniedError as e:
logger.warning(f"Permission denied for user {current_user.email} accessing chores for group {group_id}: {e.detail}")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
@router.put(
"/groups/{group_id}/chores/{chore_id}",
response_model=ChorePublic,
summary="Update Group Chore",
tags=["Chores", "Group Chores"]
)
async def update_group_chore(
group_id: int,
chore_id: int,
chore_in: ChoreUpdate,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Updates a chore's details within a specific group."""
logger.info(f"User {current_user.email} updating chore ID {chore_id} in group {group_id}")
if chore_in.type is not None and chore_in.type != ChoreTypeEnum.group:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot change chore type to personal via this endpoint.")
if chore_in.group_id is not None and chore_in.group_id != group_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Chore's group_id if provided must match path group_id ({group_id}).")
# Ensure chore_in has the correct type for the CRUD operation
chore_payload = chore_in.model_copy(update={"type": ChoreTypeEnum.group, "group_id": group_id} if chore_in.type is None else {"group_id": group_id})
try:
updated_chore = await crud_chore.update_chore(db=db, chore_id=chore_id, chore_in=chore_payload, user_id=current_user.id, group_id=group_id)
if not updated_chore:
raise ChoreNotFoundError(chore_id=chore_id, group_id=group_id)
return updated_chore
except ChoreNotFoundError as e:
logger.warning(f"Chore {e.chore_id} in group {e.group_id} not found for user {current_user.email} during update.")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
except PermissionDeniedError as e:
logger.warning(f"Permission denied for user {current_user.email} updating chore {chore_id} in group {group_id}: {e.detail}")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
except ValueError as e:
logger.warning(f"ValueError updating group chore {chore_id} for user {current_user.email} in group {group_id}: {str(e)}")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except DatabaseIntegrityError as e:
logger.error(f"DatabaseIntegrityError updating group chore {chore_id} for {current_user.email}: {e.detail}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
@router.delete(
"/groups/{group_id}/chores/{chore_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete Group Chore",
tags=["Chores", "Group Chores"]
)
async def delete_group_chore(
group_id: int,
chore_id: int,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Deletes a chore from a group, ensuring user has permission."""
logger.info(f"User {current_user.email} deleting chore ID {chore_id} from group {group_id}")
try:
# Verify chore exists and belongs to the group before attempting deletion via CRUD
# This gives a more precise error if the chore exists but isn't in this group.
chore_to_delete = await crud_chore.get_chore_by_id_and_group(db, chore_id, group_id, current_user.id) # checks permission too
if not chore_to_delete : # get_chore_by_id_and_group will raise PermissionDeniedError if user not member
raise ChoreNotFoundError(chore_id=chore_id, group_id=group_id)
success = await crud_chore.delete_chore(db=db, chore_id=chore_id, user_id=current_user.id, group_id=group_id)
if not success:
# This case should be rare if the above check passes and DB is consistent
raise ChoreNotFoundError(chore_id=chore_id, group_id=group_id)
return Response(status_code=status.HTTP_204_NO_CONTENT)
except ChoreNotFoundError as e:
logger.warning(f"Chore {e.chore_id} in group {e.group_id} not found for user {current_user.email} during delete.")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
except PermissionDeniedError as e:
logger.warning(f"Permission denied for user {current_user.email} deleting chore {chore_id} in group {group_id}: {e.detail}")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
except DatabaseIntegrityError as e:
logger.error(f"DatabaseIntegrityError deleting group chore {chore_id} for {current_user.email}: {e.detail}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
# === CHORE ASSIGNMENT ENDPOINTS ===
@router.post(
"/assignments",
response_model=ChoreAssignmentPublic,
status_code=status.HTTP_201_CREATED,
summary="Create Chore Assignment",
tags=["Chore Assignments"]
)
async def create_chore_assignment(
assignment_in: ChoreAssignmentCreate,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Creates a new chore assignment. User must have permission to manage the chore."""
logger.info(f"User {current_user.email} creating assignment for chore {assignment_in.chore_id}")
try:
return await crud_chore.create_chore_assignment(db=db, assignment_in=assignment_in, user_id=current_user.id)
except ChoreNotFoundError as e:
logger.warning(f"Chore {e.chore_id} not found for assignment creation by user {current_user.email}.")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
except PermissionDeniedError as e:
logger.warning(f"Permission denied for user {current_user.email} creating assignment: {e.detail}")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
except ValueError as e:
logger.warning(f"ValueError creating assignment for user {current_user.email}: {str(e)}")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except DatabaseIntegrityError as e:
logger.error(f"DatabaseIntegrityError creating assignment for {current_user.email}: {e.detail}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
@router.get(
"/assignments/my",
response_model=PyList[ChoreAssignmentPublic],
summary="List My Chore Assignments",
tags=["Chore Assignments"]
)
async def list_my_assignments(
include_completed: bool = False,
db: AsyncSession = Depends(get_session), # Use read-only session for GET
current_user: UserModel = Depends(current_active_user),
):
"""Retrieves all chore assignments for the current user."""
logger.info(f"User {current_user.email} listing their assignments (include_completed={include_completed})")
try:
return await crud_chore.get_user_assignments(db=db, user_id=current_user.id, include_completed=include_completed)
except Exception as e:
logger.error(f"Error listing assignments for user {current_user.email}: {e}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to retrieve assignments")
@router.get(
"/chores/{chore_id}/assignments",
response_model=PyList[ChoreAssignmentPublic],
summary="List Chore Assignments",
tags=["Chore Assignments"]
)
async def list_chore_assignments(
chore_id: int,
db: AsyncSession = Depends(get_session), # Use read-only session for GET
current_user: UserModel = Depends(current_active_user),
):
"""Retrieves all assignments for a specific chore."""
logger.info(f"User {current_user.email} listing assignments for chore {chore_id}")
try:
return await crud_chore.get_chore_assignments(db=db, chore_id=chore_id, user_id=current_user.id)
except ChoreNotFoundError as e:
logger.warning(f"Chore {e.chore_id} not found for assignment listing by user {current_user.email}.")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
except PermissionDeniedError as e:
logger.warning(f"Permission denied for user {current_user.email} listing assignments for chore {chore_id}: {e.detail}")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
@router.put(
"/assignments/{assignment_id}",
response_model=ChoreAssignmentPublic,
summary="Update Chore Assignment",
tags=["Chore Assignments"]
)
async def update_chore_assignment(
assignment_id: int,
assignment_in: ChoreAssignmentUpdate,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Updates a chore assignment. Only assignee can mark complete, managers can reschedule."""
logger.info(f"User {current_user.email} updating assignment {assignment_id}")
try:
updated_assignment = await crud_chore.update_chore_assignment(
db=db, assignment_id=assignment_id, assignment_in=assignment_in, user_id=current_user.id
)
if not updated_assignment:
raise ChoreNotFoundError(assignment_id=assignment_id)
return updated_assignment
except ChoreNotFoundError as e:
logger.warning(f"Assignment {assignment_id} not found for user {current_user.email} during update.")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
except PermissionDeniedError as e:
logger.warning(f"Permission denied for user {current_user.email} updating assignment {assignment_id}: {e.detail}")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
except ValueError as e:
logger.warning(f"ValueError updating assignment {assignment_id} for user {current_user.email}: {str(e)}")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except DatabaseIntegrityError as e:
logger.error(f"DatabaseIntegrityError updating assignment {assignment_id} for {current_user.email}: {e.detail}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
@router.delete(
"/assignments/{assignment_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete Chore Assignment",
tags=["Chore Assignments"]
)
async def delete_chore_assignment(
assignment_id: int,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Deletes a chore assignment. User must have permission to manage the chore."""
logger.info(f"User {current_user.email} deleting assignment {assignment_id}")
try:
success = await crud_chore.delete_chore_assignment(db=db, assignment_id=assignment_id, user_id=current_user.id)
if not success:
raise ChoreNotFoundError(assignment_id=assignment_id)
return Response(status_code=status.HTTP_204_NO_CONTENT)
except ChoreNotFoundError as e:
logger.warning(f"Assignment {assignment_id} not found for user {current_user.email} during delete.")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
except PermissionDeniedError as e:
logger.warning(f"Permission denied for user {current_user.email} deleting assignment {assignment_id}: {e.detail}")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
except DatabaseIntegrityError as e:
logger.error(f"DatabaseIntegrityError deleting assignment {assignment_id} for {current_user.email}: {e.detail}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
@router.patch(
"/assignments/{assignment_id}/complete",
response_model=ChoreAssignmentPublic,
summary="Mark Assignment Complete",
tags=["Chore Assignments"]
)
async def complete_chore_assignment(
assignment_id: int,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Convenience endpoint to mark an assignment as complete."""
logger.info(f"User {current_user.email} marking assignment {assignment_id} as complete")
assignment_update = ChoreAssignmentUpdate(is_complete=True)
try:
updated_assignment = await crud_chore.update_chore_assignment(
db=db, assignment_id=assignment_id, assignment_in=assignment_update, user_id=current_user.id
)
if not updated_assignment:
raise ChoreNotFoundError(assignment_id=assignment_id)
return updated_assignment
except ChoreNotFoundError as e:
logger.warning(f"Assignment {assignment_id} not found for user {current_user.email} during completion.")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
except PermissionDeniedError as e:
logger.warning(f"Permission denied for user {current_user.email} completing assignment {assignment_id}: {e.detail}")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
except DatabaseIntegrityError as e:
logger.error(f"DatabaseIntegrityError completing assignment {assignment_id} for {current_user.email}: {e.detail}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)

View File

@ -0,0 +1,423 @@
# app/api/v1/endpoints/costs.py
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, selectinload
from decimal import Decimal, ROUND_HALF_UP, ROUND_DOWN
from typing import List
from app.database import get_transactional_session
from app.auth import current_active_user
from app.models import (
User as UserModel,
Group as GroupModel,
List as ListModel,
Expense as ExpenseModel,
Item as ItemModel,
UserGroup as UserGroupModel,
SplitTypeEnum,
ExpenseSplit as ExpenseSplitModel,
Settlement as SettlementModel,
SettlementActivity as SettlementActivityModel # Added
)
from app.schemas.cost import ListCostSummary, GroupBalanceSummary, UserCostShare, UserBalanceDetail, SuggestedSettlement
from app.schemas.expense import ExpenseCreate
from app.crud import list as crud_list
from app.crud import expense as crud_expense
from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError, GroupNotFoundError
logger = logging.getLogger(__name__)
router = APIRouter()
def calculate_suggested_settlements(user_balances: List[UserBalanceDetail]) -> List[SuggestedSettlement]:
"""
Calculate suggested settlements to balance the finances within a group.
This function takes the current balances of all users and suggests optimal settlements
to minimize the number of transactions needed to settle all debts.
Args:
user_balances: List of UserBalanceDetail objects with their current balances
Returns:
List of SuggestedSettlement objects representing the suggested payments
"""
# Create list of users who owe money (negative balance) and who are owed money (positive balance)
debtors = [] # Users who owe money (negative balance)
creditors = [] # Users who are owed money (positive balance)
# Threshold to consider a balance as zero due to floating point precision
epsilon = Decimal('0.01')
# Sort users into debtors and creditors
for user in user_balances:
# Skip users with zero balance (or very close to zero)
if abs(user.net_balance) < epsilon:
continue
if user.net_balance < Decimal('0'):
# User owes money
debtors.append({
'user_id': user.user_id,
'user_identifier': user.user_identifier,
'amount': -user.net_balance # Convert to positive amount
})
else:
# User is owed money
creditors.append({
'user_id': user.user_id,
'user_identifier': user.user_identifier,
'amount': user.net_balance
})
# Sort by amount (descending) to handle largest debts first
debtors.sort(key=lambda x: x['amount'], reverse=True)
creditors.sort(key=lambda x: x['amount'], reverse=True)
settlements = []
# Iterate through debtors and match them with creditors
while debtors and creditors:
debtor = debtors[0]
creditor = creditors[0]
# Determine the settlement amount (the smaller of the two amounts)
amount = min(debtor['amount'], creditor['amount']).quantize(Decimal('0.01'), rounding=ROUND_HALF_UP)
# Create settlement record
if amount > Decimal('0'):
settlements.append(
SuggestedSettlement(
from_user_id=debtor['user_id'],
from_user_identifier=debtor['user_identifier'],
to_user_id=creditor['user_id'],
to_user_identifier=creditor['user_identifier'],
amount=amount
)
)
# Update balances
debtor['amount'] -= amount
creditor['amount'] -= amount
# Remove users who have settled their debts/credits
if debtor['amount'] < epsilon:
debtors.pop(0)
if creditor['amount'] < epsilon:
creditors.pop(0)
return settlements
@router.get(
"/lists/{list_id}/cost-summary",
response_model=ListCostSummary,
summary="Get Cost Summary for a List",
tags=["Costs"],
responses={
status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this list"},
status.HTTP_404_NOT_FOUND: {"description": "List or associated user not found"}
}
)
async def get_list_cost_summary(
list_id: int,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
Retrieves a calculated cost summary for a specific list, detailing total costs,
equal shares per user, and individual user balances based on their contributions.
The user must have access to the list to view its cost summary.
Costs are split among group members if the list belongs to a group, or just for
the creator if it's a personal list. All users who added items with prices are
included in the calculation.
"""
logger.info(f"User {current_user.email} requesting cost summary for list {list_id}")
# 1. Verify user has access to the target list
try:
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
except ListPermissionError as e:
logger.warning(f"Permission denied for user {current_user.email} on list {list_id}: {str(e)}")
raise
except ListNotFoundError as e:
logger.warning(f"List {list_id} not found when checking permissions for cost summary: {str(e)}")
raise
# 2. Get the list with its items and users
list_result = await db.execute(
select(ListModel)
.options(
selectinload(ListModel.items).options(selectinload(ItemModel.added_by_user)),
selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))),
selectinload(ListModel.creator)
)
.where(ListModel.id == list_id)
)
db_list = list_result.scalars().first()
if not db_list:
raise ListNotFoundError(list_id)
# 3. Get or create an expense for this list
expense_result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.list_id == list_id)
.options(selectinload(ExpenseModel.splits))
)
db_expense = expense_result.scalars().first()
if not db_expense:
# Create a new expense for this list
total_amount = sum(item.price for item in db_list.items if item.price is not None and item.price > Decimal("0"))
if total_amount == Decimal("0"):
return ListCostSummary(
list_id=db_list.id,
list_name=db_list.name,
total_list_cost=Decimal("0.00"),
num_participating_users=0,
equal_share_per_user=Decimal("0.00"),
user_balances=[]
)
# Create expense with ITEM_BASED split type
expense_in = ExpenseCreate(
description=f"Cost summary for list {db_list.name}",
total_amount=total_amount,
list_id=list_id,
split_type=SplitTypeEnum.ITEM_BASED,
paid_by_user_id=db_list.creator.id
)
db_expense = await crud_expense.create_expense(db=db, expense_in=expense_in, current_user_id=current_user.id)
# 4. Calculate cost summary from expense splits
participating_users = set()
user_items_added_value = {}
total_list_cost = Decimal("0.00")
# Get all users who added items
for item in db_list.items:
if item.price is not None and item.price > Decimal("0") and item.added_by_user:
participating_users.add(item.added_by_user)
user_items_added_value[item.added_by_user.id] = user_items_added_value.get(item.added_by_user.id, Decimal("0.00")) + item.price
total_list_cost += item.price
# Get all users from expense splits
for split in db_expense.splits:
if split.user:
participating_users.add(split.user)
num_participating_users = len(participating_users)
if num_participating_users == 0:
return ListCostSummary(
list_id=db_list.id,
list_name=db_list.name,
total_list_cost=Decimal("0.00"),
num_participating_users=0,
equal_share_per_user=Decimal("0.00"),
user_balances=[]
)
# This is the ideal equal share, returned in the summary
equal_share_per_user_for_response = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
# Sort users for deterministic remainder distribution
sorted_participating_users = sorted(list(participating_users), key=lambda u: u.id)
user_final_shares = {}
if num_participating_users > 0:
base_share_unrounded = total_list_cost / Decimal(num_participating_users)
# Calculate initial share for each user, rounding down
for user in sorted_participating_users:
user_final_shares[user.id] = base_share_unrounded.quantize(Decimal("0.01"), rounding=ROUND_DOWN)
# Calculate sum of rounded down shares
sum_of_rounded_shares = sum(user_final_shares.values())
# Calculate remaining pennies to be distributed
remaining_pennies = int(((total_list_cost - sum_of_rounded_shares) * Decimal("100")).to_integral_value(rounding=ROUND_HALF_UP))
# Distribute remaining pennies one by one to sorted users
for i in range(remaining_pennies):
user_to_adjust = sorted_participating_users[i % num_participating_users]
user_final_shares[user_to_adjust.id] += Decimal("0.01")
user_balances = []
for user in sorted_participating_users: # Iterate over sorted users
items_added = user_items_added_value.get(user.id, Decimal("0.00"))
# current_user_share is now the precisely calculated share for this user
current_user_share = user_final_shares.get(user.id, Decimal("0.00"))
balance = items_added - current_user_share
user_identifier = user.name if user.name else user.email
user_balances.append(
UserCostShare(
user_id=user.id,
user_identifier=user_identifier,
items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
)
)
user_balances.sort(key=lambda x: x.user_identifier)
return ListCostSummary(
list_id=db_list.id,
list_name=db_list.name,
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
num_participating_users=num_participating_users,
equal_share_per_user=equal_share_per_user_for_response, # Use the ideal share for the response field
user_balances=user_balances
)
@router.get(
"/groups/{group_id}/balance-summary",
response_model=GroupBalanceSummary,
summary="Get Detailed Balance Summary for a Group",
tags=["Costs", "Groups"],
responses={
status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this group"},
status.HTTP_404_NOT_FOUND: {"description": "Group not found"}
}
)
async def get_group_balance_summary(
group_id: int,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
Retrieves a detailed financial balance summary for all users within a specific group.
It considers all expenses, their splits, and all settlements recorded for the group.
The user must be a member of the group to view its balance summary.
"""
logger.info(f"User {current_user.email} requesting balance summary for group {group_id}")
# 1. Verify user is a member of the target group
group_check = await db.execute(
select(GroupModel)
.options(selectinload(GroupModel.member_associations))
.where(GroupModel.id == group_id)
)
db_group_for_check = group_check.scalars().first()
if not db_group_for_check:
raise GroupNotFoundError(group_id)
user_is_member = any(assoc.user_id == current_user.id for assoc in db_group_for_check.member_associations)
if not user_is_member:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"User not a member of group {group_id}")
# 2. Get all expenses and settlements for the group
expenses_result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.group_id == group_id)
.options(selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user))
)
expenses = expenses_result.scalars().all()
settlements_result = await db.execute(
select(SettlementModel)
.where(SettlementModel.group_id == group_id)
.options(
selectinload(SettlementModel.paid_by_user),
selectinload(SettlementModel.paid_to_user)
)
)
settlements = settlements_result.scalars().all()
# Fetch SettlementActivities related to the group's expenses
# This requires joining SettlementActivity -> ExpenseSplit -> Expense
settlement_activities_result = await db.execute(
select(SettlementActivityModel)
.join(ExpenseSplitModel, SettlementActivityModel.expense_split_id == ExpenseSplitModel.id)
.join(ExpenseModel, ExpenseSplitModel.expense_id == ExpenseModel.id)
.where(ExpenseModel.group_id == group_id)
.options(selectinload(SettlementActivityModel.payer)) # Optional: if you need payer details directly
)
settlement_activities = settlement_activities_result.scalars().all()
# 3. Calculate user balances
user_balances_data = {}
# Initialize UserBalanceDetail for each group member
for assoc in db_group_for_check.member_associations:
if assoc.user:
user_balances_data[assoc.user.id] = {
"user_id": assoc.user.id,
"user_identifier": assoc.user.name if assoc.user.name else assoc.user.email,
"total_paid_for_expenses": Decimal("0.00"),
"initial_total_share_of_expenses": Decimal("0.00"),
"total_amount_paid_via_settlement_activities": Decimal("0.00"),
"total_generic_settlements_paid": Decimal("0.00"),
"total_generic_settlements_received": Decimal("0.00"),
}
# Process Expenses
for expense in expenses:
if expense.paid_by_user_id in user_balances_data:
user_balances_data[expense.paid_by_user_id]["total_paid_for_expenses"] += expense.total_amount
for split in expense.splits:
if split.user_id in user_balances_data:
user_balances_data[split.user_id]["initial_total_share_of_expenses"] += split.owed_amount
# Process Settlement Activities (SettlementActivityModel)
for activity in settlement_activities:
if activity.paid_by_user_id in user_balances_data:
user_balances_data[activity.paid_by_user_id]["total_amount_paid_via_settlement_activities"] += activity.amount_paid
# Process Generic Settlements (SettlementModel)
for settlement in settlements:
if settlement.paid_by_user_id in user_balances_data:
user_balances_data[settlement.paid_by_user_id]["total_generic_settlements_paid"] += settlement.amount
if settlement.paid_to_user_id in user_balances_data:
user_balances_data[settlement.paid_to_user_id]["total_generic_settlements_received"] += settlement.amount
# Calculate Final Balances
final_user_balances = []
for user_id, data in user_balances_data.items():
initial_total_share_of_expenses = data["initial_total_share_of_expenses"]
total_amount_paid_via_settlement_activities = data["total_amount_paid_via_settlement_activities"]
adjusted_total_share_of_expenses = initial_total_share_of_expenses - total_amount_paid_via_settlement_activities
total_paid_for_expenses = data["total_paid_for_expenses"]
total_generic_settlements_received = data["total_generic_settlements_received"]
total_generic_settlements_paid = data["total_generic_settlements_paid"]
net_balance = (
total_paid_for_expenses + total_generic_settlements_received
) - (adjusted_total_share_of_expenses + total_generic_settlements_paid)
# Quantize all final values for UserBalanceDetail schema
user_detail = UserBalanceDetail(
user_id=data["user_id"],
user_identifier=data["user_identifier"],
total_paid_for_expenses=total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
# Store adjusted_total_share_of_expenses in total_share_of_expenses
total_share_of_expenses=adjusted_total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
# Store total_generic_settlements_paid in total_settlements_paid
total_settlements_paid=total_generic_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
total_settlements_received=total_generic_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
net_balance=net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
)
final_user_balances.append(user_detail)
# Sort by user identifier
final_user_balances.sort(key=lambda x: x.user_identifier)
# Calculate suggested settlements
suggested_settlements = calculate_suggested_settlements(final_user_balances)
# Calculate overall totals for the group
overall_total_expenses = sum(expense.total_amount for expense in expenses)
overall_total_settlements = sum(settlement.amount for settlement in settlements)
return GroupBalanceSummary(
group_id=db_group_for_check.id,
group_name=db_group_for_check.name,
overall_total_expenses=overall_total_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
overall_total_settlements=overall_total_settlements.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
user_balances=final_user_balances,
suggested_settlements=suggested_settlements
)

View File

@ -1,45 +0,0 @@
# app/api/v1/endpoints/expenses.py
import logging
from typing import List as PyList
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.api.dependencies import get_current_user
from app.models import User as UserModel, SettlementActivityTypeEnum
from app.schemas.expense import (
ExpenseRecordPublic,
ExpenseSharePublic,
SettleShareRequest
)
from app.schemas.message import Message
from app.crud import expense as crud_expense
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/lists/{list_id}/expenses", response_model=PyList[ExpenseRecordPublic], tags=["Expenses"])
async def read_list_expense_records(
list_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""Retrieves all historical expense calculation records for a specific list."""
records = await crud_expense.get_expense_records_for_list(db, list_id=list_id)
return records
@router.post("/expenses/{expense_record_id}/settle", response_model=Message, tags=["Expenses"])
async def settle_expense_share(
expense_record_id: int,
settle_request: SettleShareRequest,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""Marks a specific user's share within an expense record as paid."""
affected_user_id = settle_request.affected_user_id
share_to_update = await crud_expense.get_expense_share(db, record_id=expense_record_id, user_id=affected_user_id)
if not share_to_update:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Expense share not found")
await crud_expense.mark_share_as_paid(db, share_id=share_to_update.id, is_paid_status=True)
return Message(detail="Share successfully marked as paid")

View File

@ -0,0 +1,658 @@
# app/api/v1/endpoints/financials.py
import logging
from fastapi import APIRouter, Depends, HTTPException, status, Query, Response
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from typing import List as PyList, Optional, Sequence
from app.database import get_transactional_session
from app.auth import current_active_user
from app.models import (
User as UserModel,
Group as GroupModel,
List as ListModel,
UserGroup as UserGroupModel,
UserRoleEnum,
ExpenseSplit as ExpenseSplitModel
)
from app.schemas.expense import (
ExpenseCreate, ExpensePublic,
SettlementCreate, SettlementPublic,
ExpenseUpdate, SettlementUpdate
)
from app.schemas.settlement_activity import SettlementActivityCreate, SettlementActivityPublic # Added
from app.crud import expense as crud_expense
from app.crud import settlement as crud_settlement
from app.crud import settlement_activity as crud_settlement_activity # Added
from app.crud import group as crud_group
from app.crud import list as crud_list
from app.core.exceptions import (
ListNotFoundError, GroupNotFoundError, UserNotFoundError,
InvalidOperationError, GroupPermissionError, ListPermissionError,
ItemNotFoundError, GroupMembershipError
)
logger = logging.getLogger(__name__)
router = APIRouter()
# --- Helper for permissions ---
async def check_list_access_for_financials(db: AsyncSession, list_id: int, user_id: int, action: str = "access financial data for"):
try:
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=user_id, require_creator=False)
except ListPermissionError as e:
logger.warning(f"ListPermissionError in check_list_access_for_financials for list {list_id}, user {user_id}, action '{action}': {e.detail}")
raise ListPermissionError(list_id, action=action)
except ListNotFoundError:
raise
# --- Expense Endpoints ---
@router.post(
"/expenses",
response_model=ExpensePublic,
status_code=status.HTTP_201_CREATED,
summary="Create New Expense",
tags=["Expenses"]
)
async def create_new_expense(
expense_in: ExpenseCreate,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} creating expense: {expense_in.description}")
effective_group_id = expense_in.group_id
is_group_context = False
if expense_in.list_id:
# Check basic access to list (implies membership if list is in group)
await check_list_access_for_financials(db, expense_in.list_id, current_user.id, action="create expenses for")
list_obj = await db.get(ListModel, expense_in.list_id)
if not list_obj:
raise ListNotFoundError(expense_in.list_id)
if list_obj.group_id:
if expense_in.group_id and list_obj.group_id != expense_in.group_id:
raise InvalidOperationError(f"List {list_obj.id} belongs to group {list_obj.group_id}, not group {expense_in.group_id} specified in expense.")
effective_group_id = list_obj.group_id
is_group_context = True # Expense is tied to a group via the list
elif expense_in.group_id:
raise InvalidOperationError(f"Personal list {list_obj.id} cannot have expense associated with group {expense_in.group_id}.")
# If list is personal, no group check needed yet, handled by payer check below.
elif effective_group_id: # Only group_id provided for expense
is_group_context = True
# Ensure user is at least a member to create expense in group context
await crud_group.check_group_membership(db, group_id=effective_group_id, user_id=current_user.id, action="create expenses for")
else:
# This case should ideally be caught by earlier checks if list_id was present but list was personal.
# If somehow reached, it means no list_id and no group_id.
raise InvalidOperationError("Expense must be linked to a list_id or group_id.")
# Finalize expense payload with correct group_id if derived
expense_in_final = expense_in.model_copy(update={"group_id": effective_group_id})
# --- Granular Permission Check for Payer ---
if expense_in_final.paid_by_user_id != current_user.id:
logger.warning(f"User {current_user.email} attempting to create expense paid by other user {expense_in_final.paid_by_user_id}")
# If creating expense paid by someone else, user MUST be owner IF in group context
if is_group_context and effective_group_id:
try:
await crud_group.check_user_role_in_group(db, group_id=effective_group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="create expense paid by another user")
except GroupPermissionError as e:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Only group owners can create expenses paid by others. {str(e)}")
else:
# Cannot create expense paid by someone else for a personal list (no group context)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Cannot create expense paid by another user for a personal list.")
# If paying for self, basic list/group membership check above is sufficient.
try:
created_expense = await crud_expense.create_expense(db=db, expense_in=expense_in_final, current_user_id=current_user.id)
logger.info(f"Expense '{created_expense.description}' (ID: {created_expense.id}) created successfully.")
return created_expense
except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError, GroupMembershipError) as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except NotImplementedError as e:
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error creating expense: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
@router.get("/expenses/{expense_id}", response_model=ExpensePublic, summary="Get Expense by ID", tags=["Expenses"])
async def get_expense(
expense_id: int,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} requesting expense ID {expense_id}")
expense = await crud_expense.get_expense_by_id(db, expense_id=expense_id)
if not expense:
raise ItemNotFoundError(item_id=expense_id)
if expense.list_id:
await check_list_access_for_financials(db, expense.list_id, current_user.id)
elif expense.group_id:
await crud_group.check_group_membership(db, group_id=expense.group_id, user_id=current_user.id)
elif expense.paid_by_user_id != current_user.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to view this expense")
return expense
@router.get("/expenses", response_model=PyList[ExpensePublic], summary="List Expenses", tags=["Expenses"])
async def list_expenses(
list_id: Optional[int] = Query(None, description="Filter by list ID"),
group_id: Optional[int] = Query(None, description="Filter by group ID"),
isRecurring: Optional[bool] = Query(None, description="Filter by recurring expenses"),
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=200),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
List expenses with optional filters.
If list_id is provided, returns expenses for that list (user must have list access).
If group_id is provided, returns expenses for that group (user must be group member).
If both are provided, returns expenses for the list (list_id takes precedence).
If neither is provided, returns all expenses the user has access to.
"""
logger.info(f"User {current_user.email} listing expenses with filters: list_id={list_id}, group_id={group_id}, isRecurring={isRecurring}")
if list_id:
# Use existing list expenses endpoint logic
await check_list_access_for_financials(db, list_id, current_user.id)
expenses = await crud_expense.get_expenses_for_list(db, list_id=list_id, skip=skip, limit=limit)
elif group_id:
# Use existing group expenses endpoint logic
await crud_group.check_group_membership(db, group_id=group_id, user_id=current_user.id, action="list expenses for")
expenses = await crud_expense.get_expenses_for_group(db, group_id=group_id, skip=skip, limit=limit)
else:
# Get all expenses the user has access to (user's personal expenses + group expenses + list expenses)
expenses = await crud_expense.get_user_accessible_expenses(db, user_id=current_user.id, skip=skip, limit=limit)
# Apply recurring filter if specified
if isRecurring is not None:
expenses = [expense for expense in expenses if bool(expense.recurrence_rule) == isRecurring]
return expenses
@router.get("/groups/{group_id}/expenses", response_model=PyList[ExpensePublic], summary="List Expenses for a Group", tags=["Expenses", "Groups"])
async def list_group_expenses(
group_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=200),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} listing expenses for group ID {group_id}")
await crud_group.check_group_membership(db, group_id=group_id, user_id=current_user.id, action="list expenses for")
expenses = await crud_expense.get_expenses_for_group(db, group_id=group_id, skip=skip, limit=limit)
return expenses
@router.put("/expenses/{expense_id}", response_model=ExpensePublic, summary="Update Expense", tags=["Expenses"])
async def update_expense_details(
expense_id: int,
expense_in: ExpenseUpdate,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
Updates an existing expense (description, currency, expense_date only).
Requires the current version number for optimistic locking.
User must have permission to modify the expense (e.g., be the payer or group admin).
"""
logger.info(f"User {current_user.email} attempting to update expense ID {expense_id} (version {expense_in.version})")
expense_db = await crud_expense.get_expense_by_id(db, expense_id=expense_id)
if not expense_db:
raise ItemNotFoundError(item_id=expense_id)
# --- Granular Permission Check ---
can_modify = False
# 1. User paid for the expense
if expense_db.paid_by_user_id == current_user.id:
can_modify = True
# 2. OR User is owner of the group the expense belongs to
elif expense_db.group_id:
try:
await crud_group.check_user_role_in_group(db, group_id=expense_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="modify group expenses")
can_modify = True
logger.info(f"Allowing update for expense {expense_id} by group owner {current_user.email}")
except GroupMembershipError: # User not even a member
pass # Keep can_modify as False
except GroupPermissionError: # User is member but not owner
pass # Keep can_modify as False
except GroupNotFoundError: # Group doesn't exist (data integrity issue)
logger.error(f"Group {expense_db.group_id} not found for expense {expense_id} during update check.")
pass # Keep can_modify as False
# Note: If expense is only linked to a personal list (no group), only payer can modify.
if not can_modify:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot modify this expense (must be payer or group owner)")
try:
updated_expense = await crud_expense.update_expense(db=db, expense_db=expense_db, expense_in=expense_in)
logger.info(f"Expense ID {expense_id} updated successfully to version {updated_expense.version}.")
return updated_expense
except InvalidOperationError as e:
# Check if it's a version conflict (409) or other validation error (400)
status_code = status.HTTP_400_BAD_REQUEST
if "version" in str(e).lower():
status_code = status.HTTP_409_CONFLICT
raise HTTPException(status_code=status_code, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error updating expense {expense_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
@router.delete("/expenses/{expense_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete Expense", tags=["Expenses"])
async def delete_expense_record(
expense_id: int,
expected_version: Optional[int] = Query(None, description="Expected version for optimistic locking"),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
Deletes an expense and its associated splits.
Requires expected_version query parameter for optimistic locking.
User must have permission to delete the expense (e.g., be the payer or group admin).
"""
logger.info(f"User {current_user.email} attempting to delete expense ID {expense_id} (expected version {expected_version})")
expense_db = await crud_expense.get_expense_by_id(db, expense_id=expense_id)
if not expense_db:
# Return 204 even if not found, as the end state is achieved (item is gone)
logger.warning(f"Attempt to delete non-existent expense ID {expense_id}")
return Response(status_code=status.HTTP_204_NO_CONTENT)
# Alternatively, raise NotFoundError(detail=f"Expense {expense_id} not found") -> 404
# --- Granular Permission Check ---
can_delete = False
# 1. User paid for the expense
if expense_db.paid_by_user_id == current_user.id:
can_delete = True
# 2. OR User is owner of the group the expense belongs to
elif expense_db.group_id:
try:
await crud_group.check_user_role_in_group(db, group_id=expense_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="delete group expenses")
can_delete = True
logger.info(f"Allowing delete for expense {expense_id} by group owner {current_user.email}")
except GroupMembershipError:
pass
except GroupPermissionError:
pass
except GroupNotFoundError:
logger.error(f"Group {expense_db.group_id} not found for expense {expense_id} during delete check.")
pass
if not can_delete:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot delete this expense (must be payer or group owner)")
try:
await crud_expense.delete_expense(db=db, expense_db=expense_db, expected_version=expected_version)
logger.info(f"Expense ID {expense_id} deleted successfully.")
# No need to return content on 204
except InvalidOperationError as e:
# Check if it's a version conflict (409) or other validation error (400)
status_code = status.HTTP_400_BAD_REQUEST
if "version" in str(e).lower():
status_code = status.HTTP_409_CONFLICT
raise HTTPException(status_code=status_code, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error deleting expense {expense_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
return Response(status_code=status.HTTP_204_NO_CONTENT)
# --- Settlement Activity Endpoints (for ExpenseSplits) ---
@router.post(
"/expense_splits/{expense_split_id}/settle",
response_model=SettlementActivityPublic,
status_code=status.HTTP_201_CREATED,
summary="Record a Settlement Activity for an Expense Split",
tags=["Expenses", "Settlements"]
)
async def record_settlement_for_expense_split(
expense_split_id: int,
activity_in: SettlementActivityCreate,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} attempting to record settlement for expense_split_id {expense_split_id} with amount {activity_in.amount_paid}")
if activity_in.expense_split_id != expense_split_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Expense split ID in path does not match expense split ID in request body."
)
# Fetch the ExpenseSplit and its parent Expense to check context (group/list)
stmt = (
select(ExpenseSplitModel)
.options(joinedload(ExpenseSplitModel.expense)) # Load parent expense
.where(ExpenseSplitModel.id == expense_split_id)
)
result = await db.execute(stmt)
expense_split = result.scalar_one_or_none()
if not expense_split:
raise ItemNotFoundError(item_id=expense_split_id, detail_suffix="Expense split not found.")
parent_expense = expense_split.expense
if not parent_expense:
# Should not happen if data integrity is maintained
logger.error(f"Data integrity issue: ExpenseSplit {expense_split_id} has no parent Expense.")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Associated expense not found for this split.")
# --- Permission Checks ---
# The user performing the action (current_user) must be either:
# 1. The person who is making the payment (activity_in.paid_by_user_id).
# 2. An owner of the group, if the expense is tied to a group.
#
# Additionally, the payment (activity_in.paid_by_user_id) should ideally be made by the user who owes the split (expense_split.user_id).
# For simplicity, we'll first check if current_user is the one making the payment.
# More complex scenarios (e.g., a group owner settling on behalf of someone) are handled next.
can_record_settlement = False
if current_user.id == activity_in.paid_by_user_id:
# User is recording their own payment. This is allowed if they are the one who owes this split,
# or if they are paying for someone else and have group owner rights (covered below).
# We also need to ensure the person *being paid for* (activity_in.paid_by_user_id) is actually the one who owes this split.
if activity_in.paid_by_user_id != expense_split.user_id:
# Allow if current_user is group owner (checked next)
pass # Will be checked by group owner logic
else:
can_record_settlement = True # User is settling their own owed split
logger.info(f"User {current_user.email} is settling their own expense split {expense_split_id}.")
if not can_record_settlement and parent_expense.group_id:
try:
# Check if current_user is an owner of the group associated with the expense
await crud_group.check_user_role_in_group(
db,
group_id=parent_expense.group_id,
user_id=current_user.id,
required_role=UserRoleEnum.owner,
action="record settlement activities for group members"
)
can_record_settlement = True
logger.info(f"Group owner {current_user.email} is recording settlement for expense split {expense_split_id} in group {parent_expense.group_id}.")
except (GroupPermissionError, GroupMembershipError, GroupNotFoundError):
# If not group owner, and not settling own split, then permission denied.
pass # can_record_settlement remains False
if not can_record_settlement:
logger.warning(f"User {current_user.email} does not have permission to record settlement for expense split {expense_split_id}.")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You do not have permission to record this settlement activity. Must be the payer or a group owner."
)
# Final check: if someone is recording a payment for a split, the `paid_by_user_id` in the activity
# should match the `user_id` of the `ExpenseSplit` (the person who owes).
# The above permissions allow the current_user to *initiate* this, but the data itself must be consistent.
if activity_in.paid_by_user_id != expense_split.user_id:
logger.warning(f"Attempt to record settlement for expense split {expense_split_id} where activity payer ({activity_in.paid_by_user_id}) "
f"does not match split owner ({expense_split.user_id}). Only allowed if current_user is group owner and recording on behalf of split owner.")
# This scenario is tricky. If a group owner is settling for someone, they *might* set paid_by_user_id to the split owner.
# The current permission model allows the group owner to act. The crucial part is that the activity links to the correct split owner.
# If the intent is "current_user (owner) pays on behalf of expense_split.user_id", then activity_in.paid_by_user_id should be expense_split.user_id
# and current_user.id is the one performing the action (created_by_user_id in settlement_activity model).
# The CRUD `create_settlement_activity` will set `created_by_user_id` to `current_user.id`.
# The main point is that `activity_in.paid_by_user_id` should be the person whose debt is being cleared.
if current_user.id != expense_split.user_id and not (parent_expense.group_id and await crud_group.is_user_role_in_group(db, group_id=parent_expense.group_id, user_id=current_user.id, role=UserRoleEnum.owner)):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The payer ID ({activity_in.paid_by_user_id}) in the settlement activity must match the user ID of the expense split owner ({expense_split.user_id}), unless you are a group owner acting on their behalf."
)
try:
created_activity = await crud_settlement_activity.create_settlement_activity(
db=db,
settlement_activity_in=activity_in,
current_user_id=current_user.id
)
logger.info(f"Settlement activity {created_activity.id} recorded for expense split {expense_split_id} by user {current_user.email}")
return created_activity
except UserNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User referenced in settlement activity not found: {str(e)}")
except InvalidOperationError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error recording settlement activity for expense_split_id {expense_split_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while recording settlement activity.")
@router.get(
"/expense_splits/{expense_split_id}/settlement_activities",
response_model=PyList[SettlementActivityPublic],
summary="List Settlement Activities for an Expense Split",
tags=["Expenses", "Settlements"]
)
async def list_settlement_activities_for_split(
expense_split_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=200),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} listing settlement activities for expense_split_id {expense_split_id}")
# Fetch the ExpenseSplit and its parent Expense to check context (group/list) for permissions
stmt = (
select(ExpenseSplitModel)
.options(joinedload(ExpenseSplitModel.expense)) # Load parent expense
.where(ExpenseSplitModel.id == expense_split_id)
)
result = await db.execute(stmt)
expense_split = result.scalar_one_or_none()
if not expense_split:
raise ItemNotFoundError(item_id=expense_split_id, detail_suffix="Expense split not found.")
parent_expense = expense_split.expense
if not parent_expense:
logger.error(f"Data integrity issue: ExpenseSplit {expense_split_id} has no parent Expense.")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Associated expense not found for this split.")
# --- Permission Check (similar to viewing an expense) ---
# User must have access to the parent expense.
can_view_activities = False
if parent_expense.list_id:
try:
await check_list_access_for_financials(db, parent_expense.list_id, current_user.id, action="view settlement activities for list expense")
can_view_activities = True
except (ListPermissionError, ListNotFoundError):
pass # Keep can_view_activities False
elif parent_expense.group_id:
try:
await crud_group.check_group_membership(db, group_id=parent_expense.group_id, user_id=current_user.id, action="view settlement activities for group expense")
can_view_activities = True
except (GroupMembershipError, GroupNotFoundError):
pass # Keep can_view_activities False
elif parent_expense.paid_by_user_id == current_user.id or expense_split.user_id == current_user.id :
# If expense is not tied to list/group (e.g. item-based personal expense),
# allow if current user paid the expense OR is the one who owes this specific split.
can_view_activities = True
if not can_view_activities:
logger.warning(f"User {current_user.email} does not have permission to view settlement activities for expense split {expense_split_id}.")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You do not have permission to view settlement activities for this expense split."
)
activities = await crud_settlement_activity.get_settlement_activities_for_split(
db=db, expense_split_id=expense_split_id, skip=skip, limit=limit
)
return activities
# --- Settlement Endpoints ---
@router.post(
"/settlements",
response_model=SettlementPublic,
status_code=status.HTTP_201_CREATED,
summary="Record New Settlement",
tags=["Settlements"]
)
async def create_new_settlement(
settlement_in: SettlementCreate,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} recording settlement in group {settlement_in.group_id}")
await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=current_user.id, action="record settlements in")
try:
await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=settlement_in.paid_by_user_id, action="be a payer in this group's settlement")
await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=settlement_in.paid_to_user_id, action="be a payee in this group's settlement")
except GroupMembershipError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Payer or payee issue: {str(e)}")
except GroupNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
try:
created_settlement = await crud_settlement.create_settlement(db=db, settlement_in=settlement_in, current_user_id=current_user.id)
logger.info(f"Settlement ID {created_settlement.id} recorded successfully in group {settlement_in.group_id}.")
return created_settlement
except (UserNotFoundError, GroupNotFoundError, InvalidOperationError, GroupMembershipError) as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error recording settlement: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
@router.get("/settlements/{settlement_id}", response_model=SettlementPublic, summary="Get Settlement by ID", tags=["Settlements"])
async def get_settlement(
settlement_id: int,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} requesting settlement ID {settlement_id}")
settlement = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id)
if not settlement:
raise ItemNotFoundError(item_id=settlement_id)
is_party_to_settlement = current_user.id in [settlement.paid_by_user_id, settlement.paid_to_user_id]
try:
await crud_group.check_group_membership(db, group_id=settlement.group_id, user_id=current_user.id)
except GroupMembershipError:
if not is_party_to_settlement:
raise GroupMembershipError(settlement.group_id, action="view this settlement's details")
logger.info(f"User {current_user.email} (party to settlement) viewing settlement {settlement_id} for group {settlement.group_id}.")
return settlement
@router.get("/groups/{group_id}/settlements", response_model=PyList[SettlementPublic], summary="List Settlements for a Group", tags=["Settlements", "Groups"])
async def list_group_settlements(
group_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=200),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} listing settlements for group ID {group_id}")
await crud_group.check_group_membership(db, group_id=group_id, user_id=current_user.id, action="list settlements for this group")
settlements = await crud_settlement.get_settlements_for_group(db, group_id=group_id, skip=skip, limit=limit)
return settlements
@router.put("/settlements/{settlement_id}", response_model=SettlementPublic, summary="Update Settlement", tags=["Settlements"])
async def update_settlement_details(
settlement_id: int,
settlement_in: SettlementUpdate,
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
Updates an existing settlement (description, settlement_date only).
Requires the current version number for optimistic locking.
User must have permission (e.g., be involved party or group admin).
"""
logger.info(f"User {current_user.email} attempting to update settlement ID {settlement_id} (version {settlement_in.version})")
settlement_db = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id)
if not settlement_db:
raise ItemNotFoundError(item_id=settlement_id)
# --- Granular Permission Check ---
can_modify = False
# 1. User is involved party (payer or payee)
is_party = current_user.id in [settlement_db.paid_by_user_id, settlement_db.paid_to_user_id]
if is_party:
can_modify = True
# 2. OR User is owner of the group the settlement belongs to
# Note: Settlements always have a group_id based on current model
elif settlement_db.group_id:
try:
await crud_group.check_user_role_in_group(db, group_id=settlement_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="modify group settlements")
can_modify = True
logger.info(f"Allowing update for settlement {settlement_id} by group owner {current_user.email}")
except GroupMembershipError:
pass
except GroupPermissionError:
pass
except GroupNotFoundError:
logger.error(f"Group {settlement_db.group_id} not found for settlement {settlement_id} during update check.")
pass
if not can_modify:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot modify this settlement (must be involved party or group owner)")
try:
updated_settlement = await crud_settlement.update_settlement(db=db, settlement_db=settlement_db, settlement_in=settlement_in)
logger.info(f"Settlement ID {settlement_id} updated successfully to version {updated_settlement.version}.")
return updated_settlement
except InvalidOperationError as e:
status_code = status.HTTP_400_BAD_REQUEST
if "version" in str(e).lower():
status_code = status.HTTP_409_CONFLICT
raise HTTPException(status_code=status_code, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error updating settlement {settlement_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
@router.delete("/settlements/{settlement_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete Settlement", tags=["Settlements"])
async def delete_settlement_record(
settlement_id: int,
expected_version: Optional[int] = Query(None, description="Expected version for optimistic locking"),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
Deletes a settlement.
Requires expected_version query parameter for optimistic locking.
User must have permission (e.g., be involved party or group admin).
"""
logger.info(f"User {current_user.email} attempting to delete settlement ID {settlement_id} (expected version {expected_version})")
settlement_db = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id)
if not settlement_db:
logger.warning(f"Attempt to delete non-existent settlement ID {settlement_id}")
return Response(status_code=status.HTTP_204_NO_CONTENT)
# --- Granular Permission Check ---
can_delete = False
# 1. User is involved party (payer or payee)
is_party = current_user.id in [settlement_db.paid_by_user_id, settlement_db.paid_to_user_id]
if is_party:
can_delete = True
# 2. OR User is owner of the group the settlement belongs to
elif settlement_db.group_id:
try:
await crud_group.check_user_role_in_group(db, group_id=settlement_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="delete group settlements")
can_delete = True
logger.info(f"Allowing delete for settlement {settlement_id} by group owner {current_user.email}")
except GroupMembershipError:
pass
except GroupPermissionError:
pass
except GroupNotFoundError:
logger.error(f"Group {settlement_db.group_id} not found for settlement {settlement_id} during delete check.")
pass
if not can_delete:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot delete this settlement (must be involved party or group owner)")
try:
await crud_settlement.delete_settlement(db=db, settlement_db=settlement_db, expected_version=expected_version)
logger.info(f"Settlement ID {settlement_id} deleted successfully.")
except InvalidOperationError as e:
status_code = status.HTTP_400_BAD_REQUEST
if "version" in str(e).lower():
status_code = status.HTTP_409_CONFLICT
raise HTTPException(status_code=status_code, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error deleting settlement {settlement_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
return Response(status_code=status.HTTP_204_NO_CONTENT)

View File

@ -5,14 +5,24 @@ from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.api.dependencies import get_current_user
from app.database import get_transactional_session, get_session
from app.auth import current_active_user
from app.models import User as UserModel, UserRoleEnum # Import model and enum
from app.schemas.group import GroupCreate, GroupPublic
from app.schemas.invite import InviteCodePublic
from app.schemas.message import Message # For simple responses
from app.schemas.list import ListPublic, ListDetail
from app.crud import group as crud_group
from app.crud import invite as crud_invite
from app.crud import list as crud_list
from app.core.exceptions import (
GroupNotFoundError,
GroupPermissionError,
GroupMembershipError,
GroupOperationError,
GroupValidationError,
InviteCreationError
)
logger = logging.getLogger(__name__)
router = APIRouter()
@ -26,8 +36,8 @@ router = APIRouter()
)
async def create_group(
group_in: GroupCreate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Creates a new group, adding the creator as the owner."""
logger.info(f"User {current_user.email} creating group: {group_in.name}")
@ -44,8 +54,8 @@ async def create_group(
tags=["Groups"]
)
async def read_user_groups(
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_session), # Use read-only session for GET
current_user: UserModel = Depends(current_active_user),
):
"""Retrieves all groups the current user is a member of."""
logger.info(f"Fetching groups for user: {current_user.email}")
@ -61,8 +71,8 @@ async def read_user_groups(
)
async def read_group(
group_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_session), # Use read-only session for GET
current_user: UserModel = Depends(current_active_user),
):
"""Retrieves details for a specific group, including members, if the user is part of it."""
logger.info(f"User {current_user.email} requesting details for group ID: {group_id}")
@ -70,18 +80,14 @@ async def read_group(
is_member = await crud_group.is_user_member(db=db, group_id=group_id, user_id=current_user.id)
if not is_member:
logger.warning(f"Access denied: User {current_user.email} not member of group {group_id}")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not a member of this group")
raise GroupMembershipError(group_id, "view group details")
group = await crud_group.get_group_by_id(db=db, group_id=group_id)
if not group:
logger.error(f"Group {group_id} requested by member {current_user.email} not found (data inconsistency?)")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found")
raise GroupNotFoundError(group_id)
# Manually construct the members list with UserPublic schema if needed
# Pydantic v2's from_attributes should handle this if relationships are loaded
# members_public = [UserPublic.model_validate(assoc.user) for assoc in group.member_associations]
# return GroupPublic.model_validate(group, update={"members": members_public})
return group # Rely on Pydantic conversion and eager loading
return group
@router.post(
@ -92,8 +98,8 @@ async def read_group(
)
async def create_group_invite(
group_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Generates a new invite code for the group. Requires owner/admin role (MVP: owner only)."""
logger.info(f"User {current_user.email} attempting to create invite for group {group_id}")
@ -102,21 +108,59 @@ async def create_group_invite(
# --- Permission Check (MVP: Owner only) ---
if user_role != UserRoleEnum.owner:
logger.warning(f"Permission denied: User {current_user.email} (role: {user_role}) cannot create invite for group {group_id}")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only group owners can create invites")
raise GroupPermissionError(group_id, "create invites")
# Check if group exists (implicitly done by role check, but good practice)
group = await crud_group.get_group_by_id(db, group_id)
if not group:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found")
raise GroupNotFoundError(group_id)
invite = await crud_invite.create_invite(db=db, group_id=group_id, creator_id=current_user.id)
if not invite:
logger.error(f"Failed to generate unique invite code for group {group_id}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Could not generate invite code")
# This case should ideally be covered by exceptions from create_invite now
raise InviteCreationError(group_id)
logger.info(f"Invite code created for group {group_id} by user {current_user.email}")
logger.info(f"User {current_user.email} created invite code for group {group_id}")
return invite
@router.get(
"/{group_id}/invites",
response_model=InviteCodePublic, # Or Optional[InviteCodePublic] if it can be null
summary="Get Group Active Invite Code",
tags=["Groups", "Invites"]
)
async def get_group_active_invite(
group_id: int,
db: AsyncSession = Depends(get_session), # Use read-only session for GET
current_user: UserModel = Depends(current_active_user),
):
"""Retrieves the active invite code for the group. Requires group membership (owner/admin to be stricter later if needed)."""
logger.info(f"User {current_user.email} attempting to get active invite for group {group_id}")
# Permission check: Ensure user is a member of the group to view invite code
# Using get_user_role_in_group which also checks membership indirectly
user_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=current_user.id)
if user_role is None: # Not a member
logger.warning(f"Permission denied: User {current_user.email} is not a member of group {group_id} and cannot view invite code.")
# More specific error or let GroupPermissionError handle if we want to be generic
raise GroupMembershipError(group_id, "view invite code for this group (not a member)")
# Fetch the active invite for the group
invite = await crud_invite.get_active_invite_for_group(db, group_id=group_id)
if not invite:
# This case means no active (non-expired, active=true) invite exists.
# The frontend can then prompt to generate one.
logger.info(f"No active invite code found for group {group_id} when requested by {current_user.email}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No active invite code found for this group. Please generate one."
)
logger.info(f"User {current_user.email} retrieved active invite code for group {group_id}")
return invite # Pydantic will convert InviteModel to InviteCodePublic
@router.delete(
"/{group_id}/leave",
response_model=Message,
@ -125,31 +169,32 @@ async def create_group_invite(
)
async def leave_group(
group_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Removes the current user from the specified group."""
"""Removes the current user from the specified group. If the owner is the last member, the group will be deleted."""
logger.info(f"User {current_user.email} attempting to leave group {group_id}")
user_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=current_user.id)
if user_role is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="You are not a member of this group")
raise GroupMembershipError(group_id, "leave (you are not a member)")
# --- MVP: Prevent owner leaving if they are the last member/owner ---
# Check if owner is the last member
if user_role == UserRoleEnum.owner:
member_count = await crud_group.get_group_member_count(db, group_id)
# More robust check: count owners. For now, just check member count.
if member_count <= 1:
logger.warning(f"Owner {current_user.email} attempted to leave group {group_id} as last member.")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Owner cannot leave the group as the last member. Delete the group or transfer ownership.")
# Delete the group since owner is the last member
logger.info(f"Owner {current_user.email} is the last member. Deleting group {group_id}")
await crud_group.delete_group(db, group_id)
return Message(detail="Group deleted as you were the last member")
# Proceed with removal
# Proceed with removal for non-owner or if there are other members
deleted = await crud_group.remove_user_from_group(db, group_id=group_id, user_id=current_user.id)
if not deleted:
# Should not happen if role check passed, but handle defensively
logger.error(f"Failed to remove user {current_user.email} from group {group_id} despite being a member.")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to leave group")
raise GroupOperationError("Failed to leave group")
logger.info(f"User {current_user.email} successfully left group {group_id}")
return Message(detail="Successfully left the group")
@ -164,8 +209,8 @@ async def leave_group(
async def remove_group_member(
group_id: int,
user_id_to_remove: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Removes a specified user from the group. Requires current user to be owner."""
logger.info(f"Owner {current_user.email} attempting to remove user {user_id_to_remove} from group {group_id}")
@ -174,23 +219,49 @@ async def remove_group_member(
# --- Permission Check ---
if owner_role != UserRoleEnum.owner:
logger.warning(f"Permission denied: User {current_user.email} (role: {owner_role}) cannot remove members from group {group_id}")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only group owners can remove members")
raise GroupPermissionError(group_id, "remove members")
# Prevent owner removing themselves via this endpoint
if current_user.id == user_id_to_remove:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Owner cannot remove themselves using this endpoint. Use 'Leave Group' instead.")
raise GroupValidationError("Owner cannot remove themselves using this endpoint. Use 'Leave Group' instead.")
# Check if target user is actually in the group
target_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=user_id_to_remove)
if target_role is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User to remove is not a member of this group")
raise GroupMembershipError(group_id, "remove this user (they are not a member)")
# Proceed with removal
deleted = await crud_group.remove_user_from_group(db, group_id=group_id, user_id=user_id_to_remove)
if not deleted:
logger.error(f"Owner {current_user.email} failed to remove user {user_id_to_remove} from group {group_id}.")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to remove member")
raise GroupOperationError("Failed to remove member")
logger.info(f"Owner {current_user.email} successfully removed user {user_id_to_remove} from group {group_id}")
return Message(detail="Successfully removed member from the group")
@router.get(
"/{group_id}/lists",
response_model=List[ListDetail],
summary="Get Group Lists",
tags=["Groups", "Lists"]
)
async def read_group_lists(
group_id: int,
db: AsyncSession = Depends(get_session), # Use read-only session for GET
current_user: UserModel = Depends(current_active_user),
):
"""Retrieves all lists belonging to a specific group, if the user is a member."""
logger.info(f"User {current_user.email} requesting lists for group ID: {group_id}")
# Check if user is a member first
is_member = await crud_group.is_user_member(db=db, group_id=group_id, user_id=current_user.id)
if not is_member:
logger.warning(f"Access denied: User {current_user.email} not member of group {group_id}")
raise GroupMembershipError(group_id, "view group lists")
# Get all lists for the user and filter by group_id
lists = await crud_list.get_lists_for_user(db=db, user_id=current_user.id)
group_lists = [list for list in lists if list.group_id == group_id]
return group_lists

View File

@ -1,11 +1,12 @@
# app/api/v1/endpoints/health.py
import logging
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import text
from app.database import get_db # Import the dependency function
from app.schemas.health import HealthStatus # Import the response schema
from app.database import get_transactional_session
from app.schemas.health import HealthStatus
from app.core.exceptions import DatabaseConnectionError
logger = logging.getLogger(__name__)
router = APIRouter()
@ -15,9 +16,9 @@ router = APIRouter()
response_model=HealthStatus,
summary="Perform a Health Check",
description="Checks the operational status of the API and its connection to the database.",
tags=["Health"] # Group this endpoint in Swagger UI
tags=["Health"]
)
async def check_health(db: AsyncSession = Depends(get_db)):
async def check_health(db: AsyncSession = Depends(get_transactional_session)):
"""
Health check endpoint. Verifies API reachability and database connection.
"""
@ -30,16 +31,8 @@ async def check_health(db: AsyncSession = Depends(get_db)):
else:
# This case should ideally not happen with 'SELECT 1'
logger.error("Health check failed: Database connection check returned unexpected result.")
# Raise 503 Service Unavailable
raise HTTPException(
status_code=503,
detail="Database connection error: Unexpected result"
)
raise DatabaseConnectionError("Unexpected result from database connection check")
except Exception as e:
logger.error(f"Health check failed: Database connection error - {e}", exc_info=True) # Log stack trace
# Raise 503 Service Unavailable
raise HTTPException(
status_code=503,
detail=f"Database connection error: {e}"
)
logger.error(f"Health check failed: Database connection error - {e}", exc_info=True)
raise DatabaseConnectionError(str(e))

View File

@ -3,57 +3,77 @@ import logging
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.api.dependencies import get_current_user
from app.database import get_transactional_session
from app.auth import current_active_user
from app.models import User as UserModel, UserRoleEnum
from app.schemas.invite import InviteAccept
from app.schemas.message import Message
from app.schemas.group import GroupPublic
from app.crud import invite as crud_invite
from app.crud import group as crud_group
from app.core.exceptions import (
InviteNotFoundError,
InviteExpiredError,
InviteAlreadyUsedError,
InviteCreationError,
GroupNotFoundError,
GroupMembershipError,
GroupOperationError
)
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post(
"/accept", # Route relative to prefix "/invites"
response_model=Message,
response_model=GroupPublic,
summary="Accept Group Invite",
tags=["Invites"]
)
async def accept_invite(
invite_in: InviteAccept,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Allows an authenticated user to accept an invite using its code."""
code = invite_in.code
logger.info(f"User {current_user.email} attempting to accept invite code: {code}")
"""Accepts a group invite using the provided invite code."""
logger.info(f"User {current_user.email} attempting to accept invite code: {invite_in.code}")
# Find the active, non-expired invite
invite = await crud_invite.get_active_invite_by_code(db=db, code=code)
# Get the invite - this function should only return valid, active invites
invite = await crud_invite.get_active_invite_by_code(db, code=invite_in.code)
if not invite:
logger.warning(f"Invite code '{code}' not found, expired, or already used.")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Invite code is invalid or expired")
logger.warning(f"Invalid or inactive invite code attempted by user {current_user.email}: {invite_in.code}")
# We can use a more generic error or a specific one. InviteNotFound is reasonable.
raise InviteNotFoundError(invite_in.code)
group_id = invite.group_id
# Check if group still exists
group = await crud_group.get_group_by_id(db, group_id=invite.group_id)
if not group:
logger.error(f"Group {invite.group_id} not found for invite {invite_in.code}")
raise GroupNotFoundError(invite.group_id)
# Check if user is already in the group
is_member = await crud_group.is_user_member(db=db, group_id=group_id, user_id=current_user.id)
# Check if user is already a member
is_member = await crud_group.is_user_member(db, group_id=invite.group_id, user_id=current_user.id)
if is_member:
logger.info(f"User {current_user.email} is already a member of group {group_id}. Invite '{code}' still deactivated.")
# Deactivate invite even if already member, to prevent reuse
await crud_invite.deactivate_invite(db=db, invite=invite)
return Message(detail="You are already a member of this group.")
logger.warning(f"User {current_user.email} already a member of group {invite.group_id}")
raise GroupMembershipError(invite.group_id, "join (already a member)")
# Add user to the group as a member
added = await crud_group.add_user_to_group(db=db, group_id=group_id, user_id=current_user.id, role=UserRoleEnum.member)
if not added:
# Should not happen if is_member check was correct, but handle defensively
logger.error(f"Failed to add user {current_user.email} to group {group_id} via invite '{code}' despite not being a member.")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Could not join group.")
# Add user to the group
added_to_group = await crud_group.add_user_to_group(db, group_id=invite.group_id, user_id=current_user.id)
if not added_to_group:
logger.error(f"Failed to add user {current_user.email} to group {invite.group_id} during invite acceptance.")
# This could be a race condition or other issue, treat as an operational error.
raise GroupOperationError("Failed to add user to group.")
# Deactivate the invite (single-use)
await crud_invite.deactivate_invite(db=db, invite=invite)
# Deactivate the invite so it cannot be used again
await crud_invite.deactivate_invite(db, invite=invite)
logger.info(f"User {current_user.email} successfully joined group {group_id} using invite '{code}'.")
return Message(detail="Successfully joined the group.")
logger.info(f"User {current_user.email} successfully joined group {invite.group_id} via invite {invite_in.code}")
# Re-fetch the group to get the updated member list
updated_group = await crud_group.get_group_by_id(db, group_id=invite.group_id)
if not updated_group:
# This should ideally not happen as we found it before
logger.error(f"Could not re-fetch group {invite.group_id} after user {current_user.email} joined.")
raise GroupNotFoundError(invite.group_id)
return updated_group

View File

@ -1,12 +1,12 @@
# app/api/v1/endpoints/items.py
import logging
from typing import List as PyList
from typing import List as PyList, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Response
from fastapi import APIRouter, Depends, HTTPException, status, Response, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.api.dependencies import get_current_user
from app.database import get_transactional_session
from app.auth import current_active_user
# --- Import Models Correctly ---
from app.models import User as UserModel
from app.models import Item as ItemModel # <-- IMPORT Item and alias it
@ -14,6 +14,7 @@ from app.models import Item as ItemModel # <-- IMPORT Item and alias it
from app.schemas.item import ItemCreate, ItemUpdate, ItemPublic
from app.crud import item as crud_item
from app.crud import list as crud_list
from app.core.exceptions import ItemNotFoundError, ListPermissionError, ConflictError
logger = logging.getLogger(__name__)
router = APIRouter()
@ -22,19 +23,21 @@ router = APIRouter()
# Now ItemModel is defined before being used as a type hint
async def get_item_and_verify_access(
item_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user)
) -> ItemModel: # Now this type hint is valid
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user)
) -> ItemModel:
"""Dependency to get an item and verify the user has access to its list."""
item_db = await crud_item.get_item_by_id(db, item_id=item_id)
if not item_db:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Item not found")
raise ItemNotFoundError(item_id)
# Check permission on the parent list
list_db = await crud_list.check_list_permission(db=db, list_id=item_db.list_id, user_id=current_user.id)
if not list_db:
# User doesn't have access to the list this item belongs to
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to access this item's list")
return item_db # Return the fetched item if authorized
try:
await crud_list.check_list_permission(db=db, list_id=item_db.list_id, user_id=current_user.id)
except ListPermissionError as e:
# Re-raise with a more specific message
raise ListPermissionError(item_db.list_id, "access this item's list")
return item_db
# --- Endpoints ---
@ -49,25 +52,23 @@ async def get_item_and_verify_access(
async def create_list_item(
list_id: int,
item_in: ItemCreate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Adds a new item to a specific list. User must have access to the list."""
logger.info(f"User {current_user.email} adding item to list {list_id}: {item_in.name}")
user_email = current_user.email # Access email attribute before async operations
logger.info(f"User {user_email} adding item to list {list_id}: {item_in.name}")
# Verify user has access to the target list
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
if not list_db:
# Check if list exists at all for correct error code
exists = await crud_list.get_list_by_id(db, list_id)
status_code = status.HTTP_404_NOT_FOUND if not exists else status.HTTP_403_FORBIDDEN
detail = "List not found" if not exists else "You do not have permission to add items to this list"
logger.warning(f"Add item failed for list {list_id} by user {current_user.email}: {detail}")
raise HTTPException(status_code=status_code, detail=detail)
try:
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
except ListPermissionError as e:
# Re-raise with a more specific message
raise ListPermissionError(list_id, "add items to this list")
created_item = await crud_item.create_item(
db=db, item_in=item_in, list_id=list_id, user_id=current_user.id
)
logger.info(f"Item '{created_item.name}' (ID: {created_item.id}) added to list {list_id} by user {current_user.email}.")
logger.info(f"Item '{created_item.name}' (ID: {created_item.id}) added to list {list_id} by user {user_email}.")
return created_item
@ -79,63 +80,102 @@ async def create_list_item(
)
async def read_list_items(
list_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
# Add sorting/filtering params later if needed: sort_by: str = 'created_at', order: str = 'asc'
):
"""Retrieves all items for a specific list if the user has access."""
logger.info(f"User {current_user.email} listing items for list {list_id}")
user_email = current_user.email # Access email attribute before async operations
logger.info(f"User {user_email} listing items for list {list_id}")
# Verify user has access to the list
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
if not list_db:
exists = await crud_list.get_list_by_id(db, list_id)
status_code = status.HTTP_404_NOT_FOUND if not exists else status.HTTP_403_FORBIDDEN
detail = "List not found" if not exists else "You do not have permission to view items in this list"
logger.warning(f"List items failed for list {list_id} by user {current_user.email}: {detail}")
raise HTTPException(status_code=status_code, detail=detail)
try:
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
except ListPermissionError as e:
# Re-raise with a more specific message
raise ListPermissionError(list_id, "view items in this list")
items = await crud_item.get_items_by_list_id(db=db, list_id=list_id)
return items
@router.put("/items/{item_id}", response_model=ItemPublic, summary="Update Item", tags=["Items"])
@router.put(
"/lists/{list_id}/items/{item_id}", # Nested under lists
response_model=ItemPublic,
summary="Update Item",
tags=["Items"],
responses={
status.HTTP_409_CONFLICT: {"description": "Conflict: Item has been modified by someone else"}
}
)
async def update_item(
list_id: int,
item_id: int,
item_in: ItemUpdate,
item_db: ItemModel = Depends(get_item_and_verify_access),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user), # Need user ID for completed_by
):
"""
Updates an item's details (name, quantity, is_complete, price).
User must have access to the list the item belongs to.
The client MUST provide the current `version` of the item in the `item_in` payload.
If the version does not match, a 409 Conflict is returned.
Sets/unsets `completed_by_id` based on `is_complete` flag.
"""
logger.info(f"User {current_user.email} attempting to update item ID: {item_id}")
updated_item = await crud_item.update_item(db=db, item_db=item_db, item_in=item_in, user_id=current_user.id)
logger.info(f"Item {item_id} updated successfully by user {current_user.email}.")
user_email = current_user.email # Access email attribute before async operations
logger.info(f"User {user_email} attempting to update item ID: {item_id} with version {item_in.version}")
# Permission check is handled by get_item_and_verify_access dependency
try:
updated_item = await crud_item.update_item(
db=db, item_db=item_db, item_in=item_in, user_id=current_user.id
)
logger.info(f"Item {item_id} updated successfully by user {user_email} to version {updated_item.version}.")
return updated_item
except ConflictError as e:
logger.warning(f"Conflict updating item {item_id} for user {user_email}: {str(e)}")
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
except Exception as e:
logger.error(f"Error updating item {item_id} for user {user_email}: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the item.")
@router.delete(
"/items/{item_id}", # Operate directly on item ID
"/lists/{list_id}/items/{item_id}", # Nested under lists
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete Item",
tags=["Items"]
tags=["Items"],
responses={
status.HTTP_409_CONFLICT: {"description": "Conflict: Item has been modified, cannot delete specified version"}
}
)
async def delete_item(
item_id: int, # Item ID from path
list_id: int,
item_id: int,
expected_version: Optional[int] = Query(None, description="The expected version of the item to delete for optimistic locking."),
item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user), # Log who deleted it
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user), # Log who deleted it
):
"""
Deletes an item. User must have access to the list the item belongs to.
(MVP: Any member with list access can delete items).
If `expected_version` is provided and does not match the item's current version,
a 409 Conflict is returned.
"""
logger.info(f"User {current_user.email} attempting to delete item ID: {item_id}")
user_email = current_user.email # Access email attribute before async operations
logger.info(f"User {user_email} attempting to delete item ID: {item_id}, expected version: {expected_version}")
# Permission check is handled by get_item_and_verify_access dependency
if expected_version is not None and item_db.version != expected_version:
logger.warning(
f"Conflict deleting item {item_id} for user {user_email}. "
f"Expected version {expected_version}, actual version {item_db.version}."
)
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Item has been modified. Expected version {expected_version}, but current version is {item_db.version}. Please refresh."
)
await crud_item.delete_item(db=db, item_db=item_db)
logger.info(f"Item {item_id} deleted successfully by user {current_user.email}.")
logger.info(f"Item {item_id} (version {item_db.version}) deleted successfully by user {user_email}.")
return Response(status_code=status.HTTP_204_NO_CONTENT)

View File

@ -1,20 +1,27 @@
# app/api/v1/endpoints/lists.py
import logging
from typing import List as PyList # Alias for Python List type hint
from typing import List as PyList, Optional # Alias for Python List type hint
from fastapi import APIRouter, Depends, HTTPException, status, Response
from fastapi import APIRouter, Depends, HTTPException, status, Response, Query # Added Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.api.dependencies import get_current_user
from app.database import get_transactional_session
from app.auth import current_active_user
from app.models import User as UserModel
from app.schemas.list import ListCreate, ListUpdate, ListPublic, ListDetail
from app.schemas.message import Message # For simple responses
from app.crud import list as crud_list
from app.crud import expense as crud_expense
from app.crud import group as crud_group # Need for group membership check
from app.schemas.list import ListStatus
from app.schemas.expense import ExpenseRecordPublic
from app.schemas.list import ListStatus, ListStatusWithId
from app.schemas.expense import ExpensePublic # Import ExpensePublic
from app.core.exceptions import (
GroupMembershipError,
ListNotFoundError,
ListPermissionError,
ListStatusNotFoundError,
ConflictError, # Added ConflictError
DatabaseIntegrityError # Added DatabaseIntegrityError
)
logger = logging.getLogger(__name__)
router = APIRouter()
@ -24,17 +31,24 @@ router = APIRouter()
response_model=ListPublic, # Return basic list info on creation
status_code=status.HTTP_201_CREATED,
summary="Create New List",
tags=["Lists"]
tags=["Lists"],
responses={
status.HTTP_409_CONFLICT: {
"description": "Conflict: A list with this name already exists in the specified group",
"model": ListPublic
}
}
)
async def create_list(
list_in: ListCreate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
Creates a new shopping list.
- If `group_id` is provided, the user must be a member of that group.
- If `group_id` is null, it's a personal list.
- If a list with the same name already exists in the group, returns 409 with the existing list.
"""
logger.info(f"User {current_user.email} creating list: {list_in.name}")
group_id = list_in.group_id
@ -44,25 +58,42 @@ async def create_list(
is_member = await crud_group.is_user_member(db, group_id=group_id, user_id=current_user.id)
if not is_member:
logger.warning(f"User {current_user.email} attempted to create list in group {group_id} but is not a member.")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You are not a member of the specified group",
)
raise GroupMembershipError(group_id, "create lists")
try:
created_list = await crud_list.create_list(db=db, list_in=list_in, creator_id=current_user.id)
logger.info(f"List '{created_list.name}' (ID: {created_list.id}) created successfully for user {current_user.email}.")
return created_list
except DatabaseIntegrityError as e:
# Check if this is a unique constraint violation
if "unique constraint" in str(e).lower():
# Find the existing list with the same name in the group
existing_list = await crud_list.get_list_by_name_and_group(
db=db,
name=list_in.name,
group_id=group_id,
user_id=current_user.id
)
if existing_list:
logger.info(f"List '{list_in.name}' already exists in group {group_id}. Returning existing list.")
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"A list named '{list_in.name}' already exists in this group.",
headers={"X-Existing-List": str(existing_list.id)}
)
# If it's not a unique constraint or we couldn't find the existing list, re-raise
raise
@router.get(
"", # Route relative to prefix "/lists"
response_model=PyList[ListPublic], # Return a list of basic list info
response_model=PyList[ListDetail], # Return a list of detailed list info including items
summary="List Accessible Lists",
tags=["Lists"]
)
async def read_lists(
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
# Add pagination parameters later if needed: skip: int = 0, limit: int = 100
):
"""
@ -75,6 +106,39 @@ async def read_lists(
return lists
@router.get(
"/statuses",
response_model=PyList[ListStatusWithId],
summary="Get Status for Multiple Lists",
tags=["Lists"]
)
async def read_lists_statuses(
ids: PyList[int] = Query(...),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
Retrieves the status for a list of lists.
- `updated_at`: The timestamp of the last update to the list itself.
- `item_count`: The total number of items in the list.
The user must have permission to view each list requested.
Lists that the user does not have permission for will be omitted from the response.
"""
logger.info(f"User {current_user.email} requesting statuses for list IDs: {ids}")
statuses = await crud_list.get_lists_statuses_by_ids(db=db, list_ids=ids, user_id=current_user.id)
# The CRUD function returns a list of Row objects, so we map them to the Pydantic model
return [
ListStatusWithId(
id=s.id,
updated_at=s.updated_at,
item_count=s.item_count,
latest_item_updated_at=s.latest_item_updated_at
) for s in statuses
]
@router.get(
"/{list_id}",
response_model=ListDetail, # Return detailed list info including items
@ -83,29 +147,16 @@ async def read_lists(
)
async def read_list(
list_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
Retrieves details for a specific list, including its items,
if the user has permission (creator or group member).
"""
logger.info(f"User {current_user.email} requesting details for list ID: {list_id}")
# Use the helper to fetch and check permission simultaneously
# The check_list_permission function will raise appropriate exceptions
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
if not list_db:
# check_list_permission returns None if list not found OR permission denied
# We need to check if the list exists at all to return 404 vs 403
exists = await crud_list.get_list_by_id(db, list_id)
if not exists:
logger.warning(f"List ID {list_id} not found for request by user {current_user.email}.")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="List not found")
else:
logger.warning(f"Access denied: User {current_user.email} cannot access list {list_id}.")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You do not have permission to access this list")
# list_db already has items loaded due to check_list_permission
return list_db
@ -113,118 +164,123 @@ async def read_list(
"/{list_id}",
response_model=ListPublic, # Return updated basic info
summary="Update List",
tags=["Lists"]
tags=["Lists"],
responses={ # Add 409 to responses
status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified by someone else"}
}
)
async def update_list(
list_id: int,
list_in: ListUpdate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
Updates a list's details (name, description, is_complete).
Requires user to be the creator or a member of the list's group.
(MVP: Allows any member to update these fields).
The client MUST provide the current `version` of the list in the `list_in` payload.
If the version does not match, a 409 Conflict is returned.
"""
logger.info(f"User {current_user.email} attempting to update list ID: {list_id}")
logger.info(f"User {current_user.email} attempting to update list ID: {list_id} with version {list_in.version}")
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
if not list_db:
exists = await crud_list.get_list_by_id(db, list_id)
status_code = status.HTTP_404_NOT_FOUND if not exists else status.HTTP_403_FORBIDDEN
detail = "List not found" if not exists else "You do not have permission to update this list"
logger.warning(f"Update failed for list {list_id} by user {current_user.email}: {detail}")
raise HTTPException(status_code=status_code, detail=detail)
# Prevent changing group_id or creator via this endpoint for simplicity
# if list_in.group_id is not None or list_in.created_by_id is not None:
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot change group or creator via this endpoint")
try:
updated_list = await crud_list.update_list(db=db, list_db=list_db, list_in=list_in)
logger.info(f"List {list_id} updated successfully by user {current_user.email}.")
logger.info(f"List {list_id} updated successfully by user {current_user.email} to version {updated_list.version}.")
return updated_list
except ConflictError as e: # Catch and re-raise as HTTPException for proper FastAPI response
logger.warning(f"Conflict updating list {list_id} for user {current_user.email}: {str(e)}")
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
except Exception as e: # Catch other potential errors from crud operation
logger.error(f"Error updating list {list_id} for user {current_user.email}: {str(e)}")
# Consider a more generic error, but for now, let's keep it specific if possible
# Re-raising might be better if crud layer already raises appropriate HTTPExceptions
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the list.")
@router.delete(
"/{list_id}",
status_code=status.HTTP_204_NO_CONTENT, # Standard for successful DELETE with no body
summary="Delete List",
tags=["Lists"]
tags=["Lists"],
responses={ # Add 409 to responses
status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified, cannot delete specified version"}
}
)
async def delete_list(
list_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
expected_version: Optional[int] = Query(None, description="The expected version of the list to delete for optimistic locking."),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
Deletes a list. Requires user to be the creator of the list.
(Alternatively, could allow group owner).
If `expected_version` is provided and does not match the list's current version,
a 409 Conflict is returned.
"""
logger.info(f"User {current_user.email} attempting to delete list ID: {list_id}")
logger.info(f"User {current_user.email} attempting to delete list ID: {list_id}, expected version: {expected_version}")
# Use the helper, requiring creator permission
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id, require_creator=True)
if not list_db:
exists = await crud_list.get_list_by_id(db, list_id)
status_code = status.HTTP_404_NOT_FOUND if not exists else status.HTTP_403_FORBIDDEN
detail = "List not found" if not exists else "Only the list creator can delete this list"
logger.warning(f"Delete failed for list {list_id} by user {current_user.email}: {detail}")
raise HTTPException(status_code=status_code, detail=detail)
if expected_version is not None and list_db.version != expected_version:
logger.warning(
f"Conflict deleting list {list_id} for user {current_user.email}. "
f"Expected version {expected_version}, actual version {list_db.version}."
)
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"List has been modified. Expected version {expected_version}, but current version is {list_db.version}. Please refresh."
)
await crud_list.delete_list(db=db, list_db=list_db)
logger.info(f"List {list_id} deleted successfully by user {current_user.email}.")
# Return Response with 204 status explicitly if needed, otherwise FastAPI handles it
logger.info(f"List {list_id} (version: {list_db.version}) deleted successfully by user {current_user.email}.")
return Response(status_code=status.HTTP_204_NO_CONTENT)
@router.get(
"/{list_id}/status",
response_model=ListStatus,
summary="Get List Status (for polling)",
summary="Get List Status",
tags=["Lists"]
)
async def read_list_status(
list_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
Retrieves the last update time for the list and its items, plus item count.
Used for polling to check if a full refresh is needed.
Requires user to have permission to view the list.
Retrieves the update timestamp and item count for a specific list
if the user has permission (creator or group member).
"""
# Verify user has access to the list first
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
if not list_db:
# Check if list exists at all for correct error code
exists = await crud_list.get_list_by_id(db, list_id)
status_code = status.HTTP_404_NOT_FOUND if not exists else status.HTTP_403_FORBIDDEN
detail = "List not found" if not exists else "You do not have permission to access this list's status"
logger.warning(f"Status check failed for list {list_id} by user {current_user.email}: {detail}")
raise HTTPException(status_code=status_code, detail=detail)
logger.info(f"User {current_user.email} requesting status for list ID: {list_id}")
# The check_list_permission is not needed here as get_list_status handles not found
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
return await crud_list.get_list_status(db=db, list_id=list_id)
# Fetch the status details
list_status = await crud_list.get_list_status(db=db, list_id=list_id)
if not list_status:
# Should not happen if check_list_permission passed, but handle defensively
logger.error(f"Could not retrieve status for list {list_id} even though permission check passed.")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="List status not found")
return list_status
@router.post("/{list_id}/calculate-split", response_model=ExpenseRecordPublic, summary="Calculate and Record Expense Split", status_code=status.HTTP_201_CREATED, tags=["Expenses", "Lists"])
async def calculate_list_split(
@router.get(
"/{list_id}/expenses",
response_model=PyList[ExpensePublic],
summary="Get Expenses for List",
tags=["Lists", "Expenses"]
)
async def read_list_expenses(
list_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=200),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
priced_items = await crud_expense.get_priced_items_for_list(db, list_id)
total_amount = sum(item.price for item in priced_items if item.price is not None)
participant_ids = await crud_expense.get_group_member_ids(db, list_id.group_id)
return await crud_expense.create_expense_record_and_shares(
db=db,
list_id=list_id,
calculated_by_id=current_user.id,
total_amount=total_amount,
participant_ids=participant_ids
)
"""
Retrieves expenses associated with a specific list
if the user has permission (creator or group member).
"""
from app.crud import expense as crud_expense
logger.info(f"User {current_user.email} requesting expenses for list ID: {list_id}")
# Check if user has permission to access this list
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
# Get expenses for this list
expenses = await crud_expense.get_expenses_for_list(db, list_id=list_id, skip=skip, limit=limit)
return expenses

View File

@ -1,21 +1,27 @@
# app/api/v1/endpoints/ocr.py
import logging
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File
from google.api_core import exceptions as google_exceptions # Import Google API exceptions
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException, status
from google.api_core import exceptions as google_exceptions
from app.api.dependencies import get_current_user
from app.auth import current_active_user
from app.models import User as UserModel
from app.schemas.ocr import OcrExtractResponse
from app.core.gemini import extract_items_from_image_gemini, gemini_initialization_error # Import helper
from app.core.gemini import GeminiOCRService, gemini_initialization_error
from app.core.exceptions import (
OCRServiceUnavailableError,
OCRServiceConfigError,
OCRUnexpectedError,
OCRQuotaExceededError,
InvalidFileTypeError,
FileTooLargeError,
OCRProcessingError
)
from app.config import settings
logger = logging.getLogger(__name__)
router = APIRouter()
# Allowed image MIME types
ALLOWED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/webp"]
MAX_FILE_SIZE_MB = 10 # Set a reasonable max file size
ocr_service = GeminiOCRService()
@router.post(
"/extract-items",
@ -24,8 +30,7 @@ MAX_FILE_SIZE_MB = 10 # Set a reasonable max file size
tags=["OCR"]
)
async def ocr_extract_items(
current_user: UserModel = Depends(get_current_user),
# Use File(...) for better metadata handling than UploadFile directly as type hint
current_user: UserModel = Depends(current_active_user),
image_file: UploadFile = File(..., description="Image file (JPEG, PNG, WEBP) of the shopping list or receipt."),
):
"""
@ -35,74 +40,37 @@ async def ocr_extract_items(
# Check if Gemini client initialized correctly
if gemini_initialization_error:
logger.error("OCR endpoint called but Gemini client failed to initialize.")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"OCR service unavailable: {gemini_initialization_error}"
)
raise OCRServiceUnavailableError(gemini_initialization_error)
logger.info(f"User {current_user.email} uploading image '{image_file.filename}' for OCR extraction.")
# --- File Validation ---
if image_file.content_type not in ALLOWED_IMAGE_TYPES:
if image_file.content_type not in settings.ALLOWED_IMAGE_TYPES:
logger.warning(f"Invalid file type uploaded by {current_user.email}: {image_file.content_type}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid file type. Allowed types: {', '.join(ALLOWED_IMAGE_TYPES)}",
)
raise InvalidFileTypeError()
# Simple size check (FastAPI/Starlette might handle larger limits via config)
# Read content first to get size accurately
# Simple size check
contents = await image_file.read()
if len(contents) > MAX_FILE_SIZE_MB * 1024 * 1024:
if len(contents) > settings.MAX_FILE_SIZE_MB * 1024 * 1024:
logger.warning(f"File too large uploaded by {current_user.email}: {len(contents)} bytes")
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"File size exceeds limit of {MAX_FILE_SIZE_MB} MB.",
)
# --- End File Validation ---
raise FileTooLargeError()
try:
# Call the Gemini helper function
extracted_items = await extract_items_from_image_gemini(image_bytes=contents)
# Use the ocr_service instance instead of the standalone function
extracted_items = await ocr_service.extract_items(image_data=contents)
logger.info(f"Successfully extracted {len(extracted_items)} items for user {current_user.email}.")
return OcrExtractResponse(extracted_items=extracted_items)
except ValueError as e:
# Handle errors from Gemini processing (blocked, empty response, etc.)
logger.warning(f"Gemini processing error for user {current_user.email}: {e}")
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, # Or 400 Bad Request?
detail=f"Could not extract items from image: {e}",
)
except google_exceptions.ResourceExhausted as e:
# Specific handling for quota errors
logger.error(f"Gemini Quota Exceeded for user {current_user.email}: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="OCR service quota exceeded. Please try again later.",
)
except google_exceptions.GoogleAPIError as e:
# Handle other Google API errors (e.g., invalid key, permissions)
logger.error(f"Gemini API Error for user {current_user.email}: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"OCR service error: {e}",
)
except RuntimeError as e:
# Catch initialization errors from get_gemini_client()
logger.error(f"Gemini client runtime error during OCR request: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"OCR service configuration error: {e}"
)
except OCRServiceUnavailableError:
raise OCRServiceUnavailableError()
except OCRServiceConfigError:
raise OCRServiceConfigError()
except OCRQuotaExceededError:
raise OCRQuotaExceededError()
except Exception as e:
# Catch any other unexpected errors
logger.exception(f"Unexpected error during OCR extraction for user {current_user.email}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred during item extraction.",
)
raise OCRProcessingError(str(e))
finally:
# Ensure file handle is closed (UploadFile uses SpooledTemporaryFile)
# Ensure file handle is closed
await image_file.close()

View File

@ -1,30 +0,0 @@
# app/api/v1/endpoints/users.py
import logging
from fastapi import APIRouter, Depends, HTTPException
from app.api.dependencies import get_current_user # Import the dependency
from app.schemas.user import UserPublic # Import the response schema
from app.models import User as UserModel # Import the DB model for type hinting
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get(
"/me",
response_model=UserPublic, # Use the public schema to avoid exposing hash
summary="Get Current User",
description="Retrieves the details of the currently authenticated user.",
tags=["Users"]
)
async def read_users_me(
current_user: UserModel = Depends(get_current_user) # Apply the dependency
):
"""
Returns the data for the user associated with the current valid access token.
"""
logger.info(f"Fetching details for current user: {current_user.email}")
# The 'current_user' object is the SQLAlchemy model instance returned by the dependency.
# Pydantic's response_model will automatically convert it using UserPublic schema.
return current_user
# Add other user-related endpoints here later (e.g., update user, list users (admin))

View File

@ -3,7 +3,7 @@ import pytest
from httpx import AsyncClient
from app.schemas.user import UserPublic # For response validation
from app.core.security import create_access_token
# from app.core.security import create_access_token # Commented out as FastAPI-Users handles token creation
pytestmark = pytest.mark.asyncio
@ -51,15 +51,15 @@ async def test_read_users_me_invalid_token(client: AsyncClient):
assert response.status_code == 401
assert response.json()["detail"] == "Could not validate credentials" # Detail from our dependency
async def test_read_users_me_expired_token(client: AsyncClient):
# Create a short-lived token manually (or adjust settings temporarily)
email = "testexpired@example.com"
# Assume create_access_token allows timedelta override
expired_token = create_access_token(subject=email, expires_delta=timedelta(seconds=-10))
headers = {"Authorization": f"Bearer {expired_token}"}
# async def test_read_users_me_expired_token(client: AsyncClient):
# # Create a short-lived token manually (or adjust settings temporarily)
# email = "testexpired@example.com"
# # Assume create_access_token allows timedelta override
# # expired_token = create_access_token(subject=email, expires_delta=timedelta(seconds=-10))
# # headers = {"Authorization": f"Bearer {expired_token}"}
response = await client.get("/api/v1/users/me", headers=headers)
assert response.status_code == 401
assert response.json()["detail"] == "Could not validate credentials"
# # response = await client.get("/api/v1/users/me", headers=headers)
# # assert response.status_code == 401
# # assert response.json()["detail"] == "Could not validate credentials" # Detail from our dependency
# Add test case for valid token but user deleted from DB if needed

151
be/app/auth.py Normal file
View File

@ -0,0 +1,151 @@
from typing import Optional
from fastapi import Depends, Request
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users import BaseUserManager, FastAPIUsers, IntegerIDMixin
from fastapi_users.authentication import (
AuthenticationBackend,
BearerTransport,
JWTStrategy,
)
from fastapi_users.db import SQLAlchemyUserDatabase
from sqlalchemy.ext.asyncio import AsyncSession
from authlib.integrations.starlette_client import OAuth
from starlette.config import Config
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response
from pydantic import BaseModel
from fastapi.responses import JSONResponse
from .database import get_session
from .models import User
from .config import settings
# OAuth2 configuration
config = Config('.env')
oauth = OAuth(config)
# Google OAuth2 setup
oauth.register(
name='google',
server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
client_kwargs={
'scope': 'openid email profile',
'redirect_uri': settings.GOOGLE_REDIRECT_URI
}
)
# Apple OAuth2 setup
oauth.register(
name='apple',
server_metadata_url='https://appleid.apple.com/.well-known/openid-configuration',
client_kwargs={
'scope': 'openid email name',
'redirect_uri': settings.APPLE_REDIRECT_URI
}
)
# Custom Bearer Response with Refresh Token
class BearerResponseWithRefresh(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
# Custom Bearer Transport that supports refresh tokens
class BearerTransportWithRefresh(BearerTransport):
async def get_login_response(self, token: str, refresh_token: str = None) -> Response:
if refresh_token:
bearer_response = BearerResponseWithRefresh(
access_token=token,
refresh_token=refresh_token,
token_type="bearer"
)
else:
# Fallback to standard response if no refresh token
bearer_response = {
"access_token": token,
"token_type": "bearer"
}
return JSONResponse(bearer_response.dict() if hasattr(bearer_response, 'dict') else bearer_response)
# Custom Authentication Backend with Refresh Token Support
class AuthenticationBackendWithRefresh(AuthenticationBackend):
def __init__(
self,
name: str,
transport: BearerTransportWithRefresh,
get_strategy,
get_refresh_strategy,
):
self.name = name
self.transport = transport
self.get_strategy = get_strategy
self.get_refresh_strategy = get_refresh_strategy
async def login(self, strategy, user) -> Response:
# Generate both access and refresh tokens
access_token = await strategy.write_token(user)
refresh_strategy = self.get_refresh_strategy()
refresh_token = await refresh_strategy.write_token(user)
return await self.transport.get_login_response(
token=access_token,
refresh_token=refresh_token
)
async def logout(self, strategy, user, token) -> Response:
return await self.transport.get_logout_response()
class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
reset_password_token_secret = settings.SECRET_KEY
verification_token_secret = settings.SECRET_KEY
async def on_after_register(self, user: User, request: Optional[Request] = None):
print(f"User {user.id} has registered.")
async def on_after_forgot_password(
self, user: User, token: str, request: Optional[Request] = None
):
print(f"User {user.id} has forgot their password. Reset token: {token}")
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None
):
print(f"Verification requested for user {user.id}. Verification token: {token}")
async def on_after_login(
self, user: User, request: Optional[Request] = None, response: Optional[Response] = None
):
print(f"User {user.id} has logged in.")
async def get_user_db(session: AsyncSession = Depends(get_session)):
yield SQLAlchemyUserDatabase(session, User)
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
yield UserManager(user_db)
# Updated transport with refresh token support
bearer_transport = BearerTransportWithRefresh(tokenUrl="auth/jwt/login")
def get_jwt_strategy() -> JWTStrategy:
return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60)
def get_refresh_jwt_strategy() -> JWTStrategy:
# Refresh tokens last longer - 7 days
return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=7 * 24 * 60 * 60)
# Updated auth backend with refresh token support
auth_backend = AuthenticationBackendWithRefresh(
name="jwt",
transport=bearer_transport,
get_strategy=get_jwt_strategy,
get_refresh_strategy=get_refresh_jwt_strategy,
)
fastapi_users = FastAPIUsers[User, int](
get_user_manager,
[auth_backend],
)
current_active_user = fastapi_users.current_user(active=True)
current_superuser = fastapi_users.current_user(active=True, superuser=True)

View File

@ -3,6 +3,8 @@ import os
from pydantic_settings import BaseSettings
from dotenv import load_dotenv
import logging
import secrets
from typing import List
load_dotenv()
logger = logging.getLogger(__name__)
@ -10,39 +12,196 @@ logger = logging.getLogger(__name__)
class Settings(BaseSettings):
DATABASE_URL: str | None = None
GEMINI_API_KEY: str | None = None
SENTRY_DSN: str | None = None # Sentry DSN for error tracking
# --- JWT Settings ---
# Generate a strong secret key using: openssl rand -hex 32
SECRET_KEY: str = "a_very_insecure_default_secret_key_replace_me" # !! MUST BE CHANGED IN PRODUCTION !!
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # Default token lifetime: 30 minutes
# --- Environment Settings ---
ENVIRONMENT: str = "development" # development, staging, production
# --- JWT Settings --- (SECRET_KEY is used by FastAPI-Users)
SECRET_KEY: str # Must be set via environment variable
TOKEN_TYPE: str = "bearer" # Default token type for JWT authentication
# FastAPI-Users handles JWT algorithm internally
# --- OCR Settings ---
MAX_FILE_SIZE_MB: int = 10 # Maximum allowed file size for OCR processing
ALLOWED_IMAGE_TYPES: list[str] = ["image/jpeg", "image/png", "image/webp"] # Supported image formats
OCR_ITEM_EXTRACTION_PROMPT: str = """
Extract the shopping list items from this image.
List each distinct item on a new line.
Ignore prices, quantities, store names, discounts, taxes, totals, and other non-item text.
Focus only on the names of the products or items to be purchased.
Add 2 underscores before and after the item name, if it is struck through.
If the image does not appear to be a shopping list or receipt, state that clearly.
Example output for a grocery list:
Milk
Eggs
Bread
__Apples__
Organic Bananas
"""
# --- OCR Error Messages ---
OCR_SERVICE_UNAVAILABLE: str = "OCR service is currently unavailable. Please try again later."
OCR_SERVICE_CONFIG_ERROR: str = "OCR service configuration error. Please contact support."
OCR_UNEXPECTED_ERROR: str = "An unexpected error occurred during OCR processing."
OCR_QUOTA_EXCEEDED: str = "OCR service quota exceeded. Please try again later."
OCR_INVALID_FILE_TYPE: str = "Invalid file type. Supported types: {types}"
OCR_FILE_TOO_LARGE: str = "File too large. Maximum size: {size}MB"
OCR_PROCESSING_ERROR: str = "Error processing image: {detail}"
# --- Gemini AI Settings ---
GEMINI_MODEL_NAME: str = "gemini-2.0-flash" # The model to use for OCR
GEMINI_SAFETY_SETTINGS: dict = {
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE",
}
GEMINI_GENERATION_CONFIG: dict = {
"candidate_count": 1,
"max_output_tokens": 2048,
"temperature": 0.9,
"top_p": 1,
"top_k": 1
}
# --- API Settings ---
API_PREFIX: str = "/api" # Base path for all API endpoints
API_OPENAPI_URL: str = "/api/openapi.json"
API_DOCS_URL: str = "/api/docs"
API_REDOC_URL: str = "/api/redoc"
# CORS Origins - environment dependent
CORS_ORIGINS: str = "http://localhost:5173,http://localhost:5174,http://localhost:8000,http://127.0.0.1:5173,http://127.0.0.1:5174,http://127.0.0.1:8000"
FRONTEND_URL: str = "http://localhost:5173" # URL for the frontend application
# --- API Metadata ---
API_TITLE: str = "Shared Lists API"
API_DESCRIPTION: str = "API for managing shared shopping lists, OCR, and cost splitting."
API_VERSION: str = "0.1.0"
ROOT_MESSAGE: str = "Welcome to the Shared Lists API! Docs available at /api/docs"
# --- Logging Settings ---
LOG_LEVEL: str = "WARNING"
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
# --- Health Check Settings ---
HEALTH_STATUS_OK: str = "ok"
HEALTH_STATUS_ERROR: str = "error"
# --- HTTP Status Messages ---
HTTP_400_DETAIL: str = "Bad Request"
HTTP_401_DETAIL: str = "Unauthorized"
HTTP_403_DETAIL: str = "Forbidden"
HTTP_404_DETAIL: str = "Not Found"
HTTP_422_DETAIL: str = "Unprocessable Entity"
HTTP_429_DETAIL: str = "Too Many Requests"
HTTP_500_DETAIL: str = "Internal Server Error"
HTTP_503_DETAIL: str = "Service Unavailable"
# --- Database Error Messages ---
DB_CONNECTION_ERROR: str = "Database connection error"
DB_INTEGRITY_ERROR: str = "Database integrity error"
DB_TRANSACTION_ERROR: str = "Database transaction error"
DB_QUERY_ERROR: str = "Database query error"
# --- Auth Error Messages ---
AUTH_INVALID_CREDENTIALS: str = "Invalid username or password"
AUTH_NOT_AUTHENTICATED: str = "Not authenticated"
AUTH_JWT_ERROR: str = "JWT token error: {error}"
AUTH_JWT_UNEXPECTED_ERROR: str = "Unexpected JWT error: {error}"
AUTH_HEADER_NAME: str = "WWW-Authenticate"
AUTH_HEADER_PREFIX: str = "Bearer"
# OAuth Settings
# IMPORTANT: For Google OAuth to work, you MUST set the following environment variables
# (e.g., in your .env file):
# GOOGLE_CLIENT_ID: Your Google Cloud project's OAuth 2.0 Client ID
# GOOGLE_CLIENT_SECRET: Your Google Cloud project's OAuth 2.0 Client Secret
# Ensure the GOOGLE_REDIRECT_URI below matches the one configured in your Google Cloud Console.
GOOGLE_CLIENT_ID: str = ""
GOOGLE_CLIENT_SECRET: str = ""
GOOGLE_REDIRECT_URI: str = "https://mitlistbe.mohamad.dev/api/v1/auth/google/callback"
APPLE_CLIENT_ID: str = ""
APPLE_TEAM_ID: str = ""
APPLE_KEY_ID: str = ""
APPLE_PRIVATE_KEY: str = ""
APPLE_REDIRECT_URI: str = "https://mitlistbe.mohamad.dev/api/v1/auth/apple/callback"
# Session Settings
SESSION_SECRET_KEY: str = "your-session-secret-key" # Change this in production
ACCESS_TOKEN_EXPIRE_MINUTES: int = 480 # 8 hours instead of 30 minutes
# Redis Settings
REDIS_URL: str = "redis://localhost:6379"
REDIS_PASSWORD: str = ""
class Config:
env_file = ".env"
env_file_encoding = 'utf-8'
extra = "ignore"
@property
def cors_origins_list(self) -> List[str]:
"""Convert CORS_ORIGINS string to list"""
return [origin.strip() for origin in self.CORS_ORIGINS.split(",")]
@property
def is_production(self) -> bool:
"""Check if running in production environment"""
return self.ENVIRONMENT.lower() == "production"
@property
def is_development(self) -> bool:
"""Check if running in development environment"""
return self.ENVIRONMENT.lower() == "development"
@property
def docs_url(self) -> str | None:
"""Return docs URL only in development"""
return self.API_DOCS_URL if self.is_development else None
@property
def redoc_url(self) -> str | None:
"""Return redoc URL only in development"""
return self.API_REDOC_URL if self.is_development else None
@property
def openapi_url(self) -> str | None:
"""Return OpenAPI URL only in development"""
return self.API_OPENAPI_URL if self.is_development else None
settings = Settings()
# Validation for critical settings
if settings.DATABASE_URL is None:
print("Warning: DATABASE_URL environment variable not set.")
# raise ValueError("DATABASE_URL environment variable not set.")
raise ValueError("DATABASE_URL environment variable must be set.")
# CRITICAL: Check if the default secret key is being used
if settings.SECRET_KEY == "a_very_insecure_default_secret_key_replace_me":
print("*" * 80)
print("WARNING: Using default insecure SECRET_KEY. Please generate a strong key and set it in the environment variables!")
print("Use: openssl rand -hex 32")
print("*" * 80)
# Consider raising an error in a production environment check
# if os.getenv("ENVIRONMENT") == "production":
# raise ValueError("Default SECRET_KEY is not allowed in production!")
# Enforce secure secret key
if not settings.SECRET_KEY:
raise ValueError("SECRET_KEY environment variable must be set. Generate a secure key using: openssl rand -hex 32")
# Validate secret key strength
if len(settings.SECRET_KEY) < 32:
raise ValueError("SECRET_KEY must be at least 32 characters long for security")
# Production-specific validations
if settings.is_production:
if settings.SESSION_SECRET_KEY == "your-session-secret-key":
raise ValueError("SESSION_SECRET_KEY must be changed from default value in production")
if not settings.SENTRY_DSN:
logger.warning("SENTRY_DSN not set in production environment. Error tracking will be unavailable.")
if settings.GEMINI_API_KEY is None:
print.error("CRITICAL: GEMINI_API_KEY environment variable not set. Gemini features will be unavailable.")
# You might raise an error here if Gemini is essential for startup
# raise ValueError("GEMINI_API_KEY must be set.")
logger.error("CRITICAL: GEMINI_API_KEY environment variable not set. Gemini features will be unavailable.")
else:
# Optional: Log partial key for confirmation (avoid logging full key)
logger.info(f"GEMINI_API_KEY loaded (starts with: {settings.GEMINI_API_KEY[:4]}...).")
# Log environment information
logger.info(f"Application starting in {settings.ENVIRONMENT} environment")
if settings.is_production:
logger.info("Production mode: API documentation disabled")
else:
logger.info(f"Development mode: API documentation available at {settings.API_DOCS_URL}")

102
be/app/core/api_config.py Normal file
View File

@ -0,0 +1,102 @@
from typing import Dict, Any
from app.config import settings
# API Version
API_VERSION = "v1"
# API Prefix
API_PREFIX = f"/api/{API_VERSION}"
# API Endpoints
class APIEndpoints:
# Auth
AUTH = {
"LOGIN": "/auth/login",
"SIGNUP": "/auth/signup",
"REFRESH_TOKEN": "/auth/refresh-token",
}
# Users
USERS = {
"PROFILE": "/users/profile",
"UPDATE_PROFILE": "/users/profile",
}
# Lists
LISTS = {
"BASE": "/lists",
"BY_ID": "/lists/{id}",
"ITEMS": "/lists/{list_id}/items",
"ITEM": "/lists/{list_id}/items/{item_id}",
}
# Groups
GROUPS = {
"BASE": "/groups",
"BY_ID": "/groups/{id}",
"LISTS": "/groups/{group_id}/lists",
"MEMBERS": "/groups/{group_id}/members",
}
# Invites
INVITES = {
"BASE": "/invites",
"BY_ID": "/invites/{id}",
"ACCEPT": "/invites/{id}/accept",
"DECLINE": "/invites/{id}/decline",
}
# OCR
OCR = {
"PROCESS": "/ocr/process",
}
# Financials
FINANCIALS = {
"EXPENSES": "/financials/expenses",
"EXPENSE": "/financials/expenses/{id}",
"SETTLEMENTS": "/financials/settlements",
"SETTLEMENT": "/financials/settlements/{id}",
}
# Health
HEALTH = {
"CHECK": "/health",
}
# API Metadata
API_METADATA = {
"title": settings.API_TITLE,
"description": settings.API_DESCRIPTION,
"version": settings.API_VERSION,
"openapi_url": settings.API_OPENAPI_URL,
"docs_url": settings.API_DOCS_URL,
"redoc_url": settings.API_REDOC_URL,
}
# API Tags
API_TAGS = [
{"name": "Authentication", "description": "Authentication and authorization endpoints"},
{"name": "Users", "description": "User management endpoints"},
{"name": "Lists", "description": "Shopping list management endpoints"},
{"name": "Groups", "description": "Group management endpoints"},
{"name": "Invites", "description": "Group invitation management endpoints"},
{"name": "OCR", "description": "Optical Character Recognition endpoints"},
{"name": "Financials", "description": "Financial management endpoints"},
{"name": "Health", "description": "Health check endpoints"},
]
# Helper function to get full API URL
def get_api_url(endpoint: str, **kwargs) -> str:
"""
Get the full API URL for an endpoint.
Args:
endpoint: The endpoint path
**kwargs: Path parameters to format the endpoint
Returns:
str: The full API URL
"""
formatted_endpoint = endpoint.format(**kwargs)
return f"{API_PREFIX}{formatted_endpoint}"

View File

@ -0,0 +1,79 @@
from datetime import date, timedelta
from typing import Optional
from app.models import ChoreFrequencyEnum
def calculate_next_due_date(
current_due_date: date,
frequency: ChoreFrequencyEnum,
custom_interval_days: Optional[int] = None,
last_completed_date: Optional[date] = None
) -> date:
"""
Calculates the next due date for a chore.
Uses current_due_date as a base if last_completed_date is not provided.
Calculates from last_completed_date if provided.
"""
if frequency == ChoreFrequencyEnum.one_time:
if last_completed_date:
raise ValueError("Cannot calculate next due date for a completed one-time chore.")
return current_due_date
base_date = last_completed_date if last_completed_date else current_due_date
if hasattr(base_date, 'date') and callable(getattr(base_date, 'date')):
base_date = base_date.date() # type: ignore
next_due: date
if frequency == ChoreFrequencyEnum.daily:
next_due = base_date + timedelta(days=1)
elif frequency == ChoreFrequencyEnum.weekly:
next_due = base_date + timedelta(weeks=1)
elif frequency == ChoreFrequencyEnum.monthly:
month = base_date.month + 1
year = base_date.year + (month - 1) // 12
month = (month - 1) % 12 + 1
day_of_target_month_last = (date(year, month % 12 + 1, 1) - timedelta(days=1)).day if month < 12 else 31
day = min(base_date.day, day_of_target_month_last)
next_due = date(year, month, day)
elif frequency == ChoreFrequencyEnum.custom:
if not custom_interval_days or custom_interval_days <= 0:
raise ValueError("Custom frequency requires a positive custom_interval_days.")
next_due = base_date + timedelta(days=custom_interval_days)
else:
raise ValueError(f"Unknown or unsupported chore frequency: {frequency}")
today = date.today()
reference_future_date = max(today, base_date)
# This loop ensures the next_due date is always in the future relative to the reference_future_date.
while next_due <= reference_future_date:
current_base_for_recalc = next_due
if frequency == ChoreFrequencyEnum.daily:
next_due = current_base_for_recalc + timedelta(days=1)
elif frequency == ChoreFrequencyEnum.weekly:
next_due = current_base_for_recalc + timedelta(weeks=1)
elif frequency == ChoreFrequencyEnum.monthly:
month = current_base_for_recalc.month + 1
year = current_base_for_recalc.year + (month - 1) // 12
month = (month - 1) % 12 + 1
day_of_target_month_last = (date(year, month % 12 + 1, 1) - timedelta(days=1)).day if month < 12 else 31
day = min(current_base_for_recalc.day, day_of_target_month_last)
next_due = date(year, month, day)
elif frequency == ChoreFrequencyEnum.custom:
if not custom_interval_days or custom_interval_days <= 0: # Should have been validated
raise ValueError("Custom frequency requires positive interval during recalc.")
next_due = current_base_for_recalc + timedelta(days=custom_interval_days)
else: # Should not be reached
break
# Safety break: if date hasn't changed, interval is zero or logic error.
if next_due == current_base_for_recalc:
# Log error ideally, then advance by one day to prevent infinite loop.
next_due += timedelta(days=1)
break
return next_due

357
be/app/core/exceptions.py Normal file
View File

@ -0,0 +1,357 @@
from fastapi import HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from app.config import settings
from typing import Optional
class ListNotFoundError(HTTPException):
"""Raised when a list is not found."""
def __init__(self, list_id: int):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"List {list_id} not found"
)
class ListPermissionError(HTTPException):
"""Raised when a user doesn't have permission to access a list."""
def __init__(self, list_id: int, action: str = "access"):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"You do not have permission to {action} list {list_id}"
)
class ListCreatorRequiredError(HTTPException):
"""Raised when an action requires the list creator but the user is not the creator."""
def __init__(self, list_id: int, action: str):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Only the list creator can {action} list {list_id}"
)
class GroupNotFoundError(HTTPException):
"""Raised when a group is not found."""
def __init__(self, group_id: int):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Group {group_id} not found"
)
class GroupPermissionError(HTTPException):
"""Raised when a user doesn't have permission to perform an action in a group."""
def __init__(self, group_id: int, action: str):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"You do not have permission to {action} in group {group_id}"
)
class GroupMembershipError(HTTPException):
"""Raised when a user attempts to perform an action that requires group membership."""
def __init__(self, group_id: int, action: str = "access"):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"You must be a member of group {group_id} to {action}"
)
class GroupOperationError(HTTPException):
"""Raised when a group operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class GroupValidationError(HTTPException):
"""Raised when a group operation is invalid."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=detail
)
class ItemNotFoundError(HTTPException):
"""Raised when an item is not found."""
def __init__(self, item_id: int):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Item {item_id} not found"
)
class UserNotFoundError(HTTPException):
"""Raised when a user is not found."""
def __init__(self, user_id: Optional[int] = None, identifier: Optional[str] = None):
detail_msg = "User not found."
if user_id:
detail_msg = f"User with ID {user_id} not found."
elif identifier:
detail_msg = f"User with identifier '{identifier}' not found."
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail_msg
)
class InvalidOperationError(HTTPException):
"""Raised when an operation is invalid or disallowed by business logic."""
def __init__(self, detail: str, status_code: int = status.HTTP_400_BAD_REQUEST):
super().__init__(
status_code=status_code,
detail=detail
)
class DatabaseConnectionError(HTTPException):
"""Raised when there is an error connecting to the database."""
def __init__(self):
super().__init__(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=settings.DB_CONNECTION_ERROR
)
class DatabaseIntegrityError(HTTPException):
"""Raised when a database integrity constraint is violated."""
def __init__(self):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=settings.DB_INTEGRITY_ERROR
)
class DatabaseTransactionError(HTTPException):
"""Raised when a database transaction fails."""
def __init__(self, detail: str = settings.DB_TRANSACTION_ERROR):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class DatabaseQueryError(HTTPException):
"""Raised when a database query fails."""
def __init__(self, detail: str = settings.DB_QUERY_ERROR):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class ExpenseOperationError(HTTPException):
"""Raised when an expense operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class OCRServiceUnavailableError(HTTPException):
"""Raised when the OCR service is unavailable."""
def __init__(self):
super().__init__(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=settings.OCR_SERVICE_UNAVAILABLE
)
class OCRServiceConfigError(HTTPException):
"""Raised when there is an error in the OCR service configuration."""
def __init__(self):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=settings.OCR_SERVICE_CONFIG_ERROR
)
class OCRUnexpectedError(HTTPException):
"""Raised when there is an unexpected error in the OCR service."""
def __init__(self):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=settings.OCR_UNEXPECTED_ERROR
)
class OCRQuotaExceededError(HTTPException):
"""Raised when the OCR service quota is exceeded."""
def __init__(self):
super().__init__(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=settings.OCR_QUOTA_EXCEEDED
)
class InvalidFileTypeError(HTTPException):
"""Raised when an invalid file type is uploaded for OCR."""
def __init__(self):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=settings.OCR_INVALID_FILE_TYPE.format(types=", ".join(settings.ALLOWED_IMAGE_TYPES))
)
class FileTooLargeError(HTTPException):
"""Raised when an uploaded file exceeds the size limit."""
def __init__(self):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=settings.OCR_FILE_TOO_LARGE.format(size=settings.MAX_FILE_SIZE_MB)
)
class OCRProcessingError(HTTPException):
"""Raised when there is an error processing the image with OCR."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=settings.OCR_PROCESSING_ERROR.format(detail=detail)
)
class EmailAlreadyRegisteredError(HTTPException):
"""Raised when attempting to register with an email that is already in use."""
def __init__(self):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered."
)
class UserCreationError(HTTPException):
"""Raised when there is an error creating a new user."""
def __init__(self):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred during user creation."
)
class InviteNotFoundError(HTTPException):
"""Raised when an invite is not found."""
def __init__(self, invite_code: str):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Invite code {invite_code} not found"
)
class InviteExpiredError(HTTPException):
"""Raised when an invite has expired."""
def __init__(self, invite_code: str):
super().__init__(
status_code=status.HTTP_410_GONE,
detail=f"Invite code {invite_code} has expired"
)
class InviteAlreadyUsedError(HTTPException):
"""Raised when an invite has already been used."""
def __init__(self, invite_code: str):
super().__init__(
status_code=status.HTTP_410_GONE,
detail=f"Invite code {invite_code} has already been used"
)
class InviteCreationError(HTTPException):
"""Raised when an invite cannot be created."""
def __init__(self, group_id: int):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create invite for group {group_id}"
)
class ListStatusNotFoundError(HTTPException):
"""Raised when a list's status cannot be retrieved."""
def __init__(self, list_id: int):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Status for list {list_id} not found"
)
class InviteOperationError(HTTPException):
"""Raised when an invite operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class SettlementOperationError(HTTPException):
"""Raised when a settlement operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class ConflictError(HTTPException):
"""Raised when an optimistic lock version conflict occurs."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_409_CONFLICT,
detail=detail
)
class InvalidCredentialsError(HTTPException):
"""Raised when login credentials are invalid."""
def __init__(self):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.AUTH_INVALID_CREDENTIALS,
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_credentials\""}
)
class NotAuthenticatedError(HTTPException):
"""Raised when the user is not authenticated."""
def __init__(self):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.AUTH_NOT_AUTHENTICATED,
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"not_authenticated\""}
)
class JWTError(HTTPException):
"""Raised when there is an error with the JWT token."""
def __init__(self, error: str):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.AUTH_JWT_ERROR.format(error=error),
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
)
class JWTUnexpectedError(HTTPException):
"""Raised when there is an unexpected error with the JWT token."""
def __init__(self, error: str):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=settings.AUTH_JWT_UNEXPECTED_ERROR.format(error=error),
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
)
class ListOperationError(HTTPException):
"""Raised when a list operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class ItemOperationError(HTTPException):
"""Raised when an item operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class UserOperationError(HTTPException):
"""Raised when a user operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class ChoreNotFoundError(HTTPException):
"""Raised when a chore is not found."""
def __init__(self, chore_id: int, group_id: Optional[int] = None, detail: Optional[str] = None):
if detail:
error_detail = detail
elif group_id is not None:
error_detail = f"Chore {chore_id} not found in group {group_id}"
else:
error_detail = f"Chore {chore_id} not found"
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=error_detail
)
class PermissionDeniedError(HTTPException):
"""Raised when a user is denied permission for an action."""
def __init__(self, detail: str = "Permission denied."):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
detail=detail
)
# Financials & Cost Splitting specific errors

View File

@ -4,8 +4,14 @@ from typing import List
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold # For safety settings
from google.api_core import exceptions as google_exceptions
from app.config import settings
from app.core.exceptions import (
OCRServiceUnavailableError,
OCRServiceConfigError,
OCRUnexpectedError,
OCRQuotaExceededError,
OCRProcessingError
)
logger = logging.getLogger(__name__)
@ -19,26 +25,12 @@ try:
genai.configure(api_key=settings.GEMINI_API_KEY)
# Initialize the specific model we want to use
gemini_flash_client = genai.GenerativeModel(
model_name="gemini-2.0-flash",
# Optional: Add default safety settings
# Adjust these based on your expected content and risk tolerance
safety_settings={
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
},
# Optional: Add default generation config (can be overridden per request)
# generation_config=genai.types.GenerationConfig(
# # candidate_count=1, # Usually default is 1
# # stop_sequences=["\n"],
# # max_output_tokens=2048,
# # temperature=0.9, # Controls randomness (0=deterministic, >1=more random)
# # top_p=1,
# # top_k=1
# )
model_name=settings.GEMINI_MODEL_NAME,
generation_config=genai.types.GenerationConfig(
**settings.GEMINI_GENERATION_CONFIG
)
logger.info("Gemini AI client initialized successfully for model 'gemini-1.5-flash-latest'.")
)
logger.info(f"Gemini AI client initialized successfully for model '{settings.GEMINI_MODEL_NAME}'.")
else:
# Store error if API key is missing
gemini_initialization_error = "GEMINI_API_KEY not configured. Gemini client not initialized."
@ -58,10 +50,10 @@ def get_gemini_client():
Raises an exception if initialization failed.
"""
if gemini_initialization_error:
raise RuntimeError(f"Gemini client could not be initialized: {gemini_initialization_error}")
raise OCRServiceConfigError()
if gemini_flash_client is None:
# This case should ideally be covered by the check above, but as a safeguard:
raise RuntimeError("Gemini client is not available (unknown initialization issue).")
raise OCRServiceConfigError()
return gemini_flash_client
# Define the prompt as a constant
@ -79,37 +71,41 @@ Apples
Organic Bananas
"""
async def extract_items_from_image_gemini(image_bytes: bytes) -> List[str]:
async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "image/jpeg") -> List[str]:
"""
Uses Gemini Flash to extract shopping list items from image bytes.
Args:
image_bytes: The image content as bytes.
mime_type: The MIME type of the image (e.g., "image/jpeg", "image/png", "image/webp").
Returns:
A list of extracted item strings.
Raises:
RuntimeError: If the Gemini client is not initialized.
google_exceptions.GoogleAPIError: For API call errors (quota, invalid key etc.).
ValueError: If the response is blocked or contains no usable text.
OCRServiceConfigError: If the Gemini client is not initialized.
OCRQuotaExceededError: If API quota is exceeded.
OCRServiceUnavailableError: For general API call errors.
OCRProcessingError: If the response is blocked or contains no usable text.
OCRUnexpectedError: For any other unexpected errors.
"""
client = get_gemini_client() # Raises RuntimeError if not initialized
try:
client = get_gemini_client() # Raises OCRServiceConfigError if not initialized
# Prepare image part for multimodal input
image_part = {
"mime_type": "image/jpeg", # Or image/png, image/webp etc. Adjust if needed or detect mime type
"mime_type": mime_type,
"data": image_bytes
}
# Prepare the full prompt content
prompt_parts = [
OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first
settings.OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first
image_part # Then the image
]
logger.info("Sending image to Gemini for item extraction...")
try:
# Make the API call
# Use generate_content_async for async FastAPI
response = await client.generate_content_async(prompt_parts)
@ -122,9 +118,9 @@ async def extract_items_from_image_gemini(image_bytes: bytes) -> List[str]:
finish_reason = response.candidates[0].finish_reason if response.candidates else 'UNKNOWN'
safety_ratings = response.candidates[0].safety_ratings if response.candidates else 'N/A'
if finish_reason == 'SAFETY':
raise ValueError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
raise OCRProcessingError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
else:
raise ValueError(f"Gemini response was empty or incomplete. Finish Reason: {finish_reason}")
raise OCRUnexpectedError()
# Extract text - assumes the first part of the first candidate is the text response
raw_text = response.text # response.text is a shortcut for response.candidates[0].content.parts[0].text
@ -145,10 +141,89 @@ async def extract_items_from_image_gemini(image_bytes: bytes) -> List[str]:
except google_exceptions.GoogleAPIError as e:
logger.error(f"Gemini API Error: {e}", exc_info=True)
# Re-raise specific Google API errors for endpoint to handle (e.g., quota)
raise e
if "quota" in str(e).lower():
raise OCRQuotaExceededError()
raise OCRServiceUnavailableError()
except (OCRServiceConfigError, OCRQuotaExceededError, OCRServiceUnavailableError, OCRProcessingError, OCRUnexpectedError):
# Re-raise specific OCR exceptions
raise
except Exception as e:
# Catch other unexpected errors during generation or processing
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
# Wrap in a generic ValueError or re-raise
raise ValueError(f"Failed to process image with Gemini: {e}") from e
# Wrap in a custom exception
raise OCRUnexpectedError()
class GeminiOCRService:
def __init__(self):
try:
genai.configure(api_key=settings.GEMINI_API_KEY)
self.model = genai.GenerativeModel(
model_name=settings.GEMINI_MODEL_NAME,
generation_config=genai.types.GenerationConfig(
**settings.GEMINI_GENERATION_CONFIG
)
)
except Exception as e:
logger.error(f"Failed to initialize Gemini client: {e}")
raise OCRServiceConfigError()
async def extract_items(self, image_data: bytes, mime_type: str = "image/jpeg") -> List[str]:
"""
Extract shopping list items from an image using Gemini Vision.
Args:
image_data: The image content as bytes.
mime_type: The MIME type of the image (e.g., "image/jpeg", "image/png", "image/webp").
Returns:
A list of extracted item strings.
Raises:
OCRServiceConfigError: If the Gemini client is not initialized.
OCRQuotaExceededError: If API quota is exceeded.
OCRServiceUnavailableError: For general API call errors.
OCRProcessingError: If the response is blocked or contains no usable text.
OCRUnexpectedError: For any other unexpected errors.
"""
try:
# Create image part
image_parts = [{"mime_type": mime_type, "data": image_data}]
# Generate content
response = await self.model.generate_content_async(
contents=[settings.OCR_ITEM_EXTRACTION_PROMPT, *image_parts]
)
# Process response
if not response.text:
logger.warning("Gemini response is empty")
raise OCRUnexpectedError()
# Check for safety blocks
if hasattr(response, 'candidates') and response.candidates and hasattr(response.candidates[0], 'finish_reason'):
finish_reason = response.candidates[0].finish_reason
if finish_reason == 'SAFETY':
safety_ratings = response.candidates[0].safety_ratings if hasattr(response.candidates[0], 'safety_ratings') else 'N/A'
raise OCRProcessingError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
# Split response into lines and clean up
items = []
for line in response.text.splitlines():
cleaned_line = line.strip()
if cleaned_line and len(cleaned_line) > 1 and not cleaned_line.startswith("Example"):
items.append(cleaned_line)
logger.info(f"Extracted {len(items)} potential items.")
return items
except google_exceptions.GoogleAPIError as e:
logger.error(f"Error during OCR extraction: {e}")
if "quota" in str(e).lower():
raise OCRQuotaExceededError()
raise OCRServiceUnavailableError()
except (OCRServiceConfigError, OCRQuotaExceededError, OCRServiceUnavailableError, OCRProcessingError, OCRUnexpectedError):
# Re-raise specific OCR exceptions
raise
except Exception as e:
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
raise OCRUnexpectedError()

73
be/app/core/scheduler.py Normal file
View File

@ -0,0 +1,73 @@
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
from apscheduler.executors.pool import ThreadPoolExecutor
from apscheduler.triggers.cron import CronTrigger
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from app.config import settings
from app.jobs.recurring_expenses import generate_recurring_expenses
from app.db.session import async_session
import logging
logger = logging.getLogger(__name__)
# Convert async database URL to sync URL for APScheduler
# Replace postgresql+asyncpg:// with postgresql://
sync_db_url = settings.DATABASE_URL.replace('postgresql+asyncpg://', 'postgresql://')
# Configure the scheduler
jobstores = {
'default': SQLAlchemyJobStore(url=sync_db_url)
}
executors = {
'default': ThreadPoolExecutor(20)
}
job_defaults = {
'coalesce': False,
'max_instances': 1
}
scheduler = AsyncIOScheduler(
jobstores=jobstores,
executors=executors,
job_defaults=job_defaults,
timezone='UTC'
)
async def run_recurring_expenses_job():
"""Wrapper function to run the recurring expenses job with a database session."""
try:
async with async_session() as session:
await generate_recurring_expenses(session)
except Exception as e:
logger.error(f"Error running recurring expenses job: {str(e)}")
raise
def init_scheduler():
"""Initialize and start the scheduler."""
try:
# Add the recurring expenses job
scheduler.add_job(
run_recurring_expenses_job,
trigger=CronTrigger(hour=0, minute=0), # Run at midnight UTC
id='generate_recurring_expenses',
name='Generate Recurring Expenses',
replace_existing=True
)
# Start the scheduler
scheduler.start()
logger.info("Scheduler started successfully")
except Exception as e:
logger.error(f"Error initializing scheduler: {str(e)}")
raise
def shutdown_scheduler():
"""Shutdown the scheduler gracefully."""
try:
scheduler.shutdown()
logger.info("Scheduler shut down successfully")
except Exception as e:
logger.error(f"Error shutting down scheduler: {str(e)}")
raise

View File

@ -8,6 +8,9 @@ from passlib.context import CryptContext
from app.config import settings # Import settings from config
# --- Password Hashing ---
# These functions are used for password hashing and verification
# They complement FastAPI-Users but provide direct access to the underlying password functionality
# when needed outside of the FastAPI-Users authentication flow.
# Configure passlib context
# Using bcrypt as the default hashing scheme
@ -17,6 +20,8 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
Verifies a plain text password against a hashed password.
This is used by FastAPI-Users internally, but also exposed here for custom authentication flows
if needed.
Args:
plain_password: The password attempt.
@ -34,6 +39,8 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
def hash_password(password: str) -> str:
"""
Hashes a plain text password using the configured context (bcrypt).
This is used by FastAPI-Users internally, but also exposed here for
custom user creation or password reset flows if needed.
Args:
password: The plain text password to hash.
@ -45,66 +52,22 @@ def hash_password(password: str) -> str:
# --- JSON Web Tokens (JWT) ---
# FastAPI-Users now handles all JWT token creation and validation.
# The code below is commented out because FastAPI-Users provides these features.
# It's kept for reference in case a custom implementation is needed later.
def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
"""
Creates a JWT access token.
Args:
subject: The subject of the token (e.g., user ID or email).
expires_delta: Optional timedelta object for token expiry. If None,
uses ACCESS_TOKEN_EXPIRE_MINUTES from settings.
Returns:
The encoded JWT access token string.
"""
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
# Data to encode in the token payload
to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def verify_access_token(token: str) -> Optional[dict]:
"""
Verifies a JWT access token and returns its payload if valid.
Args:
token: The JWT token string to verify.
Returns:
The decoded token payload (dict) if the token is valid and not expired,
otherwise None.
"""
try:
# Decode the token. This also automatically verifies:
# - Signature (using SECRET_KEY and ALGORITHM)
# - Expiration ('exp' claim)
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
return payload
except JWTError as e:
# Handles InvalidSignatureError, ExpiredSignatureError, etc.
print(f"JWT Error: {e}") # Log the error for debugging
return None
except Exception as e:
# Handle other potential unexpected errors during decoding
print(f"Unexpected error decoding JWT: {e}")
return None
# You might add a function here later to extract the 'sub' (subject/user id)
# specifically, often used in dependency injection for authentication.
# Example of a potential future implementation:
# def get_subject_from_token(token: str) -> Optional[str]:
# payload = verify_access_token(token)
# if payload:
# """
# Extract the subject (user ID) from a JWT token.
# This would be used if we need to validate tokens outside of FastAPI-Users flow.
# For now, use fastapi_users.current_user dependency instead.
# """
# # This would need to use FastAPI-Users' token verification if ever implemented
# # For example, by decoding the token using the strategy from the auth backend
# try:
# payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
# return payload.get("sub")
# except JWTError:
# return None
# return None

View File

@ -1,83 +0,0 @@
# be/tests/core/test_gemini.py
import pytest
import os
from unittest.mock import patch, MagicMock
# Store original key if exists, then clear it for testing missing key scenario
original_api_key = os.environ.get("GEMINI_API_KEY")
if "GEMINI_API_KEY" in os.environ:
del os.environ["GEMINI_API_KEY"]
# --- Test Module Import ---
# This forces the module-level initialization code in gemini.py to run
# We need to reload modules because settings might have been cached
from importlib import reload
from app.config import settings as app_settings
from app.core import gemini as gemini_core
# Reload settings first to ensure GEMINI_API_KEY is None initially
reload(app_settings)
# Reload gemini core to trigger initialization logic with potentially missing key
reload(gemini_core)
def test_gemini_initialization_without_key():
"""Verify behavior when GEMINI_API_KEY is not set."""
# Reload modules again to ensure clean state for this specific test
if "GEMINI_API_KEY" in os.environ:
del os.environ["GEMINI_API_KEY"]
reload(app_settings)
reload(gemini_core)
assert gemini_core.gemini_flash_client is None
assert gemini_core.gemini_initialization_error is not None
assert "GEMINI_API_KEY not configured" in gemini_core.gemini_initialization_error
with pytest.raises(RuntimeError, match="GEMINI_API_KEY not configured"):
gemini_core.get_gemini_client()
@patch('google.generativeai.configure')
@patch('google.generativeai.GenerativeModel')
def test_gemini_initialization_with_key(mock_generative_model: MagicMock, mock_configure: MagicMock):
"""Verify initialization logic is called when key is present (using mocks)."""
# Set a dummy key in the environment for this test
test_key = "TEST_API_KEY_123"
os.environ["GEMINI_API_KEY"] = test_key
# Reload settings and gemini module to pick up the new key
reload(app_settings)
reload(gemini_core)
# Assertions
mock_configure.assert_called_once_with(api_key=test_key)
mock_generative_model.assert_called_once_with(
model_name="gemini-1.5-flash-latest",
safety_settings=pytest.ANY, # Check safety settings were passed (ANY allows flexibility)
# generation_config=pytest.ANY # Check if you added default generation config
)
assert gemini_core.gemini_flash_client is not None
assert gemini_core.gemini_initialization_error is None
# Test get_gemini_client() success path
client = gemini_core.get_gemini_client()
assert client is not None # Should return the mocked client instance
# Clean up environment variable after test
if original_api_key:
os.environ["GEMINI_API_KEY"] = original_api_key
else:
if "GEMINI_API_KEY" in os.environ:
del os.environ["GEMINI_API_KEY"]
# Reload modules one last time to restore state for other tests
reload(app_settings)
reload(gemini_core)
# Restore original key after all tests in the module run (if needed)
def teardown_module(module):
if original_api_key:
os.environ["GEMINI_API_KEY"] = original_api_key
else:
if "GEMINI_API_KEY" in os.environ:
del os.environ["GEMINI_API_KEY"]
reload(app_settings)
reload(gemini_core)

View File

@ -1,86 +0,0 @@
# Example: be/tests/core/test_security.py
import pytest
from datetime import timedelta
from jose import jwt, JWTError
import time
from app.core.security import (
hash_password,
verify_password,
create_access_token,
verify_access_token,
)
from app.config import settings # Import settings for testing JWT config
# --- Password Hashing Tests ---
def test_hash_password_returns_string():
password = "testpassword"
hashed = hash_password(password)
assert isinstance(hashed, str)
assert password != hashed # Ensure it's not plain text
def test_verify_password_correct():
password = "correct_password"
hashed = hash_password(password)
assert verify_password(password, hashed) is True
def test_verify_password_incorrect():
hashed = hash_password("correct_password")
assert verify_password("wrong_password", hashed) is False
def test_verify_password_invalid_hash_format():
# Passlib's verify handles many format errors gracefully
assert verify_password("any_password", "invalid_hash_string") is False
# --- JWT Tests ---
def test_create_access_token():
subject = "testuser@example.com"
token = create_access_token(subject=subject)
assert isinstance(token, str)
# Decode manually for basic check (verification done in verify_access_token tests)
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
assert payload["sub"] == subject
assert "exp" in payload
assert isinstance(payload["exp"], int)
def test_verify_access_token_valid():
subject = "test_subject_valid"
token = create_access_token(subject=subject)
payload = verify_access_token(token)
assert payload is not None
assert payload["sub"] == subject
def test_verify_access_token_invalid_signature():
subject = "test_subject_invalid_sig"
token = create_access_token(subject=subject)
# Attempt to verify with a wrong key
wrong_key = settings.SECRET_KEY + "wrong"
with pytest.raises(JWTError): # Decoding with wrong key should raise JWTError internally
jwt.decode(token, wrong_key, algorithms=[settings.ALGORITHM])
# Our verify function should catch this and return None
assert verify_access_token(token + "tamper") is None # Tampering token often invalidates sig
# Note: Testing verify_access_token directly returning None for wrong key is tricky
# as the error happens *during* jwt.decode. We rely on it catching JWTError.
def test_verify_access_token_expired():
# Create a token that expires almost immediately
subject = "test_subject_expired"
expires_delta = timedelta(seconds=-1) # Expired 1 second ago
token = create_access_token(subject=subject, expires_delta=expires_delta)
# Wait briefly just in case of timing issues, though negative delta should guarantee expiry
time.sleep(0.1)
# Decoding expired token raises ExpiredSignatureError internally
with pytest.raises(JWTError): # Specifically ExpiredSignatureError, but JWTError catches it
jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
# Our verify function should catch this and return None
assert verify_access_token(token) is None
def test_verify_access_token_malformed():
assert verify_access_token("this.is.not.a.valid.token") is None

505
be/app/crud/chore.py Normal file
View File

@ -0,0 +1,505 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sqlalchemy import union_all
from typing import List, Optional
import logging
from datetime import date, datetime
from app.models import Chore, Group, User, ChoreAssignment, ChoreFrequencyEnum, ChoreTypeEnum, UserGroup
from app.schemas.chore import ChoreCreate, ChoreUpdate, ChoreAssignmentCreate, ChoreAssignmentUpdate
from app.core.chore_utils import calculate_next_due_date
from app.crud.group import get_group_by_id, is_user_member
from app.core.exceptions import ChoreNotFoundError, GroupNotFoundError, PermissionDeniedError, DatabaseIntegrityError
logger = logging.getLogger(__name__)
async def get_all_user_chores(db: AsyncSession, user_id: int) -> List[Chore]:
"""Gets all chores (personal and group) for a user in optimized queries."""
# Get personal chores query
personal_chores_query = (
select(Chore)
.where(
Chore.created_by_id == user_id,
Chore.type == ChoreTypeEnum.personal
)
)
# Get user's group IDs first
user_groups_result = await db.execute(
select(UserGroup.group_id).where(UserGroup.user_id == user_id)
)
user_group_ids = user_groups_result.scalars().all()
all_chores = []
# Execute personal chores query
personal_result = await db.execute(
personal_chores_query
.options(
selectinload(Chore.creator),
selectinload(Chore.assignments).selectinload(ChoreAssignment.assigned_user)
)
.order_by(Chore.next_due_date, Chore.name)
)
all_chores.extend(personal_result.scalars().all())
# If user has groups, get all group chores in one query
if user_group_ids:
group_chores_result = await db.execute(
select(Chore)
.where(
Chore.group_id.in_(user_group_ids),
Chore.type == ChoreTypeEnum.group
)
.options(
selectinload(Chore.creator),
selectinload(Chore.group),
selectinload(Chore.assignments).selectinload(ChoreAssignment.assigned_user)
)
.order_by(Chore.next_due_date, Chore.name)
)
all_chores.extend(group_chores_result.scalars().all())
return all_chores
async def create_chore(
db: AsyncSession,
chore_in: ChoreCreate,
user_id: int,
group_id: Optional[int] = None
) -> Chore:
"""Creates a new chore, either personal or within a specific group."""
# Use the transaction pattern from the FastAPI strategy
async with db.begin_nested() if db.in_transaction() else db.begin():
if chore_in.type == ChoreTypeEnum.group:
if not group_id:
raise ValueError("group_id is required for group chores")
# Validate group existence and user membership
group = await get_group_by_id(db, group_id)
if not group:
raise GroupNotFoundError(group_id)
if not await is_user_member(db, group_id, user_id):
raise PermissionDeniedError(detail=f"User {user_id} not a member of group {group_id}")
else: # personal chore
if group_id:
raise ValueError("group_id must be None for personal chores")
db_chore = Chore(
**chore_in.model_dump(exclude_unset=True, exclude={'group_id'}),
group_id=group_id,
created_by_id=user_id,
)
# Specific check for custom frequency
if chore_in.frequency == ChoreFrequencyEnum.custom and chore_in.custom_interval_days is None:
raise ValueError("custom_interval_days must be set for custom frequency chores.")
db.add(db_chore)
await db.flush() # Get the ID for the chore
try:
# Load relationships for the response with eager loading
result = await db.execute(
select(Chore)
.where(Chore.id == db_chore.id)
.options(
selectinload(Chore.creator),
selectinload(Chore.group),
selectinload(Chore.assignments)
)
)
return result.scalar_one()
except Exception as e:
logger.error(f"Error creating chore: {e}", exc_info=True)
raise DatabaseIntegrityError(f"Could not create chore. Error: {str(e)}")
async def get_chore_by_id(db: AsyncSession, chore_id: int) -> Optional[Chore]:
"""Gets a chore by ID."""
result = await db.execute(
select(Chore)
.where(Chore.id == chore_id)
.options(selectinload(Chore.creator), selectinload(Chore.group), selectinload(Chore.assignments))
)
return result.scalar_one_or_none()
async def get_chore_by_id_and_group(
db: AsyncSession,
chore_id: int,
group_id: int,
user_id: int
) -> Optional[Chore]:
"""Gets a specific group chore by ID, ensuring it belongs to the group and user is a member."""
if not await is_user_member(db, group_id, user_id):
raise PermissionDeniedError(detail=f"User {user_id} not a member of group {group_id}")
chore = await get_chore_by_id(db, chore_id)
if chore and chore.group_id == group_id and chore.type == ChoreTypeEnum.group:
return chore
return None
async def get_personal_chores(
db: AsyncSession,
user_id: int
) -> List[Chore]:
"""Gets all personal chores for a user with optimized eager loading."""
result = await db.execute(
select(Chore)
.where(
Chore.created_by_id == user_id,
Chore.type == ChoreTypeEnum.personal
)
.options(
selectinload(Chore.creator),
selectinload(Chore.assignments).selectinload(ChoreAssignment.assigned_user)
)
.order_by(Chore.next_due_date, Chore.name)
)
return result.scalars().all()
async def get_chores_by_group_id(
db: AsyncSession,
group_id: int,
user_id: int
) -> List[Chore]:
"""Gets all chores for a specific group with optimized eager loading, if the user is a member."""
if not await is_user_member(db, group_id, user_id):
raise PermissionDeniedError(detail=f"User {user_id} not a member of group {group_id}")
result = await db.execute(
select(Chore)
.where(
Chore.group_id == group_id,
Chore.type == ChoreTypeEnum.group
)
.options(
selectinload(Chore.creator),
selectinload(Chore.assignments).selectinload(ChoreAssignment.assigned_user)
)
.order_by(Chore.next_due_date, Chore.name)
)
return result.scalars().all()
async def update_chore(
db: AsyncSession,
chore_id: int,
chore_in: ChoreUpdate,
user_id: int,
group_id: Optional[int] = None
) -> Optional[Chore]:
"""Updates a chore's details using proper transaction management."""
async with db.begin_nested() if db.in_transaction() else db.begin():
db_chore = await get_chore_by_id(db, chore_id)
if not db_chore:
raise ChoreNotFoundError(chore_id, group_id)
# Check permissions
if db_chore.type == ChoreTypeEnum.group:
if not group_id:
raise ValueError("group_id is required for group chores")
if not await is_user_member(db, group_id, user_id):
raise PermissionDeniedError(detail=f"User {user_id} not a member of group {group_id}")
if db_chore.group_id != group_id:
raise ChoreNotFoundError(chore_id, group_id)
else: # personal chore
if group_id:
raise ValueError("group_id must be None for personal chores")
if db_chore.created_by_id != user_id:
raise PermissionDeniedError(detail="Only the creator can update personal chores")
update_data = chore_in.model_dump(exclude_unset=True)
# Handle type change
if 'type' in update_data:
new_type = update_data['type']
if new_type == ChoreTypeEnum.group and not group_id:
raise ValueError("group_id is required for group chores")
if new_type == ChoreTypeEnum.personal and group_id:
raise ValueError("group_id must be None for personal chores")
# Recalculate next_due_date if needed
recalculate = False
if 'frequency' in update_data and update_data['frequency'] != db_chore.frequency:
recalculate = True
if 'custom_interval_days' in update_data and update_data['custom_interval_days'] != db_chore.custom_interval_days:
recalculate = True
current_next_due_date_for_calc = db_chore.next_due_date
if 'next_due_date' in update_data:
current_next_due_date_for_calc = update_data['next_due_date']
if not ('frequency' in update_data or 'custom_interval_days' in update_data):
recalculate = False
for field, value in update_data.items():
setattr(db_chore, field, value)
if recalculate:
db_chore.next_due_date = calculate_next_due_date(
current_due_date=current_next_due_date_for_calc,
frequency=db_chore.frequency,
custom_interval_days=db_chore.custom_interval_days,
last_completed_date=db_chore.last_completed_at
)
if db_chore.frequency == ChoreFrequencyEnum.custom and db_chore.custom_interval_days is None:
raise ValueError("custom_interval_days must be set for custom frequency chores.")
try:
await db.flush() # Flush changes within the transaction
result = await db.execute(
select(Chore)
.where(Chore.id == db_chore.id)
.options(
selectinload(Chore.creator),
selectinload(Chore.group),
selectinload(Chore.assignments).selectinload(ChoreAssignment.assigned_user)
)
)
return result.scalar_one()
except Exception as e:
logger.error(f"Error updating chore {chore_id}: {e}", exc_info=True)
raise DatabaseIntegrityError(f"Could not update chore {chore_id}. Error: {str(e)}")
async def delete_chore(
db: AsyncSession,
chore_id: int,
user_id: int,
group_id: Optional[int] = None
) -> bool:
"""Deletes a chore and its assignments using proper transaction management, ensuring user has permission."""
async with db.begin_nested() if db.in_transaction() else db.begin():
db_chore = await get_chore_by_id(db, chore_id)
if not db_chore:
raise ChoreNotFoundError(chore_id, group_id)
# Check permissions
if db_chore.type == ChoreTypeEnum.group:
if not group_id:
raise ValueError("group_id is required for group chores")
if not await is_user_member(db, group_id, user_id):
raise PermissionDeniedError(detail=f"User {user_id} not a member of group {group_id}")
if db_chore.group_id != group_id:
raise ChoreNotFoundError(chore_id, group_id)
else: # personal chore
if group_id:
raise ValueError("group_id must be None for personal chores")
if db_chore.created_by_id != user_id:
raise PermissionDeniedError(detail="Only the creator can delete personal chores")
try:
await db.delete(db_chore)
await db.flush() # Ensure deletion is processed within the transaction
return True
except Exception as e:
logger.error(f"Error deleting chore {chore_id}: {e}", exc_info=True)
raise DatabaseIntegrityError(f"Could not delete chore {chore_id}. Error: {str(e)}")
# === CHORE ASSIGNMENT CRUD FUNCTIONS ===
async def create_chore_assignment(
db: AsyncSession,
assignment_in: ChoreAssignmentCreate,
user_id: int
) -> ChoreAssignment:
"""Creates a new chore assignment. User must be able to manage the chore."""
async with db.begin_nested() if db.in_transaction() else db.begin():
# Get the chore and validate permissions
chore = await get_chore_by_id(db, assignment_in.chore_id)
if not chore:
raise ChoreNotFoundError(chore_id=assignment_in.chore_id)
# Check permissions to assign this chore
if chore.type == ChoreTypeEnum.personal:
if chore.created_by_id != user_id:
raise PermissionDeniedError(detail="Only the creator can assign personal chores")
else: # group chore
if not await is_user_member(db, chore.group_id, user_id):
raise PermissionDeniedError(detail=f"User {user_id} not a member of group {chore.group_id}")
# For group chores, check if assignee is also a group member
if not await is_user_member(db, chore.group_id, assignment_in.assigned_to_user_id):
raise PermissionDeniedError(detail=f"Cannot assign chore to user {assignment_in.assigned_to_user_id} who is not a group member")
db_assignment = ChoreAssignment(**assignment_in.model_dump(exclude_unset=True))
db.add(db_assignment)
await db.flush() # Get the ID for the assignment
try:
# Load relationships for the response
result = await db.execute(
select(ChoreAssignment)
.where(ChoreAssignment.id == db_assignment.id)
.options(
selectinload(ChoreAssignment.chore).selectinload(Chore.creator),
selectinload(ChoreAssignment.assigned_user)
)
)
return result.scalar_one()
except Exception as e:
logger.error(f"Error creating chore assignment: {e}", exc_info=True)
raise DatabaseIntegrityError(f"Could not create chore assignment. Error: {str(e)}")
async def get_chore_assignment_by_id(db: AsyncSession, assignment_id: int) -> Optional[ChoreAssignment]:
"""Gets a chore assignment by ID."""
result = await db.execute(
select(ChoreAssignment)
.where(ChoreAssignment.id == assignment_id)
.options(
selectinload(ChoreAssignment.chore).selectinload(Chore.creator),
selectinload(ChoreAssignment.assigned_user)
)
)
return result.scalar_one_or_none()
async def get_user_assignments(
db: AsyncSession,
user_id: int,
include_completed: bool = False
) -> List[ChoreAssignment]:
"""Gets all chore assignments for a user."""
query = select(ChoreAssignment).where(ChoreAssignment.assigned_to_user_id == user_id)
if not include_completed:
query = query.where(ChoreAssignment.is_complete == False)
query = query.options(
selectinload(ChoreAssignment.chore).selectinload(Chore.creator),
selectinload(ChoreAssignment.assigned_user)
).order_by(ChoreAssignment.due_date, ChoreAssignment.id)
result = await db.execute(query)
return result.scalars().all()
async def get_chore_assignments(
db: AsyncSession,
chore_id: int,
user_id: int
) -> List[ChoreAssignment]:
"""Gets all assignments for a specific chore. User must have permission to view the chore."""
chore = await get_chore_by_id(db, chore_id)
if not chore:
raise ChoreNotFoundError(chore_id=chore_id)
# Check permissions
if chore.type == ChoreTypeEnum.personal:
if chore.created_by_id != user_id:
raise PermissionDeniedError(detail="Can only view assignments for own personal chores")
else: # group chore
if not await is_user_member(db, chore.group_id, user_id):
raise PermissionDeniedError(detail=f"User {user_id} not a member of group {chore.group_id}")
result = await db.execute(
select(ChoreAssignment)
.where(ChoreAssignment.chore_id == chore_id)
.options(
selectinload(ChoreAssignment.chore).selectinload(Chore.creator),
selectinload(ChoreAssignment.assigned_user)
)
.order_by(ChoreAssignment.due_date, ChoreAssignment.id)
)
return result.scalars().all()
async def update_chore_assignment(
db: AsyncSession,
assignment_id: int,
assignment_in: ChoreAssignmentUpdate,
user_id: int
) -> Optional[ChoreAssignment]:
"""Updates a chore assignment. Only the assignee can mark it complete."""
async with db.begin_nested() if db.in_transaction() else db.begin():
db_assignment = await get_chore_assignment_by_id(db, assignment_id)
if not db_assignment:
raise ChoreNotFoundError(assignment_id=assignment_id)
# Load the chore for permission checking
chore = await get_chore_by_id(db, db_assignment.chore_id)
if not chore:
raise ChoreNotFoundError(chore_id=db_assignment.chore_id)
# Check permissions - only assignee can complete, but chore managers can reschedule
can_manage = False
if chore.type == ChoreTypeEnum.personal:
can_manage = chore.created_by_id == user_id
else: # group chore
can_manage = await is_user_member(db, chore.group_id, user_id)
can_complete = db_assignment.assigned_to_user_id == user_id
update_data = assignment_in.model_dump(exclude_unset=True)
# Check specific permissions for different updates
if 'is_complete' in update_data and not can_complete:
raise PermissionDeniedError(detail="Only the assignee can mark assignments as complete")
if 'due_date' in update_data and not can_manage:
raise PermissionDeniedError(detail="Only chore managers can reschedule assignments")
# Handle completion logic
if 'is_complete' in update_data and update_data['is_complete']:
if not db_assignment.is_complete: # Only if not already complete
update_data['completed_at'] = datetime.utcnow()
# Update parent chore's last_completed_at and recalculate next_due_date
chore.last_completed_at = update_data['completed_at']
chore.next_due_date = calculate_next_due_date(
current_due_date=chore.next_due_date,
frequency=chore.frequency,
custom_interval_days=chore.custom_interval_days,
last_completed_date=chore.last_completed_at
)
elif 'is_complete' in update_data and not update_data['is_complete']:
# If marking as incomplete, clear completed_at
update_data['completed_at'] = None
# Apply updates
for field, value in update_data.items():
setattr(db_assignment, field, value)
try:
await db.flush() # Flush changes within the transaction
# Load relationships for the response
result = await db.execute(
select(ChoreAssignment)
.where(ChoreAssignment.id == db_assignment.id)
.options(
selectinload(ChoreAssignment.chore).selectinload(Chore.creator),
selectinload(ChoreAssignment.assigned_user)
)
)
return result.scalar_one()
except Exception as e:
logger.error(f"Error updating chore assignment {assignment_id}: {e}", exc_info=True)
raise DatabaseIntegrityError(f"Could not update chore assignment {assignment_id}. Error: {str(e)}")
async def delete_chore_assignment(
db: AsyncSession,
assignment_id: int,
user_id: int
) -> bool:
"""Deletes a chore assignment. User must have permission to manage the chore."""
async with db.begin_nested() if db.in_transaction() else db.begin():
db_assignment = await get_chore_assignment_by_id(db, assignment_id)
if not db_assignment:
raise ChoreNotFoundError(assignment_id=assignment_id)
# Load the chore for permission checking
chore = await get_chore_by_id(db, db_assignment.chore_id)
if not chore:
raise ChoreNotFoundError(chore_id=db_assignment.chore_id)
# Check permissions
if chore.type == ChoreTypeEnum.personal:
if chore.created_by_id != user_id:
raise PermissionDeniedError(detail="Only the creator can delete personal chore assignments")
else: # group chore
if not await is_user_member(db, chore.group_id, user_id):
raise PermissionDeniedError(detail=f"User {user_id} not a member of group {chore.group_id}")
try:
await db.delete(db_assignment)
await db.flush() # Ensure deletion is processed within the transaction
return True
except Exception as e:
logger.error(f"Error deleting chore assignment {assignment_id}: {e}", exc_info=True)
raise DatabaseIntegrityError(f"Could not delete chore assignment {assignment_id}. Error: {str(e)}")

View File

@ -1,88 +1,699 @@
# app/crud/expense.py
import logging # Add logging import
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from typing import List as PyList, Sequence, Optional
from decimal import Decimal, ROUND_HALF_UP
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError # Added import
from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict, Any
from datetime import datetime, timezone # Added timezone
from app.models import (
Item as ItemModel,
Expense as ExpenseModel,
ExpenseSplit as ExpenseSplitModel,
User as UserModel,
List as ListModel,
Group as GroupModel,
UserGroup as UserGroupModel,
ExpenseRecord as ExpenseRecordModel,
ExpenseShare as ExpenseShareModel,
SettlementActivity as SettlementActivityModel,
SplitTypeEnum,
Item as ItemModel,
ExpenseOverallStatusEnum, # Added
ExpenseSplitStatusEnum, # Added
)
from app.schemas.expense import ExpenseCreate, ExpenseSplitCreate, ExpenseUpdate # Removed unused ExpenseUpdate
from app.core.exceptions import (
# Using existing specific exceptions where possible
ListNotFoundError,
GroupNotFoundError,
UserNotFoundError,
InvalidOperationError, # Import the new exception
DatabaseConnectionError, # Added
DatabaseIntegrityError, # Added
DatabaseQueryError, # Added
DatabaseTransactionError,# Added
ExpenseOperationError # Added specific exception
)
from app.models import RecurrencePattern
async def get_priced_items_for_list(db: AsyncSession, list_id: int) -> Sequence[ItemModel]:
result = await db.execute(select(ItemModel).where(ItemModel.list_id == list_id, ItemModel.price.is_not(None)))
return result.scalars().all()
# Placeholder for InvalidOperationError if not defined in app.core.exceptions
# This should be a proper HTTPException subclass if used in API layer
# class CrudInvalidOperationError(ValueError): # For internal CRUD validation logic
# pass
async def get_group_member_ids(db: AsyncSession, group_id: int) -> PyList[int]:
result = await db.execute(select(UserModel.user_id).where(UserGroupModel.group_id == group_id))
return result.scalars().all()
logger = logging.getLogger(__name__) # Initialize logger
async def create_expense_record_and_shares(
db: AsyncSession,
list_id: int,
calculated_by_id: int,
total_amount: Decimal,
participant_ids: PyList[int],
split_type: SplitTypeEnum = SplitTypeEnum.equal
) -> ExpenseRecordModel:
if not participant_ids or total_amount <= Decimal("0.00"):
raise ValueError("Invalid participants or total amount.")
def _round_money(amount: Decimal) -> Decimal:
"""Rounds a Decimal to two decimal places using ROUND_HALF_UP."""
return amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
db_expense_record = ExpenseRecordModel(
list_id=list_id,
calculated_by_id=calculated_by_id,
total_amount=total_amount,
participants=participant_ids,
split_type=split_type,
is_settled=False
async def get_users_for_splitting(db: AsyncSession, expense_group_id: Optional[int], expense_list_id: Optional[int], expense_paid_by_user_id: int) -> PyList[UserModel]:
"""
Determines the list of users an expense should be split amongst.
Priority: Group members (if group_id), then List's group members or creator (if list_id).
Fallback to only the payer if no other context yields users.
"""
users_to_split_with: PyList[UserModel] = []
processed_user_ids = set()
async def _add_user(user: Optional[UserModel]):
if user and user.id not in processed_user_ids:
users_to_split_with.append(user)
processed_user_ids.add(user.id)
if expense_group_id:
group_result = await db.execute(
select(GroupModel).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user)))
.where(GroupModel.id == expense_group_id)
)
db.add(db_expense_record)
group = group_result.scalars().first()
if not group:
raise GroupNotFoundError(expense_group_id)
for assoc in group.member_associations:
await _add_user(assoc.user)
elif expense_list_id: # Only if group_id was not primary context
list_result = await db.execute(
select(ListModel)
.options(
selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))),
selectinload(ListModel.creator)
)
.where(ListModel.id == expense_list_id)
)
db_list = list_result.scalars().first()
if not db_list:
raise ListNotFoundError(expense_list_id)
if db_list.group:
for assoc in db_list.group.member_associations:
await _add_user(assoc.user)
elif db_list.creator:
await _add_user(db_list.creator)
if not users_to_split_with:
payer_user = await db.get(UserModel, expense_paid_by_user_id)
if not payer_user:
# This should have been caught earlier if paid_by_user_id was validated before calling this helper
raise UserNotFoundError(user_id=expense_paid_by_user_id)
await _add_user(payer_user)
if not users_to_split_with:
# This should ideally not be reached if payer is always a fallback
raise InvalidOperationError("Could not determine any users for splitting the expense.")
return users_to_split_with
async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_user_id: int) -> ExpenseModel:
"""Creates a new expense and its associated splits.
Args:
db: Database session
expense_in: Expense creation data
current_user_id: ID of the user creating the expense
Returns:
The created expense with splits
Raises:
UserNotFoundError: If payer or split users don't exist
ListNotFoundError: If specified list doesn't exist
GroupNotFoundError: If specified group doesn't exist
InvalidOperationError: For various validation failures
"""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
# 1. Validate payer
payer = await db.get(UserModel, expense_in.paid_by_user_id)
if not payer:
raise UserNotFoundError(user_id=expense_in.paid_by_user_id, identifier="Payer")
# 2. Context Resolution and Validation (now part of the transaction)
if not expense_in.list_id and not expense_in.group_id and not expense_in.item_id:
raise InvalidOperationError("Expense must be associated with a list, a group, or an item.")
final_group_id = await _resolve_expense_context(db, expense_in)
# Further validation for item_id if provided
db_item_instance = None
if expense_in.item_id:
db_item_instance = await db.get(ItemModel, expense_in.item_id)
if not db_item_instance:
raise InvalidOperationError(f"Item with ID {expense_in.item_id} not found.")
# Potentially link item's list/group if not already set on expense_in
if db_item_instance.list_id and not expense_in.list_id:
expense_in.list_id = db_item_instance.list_id
# Re-resolve context if list_id was derived from item
final_group_id = await _resolve_expense_context(db, expense_in)
# Create recurrence pattern if this is a recurring expense
recurrence_pattern = None
if expense_in.is_recurring and expense_in.recurrence_pattern:
recurrence_pattern = RecurrencePattern(
type=expense_in.recurrence_pattern.type,
interval=expense_in.recurrence_pattern.interval,
days_of_week=expense_in.recurrence_pattern.days_of_week,
end_date=expense_in.recurrence_pattern.end_date,
max_occurrences=expense_in.recurrence_pattern.max_occurrences,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
db.add(recurrence_pattern)
await db.flush()
num_participants = len(participant_ids)
individual_share = (total_amount / Decimal(num_participants)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
total_calculated = individual_share * (num_participants - 1)
last_share = total_amount - total_calculated
# 3. Create the ExpenseModel instance
db_expense = ExpenseModel(
description=expense_in.description,
total_amount=_round_money(expense_in.total_amount),
currency=expense_in.currency or "USD",
expense_date=expense_in.expense_date or datetime.now(timezone.utc),
split_type=expense_in.split_type,
list_id=expense_in.list_id,
group_id=final_group_id, # Use resolved group_id
item_id=expense_in.item_id,
paid_by_user_id=expense_in.paid_by_user_id,
created_by_user_id=current_user_id,
overall_settlement_status=ExpenseOverallStatusEnum.unpaid,
is_recurring=expense_in.is_recurring,
recurrence_pattern=recurrence_pattern,
next_occurrence=expense_in.expense_date if expense_in.is_recurring else None
)
db.add(db_expense)
await db.flush() # Get expense ID
shares_to_add = [
ExpenseShareModel(expense_record_id=db_expense_record.id, user_id=user_id, amount_owed=(last_share if i == num_participants - 1 else individual_share), is_paid=False)
for i, user_id in enumerate(participant_ids)
]
# 4. Generate splits (passing current_user_id through kwargs if needed by specific split types)
splits_to_create = await _generate_expense_splits(
db=db,
expense_model=db_expense,
expense_in=expense_in,
current_user_id=current_user_id # Pass for item-based splits needing creator info
)
db.add_all(shares_to_add)
await db.commit()
await db.refresh(db_expense_record, attribute_names=['shares'])
return db_expense_record
for split_model in splits_to_create:
split_model.expense_id = db_expense.id # Set FK after db_expense has ID
db.add_all(splits_to_create)
await db.flush() # Persist splits
# Fetch all expense records for a list
async def get_expense_records_for_list(db: AsyncSession, list_id: int) -> Sequence[ExpenseRecordModel]:
result = await db.execute(
select(ExpenseRecordModel)
.where(ExpenseRecordModel.list_id == list_id)
# 5. Re-fetch the expense with all necessary relationships for the response
stmt = (
select(ExpenseModel)
.where(ExpenseModel.id == db_expense.id)
.options(
selectinload(ExpenseRecordModel.shares).selectinload(ExpenseShareModel.user),
selectinload(ExpenseRecordModel.settlement_activities)
selectinload(ExpenseModel.paid_by_user),
selectinload(ExpenseModel.created_by_user), # If you have this relationship
selectinload(ExpenseModel.list),
selectinload(ExpenseModel.group),
selectinload(ExpenseModel.item),
selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user),
selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.settlement_activities)
)
.order_by(ExpenseRecordModel.calculated_at.desc())
)
return result.scalars().unique().all()
result = await db.execute(stmt)
loaded_expense = result.scalar_one_or_none()
# Fetch a specific expense record by ID
async def get_expense_record_by_id(db: AsyncSession, record_id: int) -> Optional[ExpenseRecordModel]:
if loaded_expense is None:
# The context manager will handle rollback if an exception is raised.
# await transaction.rollback() # Should be handled by context manager
raise ExpenseOperationError("Failed to load expense after creation.")
# await transaction.commit() # Explicit commit removed, context manager handles it.
return loaded_expense
except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
# These are business logic validation errors, re-raise them.
# If a transaction was started, the context manager handles rollback.
raise
except IntegrityError as e:
# Context manager handles rollback.
logger.error(f"Database integrity error during expense creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to save expense due to database integrity issue: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during expense creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error during expense creation: {str(e)}")
except SQLAlchemyError as e:
# Context manager handles rollback.
logger.error(f"Unexpected SQLAlchemy error during expense creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to save expense due to a database transaction error: {str(e)}")
async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]:
"""Resolves and validates the expense's context (list and group).
Returns the final group_id for the expense after validation.
"""
final_group_id = expense_in.group_id
# If list_id is provided, validate it and potentially derive group_id
if expense_in.list_id:
list_obj = await db.get(ListModel, expense_in.list_id)
if not list_obj:
raise ListNotFoundError(expense_in.list_id)
# If list belongs to a group, verify consistency or inherit group_id
if list_obj.group_id:
if expense_in.group_id and list_obj.group_id != expense_in.group_id:
raise InvalidOperationError(
f"List {expense_in.list_id} belongs to group {list_obj.group_id}, "
f"but expense was specified for group {expense_in.group_id}."
)
final_group_id = list_obj.group_id # Prioritize list's group
# If only group_id is provided (no list_id), validate group_id
elif final_group_id:
group_obj = await db.get(GroupModel, final_group_id)
if not group_obj:
raise GroupNotFoundError(final_group_id)
return final_group_id
async def _generate_expense_splits(
db: AsyncSession,
expense_model: ExpenseModel,
expense_in: ExpenseCreate,
**kwargs: Any
) -> PyList[ExpenseSplitModel]:
"""Generates appropriate expense splits based on split type."""
splits_to_create: PyList[ExpenseSplitModel] = []
# Pass db to split creation helpers if they need to fetch more data (e.g., item details for item-based)
common_args = {"db": db, "expense_model": expense_model, "expense_in": expense_in, "round_money_func": _round_money, "kwargs": kwargs}
# Create splits based on the split type
if expense_in.split_type == SplitTypeEnum.EQUAL:
splits_to_create = await _create_equal_splits(**common_args)
elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS:
splits_to_create = await _create_exact_amount_splits(**common_args)
elif expense_in.split_type == SplitTypeEnum.PERCENTAGE:
splits_to_create = await _create_percentage_splits(**common_args)
elif expense_in.split_type == SplitTypeEnum.SHARES:
splits_to_create = await _create_shares_splits(**common_args)
elif expense_in.split_type == SplitTypeEnum.ITEM_BASED:
splits_to_create = await _create_item_based_splits(**common_args)
else:
raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}")
if not splits_to_create:
raise InvalidOperationError("No expense splits were generated.")
return splits_to_create
async def _create_equal_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
"""Creates equal splits among users."""
users_for_splitting = await get_users_for_splitting(
db, expense_model.group_id, expense_model.list_id, expense_model.paid_by_user_id
)
if not users_for_splitting:
raise InvalidOperationError("No users found for EQUAL split.")
num_users = len(users_for_splitting)
amount_per_user = round_money_func(expense_model.total_amount / Decimal(num_users))
remainder = expense_model.total_amount - (amount_per_user * num_users)
splits = []
for i, user in enumerate(users_for_splitting):
split_amount = amount_per_user
if i == 0 and remainder != Decimal('0'):
split_amount = round_money_func(amount_per_user + remainder)
splits.append(ExpenseSplitModel(
user_id=user.id,
owed_amount=split_amount,
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
))
return splits
async def _create_exact_amount_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
"""Creates splits with exact amounts."""
if not expense_in.splits_in:
raise InvalidOperationError("Splits data is required for EXACT_AMOUNTS split type.")
# Validate all users in splits exist
await _validate_users_in_splits(db, expense_in.splits_in)
current_total = Decimal("0.00")
splits = []
for split_in in expense_in.splits_in:
if split_in.owed_amount <= Decimal('0'):
raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.")
rounded_amount = round_money_func(split_in.owed_amount)
current_total += rounded_amount
splits.append(ExpenseSplitModel(
user_id=split_in.user_id,
owed_amount=rounded_amount,
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
))
if round_money_func(current_total) != expense_model.total_amount:
raise InvalidOperationError(
f"Sum of exact split amounts ({current_total}) != expense total ({expense_model.total_amount})."
)
return splits
async def _create_percentage_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
"""Creates splits based on percentages."""
if not expense_in.splits_in:
raise InvalidOperationError("Splits data is required for PERCENTAGE split type.")
# Validate all users in splits exist
await _validate_users_in_splits(db, expense_in.splits_in)
total_percentage = Decimal("0.00")
current_total = Decimal("0.00")
splits = []
for split_in in expense_in.splits_in:
if not (split_in.share_percentage and Decimal("0") < split_in.share_percentage <= Decimal("100")):
raise InvalidOperationError(
f"Invalid percentage {split_in.share_percentage} for user {split_in.user_id}."
)
total_percentage += split_in.share_percentage
owed_amount = round_money_func(expense_model.total_amount * (split_in.share_percentage / Decimal("100")))
current_total += owed_amount
splits.append(ExpenseSplitModel(
user_id=split_in.user_id,
owed_amount=owed_amount,
share_percentage=split_in.share_percentage,
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
))
if round_money_func(total_percentage) != Decimal("100.00"):
raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.")
# Adjust for rounding differences
if current_total != expense_model.total_amount and splits:
diff = expense_model.total_amount - current_total
splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
return splits
async def _create_shares_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
"""Creates splits based on shares."""
if not expense_in.splits_in:
raise InvalidOperationError("Splits data is required for SHARES split type.")
# Validate all users in splits exist
await _validate_users_in_splits(db, expense_in.splits_in)
# Calculate total shares
total_shares = sum(s.share_units for s in expense_in.splits_in if s.share_units and s.share_units > 0)
if total_shares == 0:
raise InvalidOperationError("Total shares cannot be zero for SHARES split.")
splits = []
current_total = Decimal("0.00")
for split_in in expense_in.splits_in:
if not (split_in.share_units and split_in.share_units > 0):
raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.")
share_ratio = Decimal(split_in.share_units) / Decimal(total_shares)
owed_amount = round_money_func(expense_model.total_amount * share_ratio)
current_total += owed_amount
splits.append(ExpenseSplitModel(
user_id=split_in.user_id,
owed_amount=owed_amount,
share_units=split_in.share_units,
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
))
# Adjust for rounding differences
if current_total != expense_model.total_amount and splits:
diff = expense_model.total_amount - current_total
splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
return splits
async def _create_item_based_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
"""Creates splits based on items in a shopping list."""
if not expense_model.list_id:
raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.")
if expense_in.splits_in:
logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.")
# Build query to fetch relevant items
items_query = select(ItemModel).where(ItemModel.list_id == expense_model.list_id)
if expense_model.item_id:
items_query = items_query.where(ItemModel.id == expense_model.item_id)
else:
items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0")))
# Load items with their adders
items_result = await db.execute(items_query.options(selectinload(ItemModel.added_by_user)))
relevant_items = items_result.scalars().all()
if not relevant_items:
error_msg = (
f"Specified item ID {expense_model.item_id} not found in list {expense_model.list_id}."
if expense_model.item_id else
f"List {expense_model.list_id} has no priced items to base the expense on."
)
raise InvalidOperationError(error_msg)
# Aggregate owed amounts by user
calculated_total = Decimal("0.00")
user_owed_amounts = defaultdict(Decimal)
processed_items = 0
for item in relevant_items:
if item.price is None or item.price <= Decimal("0"):
if expense_model.item_id:
raise InvalidOperationError(
f"Item ID {expense_model.item_id} must have a positive price for ITEM_BASED expense."
)
continue
if not item.added_by_user:
logger.error(f"Item ID {item.id} is missing added_by_user relationship.")
raise InvalidOperationError(f"Data integrity issue: Item {item.id} is missing adder information.")
calculated_total += item.price
user_owed_amounts[item.added_by_user.id] += item.price
processed_items += 1
if processed_items == 0:
raise InvalidOperationError(
f"No items with positive prices found in list {expense_model.list_id} to create ITEM_BASED expense."
)
# Validate total matches calculated total
if round_money_func(calculated_total) != expense_model.total_amount:
raise InvalidOperationError(
f"Expense total amount ({expense_model.total_amount}) does not match the "
f"calculated total from item prices ({calculated_total})."
)
# Create splits based on aggregated amounts
splits = []
for user_id, owed_amount in user_owed_amounts.items():
splits.append(ExpenseSplitModel(
user_id=user_id,
owed_amount=round_money_func(owed_amount),
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
))
return splits
async def _validate_users_in_splits(db: AsyncSession, splits_in: PyList[ExpenseSplitCreate]) -> None:
"""Validates that all users in the splits exist."""
user_ids_in_split = [s.user_id for s in splits_in]
user_results = await db.execute(select(UserModel.id).where(UserModel.id.in_(user_ids_in_split)))
found_user_ids = {row[0] for row in user_results}
if len(found_user_ids) != len(user_ids_in_split):
missing_user_ids = set(user_ids_in_split) - found_user_ids
raise UserNotFoundError(identifier=f"users in split data: {list(missing_user_ids)}")
async def get_expense_by_id(db: AsyncSession, expense_id: int) -> Optional[ExpenseModel]:
result = await db.execute(
select(ExpenseRecordModel)
.where(ExpenseRecordModel.id == record_id)
select(ExpenseModel)
.options(
selectinload(ExpenseRecordModel.shares).selectinload(ExpenseShareModel.user),
selectinload(ExpenseRecordModel.settlement_activities).options(
joinedload(SettlementActivityModel.payer),
joinedload(SettlementActivityModel.affected_user)
)
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)),
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.settlement_activities)),
selectinload(ExpenseModel.paid_by_user),
selectinload(ExpenseModel.list),
selectinload(ExpenseModel.group),
selectinload(ExpenseModel.item)
)
.where(ExpenseModel.id == expense_id)
)
return result.scalars().first()
async def get_expenses_for_list(db: AsyncSession, list_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.list_id == list_id)
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
.offset(skip).limit(limit)
.options(
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)),
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.settlement_activities))
)
)
return result.scalars().all()
async def get_expenses_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.group_id == group_id)
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
.offset(skip).limit(limit)
.options(
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)),
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.settlement_activities))
)
)
return result.scalars().all()
async def get_user_accessible_expenses(db: AsyncSession, user_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
"""
Get all expenses that a user has access to:
- Expenses they paid for
- Expenses in groups they are members of
- Expenses for lists they have access to
"""
from app.models import UserGroup as UserGroupModel, List as ListModel # Import here to avoid circular imports
# Build the query for accessible expenses
# 1. Expenses paid by the user
paid_by_condition = ExpenseModel.paid_by_user_id == user_id
# 2. Expenses in groups where user is a member
group_member_subquery = select(UserGroupModel.group_id).where(UserGroupModel.user_id == user_id)
group_expenses_condition = ExpenseModel.group_id.in_(group_member_subquery)
# 3. Expenses for lists where user is creator or has access (simplified to creator for now)
user_lists_subquery = select(ListModel.id).where(ListModel.created_by_id == user_id)
list_expenses_condition = ExpenseModel.list_id.in_(user_lists_subquery)
# Combine all conditions with OR
combined_condition = paid_by_condition | group_expenses_condition | list_expenses_condition
result = await db.execute(
select(ExpenseModel)
.where(combined_condition)
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
.offset(skip).limit(limit)
.options(
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)),
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.settlement_activities)),
selectinload(ExpenseModel.paid_by_user),
selectinload(ExpenseModel.list),
selectinload(ExpenseModel.group)
)
)
return result.scalars().all()
async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in: ExpenseUpdate) -> ExpenseModel:
"""
Updates an existing expense.
Only allows updates to description, currency, and expense_date to avoid split complexities.
Requires version matching for optimistic locking.
"""
if expense_db.version != expense_in.version:
raise InvalidOperationError(
f"Expense '{expense_db.description}' (ID: {expense_db.id}) has been modified. "
f"Your version is {expense_in.version}, current version is {expense_db.version}. Please refresh.",
# status_code=status.HTTP_409_CONFLICT # This would be for the API layer to set
)
update_data = expense_in.model_dump(exclude_unset=True, exclude={"version"}) # Exclude version itself from data
# Fields that are safe to update without affecting splits or core logic
allowed_to_update = {"description", "currency", "expense_date"}
updated_something = False
for field, value in update_data.items():
if field in allowed_to_update:
setattr(expense_db, field, value)
updated_something = True
else:
# If any other field is present in the update payload, it's an invalid operation for this simple update
raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed.")
if not updated_something and not expense_in.model_fields_set.intersection(allowed_to_update):
# No actual updatable fields were provided in the payload, even if others (like version) were.
# This could be a non-issue, or an indication of a misuse of the endpoint.
# For now, if only version was sent, we still increment if it matched.
pass # Or raise InvalidOperationError("No updatable fields provided.")
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
expense_db.version += 1
expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp
# db.add(expense_db) # Not strictly necessary as expense_db is already tracked by the session
await db.flush() # Persist changes to the DB and run constraints
await db.refresh(expense_db) # Refresh the object from the DB
return expense_db
except InvalidOperationError: # Re-raise validation errors to be handled by the caller
raise
except IntegrityError as e:
logger.error(f"Database integrity error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseIntegrityError(f"Failed to update expense ID {expense_db.id} due to database integrity issue.") from e
except SQLAlchemyError as e: # Catch other SQLAlchemy errors
logger.error(f"Database transaction error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseTransactionError(f"Failed to update expense ID {expense_db.id} due to a database transaction error.") from e
# No generic Exception catch here, let other unexpected errors propagate if not SQLAlchemy related.
async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None:
"""
Deletes an expense. Requires version matching if expected_version is provided.
Associated ExpenseSplits are cascade deleted by the database foreign key constraint.
"""
if expected_version is not None and expense_db.version != expected_version:
raise InvalidOperationError(
f"Expense '{expense_db.description}' (ID: {expense_db.id}) cannot be deleted. "
f"Your expected version {expected_version} does not match current version {expense_db.version}. Please refresh.",
# status_code=status.HTTP_409_CONFLICT
)
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
await db.delete(expense_db)
await db.flush() # Ensure the delete operation is sent to the database
except InvalidOperationError: # Re-raise validation errors
raise
except IntegrityError as e:
logger.error(f"Database integrity error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseIntegrityError(f"Failed to delete expense ID {expense_db.id} due to database integrity issue.") from e
except SQLAlchemyError as e: # Catch other SQLAlchemy errors
logger.error(f"Database transaction error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseTransactionError(f"Failed to delete expense ID {expense_db.id} due to a database transaction error.") from e
return None
# Note: The InvalidOperationError is a simple ValueError placeholder.
# For API endpoints, these should be translated to appropriate HTTPExceptions.
# Ensure app.core.exceptions has proper HTTP error classes if needed.

View File

@ -2,67 +2,95 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload # For eager loading members
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List
from sqlalchemy import delete, func
import logging # Add logging import
from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel
from app.schemas.group import GroupCreate
from app.models import UserRoleEnum # Import enum
from app.core.exceptions import (
GroupOperationError,
GroupNotFoundError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
GroupMembershipError,
GroupPermissionError # Import GroupPermissionError
)
# --- Keep existing functions: get_user_by_email, create_user ---
# (These are actually user CRUD, should ideally be in user.py, but keep for now if working)
from app.core.security import hash_password
from app.schemas.user import UserCreate # Assuming create_user uses this
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]:
result = await db.execute(select(UserModel).filter(UserModel.email == email))
return result.scalars().first()
async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel:
_hashed_password = hash_password(user_in.password)
db_user = UserModel(
email=user_in.email,
password_hash=_hashed_password, # Use correct keyword argument
name=user_in.name
)
db.add(db_user)
await db.commit()
await db.refresh(db_user)
return db_user
# --- End User CRUD ---
logger = logging.getLogger(__name__) # Initialize logger
# --- Group CRUD ---
async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) -> GroupModel:
"""Creates a group and adds the creator as the owner."""
try:
# Use the composability pattern for transactions as per fastapi-db-strategy.
# This creates a savepoint if already in a transaction (e.g., from get_transactional_session)
# or starts a new transaction if called outside of one (e.g., from a script).
async with db.begin_nested() if db.in_transaction() else db.begin():
db_group = GroupModel(name=group_in.name, created_by_id=creator_id)
db.add(db_group)
await db.flush() # Flush to get the db_group.id for the UserGroup entry
await db.flush() # Assigns ID to db_group
# Add creator as owner
db_user_group = UserGroupModel(
user_id=creator_id,
group_id=db_group.id,
role=UserRoleEnum.owner # Use the Enum member
role=UserRoleEnum.owner
)
db.add(db_user_group)
await db.flush() # Commits user_group, links to group
await db.commit()
await db.refresh(db_group)
return db_group
# After creation and linking, explicitly load the group with its member associations and users
stmt = (
select(GroupModel)
.where(GroupModel.id == db_group.id)
.options(
selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user)
)
)
result = await db.execute(stmt)
loaded_group = result.scalar_one_or_none()
if loaded_group is None:
# This should not happen if we just created it, but as a safeguard
raise GroupOperationError("Failed to load group after creation.")
return loaded_group
except IntegrityError as e:
logger.error(f"Database integrity error during group creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to create group due to integrity issue: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during group creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error during group creation: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during group creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Database transaction error during group creation: {str(e)}")
async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]:
"""Gets all groups a user is a member of."""
"""Gets all groups a user is a member of with optimized eager loading."""
try:
result = await db.execute(
select(GroupModel)
.join(UserGroupModel)
.where(UserGroupModel.user_id == user_id)
.options(selectinload(GroupModel.member_associations)) # Optional: preload associations if needed often
.options(
selectinload(GroupModel.member_associations).options(
selectinload(UserGroupModel.user)
)
)
)
return result.scalars().all()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query user groups: {str(e)}")
async def get_group_by_id(db: AsyncSession, group_id: int) -> Optional[GroupModel]:
"""Gets a single group by its ID, optionally loading members."""
# Use selectinload to eager load members and their user details
try:
result = await db.execute(
select(GroupModel)
.where(GroupModel.id == group_id)
@ -71,53 +99,201 @@ async def get_group_by_id(db: AsyncSession, group_id: int) -> Optional[GroupMode
)
)
return result.scalars().first()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query group: {str(e)}")
async def is_user_member(db: AsyncSession, group_id: int, user_id: int) -> bool:
"""Checks if a user is a member of a specific group."""
try:
result = await db.execute(
select(UserGroupModel.id) # Select just one column for existence check
select(UserGroupModel.id)
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
.limit(1)
)
return result.scalar_one_or_none() is not None
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to check group membership: {str(e)}")
async def get_user_role_in_group(db: AsyncSession, group_id: int, user_id: int) -> Optional[UserRoleEnum]:
"""Gets the role of a user in a specific group."""
try:
result = await db.execute(
select(UserGroupModel.role)
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
)
role = result.scalar_one_or_none()
return role # Will be None if not a member, or the UserRoleEnum value
return result.scalar_one_or_none()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query user role: {str(e)}")
async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role: UserRoleEnum = UserRoleEnum.member) -> Optional[UserGroupModel]:
"""Adds a user to a group if they aren't already a member."""
# Check if already exists
existing = await db.execute(
select(UserGroupModel).where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
)
if existing.scalar_one_or_none():
return None # Indicate user already in group
try:
# Check if user is already a member before starting a transaction
existing_stmt = select(UserGroupModel.id).where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
existing_result = await db.execute(existing_stmt)
if existing_result.scalar_one_or_none():
return None
# Use a single transaction
async with db.begin_nested() if db.in_transaction() else db.begin():
db_user_group = UserGroupModel(user_id=user_id, group_id=group_id, role=role)
db.add(db_user_group)
await db.commit()
await db.refresh(db_user_group)
return db_user_group
await db.flush() # Assigns ID to db_user_group
# Eagerly load the 'user' and 'group' relationships for the response
stmt = (
select(UserGroupModel)
.where(UserGroupModel.id == db_user_group.id)
.options(
selectinload(UserGroupModel.user),
selectinload(UserGroupModel.group)
)
)
result = await db.execute(stmt)
loaded_user_group = result.scalar_one_or_none()
if loaded_user_group is None:
raise GroupOperationError(f"Failed to load user group association after adding user {user_id} to group {group_id}.")
return loaded_user_group
except IntegrityError as e:
logger.error(f"Database integrity error while adding user to group: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to add user to group: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error while adding user to group: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error while adding user to group: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to add user to group: {str(e)}")
async def remove_user_from_group(db: AsyncSession, group_id: int, user_id: int) -> bool:
"""Removes a user from a group."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin():
result = await db.execute(
delete(UserGroupModel)
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
.returning(UserGroupModel.id) # Optional: check if a row was actually deleted
.returning(UserGroupModel.id)
)
await db.commit()
return result.scalar_one_or_none() is not None # True if deletion happened
return result.scalar_one_or_none() is not None
except OperationalError as e:
logger.error(f"Database connection error while removing user from group: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error while removing user from group: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to remove user from group: {str(e)}")
async def get_group_member_count(db: AsyncSession, group_id: int) -> int:
"""Counts the number of members in a group."""
try:
result = await db.execute(
select(func.count(UserGroupModel.id)).where(UserGroupModel.group_id == group_id)
)
return result.scalar_one()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to count group members: {str(e)}")
async def check_group_membership(
db: AsyncSession,
group_id: int,
user_id: int,
action: str = "access this group"
) -> None:
"""
Checks if a user is a member of a group. Raises exceptions if not found or not a member.
Raises:
GroupNotFoundError: If the group_id does not exist.
GroupMembershipError: If the user_id is not a member of the group.
"""
try:
# Check group existence first
group_exists = await db.get(GroupModel, group_id)
if not group_exists:
raise GroupNotFoundError(group_id)
# Check membership
membership = await db.execute(
select(UserGroupModel.id)
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
.limit(1)
)
if membership.scalar_one_or_none() is None:
raise GroupMembershipError(group_id, action=action)
# If we reach here, the user is a member
return None
except GroupNotFoundError: # Re-raise specific errors
raise
except GroupMembershipError:
raise
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database while checking membership: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to check group membership: {str(e)}")
async def check_user_role_in_group(
db: AsyncSession,
group_id: int,
user_id: int,
required_role: UserRoleEnum,
action: str = "perform this action"
) -> None:
"""
Checks if a user is a member of a group and has the required role (or higher).
Raises:
GroupNotFoundError: If the group_id does not exist.
GroupMembershipError: If the user_id is not a member of the group.
GroupPermissionError: If the user does not have the required role.
"""
# First, ensure user is a member (this also checks group existence)
await check_group_membership(db, group_id, user_id, action=f"be checked for permissions to {action}")
# Get the user's actual role
actual_role = await get_user_role_in_group(db, group_id, user_id)
# Define role hierarchy (assuming owner > member)
role_hierarchy = {UserRoleEnum.owner: 2, UserRoleEnum.member: 1}
if not actual_role or role_hierarchy.get(actual_role, 0) < role_hierarchy.get(required_role, 0):
raise GroupPermissionError(
group_id=group_id,
action=f"{action} (requires at least '{required_role.value}' role)"
)
# If role is sufficient, return None
return None
async def delete_group(db: AsyncSession, group_id: int) -> None:
"""
Deletes a group and all its associated data (members, invites, lists, etc.).
The cascade delete in the models will handle the deletion of related records.
Raises:
GroupNotFoundError: If the group doesn't exist.
DatabaseError: If there's an error during deletion.
"""
try:
# Get the group first to ensure it exists
group = await get_group_by_id(db, group_id)
if not group:
raise GroupNotFoundError(group_id)
# Delete the group - cascading delete will handle related records
await db.delete(group)
await db.flush()
logger.info(f"Group {group_id} deleted successfully")
except OperationalError as e:
logger.error(f"Database connection error while deleting group {group_id}: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error while deleting group {group_id}: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to delete group: {str(e)}")

View File

@ -1,69 +1,199 @@
# app/crud/invite.py
import logging # Add logging import
import secrets
from datetime import datetime, timedelta, timezone
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
from sqlalchemy import delete # Import delete statement
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
from typing import Optional
from app.models import Invite as InviteModel
from app.models import Invite as InviteModel, Group as GroupModel, User as UserModel # Import related models for selectinload
from app.core.exceptions import (
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
InviteOperationError # Add new specific exception
)
logger = logging.getLogger(__name__) # Initialize logger
# Invite codes should be reasonably unique, but handle potential collision
MAX_CODE_GENERATION_ATTEMPTS = 5
async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 7) -> Optional[InviteModel]:
"""Creates a new invite code for a group."""
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
code = None
attempts = 0
# Generate a unique code, retrying if a collision occurs (highly unlikely but safe)
while attempts < MAX_CODE_GENERATION_ATTEMPTS:
attempts += 1
potential_code = secrets.token_urlsafe(16)
# Check if an *active* invite with this code already exists
existing = await db.execute(
select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
async def deactivate_all_active_invites_for_group(db: AsyncSession, group_id: int):
"""Deactivates all currently active invite codes for a specific group."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin():
stmt = (
select(InviteModel)
.where(InviteModel.group_id == group_id, InviteModel.is_active == True)
)
if existing.scalar_one_or_none() is None:
code = potential_code
break
result = await db.execute(stmt)
active_invites = result.scalars().all()
if code is None:
# Failed to generate a unique code after several attempts
return None
if not active_invites:
return # No active invites to deactivate
for invite in active_invites:
invite.is_active = False
db.add(invite)
await db.flush() # Flush changes within this transaction block
# await db.flush() # Removed: Rely on caller to flush/commit
# No explicit commit here, assuming it's part of a larger transaction or caller handles commit.
except OperationalError as e:
logger.error(f"Database connection error deactivating invites for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error deactivating invites for group {group_id}: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error deactivating invites for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error deactivating invites for group {group_id}: {str(e)}")
async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 365 * 100) -> Optional[InviteModel]: # Default to 100 years
"""Creates a new invite code for a group, deactivating any existing active ones for that group first."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin():
# Deactivate existing active invites for this group
await deactivate_all_active_invites_for_group(db, group_id)
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
potential_code = None
for attempt in range(MAX_CODE_GENERATION_ATTEMPTS):
potential_code = secrets.token_urlsafe(16)
existing_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
existing_result = await db.execute(existing_check_stmt)
if existing_result.scalar_one_or_none() is None:
break
if attempt == MAX_CODE_GENERATION_ATTEMPTS - 1:
raise InviteOperationError("Failed to generate a unique invite code after several attempts.")
final_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
final_check_result = await db.execute(final_check_stmt)
if final_check_result.scalar_one_or_none() is not None:
raise InviteOperationError("Invite code collision detected just before creation attempt.")
db_invite = InviteModel(
code=code,
code=potential_code,
group_id=group_id,
created_by_id=creator_id,
expires_at=expires_at,
is_active=True
)
db.add(db_invite)
await db.commit()
await db.refresh(db_invite)
return db_invite
await db.flush()
stmt = (
select(InviteModel)
.where(InviteModel.id == db_invite.id)
.options(
selectinload(InviteModel.group),
selectinload(InviteModel.creator)
)
)
result = await db.execute(stmt)
loaded_invite = result.scalar_one_or_none()
if loaded_invite is None:
raise InviteOperationError("Failed to load invite after creation and flush.")
return loaded_invite
except InviteOperationError: # Already specific, re-raise
raise
except IntegrityError as e:
logger.error(f"Database integrity error during invite creation for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to create invite due to DB integrity issue: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during invite creation for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error during invite creation: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during invite creation for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error during invite creation: {str(e)}")
async def get_active_invite_for_group(db: AsyncSession, group_id: int) -> Optional[InviteModel]:
"""Gets the currently active and non-expired invite for a specific group."""
now = datetime.now(timezone.utc)
try:
stmt = (
select(InviteModel).where(
InviteModel.group_id == group_id,
InviteModel.is_active == True,
InviteModel.expires_at > now # Still respect expiry, even if very long
)
.order_by(InviteModel.created_at.desc()) # Get the most recent one if multiple (should not happen)
.limit(1)
.options(
selectinload(InviteModel.group), # Eager load group
selectinload(InviteModel.creator) # Eager load creator
)
)
result = await db.execute(stmt)
return result.scalars().first()
except OperationalError as e:
logger.error(f"Database connection error fetching active invite for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error fetching active invite for group {group_id}: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"DB query error fetching active invite for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseQueryError(f"DB query error fetching active invite for group {group_id}: {str(e)}")
async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]:
"""Gets an active and non-expired invite by its code."""
now = datetime.now(timezone.utc)
result = await db.execute(
try:
stmt = (
select(InviteModel).where(
InviteModel.code == code,
InviteModel.is_active == True,
InviteModel.expires_at > now
)
.options(
selectinload(InviteModel.group),
selectinload(InviteModel.creator)
)
)
result = await db.execute(stmt)
return result.scalars().first()
except OperationalError as e:
raise DatabaseConnectionError(f"DB connection error fetching invite: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"DB query error fetching invite: {str(e)}")
async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteModel:
"""Marks an invite as inactive (used)."""
"""Marks an invite as inactive (used) and reloads with relationships."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
invite.is_active = False
db.add(invite) # Add to session to track change
await db.commit()
await db.refresh(invite)
return invite
await db.flush() # Persist is_active change
# Re-fetch with relationships
stmt = (
select(InviteModel)
.where(InviteModel.id == invite.id)
.options(
selectinload(InviteModel.group),
selectinload(InviteModel.creator)
)
)
result = await db.execute(stmt)
updated_invite = result.scalar_one_or_none()
if updated_invite is None: # Should not happen as invite is passed in
raise InviteOperationError("Failed to load invite after deactivation.")
return updated_invite
except OperationalError as e:
logger.error(f"Database connection error deactivating invite: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error deactivating invite: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error deactivating invite: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error deactivating invite: {str(e)}")
# Ensure InviteOperationError is defined in app.core.exceptions
# Example: class InviteOperationError(AppException): pass
# Optional: Function to periodically delete old, inactive invites
# async def cleanup_old_invites(db: AsyncSession, older_than_days: int = 30): ...

View File

@ -1,69 +1,209 @@
# app/crud/item.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
from sqlalchemy import delete as sql_delete, update as sql_update # Use aliases
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List as PyList
from datetime import datetime, timezone
from app.models import Item as ItemModel
import logging # Add logging import
from sqlalchemy import func
from app.models import Item as ItemModel, User as UserModel # Import UserModel for type hints if needed for selectinload
from app.schemas.item import ItemCreate, ItemUpdate
from app.core.exceptions import (
ItemNotFoundError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
ConflictError,
ItemOperationError # Add if specific item operation errors are needed
)
logger = logging.getLogger(__name__) # Initialize logger
async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel:
"""Creates a new item record for a specific list."""
"""Creates a new item record for a specific list, setting its position."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
# Get the current max position in the list
max_pos_stmt = select(func.max(ItemModel.position)).where(ItemModel.list_id == list_id)
max_pos_result = await db.execute(max_pos_stmt)
max_pos = max_pos_result.scalar_one_or_none() or 0
db_item = ItemModel(
name=item_in.name,
quantity=item_in.quantity,
list_id=list_id,
added_by_id=user_id,
is_complete=False # Default on creation
is_complete=False,
position=max_pos + 1 # Set the new position
)
db.add(db_item)
await db.commit()
await db.refresh(db_item)
return db_item
await db.flush() # Assigns ID
# Re-fetch with relationships
stmt = (
select(ItemModel)
.where(ItemModel.id == db_item.id)
.options(
selectinload(ItemModel.added_by_user),
selectinload(ItemModel.completed_by_user) # Will be None but loads relationship
)
)
result = await db.execute(stmt)
loaded_item = result.scalar_one_or_none()
if loaded_item is None:
# await transaction.rollback() # Redundant, context manager handles rollback on exception
raise ItemOperationError("Failed to load item after creation.") # Define ItemOperationError
return loaded_item
except IntegrityError as e:
logger.error(f"Database integrity error during item creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to create item: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during item creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during item creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to create item: {str(e)}")
# Removed generic Exception block as SQLAlchemyError should cover DB issues,
# and context manager handles rollback.
async def get_items_by_list_id(db: AsyncSession, list_id: int) -> PyList[ItemModel]:
"""Gets all items belonging to a specific list, ordered by creation time."""
result = await db.execute(
try:
stmt = (
select(ItemModel)
.where(ItemModel.list_id == list_id)
.order_by(ItemModel.created_at.asc()) # Or desc() if preferred
.options(
selectinload(ItemModel.added_by_user),
selectinload(ItemModel.completed_by_user)
)
.order_by(ItemModel.created_at.asc())
)
result = await db.execute(stmt)
return result.scalars().all()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query items: {str(e)}")
async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]:
"""Gets a single item by its ID."""
result = await db.execute(select(ItemModel).where(ItemModel.id == item_id))
try:
stmt = (
select(ItemModel)
.where(ItemModel.id == item_id)
.options(
selectinload(ItemModel.added_by_user),
selectinload(ItemModel.completed_by_user),
selectinload(ItemModel.list) # Often useful to get the parent list
)
)
result = await db.execute(stmt)
return result.scalars().first()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query item: {str(e)}")
async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel:
"""Updates an existing item record."""
update_data = item_in.model_dump(exclude_unset=True)
now_utc = datetime.now(timezone.utc)
"""Updates an existing item record, checking for version conflicts and handling reordering."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
if item_db.version != item_in.version:
raise ConflictError(
f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. "
f"Your version is {item_in.version}, current version is {item_db.version}. Please refresh."
)
update_data = item_in.model_dump(exclude_unset=True, exclude={'version'})
# --- Handle Reordering ---
if 'position' in update_data:
new_position = update_data.pop('position') # Remove from update_data to handle separately
# We need the full list to reorder, making sure it's loaded and ordered
list_id = item_db.list_id
stmt = select(ItemModel).where(ItemModel.list_id == list_id).order_by(ItemModel.position.asc(), ItemModel.created_at.asc())
result = await db.execute(stmt)
items_in_list = result.scalars().all()
# Find the item to move
item_to_move = next((it for it in items_in_list if it.id == item_db.id), None)
if item_to_move:
items_in_list.remove(item_to_move)
# Insert at the new position (adjust for 1-based index from frontend)
# Clamp position to be within bounds
insert_pos = max(0, min(new_position - 1, len(items_in_list)))
items_in_list.insert(insert_pos, item_to_move)
# Re-assign positions
for i, item in enumerate(items_in_list):
item.position = i + 1
# --- End Handle Reordering ---
if 'is_complete' in update_data:
if update_data['is_complete'] is True and item_db.completed_by_id is None:
if update_data['is_complete'] is True:
if item_db.completed_by_id is None:
update_data['completed_by_id'] = user_id
elif update_data['is_complete'] is False:
update_data['completed_by_id'] = None
if 'price' in update_data:
if update_data['price'] is not None:
update_data['price_added_by_id'] = user_id
update_data['price_added_at'] = now_utc
else:
update_data['price_added_by_id'] = None
update_data['price_added_at'] = None
update_data['completed_by_id'] = None
for key, value in update_data.items():
setattr(item_db, key, value)
db.add(item_db)
await db.commit()
await db.refresh(item_db)
return item_db
item_db.version += 1
db.add(item_db) # Mark as dirty
await db.flush()
# Re-fetch with relationships
stmt = (
select(ItemModel)
.where(ItemModel.id == item_db.id)
.options(
selectinload(ItemModel.added_by_user),
selectinload(ItemModel.completed_by_user),
selectinload(ItemModel.list)
)
)
result = await db.execute(stmt)
updated_item = result.scalar_one_or_none()
if updated_item is None: # Should not happen
# Rollback will be handled by context manager on raise
raise ItemOperationError("Failed to load item after update.")
return updated_item
except IntegrityError as e:
logger.error(f"Database integrity error during item update: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error while updating item: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}")
except ConflictError: # Re-raise ConflictError, rollback handled by context manager
raise
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during item update: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to update item: {str(e)}")
async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
"""Deletes an item record."""
"""Deletes an item record. Version check should be done by the caller (API endpoint)."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
await db.delete(item_db)
await db.commit()
return None # Or return True/False
# await transaction.commit() # Removed
# No return needed for None
except OperationalError as e:
logger.error(f"Database connection error while deleting item: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error while deleting item: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to delete item: {str(e)}")
# Ensure ItemOperationError is defined in app.core.exceptions if used
# Example: class ItemOperationError(AppException): pass

View File

@ -2,150 +2,351 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import or_, and_, delete as sql_delete # Use alias for delete
from typing import Optional, List as PyList # Use alias for List
from sqlalchemy import func as sql_func, desc # Import func and desc
from sqlalchemy import or_, and_, delete as sql_delete, func as sql_func, desc
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List as PyList
import logging # Add logging import
from app.schemas.list import ListStatus # Import the new schema
from app.schemas.list import ListStatus
from app.models import List as ListModel, UserGroup as UserGroupModel, Item as ItemModel
from app.schemas.list import ListCreate, ListUpdate
from app.core.exceptions import (
ListNotFoundError,
ListPermissionError,
ListCreatorRequiredError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
ConflictError,
ListOperationError
)
logger = logging.getLogger(__name__) # Initialize logger
async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
"""Creates a new list record."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
db_list = ListModel(
name=list_in.name,
description=list_in.description,
group_id=list_in.group_id,
created_by_id=creator_id,
is_complete=False # Default on creation
is_complete=False
)
db.add(db_list)
await db.commit()
await db.refresh(db_list)
return db_list
await db.flush() # Assigns ID
# Re-fetch with relationships for the response
stmt = (
select(ListModel)
.where(ListModel.id == db_list.id)
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
# selectinload(ListModel.items) # Optionally add if items are always needed in response
)
)
result = await db.execute(stmt)
loaded_list = result.scalar_one_or_none()
if loaded_list is None:
raise ListOperationError("Failed to load list after creation.")
return loaded_list
except IntegrityError as e:
logger.error(f"Database integrity error during list creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to create list: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during list creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during list creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to create list: {str(e)}")
async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel]:
"""
Gets all lists accessible by a user:
- Personal lists created by the user (group_id is NULL).
- Lists belonging to groups the user is a member of.
"""
# Get IDs of groups the user is a member of
"""Gets all lists accessible by a user."""
try:
group_ids_result = await db.execute(
select(UserGroupModel.group_id).where(UserGroupModel.user_id == user_id)
)
user_group_ids = group_ids_result.scalars().all()
# Query for lists
query = select(ListModel).where(
or_(
# Personal lists
and_(ListModel.created_by_id == user_id, ListModel.group_id == None),
# Group lists where user is a member
ListModel.group_id.in_(user_group_ids)
conditions = [
and_(ListModel.created_by_id == user_id, ListModel.group_id.is_(None))
]
if user_group_ids:
conditions.append(ListModel.group_id.in_(user_group_ids))
query = (
select(ListModel)
.where(or_(*conditions))
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group),
selectinload(ListModel.items).options(
joinedload(ItemModel.added_by_user),
joinedload(ItemModel.completed_by_user)
)
)
.order_by(ListModel.updated_at.desc())
)
).order_by(ListModel.updated_at.desc()) # Order by most recently updated
result = await db.execute(query)
return result.scalars().all()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query user lists: {str(e)}")
async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = False) -> Optional[ListModel]:
"""Gets a single list by ID, optionally loading its items."""
query = select(ListModel).where(ListModel.id == list_id)
if load_items:
# Eager load items and their creators/completers if needed
query = query.options(
selectinload(ListModel.items)
try:
query = (
select(ListModel)
.where(ListModel.id == list_id)
.options(
joinedload(ItemModel.added_by_user), # Use joinedload for simple FKs
selectinload(ListModel.creator),
selectinload(ListModel.group)
)
)
if load_items:
query = query.options(
selectinload(ListModel.items).options(
joinedload(ItemModel.added_by_user),
joinedload(ItemModel.completed_by_user)
)
)
result = await db.execute(query)
return result.scalars().first()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query list: {str(e)}")
async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel:
"""Updates an existing list record."""
update_data = list_in.model_dump(exclude_unset=True) # Get only provided fields
"""Updates an existing list record, checking for version conflicts."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
if list_db.version != list_in.version: # list_db here is the one passed in, pre-loaded by API layer
raise ConflictError(
f"List '{list_db.name}' (ID: {list_db.id}) has been modified. "
f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh."
)
update_data = list_in.model_dump(exclude_unset=True, exclude={'version'})
for key, value in update_data.items():
setattr(list_db, key, value)
db.add(list_db) # Add to session to track changes
await db.commit()
await db.refresh(list_db)
return list_db
list_db.version += 1
db.add(list_db) # Add the already attached list_db to mark it dirty for the session
await db.flush()
# Re-fetch with relationships for the response
stmt = (
select(ListModel)
.where(ListModel.id == list_db.id)
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
# selectinload(ListModel.items) # Optionally add if items are always needed in response
)
)
result = await db.execute(stmt)
updated_list = result.scalar_one_or_none()
if updated_list is None: # Should not happen
raise ListOperationError("Failed to load list after update.")
return updated_list
except IntegrityError as e:
logger.error(f"Database integrity error during list update: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error while updating list: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}")
except ConflictError:
raise
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during list update: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to update list: {str(e)}")
async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
"""Deletes a list record."""
# Items should be deleted automatically due to cascade="all, delete-orphan"
# on List.items relationship and ondelete="CASCADE" on Item.list_id FK
"""Deletes a list record. Version check should be done by the caller (API endpoint)."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Standardize transaction
await db.delete(list_db)
await db.commit()
return None # Or return True/False if needed
except OperationalError as e:
logger.error(f"Database connection error while deleting list: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error while deleting list: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to delete list: {str(e)}")
# --- Helper for Permission Checks ---
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> Optional[ListModel]:
"""
Fetches a list and verifies user permission.
Args:
db: Database session.
list_id: The ID of the list to check.
user_id: The ID of the user requesting access.
require_creator: If True, only allows the creator access.
Returns:
The ListModel if found and permission granted, otherwise None.
(Raising exceptions might be better handled in the endpoint).
"""
list_db = await get_list_by_id(db, list_id=list_id, load_items=True) # Load items for detail/update/delete context
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel:
"""Fetches a list and verifies user permission."""
try:
list_db = await get_list_by_id(db, list_id=list_id, load_items=True)
if not list_db:
return None # List not found
raise ListNotFoundError(list_id)
# Check if user is the creator
is_creator = list_db.created_by_id == user_id
if require_creator:
return list_db if is_creator else None
if not is_creator:
raise ListCreatorRequiredError(list_id, "access")
return list_db
# If not requiring creator, check membership if it's a group list
if is_creator:
return list_db # Creator always has access
return list_db
if list_db.group_id:
# Check if user is member of the list's group
from app.crud.group import is_user_member # Avoid circular import at top level
from app.crud.group import is_user_member
is_member = await is_user_member(db, group_id=list_db.group_id, user_id=user_id)
return list_db if is_member else None
if not is_member:
raise ListPermissionError(list_id)
return list_db
else:
# Personal list, not the creator -> no access
return None
raise ListPermissionError(list_id)
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to check list permissions: {str(e)}")
async def get_list_status(db: AsyncSession, list_id: int) -> Optional[ListStatus]:
"""
Gets the update timestamps and item count for a list.
Returns None if the list itself doesn't exist.
"""
# Fetch list updated_at time
list_query = select(ListModel.updated_at).where(ListModel.id == list_id)
list_result = await db.execute(list_query)
list_updated_at = list_result.scalar_one_or_none()
if list_updated_at is None:
return None # List not found
# Fetch the latest item update time and count for that list
item_status_query = (
async def get_list_status(db: AsyncSession, list_id: int) -> ListStatus:
"""Gets the update timestamps and item count for a list."""
try:
query = (
select(
sql_func.max(ItemModel.updated_at).label("latest_item_updated_at"),
sql_func.count(ItemModel.id).label("item_count")
ListModel.updated_at,
sql_func.count(ItemModel.id).label("item_count"),
sql_func.max(ItemModel.updated_at).label("latest_item_updated_at")
)
.where(ItemModel.list_id == list_id)
.select_from(ListModel)
.outerjoin(ItemModel, ItemModel.list_id == ListModel.id)
.where(ListModel.id == list_id)
.group_by(ListModel.id)
)
item_result = await db.execute(item_status_query)
item_status = item_result.first() # Use first() as aggregate always returns one row
result = await db.execute(query)
status = result.first()
if status is None:
raise ListNotFoundError(list_id)
return ListStatus(
list_updated_at=list_updated_at,
latest_item_updated_at=item_status.latest_item_updated_at if item_status else None,
item_count=item_status.item_count if item_status else 0
updated_at=status.updated_at,
item_count=status.item_count,
latest_item_updated_at=status.latest_item_updated_at
)
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to get list status: {str(e)}")
async def get_list_by_name_and_group(
db: AsyncSession,
name: str,
group_id: Optional[int],
user_id: int # user_id is for permission check, not direct list attribute
) -> Optional[ListModel]:
"""
Gets a list by name and group, ensuring the user has permission to access it.
Used for conflict resolution when creating lists.
"""
try:
# Base query for the list itself
base_query = select(ListModel).where(ListModel.name == name)
if group_id is not None:
base_query = base_query.where(ListModel.group_id == group_id)
else:
base_query = base_query.where(ListModel.group_id.is_(None))
# Add eager loading for common relationships
base_query = base_query.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
)
list_result = await db.execute(base_query)
target_list = list_result.scalar_one_or_none()
if not target_list:
return None
# Permission check
is_creator = target_list.created_by_id == user_id
if is_creator:
return target_list
if target_list.group_id:
from app.crud.group import is_user_member # Assuming this is a quick check not needing its own transaction
is_member_of_group = await is_user_member(db, group_id=target_list.group_id, user_id=user_id)
if is_member_of_group:
return target_list
# If not creator and (not a group list or not a member of the group list)
return None
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query list by name and group: {str(e)}")
async def get_lists_statuses_by_ids(db: AsyncSession, list_ids: PyList[int], user_id: int) -> PyList[ListModel]:
"""
Gets status for a list of lists if the user has permission.
Status includes list updated_at and a count of its items.
"""
if not list_ids:
return []
try:
# First, get the groups the user is a member of
group_ids_result = await db.execute(
select(UserGroupModel.group_id).where(UserGroupModel.user_id == user_id)
)
user_group_ids = group_ids_result.scalars().all()
# Build the permission logic
permission_filter = or_(
# User is the creator of the list
and_(ListModel.created_by_id == user_id, ListModel.group_id.is_(None)),
# List belongs to a group the user is a member of
ListModel.group_id.in_(user_group_ids)
)
# Main query to get list data and item counts
query = (
select(
ListModel.id,
ListModel.updated_at,
sql_func.count(ItemModel.id).label("item_count"),
sql_func.max(ItemModel.updated_at).label("latest_item_updated_at")
)
.outerjoin(ItemModel, ListModel.id == ItemModel.list_id)
.where(
and_(
ListModel.id.in_(list_ids),
permission_filter
)
)
.group_by(ListModel.id)
)
result = await db.execute(query)
# The result will be rows of (id, updated_at, item_count).
# We need to verify that all requested list_ids that the user *should* have access to are present.
# The filter in the query already handles permissions.
return result.all() # Returns a list of Row objects with id, updated_at, item_count
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to get lists statuses: {str(e)}")

281
be/app/crud/settlement.py Normal file
View File

@ -0,0 +1,281 @@
# app/crud/settlement.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import or_
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
from decimal import Decimal, ROUND_HALF_UP
from typing import List as PyList, Optional, Sequence
from datetime import datetime, timezone
import logging # Add logging import
from app.models import (
Settlement as SettlementModel,
User as UserModel,
Group as GroupModel,
UserGroup as UserGroupModel
)
from app.schemas.expense import SettlementCreate, SettlementUpdate
from app.core.exceptions import (
UserNotFoundError,
GroupNotFoundError,
InvalidOperationError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
SettlementOperationError,
ConflictError
)
logger = logging.getLogger(__name__) # Initialize logger
async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel:
"""Creates a new settlement record."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin():
payer = await db.get(UserModel, settlement_in.paid_by_user_id)
if not payer:
raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer")
payee = await db.get(UserModel, settlement_in.paid_to_user_id)
if not payee:
raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee")
if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id:
raise InvalidOperationError("Payer and Payee cannot be the same user.")
group = await db.get(GroupModel, settlement_in.group_id)
if not group:
raise GroupNotFoundError(settlement_in.group_id)
# Permission check example (can be in API layer too)
# if current_user_id not in [payer.id, payee.id]:
# is_member_stmt = select(UserGroupModel.id).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id).limit(1)
# is_member_result = await db.execute(is_member_stmt)
# if not is_member_result.scalar_one_or_none():
# raise InvalidOperationError("Settlement recorder must be part of the group or one of the parties.")
db_settlement = SettlementModel(
group_id=settlement_in.group_id,
paid_by_user_id=settlement_in.paid_by_user_id,
paid_to_user_id=settlement_in.paid_to_user_id,
amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc),
description=settlement_in.description,
created_by_user_id=current_user_id
)
db.add(db_settlement)
await db.flush()
# Re-fetch with relationships
stmt = (
select(SettlementModel)
.where(SettlementModel.id == db_settlement.id)
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
)
)
result = await db.execute(stmt)
loaded_settlement = result.scalar_one_or_none()
if loaded_settlement is None:
raise SettlementOperationError("Failed to load settlement after creation.")
return loaded_settlement
except (UserNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
# These are validation errors, re-raise them.
# If a transaction was started, context manager handles rollback.
raise
except IntegrityError as e:
logger.error(f"Database integrity error during settlement creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to save settlement due to DB integrity: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during settlement creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error during settlement creation: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during settlement creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error during settlement creation: {str(e)}")
async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional[SettlementModel]:
try:
result = await db.execute(
select(SettlementModel)
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
)
.where(SettlementModel.id == settlement_id)
)
return result.scalars().first()
except OperationalError as e:
# Optional: logger.warning or info if needed for read operations
raise DatabaseConnectionError(f"DB connection error fetching settlement: {str(e)}")
except SQLAlchemyError as e:
# Optional: logger.warning or info if needed for read operations
raise DatabaseQueryError(f"DB query error fetching settlement: {str(e)}")
async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]:
try:
result = await db.execute(
select(SettlementModel)
.where(SettlementModel.group_id == group_id)
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
.offset(skip).limit(limit)
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
)
)
return result.scalars().all()
except OperationalError as e:
raise DatabaseConnectionError(f"DB connection error fetching group settlements: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"DB query error fetching group settlements: {str(e)}")
async def get_settlements_involving_user(
db: AsyncSession,
user_id: int,
group_id: Optional[int] = None,
skip: int = 0,
limit: int = 100
) -> Sequence[SettlementModel]:
try:
query = (
select(SettlementModel)
.where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id))
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
.offset(skip).limit(limit)
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
)
)
if group_id:
query = query.where(SettlementModel.group_id == group_id)
result = await db.execute(query)
return result.scalars().all()
except OperationalError as e:
raise DatabaseConnectionError(f"DB connection error fetching user settlements: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"DB query error fetching user settlements: {str(e)}")
async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, settlement_in: SettlementUpdate) -> SettlementModel:
"""
Updates an existing settlement.
Only allows updates to description and settlement_date.
Requires version matching for optimistic locking.
Assumes SettlementUpdate schema includes a version field.
Assumes SettlementModel has version and updated_at fields.
"""
try:
async with db.begin_nested() if db.in_transaction() else db.begin():
# Ensure the settlement_db passed is managed by the current session if not already.
# This is usually true if fetched by an endpoint dependency using the same session.
# If not, `db.add(settlement_db)` might be needed before modification if it's detached.
if not hasattr(settlement_db, 'version') or not hasattr(settlement_in, 'version'):
raise InvalidOperationError("Version field is missing in model or input for optimistic locking.")
if settlement_db.version != settlement_in.version:
raise ConflictError( # Make sure ConflictError is defined in exceptions
f"Settlement (ID: {settlement_db.id}) has been modified. "
f"Your version {settlement_in.version} does not match current version {settlement_db.version}. Please refresh."
)
update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"})
allowed_to_update = {"description", "settlement_date"}
updated_something = False
for field, value in update_data.items():
if field in allowed_to_update:
setattr(settlement_db, field, value)
updated_something = True
# Silently ignore fields not allowed to update or raise error:
# else:
# raise InvalidOperationError(f"Field '{field}' cannot be updated.")
if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update):
# No updatable fields were actually provided, or they didn't change
# Still, we might want to return the re-loaded settlement if version matched.
pass
settlement_db.version += 1
settlement_db.updated_at = datetime.now(timezone.utc) # Ensure model has this field
db.add(settlement_db) # Mark as dirty
await db.flush()
# Re-fetch with relationships
stmt = (
select(SettlementModel)
.where(SettlementModel.id == settlement_db.id)
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group),
selectinload(SettlementModel.created_by_user)
)
)
result = await db.execute(stmt)
updated_settlement = result.scalar_one_or_none()
if updated_settlement is None: # Should not happen
raise SettlementOperationError("Failed to load settlement after update.")
return updated_settlement
except ConflictError as e: # ConflictError should be defined in exceptions
raise
except InvalidOperationError as e:
raise
except IntegrityError as e:
logger.error(f"Database integrity error during settlement update: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to update settlement due to DB integrity: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during settlement update: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error during settlement update: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during settlement update: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error during settlement update: {str(e)}")
async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, expected_version: Optional[int] = None) -> None:
"""
Deletes a settlement. Requires version matching if expected_version is provided.
Assumes SettlementModel has a version field.
"""
try:
async with db.begin_nested() if db.in_transaction() else db.begin():
if expected_version is not None:
if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version:
raise ConflictError( # Make sure ConflictError is defined
f"Settlement (ID: {settlement_db.id}) cannot be deleted. "
f"Expected version {expected_version} does not match current version {settlement_db.version}. Please refresh."
)
await db.delete(settlement_db)
except ConflictError as e: # ConflictError should be defined
raise
except OperationalError as e:
logger.error(f"Database connection error during settlement deletion: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error during settlement deletion: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during settlement deletion: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error during settlement deletion: {str(e)}")
# Ensure SettlementOperationError and ConflictError are defined in app.core.exceptions
# Example: class SettlementOperationError(AppException): pass
# Example: class ConflictError(AppException): status_code = 409

View File

@ -0,0 +1,211 @@
from typing import List, Optional
from decimal import Decimal
from datetime import datetime, timezone
from sqlalchemy import select, func, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload, joinedload
from app.models import (
SettlementActivity,
ExpenseSplit,
Expense,
User,
ExpenseSplitStatusEnum,
ExpenseOverallStatusEnum,
)
# Placeholder for Pydantic schema - actual schema definition is a later step
# from app.schemas.settlement_activity import SettlementActivityCreate # Assuming this path
from pydantic import BaseModel # Using pydantic BaseModel directly for the placeholder
class SettlementActivityCreatePlaceholder(BaseModel):
expense_split_id: int
paid_by_user_id: int
amount_paid: Decimal
paid_at: Optional[datetime] = None
class Config:
orm_mode = True # Pydantic V1 style orm_mode
# from_attributes = True # Pydantic V2 style
async def update_expense_split_status(db: AsyncSession, expense_split_id: int) -> Optional[ExpenseSplit]:
"""
Updates the status of an ExpenseSplit based on its settlement activities.
Also updates the overall status of the parent Expense.
"""
# Fetch the ExpenseSplit with its related settlement_activities and the parent expense
result = await db.execute(
select(ExpenseSplit)
.options(
selectinload(ExpenseSplit.settlement_activities),
joinedload(ExpenseSplit.expense) # To get expense_id easily
)
.where(ExpenseSplit.id == expense_split_id)
)
expense_split = result.scalar_one_or_none()
if not expense_split:
# Or raise an exception, depending on desired error handling
return None
# Calculate total_paid from all settlement_activities for that split
total_paid = sum(activity.amount_paid for activity in expense_split.settlement_activities)
total_paid = Decimal(total_paid).quantize(Decimal("0.01")) # Ensure two decimal places
# Compare total_paid with ExpenseSplit.owed_amount
if total_paid >= expense_split.owed_amount:
expense_split.status = ExpenseSplitStatusEnum.paid
# Set paid_at to the latest relevant SettlementActivity or current time
# For simplicity, let's find the latest paid_at from activities, or use now()
latest_paid_at = None
if expense_split.settlement_activities:
latest_paid_at = max(act.paid_at for act in expense_split.settlement_activities if act.paid_at)
expense_split.paid_at = latest_paid_at if latest_paid_at else datetime.now(timezone.utc)
elif total_paid > 0:
expense_split.status = ExpenseSplitStatusEnum.partially_paid
expense_split.paid_at = None # Clear paid_at if not fully paid
else: # total_paid == 0
expense_split.status = ExpenseSplitStatusEnum.unpaid
expense_split.paid_at = None # Clear paid_at
await db.flush()
await db.refresh(expense_split, attribute_names=['status', 'paid_at', 'expense']) # Refresh to get updated data and related expense
return expense_split
async def update_expense_overall_status(db: AsyncSession, expense_id: int) -> Optional[Expense]:
"""
Updates the overall_status of an Expense based on the status of its splits.
"""
# Fetch the Expense with its related splits
result = await db.execute(
select(Expense).options(selectinload(Expense.splits)).where(Expense.id == expense_id)
)
expense = result.scalar_one_or_none()
if not expense:
# Or raise an exception
return None
if not expense.splits: # No splits, should not happen for a valid expense but handle defensively
expense.overall_settlement_status = ExpenseOverallStatusEnum.unpaid # Or some other default/error state
await db.flush()
await db.refresh(expense)
return expense
num_splits = len(expense.splits)
num_paid_splits = 0
num_partially_paid_splits = 0
num_unpaid_splits = 0
for split in expense.splits:
if split.status == ExpenseSplitStatusEnum.paid:
num_paid_splits += 1
elif split.status == ExpenseSplitStatusEnum.partially_paid:
num_partially_paid_splits += 1
else: # unpaid
num_unpaid_splits += 1
if num_paid_splits == num_splits:
expense.overall_settlement_status = ExpenseOverallStatusEnum.paid
elif num_unpaid_splits == num_splits:
expense.overall_settlement_status = ExpenseOverallStatusEnum.unpaid
else: # Mix of paid, partially_paid, or unpaid but not all unpaid/paid
expense.overall_settlement_status = ExpenseOverallStatusEnum.partially_paid
await db.flush()
await db.refresh(expense, attribute_names=['overall_settlement_status'])
return expense
async def create_settlement_activity(
db: AsyncSession,
settlement_activity_in: SettlementActivityCreatePlaceholder,
current_user_id: int
) -> Optional[SettlementActivity]:
"""
Creates a new settlement activity, then updates the parent expense split and expense statuses.
"""
# Validate ExpenseSplit
split_result = await db.execute(select(ExpenseSplit).where(ExpenseSplit.id == settlement_activity_in.expense_split_id))
expense_split = split_result.scalar_one_or_none()
if not expense_split:
# Consider raising an HTTPException in an API layer
return None # ExpenseSplit not found
# Validate User (paid_by_user_id)
user_result = await db.execute(select(User).where(User.id == settlement_activity_in.paid_by_user_id))
paid_by_user = user_result.scalar_one_or_none()
if not paid_by_user:
return None # User not found
# Create SettlementActivity instance
db_settlement_activity = SettlementActivity(
expense_split_id=settlement_activity_in.expense_split_id,
paid_by_user_id=settlement_activity_in.paid_by_user_id,
amount_paid=settlement_activity_in.amount_paid,
paid_at=settlement_activity_in.paid_at if settlement_activity_in.paid_at else datetime.now(timezone.utc),
created_by_user_id=current_user_id # The user recording the activity
)
db.add(db_settlement_activity)
await db.flush() # Flush to get the ID for db_settlement_activity
# Update statuses
updated_split = await update_expense_split_status(db, expense_split_id=db_settlement_activity.expense_split_id)
if updated_split and updated_split.expense_id:
await update_expense_overall_status(db, expense_id=updated_split.expense_id)
else:
# This case implies update_expense_split_status returned None or expense_id was missing.
# This could be a problem, consider logging or raising an error.
# For now, the transaction would roll back if an exception is raised.
# If not raising, the overall status update might be skipped.
pass # Or handle error
await db.refresh(db_settlement_activity, attribute_names=['split', 'payer', 'creator']) # Refresh to load relationships
return db_settlement_activity
async def get_settlement_activity_by_id(
db: AsyncSession, settlement_activity_id: int
) -> Optional[SettlementActivity]:
"""
Fetches a single SettlementActivity by its ID, loading relationships.
"""
result = await db.execute(
select(SettlementActivity)
.options(
selectinload(SettlementActivity.split).selectinload(ExpenseSplit.expense), # Load split and its parent expense
selectinload(SettlementActivity.payer), # Load the user who paid
selectinload(SettlementActivity.creator) # Load the user who created the record
)
.where(SettlementActivity.id == settlement_activity_id)
)
return result.scalar_one_or_none()
async def get_settlement_activities_for_split(
db: AsyncSession, expense_split_id: int, skip: int = 0, limit: int = 100
) -> List[SettlementActivity]:
"""
Fetches a list of SettlementActivity records associated with a given expense_split_id.
"""
result = await db.execute(
select(SettlementActivity)
.where(SettlementActivity.expense_split_id == expense_split_id)
.options(
selectinload(SettlementActivity.payer), # Load the user who paid
selectinload(SettlementActivity.creator) # Load the user who created the record
)
.order_by(SettlementActivity.paid_at.desc(), SettlementActivity.created_at.desc())
.offset(skip)
.limit(limit)
)
return result.scalars().all()
# Further CRUD operations like update/delete can be added later if needed.

View File

@ -1,28 +1,90 @@
# app/crud/user.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional
import logging # Add logging import
from app.models import User as UserModel # Alias to avoid name clash
from app.models import User as UserModel, UserGroup as UserGroupModel, Group as GroupModel # Import related models for selectinload
from app.schemas.user import UserCreate
from app.core.security import hash_password
from app.core.exceptions import (
UserCreationError,
EmailAlreadyRegisteredError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
UserOperationError # Add if specific user operation errors are needed
)
logger = logging.getLogger(__name__) # Initialize logger
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]:
"""Fetches a user from the database by email."""
result = await db.execute(select(UserModel).filter(UserModel.email == email))
"""Fetches a user from the database by email, with common relationships."""
try:
# db.begin() is not strictly necessary for a single read, but ensures atomicity if multiple reads were added.
# For a single select, it can be omitted if preferred, session handles connection.
stmt = (
select(UserModel)
.filter(UserModel.email == email)
.options(
selectinload(UserModel.group_associations).selectinload(UserGroupModel.group), # Groups user is member of
selectinload(UserModel.created_groups) # Groups user created
# Add other relationships as needed by UserPublic schema
)
)
result = await db.execute(stmt)
return result.scalars().first()
except OperationalError as e:
logger.error(f"Database connection error while fetching user by email '{email}': {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error while fetching user by email '{email}': {str(e)}", exc_info=True)
raise DatabaseQueryError(f"Failed to query user: {str(e)}")
async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel:
"""Creates a new user record in the database."""
_hashed_password = hash_password(user_in.password) # Keep local var name if you like
# Create SQLAlchemy model instance - explicitly map fields
"""Creates a new user record in the database with common relationships loaded."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
_hashed_password = hash_password(user_in.password)
db_user = UserModel(
email=user_in.email,
# Use the correct keyword argument matching the model column name
password_hash=_hashed_password,
hashed_password=_hashed_password, # Field name in model is hashed_password
name=user_in.name
)
db.add(db_user)
await db.commit()
await db.refresh(db_user) # Refresh to get DB-generated values like ID, created_at
return db_user
await db.flush() # Flush to get DB-generated values like ID
# Re-fetch with relationships
stmt = (
select(UserModel)
.where(UserModel.id == db_user.id)
.options(
selectinload(UserModel.group_associations).selectinload(UserGroupModel.group),
selectinload(UserModel.created_groups)
# Add other relationships as needed by UserPublic schema
)
)
result = await db.execute(stmt)
loaded_user = result.scalar_one_or_none()
if loaded_user is None:
raise UserOperationError("Failed to load user after creation.") # Define UserOperationError
return loaded_user
except IntegrityError as e:
logger.error(f"Database integrity error during user creation for email '{user_in.email}': {str(e)}", exc_info=True)
if "unique constraint" in str(e).lower() and ("users_email_key" in str(e).lower() or "ix_users_email" in str(e).lower()):
raise EmailAlreadyRegisteredError(email=user_in.email)
raise DatabaseIntegrityError(f"Failed to create user due to integrity issue: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during user creation for email '{user_in.email}': {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error during user creation: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during user creation for email '{user_in.email}': {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to create user due to other DB error: {str(e)}")
# Ensure UserOperationError is defined in app.core.exceptions if used
# Example: class UserOperationError(AppException): pass

View File

@ -11,9 +11,10 @@ if not settings.DATABASE_URL:
# pool_recycle=3600 helps prevent stale connections on some DBs
engine = create_async_engine(
settings.DATABASE_URL,
echo=True, # Log SQL queries (useful for debugging)
echo=False, # Disable SQL query logging for production (use DEBUG log level to enable)
future=True, # Use SQLAlchemy 2.0 style features
pool_recycle=3600 # Optional: recycle connections after 1 hour
pool_recycle=3600, # Optional: recycle connections after 1 hour
pool_pre_ping=True # Add this line to ensure connections are live
)
# Create a configured "Session" class
@ -30,18 +31,27 @@ AsyncSessionLocal = sessionmaker(
Base = declarative_base()
# Dependency to get DB session in path operations
async def get_db() -> AsyncSession: # type: ignore
async def get_session() -> AsyncSession: # type: ignore
"""
Dependency function that yields an AsyncSession.
Dependency function that yields an AsyncSession for read-only operations.
Ensures the session is closed after the request.
"""
async with AsyncSessionLocal() as session:
try:
yield session
# Optionally commit if your endpoints modify data directly
# await session.commit() # Usually commit happens within endpoint logic
except Exception:
await session.rollback()
raise
finally:
await session.close() # Not strictly necessary with async context manager, but explicit
# The 'async with' block handles session.close() automatically.
async def get_transactional_session() -> AsyncSession: # type: ignore
"""
Dependency function that yields an AsyncSession and manages a transaction.
Commits the transaction if the request handler succeeds, otherwise rollbacks.
Ensures the session is closed after the request.
This follows the FastAPI-DB strategy for endpoint-level transaction management.
"""
async with AsyncSessionLocal() as session:
async with session.begin():
yield session
# Transaction is automatically committed on success or rolled back on exception
# Alias for backward compatibility
get_db = get_session

3
be/app/db/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from app.db.session import async_session
__all__ = ["async_session"]

4
be/app/db/session.py Normal file
View File

@ -0,0 +1,4 @@
from app.database import AsyncSessionLocal
# Export the async session factory
async_session = AsyncSessionLocal

View File

@ -0,0 +1,119 @@
from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_
from app.models import Expense, RecurrencePattern
from app.crud.expense import create_expense
from app.schemas.expense import ExpenseCreate
import logging
from typing import Optional
logger = logging.getLogger(__name__)
async def generate_recurring_expenses(db: AsyncSession) -> None:
"""
Background job to generate recurring expenses.
Should be run daily to check for and create new recurring expenses.
"""
try:
# Get all active recurring expenses that need to be generated
now = datetime.utcnow()
query = select(Expense).join(RecurrencePattern).where(
and_(
Expense.is_recurring == True,
Expense.next_occurrence <= now,
# Check if we haven't reached max occurrences
(
(RecurrencePattern.max_occurrences == None) |
(RecurrencePattern.max_occurrences > 0)
),
# Check if we haven't reached end date
(
(RecurrencePattern.end_date == None) |
(RecurrencePattern.end_date > now)
)
)
)
result = await db.execute(query)
recurring_expenses = result.scalars().all()
for expense in recurring_expenses:
try:
await _generate_next_occurrence(db, expense)
except Exception as e:
logger.error(f"Error generating next occurrence for expense {expense.id}: {str(e)}")
continue
except Exception as e:
logger.error(f"Error in generate_recurring_expenses job: {str(e)}")
raise
async def _generate_next_occurrence(db: AsyncSession, expense: Expense) -> None:
"""Generate the next occurrence of a recurring expense."""
pattern = expense.recurrence_pattern
if not pattern:
return
# Calculate next occurrence date
next_date = _calculate_next_occurrence(expense.next_occurrence, pattern)
if not next_date:
return
# Create new expense based on template
new_expense = ExpenseCreate(
description=expense.description,
total_amount=expense.total_amount,
currency=expense.currency,
expense_date=next_date,
split_type=expense.split_type,
list_id=expense.list_id,
group_id=expense.group_id,
item_id=expense.item_id,
paid_by_user_id=expense.paid_by_user_id,
is_recurring=False, # Generated expenses are not recurring
splits_in=None # Will be generated based on split_type
)
# Create the new expense
created_expense = await create_expense(db, new_expense, expense.created_by_user_id)
# Update the original expense
expense.last_occurrence = next_date
expense.next_occurrence = _calculate_next_occurrence(next_date, pattern)
if pattern.max_occurrences:
pattern.max_occurrences -= 1
await db.flush()
def _calculate_next_occurrence(current_date: datetime, pattern: RecurrencePattern) -> Optional[datetime]:
"""Calculate the next occurrence date based on the pattern."""
if not current_date:
return None
if pattern.type == 'daily':
return current_date + timedelta(days=pattern.interval)
elif pattern.type == 'weekly':
if not pattern.days_of_week:
return current_date + timedelta(weeks=pattern.interval)
# Find next day of week
current_weekday = current_date.weekday()
next_weekday = min((d for d in pattern.days_of_week if d > current_weekday),
default=min(pattern.days_of_week))
days_ahead = next_weekday - current_weekday
if days_ahead <= 0:
days_ahead += 7
return current_date + timedelta(days=days_ahead)
elif pattern.type == 'monthly':
# Add months to current date
year = current_date.year + (current_date.month + pattern.interval - 1) // 12
month = (current_date.month + pattern.interval - 1) % 12 + 1
return current_date.replace(year=year, month=month)
elif pattern.type == 'yearly':
return current_date.replace(year=current_date.year + pattern.interval)
return None

View File

@ -1,55 +1,204 @@
# app/main.py
import logging
import uvicorn
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException, Depends, status, Request
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
import sentry_sdk
from sentry_sdk.integrations.fastapi import FastApiIntegration
from fastapi_users.authentication import JWTStrategy
from pydantic import BaseModel
from jose import jwt, JWTError
from sqlalchemy.ext.asyncio import AsyncEngine
from alembic.config import Config
from alembic import command
import os
import sys
from app.api.api_router import api_router # Import the main combined router
# Import database and models if needed for startup/shutdown events later
# from . import database, models
from app.api.api_router import api_router
from app.config import settings
from app.core.api_config import API_METADATA, API_TAGS
from app.auth import fastapi_users, auth_backend, get_refresh_jwt_strategy, get_jwt_strategy
from app.models import User
from app.api.auth.oauth import router as oauth_router
from app.schemas.user import UserPublic, UserCreate, UserUpdate
from app.core.scheduler import init_scheduler, shutdown_scheduler
from app.database import get_session
from sqlalchemy import select
# Response model for refresh endpoint
class RefreshResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
# Initialize Sentry only if DSN is provided
if settings.SENTRY_DSN:
sentry_sdk.init(
dsn=settings.SENTRY_DSN,
integrations=[
FastApiIntegration(),
],
# Adjust traces_sample_rate for production
traces_sample_rate=0.1 if settings.is_production else 1.0,
environment=settings.ENVIRONMENT,
# Enable PII data only in development
send_default_pii=not settings.is_production
)
# --- Logging Setup ---
# Configure logging (can be more sophisticated later, e.g., using logging.yaml)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logging.basicConfig(
level=getattr(logging, settings.LOG_LEVEL),
format=settings.LOG_FORMAT
)
logger = logging.getLogger(__name__)
# --- FastAPI App Instance ---
# Create API metadata with environment-dependent settings
api_metadata = {
**API_METADATA,
"docs_url": settings.docs_url,
"redoc_url": settings.redoc_url,
"openapi_url": settings.openapi_url,
}
app = FastAPI(
title="Shared Lists API",
description="API for managing shared shopping lists, OCR, and cost splitting.",
version="0.1.0",
openapi_url="/api/openapi.json", # Place OpenAPI spec under /api
docs_url="/api/docs", # Place Swagger UI under /api
redoc_url="/api/redoc" # Place ReDoc under /api
**api_metadata,
openapi_tags=API_TAGS
)
# Add session middleware for OAuth
app.add_middleware(
SessionMiddleware,
secret_key=settings.SESSION_SECRET_KEY
)
# --- CORS Middleware ---
# Define allowed origins. Be specific in production!
# Use ["*"] for wide open access during early development if needed,
# but restrict it as soon as possible.
# SvelteKit default dev port is 5173
origins = [
"http://localhost:5174",
"http://localhost:8000", # Allow requests from the API itself (e.g., Swagger UI)
# Add your deployed frontend URL here later
# "https://your-frontend-domain.com",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins, # List of origins that are allowed to make requests
allow_credentials=True, # Allow cookies to be included in requests
allow_methods=["*"], # Allow all methods (GET, POST, PUT, DELETE, etc.)
allow_headers=["*"], # Allow all headers
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"]
)
# --- End CORS Middleware ---
# Refresh token endpoint
@app.post("/auth/jwt/refresh", response_model=RefreshResponse, tags=["auth"])
async def refresh_jwt_token(
request: Request,
refresh_strategy: JWTStrategy = Depends(get_refresh_jwt_strategy),
access_strategy: JWTStrategy = Depends(get_jwt_strategy),
):
"""
Refresh access token using a valid refresh token.
Send refresh token in Authorization header: Bearer <refresh_token>
"""
try:
# Get refresh token from Authorization header
authorization = request.headers.get("Authorization")
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token missing or invalid format",
headers={"WWW-Authenticate": "Bearer"},
)
refresh_token = authorization.split(" ")[1]
# Validate refresh token and get user data
try:
# Decode the refresh token to get the user identifier
payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=["HS256"])
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
# Get user from database
async with get_session() as session:
result = await session.execute(select(User).where(User.id == int(user_id)))
user = result.scalar_one_or_none()
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or inactive",
)
# Generate new tokens
new_access_token = await access_strategy.write_token(user)
new_refresh_token = await refresh_strategy.write_token(user)
return RefreshResponse(
access_token=new_access_token,
refresh_token=new_refresh_token,
token_type="bearer"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error refreshing token: {e}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token"
)
# --- Include API Routers ---
# All API endpoints will be prefixed with /api
app.include_router(api_router, prefix="/api")
# Include OAuth routes first (no auth required)
app.include_router(oauth_router, prefix="/auth", tags=["auth"])
# Include FastAPI-Users routes
app.include_router(
fastapi_users.get_auth_router(auth_backend),
prefix="/auth/jwt",
tags=["auth"],
)
app.include_router(
fastapi_users.get_register_router(UserPublic, UserCreate),
prefix="/auth",
tags=["auth"],
)
app.include_router(
fastapi_users.get_reset_password_router(),
prefix="/auth",
tags=["auth"],
)
app.include_router(
fastapi_users.get_verify_router(UserPublic),
prefix="/auth",
tags=["auth"],
)
app.include_router(
fastapi_users.get_users_router(UserPublic, UserUpdate),
prefix="/users",
tags=["users"],
)
# Include your API router
app.include_router(api_router, prefix=settings.API_PREFIX)
# --- End Include API Routers ---
# Health check endpoint
@app.get("/health", tags=["Health"])
async def health_check():
"""
Health check endpoint for load balancers and monitoring.
"""
return {
"status": settings.HEALTH_STATUS_OK,
"environment": settings.ENVIRONMENT,
"version": settings.API_VERSION
}
# --- Root Endpoint (Optional - outside the main API structure) ---
@app.get("/", tags=["Root"])
@ -59,26 +208,53 @@ async def read_root():
Useful for basic reachability checks.
"""
logger.info("Root endpoint '/' accessed.")
# You could redirect to the docs or return a simple message
# from fastapi.responses import RedirectResponse
# return RedirectResponse(url="/api/docs")
return {"message": "Welcome to the Shared Lists API! Docs available at /api/docs"}
return {
"message": settings.ROOT_MESSAGE,
"environment": settings.ENVIRONMENT,
"version": settings.API_VERSION
}
# --- End Root Endpoint ---
async def run_migrations():
"""Run database migrations."""
try:
logger.info("Running database migrations...")
# Get the absolute path to the alembic directory
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
alembic_path = os.path.join(base_path, 'alembic')
# --- Application Startup/Shutdown Events (Optional) ---
# @app.on_event("startup")
# async def startup_event():
# logger.info("Application startup: Connecting to database...")
# # You might perform initial checks or warm-up here
# # await database.engine.connect() # Example check (get_db handles sessions per request)
# logger.info("Application startup complete.")
# Add alembic directory to Python path
if alembic_path not in sys.path:
sys.path.insert(0, alembic_path)
# @app.on_event("shutdown")
# async def shutdown_event():
# logger.info("Application shutdown: Disconnecting from database...")
# # await database.engine.dispose() # Close connection pool
# logger.info("Application shutdown complete.")
# Import and run migrations
from migrations import run_migrations as run_db_migrations
await run_db_migrations()
logger.info("Database migrations completed successfully.")
except Exception as e:
logger.error(f"Error running migrations: {e}")
raise
@app.on_event("startup")
async def startup_event():
"""Initialize services on startup."""
logger.info(f"Application startup in {settings.ENVIRONMENT} environment...")
# Run database migrations
# await run_migrations()
# Initialize scheduler
init_scheduler()
logger.info("Application startup complete.")
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup services on shutdown."""
logger.info("Application shutdown: Disconnecting from database...")
# await database.engine.dispose() # Close connection pool
shutdown_scheduler()
logger.info("Application shutdown complete.")
# --- End Events ---

View File

@ -20,9 +20,10 @@ from sqlalchemy import (
text as sa_text,
Text, # <-- Add Text for description
Numeric, # <-- Add Numeric for price
ARRAY
CheckConstraint,
Date # Added Date for Chore model
)
from sqlalchemy.orm import relationship
from sqlalchemy.orm import relationship, backref
from .database import Base
@ -32,14 +33,43 @@ class UserRoleEnum(enum.Enum):
member = "member"
class SplitTypeEnum(enum.Enum):
equal = "equal"
# Add other types later if needed (e.g., custom, percentage)
# custom = "custom"
EQUAL = "EQUAL" # Split equally among all involved users
EXACT_AMOUNTS = "EXACT_AMOUNTS" # Specific amounts for each user (defined in ExpenseSplit)
PERCENTAGE = "PERCENTAGE" # Percentage for each user (defined in ExpenseSplit)
SHARES = "SHARES" # Proportional to shares/units (defined in ExpenseSplit)
ITEM_BASED = "ITEM_BASED" # If an expense is derived directly from item prices and who added them
# Consider renaming to a more generic term like 'DERIVED' or 'ENTITY_DRIVEN'
# if expenses might be derived from other entities in the future.
# Add more types as needed, e.g., UNPAID (for tracking debts not part of a formal expense)
class SettlementActivityTypeEnum(enum.Enum):
marked_paid = "marked_paid"
marked_unpaid = "marked_unpaid"
# Add other activity types later if needed
class ExpenseSplitStatusEnum(enum.Enum):
unpaid = "unpaid"
partially_paid = "partially_paid"
paid = "paid"
class ExpenseOverallStatusEnum(enum.Enum):
unpaid = "unpaid"
partially_paid = "partially_paid"
paid = "paid"
class RecurrenceTypeEnum(enum.Enum):
DAILY = "DAILY"
WEEKLY = "WEEKLY"
MONTHLY = "MONTHLY"
YEARLY = "YEARLY"
# Add more types as needed
# Define ChoreFrequencyEnum
class ChoreFrequencyEnum(enum.Enum):
one_time = "one_time"
daily = "daily"
weekly = "weekly"
monthly = "monthly"
custom = "custom"
class ChoreTypeEnum(enum.Enum):
personal = "personal"
group = "group"
# --- User Model ---
class User(Base):
@ -47,8 +77,11 @@ class User(Base):
id = Column(Integer, primary_key=True, index=True)
email = Column(String, unique=True, index=True, nullable=False)
password_hash = Column(String, nullable=False)
hashed_password = Column(String, nullable=False)
name = Column(String, index=True, nullable=True)
is_active = Column(Boolean, default=True, nullable=False)
is_superuser = Column(Boolean, default=False, nullable=False)
is_verified = Column(Boolean, default=False, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
# --- Relationships ---
@ -62,6 +95,20 @@ class User(Base):
completed_items = relationship("Item", foreign_keys="Item.completed_by_id", back_populates="completed_by_user") # Link Item.completed_by_id -> User
# --- End NEW Relationships ---
# --- Relationships for Cost Splitting ---
expenses_paid = relationship("Expense", foreign_keys="Expense.paid_by_user_id", back_populates="paid_by_user", cascade="all, delete-orphan")
expenses_created = relationship("Expense", foreign_keys="Expense.created_by_user_id", back_populates="created_by_user", cascade="all, delete-orphan")
expense_splits = relationship("ExpenseSplit", foreign_keys="ExpenseSplit.user_id", back_populates="user", cascade="all, delete-orphan")
settlements_made = relationship("Settlement", foreign_keys="Settlement.paid_by_user_id", back_populates="payer", cascade="all, delete-orphan")
settlements_received = relationship("Settlement", foreign_keys="Settlement.paid_to_user_id", back_populates="payee", cascade="all, delete-orphan")
settlements_created = relationship("Settlement", foreign_keys="Settlement.created_by_user_id", back_populates="created_by_user", cascade="all, delete-orphan")
# --- End Relationships for Cost Splitting ---
# --- Relationships for Chores ---
created_chores = relationship("Chore", foreign_keys="[Chore.created_by_id]", back_populates="creator")
assigned_chores = relationship("ChoreAssignment", back_populates="assigned_user", cascade="all, delete-orphan")
# --- End Relationships for Chores ---
# --- Group Model ---
class Group(Base):
@ -81,6 +128,15 @@ class Group(Base):
lists = relationship("List", back_populates="group", cascade="all, delete-orphan") # Link List.group_id -> Group
# --- End NEW Relationship ---
# --- Relationships for Cost Splitting ---
expenses = relationship("Expense", foreign_keys="Expense.group_id", back_populates="group", cascade="all, delete-orphan")
settlements = relationship("Settlement", foreign_keys="Settlement.group_id", back_populates="group", cascade="all, delete-orphan")
# --- End Relationships for Cost Splitting ---
# --- Relationship for Chores ---
chores = relationship("Chore", back_populates="group", cascade="all, delete-orphan")
# --- End Relationship for Chores ---
# --- UserGroup Association Model ---
class UserGroup(Base):
@ -128,15 +184,29 @@ class List(Base):
is_complete = Column(Boolean, default=False, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
version = Column(Integer, nullable=False, default=1, server_default='1')
# --- Relationships ---
creator = relationship("User", back_populates="created_lists") # Link to User.created_lists
group = relationship("Group", back_populates="lists") # Link to Group.lists
items = relationship("Item", back_populates="list", cascade="all, delete-orphan", order_by="Item.created_at") # Link to Item.list, cascade deletes
items = relationship(
"Item",
back_populates="list",
cascade="all, delete-orphan",
order_by="Item.position.asc(), Item.created_at.asc()" # Default order by position, then creation
)
# --- Relationships for Cost Splitting ---
expenses = relationship("Expense", foreign_keys="Expense.list_id", back_populates="list", cascade="all, delete-orphan")
# --- End Relationships for Cost Splitting ---
# === NEW: Item Model ===
class Item(Base):
__tablename__ = "items"
__table_args__ = (
Index('ix_items_list_id_position', 'list_id', 'position'),
)
id = Column(Integer, primary_key=True, index=True)
list_id = Column(Integer, ForeignKey("lists.id", ondelete="CASCADE"), nullable=False) # Belongs to which list
@ -144,61 +214,219 @@ class Item(Base):
quantity = Column(String, nullable=True) # Flexible quantity (e.g., "1", "2 lbs", "a bunch")
is_complete = Column(Boolean, default=False, nullable=False)
price = Column(Numeric(10, 2), nullable=True) # For cost splitting later (e.g., 12345678.99)
price_added_by_id = Column(Integer, ForeignKey("users.id"), nullable=True)
price_added_at = Column(DateTime(timezone=True), nullable=True)
position = Column(Integer, nullable=False, server_default='0') # For ordering
added_by_id = Column(Integer, ForeignKey("users.id"), nullable=False) # Who added this item
completed_by_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Who marked it complete
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
version = Column(Integer, nullable=False, default=1, server_default='1')
# --- Relationships ---
list = relationship("List", back_populates="items") # Link to List.items
added_by_user = relationship("User", foreign_keys=[added_by_id], back_populates="added_items") # Link to User.added_items
completed_by_user = relationship("User", foreign_keys=[completed_by_id], back_populates="completed_items") # Link to User.completed_items
# === NEW: ExpenseRecord Model ===
class ExpenseRecord(Base):
__tablename__ = "expense_records"
# --- Relationships for Cost Splitting ---
# If an item directly results in an expense, or an expense can be tied to an item.
expenses = relationship("Expense", back_populates="item") # An item might have multiple associated expenses
# --- End Relationships for Cost Splitting ---
# === NEW Models for Advanced Cost Splitting ===
class Expense(Base):
__tablename__ = "expenses"
id = Column(Integer, primary_key=True, index=True)
list_id = Column(Integer, ForeignKey("lists.id"), index=True, nullable=False)
calculated_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
calculated_by_id = Column(Integer, ForeignKey("users.id"), nullable=False)
description = Column(String, nullable=False)
total_amount = Column(Numeric(10, 2), nullable=False)
participants = Column(ARRAY(Integer), nullable=False)
split_type = Column(SAEnum(SplitTypeEnum, name="splittypeenum", create_type=True), nullable=False, default=SplitTypeEnum.equal)
is_settled = Column(Boolean, default=False, nullable=False)
currency = Column(String, nullable=False, default="USD")
expense_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
split_type = Column(SAEnum(SplitTypeEnum, name="splittypeenum", create_type=True), nullable=False)
# Foreign Keys
list_id = Column(Integer, ForeignKey("lists.id"), nullable=True, index=True)
group_id = Column(Integer, ForeignKey("groups.id"), nullable=True, index=True)
item_id = Column(Integer, ForeignKey("items.id"), nullable=True)
paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
version = Column(Integer, nullable=False, default=1, server_default='1')
# Relationships
list = relationship("List")
calculator = relationship("User")
shares = relationship("ExpenseShare", back_populates="expense_record", cascade="all, delete-orphan")
settlement_activities = relationship("SettlementActivity", back_populates="expense_record", cascade="all, delete-orphan")
paid_by_user = relationship("User", foreign_keys=[paid_by_user_id], back_populates="expenses_paid")
created_by_user = relationship("User", foreign_keys=[created_by_user_id], back_populates="expenses_created")
list = relationship("List", foreign_keys=[list_id], back_populates="expenses")
group = relationship("Group", foreign_keys=[group_id], back_populates="expenses")
item = relationship("Item", foreign_keys=[item_id], back_populates="expenses")
splits = relationship("ExpenseSplit", back_populates="expense", cascade="all, delete-orphan")
parent_expense = relationship("Expense", remote_side=[id], back_populates="child_expenses")
child_expenses = relationship("Expense", back_populates="parent_expense")
overall_settlement_status = Column(SAEnum(ExpenseOverallStatusEnum, name="expenseoverallstatusenum", create_type=True), nullable=False, server_default=ExpenseOverallStatusEnum.unpaid.value, default=ExpenseOverallStatusEnum.unpaid)
# --- Recurrence fields ---
is_recurring = Column(Boolean, default=False, nullable=False)
recurrence_pattern_id = Column(Integer, ForeignKey("recurrence_patterns.id"), nullable=True)
recurrence_pattern = relationship("RecurrencePattern", back_populates="expenses", uselist=False) # One-to-one
next_occurrence = Column(DateTime(timezone=True), nullable=True) # For recurring expenses
parent_expense_id = Column(Integer, ForeignKey("expenses.id"), nullable=True)
last_occurrence = Column(DateTime(timezone=True), nullable=True)
class ExpenseShare(Base):
__tablename__ = "expense_shares"
__table_args__ = (UniqueConstraint('expense_record_id', 'user_id', name='uq_expense_share_user'),)
__table_args__ = (
# Ensure at least one context is provided
CheckConstraint('(item_id IS NOT NULL) OR (list_id IS NOT NULL) OR (group_id IS NOT NULL)', name='chk_expense_context'),
)
class ExpenseSplit(Base):
__tablename__ = "expense_splits"
__table_args__ = (
UniqueConstraint('expense_id', 'user_id', name='uq_expense_user_split'),
Index('ix_expense_splits_user_id', 'user_id'), # For looking up user's splits
)
id = Column(Integer, primary_key=True, index=True)
expense_record_id = Column(Integer, ForeignKey("expense_records.id", ondelete="CASCADE"), index=True, nullable=False)
user_id = Column(Integer, ForeignKey("users.id"), index=True, nullable=False)
amount_owed = Column(Numeric(10, 2), nullable=False)
is_paid = Column(Boolean, default=False, nullable=False)
expense_id = Column(Integer, ForeignKey("expenses.id", ondelete="CASCADE"), nullable=False)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
owed_amount = Column(Numeric(10, 2), nullable=False)
share_percentage = Column(Numeric(5, 2), nullable=True)
share_units = Column(Integer, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
# Relationships
expense_record = relationship("ExpenseRecord", back_populates="shares")
user = relationship("User")
expense = relationship("Expense", back_populates="splits")
user = relationship("User", foreign_keys=[user_id], back_populates="expense_splits")
settlement_activities = relationship("SettlementActivity", back_populates="split", cascade="all, delete-orphan")
# New fields for tracking payment status
status = Column(SAEnum(ExpenseSplitStatusEnum, name="expensesplitstatusenum", create_type=True), nullable=False, server_default=ExpenseSplitStatusEnum.unpaid.value, default=ExpenseSplitStatusEnum.unpaid)
paid_at = Column(DateTime(timezone=True), nullable=True) # Timestamp when the split was fully paid
class Settlement(Base):
__tablename__ = "settlements"
id = Column(Integer, primary_key=True, index=True)
group_id = Column(Integer, ForeignKey("groups.id"), nullable=False, index=True)
paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
paid_to_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
amount = Column(Numeric(10, 2), nullable=False)
settlement_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
description = Column(Text, nullable=True)
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
version = Column(Integer, nullable=False, default=1, server_default='1')
# Relationships
group = relationship("Group", foreign_keys=[group_id], back_populates="settlements")
payer = relationship("User", foreign_keys=[paid_by_user_id], back_populates="settlements_made")
payee = relationship("User", foreign_keys=[paid_to_user_id], back_populates="settlements_received")
created_by_user = relationship("User", foreign_keys=[created_by_user_id], back_populates="settlements_created")
__table_args__ = (
# Ensure payer and payee are different users
CheckConstraint('paid_by_user_id != paid_to_user_id', name='chk_settlement_different_users'),
)
# Potential future: PaymentMethod model, etc.
class SettlementActivity(Base):
__tablename__ = "settlement_activities"
id = Column(Integer, primary_key=True, index=True)
expense_record_id = Column(Integer, ForeignKey("expense_records.id", ondelete="CASCADE"), index=True, nullable=False)
payer_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
affected_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
activity_type = Column(SAEnum(SettlementActivityTypeEnum, name="settlementactivitytypeenum", create_type=True), nullable=False)
timestamp = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
# Relationships
expense_record = relationship("ExpenseRecord", back_populates="settlement_activities")
payer = relationship("User", foreign_keys=[payer_user_id])
affected_user = relationship("User", foreign_keys=[affected_user_id])
id = Column(Integer, primary_key=True, index=True)
expense_split_id = Column(Integer, ForeignKey("expense_splits.id"), nullable=False, index=True)
paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) # User who made this part of the payment
paid_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
amount_paid = Column(Numeric(10, 2), nullable=False)
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) # User who recorded this activity
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
# --- Relationships ---
split = relationship("ExpenseSplit", back_populates="settlement_activities")
payer = relationship("User", foreign_keys=[paid_by_user_id], backref="made_settlement_activities")
creator = relationship("User", foreign_keys=[created_by_user_id], backref="created_settlement_activities")
__table_args__ = (
Index('ix_settlement_activity_expense_split_id', 'expense_split_id'),
Index('ix_settlement_activity_paid_by_user_id', 'paid_by_user_id'),
Index('ix_settlement_activity_created_by_user_id', 'created_by_user_id'),
)
# --- Chore Model ---
class Chore(Base):
__tablename__ = "chores"
id = Column(Integer, primary_key=True, index=True)
type = Column(SAEnum(ChoreTypeEnum, name="choretypeenum", create_type=True), nullable=False)
group_id = Column(Integer, ForeignKey("groups.id", ondelete="CASCADE"), nullable=True, index=True)
name = Column(String, nullable=False, index=True)
description = Column(Text, nullable=True)
created_by_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
frequency = Column(SAEnum(ChoreFrequencyEnum, name="chorefrequencyenum", create_type=True), nullable=False)
custom_interval_days = Column(Integer, nullable=True) # Only if frequency is 'custom'
next_due_date = Column(Date, nullable=False) # Changed to Date
last_completed_at = Column(DateTime(timezone=True), nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
# --- Relationships ---
group = relationship("Group", back_populates="chores")
creator = relationship("User", back_populates="created_chores")
assignments = relationship("ChoreAssignment", back_populates="chore", cascade="all, delete-orphan")
# --- ChoreAssignment Model ---
class ChoreAssignment(Base):
__tablename__ = "chore_assignments"
id = Column(Integer, primary_key=True, index=True)
chore_id = Column(Integer, ForeignKey("chores.id", ondelete="CASCADE"), nullable=False, index=True)
assigned_to_user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
due_date = Column(Date, nullable=False) # Specific due date for this instance, changed to Date
is_complete = Column(Boolean, default=False, nullable=False)
completed_at = Column(DateTime(timezone=True), nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
# --- Relationships ---
chore = relationship("Chore", back_populates="assignments")
assigned_user = relationship("User", back_populates="assigned_chores")
# === NEW: RecurrencePattern Model ===
class RecurrencePattern(Base):
__tablename__ = "recurrence_patterns"
id = Column(Integer, primary_key=True, index=True)
type = Column(SAEnum(RecurrenceTypeEnum, name="recurrencetypeenum", create_type=True), nullable=False)
interval = Column(Integer, default=1, nullable=False) # e.g., every 1 day, every 2 weeks
days_of_week = Column(String, nullable=True) # For weekly recurrences, e.g., "MON,TUE,FRI"
# day_of_month = Column(Integer, nullable=True) # For monthly on a specific day
# week_of_month = Column(Integer, nullable=True) # For monthly on a specific week (e.g., 2nd week)
# month_of_year = Column(Integer, nullable=True) # For yearly recurrences
end_date = Column(DateTime(timezone=True), nullable=True)
max_occurrences = Column(Integer, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
# Relationship back to Expenses that use this pattern (could be one-to-many if patterns are shared)
# However, the current CRUD implies one RecurrencePattern per Expense if recurring.
# If a pattern can be shared, this would be a one-to-many (RecurrencePattern to many Expenses).
# For now, assuming one-to-one as implied by current Expense.recurrence_pattern relationship setup.
expenses = relationship("Expense", back_populates="recurrence_pattern")
# === END: RecurrencePattern Model ===

0
be/app/models/expense.py Normal file
View File

View File

@ -1,9 +1,11 @@
# app/schemas/auth.py
from pydantic import BaseModel, EmailStr
from app.config import settings
class Token(BaseModel):
access_token: str
token_type: str = "bearer" # Default token type
refresh_token: str # Added refresh token
token_type: str = settings.TOKEN_TYPE # Use configured token type
# Optional: If you preferred not to use OAuth2PasswordRequestForm
# class UserLogin(BaseModel):

111
be/app/schemas/chore.py Normal file
View File

@ -0,0 +1,111 @@
from datetime import date, datetime
from typing import Optional, List
from pydantic import BaseModel, ConfigDict, field_validator
# Assuming ChoreFrequencyEnum is imported from models
# Adjust the import path if necessary based on your project structure.
# e.g., from app.models import ChoreFrequencyEnum
from ..models import ChoreFrequencyEnum, ChoreTypeEnum, User as UserModel # For UserPublic relation
from .user import UserPublic # For embedding user information
# Chore Schemas
class ChoreBase(BaseModel):
name: str
description: Optional[str] = None
frequency: ChoreFrequencyEnum
custom_interval_days: Optional[int] = None
next_due_date: date # For creation, this will be the initial due date
type: ChoreTypeEnum
@field_validator('custom_interval_days', mode='before')
@classmethod
def check_custom_interval_days(cls, value, values):
# Pydantic v2 uses `values.data` to get all fields
# For older Pydantic, it might just be `values`
# This is a simplified check; actual access might differ slightly
# based on Pydantic version context within the validator.
# The goal is to ensure custom_interval_days is present if frequency is 'custom'.
# This validator might be more complex in a real Pydantic v2 setup.
# A more direct way if 'frequency' is already parsed into values.data:
# freq = values.data.get('frequency')
# For this example, we'll assume 'frequency' might not be in 'values.data' yet
# if 'custom_interval_days' is validated 'before' 'frequency'.
# A truly robust validator might need to be on the whole model or run 'after'.
# For now, this is a placeholder for the logic.
# Consider if this validation is better handled at the service/CRUD layer for complex cases.
return value
class ChoreCreate(ChoreBase):
group_id: Optional[int] = None
@field_validator('group_id')
@classmethod
def validate_group_id(cls, v, values):
if values.data.get('type') == ChoreTypeEnum.group and v is None:
raise ValueError("group_id is required for group chores")
if values.data.get('type') == ChoreTypeEnum.personal and v is not None:
raise ValueError("group_id must be None for personal chores")
return v
class ChoreUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
frequency: Optional[ChoreFrequencyEnum] = None
custom_interval_days: Optional[int] = None
next_due_date: Optional[date] = None # Allow updating next_due_date directly if needed
type: Optional[ChoreTypeEnum] = None
group_id: Optional[int] = None
# last_completed_at should generally not be updated directly by user
@field_validator('group_id')
@classmethod
def validate_group_id(cls, v, values):
if values.data.get('type') == ChoreTypeEnum.group and v is None:
raise ValueError("group_id is required for group chores")
if values.data.get('type') == ChoreTypeEnum.personal and v is not None:
raise ValueError("group_id must be None for personal chores")
return v
class ChorePublic(ChoreBase):
id: int
group_id: Optional[int] = None
created_by_id: int
last_completed_at: Optional[datetime] = None
created_at: datetime
updated_at: datetime
creator: Optional[UserPublic] = None # Embed creator UserPublic schema
# group: Optional[GroupPublic] = None # Embed GroupPublic schema if needed
model_config = ConfigDict(from_attributes=True)
# Chore Assignment Schemas
class ChoreAssignmentBase(BaseModel):
chore_id: int
assigned_to_user_id: int
due_date: date
class ChoreAssignmentCreate(ChoreAssignmentBase):
pass
class ChoreAssignmentUpdate(BaseModel):
# Only completion status and perhaps due_date can be updated for an assignment
is_complete: Optional[bool] = None
due_date: Optional[date] = None # If rescheduling an existing assignment is allowed
class ChoreAssignmentPublic(ChoreAssignmentBase):
id: int
is_complete: bool
completed_at: Optional[datetime] = None
created_at: datetime
updated_at: datetime
# Embed ChorePublic and UserPublic for richer responses
chore: Optional[ChorePublic] = None
assigned_user: Optional[UserPublic] = None
model_config = ConfigDict(from_attributes=True)
# To handle potential circular imports if ChorePublic needs GroupPublic and GroupPublic needs ChorePublic
# We can update forward refs after all models are defined.
# ChorePublic.model_rebuild() # If using Pydantic v2 and forward refs were used with strings
# ChoreAssignmentPublic.model_rebuild()

55
be/app/schemas/cost.py Normal file
View File

@ -0,0 +1,55 @@
from pydantic import BaseModel, ConfigDict
from typing import List, Optional
from decimal import Decimal
class UserCostShare(BaseModel):
user_id: int
user_identifier: str # Name or email
items_added_value: Decimal = Decimal("0.00") # Total value of items this user added
amount_due: Decimal # The user's share of the total cost (for equal split, this is total_cost / num_users)
balance: Decimal # items_added_value - amount_due
model_config = ConfigDict(from_attributes=True)
class ListCostSummary(BaseModel):
list_id: int
list_name: str
total_list_cost: Decimal
num_participating_users: int
equal_share_per_user: Decimal
user_balances: List[UserCostShare]
model_config = ConfigDict(from_attributes=True)
class UserBalanceDetail(BaseModel):
user_id: int
user_identifier: str # Name or email
total_paid_for_expenses: Decimal = Decimal("0.00")
total_share_of_expenses: Decimal = Decimal("0.00")
total_settlements_paid: Decimal = Decimal("0.00")
total_settlements_received: Decimal = Decimal("0.00")
net_balance: Decimal = Decimal("0.00") # (paid_for_expenses + settlements_received) - (share_of_expenses + settlements_paid)
model_config = ConfigDict(from_attributes=True)
class SuggestedSettlement(BaseModel):
from_user_id: int
from_user_identifier: str # Name or email of payer
to_user_id: int
to_user_identifier: str # Name or email of payee
amount: Decimal
model_config = ConfigDict(from_attributes=True)
class GroupBalanceSummary(BaseModel):
group_id: int
group_name: str
overall_total_expenses: Decimal = Decimal("0.00")
overall_total_settlements: Decimal = Decimal("0.00")
user_balances: List[UserBalanceDetail]
# Optional: Could add a list of suggested settlements to zero out balances
suggested_settlements: Optional[List[SuggestedSettlement]] = None
model_config = ConfigDict(from_attributes=True)
# class SuggestedSettlement(BaseModel):
# from_user_id: int
# to_user_id: int
# amount: Decimal

View File

@ -1,49 +1,180 @@
# app/schemas/expense.py
from pydantic import BaseModel, ConfigDict
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel, ConfigDict, validator, Field
from typing import List, Optional, Dict, Any
from decimal import Decimal
from datetime import datetime
from .user import UserPublic # Assuming UserPublic schema exists
from app.models import SplitTypeEnum, SettlementActivityTypeEnum # Import Enums from models
# Assuming SplitTypeEnum is accessible here, e.g., from app.models or app.core.enums
# For now, let's redefine it or import it if models.py is parsable by Pydantic directly
# If it's from app.models, you might need to make app.models.SplitTypeEnum Pydantic-compatible or map it.
# For simplicity during schema definition, I'll redefine a string enum here.
# In a real setup, ensure this aligns with the SQLAlchemy enum in models.py.
from app.models import SplitTypeEnum, ExpenseSplitStatusEnum, ExpenseOverallStatusEnum # Try importing directly
from app.schemas.user import UserPublic # For user details in responses
from app.schemas.settlement_activity import SettlementActivityPublic # For settlement activities
# Represents a single user's share of an expense
class ExpenseSharePublic(BaseModel):
id: int
expense_record_id: int
# --- ExpenseSplit Schemas ---
class ExpenseSplitBase(BaseModel):
user_id: int
amount_owed: Decimal
is_paid: bool
user: Optional[UserPublic] = None # Include user details for context
owed_amount: Decimal
share_percentage: Optional[Decimal] = None
share_units: Optional[int] = None
class ExpenseSplitCreate(ExpenseSplitBase):
pass # All fields from base are needed for creation
class ExpenseSplitPublic(ExpenseSplitBase):
id: int
expense_id: int
user: Optional[UserPublic] = None # If we want to nest user details
created_at: datetime
updated_at: datetime
status: ExpenseSplitStatusEnum # New field
paid_at: Optional[datetime] = None # New field
settlement_activities: List[SettlementActivityPublic] = [] # New field
model_config = ConfigDict(from_attributes=True)
# Represents a log of settlement actions
class SettlementActivityPublic(BaseModel):
id: int
expense_record_id: int
payer_user_id: int # Who marked it paid/unpaid
affected_user_id: int # Whose share status changed
activity_type: SettlementActivityTypeEnum # Use the Enum
timestamp: datetime
# --- Expense Schemas ---
class RecurrencePatternBase(BaseModel):
type: str = Field(..., description="Type of recurrence: daily, weekly, monthly, yearly")
interval: int = Field(..., description="Interval of recurrence (e.g., every X days/weeks/months/years)")
days_of_week: Optional[List[int]] = Field(None, description="Days of week for weekly recurrence (0-6, Sunday-Saturday)")
end_date: Optional[datetime] = Field(None, description="Optional end date for the recurrence")
max_occurrences: Optional[int] = Field(None, description="Optional maximum number of occurrences")
model_config = ConfigDict(from_attributes=True)
class RecurrencePatternCreate(RecurrencePatternBase):
pass
# Represents a finalized expense split record for a list
class ExpenseRecordPublic(BaseModel):
class RecurrencePatternUpdate(RecurrencePatternBase):
pass
class RecurrencePatternInDB(RecurrencePatternBase):
id: int
list_id: int
calculated_at: datetime
calculated_by_id: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class ExpenseBase(BaseModel):
description: str
total_amount: Decimal
split_type: SplitTypeEnum # Use the Enum
is_settled: bool
participants: List[int] # List of user IDs who participated
shares: List[ExpenseSharePublic] = [] # Include the individual shares
settlement_activities: List[SettlementActivityPublic] = [] # Include settlement history
currency: Optional[str] = "USD"
expense_date: Optional[datetime] = None
split_type: SplitTypeEnum
list_id: Optional[int] = None
group_id: Optional[int] = None # Should be present if list_id is not, and vice-versa
item_id: Optional[int] = None
paid_by_user_id: int
is_recurring: bool = Field(False, description="Whether this is a recurring expense")
recurrence_pattern: Optional[RecurrencePatternCreate] = Field(None, description="Recurrence pattern for recurring expenses")
class ExpenseCreate(ExpenseBase):
# For EQUAL split, splits are generated. For others, they might be provided.
# This logic will be in the CRUD: if split_type is EXACT_AMOUNTS, PERCENTAGE, SHARES,
# then 'splits_in' should be provided.
splits_in: Optional[List[ExpenseSplitCreate]] = None
@validator('total_amount')
def total_amount_must_be_positive(cls, v):
if v <= Decimal('0'):
raise ValueError('Total amount must be positive')
return v
# Basic validation: if list_id is None, group_id must be provided.
# More complex cross-field validation might be needed.
@validator('group_id', always=True)
def check_list_or_group_id(cls, v, values):
if values.get('list_id') is None and v is None:
raise ValueError('Either list_id or group_id must be provided for an expense')
return v
@validator('recurrence_pattern')
def validate_recurrence_pattern(cls, v, values):
if values.get('is_recurring') and not v:
raise ValueError('Recurrence pattern is required for recurring expenses')
if not values.get('is_recurring') and v:
raise ValueError('Recurrence pattern should not be provided for non-recurring expenses')
return v
class ExpenseUpdate(BaseModel):
description: Optional[str] = None
total_amount: Optional[Decimal] = None
currency: Optional[str] = None
expense_date: Optional[datetime] = None
split_type: Optional[SplitTypeEnum] = None
list_id: Optional[int] = None
group_id: Optional[int] = None
item_id: Optional[int] = None
# paid_by_user_id is usually not updatable directly to maintain integrity.
# Updating splits would be a more complex operation, potentially a separate endpoint or careful logic.
version: int # For optimistic locking
is_recurring: Optional[bool] = None
recurrence_pattern: Optional[RecurrencePatternUpdate] = None
next_occurrence: Optional[datetime] = None
class ExpensePublic(ExpenseBase):
id: int
created_at: datetime
updated_at: datetime
version: int
created_by_user_id: int
splits: List[ExpenseSplitPublic] = []
paid_by_user: Optional[UserPublic] = None # If nesting user details
overall_settlement_status: ExpenseOverallStatusEnum # New field
# list: Optional[ListPublic] # If nesting list details
# group: Optional[GroupPublic] # If nesting group details
# item: Optional[ItemPublic] # If nesting item details
is_recurring: bool
next_occurrence: Optional[datetime]
last_occurrence: Optional[datetime]
recurrence_pattern: Optional[RecurrencePatternInDB]
parent_expense_id: Optional[int]
generated_expenses: List['ExpensePublic'] = []
model_config = ConfigDict(from_attributes=True)
# Schema for the request body of the settle endpoint
class SettleShareRequest(BaseModel):
affected_user_id: int # The ID of the user whose share is being settled
# --- Settlement Schemas ---
class SettlementBase(BaseModel):
group_id: int
paid_by_user_id: int
paid_to_user_id: int
amount: Decimal
settlement_date: Optional[datetime] = None
description: Optional[str] = None
class SettlementCreate(SettlementBase):
@validator('amount')
def amount_must_be_positive(cls, v):
if v <= Decimal('0'):
raise ValueError('Settlement amount must be positive')
return v
@validator('paid_to_user_id')
def payer_and_payee_must_be_different(cls, v, values):
if 'paid_by_user_id' in values and v == values['paid_by_user_id']:
raise ValueError('Payer and payee cannot be the same user')
return v
class SettlementUpdate(BaseModel):
amount: Optional[Decimal] = None
settlement_date: Optional[datetime] = None
description: Optional[str] = None
# group_id, paid_by_user_id, paid_to_user_id are typically not updatable.
version: int # For optimistic locking
class SettlementPublic(SettlementBase):
id: int
created_at: datetime
updated_at: datetime
version: int
created_by_user_id: int
# payer: Optional[UserPublic] # If we want to include payer details
# payee: Optional[UserPublic] # If we want to include payee details
# group: Optional[GroupPublic] # If we want to include group details
model_config = ConfigDict(from_attributes=True)
# Placeholder for nested schemas (e.g., UserPublic) if needed
# from app.schemas.user import UserPublic
# from app.schemas.list import ListPublic
# from app.schemas.group import GroupPublic
# from app.schemas.item import ItemPublic

View File

@ -1,5 +1,5 @@
# app/schemas/group.py
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, computed_field
from datetime import datetime
from typing import Optional, List
@ -15,7 +15,25 @@ class GroupPublic(BaseModel):
name: str
created_by_id: int
created_at: datetime
members: Optional[List[UserPublic]] = None # Include members only in detailed view
member_associations: Optional[List["UserGroupPublic"]] = None
@computed_field
@property
def members(self) -> Optional[List[UserPublic]]:
if not self.member_associations:
return None
return [assoc.user for assoc in self.member_associations if assoc.user]
model_config = ConfigDict(from_attributes=True)
# Properties for UserGroup association
class UserGroupPublic(BaseModel):
id: int
user_id: int
group_id: int
role: str
joined_at: datetime
user: Optional[UserPublic] = None
model_config = ConfigDict(from_attributes=True)

View File

@ -1,9 +1,10 @@
# app/schemas/health.py
from pydantic import BaseModel
from app.config import settings
class HealthStatus(BaseModel):
"""
Response model for the health check endpoint.
"""
status: str = "ok" # Provide a default value
status: str = settings.HEALTH_STATUS_OK # Use configured default value
database: str

View File

@ -16,6 +16,7 @@ class ItemPublic(BaseModel):
completed_by_id: Optional[int] = None
created_at: datetime
updated_at: datetime
version: int
model_config = ConfigDict(from_attributes=True)
# Properties to receive via API on creation
@ -31,4 +32,6 @@ class ItemUpdate(BaseModel):
quantity: Optional[str] = None
is_complete: Optional[bool] = None
price: Optional[Decimal] = None # Price added here for update
position: Optional[int] = None # For reordering
version: int
# completed_by_id will be set internally if is_complete is true

View File

@ -16,6 +16,7 @@ class ListUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
is_complete: Optional[bool] = None
version: int # Client must provide the version for updates
# Potentially add group_id update later if needed
# Base properties returned by API (common fields)
@ -28,6 +29,7 @@ class ListBase(BaseModel):
is_complete: bool
created_at: datetime
updated_at: datetime
version: int # Include version in responses
model_config = ConfigDict(from_attributes=True)
@ -40,6 +42,9 @@ class ListDetail(ListBase):
items: List[ItemPublic] = [] # Include list of items
class ListStatus(BaseModel):
list_updated_at: datetime
latest_item_updated_at: Optional[datetime] = None # Can be null if list has no items
updated_at: datetime
item_count: int
latest_item_updated_at: Optional[datetime] = None
class ListStatusWithId(ListStatus):
id: int

View File

@ -0,0 +1,43 @@
from pydantic import BaseModel, ConfigDict, field_validator
from typing import Optional, List
from decimal import Decimal
from datetime import datetime
from app.schemas.user import UserPublic # Assuming UserPublic is defined here
class SettlementActivityBase(BaseModel):
expense_split_id: int
paid_by_user_id: int
amount_paid: Decimal
paid_at: Optional[datetime] = None
class SettlementActivityCreate(SettlementActivityBase):
@field_validator('amount_paid')
@classmethod
def amount_must_be_positive(cls, v: Decimal) -> Decimal:
if v <= Decimal("0"):
raise ValueError("Amount paid must be a positive value.")
return v
class SettlementActivityPublic(SettlementActivityBase):
id: int
created_by_user_id: int # User who recorded this activity
created_at: datetime
updated_at: datetime
payer: Optional[UserPublic] = None # User who made this part of the payment
creator: Optional[UserPublic] = None # User who recorded this activity
model_config = ConfigDict(from_attributes=True)
# Schema for updating a settlement activity (if needed in the future)
# class SettlementActivityUpdate(BaseModel):
# amount_paid: Optional[Decimal] = None
# paid_at: Optional[datetime] = None
# @field_validator('amount_paid')
# @classmethod
# def amount_must_be_positive_if_provided(cls, v: Optional[Decimal]) -> Optional[Decimal]:
# if v is not None and v <= Decimal("0"):
# raise ValueError("Amount paid must be a positive value.")
# return v

View File

@ -12,14 +12,27 @@ class UserBase(BaseModel):
class UserCreate(UserBase):
password: str
# Properties to receive via API on update (optional, add later if needed)
# class UserUpdate(UserBase):
# password: Optional[str] = None
def create_update_dict(self):
return {
"email": self.email,
"name": self.name,
"password": self.password,
"is_active": True,
"is_superuser": False,
"is_verified": False
}
# Properties to receive via API on update
class UserUpdate(UserBase):
password: Optional[str] = None
is_active: Optional[bool] = None
is_superuser: Optional[bool] = None
is_verified: Optional[bool] = None
# Properties stored in DB
class UserInDBBase(UserBase):
id: int
hashed_password: str
password_hash: str
created_at: datetime
model_config = ConfigDict(from_attributes=True) # Use orm_mode in Pydantic v1

10
be/entrypoint.sh Executable file
View File

@ -0,0 +1,10 @@
#!/bin/sh
set -e
# Run database migrations
echo "Running database migrations..."
alembic upgrade head
# Execute the command passed as arguments to this script
echo "Starting application..."
exec "$@"

5
be/pytest.ini Normal file
View File

@ -0,0 +1,5 @@
[pytest]
pythonpath = .
testpaths = tests
python_files = test_*.py
asyncio_mode = auto

View File

@ -10,3 +10,18 @@ passlib[bcrypt]>=1.7.4
python-jose[cryptography]>=3.3.0
pydantic[email]
google-generativeai>=0.5.0
sentry-sdk[fastapi]>=1.39.0
python-multipart>=0.0.6 # Required for form data handling
fastapi-users[sqlalchemy]>=12.1.2
email-validator>=2.0.0
fastapi-users[oauth]>=12.1.2
authlib>=1.3.0
itsdangerous>=2.1.2
pytest>=7.4.0
pytest-asyncio>=0.21.0
pytest-cov>=4.1.0
httpx>=0.24.0 # For async HTTP testing
aiosqlite>=0.19.0 # For async SQLite support in tests
# Scheduler
APScheduler==3.10.4

View File

@ -0,0 +1,649 @@
import pytest
from fastapi import status
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Callable, Dict, Any
from unittest.mock import patch, MagicMock
from app.models import User as UserModel, Group as GroupModel, List as ListModel
from app.schemas.expense import ExpenseCreate, ExpensePublic, ExpenseUpdate
# from app.config import settings # Comment out the original import
# Helper to create a URL for an endpoint
# API_V1_STR = settings.API_V1_STR # Comment out the original assignment
@pytest.fixture(scope="module")
def mock_settings_financials():
mock_settings = MagicMock()
mock_settings.API_V1_STR = "/api/v1"
return mock_settings
# Patch the settings in the test module
@pytest.fixture(autouse=True)
def patch_settings_financials(mock_settings_financials):
with patch("app.config.settings", mock_settings_financials):
yield
def expense_url(endpoint: str = "") -> str:
# Use the mocked API_V1_STR via the patched settings object
from app.config import settings # Import settings here to use the patched version
return f"{settings.API_V1_STR}/financials/expenses{endpoint}"
def settlement_url(endpoint: str = "") -> str:
# Use the mocked API_V1_STR via the patched settings object
from app.config import settings # Import settings here to use the patched version
return f"{settings.API_V1_STR}/financials/settlements{endpoint}"
@pytest.mark.asyncio
async def test_create_new_expense_success_list_context(
client: AsyncClient,
db_session: AsyncSession,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_list_user_is_member: ListModel,
) -> None:
expense_data = ExpenseCreate(
description="Test Expense for List",
amount=100.00,
currency="USD",
paid_by_user_id=test_user.id,
list_id=test_list_user_is_member.id,
group_id=None,
)
response = await client.post(
expense_url(),
headers=normal_user_token_headers,
json=expense_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_201_CREATED
content = response.json()
assert content["description"] == expense_data.description
assert content["amount"] == expense_data.amount
assert content["currency"] == expense_data.currency
assert content["paid_by_user_id"] == test_user.id
assert content["list_id"] == test_list_user_is_member.id
if test_list_user_is_member.group_id:
assert content["group_id"] == test_list_user_is_member.group_id
else:
assert content["group_id"] is None
assert "id" in content
assert "created_at" in content
assert "updated_at" in content
assert "version" in content
assert content["version"] == 1
@pytest.mark.asyncio
async def test_create_new_expense_success_group_context(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_group_user_is_member: GroupModel,
) -> None:
expense_data = ExpenseCreate(
description="Test Expense for Group",
amount=50.00,
currency="EUR",
paid_by_user_id=test_user.id,
group_id=test_group_user_is_member.id,
list_id=None,
)
response = await client.post(
expense_url(),
headers=normal_user_token_headers,
json=expense_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_201_CREATED
content = response.json()
assert content["description"] == expense_data.description
assert content["paid_by_user_id"] == test_user.id
assert content["group_id"] == test_group_user_is_member.id
assert content["list_id"] is None
assert content["version"] == 1
@pytest.mark.asyncio
async def test_create_new_expense_fail_no_list_or_group(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
) -> None:
expense_data = ExpenseCreate(
description="Test Invalid Expense",
amount=10.00,
currency="USD",
paid_by_user_id=test_user.id,
list_id=None,
group_id=None,
)
response = await client.post(
expense_url(),
headers=normal_user_token_headers,
json=expense_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
content = response.json()
assert "Expense must be linked to a list_id or group_id" in content["detail"]
@pytest.mark.asyncio
async def test_create_new_expense_fail_paid_by_other_not_owner(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_group_user_is_member: GroupModel,
another_user_in_group: UserModel,
) -> None:
expense_data = ExpenseCreate(
description="Expense paid by other",
amount=75.00,
currency="GBP",
paid_by_user_id=another_user_in_group.id,
group_id=test_group_user_is_member.id,
list_id=None,
)
response = await client.post(
expense_url(),
headers=normal_user_token_headers,
json=expense_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "Only group owners can create expenses paid by others" in content["detail"]
@pytest.mark.asyncio
async def test_get_expense_success(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
created_expense: ExpensePublic,
) -> None:
response = await client.get(
expense_url(f"/{created_expense.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert content["id"] == created_expense.id
assert content["description"] == created_expense.description
@pytest.mark.asyncio
async def test_get_expense_not_found(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
) -> None:
response = await client.get(
expense_url("/999"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_404_NOT_FOUND
content = response.json()
assert "Expense not found" in content["detail"]
@pytest.mark.asyncio
async def test_get_expense_forbidden_personal_expense_other_user(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
personal_expense_of_another_user: ExpensePublic,
) -> None:
response = await client.get(
expense_url(f"/{personal_expense_of_another_user.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "You do not have permission to access this expense" in content["detail"]
@pytest.mark.asyncio
async def test_get_expense_forbidden_not_member_of_list_or_group(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
another_user: UserModel,
expense_in_inaccessible_list_or_group: ExpensePublic,
) -> None:
response = await client.get(
expense_url(f"/{expense_in_inaccessible_list_or_group.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "You do not have permission to access this expense" in content["detail"]
@pytest.mark.asyncio
async def test_get_expense_success_in_list_user_has_access(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
expense_in_accessible_list: ExpensePublic,
test_list_user_is_member: ListModel,
) -> None:
response = await client.get(
expense_url(f"/{expense_in_accessible_list.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert content["id"] == expense_in_accessible_list.id
assert content["list_id"] == test_list_user_is_member.id
@pytest.mark.asyncio
async def test_get_expense_success_in_group_user_has_access(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
expense_in_accessible_group: ExpensePublic,
test_group_user_is_member: GroupModel,
) -> None:
response = await client.get(
expense_url(f"/{expense_in_accessible_group.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert content["id"] == expense_in_accessible_group.id
assert content["group_id"] == test_group_user_is_member.id
@pytest.mark.asyncio
async def test_list_list_expenses_success(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_list_user_is_member: ListModel,
) -> None:
response = await client.get(
expense_url(f"?list_id={test_list_user_is_member.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
for expense in content:
assert expense["list_id"] == test_list_user_is_member.id
@pytest.mark.asyncio
async def test_list_list_expenses_list_not_found(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
) -> None:
response = await client.get(
expense_url("?list_id=999"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_404_NOT_FOUND
content = response.json()
assert "List not found" in content["detail"]
@pytest.mark.asyncio
async def test_list_list_expenses_no_access(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_list_user_not_member: ListModel,
) -> None:
response = await client.get(
expense_url(f"?list_id={test_list_user_not_member.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "You do not have permission to access this list" in content["detail"]
@pytest.mark.asyncio
async def test_list_list_expenses_empty(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_list_user_is_member_no_expenses: ListModel,
) -> None:
response = await client.get(
expense_url(f"?list_id={test_list_user_is_member_no_expenses.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
assert len(content) == 0
@pytest.mark.asyncio
async def test_list_list_expenses_pagination(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_list_with_multiple_expenses: ListModel,
created_expenses_for_list: list[ExpensePublic],
) -> None:
# Test first page
response = await client.get(
expense_url(f"?list_id={test_list_with_multiple_expenses.id}&skip=0&limit=2"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
assert len(content) == 2
assert content[0]["id"] == created_expenses_for_list[0].id
assert content[1]["id"] == created_expenses_for_list[1].id
# Test second page
response = await client.get(
expense_url(f"?list_id={test_list_with_multiple_expenses.id}&skip=2&limit=2"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
assert len(content) == 2
assert content[0]["id"] == created_expenses_for_list[2].id
assert content[1]["id"] == created_expenses_for_list[3].id
@pytest.mark.asyncio
async def test_list_group_expenses_success(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_group_user_is_member: GroupModel,
) -> None:
response = await client.get(
expense_url(f"?group_id={test_group_user_is_member.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
for expense in content:
assert expense["group_id"] == test_group_user_is_member.id
@pytest.mark.asyncio
async def test_list_group_expenses_group_not_found(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
) -> None:
response = await client.get(
expense_url("?group_id=999"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_404_NOT_FOUND
content = response.json()
assert "Group not found" in content["detail"]
@pytest.mark.asyncio
async def test_list_group_expenses_no_access(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_group_user_not_member: GroupModel,
) -> None:
response = await client.get(
expense_url(f"?group_id={test_group_user_not_member.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "You do not have permission to access this group" in content["detail"]
@pytest.mark.asyncio
async def test_list_group_expenses_empty(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_group_user_is_member_no_expenses: GroupModel,
) -> None:
response = await client.get(
expense_url(f"?group_id={test_group_user_is_member_no_expenses.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
assert len(content) == 0
@pytest.mark.asyncio
async def test_list_group_expenses_pagination(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_group_with_multiple_expenses: GroupModel,
created_expenses_for_group: list[ExpensePublic],
) -> None:
# Test first page
response = await client.get(
expense_url(f"?group_id={test_group_with_multiple_expenses.id}&skip=0&limit=2"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
assert len(content) == 2
assert content[0]["id"] == created_expenses_for_group[0].id
assert content[1]["id"] == created_expenses_for_group[1].id
# Test second page
response = await client.get(
expense_url(f"?group_id={test_group_with_multiple_expenses.id}&skip=2&limit=2"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
assert len(content) == 2
assert content[0]["id"] == created_expenses_for_group[2].id
assert content[1]["id"] == created_expenses_for_group[3].id
@pytest.mark.asyncio
async def test_update_expense_success_payer_updates_details(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
expense_paid_by_test_user: ExpensePublic,
) -> None:
update_data = ExpenseUpdate(
description="Updated expense description",
version=expense_paid_by_test_user.version,
)
response = await client.put(
expense_url(f"/{expense_paid_by_test_user.id}"),
headers=normal_user_token_headers,
json=update_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert content["description"] == update_data.description
assert content["version"] == expense_paid_by_test_user.version + 1
@pytest.mark.asyncio
async def test_update_expense_success_group_owner_updates_others_expense(
client: AsyncClient,
group_owner_token_headers: Dict[str, str],
group_owner: UserModel,
expense_paid_by_another_in_group_where_test_user_is_owner: ExpensePublic,
another_user_in_group: UserModel,
) -> None:
update_data = ExpenseUpdate(
description="Updated by group owner",
version=expense_paid_by_another_in_group_where_test_user_is_owner.version,
)
response = await client.put(
expense_url(f"/{expense_paid_by_another_in_group_where_test_user_is_owner.id}"),
headers=group_owner_token_headers,
json=update_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert content["description"] == update_data.description
assert content["version"] == expense_paid_by_another_in_group_where_test_user_is_owner.version + 1
@pytest.mark.asyncio
async def test_update_expense_fail_not_payer_nor_group_owner(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
expense_paid_by_another_in_group_where_test_user_is_member: ExpensePublic,
another_user_in_group: UserModel,
) -> None:
update_data = ExpenseUpdate(
description="Attempted update by non-owner",
version=expense_paid_by_another_in_group_where_test_user_is_member.version,
)
response = await client.put(
expense_url(f"/{expense_paid_by_another_in_group_where_test_user_is_member.id}"),
headers=normal_user_token_headers,
json=update_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "You do not have permission to update this expense" in content["detail"]
@pytest.mark.asyncio
async def test_update_expense_fail_not_found(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
) -> None:
update_data = ExpenseUpdate(
description="Update attempt on non-existent expense",
version=1,
)
response = await client.put(
expense_url("/999"),
headers=normal_user_token_headers,
json=update_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_404_NOT_FOUND
content = response.json()
assert "Expense not found" in content["detail"]
@pytest.mark.asyncio
async def test_update_expense_fail_change_paid_by_user_not_owner(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
expense_paid_by_test_user_in_group: ExpensePublic,
another_user_in_same_group: UserModel,
) -> None:
update_data = ExpenseUpdate(
paid_by_user_id=another_user_in_same_group.id,
version=expense_paid_by_test_user_in_group.version,
)
response = await client.put(
expense_url(f"/{expense_paid_by_test_user_in_group.id}"),
headers=normal_user_token_headers,
json=update_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "Only group owners can change the payer of an expense" in content["detail"]
@pytest.mark.asyncio
async def test_update_expense_success_owner_changes_paid_by_user(
client: AsyncClient,
group_owner_token_headers: Dict[str, str],
group_owner: UserModel,
expense_in_group_owner_group: ExpensePublic,
another_user_in_same_group: UserModel,
) -> None:
update_data = ExpenseUpdate(
paid_by_user_id=another_user_in_same_group.id,
version=expense_in_group_owner_group.version,
)
response = await client.put(
expense_url(f"/{expense_in_group_owner_group.id}"),
headers=group_owner_token_headers,
json=update_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert content["paid_by_user_id"] == another_user_in_same_group.id
assert content["version"] == expense_in_group_owner_group.version + 1
@pytest.mark.asyncio
async def test_delete_expense_success_payer(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
expense_paid_by_test_user: ExpensePublic,
) -> None:
response = await client.delete(
expense_url(f"/{expense_paid_by_test_user.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_204_NO_CONTENT
@pytest.mark.asyncio
async def test_delete_expense_success_group_owner(
client: AsyncClient,
group_owner_token_headers: Dict[str, str],
group_owner: UserModel,
expense_paid_by_another_in_group_where_test_user_is_owner: ExpensePublic,
) -> None:
response = await client.delete(
expense_url(f"/{expense_paid_by_another_in_group_where_test_user_is_owner.id}"),
headers=group_owner_token_headers
)
assert response.status_code == status.HTTP_204_NO_CONTENT
@pytest.mark.asyncio
async def test_delete_expense_fail_not_payer_nor_group_owner(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
expense_paid_by_another_in_group_where_test_user_is_member: ExpensePublic,
) -> None:
response = await client.delete(
expense_url(f"/{expense_paid_by_another_in_group_where_test_user_is_member.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "You do not have permission to delete this expense" in content["detail"]
@pytest.mark.asyncio
async def test_delete_expense_fail_not_found(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
) -> None:
response = await client.delete(
expense_url("/999"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_404_NOT_FOUND
content = response.json()
assert "Expense not found" in content["detail"]
@pytest.mark.asyncio
async def test_delete_expense_idempotency(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
expense_paid_by_test_user: ExpensePublic,
) -> None:
# First delete
response = await client.delete(
expense_url(f"/{expense_paid_by_test_user.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_204_NO_CONTENT
# Second delete should also succeed
response = await client.delete(
expense_url(f"/{expense_paid_by_test_user.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_204_NO_CONTENT
# GET /settlements/{settlement_id}
# POST /settlements
# GET /groups/{group_id}/settlements
# PUT /settlements/{settlement_id}
# DELETE /settlements/{settlement_id}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,411 @@
import pytest
import httpx
from typing import List, Dict, Any
from decimal import Decimal
from datetime import datetime, timezone
from app.models import (
User,
Group,
Expense,
ExpenseSplit,
SettlementActivity,
UserRoleEnum,
SplitTypeEnum,
ExpenseOverallStatusEnum,
ExpenseSplitStatusEnum
)
from app.schemas.settlement_activity import SettlementActivityPublic, SettlementActivityCreate
from app.schemas.expense import ExpensePublic, ExpenseSplitPublic
from app.core.config import settings # For API prefix
# Assume db_session, event_loop, client are provided by conftest.py or similar setup
# For this example, I'll define basic user/auth fixtures if not assumed from conftest
@pytest.fixture
async def test_user1_api(db_session, client: httpx.AsyncClient) -> Dict[str, Any]:
user = User(email="api.user1@example.com", name="API User 1", hashed_password="password1")
db_session.add(user)
await db_session.commit()
await db_session.refresh(user)
# Simulate token login - in a real setup, you'd call a login endpoint
# For now, just returning user and headers directly for mock authentication
# This would typically be handled by a dependency override in tests
# For simplicity, we'll assume current_active_user dependency correctly resolves to this user
# when these headers are used (or mock the dependency).
return {"user": user, "headers": {"Authorization": f"Bearer token-for-{user.id}"}}
@pytest.fixture
async def test_user2_api(db_session, client: httpx.AsyncClient) -> Dict[str, Any]:
user = User(email="api.user2@example.com", name="API User 2", hashed_password="password2")
db_session.add(user)
await db_session.commit()
await db_session.refresh(user)
return {"user": user, "headers": {"Authorization": f"Bearer token-for-{user.id}"}}
@pytest.fixture
async def test_group_user1_owner_api(db_session, test_user1_api: Dict[str, Any]) -> Group:
user1 = test_user1_api["user"]
group = Group(name="API Test Group", created_by_id=user1.id)
db_session.add(group)
await db_session.flush() # Get group.id
# Add user1 as owner
from app.models import UserGroup
user_group_assoc = UserGroup(user_id=user1.id, group_id=group.id, role=UserRoleEnum.owner)
db_session.add(user_group_assoc)
await db_session.commit()
await db_session.refresh(group)
return group
@pytest.fixture
async def test_expense_in_group_api(db_session, test_user1_api: Dict[str, Any], test_group_user1_owner_api: Group) -> Expense:
user1 = test_user1_api["user"]
expense = Expense(
description="Group API Expense",
total_amount=Decimal("50.00"),
currency="USD",
group_id=test_group_user1_owner_api.id,
paid_by_user_id=user1.id,
created_by_user_id=user1.id,
split_type=SplitTypeEnum.EQUAL,
overall_settlement_status=ExpenseOverallStatusEnum.unpaid
)
db_session.add(expense)
await db_session.commit()
await db_session.refresh(expense)
return expense
@pytest.fixture
async def test_expense_split_for_user2_api(db_session, test_expense_in_group_api: Expense, test_user1_api: Dict[str, Any], test_user2_api: Dict[str, Any]) -> ExpenseSplit:
user1 = test_user1_api["user"]
user2 = test_user2_api["user"]
# Split for User 1 (payer)
split1 = ExpenseSplit(
expense_id=test_expense_in_group_api.id,
user_id=user1.id,
owed_amount=Decimal("25.00"),
status=ExpenseSplitStatusEnum.unpaid
)
# Split for User 2 (owes)
split2 = ExpenseSplit(
expense_id=test_expense_in_group_api.id,
user_id=user2.id,
owed_amount=Decimal("25.00"),
status=ExpenseSplitStatusEnum.unpaid
)
db_session.add_all([split1, split2])
# Add user2 to the group as a member for permission checks
from app.models import UserGroup
user_group_assoc = UserGroup(user_id=user2.id, group_id=test_expense_in_group_api.group_id, role=UserRoleEnum.member)
db_session.add(user_group_assoc)
await db_session.commit()
await db_session.refresh(split1)
await db_session.refresh(split2)
return split2 # Return the split that user2 owes
# --- Tests for POST /expense_splits/{expense_split_id}/settle ---
@pytest.mark.asyncio
async def test_settle_expense_split_by_self_success(
client: httpx.AsyncClient,
test_user2_api: Dict[str, Any], # User2 will settle their own split
test_expense_split_for_user2_api: ExpenseSplit,
db_session: AsyncSession # To verify db changes
):
user2 = test_user2_api["user"]
user2_headers = test_user2_api["headers"]
split_to_settle = test_expense_split_for_user2_api
payload = SettlementActivityCreate(
expense_split_id=split_to_settle.id,
paid_by_user_id=user2.id, # User2 is paying
amount_paid=split_to_settle.owed_amount # Full payment
)
response = await client.post(
f"{settings.API_V1_STR}/expense_splits/{split_to_settle.id}/settle",
json=payload.model_dump(mode='json'), # Pydantic v2
headers=user2_headers
)
assert response.status_code == 201
activity_data = response.json()
assert activity_data["amount_paid"] == str(split_to_settle.owed_amount) # Compare as string due to JSON
assert activity_data["paid_by_user_id"] == user2.id
assert activity_data["expense_split_id"] == split_to_settle.id
assert "id" in activity_data
# Verify DB state
await db_session.refresh(split_to_settle)
assert split_to_settle.status == ExpenseSplitStatusEnum.paid
assert split_to_settle.paid_at is not None
# Verify parent expense status (this requires other splits to be paid too)
# For a focused test, we might need to ensure the other split (user1's share) is also paid.
# Or, accept 'partially_paid' if only this one is paid.
parent_expense_id = split_to_settle.expense_id
parent_expense = await db_session.get(Expense, parent_expense_id)
await db_session.refresh(parent_expense, attribute_names=['splits']) # Load splits to check status
all_splits_paid = all(s.status == ExpenseSplitStatusEnum.paid for s in parent_expense.splits)
if all_splits_paid:
assert parent_expense.overall_settlement_status == ExpenseOverallStatusEnum.paid
else:
assert parent_expense.overall_settlement_status == ExpenseOverallStatusEnum.partially_paid
@pytest.mark.asyncio
async def test_settle_expense_split_by_group_owner_success(
client: httpx.AsyncClient,
test_user1_api: Dict[str, Any], # User1 is group owner
test_user2_api: Dict[str, Any], # User2 owes the split
test_expense_split_for_user2_api: ExpenseSplit,
db_session: AsyncSession
):
user1_headers = test_user1_api["headers"]
user_who_owes = test_user2_api["user"]
split_to_settle = test_expense_split_for_user2_api
payload = SettlementActivityCreate(
expense_split_id=split_to_settle.id,
paid_by_user_id=user_who_owes.id, # User1 (owner) records that User2 has paid
amount_paid=split_to_settle.owed_amount
)
response = await client.post(
f"{settings.API_V1_STR}/expense_splits/{split_to_settle.id}/settle",
json=payload.model_dump(mode='json'),
headers=user1_headers # Authenticated as group owner
)
assert response.status_code == 201
activity_data = response.json()
assert activity_data["paid_by_user_id"] == user_who_owes.id
assert activity_data["created_by_user_id"] == test_user1_api["user"].id # Activity created by owner
await db_session.refresh(split_to_settle)
assert split_to_settle.status == ExpenseSplitStatusEnum.paid
@pytest.mark.asyncio
async def test_settle_expense_split_path_body_id_mismatch(
client: httpx.AsyncClient, test_user2_api: Dict[str, Any], test_expense_split_for_user2_api: ExpenseSplit
):
user2_headers = test_user2_api["headers"]
split_to_settle = test_expense_split_for_user2_api
payload = SettlementActivityCreate(
expense_split_id=split_to_settle.id + 1, # Mismatch
paid_by_user_id=test_user2_api["user"].id,
amount_paid=split_to_settle.owed_amount
)
response = await client.post(
f"{settings.API_V1_STR}/expense_splits/{split_to_settle.id}/settle",
json=payload.model_dump(mode='json'), headers=user2_headers
)
assert response.status_code == 400 # As per API endpoint logic
@pytest.mark.asyncio
async def test_settle_expense_split_not_found(
client: httpx.AsyncClient, test_user2_api: Dict[str, Any]
):
user2_headers = test_user2_api["headers"]
payload = SettlementActivityCreate(expense_split_id=9999, paid_by_user_id=test_user2_api["user"].id, amount_paid=Decimal("10.00"))
response = await client.post(
f"{settings.API_V1_STR}/expense_splits/9999/settle",
json=payload.model_dump(mode='json'), headers=user2_headers
)
assert response.status_code == 404 # ItemNotFoundError
@pytest.mark.asyncio
async def test_settle_expense_split_insufficient_permissions(
client: httpx.AsyncClient,
test_user1_api: Dict[str, Any], # User1 is not group owner for this setup, nor involved in split
test_user2_api: Dict[str, Any],
test_expense_split_for_user2_api: ExpenseSplit, # User2 owes this
db_session: AsyncSession
):
# Create a new user (user3) who is not involved and not an owner
user3 = User(email="api.user3@example.com", name="API User 3", hashed_password="password3")
db_session.add(user3)
await db_session.commit()
user3_headers = {"Authorization": f"Bearer token-for-{user3.id}"}
split_owner = test_user2_api["user"] # User2 owns the split
split_to_settle = test_expense_split_for_user2_api
payload = SettlementActivityCreate(
expense_split_id=split_to_settle.id,
paid_by_user_id=split_owner.id, # User2 is paying
amount_paid=split_to_settle.owed_amount
)
# User3 (neither payer nor group owner) tries to record User2's payment
response = await client.post(
f"{settings.API_V1_STR}/expense_splits/{split_to_settle.id}/settle",
json=payload.model_dump(mode='json'),
headers=user3_headers # Authenticated as User3
)
assert response.status_code == 403
# --- Tests for GET /expense_splits/{expense_split_id}/settlement_activities ---
@pytest.mark.asyncio
async def test_get_settlement_activities_success(
client: httpx.AsyncClient,
test_user1_api: Dict[str, Any], # Group owner / expense creator
test_user2_api: Dict[str, Any], # User who owes and pays
test_expense_split_for_user2_api: ExpenseSplit,
db_session: AsyncSession
):
user1_headers = test_user1_api["headers"]
user2 = test_user2_api["user"]
split = test_expense_split_for_user2_api
# Create a settlement activity first
activity_payload = SettlementActivityCreate(expense_split_id=split.id, paid_by_user_id=user2.id, amount_paid=Decimal("10.00"))
await client.post(
f"{settings.API_V1_STR}/expense_splits/{split.id}/settle",
json=activity_payload.model_dump(mode='json'), headers=test_user2_api["headers"] # User2 settles
)
# User1 (group owner) fetches activities
response = await client.get(
f"{settings.API_V1_STR}/expense_splits/{split.id}/settlement_activities",
headers=user1_headers
)
assert response.status_code == 200
activities_data = response.json()
assert isinstance(activities_data, list)
assert len(activities_data) == 1
assert activities_data[0]["amount_paid"] == "10.00"
assert activities_data[0]["paid_by_user_id"] == user2.id
@pytest.mark.asyncio
async def test_get_settlement_activities_split_not_found(
client: httpx.AsyncClient, test_user1_api: Dict[str, Any]
):
user1_headers = test_user1_api["headers"]
response = await client.get(
f"{settings.API_V1_STR}/expense_splits/9999/settlement_activities",
headers=user1_headers
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_get_settlement_activities_no_permission(
client: httpx.AsyncClient,
test_expense_split_for_user2_api: ExpenseSplit, # Belongs to group of user1/user2
db_session: AsyncSession
):
# Create a new user (user3) who is not in the group
user3 = User(email="api.user3.other@example.com", name="API User 3 Other", hashed_password="password3")
db_session.add(user3)
await db_session.commit()
user3_headers = {"Authorization": f"Bearer token-for-{user3.id}"}
response = await client.get(
f"{settings.API_V1_STR}/expense_splits/{test_expense_split_for_user2_api.id}/settlement_activities",
headers=user3_headers # Authenticated as User3
)
assert response.status_code == 403
# --- Test existing expense endpoints for new fields ---
@pytest.mark.asyncio
async def test_get_expense_by_id_includes_new_fields(
client: httpx.AsyncClient,
test_user1_api: Dict[str, Any], # User in group
test_expense_in_group_api: Expense,
test_expense_split_for_user2_api: ExpenseSplit # one of the splits
):
user1_headers = test_user1_api["headers"]
expense_id = test_expense_in_group_api.id
response = await client.get(f"{settings.API_V1_STR}/expenses/{expense_id}", headers=user1_headers)
assert response.status_code == 200
expense_data = response.json()
assert "overall_settlement_status" in expense_data
assert expense_data["overall_settlement_status"] == ExpenseOverallStatusEnum.unpaid.value # Initial state
assert "splits" in expense_data
assert len(expense_data["splits"]) > 0
found_split = False
for split_json in expense_data["splits"]:
if split_json["id"] == test_expense_split_for_user2_api.id:
found_split = True
assert "status" in split_json
assert split_json["status"] == ExpenseSplitStatusEnum.unpaid.value # Initial state
assert "paid_at" in split_json # Should be null initially
assert split_json["paid_at"] is None
assert "settlement_activities" in split_json
assert isinstance(split_json["settlement_activities"], list)
assert len(split_json["settlement_activities"]) == 0 # No activities yet
break
assert found_split, "The specific test split was not found in the expense data."
# Placeholder for conftest.py content if needed for local execution understanding
"""
# conftest.py (example structure)
import pytest
import asyncio
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from app.main import app # Your FastAPI app
from app.database import Base, get_transactional_session # Your DB setup
TEST_DATABASE_URL = "sqlite+aiosqlite:///./test.db"
engine = create_async_engine(TEST_DATABASE_URL, echo=True)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, class_=AsyncSession)
@pytest.fixture(scope="session")
def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session", autouse=True)
async def setup_db():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
async def db_session() -> AsyncSession:
async with TestingSessionLocal() as session:
# Transaction is handled by get_transactional_session override or test logic
yield session
# Rollback changes after test if not using transactional tests per case
# await session.rollback() # Or rely on test isolation method
@pytest.fixture
async def client(db_session) -> AsyncClient: # Depends on db_session to ensure DB is ready
async def override_get_transactional_session():
# Provide the test session, potentially managing transactions per test
# This is a simplified version; real setup might involve nested transactions
# or ensuring each test runs in its own transaction that's rolled back.
try:
yield db_session
# await db_session.commit() # Or commit if test is meant to persist then rollback globally
except Exception:
# await db_session.rollback()
raise
# finally:
# await db_session.rollback() # Ensure rollback after each test using this fixture
app.dependency_overrides[get_transactional_session] = override_get_transactional_session
async with AsyncClient(app=app, base_url="http://test") as c:
yield c
del app.dependency_overrides[get_transactional_session] # Clean up
"""

56
be/tests/conftest.py Normal file
View File

@ -0,0 +1,56 @@
import pytest
import asyncio
from typing import AsyncGenerator
from fastapi.testclient import TestClient
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.main import app
from app.models import Base
from app.database import get_db
from app.config import settings
# Create test database engine
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
engine = create_async_engine(
TEST_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
)
@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session")
async def test_db():
"""Create test database and tables."""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
async def db_session(test_db) -> AsyncGenerator[AsyncSession, None]:
"""Create a fresh database session for each test."""
async with TestingSessionLocal() as session:
yield session
@pytest.fixture
async def client(db_session) -> AsyncGenerator[TestClient, None]:
"""Create a test client with the test database session."""
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as test_client:
yield test_client
app.dependency_overrides.clear()

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,208 @@
import pytest
from fastapi import status
from app.core.exceptions import (
ListNotFoundError,
ListPermissionError,
ListCreatorRequiredError,
GroupNotFoundError,
GroupPermissionError,
GroupMembershipError,
GroupOperationError,
GroupValidationError,
ItemNotFoundError,
UserNotFoundError,
InvalidOperationError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseTransactionError,
DatabaseQueryError,
OCRServiceUnavailableError,
OCRServiceConfigError,
OCRUnexpectedError,
OCRQuotaExceededError,
InvalidFileTypeError,
FileTooLargeError,
OCRProcessingError,
EmailAlreadyRegisteredError,
UserCreationError,
InviteNotFoundError,
InviteExpiredError,
InviteAlreadyUsedError,
InviteCreationError,
ListStatusNotFoundError,
ConflictError,
InvalidCredentialsError,
NotAuthenticatedError,
JWTError,
JWTUnexpectedError
)
def test_list_not_found_error():
list_id = 123
with pytest.raises(ListNotFoundError) as excinfo:
raise ListNotFoundError(list_id=list_id)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"List {list_id} not found"
def test_list_permission_error():
list_id = 456
action = "delete"
with pytest.raises(ListPermissionError) as excinfo:
raise ListPermissionError(list_id=list_id, action=action)
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
assert excinfo.value.detail == f"You do not have permission to {action} list {list_id}"
def test_list_permission_error_default_action():
list_id = 789
with pytest.raises(ListPermissionError) as excinfo:
raise ListPermissionError(list_id=list_id)
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
assert excinfo.value.detail == f"You do not have permission to access list {list_id}"
def test_list_creator_required_error():
list_id = 101
action = "update"
with pytest.raises(ListCreatorRequiredError) as excinfo:
raise ListCreatorRequiredError(list_id=list_id, action=action)
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
assert excinfo.value.detail == f"Only the list creator can {action} list {list_id}"
def test_group_not_found_error():
group_id = 202
with pytest.raises(GroupNotFoundError) as excinfo:
raise GroupNotFoundError(group_id=group_id)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"Group {group_id} not found"
def test_group_permission_error():
group_id = 303
action = "invite"
with pytest.raises(GroupPermissionError) as excinfo:
raise GroupPermissionError(group_id=group_id, action=action)
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
assert excinfo.value.detail == f"You do not have permission to {action} in group {group_id}"
def test_group_membership_error():
group_id = 404
action = "post"
with pytest.raises(GroupMembershipError) as excinfo:
raise GroupMembershipError(group_id=group_id, action=action)
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
assert excinfo.value.detail == f"You must be a member of group {group_id} to {action}"
def test_group_membership_error_default_action():
group_id = 505
with pytest.raises(GroupMembershipError) as excinfo:
raise GroupMembershipError(group_id=group_id)
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
assert excinfo.value.detail == f"You must be a member of group {group_id} to access"
def test_group_operation_error():
detail_msg = "Failed to perform group operation."
with pytest.raises(GroupOperationError) as excinfo:
raise GroupOperationError(detail=detail_msg)
assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert excinfo.value.detail == detail_msg
def test_group_validation_error():
detail_msg = "Invalid group data."
with pytest.raises(GroupValidationError) as excinfo:
raise GroupValidationError(detail=detail_msg)
assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
assert excinfo.value.detail == detail_msg
def test_item_not_found_error():
item_id = 606
with pytest.raises(ItemNotFoundError) as excinfo:
raise ItemNotFoundError(item_id=item_id)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"Item {item_id} not found"
def test_user_not_found_error_no_identifier():
with pytest.raises(UserNotFoundError) as excinfo:
raise UserNotFoundError()
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == "User not found."
def test_user_not_found_error_with_id():
user_id = 707
with pytest.raises(UserNotFoundError) as excinfo:
raise UserNotFoundError(user_id=user_id)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"User with ID {user_id} not found."
def test_user_not_found_error_with_identifier_string():
identifier = "test_user"
with pytest.raises(UserNotFoundError) as excinfo:
raise UserNotFoundError(identifier=identifier)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"User with identifier '{identifier}' not found."
def test_invalid_operation_error():
detail_msg = "This operation is not allowed."
with pytest.raises(InvalidOperationError) as excinfo:
raise InvalidOperationError(detail=detail_msg)
assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
assert excinfo.value.detail == detail_msg
def test_invalid_operation_error_custom_status():
detail_msg = "This operation is forbidden."
custom_status = status.HTTP_403_FORBIDDEN
with pytest.raises(InvalidOperationError) as excinfo:
raise InvalidOperationError(detail=detail_msg, status_code=custom_status)
assert excinfo.value.status_code == custom_status
assert excinfo.value.detail == detail_msg
def test_email_already_registered_error():
with pytest.raises(EmailAlreadyRegisteredError) as excinfo:
raise EmailAlreadyRegisteredError()
assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
assert excinfo.value.detail == "Email already registered."
def test_user_creation_error():
with pytest.raises(UserCreationError) as excinfo:
raise UserCreationError()
assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert excinfo.value.detail == "An error occurred during user creation."
def test_invite_not_found_error():
invite_code = "TESTCODE123"
with pytest.raises(InviteNotFoundError) as excinfo:
raise InviteNotFoundError(invite_code=invite_code)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"Invite code {invite_code} not found"
def test_invite_expired_error():
invite_code = "EXPIREDCODE"
with pytest.raises(InviteExpiredError) as excinfo:
raise InviteExpiredError(invite_code=invite_code)
assert excinfo.value.status_code == status.HTTP_410_GONE
assert excinfo.value.detail == f"Invite code {invite_code} has expired"
def test_invite_already_used_error():
invite_code = "USEDCODE"
with pytest.raises(InviteAlreadyUsedError) as excinfo:
raise InviteAlreadyUsedError(invite_code=invite_code)
assert excinfo.value.status_code == status.HTTP_410_GONE
assert excinfo.value.detail == f"Invite code {invite_code} has already been used"
def test_invite_creation_error():
group_id = 909
with pytest.raises(InviteCreationError) as excinfo:
raise InviteCreationError(group_id=group_id)
assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert excinfo.value.detail == f"Failed to create invite for group {group_id}"
def test_list_status_not_found_error():
list_id = 808
with pytest.raises(ListStatusNotFoundError) as excinfo:
raise ListStatusNotFoundError(list_id=list_id)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"Status for list {list_id} not found"
def test_conflict_error():
detail_msg = "Resource version mismatch."
with pytest.raises(ConflictError) as excinfo:
raise ConflictError(detail=detail_msg)
assert excinfo.value.status_code == status.HTTP_409_CONFLICT
assert excinfo.value.detail == detail_msg

View File

@ -0,0 +1,300 @@
import pytest
import asyncio
from unittest.mock import patch, MagicMock, AsyncMock
import google.generativeai as genai
from google.api_core import exceptions as google_exceptions
# Modules to test
from app.core import gemini
from app.core.exceptions import (
OCRServiceUnavailableError,
OCRServiceConfigError,
OCRUnexpectedError,
OCRQuotaExceededError
)
# Default Mock Settings
@pytest.fixture
def mock_gemini_settings():
settings_mock = MagicMock()
settings_mock.GEMINI_API_KEY = "test_api_key"
settings_mock.GEMINI_MODEL_NAME = "gemini-pro-vision"
settings_mock.GEMINI_SAFETY_SETTINGS = {
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
}
settings_mock.GEMINI_GENERATION_CONFIG = {"temperature": 0.7}
settings_mock.OCR_ITEM_EXTRACTION_PROMPT = "Extract items:"
return settings_mock
@pytest.fixture
def mock_generative_model_instance():
model_instance = AsyncMock(spec=genai.GenerativeModel)
model_instance.generate_content_async = AsyncMock()
return model_instance
@pytest.fixture
def patch_google_ai_client(mock_generative_model_instance):
with patch('google.generativeai.GenerativeModel', return_value=mock_generative_model_instance) as mock_generative_model, \
patch('google.generativeai.configure') as mock_configure:
yield mock_configure, mock_generative_model, mock_generative_model_instance
# --- Test Gemini Client Initialization (Global Client) ---
# Parametrize to test different scenarios for the global client init
@pytest.mark.parametrize(
"api_key_present, configure_raises, model_init_raises, expected_error_message_part",
[
(True, None, None, None), # Success
(False, None, None, "GEMINI_API_KEY not configured"), # API key missing
(True, Exception("Config error"), None, "Failed to initialize Gemini AI client: Config error"), # genai.configure error
(True, None, Exception("Model init error"), "Failed to initialize Gemini AI client: Model init error"), # GenerativeModel error
]
)
@patch('app.core.gemini.genai') # Patch genai within the gemini module
def test_global_gemini_client_initialization(
mock_genai_module,
mock_gemini_settings,
api_key_present,
configure_raises,
model_init_raises,
expected_error_message_part
):
"""Tests the global gemini_flash_client initialization logic in app.core.gemini."""
# We need to reload the module to re-trigger its top-level initialization code.
# This is a bit tricky. A common pattern is to put init logic in a function.
# For now, we'll try to simulate it by controlling mocks before module access.
with patch('app.core.gemini.settings', mock_gemini_settings):
if not api_key_present:
mock_gemini_settings.GEMINI_API_KEY = None
mock_genai_module.configure = MagicMock()
mock_genai_module.GenerativeModel = MagicMock()
mock_genai_module.types = genai.types # Keep original types
mock_genai_module.HarmCategory = genai.HarmCategory
mock_genai_module.HarmBlockThreshold = genai.HarmBlockThreshold
if configure_raises:
mock_genai_module.configure.side_effect = configure_raises
if model_init_raises:
mock_genai_module.GenerativeModel.side_effect = model_init_raises
# Python modules are singletons. To re-run top-level code, we need to unload and reload.
# This is generally discouraged. It's better to have an explicit init function.
# For this test, we'll check the state variables set by the module's import-time code.
import importlib
importlib.reload(gemini) # This re-runs the try-except block at the top of gemini.py
if expected_error_message_part:
assert gemini.gemini_initialization_error is not None
assert expected_error_message_part in gemini.gemini_initialization_error
assert gemini.gemini_flash_client is None
else:
assert gemini.gemini_initialization_error is None
assert gemini.gemini_flash_client is not None
mock_genai_module.configure.assert_called_once_with(api_key="test_api_key")
mock_genai_module.GenerativeModel.assert_called_once()
# Could add more assertions about safety_settings and generation_config here
# Clean up after reload for other tests
importlib.reload(gemini)
# --- Test get_gemini_client ---
# Assuming the global client tests above set the stage for these
@patch('app.core.gemini.gemini_flash_client', new_callable=MagicMock)
@patch('app.core.gemini.gemini_initialization_error', None)
def test_get_gemini_client_success(mock_client_var, mock_error_var):
mock_client_var.return_value = MagicMock(spec=genai.GenerativeModel) # Simulate an initialized client
gemini.gemini_flash_client = mock_client_var # Assign the mock
gemini.gemini_initialization_error = None
client = gemini.get_gemini_client()
assert client is not None
@patch('app.core.gemini.gemini_flash_client', None)
@patch('app.core.gemini.gemini_initialization_error', "Test init error")
def test_get_gemini_client_init_error(mock_client_var, mock_error_var):
gemini.gemini_flash_client = None
gemini.gemini_initialization_error = "Test init error"
with pytest.raises(RuntimeError, match="Gemini client could not be initialized: Test init error"):
gemini.get_gemini_client()
@patch('app.core.gemini.gemini_flash_client', None)
@patch('app.core.gemini.gemini_initialization_error', None) # No init error, but client is None
def test_get_gemini_client_none_client_unknown_issue(mock_client_var, mock_error_var):
gemini.gemini_flash_client = None
gemini.gemini_initialization_error = None
with pytest.raises(RuntimeError, match="Gemini client is not available \(unknown initialization issue\)."):
gemini.get_gemini_client()
# --- Tests for extract_items_from_image_gemini --- (Simplified for brevity, needs more cases)
@pytest.mark.asyncio
async def test_extract_items_from_image_gemini_success(
mock_gemini_settings,
mock_generative_model_instance,
patch_google_ai_client
):
mock_response = MagicMock()
mock_response.text = "Item 1\nItem 2\n Item 3 \n\nAnother Item"
mock_candidate = MagicMock()
mock_candidate.content.parts = [MagicMock(text=mock_response.text)]
mock_candidate.finish_reason = 'STOP'
mock_candidate.safety_ratings = []
mock_response.candidates = [mock_candidate]
mock_generative_model_instance.generate_content_async.return_value = mock_response
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
patch('app.core.gemini.gemini_initialization_error', None):
image_bytes = b"dummy_image_bytes"
mime_type = "image/png"
items = await gemini.extract_items_from_image_gemini(image_bytes, mime_type)
mock_generative_model_instance.generate_content_async.assert_called_once_with([
mock_gemini_settings.OCR_ITEM_EXTRACTION_PROMPT,
{"mime_type": mime_type, "data": image_bytes}
])
assert items == ["Item 1", "Item 2", "Item 3", "Another Item"]
@pytest.mark.asyncio
async def test_extract_items_from_image_gemini_client_not_init(mock_gemini_settings):
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch('app.core.gemini.gemini_flash_client', None), \
patch('app.core.gemini.gemini_initialization_error', "Initialization failed explicitly"):
image_bytes = b"dummy_image_bytes"
with pytest.raises(RuntimeError, match="Gemini client could not be initialized: Initialization failed explicitly"):
await gemini.extract_items_from_image_gemini(image_bytes)
@pytest.mark.asyncio
async def test_extract_items_from_image_gemini_api_quota_error(
mock_gemini_settings,
mock_generative_model_instance
):
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ResourceExhausted("Quota exceeded")
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
patch('app.core.gemini.gemini_initialization_error', None):
image_bytes = b"dummy_image_bytes"
with pytest.raises(google_exceptions.ResourceExhausted, match="Quota exceeded"):
await gemini.extract_items_from_image_gemini(image_bytes)
# --- Tests for GeminiOCRService --- (Example tests, more needed)
@patch('app.core.gemini.genai.configure')
@patch('app.core.gemini.genai.GenerativeModel')
def test_gemini_ocr_service_init_success(MockGenerativeModel, MockConfigure, mock_gemini_settings, mock_generative_model_instance):
MockGenerativeModel.return_value = mock_generative_model_instance
with patch('app.core.gemini.settings', mock_gemini_settings):
service = gemini.GeminiOCRService()
MockConfigure.assert_called_once_with(api_key=mock_gemini_settings.GEMINI_API_KEY)
MockGenerativeModel.assert_called_once_with(mock_gemini_settings.GEMINI_MODEL_NAME)
assert service.model == mock_generative_model_instance
# Could add assertions for safety_settings and generation_config if they are set directly on model
@patch('app.core.gemini.genai.configure')
@patch('app.core.gemini.genai.GenerativeModel', side_effect=Exception("Init model failed"))
def test_gemini_ocr_service_init_failure(MockGenerativeModel, MockConfigure, mock_gemini_settings):
with patch('app.core.gemini.settings', mock_gemini_settings):
with pytest.raises(OCRServiceConfigError):
gemini.GeminiOCRService()
@pytest.mark.asyncio
async def test_gemini_ocr_service_extract_items_success(
mock_gemini_settings,
mock_generative_model_instance
):
mock_response = MagicMock()
mock_response.text = "Item 1\nItem 2\n Item 3 \n\nAnother Item"
mock_candidate = MagicMock()
mock_candidate.content.parts = [MagicMock(text=mock_response.text)]
mock_candidate.finish_reason = 'STOP'
mock_candidate.safety_ratings = []
mock_response.candidates = [mock_candidate]
mock_generative_model_instance.generate_content_async.return_value = mock_response
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
patch('app.core.gemini.gemini_initialization_error', None):
service = gemini.GeminiOCRService()
image_bytes = b"dummy_image_bytes"
mime_type = "image/png"
items = await service.extract_items(image_bytes, mime_type)
mock_generative_model_instance.generate_content_async.assert_called_once_with([
mock_gemini_settings.OCR_ITEM_EXTRACTION_PROMPT,
{"mime_type": mime_type, "data": image_bytes}
])
assert items == ["Item 1", "Item 2", "Item 3", "Another Item"]
@pytest.mark.asyncio
async def test_gemini_ocr_service_extract_items_quota_error(
mock_gemini_settings,
mock_generative_model_instance
):
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ResourceExhausted("Quota exceeded")
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
patch('app.core.gemini.gemini_initialization_error', None):
service = gemini.GeminiOCRService()
image_bytes = b"dummy_image_bytes"
with pytest.raises(OCRQuotaExceededError):
await service.extract_items(image_bytes)
@pytest.mark.asyncio
async def test_gemini_ocr_service_extract_items_api_unavailable(
mock_gemini_settings,
mock_generative_model_instance
):
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ServiceUnavailable("Service unavailable")
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
patch('app.core.gemini.gemini_initialization_error', None):
service = gemini.GeminiOCRService()
image_bytes = b"dummy_image_bytes"
with pytest.raises(OCRServiceUnavailableError):
await service.extract_items(image_bytes)
@pytest.mark.asyncio
async def test_gemini_ocr_service_extract_items_no_text_response(
mock_gemini_settings,
mock_generative_model_instance
):
mock_response = MagicMock()
mock_response.text = ""
mock_candidate = MagicMock()
mock_candidate.content.parts = [MagicMock(text=mock_response.text)]
mock_candidate.finish_reason = 'STOP'
mock_candidate.safety_ratings = []
mock_response.candidates = [mock_candidate]
mock_generative_model_instance.generate_content_async.return_value = mock_response
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
patch('app.core.gemini.gemini_initialization_error', None):
service = gemini.GeminiOCRService()
image_bytes = b"dummy_image_bytes"
items = await service.extract_items(image_bytes)
assert items == []

View File

@ -0,0 +1,216 @@
import pytest
from unittest.mock import patch, MagicMock
from datetime import datetime, timedelta, timezone
from jose import jwt, JWTError
from passlib.context import CryptContext
from app.core.security import (
verify_password,
hash_password,
# create_access_token,
# create_refresh_token,
# verify_access_token,
# verify_refresh_token,
pwd_context, # Import for direct testing if needed, or to check its config
)
# Assuming app.config.settings will be mocked
# from app.config import settings
# --- Tests for Password Hashing ---
def test_hash_password():
password = "securepassword123"
hashed = hash_password(password)
assert isinstance(hashed, str)
assert hashed != password
# Check that the default scheme (bcrypt) is used by verifying the hash prefix
# bcrypt hashes typically start with $2b$ or $2a$ or $2y$
assert hashed.startswith("$2b$") or hashed.startswith("$2a$") or hashed.startswith("$2y$")
def test_verify_password_correct():
password = "testpassword"
hashed_password = pwd_context.hash(password) # Use the same context for consistency
assert verify_password(password, hashed_password) is True
def test_verify_password_incorrect():
password = "testpassword"
wrong_password = "wrongpassword"
hashed_password = pwd_context.hash(password)
assert verify_password(wrong_password, hashed_password) is False
def test_verify_password_invalid_hash_format():
password = "testpassword"
invalid_hash = "notarealhash"
assert verify_password(password, invalid_hash) is False
# --- Tests for JWT Creation ---
# Mock settings for JWT tests
# @pytest.fixture(scope="module")
# def mock_jwt_settings():
# mock_settings = MagicMock()
# mock_settings.SECRET_KEY = "testsecretkey"
# mock_settings.ALGORITHM = "HS256"
# mock_settings.ACCESS_TOKEN_EXPIRE_MINUTES = 30
# mock_settings.REFRESH_TOKEN_EXPIRE_MINUTES = 10080 # 7 days
# return mock_settings
# @patch('app.core.security.settings')
# def test_create_access_token_default_expiry(mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
# subject = "user@example.com"
# token = create_access_token(subject)
# assert isinstance(token, str)
# decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
# assert decoded_payload["sub"] == subject
# assert decoded_payload["type"] == "access"
# assert "exp" in decoded_payload
# # Check if expiry is roughly correct (within a small delta)
# expected_expiry = datetime.now(timezone.utc) + timedelta(minutes=mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES)
# assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
# @patch('app.core.security.settings')
# def test_create_access_token_custom_expiry(mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# # ACCESS_TOKEN_EXPIRE_MINUTES is not used here due to custom delta
# subject = 123 # Subject can be int
# custom_delta = timedelta(hours=1)
# token = create_access_token(subject, expires_delta=custom_delta)
# assert isinstance(token, str)
# decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
# assert decoded_payload["sub"] == str(subject)
# assert decoded_payload["type"] == "access"
# expected_expiry = datetime.now(timezone.utc) + custom_delta
# assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
# @patch('app.core.security.settings')
# def test_create_refresh_token_default_expiry(mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES
# subject = "refresh_subject"
# token = create_refresh_token(subject)
# assert isinstance(token, str)
# decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
# assert decoded_payload["sub"] == subject
# assert decoded_payload["type"] == "refresh"
# expected_expiry = datetime.now(timezone.utc) + timedelta(minutes=mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES)
# assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
# --- Tests for JWT Verification --- (More tests to be added here)
# @patch('app.core.security.settings')
# def test_verify_access_token_valid(mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
# subject = "test_user_valid_access"
# token = create_access_token(subject)
# payload = verify_access_token(token)
# assert payload is not None
# assert payload["sub"] == subject
# assert payload["type"] == "access"
# @patch('app.core.security.settings')
# def test_verify_access_token_invalid_signature(mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
# subject = "test_user_invalid_sig"
# # Create token with correct key
# token = create_access_token(subject)
# # Try to verify with wrong key
# mock_settings_global.SECRET_KEY = "wrongsecretkey"
# payload = verify_access_token(token)
# assert payload is None
# @patch('app.core.security.settings')
# @patch('app.core.security.datetime') # Mock datetime to control token expiry
# def test_verify_access_token_expired(mock_datetime, mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = 1 # Expire in 1 minute
# # Set current time for token creation
# now = datetime.now(timezone.utc)
# mock_datetime.now.return_value = now
# mock_datetime.fromtimestamp = datetime.fromtimestamp # Ensure original fromtimestamp is used by jwt.decode
# mock_datetime.timedelta = timedelta # Ensure original timedelta is used
# subject = "test_user_expired"
# token = create_access_token(subject)
# # Advance time beyond expiry for verification
# mock_datetime.now.return_value = now + timedelta(minutes=5)
# payload = verify_access_token(token)
# assert payload is None
# @patch('app.core.security.settings')
# def test_verify_access_token_wrong_type(mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES # For refresh token creation
# subject = "test_user_wrong_type"
# # Create a refresh token
# refresh_token = create_refresh_token(subject)
# # Try to verify it as an access token
# payload = verify_access_token(refresh_token)
# assert payload is None
# @patch('app.core.security.settings')
# def test_verify_refresh_token_valid(mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES
# subject = "test_user_valid_refresh"
# token = create_refresh_token(subject)
# payload = verify_refresh_token(token)
# assert payload is not None
# assert payload["sub"] == subject
# assert payload["type"] == "refresh"
# @patch('app.core.security.settings')
# @patch('app.core.security.datetime')
# def test_verify_refresh_token_expired(mock_datetime, mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = 1 # Expire in 1 minute
# now = datetime.now(timezone.utc)
# mock_datetime.now.return_value = now
# mock_datetime.fromtimestamp = datetime.fromtimestamp
# mock_datetime.timedelta = timedelta
# subject = "test_user_expired_refresh"
# token = create_refresh_token(subject)
# mock_datetime.now.return_value = now + timedelta(minutes=5)
# payload = verify_refresh_token(token)
# assert payload is None
# @patch('app.core.security.settings')
# def test_verify_refresh_token_wrong_type(mock_settings_global, mock_jwt_settings):
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
# subject = "test_user_wrong_type_refresh"
# access_token = create_access_token(subject)
# payload = verify_refresh_token(access_token)
# assert payload is None

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,369 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError
from decimal import Decimal, ROUND_HALF_UP
from datetime import datetime, timezone
from typing import List as PyList, Optional
from app.crud.expense import (
create_expense,
get_expense_by_id,
get_expenses_for_list,
get_expenses_for_group,
update_expense,
delete_expense,
get_users_for_splitting
)
from app.schemas.expense import ExpenseCreate, ExpenseUpdate, ExpenseSplitCreate, ExpenseRead
from app.models import (
Expense as ExpenseModel,
ExpenseSplit as ExpenseSplitModel,
User as UserModel,
List as ListModel,
Group as GroupModel,
UserGroup as UserGroupModel,
Item as ItemModel,
SplitTypeEnum,
ExpenseOverallStatusEnum, # Added
ExpenseSplitStatusEnum # Added
)
from app.core.exceptions import (
ListNotFoundError,
GroupNotFoundError,
UserNotFoundError,
InvalidOperationError,
ExpenseNotFoundError,
DatabaseTransactionError,
ConflictError
)
# General Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.begin_nested = AsyncMock() # For nested transactions within functions
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock()
session.execute = AsyncMock()
session.get = AsyncMock()
session.flush = AsyncMock()
session.in_transaction = MagicMock(return_value=False)
# Mock session.begin() to return an async context manager
mock_transaction_context = AsyncMock()
session.begin = MagicMock(return_value=mock_transaction_context)
return session
@pytest.fixture
def basic_user_model():
return UserModel(id=1, name="Test User", email="test@example.com", version=1)
@pytest.fixture
def another_user_model():
return UserModel(id=2, name="Another User", email="another@example.com", version=1)
@pytest.fixture
def basic_group_model(basic_user_model, another_user_model):
group = GroupModel(id=1, name="Test Group", version=1)
# Simulate member_associations for get_users_for_splitting if needed directly
# group.member_associations = [UserGroupModel(user_id=1, group_id=1, user=basic_user_model), UserGroupModel(user_id=2, group_id=1, user=another_user_model)]
return group
@pytest.fixture
def basic_list_model(basic_group_model, basic_user_model):
return ListModel(id=1, name="Test List", group_id=basic_group_model.id, group=basic_group_model, created_by_id=basic_user_model.id, creator=basic_user_model, version=1)
@pytest.fixture
def expense_create_data_equal_split_list_ctx(basic_list_model, basic_user_model):
return ExpenseCreate(
description="Grocery run",
total_amount=Decimal("30.00"),
currency="USD",
expense_date=datetime.now(timezone.utc).date(),
split_type=SplitTypeEnum.EQUAL,
list_id=basic_list_model.id,
group_id=None, # Derived from list
item_id=None,
paid_by_user_id=basic_user_model.id,
splits_in=None
)
@pytest.fixture
def expense_create_data_equal_split_group_ctx(basic_group_model, basic_user_model):
return ExpenseCreate(
description="Movies",
total_amount=Decimal("50.00"),
currency="USD",
expense_date=datetime.now(timezone.utc).date(),
split_type=SplitTypeEnum.EQUAL,
list_id=None,
group_id=basic_group_model.id,
item_id=None,
paid_by_user_id=basic_user_model.id,
splits_in=None
)
@pytest.fixture
def expense_create_data_exact_split(basic_group_model, basic_user_model, another_user_model):
return ExpenseCreate(
description="Dinner",
total_amount=Decimal("100.00"),
expense_date=datetime.now(timezone.utc).date(),
currency="USD",
split_type=SplitTypeEnum.EXACT_AMOUNTS,
group_id=basic_group_model.id,
paid_by_user_id=basic_user_model.id,
splits_in=[
ExpenseSplitCreate(user_id=basic_user_model.id, owed_amount=Decimal("60.00")),
ExpenseSplitCreate(user_id=another_user_model.id, owed_amount=Decimal("40.00")),
]
)
@pytest.fixture
def expense_update_data():
return ExpenseUpdate(
description="Updated Dinner",
total_amount=Decimal("120.00"),
version=1 # Ensure version is provided for updates
)
@pytest.fixture
def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model):
expense = ExpenseModel(
id=1,
description=expense_create_data_equal_split_group_ctx.description,
total_amount=expense_create_data_equal_split_group_ctx.total_amount,
currency=expense_create_data_equal_split_group_ctx.currency,
expense_date=expense_create_data_equal_split_group_ctx.expense_date,
split_type=expense_create_data_equal_split_group_ctx.split_type,
list_id=expense_create_data_equal_split_group_ctx.list_id,
group_id=expense_create_data_equal_split_group_ctx.group_id,
group=basic_group_model, # Link to group fixture
item_id=expense_create_data_equal_split_group_ctx.item_id,
paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id,
created_by_user_id=basic_user_model.id,
paid_by=basic_user_model,
created_by_user=basic_user_model,
version=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
# Simulate splits for an existing expense
expense.splits = [
ExpenseSplitModel(id=1, expense_id=1, user_id=basic_user_model.id, owed_amount=Decimal("25.00"), version=1),
ExpenseSplitModel(id=2, expense_id=1, user_id=2, owed_amount=Decimal("25.00"), version=1) # Assuming another_user_model has id 2
]
return expense
# Tests for get_users_for_splitting
@pytest.mark.asyncio
async def test_get_users_for_splitting_group_context(mock_db_session, basic_group_model, basic_user_model, another_user_model):
user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id, group_id=basic_group_model.id)
user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id, group_id=basic_group_model.id)
basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2]
mock_db_session.get.return_value = basic_group_model # Mock get for group
users = await get_users_for_splitting(mock_db_session, expense_group_id=basic_group_model.id, expense_list_id=None, expense_paid_by_user_id=basic_user_model.id)
assert len(users) == 2
assert basic_user_model in users
assert another_user_model in users
@pytest.mark.asyncio
async def test_get_users_for_splitting_list_context(mock_db_session, basic_list_model, basic_group_model, basic_user_model, another_user_model):
user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id, group_id=basic_group_model.id)
user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id, group_id=basic_group_model.id)
basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2]
basic_list_model.group = basic_group_model # Ensure list is associated with the group
mock_db_session.get.return_value = basic_list_model # Mock get for list
users = await get_users_for_splitting(mock_db_session, expense_group_id=None, expense_list_id=basic_list_model.id, expense_paid_by_user_id=basic_user_model.id)
assert len(users) == 2
assert basic_user_model in users
assert another_user_model in users
# --- create_expense Tests ---
@pytest.mark.asyncio
async def test_create_expense_equal_split_group_success(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model, another_user_model):
# Setup mocks
mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # paid_by_user, then group
# Mock get_users_for_splitting directly
with patch('app.crud.expense.get_users_for_splitting', new_callable=AsyncMock) as mock_get_users:
mock_get_users.return_value = [basic_user_model, another_user_model]
async def mock_refresh(instance, attribute_names=None, with_for_update=None):
if isinstance(instance, ExpenseModel):
instance.id = 1 # Simulate ID assignment after flush
instance.version = 1
instance.created_at = datetime.now(timezone.utc)
instance.updated_at = datetime.now(timezone.utc)
# Simulate splits being added to the session and linked by refresh
instance.splits = [
ExpenseSplitModel(expense_id=instance.id, user_id=basic_user_model.id, owed_amount=Decimal("25.00"), version=1),
ExpenseSplitModel(expense_id=instance.id, user_id=another_user_model.id, owed_amount=Decimal("25.00"), version=1)
]
return None
mock_db_session.refresh.side_effect = mock_refresh
created_expense = await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, current_user_id=basic_user_model.id)
mock_db_session.add.assert_called()
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once()
assert created_expense is not None
assert created_expense.total_amount == expense_create_data_equal_split_group_ctx.total_amount
assert created_expense.split_type == SplitTypeEnum.EQUAL
assert len(created_expense.splits) == 2
expected_amount_per_user = (expense_create_data_equal_split_group_ctx.total_amount / 2).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
for split in created_expense.splits:
assert split.owed_amount == expected_amount_per_user
assert split.status == ExpenseSplitStatusEnum.unpaid # Verify initial split status
assert created_expense.overall_settlement_status == ExpenseOverallStatusEnum.unpaid # Verify initial expense status
@pytest.mark.asyncio
async def test_create_expense_exact_split_success(mock_db_session, expense_create_data_exact_split, basic_user_model, basic_group_model, another_user_model):
mock_db_session.get.side_effect = [basic_user_model, basic_group_model, basic_user_model, another_user_model] # Payer, Group, User1 in split, User2 in split
async def mock_refresh(instance, attribute_names=None, with_for_update=None):
if isinstance(instance, ExpenseModel):
instance.id = 2
instance.version = 1
instance.splits = [
ExpenseSplitModel(expense_id=instance.id, user_id=basic_user_model.id, owed_amount=Decimal("60.00")),
ExpenseSplitModel(expense_id=instance.id, user_id=another_user_model.id, owed_amount=Decimal("40.00"))
]
return None
mock_db_session.refresh.side_effect = mock_refresh
created_expense = await create_expense(mock_db_session, expense_create_data_exact_split, current_user_id=basic_user_model.id)
mock_db_session.add.assert_called()
mock_db_session.flush.assert_called_once()
assert created_expense is not None
assert created_expense.split_type == SplitTypeEnum.EXACT_AMOUNTS
assert len(created_expense.splits) == 2
assert created_expense.splits[0].owed_amount == Decimal("60.00")
assert created_expense.splits[1].owed_amount == Decimal("40.00")
for split in created_expense.splits:
assert split.status == ExpenseSplitStatusEnum.unpaid # Verify initial split status
assert created_expense.overall_settlement_status == ExpenseOverallStatusEnum.unpaid # Verify initial expense status
@pytest.mark.asyncio
async def test_create_expense_payer_not_found(mock_db_session, expense_create_data_equal_split_group_ctx):
mock_db_session.get.side_effect = [None] # Payer not found, group lookup won't happen
with pytest.raises(UserNotFoundError):
await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, 999) # current_user_id is for creator, paid_by_user_id is in schema
@pytest.mark.asyncio
async def test_create_expense_no_list_or_group(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model):
mock_db_session.get.return_value = basic_user_model # Payer found
expense_data = expense_create_data_equal_split_group_ctx.model_copy()
expense_data.list_id = None
expense_data.group_id = None
with pytest.raises(InvalidOperationError, match="Expense must be associated with a list or a group"):
await create_expense(mock_db_session, expense_data, basic_user_model.id)
# --- get_expense_by_id Tests ---
@pytest.mark.asyncio
async def test_get_expense_by_id_found(mock_db_session, db_expense_model):
mock_db_session.get.return_value = db_expense_model
expense = await get_expense_by_id(mock_db_session, db_expense_model.id)
assert expense is not None
assert expense.id == db_expense_model.id
mock_db_session.get.assert_called_once_with(ExpenseModel, db_expense_model.id, options=[
MagicMock(), MagicMock(), MagicMock()
]) # Adjust based on actual options used in get_expense_by_id
@pytest.mark.asyncio
async def test_get_expense_by_id_not_found(mock_db_session):
mock_db_session.get.return_value = None
expense = await get_expense_by_id(mock_db_session, 999)
assert expense is None
# --- get_expenses_for_list Tests ---
@pytest.mark.asyncio
async def test_get_expenses_for_list_success(mock_db_session, db_expense_model, basic_list_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_expense_model]
mock_db_session.execute.return_value = mock_result
expenses = await get_expenses_for_list(mock_db_session, basic_list_model.id)
assert len(expenses) == 1
assert expenses[0].id == db_expense_model.id
mock_db_session.execute.assert_called_once()
# --- get_expenses_for_group Tests ---
@pytest.mark.asyncio
async def test_get_expenses_for_group_success(mock_db_session, db_expense_model, basic_group_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_expense_model]
mock_db_session.execute.return_value = mock_result
expenses = await get_expenses_for_group(mock_db_session, basic_group_model.id)
assert len(expenses) == 1
assert expenses[0].id == db_expense_model.id
mock_db_session.execute.assert_called_once()
# --- update_expense Tests ---
@pytest.mark.asyncio
async def test_update_expense_success(mock_db_session, db_expense_model, expense_update_data, basic_user_model):
expense_update_data.version = db_expense_model.version # Match version
# Simulate that the db_expense_model is returned by session.get
mock_db_session.get.return_value = db_expense_model
updated_expense = await update_expense(mock_db_session, db_expense_model.id, expense_update_data, basic_user_model.id)
mock_db_session.add.assert_called_once_with(db_expense_model)
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once_with(db_expense_model)
assert updated_expense.description == expense_update_data.description
assert updated_expense.total_amount == expense_update_data.total_amount
assert updated_expense.version == db_expense_model.version # Version incremented by the function
@pytest.mark.asyncio
async def test_update_expense_not_found(mock_db_session, expense_update_data, basic_user_model):
mock_db_session.get.return_value = None # Expense not found
with pytest.raises(ExpenseNotFoundError):
await update_expense(mock_db_session, 999, expense_update_data, basic_user_model.id)
@pytest.mark.asyncio
async def test_update_expense_version_conflict(mock_db_session, db_expense_model, expense_update_data, basic_user_model):
expense_update_data.version = db_expense_model.version + 1 # Create version mismatch
mock_db_session.get.return_value = db_expense_model
with pytest.raises(ConflictError):
await update_expense(mock_db_session, db_expense_model.id, expense_update_data, basic_user_model.id)
mock_db_session.rollback.assert_called_once()
# --- delete_expense Tests ---
@pytest.mark.asyncio
async def test_delete_expense_success(mock_db_session, db_expense_model, basic_user_model):
mock_db_session.get.return_value = db_expense_model # Simulate expense found
await delete_expense(mock_db_session, db_expense_model.id, basic_user_model.id)
mock_db_session.delete.assert_called_once_with(db_expense_model)
# Assuming delete_expense uses session.begin() and commits
mock_db_session.begin().commit.assert_called_once()
@pytest.mark.asyncio
async def test_delete_expense_not_found(mock_db_session, basic_user_model):
mock_db_session.get.return_value = None # Expense not found
with pytest.raises(ExpenseNotFoundError):
await delete_expense(mock_db_session, 999, basic_user_model.id)
mock_db_session.rollback.assert_not_called() # Rollback might be called by begin() context manager exit
@pytest.mark.asyncio
async def test_delete_expense_db_error(mock_db_session, db_expense_model, basic_user_model):
mock_db_session.get.return_value = db_expense_model
mock_db_session.delete.side_effect = OperationalError("mock op error", "params", "orig")
with pytest.raises(DatabaseTransactionError):
await delete_expense(mock_db_session, db_expense_model.id, basic_user_model.id)
mock_db_session.begin().rollback.assert_called_once() # Rollback from the transaction context

354
be/tests/crud/test_group.py Normal file
View File

@ -0,0 +1,354 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError
from sqlalchemy.future import select
from sqlalchemy import delete, func
from app.crud.group import (
create_group,
get_user_groups,
get_group_by_id,
is_user_member,
get_user_role_in_group,
add_user_to_group,
remove_user_from_group,
get_group_member_count,
check_group_membership,
check_user_role_in_group,
update_group_member_role # Assuming this will be added
)
from app.schemas.group import GroupCreate, GroupUpdate # Added GroupUpdate
from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel, UserRoleEnum
from app.core.exceptions import (
GroupOperationError,
GroupNotFoundError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
GroupMembershipError,
GroupPermissionError,
UserNotFoundError, # For adding user to group
ConflictError # For updates
)
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
mock_transaction_context = AsyncMock()
session.begin = MagicMock(return_value=mock_transaction_context)
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock()
session.execute = AsyncMock()
session.get = AsyncMock()
session.flush = AsyncMock()
return session
@pytest.fixture
def group_create_data():
return GroupCreate(name="Test Group")
@pytest.fixture
def group_update_data():
return GroupUpdate(name="Updated Test Group", version=1)
@pytest.fixture
def creator_user_model():
return UserModel(id=1, name="Creator User", email="creator@example.com", version=1)
@pytest.fixture
def member_user_model():
return UserModel(id=2, name="Member User", email="member@example.com", version=1)
@pytest.fixture
def non_member_user_model():
return UserModel(id=3, name="Non Member User", email="nonmember@example.com", version=1)
@pytest.fixture
def db_group_model(creator_user_model):
return GroupModel(id=1, name="Test Group", created_by_id=creator_user_model.id, creator=creator_user_model, version=1)
@pytest.fixture
def db_user_group_owner_assoc(db_group_model, creator_user_model):
return UserGroupModel(id=1, user_id=creator_user_model.id, group_id=db_group_model.id, role=UserRoleEnum.owner, user=creator_user_model, group=db_group_model, version=1)
@pytest.fixture
def db_user_group_member_assoc(db_group_model, member_user_model):
return UserGroupModel(id=2, user_id=member_user_model.id, group_id=db_group_model.id, role=UserRoleEnum.member, user=member_user_model, group=db_group_model, version=1)
# --- create_group Tests ---
@pytest.mark.asyncio
async def test_create_group_success(mock_db_session, group_create_data, creator_user_model):
async def mock_refresh(instance, attribute_names=None, with_for_update=None):
if isinstance(instance, GroupModel):
instance.id = 1 # Simulate ID assignment by DB
instance.version = 1
# Simulate the UserGroup association being added and refreshed if done via relationship back_populates
instance.members = [UserGroupModel(user_id=creator_user_model.id, group_id=instance.id, role=UserRoleEnum.owner, version=1)]
elif isinstance(instance, UserGroupModel):
instance.id = 1 # Simulate ID for UserGroupModel
instance.version = 1
return None
mock_db_session.refresh.side_effect = mock_refresh
# Mock the user get for the creator
mock_db_session.get.return_value = creator_user_model
created_group = await create_group(mock_db_session, group_create_data, creator_user_model.id)
assert mock_db_session.add.call_count == 2 # Group and UserGroup
mock_db_session.flush.assert_called()
assert mock_db_session.refresh.call_count >= 1 # Called for group, maybe for UserGroup too
assert created_group is not None
assert created_group.name == group_create_data.name
assert created_group.created_by_id == creator_user_model.id
assert len(created_group.members) == 1
assert created_group.members[0].role == UserRoleEnum.owner
@pytest.mark.asyncio
async def test_create_group_integrity_error(mock_db_session, group_create_data, creator_user_model):
mock_db_session.get.return_value = creator_user_model # Creator user found
mock_db_session.flush.side_effect = IntegrityError("mock integrity error", "params", "orig")
with pytest.raises(DatabaseIntegrityError):
await create_group(mock_db_session, group_create_data, creator_user_model.id)
mock_db_session.rollback.assert_called_once()
# --- get_user_groups Tests ---
@pytest.mark.asyncio
async def test_get_user_groups_success(mock_db_session, db_group_model, creator_user_model):
# Mock the execute call that fetches groups for a user
mock_result_groups = AsyncMock()
mock_result_groups.scalars.return_value.all.return_value = [db_group_model]
mock_db_session.execute.return_value = mock_result_groups
groups = await get_user_groups(mock_db_session, creator_user_model.id)
assert len(groups) == 1
assert groups[0].name == db_group_model.name
mock_db_session.execute.assert_called_once()
# --- get_group_by_id Tests ---
@pytest.mark.asyncio
async def test_get_group_by_id_found(mock_db_session, db_group_model):
mock_db_session.get.return_value = db_group_model
group = await get_group_by_id(mock_db_session, db_group_model.id)
assert group is not None
assert group.id == db_group_model.id
mock_db_session.get.assert_called_once_with(GroupModel, db_group_model.id, options=ANY) # options for eager loading
@pytest.mark.asyncio
async def test_get_group_by_id_not_found(mock_db_session):
mock_db_session.get.return_value = None
group = await get_group_by_id(mock_db_session, 999)
assert group is None
# --- is_user_member Tests ---
from unittest.mock import ANY # For checking options in get
@pytest.mark.asyncio
async def test_is_user_member_true(mock_db_session, db_group_model, creator_user_model, db_user_group_owner_assoc):
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = db_user_group_owner_assoc.id
mock_db_session.execute.return_value = mock_result
is_member = await is_user_member(mock_db_session, db_group_model.id, creator_user_model.id)
assert is_member is True
@pytest.mark.asyncio
async def test_is_user_member_false(mock_db_session, db_group_model, non_member_user_model):
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = None
mock_db_session.execute.return_value = mock_result
is_member = await is_user_member(mock_db_session, db_group_model.id, non_member_user_model.id)
assert is_member is False
# --- get_user_role_in_group Tests ---
@pytest.mark.asyncio
async def test_get_user_role_in_group_owner(mock_db_session, db_group_model, creator_user_model):
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = UserRoleEnum.owner
mock_db_session.execute.return_value = mock_result
role = await get_user_role_in_group(mock_db_session, db_group_model.id, creator_user_model.id)
assert role == UserRoleEnum.owner
# --- add_user_to_group Tests ---
@pytest.mark.asyncio
async def test_add_user_to_group_new_member(mock_db_session, db_group_model, member_user_model, non_member_user_model):
# Mock is_user_member to return False initially
with patch('app.crud.group.is_user_member', new_callable=AsyncMock) as mock_is_member:
mock_is_member.return_value = False
# Mock get for the user to be added
mock_db_session.get.return_value = non_member_user_model
async def mock_refresh_user_group(instance, attribute_names=None, with_for_update=None):
instance.id = 100
instance.version = 1
return None
mock_db_session.refresh.side_effect = mock_refresh_user_group
user_group_assoc = await add_user_to_group(mock_db_session, db_group_model, non_member_user_model.id, UserRoleEnum.member)
mock_db_session.add.assert_called_once()
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once()
assert user_group_assoc is not None
assert user_group_assoc.user_id == non_member_user_model.id
assert user_group_assoc.group_id == db_group_model.id
assert user_group_assoc.role == UserRoleEnum.member
@pytest.mark.asyncio
async def test_add_user_to_group_already_member(mock_db_session, db_group_model, creator_user_model):
with patch('app.crud.group.is_user_member', new_callable=AsyncMock) as mock_is_member:
mock_is_member.return_value = True # User is already a member
# No need to mock session.get for the user if is_user_member is true first
user_group_assoc = await add_user_to_group(mock_db_session, db_group_model, creator_user_model.id)
assert user_group_assoc is None # Should return None if user already member
mock_db_session.add.assert_not_called()
@pytest.mark.asyncio
async def test_add_user_to_group_user_not_found(mock_db_session, db_group_model):
with patch('app.crud.group.is_user_member', new_callable=AsyncMock) as mock_is_member:
mock_is_member.return_value = False # User not member initially
mock_db_session.get.return_value = None # User to be added not found
with pytest.raises(UserNotFoundError):
await add_user_to_group(mock_db_session, db_group_model, 999, UserRoleEnum.member)
mock_db_session.add.assert_not_called()
# --- remove_user_from_group Tests ---
@pytest.mark.asyncio
async def test_remove_user_from_group_success(mock_db_session, db_group_model, member_user_model, db_user_group_member_assoc):
# Mock get_user_role_in_group to confirm user is not owner
with patch('app.crud.group.get_user_role_in_group', new_callable=AsyncMock) as mock_get_role:
mock_get_role.return_value = UserRoleEnum.member
# Mock the execute call for the delete statement
mock_delete_result = AsyncMock()
mock_delete_result.rowcount = 1 # Simulate one row was affected/deleted
mock_db_session.execute.return_value = mock_delete_result
removed = await remove_user_from_group(mock_db_session, db_group_model, member_user_model.id)
assert removed is True
mock_db_session.execute.assert_called_once()
# Check that the delete statement was indeed called, e.g., by checking the structure of the query passed to execute
# This is a bit more involved if you want to match the exact SQLAlchemy delete object.
# For now, assert_called_once() confirms it was called.
@pytest.mark.asyncio
async def test_remove_user_from_group_owner_last_member(mock_db_session, db_group_model, creator_user_model):
with patch('app.crud.group.get_user_role_in_group', new_callable=AsyncMock) as mock_get_role, \
patch('app.crud.group.get_group_member_count', new_callable=AsyncMock) as mock_member_count:
mock_get_role.return_value = UserRoleEnum.owner
mock_member_count.return_value = 1 # This user is the last member
with pytest.raises(GroupOperationError, match="Cannot remove the sole owner of a group. Delete the group instead."):
await remove_user_from_group(mock_db_session, db_group_model, creator_user_model.id)
mock_db_session.execute.assert_not_called() # Delete should not be called
@pytest.mark.asyncio
async def test_remove_user_from_group_not_member(mock_db_session, db_group_model, non_member_user_model):
# Mock get_user_role_in_group to return None, indicating not a member or role not found (effectively not a member for removal purposes)
with patch('app.crud.group.get_user_role_in_group', new_callable=AsyncMock) as mock_get_role:
mock_get_role.return_value = None
# For this specific test, we might not even need to mock `execute` if `get_user_role_in_group` returning None
# already causes the function to exit or raise an error handled by `GroupMembershipError`.
# However, if the function proceeds to attempt a delete that affects 0 rows, then `rowcount = 0` is the correct mock.
mock_delete_result = AsyncMock()
mock_delete_result.rowcount = 0
mock_db_session.execute.return_value = mock_delete_result
with pytest.raises(GroupMembershipError, match="User is not a member of the group or cannot be removed."):
await remove_user_from_group(mock_db_session, db_group_model, non_member_user_model.id)
# Depending on the implementation: execute might be called or not.
# If there's a check before executing delete, it might not be called.
# If it tries to delete and finds nothing, it would be called.
# For now, let's assume it could be called. If your function logic prevents it, adjust this.
# mock_db_session.execute.assert_called_once() <--- This might fail if not called
# --- get_group_member_count Tests ---
@pytest.mark.asyncio
async def test_get_group_member_count_success(mock_db_session, db_group_model):
mock_result_count = AsyncMock()
mock_result_count.scalar_one.return_value = 5 # Example count
mock_db_session.execute.return_value = mock_result_count
count = await get_group_member_count(mock_db_session, db_group_model.id)
assert count == 5
# --- check_group_membership Tests ---
@pytest.mark.asyncio
async def test_check_group_membership_is_member(mock_db_session, db_group_model, creator_user_model):
# Mock get_group_by_id
with patch('app.crud.group.get_group_by_id', new_callable=AsyncMock) as mock_get_group, \
patch('app.crud.group.is_user_member', new_callable=AsyncMock) as mock_is_member:
mock_get_group.return_value = db_group_model
mock_is_member.return_value = True
group = await check_group_membership(mock_db_session, db_group_model.id, creator_user_model.id)
assert group is db_group_model
@pytest.mark.asyncio
async def test_check_group_membership_group_not_found(mock_db_session, creator_user_model):
with patch('app.crud.group.get_group_by_id', new_callable=AsyncMock) as mock_get_group:
mock_get_group.return_value = None
with pytest.raises(GroupNotFoundError):
await check_group_membership(mock_db_session, 999, creator_user_model.id)
@pytest.mark.asyncio
async def test_check_group_membership_not_member(mock_db_session, db_group_model, non_member_user_model):
with patch('app.crud.group.get_group_by_id', new_callable=AsyncMock) as mock_get_group, \
patch('app.crud.group.is_user_member', new_callable=AsyncMock) as mock_is_member:
mock_get_group.return_value = db_group_model
mock_is_member.return_value = False
with pytest.raises(GroupMembershipError, match="User is not a member of the specified group"):
await check_group_membership(mock_db_session, db_group_model.id, non_member_user_model.id)
# --- check_user_role_in_group (standalone check, not just membership) ---
@pytest.mark.asyncio
async def test_check_user_role_in_group_sufficient_role(mock_db_session, db_group_model, creator_user_model):
# This test assumes check_group_membership is called internally first, or similar logic applies
with patch('app.crud.group.check_group_membership', new_callable=AsyncMock) as mock_check_membership, \
patch('app.crud.group.get_user_role_in_group', new_callable=AsyncMock) as mock_get_role:
mock_check_membership.return_value = db_group_model # Group exists and user is member
mock_get_role.return_value = UserRoleEnum.owner
# Check if owner has owner role (should pass)
await check_user_role_in_group(mock_db_session, db_group_model.id, creator_user_model.id, UserRoleEnum.owner)
# Check if owner has member role (should pass, as owner is implicitly a member with higher privileges)
await check_user_role_in_group(mock_db_session, db_group_model.id, creator_user_model.id, UserRoleEnum.member)
@pytest.mark.asyncio
async def test_check_user_role_in_group_insufficient_role(mock_db_session, db_group_model, member_user_model):
with patch('app.crud.group.check_group_membership', new_callable=AsyncMock) as mock_check_membership, \
patch('app.crud.group.get_user_role_in_group', new_callable=AsyncMock) as mock_get_role:
mock_check_membership.return_value = db_group_model
mock_get_role.return_value = UserRoleEnum.member
with pytest.raises(GroupPermissionError, match="User does not have the required role in the group."):
await check_user_role_in_group(mock_db_session, db_group_model.id, member_user_model.id, UserRoleEnum.owner)
# Future test ideas, to be moved to a proper test planning tool or issue tracker.
# Consider these during major refactors or when expanding test coverage.
# Example of a DB operational error test (can be adapted for other functions)
# @pytest.mark.asyncio
# async def test_create_group_operational_error(mock_db_session, group_create_data, creator_user_model):
# mock_db_session.get.return_value = creator_user_model
# mock_db_session.flush.side_effect = OperationalError("mock operational error", "params", "orig")
# with pytest.raises(DatabaseConnectionError): # Assuming OperationalError maps to this
# await create_group(mock_db_session, group_create_data, creator_user_model.id)
# mock_db_session.rollback.assert_called_once()

View File

@ -0,0 +1,174 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError # Assuming these might be raised
from datetime import datetime, timedelta, timezone
import secrets
from app.crud.invite import (
create_invite,
get_active_invite_by_code,
deactivate_invite,
MAX_CODE_GENERATION_ATTEMPTS
)
from app.models import Invite as InviteModel, User as UserModel, Group as GroupModel # For context
# No specific schemas for invite CRUD usually, but models are used.
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.execute = AsyncMock()
return session
@pytest.fixture
def group_model():
return GroupModel(id=1, name="Test Group")
@pytest.fixture
def user_model(): # Creator
return UserModel(id=1, name="Creator User")
@pytest.fixture
def db_invite_model(group_model, user_model):
return InviteModel(
id=1,
code="test_invite_code_123",
group_id=group_model.id,
created_by_id=user_model.id,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
is_active=True
)
# --- create_invite Tests ---
@pytest.mark.asyncio
@patch('app.crud.invite.secrets.token_urlsafe') # Patch secrets.token_urlsafe
async def test_create_invite_success_first_attempt(mock_token_urlsafe, mock_db_session, group_model, user_model):
generated_code = "unique_code_123"
mock_token_urlsafe.return_value = generated_code
# Mock DB execute for checking existing code (first attempt, no existing code)
mock_existing_check_result = AsyncMock()
mock_existing_check_result.scalar_one_or_none.return_value = None
mock_db_session.execute.return_value = mock_existing_check_result
invite = await create_invite(mock_db_session, group_model.id, user_model.id, expires_in_days=5)
mock_token_urlsafe.assert_called_once_with(16)
mock_db_session.execute.assert_called_once() # For the uniqueness check
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
mock_db_session.refresh.assert_called_once_with(invite)
assert invite is not None
assert invite.code == generated_code
assert invite.group_id == group_model.id
assert invite.created_by_id == user_model.id
assert invite.is_active is True
assert invite.expires_at > datetime.now(timezone.utc) + timedelta(days=4) # Check expiry is roughly correct
@pytest.mark.asyncio
@patch('app.crud.invite.secrets.token_urlsafe')
async def test_create_invite_success_after_collision(mock_token_urlsafe, mock_db_session, group_model, user_model):
colliding_code = "colliding_code"
unique_code = "finally_unique_code"
mock_token_urlsafe.side_effect = [colliding_code, unique_code] # First call collides, second is unique
# Mock DB execute for checking existing code
mock_collision_check_result = AsyncMock()
mock_collision_check_result.scalar_one_or_none.return_value = 1 # Simulate collision (ID found)
mock_no_collision_check_result = AsyncMock()
mock_no_collision_check_result.scalar_one_or_none.return_value = None # No collision
mock_db_session.execute.side_effect = [mock_collision_check_result, mock_no_collision_check_result]
invite = await create_invite(mock_db_session, group_model.id, user_model.id)
assert mock_token_urlsafe.call_count == 2
assert mock_db_session.execute.call_count == 2
assert invite is not None
assert invite.code == unique_code
@pytest.mark.asyncio
@patch('app.crud.invite.secrets.token_urlsafe')
async def test_create_invite_fails_after_max_attempts(mock_token_urlsafe, mock_db_session, group_model, user_model):
mock_token_urlsafe.return_value = "always_colliding_code"
mock_collision_check_result = AsyncMock()
mock_collision_check_result.scalar_one_or_none.return_value = 1 # Always collide
mock_db_session.execute.return_value = mock_collision_check_result
invite = await create_invite(mock_db_session, group_model.id, user_model.id)
assert invite is None
assert mock_token_urlsafe.call_count == MAX_CODE_GENERATION_ATTEMPTS
assert mock_db_session.execute.call_count == MAX_CODE_GENERATION_ATTEMPTS
mock_db_session.add.assert_not_called()
# --- get_active_invite_by_code Tests ---
@pytest.mark.asyncio
async def test_get_active_invite_by_code_found_active(mock_db_session, db_invite_model):
db_invite_model.is_active = True
db_invite_model.expires_at = datetime.now(timezone.utc) + timedelta(days=1)
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_invite_model
mock_db_session.execute.return_value = mock_result
invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code)
assert invite is not None
assert invite.code == db_invite_model.code
mock_db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_active_invite_by_code_not_found(mock_db_session):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None
mock_db_session.execute.return_value = mock_result
invite = await get_active_invite_by_code(mock_db_session, "non_existent_code")
assert invite is None
@pytest.mark.asyncio
async def test_get_active_invite_by_code_inactive(mock_db_session, db_invite_model):
db_invite_model.is_active = False # Inactive
db_invite_model.expires_at = datetime.now(timezone.utc) + timedelta(days=1)
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None # Should not be found by query
mock_db_session.execute.return_value = mock_result
invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code)
assert invite is None
@pytest.mark.asyncio
async def test_get_active_invite_by_code_expired(mock_db_session, db_invite_model):
db_invite_model.is_active = True
db_invite_model.expires_at = datetime.now(timezone.utc) - timedelta(days=1) # Expired
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None # Should not be found by query
mock_db_session.execute.return_value = mock_result
invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code)
assert invite is None
# --- deactivate_invite Tests ---
@pytest.mark.asyncio
async def test_deactivate_invite_success(mock_db_session, db_invite_model):
db_invite_model.is_active = True # Ensure it starts active
deactivated_invite = await deactivate_invite(mock_db_session, db_invite_model)
mock_db_session.add.assert_called_once_with(db_invite_model)
mock_db_session.commit.assert_called_once()
mock_db_session.refresh.assert_called_once_with(db_invite_model)
assert deactivated_invite.is_active is False
# It might be useful to test DB error cases (OperationalError, etc.) for each function
# if they have specific try-except blocks, but invite.py seems to rely on caller/framework for some of that.
# create_invite has its own DB interaction within the loop, so that's covered.
# get_active_invite_by_code and deactivate_invite are simpler DB ops.

186
be/tests/crud/test_item.py Normal file
View File

@ -0,0 +1,186 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError
from datetime import datetime, timezone
from app.crud.item import (
create_item,
get_items_by_list_id,
get_item_by_id,
update_item,
delete_item
)
from app.schemas.item import ItemCreate, ItemUpdate
from app.models import Item as ItemModel, User as UserModel, List as ListModel
from app.core.exceptions import (
ItemNotFoundError, # Not directly raised by CRUD but good for API layer tests
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
ConflictError
)
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.begin = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock()
session.execute = AsyncMock()
session.get = AsyncMock() # Though not directly used in item.py, good for consistency
session.flush = AsyncMock()
return session
@pytest.fixture
def item_create_data():
return ItemCreate(name="Test Item", quantity="1 pack")
@pytest.fixture
def item_update_data():
return ItemUpdate(name="Updated Test Item", quantity="2 packs", version=1, is_complete=False)
@pytest.fixture
def user_model():
return UserModel(id=1, name="Test User", email="test@example.com")
@pytest.fixture
def list_model():
return ListModel(id=1, name="Test List")
@pytest.fixture
def db_item_model(list_model, user_model):
return ItemModel(
id=1,
name="Existing Item",
quantity="1 unit",
list_id=list_model.id,
added_by_id=user_model.id,
is_complete=False,
version=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
# --- create_item Tests ---
@pytest.mark.asyncio
async def test_create_item_success(mock_db_session, item_create_data, list_model, user_model):
async def mock_refresh(instance):
instance.id = 10 # Simulate ID assignment
instance.version = 1 # Simulate version init
instance.created_at = datetime.now(timezone.utc)
instance.updated_at = datetime.now(timezone.utc)
return None
mock_db_session.refresh = AsyncMock(side_effect=mock_refresh)
created_item = await create_item(mock_db_session, item_create_data, list_model.id, user_model.id)
mock_db_session.add.assert_called_once()
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once_with(created_item)
assert created_item is not None
assert created_item.name == item_create_data.name
assert created_item.list_id == list_model.id
assert created_item.added_by_id == user_model.id
assert created_item.is_complete is False
assert created_item.version == 1
@pytest.mark.asyncio
async def test_create_item_integrity_error(mock_db_session, item_create_data, list_model, user_model):
mock_db_session.flush.side_effect = IntegrityError("mock integrity error", "params", "orig")
with pytest.raises(DatabaseIntegrityError):
await create_item(mock_db_session, item_create_data, list_model.id, user_model.id)
mock_db_session.rollback.assert_called_once()
# --- get_items_by_list_id Tests ---
@pytest.mark.asyncio
async def test_get_items_by_list_id_success(mock_db_session, db_item_model, list_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_item_model]
mock_db_session.execute.return_value = mock_result
items = await get_items_by_list_id(mock_db_session, list_model.id)
assert len(items) == 1
assert items[0].id == db_item_model.id
mock_db_session.execute.assert_called_once()
# --- get_item_by_id Tests ---
@pytest.mark.asyncio
async def test_get_item_by_id_found(mock_db_session, db_item_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_item_model
mock_db_session.execute.return_value = mock_result
item = await get_item_by_id(mock_db_session, db_item_model.id)
assert item is not None
assert item.id == db_item_model.id
@pytest.mark.asyncio
async def test_get_item_by_id_not_found(mock_db_session):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None
mock_db_session.execute.return_value = mock_result
item = await get_item_by_id(mock_db_session, 999)
assert item is None
# --- update_item Tests ---
@pytest.mark.asyncio
async def test_update_item_success(mock_db_session, db_item_model, item_update_data, user_model):
item_update_data.version = db_item_model.version # Match versions for successful update
item_update_data.name = "Newly Updated Name"
item_update_data.is_complete = True # Test completion logic
updated_item = await update_item(mock_db_session, db_item_model, item_update_data, user_model.id)
mock_db_session.add.assert_called_once_with(db_item_model) # add is used for existing objects too
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once_with(db_item_model)
assert updated_item.name == "Newly Updated Name"
assert updated_item.version == db_item_model.version + 1 # Check version increment logic in function
assert updated_item.is_complete is True
assert updated_item.completed_by_id == user_model.id
@pytest.mark.asyncio
async def test_update_item_version_conflict(mock_db_session, db_item_model, item_update_data, user_model):
item_update_data.version = db_item_model.version + 1 # Create a version mismatch
with pytest.raises(ConflictError):
await update_item(mock_db_session, db_item_model, item_update_data, user_model.id)
mock_db_session.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_update_item_set_incomplete(mock_db_session, db_item_model, item_update_data, user_model):
db_item_model.is_complete = True # Start as complete
db_item_model.completed_by_id = user_model.id
db_item_model.version = 1
item_update_data.version = 1
item_update_data.is_complete = False
item_update_data.name = db_item_model.name # No name change for this test
item_update_data.quantity = db_item_model.quantity
updated_item = await update_item(mock_db_session, db_item_model, item_update_data, user_model.id)
assert updated_item.is_complete is False
assert updated_item.completed_by_id is None
assert updated_item.version == 2
# --- delete_item Tests ---
@pytest.mark.asyncio
async def test_delete_item_success(mock_db_session, db_item_model):
result = await delete_item(mock_db_session, db_item_model)
assert result is None
mock_db_session.delete.assert_called_once_with(db_item_model)
# Assuming delete_item commits the session or is called within a transaction that commits.
# If delete_item itself doesn't commit, this might need to be adjusted based on calling context.
# mock_db_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_delete_item_db_error(mock_db_session, db_item_model):
mock_db_session.delete.side_effect = OperationalError("mock op error", "params", "orig")
with pytest.raises(DatabaseTransactionError): # Changed to DatabaseTransactionError based on crud logic
await delete_item(mock_db_session, db_item_model)
mock_db_session.rollback.assert_called_once()
# TODO: Add more specific DB error tests (Operational, SQLAlchemyError) for each function.

351
be/tests/crud/test_list.py Normal file
View File

@ -0,0 +1,351 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError
from sqlalchemy.future import select
from sqlalchemy import func as sql_func # For get_list_status
from datetime import datetime, timezone
from app.crud.list import (
create_list,
get_lists_for_user,
get_list_by_id,
update_list,
delete_list,
check_list_permission,
get_list_status
)
from app.schemas.list import ListCreate, ListUpdate, ListStatus
from app.models import List as ListModel, User as UserModel, Group as GroupModel, UserGroup as UserGroupModel, Item as ItemModel
from app.core.exceptions import (
ListNotFoundError,
ListPermissionError,
ListCreatorRequiredError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
ConflictError
)
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock() # Overall session mock
# For session.begin() and session.begin_nested()
# These are sync methods returning an async context manager.
# The returned AsyncMock will act as the async context manager.
mock_transaction_context = AsyncMock()
session.begin = MagicMock(return_value=mock_transaction_context)
session.begin_nested = MagicMock(return_value=mock_transaction_context) # Can use the same or a new one
# Async methods on the session itself
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.execute = AsyncMock() # Correct: execute is async
session.get = AsyncMock() # Correct: get is async
session.flush = AsyncMock() # Correct: flush is async
# Sync methods on the session
session.add = MagicMock()
session.delete = MagicMock()
session.in_transaction = MagicMock(return_value=False)
return session
@pytest.fixture
def list_create_data():
return ListCreate(name="New Shopping List", description="Groceries for the week")
@pytest.fixture
def list_update_data():
return ListUpdate(name="Updated Shopping List", description="Weekend Groceries", version=1)
@pytest.fixture
def user_model():
return UserModel(id=1, name="Test User", email="test@example.com")
@pytest.fixture
def another_user_model():
return UserModel(id=2, name="Another User", email="another@example.com")
@pytest.fixture
def group_model():
return GroupModel(id=1, name="Test Group")
@pytest.fixture
def db_list_personal_model(user_model):
return ListModel(
id=1, name="Personal List", created_by_id=user_model.id, creator=user_model,
version=1, updated_at=datetime.now(timezone.utc), items=[]
)
@pytest.fixture
def db_list_group_model(user_model, group_model):
return ListModel(
id=2, name="Group List", created_by_id=user_model.id, creator=user_model,
group_id=group_model.id, group=group_model, version=1, updated_at=datetime.now(timezone.utc), items=[]
)
# --- create_list Tests ---
@pytest.mark.asyncio
async def test_create_list_success(mock_db_session, list_create_data, user_model):
async def mock_refresh(instance):
instance.id = 100
instance.version = 1
instance.updated_at = datetime.now(timezone.utc)
return None
mock_db_session.refresh.side_effect = mock_refresh
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = ListModel(
id=100,
name=list_create_data.name,
description=list_create_data.description,
created_by_id=user_model.id,
version=1,
updated_at=datetime.now(timezone.utc)
)
mock_db_session.execute.return_value = mock_result
created_list = await create_list(mock_db_session, list_create_data, user_model.id)
mock_db_session.add.assert_called_once()
mock_db_session.flush.assert_called_once()
assert created_list.name == list_create_data.name
assert created_list.created_by_id == user_model.id
# --- get_lists_for_user Tests ---
@pytest.mark.asyncio
async def test_get_lists_for_user_mix(mock_db_session, user_model, db_list_personal_model, db_list_group_model):
# Mock for the object returned by .scalars() for group_ids query
mock_group_ids_scalar_result = MagicMock()
mock_group_ids_scalar_result.all.return_value = [db_list_group_model.group_id]
# Mock for the object returned by await session.execute() for group_ids query
mock_group_ids_execute_result = MagicMock()
mock_group_ids_execute_result.scalars.return_value = mock_group_ids_scalar_result
# Mock for the object returned by .scalars() for lists query
mock_lists_scalar_result = MagicMock()
mock_lists_scalar_result.all.return_value = [db_list_personal_model, db_list_group_model]
# Mock for the object returned by await session.execute() for lists query
mock_lists_execute_result = MagicMock()
mock_lists_execute_result.scalars.return_value = mock_lists_scalar_result
mock_db_session.execute.side_effect = [mock_group_ids_execute_result, mock_lists_execute_result]
lists = await get_lists_for_user(mock_db_session, user_model.id)
assert len(lists) == 2
assert db_list_personal_model in lists
assert db_list_group_model in lists
assert mock_db_session.execute.call_count == 2
# --- get_list_by_id Tests ---
@pytest.mark.asyncio
async def test_get_list_by_id_found_no_items(mock_db_session, db_list_personal_model):
# Mock for the object returned by .scalars()
mock_scalar_result = MagicMock()
mock_scalar_result.first.return_value = db_list_personal_model
# Mock for the object returned by await session.execute()
mock_execute_result = MagicMock()
mock_execute_result.scalars.return_value = mock_scalar_result
mock_db_session.execute.return_value = mock_execute_result
found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=False)
assert found_list is not None
assert found_list.id == db_list_personal_model.id
@pytest.mark.asyncio
async def test_get_list_by_id_found_with_items(mock_db_session, db_list_personal_model):
db_list_personal_model.items = [ItemModel(id=1, name="Test Item")]
# Mock for the object returned by .scalars()
mock_scalar_result = MagicMock()
mock_scalar_result.first.return_value = db_list_personal_model
# Mock for the object returned by await session.execute()
mock_execute_result = MagicMock()
mock_execute_result.scalars.return_value = mock_scalar_result
mock_db_session.execute.return_value = mock_execute_result
found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=True)
assert found_list is not None
assert len(found_list.items) == 1
# --- update_list Tests ---
@pytest.mark.asyncio
async def test_update_list_success(mock_db_session, db_list_personal_model, list_update_data):
list_update_data.version = db_list_personal_model.version
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = db_list_personal_model
mock_db_session.execute.return_value = mock_result
updated_list = await update_list(mock_db_session, db_list_personal_model, list_update_data)
assert updated_list.name == list_update_data.name
assert updated_list.version == db_list_personal_model.version + 1
mock_db_session.add.assert_called_once_with(db_list_personal_model)
mock_db_session.flush.assert_called_once()
@pytest.mark.asyncio
async def test_update_list_conflict(mock_db_session, db_list_personal_model, list_update_data):
list_update_data.version = db_list_personal_model.version + 1 # Simulate version mismatch
# When update_list is called with a version mismatch, it should raise ConflictError
with pytest.raises(ConflictError):
await update_list(mock_db_session, db_list_personal_model, list_update_data)
# Ensure rollback was called if a conflict occurred and was handled within update_list
# This depends on how update_list implements error handling.
# If update_list is expected to call session.rollback(), this assertion is valid.
# If the caller of update_list is responsible for rollback, this might not be asserted here.
# Based on the provided context, ConflictError is raised by update_list,
# implying internal rollback or no changes persisted.
# Let's assume for now the function itself handles rollback or prevents commit.
# mock_db_session.rollback.assert_called_once() # This might be too specific depending on impl.
# --- delete_list Tests ---
@pytest.mark.asyncio
async def test_delete_list_success(mock_db_session, db_list_personal_model):
await delete_list(mock_db_session, db_list_personal_model)
mock_db_session.delete.assert_called_once_with(db_list_personal_model)
# mock_db_session.flush.assert_called_once() # delete usually implies a flush
# --- check_list_permission Tests ---
@pytest.mark.asyncio
async def test_check_list_permission_creator_access_personal_list(mock_db_session, db_list_personal_model, user_model):
# Mock for the object returned by .scalars()
mock_scalar_result = MagicMock()
mock_scalar_result.first.return_value = db_list_personal_model
# Mock for the object returned by await session.execute()
mock_execute_result = MagicMock()
mock_execute_result.scalars.return_value = mock_scalar_result
mock_db_session.execute.return_value = mock_execute_result
ret_list = await check_list_permission(mock_db_session, db_list_personal_model.id, user_model.id)
assert ret_list.id == db_list_personal_model.id
@pytest.mark.asyncio
async def test_check_list_permission_group_member_access_group_list(mock_db_session, db_list_group_model, another_user_model, group_model):
# Mock for the object returned by .scalars()
mock_scalar_result = MagicMock()
mock_scalar_result.first.return_value = db_list_group_model
# Mock for the object returned by await session.execute()
mock_execute_result = MagicMock()
mock_execute_result.scalars.return_value = mock_scalar_result
mock_db_session.execute.return_value = mock_execute_result
with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member:
mock_is_member.return_value = True
ret_list = await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id)
assert ret_list.id == db_list_group_model.id
mock_is_member.assert_called_once_with(mock_db_session, group_id=group_model.id, user_id=another_user_model.id)
@pytest.mark.asyncio
async def test_check_list_permission_non_member_no_access_group_list(mock_db_session, db_list_group_model, another_user_model):
# Mock for the object returned by .scalars()
mock_scalar_result = MagicMock()
mock_scalar_result.first.return_value = db_list_group_model
# Mock for the object returned by await session.execute()
mock_execute_result = MagicMock()
mock_execute_result.scalars.return_value = mock_scalar_result
mock_db_session.execute.return_value = mock_execute_result
with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member:
mock_is_member.return_value = False
with pytest.raises(ListPermissionError):
await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id)
@pytest.mark.asyncio
async def test_check_list_permission_creator_required_fail(mock_db_session, db_list_group_model, another_user_model):
# Simulate another_user_model is not the creator of db_list_group_model
# db_list_group_model.created_by_id is user_model.id (1), another_user_model.id is 2
# Mock for the object returned by .scalars()
mock_scalar_result = MagicMock()
mock_scalar_result.first.return_value = db_list_group_model # List is found
# Mock for the object returned by await session.execute()
mock_execute_result = MagicMock()
mock_execute_result.scalars.return_value = mock_scalar_result
mock_db_session.execute.return_value = mock_execute_result
# No need to mock is_user_member if require_creator is True and user is not creator
with pytest.raises(ListCreatorRequiredError):
await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id, require_creator=True)
@pytest.mark.asyncio
async def test_check_list_permission_list_not_found(mock_db_session, user_model):
# Mock for the object returned by .scalars()
mock_scalar_result = MagicMock()
mock_scalar_result.first.return_value = None # Simulate list not found
# Mock for the object returned by await session.execute()
mock_execute_result = MagicMock()
mock_execute_result.scalars.return_value = mock_scalar_result
mock_db_session.execute.return_value = mock_execute_result
with pytest.raises(ListNotFoundError):
await check_list_permission(mock_db_session, 999, user_model.id)
# --- get_list_status Tests ---
@pytest.mark.asyncio
async def test_get_list_status_success(mock_db_session, db_list_personal_model):
# This test is more complex due to multiple potential execute calls or specific query structures
# For simplicity, assuming the primary query for the list model uses the same pattern:
# Mock for finding the list by ID (first execute call in get_list_status)
mock_list_scalar = MagicMock()
mock_list_scalar.first.return_value = db_list_personal_model
mock_list_execute = MagicMock()
mock_list_execute.scalars.return_value = mock_list_scalar
# Mock for counting total items (second execute call)
mock_total_items_scalar = MagicMock()
mock_total_items_scalar.one.return_value = 5
mock_total_items_execute = MagicMock()
mock_total_items_execute.scalars.return_value = mock_total_items_scalar
# Mock for counting completed items (third execute call)
mock_completed_items_scalar = MagicMock()
mock_completed_items_scalar.one.return_value = 2
mock_completed_items_execute = MagicMock()
mock_completed_items_execute.scalars.return_value = mock_completed_items_scalar
mock_db_session.execute.side_effect = [
mock_list_execute,
mock_total_items_execute,
mock_completed_items_execute
]
status = await get_list_status(mock_db_session, db_list_personal_model.id)
assert status.list_id == db_list_personal_model.id
assert status.total_items == 5
assert status.completed_items == 2
assert status.name == db_list_personal_model.name
assert mock_db_session.execute.call_count == 3
@pytest.mark.asyncio
async def test_get_list_status_list_not_found(mock_db_session):
# Mock for the object returned by .scalars()
mock_scalar_result = MagicMock()
mock_scalar_result.first.return_value = None # List not found
# Mock for the object returned by await session.execute()
mock_execute_result = MagicMock()
mock_execute_result.scalars.return_value = mock_scalar_result
mock_db_session.execute.return_value = mock_execute_result
with pytest.raises(ListNotFoundError):
await get_list_status(mock_db_session, 999)
# TODO: Add more specific DB error tests (Operational, SQLAlchemyError, IntegrityError) for each function.
# TODO: Test check_list_permission with require_creator=True cases.

View File

@ -0,0 +1,277 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.future import select
from decimal import Decimal, ROUND_HALF_UP
from datetime import datetime, timezone
from typing import List as PyList
from app.crud.settlement import (
create_settlement,
get_settlement_by_id,
get_settlements_for_group,
get_settlements_involving_user,
update_settlement,
delete_settlement
)
from app.schemas.expense import SettlementCreate, SettlementUpdate
from app.models import Settlement as SettlementModel, User as UserModel, Group as GroupModel
from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError, ConflictError
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.begin = AsyncMock()
session.begin_nested = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock()
session.execute = AsyncMock()
session.get = AsyncMock()
session.flush = AsyncMock()
session.in_transaction = MagicMock(return_value=False)
return session
@pytest.fixture
def settlement_create_data():
return SettlementCreate(
group_id=1,
paid_by_user_id=1,
paid_to_user_id=2,
amount=Decimal("10.50"),
settlement_date=datetime.now(timezone.utc),
description="Test settlement"
)
@pytest.fixture
def settlement_update_data():
return SettlementUpdate(
description="Updated settlement description",
settlement_date=datetime.now(timezone.utc),
version=1 # Assuming version is required for update
)
@pytest.fixture
def db_settlement_model():
return SettlementModel(
id=1,
group_id=1,
paid_by_user_id=1,
paid_to_user_id=2,
amount=Decimal("10.50"),
settlement_date=datetime.now(timezone.utc),
description="Original settlement",
created_by_user_id=1,
version=1, # Initial version
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
payer=UserModel(id=1, name="Payer User"),
payee=UserModel(id=2, name="Payee User"),
group=GroupModel(id=1, name="Test Group"),
created_by_user=UserModel(id=1, name="Payer User") # Same as payer for simplicity
)
@pytest.fixture
def payer_user_model():
return UserModel(id=1, name="Payer User", email="payer@example.com")
@pytest.fixture
def payee_user_model():
return UserModel(id=2, name="Payee User", email="payee@example.com")
@pytest.fixture
def group_model():
return GroupModel(id=1, name="Test Group")
# Tests for create_settlement
@pytest.mark.asyncio
async def test_create_settlement_success(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model):
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model]
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = SettlementModel(
id=1,
group_id=settlement_create_data.group_id,
paid_by_user_id=settlement_create_data.paid_by_user_id,
paid_to_user_id=settlement_create_data.paid_to_user_id,
amount=settlement_create_data.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
settlement_date=settlement_create_data.settlement_date,
description=settlement_create_data.description,
created_by_user_id=1,
version=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
mock_db_session.execute.return_value = mock_result
created_settlement = await create_settlement(mock_db_session, settlement_create_data, current_user_id=1)
mock_db_session.add.assert_called_once()
mock_db_session.flush.assert_called_once()
assert created_settlement is not None
assert created_settlement.amount == settlement_create_data.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
assert created_settlement.paid_by_user_id == settlement_create_data.paid_by_user_id
assert created_settlement.paid_to_user_id == settlement_create_data.paid_to_user_id
@pytest.mark.asyncio
async def test_create_settlement_payer_not_found(mock_db_session, settlement_create_data):
mock_db_session.get.side_effect = [None, payee_user_model, group_model]
with pytest.raises(UserNotFoundError) as excinfo:
await create_settlement(mock_db_session, settlement_create_data, 1)
assert "Payer" in str(excinfo.value)
@pytest.mark.asyncio
async def test_create_settlement_payee_not_found(mock_db_session, settlement_create_data, payer_user_model):
mock_db_session.get.side_effect = [payer_user_model, None, group_model]
with pytest.raises(UserNotFoundError) as excinfo:
await create_settlement(mock_db_session, settlement_create_data, 1)
assert "Payee" in str(excinfo.value)
@pytest.mark.asyncio
async def test_create_settlement_group_not_found(mock_db_session, settlement_create_data, payer_user_model, payee_user_model):
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, None]
with pytest.raises(GroupNotFoundError):
await create_settlement(mock_db_session, settlement_create_data, 1)
@pytest.mark.asyncio
async def test_create_settlement_payer_equals_payee(mock_db_session, settlement_create_data, payer_user_model, group_model):
settlement_create_data.paid_to_user_id = settlement_create_data.paid_by_user_id
mock_db_session.get.side_effect = [payer_user_model, payer_user_model, group_model]
with pytest.raises(InvalidOperationError) as excinfo:
await create_settlement(mock_db_session, settlement_create_data, 1)
assert "Payer and Payee cannot be the same user" in str(excinfo.value)
@pytest.mark.asyncio
async def test_create_settlement_commit_failure(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model):
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model]
mock_db_session.commit.side_effect = Exception("DB commit failed")
with pytest.raises(InvalidOperationError) as excinfo:
await create_settlement(mock_db_session, settlement_create_data, 1)
assert "Failed to save settlement" in str(excinfo.value)
mock_db_session.rollback.assert_called_once()
# Tests for get_settlement_by_id
@pytest.mark.asyncio
async def test_get_settlement_by_id_found(mock_db_session, db_settlement_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_settlement_model
mock_db_session.execute.return_value = mock_result
settlement = await get_settlement_by_id(mock_db_session, 1)
assert settlement is not None
assert settlement.id == 1
mock_db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_settlement_by_id_not_found(mock_db_session):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None
mock_db_session.execute.return_value = mock_result
settlement = await get_settlement_by_id(mock_db_session, 999)
assert settlement is None
# Tests for get_settlements_for_group
@pytest.mark.asyncio
async def test_get_settlements_for_group_success(mock_db_session, db_settlement_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_settlement_model]
mock_db_session.execute.return_value = mock_result
settlements = await get_settlements_for_group(mock_db_session, group_id=1)
assert len(settlements) == 1
assert settlements[0].group_id == 1
mock_db_session.execute.assert_called_once()
# Tests for get_settlements_involving_user
@pytest.mark.asyncio
async def test_get_settlements_involving_user_success(mock_db_session, db_settlement_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_settlement_model]
mock_db_session.execute.return_value = mock_result
settlements = await get_settlements_involving_user(mock_db_session, user_id=1)
assert len(settlements) == 1
assert settlements[0].paid_by_user_id == 1 or settlements[0].paid_to_user_id == 1
mock_db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_settlements_involving_user_with_group_filter(mock_db_session, db_settlement_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_settlement_model]
mock_db_session.execute.return_value = mock_result
settlements = await get_settlements_involving_user(mock_db_session, user_id=1, group_id=1)
assert len(settlements) == 1
mock_db_session.execute.assert_called_once()
# Tests for update_settlement
@pytest.mark.asyncio
async def test_update_settlement_success(mock_db_session, db_settlement_model, settlement_update_data):
settlement_update_data.version = db_settlement_model.version
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = db_settlement_model
mock_db_session.execute.return_value = mock_result
updated_settlement = await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
mock_db_session.add.assert_called_once_with(db_settlement_model)
mock_db_session.flush.assert_called_once()
assert updated_settlement.description == settlement_update_data.description
assert updated_settlement.settlement_date == settlement_update_data.settlement_date
assert updated_settlement.version == db_settlement_model.version + 1
@pytest.mark.asyncio
async def test_update_settlement_version_mismatch(mock_db_session, db_settlement_model, settlement_update_data):
settlement_update_data.version = db_settlement_model.version + 1
with pytest.raises(ConflictError):
await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
mock_db_session.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_update_settlement_disallowed_field(mock_db_session, db_settlement_model):
# Create an update payload with a field not allowed to be updated, e.g., 'amount'
invalid_update_data = SettlementUpdate(amount=Decimal("100.00"), version=db_settlement_model.version)
with pytest.raises(InvalidOperationError) as excinfo:
await update_settlement(mock_db_session, db_settlement_model, invalid_update_data)
assert "Field 'amount' cannot be updated" in str(excinfo.value)
@pytest.mark.asyncio
async def test_update_settlement_commit_failure(mock_db_session, db_settlement_model, settlement_update_data):
settlement_update_data.version = db_settlement_model.version
mock_db_session.commit.side_effect = Exception("DB commit failed")
with pytest.raises(InvalidOperationError) as excinfo:
await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
assert "Failed to update settlement" in str(excinfo.value)
mock_db_session.rollback.assert_called_once()
# Tests for delete_settlement
@pytest.mark.asyncio
async def test_delete_settlement_success(mock_db_session, db_settlement_model):
await delete_settlement(mock_db_session, db_settlement_model)
mock_db_session.delete.assert_called_once_with(db_settlement_model)
mock_db_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_delete_settlement_success_with_version_check(mock_db_session, db_settlement_model):
await delete_settlement(mock_db_session, db_settlement_model, expected_version=db_settlement_model.version)
mock_db_session.delete.assert_called_once_with(db_settlement_model)
mock_db_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_delete_settlement_version_mismatch(mock_db_session, db_settlement_model):
db_settlement_model.version = 2
with pytest.raises(ConflictError):
await delete_settlement(mock_db_session, db_settlement_model, expected_version=1)
mock_db_session.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_delete_settlement_commit_failure(mock_db_session, db_settlement_model):
mock_db_session.commit.side_effect = Exception("DB commit failed")
with pytest.raises(InvalidOperationError) as excinfo:
await delete_settlement(mock_db_session, db_settlement_model)
assert "Failed to delete settlement" in str(excinfo.value)
mock_db_session.rollback.assert_called_once()

View File

@ -0,0 +1,369 @@
import pytest
from decimal import Decimal
from datetime import datetime, timezone
from typing import AsyncGenerator, List
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import (
User,
Group,
Expense,
ExpenseSplit,
SettlementActivity,
ExpenseSplitStatusEnum,
ExpenseOverallStatusEnum,
SplitTypeEnum,
UserRoleEnum
)
from app.crud.settlement_activity import (
create_settlement_activity,
get_settlement_activity_by_id,
get_settlement_activities_for_split,
update_expense_split_status, # For direct testing if needed
update_expense_overall_status # For direct testing if needed
)
from app.schemas.settlement_activity import SettlementActivityCreate as SettlementActivityCreateSchema
@pytest.fixture
async def test_user1(db_session: AsyncSession) -> User:
user = User(email="user1@example.com", name="Test User 1", hashed_password="password1")
db_session.add(user)
await db_session.commit()
await db_session.refresh(user)
return user
@pytest.fixture
async def test_user2(db_session: AsyncSession) -> User:
user = User(email="user2@example.com", name="Test User 2", hashed_password="password2")
db_session.add(user)
await db_session.commit()
await db_session.refresh(user)
return user
@pytest.fixture
async def test_group(db_session: AsyncSession, test_user1: User) -> Group:
group = Group(name="Test Group", created_by_id=test_user1.id)
db_session.add(group)
await db_session.commit()
# Add user1 as owner and user2 as member (can be done in specific tests if needed)
await db_session.refresh(group)
return group
@pytest.fixture
async def test_expense(db_session: AsyncSession, test_user1: User, test_group: Group) -> Expense:
expense = Expense(
description="Test Expense for Settlement",
total_amount=Decimal("20.00"),
currency="USD",
expense_date=datetime.now(timezone.utc),
split_type=SplitTypeEnum.EQUAL,
group_id=test_group.id,
paid_by_user_id=test_user1.id,
created_by_user_id=test_user1.id,
overall_settlement_status=ExpenseOverallStatusEnum.unpaid # Initial status
)
db_session.add(expense)
await db_session.commit()
await db_session.refresh(expense)
return expense
@pytest.fixture
async def test_expense_split_user2_owes(db_session: AsyncSession, test_expense: Expense, test_user2: User) -> ExpenseSplit:
# User2 owes 10.00 to User1 (who paid the expense)
split = ExpenseSplit(
expense_id=test_expense.id,
user_id=test_user2.id,
owed_amount=Decimal("10.00"),
status=ExpenseSplitStatusEnum.unpaid # Initial status
)
db_session.add(split)
await db_session.commit()
await db_session.refresh(split)
return split
@pytest.fixture
async def test_expense_split_user1_owes_self_for_completeness(db_session: AsyncSession, test_expense: Expense, test_user1: User) -> ExpenseSplit:
# User1's own share (owes 10.00 to self, effectively settled)
# This is often how splits are represented, even for the payer
split = ExpenseSplit(
expense_id=test_expense.id,
user_id=test_user1.id,
owed_amount=Decimal("10.00"), # User1's share of the 20.00 expense
status=ExpenseSplitStatusEnum.unpaid # Initial status, though payer's own share might be considered paid by some logic
)
db_session.add(split)
await db_session.commit()
await db_session.refresh(split)
return split
# --- Tests for create_settlement_activity ---
@pytest.mark.asyncio
async def test_create_settlement_activity_full_payment(
db_session: AsyncSession,
test_user1: User, # Creator of activity, Payer of expense
test_user2: User, # Payer of this settlement activity (settling their debt)
test_expense: Expense,
test_expense_split_user2_owes: ExpenseSplit,
test_expense_split_user1_owes_self_for_completeness: ExpenseSplit # User1's own share
):
# Scenario: User2 fully pays their 10.00 share.
# User1's share is also part of the expense. Let's assume it's 'paid' by default or handled separately.
# For this test, we focus on User2's split.
# To make overall expense paid, User1's split also needs to be considered paid.
# We can manually update User1's split status to paid for this test case.
test_expense_split_user1_owes_self_for_completeness.status = ExpenseSplitStatusEnum.paid
test_expense_split_user1_owes_self_for_completeness.paid_at = datetime.now(timezone.utc)
db_session.add(test_expense_split_user1_owes_self_for_completeness)
await db_session.commit()
await db_session.refresh(test_expense_split_user1_owes_self_for_completeness)
await db_session.refresh(test_expense) # Refresh expense to reflect split status change
activity_data = SettlementActivityCreateSchema(
expense_split_id=test_expense_split_user2_owes.id,
paid_by_user_id=test_user2.id, # User2 is paying their share
amount_paid=Decimal("10.00")
)
created_activity = await create_settlement_activity(
db=db_session,
settlement_activity_in=activity_data,
current_user_id=test_user2.id # User2 is recording their own payment
)
assert created_activity is not None
assert created_activity.expense_split_id == test_expense_split_user2_owes.id
assert created_activity.paid_by_user_id == test_user2.id
assert created_activity.amount_paid == Decimal("10.00")
assert created_activity.created_by_user_id == test_user2.id
await db_session.refresh(test_expense_split_user2_owes)
await db_session.refresh(test_expense) # Refresh to get updated overall_status
assert test_expense_split_user2_owes.status == ExpenseSplitStatusEnum.paid
assert test_expense_split_user2_owes.paid_at is not None
# Check parent expense status
# This depends on all splits being paid for the expense to be fully paid.
assert test_expense.overall_settlement_status == ExpenseOverallStatusEnum.paid
@pytest.mark.asyncio
async def test_create_settlement_activity_partial_payment(
db_session: AsyncSession,
test_user1: User, # Creator of activity
test_user2: User, # Payer of this settlement activity
test_expense: Expense,
test_expense_split_user2_owes: ExpenseSplit
):
activity_data = SettlementActivityCreateSchema(
expense_split_id=test_expense_split_user2_owes.id,
paid_by_user_id=test_user2.id,
amount_paid=Decimal("5.00")
)
created_activity = await create_settlement_activity(
db=db_session,
settlement_activity_in=activity_data,
current_user_id=test_user2.id # User2 records their payment
)
assert created_activity is not None
assert created_activity.amount_paid == Decimal("5.00")
await db_session.refresh(test_expense_split_user2_owes)
await db_session.refresh(test_expense)
assert test_expense_split_user2_owes.status == ExpenseSplitStatusEnum.partially_paid
assert test_expense_split_user2_owes.paid_at is None
assert test_expense.overall_settlement_status == ExpenseOverallStatusEnum.partially_paid # Assuming other splits are unpaid or partially paid
@pytest.mark.asyncio
async def test_create_settlement_activity_multiple_payments_to_full(
db_session: AsyncSession,
test_user1: User,
test_user2: User,
test_expense: Expense,
test_expense_split_user2_owes: ExpenseSplit,
test_expense_split_user1_owes_self_for_completeness: ExpenseSplit # User1's own share
):
# Assume user1's share is already 'paid' for overall expense status testing
test_expense_split_user1_owes_self_for_completeness.status = ExpenseSplitStatusEnum.paid
test_expense_split_user1_owes_self_for_completeness.paid_at = datetime.now(timezone.utc)
db_session.add(test_expense_split_user1_owes_self_for_completeness)
await db_session.commit()
# First partial payment
activity_data1 = SettlementActivityCreateSchema(
expense_split_id=test_expense_split_user2_owes.id,
paid_by_user_id=test_user2.id,
amount_paid=Decimal("3.00")
)
await create_settlement_activity(db=db_session, settlement_activity_in=activity_data1, current_user_id=test_user2.id)
await db_session.refresh(test_expense_split_user2_owes)
await db_session.refresh(test_expense)
assert test_expense_split_user2_owes.status == ExpenseSplitStatusEnum.partially_paid
assert test_expense.overall_settlement_status == ExpenseOverallStatusEnum.partially_paid
# Second payment completing the amount
activity_data2 = SettlementActivityCreateSchema(
expense_split_id=test_expense_split_user2_owes.id,
paid_by_user_id=test_user2.id,
amount_paid=Decimal("7.00") # 3.00 + 7.00 = 10.00
)
await create_settlement_activity(db=db_session, settlement_activity_in=activity_data2, current_user_id=test_user2.id)
await db_session.refresh(test_expense_split_user2_owes)
await db_session.refresh(test_expense)
assert test_expense_split_user2_owes.status == ExpenseSplitStatusEnum.paid
assert test_expense_split_user2_owes.paid_at is not None
assert test_expense.overall_settlement_status == ExpenseOverallStatusEnum.paid
@pytest.mark.asyncio
async def test_create_settlement_activity_invalid_split_id(
db_session: AsyncSession, test_user1: User
):
activity_data = SettlementActivityCreateSchema(
expense_split_id=99999, # Non-existent
paid_by_user_id=test_user1.id,
amount_paid=Decimal("10.00")
)
# The CRUD function returns None for not found related objects
result = await create_settlement_activity(db=db_session, settlement_activity_in=activity_data, current_user_id=test_user1.id)
assert result is None
@pytest.mark.asyncio
async def test_create_settlement_activity_invalid_paid_by_user_id(
db_session: AsyncSession, test_user1: User, test_expense_split_user2_owes: ExpenseSplit
):
activity_data = SettlementActivityCreateSchema(
expense_split_id=test_expense_split_user2_owes.id,
paid_by_user_id=99999, # Non-existent
amount_paid=Decimal("10.00")
)
result = await create_settlement_activity(db=db_session, settlement_activity_in=activity_data, current_user_id=test_user1.id)
assert result is None
# --- Tests for get_settlement_activity_by_id ---
@pytest.mark.asyncio
async def test_get_settlement_activity_by_id_found(
db_session: AsyncSession, test_user2: User, test_expense_split_user2_owes: ExpenseSplit
):
activity_data = SettlementActivityCreateSchema(
expense_split_id=test_expense_split_user2_owes.id,
paid_by_user_id=test_user2.id,
amount_paid=Decimal("5.00")
)
created = await create_settlement_activity(db=db_session, settlement_activity_in=activity_data, current_user_id=test_user2.id)
assert created is not None
fetched = await get_settlement_activity_by_id(db=db_session, settlement_activity_id=created.id)
assert fetched is not None
assert fetched.id == created.id
assert fetched.amount_paid == Decimal("5.00")
@pytest.mark.asyncio
async def test_get_settlement_activity_by_id_not_found(db_session: AsyncSession):
fetched = await get_settlement_activity_by_id(db=db_session, settlement_activity_id=99999)
assert fetched is None
# --- Tests for get_settlement_activities_for_split ---
@pytest.mark.asyncio
async def test_get_settlement_activities_for_split_multiple_found(
db_session: AsyncSession, test_user2: User, test_expense_split_user2_owes: ExpenseSplit
):
act1_data = SettlementActivityCreateSchema(expense_split_id=test_expense_split_user2_owes.id, paid_by_user_id=test_user2.id, amount_paid=Decimal("2.00"))
act2_data = SettlementActivityCreateSchema(expense_split_id=test_expense_split_user2_owes.id, paid_by_user_id=test_user2.id, amount_paid=Decimal("3.00"))
await create_settlement_activity(db=db_session, settlement_activity_in=act1_data, current_user_id=test_user2.id)
await create_settlement_activity(db=db_session, settlement_activity_in=act2_data, current_user_id=test_user2.id)
activities: List[SettlementActivity] = await get_settlement_activities_for_split(db=db_session, expense_split_id=test_expense_split_user2_owes.id)
assert len(activities) == 2
amounts = sorted([act.amount_paid for act in activities])
assert amounts == [Decimal("2.00"), Decimal("3.00")]
@pytest.mark.asyncio
async def test_get_settlement_activities_for_split_none_found(
db_session: AsyncSession, test_expense_split_user2_owes: ExpenseSplit # A split with no activities
):
activities: List[SettlementActivity] = await get_settlement_activities_for_split(db=db_session, expense_split_id=test_expense_split_user2_owes.id)
assert len(activities) == 0
# Note: Direct tests for helper functions update_expense_split_status and update_expense_overall_status
# could be added if complex logic within them isn't fully covered by create_settlement_activity tests.
# However, their effects are validated through the main CRUD function here.
# For example, to test update_expense_split_status directly:
# 1. Create an ExpenseSplit.
# 2. Create one or more SettlementActivity instances directly in the DB session for that split.
# 3. Call await update_expense_split_status(db_session, expense_split_id=split.id).
# 4. Assert the split.status and split.paid_at are as expected.
# Similar for update_expense_overall_status by setting up multiple splits.
# For now, relying on indirect testing via create_settlement_activity.
# More tests can be added for edge cases, such as:
# - Overpayment (current logic in update_expense_split_status treats >= owed_amount as 'paid').
# - Different users creating the activity vs. paying for it (permission aspects, though that's more for API tests).
# - Interactions with different expense split types if that affects status updates.
# - Ensuring `overall_settlement_status` correctly reflects if one split is paid, another is unpaid, etc.
# (e.g. test_expense_split_user1_owes_self_for_completeness is set to unpaid initially).
# A test case where one split becomes 'paid' but another remains 'unpaid' should result in 'partially_paid' for the expense.
@pytest.mark.asyncio
async def test_create_settlement_activity_overall_status_becomes_partially_paid(
db_session: AsyncSession,
test_user1: User,
test_user2: User,
test_expense: Expense, # Overall status is initially unpaid
test_expense_split_user2_owes: ExpenseSplit, # User2's split, initially unpaid
test_expense_split_user1_owes_self_for_completeness: ExpenseSplit # User1's split, also initially unpaid
):
# Sanity check: both splits and expense are unpaid initially
assert test_expense_split_user2_owes.status == ExpenseSplitStatusEnum.unpaid
assert test_expense_split_user1_owes_self_for_completeness.status == ExpenseSplitStatusEnum.unpaid
assert test_expense.overall_settlement_status == ExpenseOverallStatusEnum.unpaid
# User2 fully pays their 10.00 share.
activity_data = SettlementActivityCreateSchema(
expense_split_id=test_expense_split_user2_owes.id,
paid_by_user_id=test_user2.id, # User2 is paying their share
amount_paid=Decimal("10.00")
)
await create_settlement_activity(
db=db_session,
settlement_activity_in=activity_data,
current_user_id=test_user2.id # User2 is recording their own payment
)
await db_session.refresh(test_expense_split_user2_owes)
await db_session.refresh(test_expense_split_user1_owes_self_for_completeness) # Ensure its status is current
await db_session.refresh(test_expense)
assert test_expense_split_user2_owes.status == ExpenseSplitStatusEnum.paid
assert test_expense_split_user1_owes_self_for_completeness.status == ExpenseSplitStatusEnum.unpaid # User1's split is still unpaid
# Since one split is paid and the other is unpaid, the overall expense status should be partially_paid
assert test_expense.overall_settlement_status == ExpenseOverallStatusEnum.partially_paid
# Example of a placeholder for db_session fixture if not provided by conftest.py
# @pytest.fixture
# async def db_session() -> AsyncGenerator[AsyncSession, None]:
# # This needs to be implemented based on your test database setup
# # e.g., using a test-specific database and creating a new session per test
# # from app.database import SessionLocal # Assuming SessionLocal is your session factory
# # async with SessionLocal() as session:
# # async with session.begin(): # Start a transaction
# # yield session
# # # Transaction will be rolled back here after the test
# pass # Replace with actual implementation if needed

128
be/tests/crud/test_user.py Normal file
View File

@ -0,0 +1,128 @@
import pytest
from unittest.mock import AsyncMock, MagicMock
from sqlalchemy.exc import IntegrityError, OperationalError
from app.crud.user import get_user_by_email, create_user
from app.schemas.user import UserCreate
from app.models import User as UserModel
from app.core.exceptions import (
UserCreationError,
EmailAlreadyRegisteredError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError
)
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.begin = AsyncMock()
session.begin_nested = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock()
session.execute = AsyncMock()
session.get = AsyncMock()
session.flush = AsyncMock()
session.in_transaction = MagicMock(return_value=False)
return session
@pytest.fixture
def user_create_data():
return UserCreate(email="test@example.com", password="password123", name="Test User")
@pytest.fixture
def existing_user_data():
return UserModel(id=1, email="exists@example.com", password_hash="hashed_password", name="Existing User")
# Tests for get_user_by_email
@pytest.mark.asyncio
async def test_get_user_by_email_found(mock_db_session, existing_user_data):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = existing_user_data
mock_db_session.execute.return_value = mock_result
user = await get_user_by_email(mock_db_session, "exists@example.com")
assert user is not None
assert user.email == "exists@example.com"
mock_db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_user_by_email_not_found(mock_db_session):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None
mock_db_session.execute.return_value = mock_result
user = await get_user_by_email(mock_db_session, "nonexistent@example.com")
assert user is None
mock_db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_user_by_email_db_connection_error(mock_db_session):
mock_db_session.execute.side_effect = OperationalError("mock_op_error", "params", "orig")
with pytest.raises(DatabaseConnectionError):
await get_user_by_email(mock_db_session, "test@example.com")
@pytest.mark.asyncio
async def test_get_user_by_email_db_query_error(mock_db_session):
# Simulate a generic SQLAlchemyError that is not OperationalError
mock_db_session.execute.side_effect = IntegrityError("mock_sql_error", "params", "orig") # Using IntegrityError as an example of SQLAlchemyError
with pytest.raises(DatabaseQueryError):
await get_user_by_email(mock_db_session, "test@example.com")
# Tests for create_user
@pytest.mark.asyncio
async def test_create_user_success(mock_db_session, user_create_data):
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = UserModel(
id=1,
email=user_create_data.email,
name=user_create_data.name,
password_hash="hashed_password" # This would be set by the actual hash_password function
)
mock_db_session.execute.return_value = mock_result
created_user = await create_user(mock_db_session, user_create_data)
mock_db_session.add.assert_called_once()
mock_db_session.flush.assert_called_once()
assert created_user is not None
assert created_user.email == user_create_data.email
assert created_user.name == user_create_data.name
assert created_user.id == 1
@pytest.mark.asyncio
async def test_create_user_email_already_registered(mock_db_session, user_create_data):
mock_db_session.flush.side_effect = IntegrityError("mock error (unique constraint)", "params", "orig")
with pytest.raises(EmailAlreadyRegisteredError):
await create_user(mock_db_session, user_create_data)
@pytest.mark.asyncio
async def test_create_user_db_integrity_error_not_unique(mock_db_session, user_create_data):
# Simulate an IntegrityError that is not related to a unique constraint
mock_db_session.flush.side_effect = IntegrityError("mock error (not unique constraint)", "params", "orig")
with pytest.raises(DatabaseIntegrityError):
await create_user(mock_db_session, user_create_data)
@pytest.mark.asyncio
async def test_create_user_db_connection_error(mock_db_session, user_create_data):
mock_db_session.begin.side_effect = OperationalError("mock_op_error", "params", "orig")
with pytest.raises(DatabaseConnectionError):
await create_user(mock_db_session, user_create_data)
# also test OperationalError on flush
mock_db_session.begin.side_effect = None # reset side effect
mock_db_session.flush.side_effect = OperationalError("mock_op_error", "params", "orig")
with pytest.raises(DatabaseConnectionError):
await create_user(mock_db_session, user_create_data)
@pytest.mark.asyncio
async def test_create_user_db_transaction_error(mock_db_session, user_create_data):
# Simulate a generic SQLAlchemyError on flush that is not IntegrityError or OperationalError
mock_db_session.flush.side_effect = UserCreationError("Simulated non-specific SQLAlchemyError") # Or any other SQLAlchemyError
with pytest.raises(DatabaseTransactionError):
await create_user(mock_db_session, user_create_data)

28
docker-compose.prod.yml Normal file
View File

@ -0,0 +1,28 @@
services:
backend:
container_name: fastapi_backend_prod
build:
context: ./be
dockerfile: Dockerfile.prod
target: production
environment:
- DATABASE_URL=${DATABASE_URL}
- GEMINI_API_KEY=${GEMINI_API_KEY}
- SECRET_KEY=${SECRET_KEY}
- SESSION_SECRET_KEY=${SESSION_SECRET_KEY}
- SENTRY_DSN=${SENTRY_DSN}
- LOG_LEVEL=INFO
- ENVIRONMENT=production
- CORS_ORIGINS=${CORS_ORIGINS}
- FRONTEND_URL=${FRONTEND_URL}
restart: unless-stopped
frontend:
container_name: frontend_prod
build:
context: ./fe
dockerfile: Dockerfile.prod
target: production
environment:
- VITE_API_URL=https://mitlistbe.mohamad.dev
restart: unless-stopped

View File

@ -1,14 +1,11 @@
# docker-compose.yml (in project root)
version: '3.8'
services:
db:
image: postgres:15 # Use a specific PostgreSQL version
image: postgres:17 # Use a specific PostgreSQL version
container_name: postgres_db
environment:
POSTGRES_USER: dev_user # Define DB user
POSTGRES_PASSWORD: dev_password # Define DB password
POSTGRES_DB: dev_db # Define Database name
POSTGRES_USER: xxx # Define DB user
POSTGRES_PASSWORD: xxx # Define DB password
POSTGRES_DB: xxx # Define Database name
volumes:
- postgres_data:/var/lib/postgresql/data # Persist data using a named volume
ports:
@ -36,30 +33,38 @@ services:
# Pass the database URL to the backend container
# Uses the service name 'db' as the host, and credentials defined above
# IMPORTANT: Use the correct async driver prefix if your app needs it!
- DATABASE_URL=postgresql+asyncpg://dev_user:dev_password@db:5432/dev_db
- DATABASE_URL=postgresql+asyncpg://mitlist_owner:npg_p0SkmyJ6BPWO@ep-small-sound-a9ketcef-pooler.gwc.azure.neon.tech/testnewmig
- GEMINI_API_KEY=xxx
- SECRET_KEY=xxx
# Add other environment variables needed by the backend here
# - SOME_OTHER_VAR=some_value
depends_on:
db: # Wait for the db service to be healthy before starting backend
db:
# Wait for the db service to be healthy before starting backend
condition: service_healthy
command: ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] # Override CMD for development reload
command: [
"uvicorn",
"app.main:app",
"--host",
"0.0.0.0",
"--port",
"8000",
"--reload",
] # Override CMD for development reload
restart: unless-stopped
pgadmin: # Optional service for database administration
image: dpage/pgadmin4:latest
container_name: pgadmin4_server
environment:
PGADMIN_DEFAULT_EMAIL: admin@example.com # Change as needed
PGADMIN_DEFAULT_PASSWORD: admin_password # Change to a secure password
PGADMIN_CONFIG_SERVER_MODE: 'False' # Run in Desktop mode for easier local dev server setup
volumes:
- pgadmin_data:/var/lib/pgadmin # Persist pgAdmin configuration
frontend:
container_name: vite_frontend
build:
context: ./fe
dockerfile: Dockerfile
ports:
- "5050:80" # Map container port 80 to host port 5050
- "80:80"
depends_on:
- db # Depends on the database service
- backend
restart: unless-stopped
volumes: # Define named volumes for data persistence
volumes:
# Define named volumes for data persistence
postgres_data:
pgadmin_data:

196
docs/PRODUCTION.md Normal file
View File

@ -0,0 +1,196 @@
# Production Deployment Guide (Gitea Actions)
This guide covers deploying the mitlist application to a production environment using Docker Compose and Gitea Actions for CI/CD.
## 🚀 Quick Start
1. **Clone the repository** (if not already done):
```bash
git clone <your-repo>
cd mitlist
```
2. **Configure Gitea Secrets**:
In your Gitea repository settings, go to "Secrets" and add the following secrets. These will be used by the `deploy-prod.yml` workflow.
* `DOCKER_USERNAME`: Your Docker Hub username (or username for your container registry).
* `DOCKER_PASSWORD`: Your Docker Hub password (or token for your container registry).
* `SERVER_HOST`: IP address or hostname of your production server.
* `SERVER_USERNAME`: Username for SSH access to your production server.
* `SSH_PRIVATE_KEY`: Your private SSH key for accessing the production server.
* `SERVER_PORT`: (Optional) SSH port for your server (defaults to 22).
* `POSTGRES_USER`: Production database username.
* `POSTGRES_PASSWORD`: Production database password.
* `POSTGRES_DB`: Production database name.
* `DATABASE_URL`: Production database connection string.
* `SECRET_KEY`: FastAPI application secret key.
* `SESSION_SECRET_KEY`: FastAPI session secret key.
* `GEMINI_API_KEY`: API key for Gemini.
* `REDIS_PASSWORD`: Password for Redis.
* `SENTRY_DSN`: (Optional) Sentry DSN for backend error tracking.
* `CORS_ORIGINS`: Comma-separated list of allowed CORS origins for production (e.g., `https://yourdomain.com`).
* `FRONTEND_URL`: The public URL of your frontend (e.g., `https://yourdomain.com`).
* `VITE_API_URL`: The public API URL for the frontend (e.g., `https://yourdomain.com/api`).
* `VITE_SENTRY_DSN`: (Optional) Sentry DSN for frontend error tracking.
3. **Prepare your Production Server**:
* Install Docker and Docker Compose (see Prerequisites section below).
* Ensure your server can be accessed via SSH using the key you added to Gitea secrets.
* Create the deployment directory on your server (e.g., `/srv/mitlist`).
* Copy the `docker-compose.prod.yml` file to this directory on your server.
4. **Push to `main` branch**:
Once the Gitea workflows (`.gitea/workflows/build-test.yml` and `.gitea/workflows/deploy-prod.yml`) are in your repository, pushing to the `main` branch will automatically trigger the deployment workflow.
## 📋 Prerequisites (Server Setup)
### System Requirements
- **OS**: Ubuntu 20.04+ / CentOS 8+ / Debian 11+
- **RAM**: Minimum 2GB, Recommended 4GB+
- **Storage**: Minimum 20GB free space
- **CPU**: 2+ cores recommended
### Software Dependencies (on Production Server)
- Docker 20.10+
- Docker Compose 2.0+
### Installation Commands
**Ubuntu/Debian:**
```bash
# Update system
sudo apt update && sudo apt upgrade -y
# Install Docker
curl -fsSL https://get.docker.com -o get-docker.sh
sudo sh get-docker.sh
sudo usermod -aG docker $USER # Add your deployment user to docker group
# Install Docker Compose
sudo apt install docker-compose-plugin
# Reboot or log out/in to apply group changes
# sudo reboot
```
**CentOS/RHEL:**
```bash
# Update system
sudo yum update -y
# Install Docker
sudo yum install -y yum-utils
sudo yum-config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo
sudo yum install docker-ce docker-ce-cli containerd.io docker-compose-plugin
sudo systemctl start docker
sudo systemctl enable docker
sudo usermod -aG docker $USER # Add your deployment user to docker group
# Reboot or log out/in to apply group changes
# sudo reboot
```
## 🔧 Configuration Overview
* **`docker-compose.prod.yml`**: Defines the production services (database, backend, frontend, redis). This file needs to be on your production server in the deployment directory.
* **`.gitea/workflows/build-test.yml`**: Gitea workflow that builds and runs tests on every push to `main` or `develop`, and on pull requests to these branches.
* **`.gitea/workflows/deploy-prod.yml`**: Gitea workflow that triggers on pushes to the `main` branch. It builds and pushes Docker images to your container registry and then SSHes into your production server to update environment variables and restart services using `docker-compose`.
* **`env.production.template`**: A template file showing the environment variables needed. These are now set directly in the Gitea deployment workflow via secrets.
## 🚀 Deployment Process (via Gitea Actions)
1. **Code Push**: Developer pushes code to the `main` branch.
2. **Build & Test Workflow**: (Optional, if you keep `build-test.yml` active on `main` as well) The `build-test.yml` workflow runs, ensuring code quality.
3. **Deploy Workflow Trigger**: The `deploy-prod.yml` workflow is triggered.
4. **Checkout Code**: The workflow checks out the latest code.
5. **Login to Registry**: Logs into your specified Docker container registry.
6. **Build & Push Images**: Builds the production Docker images for the backend and frontend and pushes them to the registry.
7. **SSH to Server**: Connects to your production server via SSH.
8. **Set Environment Variables**: Creates/updates the `.env.production` file on the server using the Gitea secrets.
9. **Pull New Images**: Runs `docker-compose pull` to fetch the newly pushed images.
10. **Restart Services**: Runs `docker-compose up -d` to restart the services with the new images and configuration.
11. **Prune Images**: Cleans up old, unused Docker images on the server.
## 🏗️ Simplified Architecture
With the removal of nginx as a reverse proxy, the architecture is simpler:
```
[ User / Internet ]
|
v
[ Frontend Service (Port 80) ] <-- Serves Vue.js app (e.g., via `serve`)
|
v (API Calls)
[ Backend Service (Internal Port 8000) ] <-- FastAPI
| |
v v
[ PostgreSQL ] [ Redis ]
(Database) (Cache)
```
* The **Frontend** service now directly exposes port 80 (or another port you configure) to the internet.
* The **Backend** service is still internal and accessed by the frontend via its Docker network name (`backend:8000`).
**Note on SSL/HTTPS**: Since nginx is removed, SSL termination is not handled by this setup. You would typically handle SSL at a higher level, for example:
* Using a cloud provider's load balancer with SSL termination.
* Placing another reverse proxy (like Caddy, Traefik, or a dedicated nginx instance) in front of your Docker setup on the server, configured for SSL.
* Using services like Cloudflare that can provide SSL for your domain.
## 📊 Monitoring & Logging
### Health Checks
* **Backend**: `http://<your_server_ip_or_domain>/api/health` (assuming your backend health endpoint is accessible if you map its port in `docker-compose.prod.yml` or if the frontend proxies it).
* **Frontend**: The `serve` package used by the frontend doesn't have a dedicated health check endpoint by default. You can check if the main page loads.
### Log Access
```bash
# On your production server, in the deployment directory
docker-compose -f docker-compose.prod.yml logs -f
# Specific service logs
docker-compose -f docker-compose.prod.yml logs -f backend
docker-compose -f docker-compose.prod.yml logs -f frontend
```
## 🔄 Maintenance
### Database Backups
Manual backups can still be performed on the server:
```bash
# Ensure your .env.production file is sourced or vars are available
docker exec postgres_db_prod pg_dump -U $POSTGRES_USER $POSTGRES_DB > backup-$(date +%Y%m%d).sql
```
Consider automating this with a cron job on your server.
### Updates
Updates are now handled by pushing to the `main` branch, which triggers the Gitea deployment workflow.
## 🐛 Troubleshooting
### Gitea Workflow Failures
* Check the Gitea Actions logs for the specific workflow run to identify errors.
* Ensure all secrets are correctly configured in Gitea.
* Verify Docker Hub/registry credentials.
* Check SSH connectivity to your server from the Gitea runner (if using self-hosted runners, ensure network access).
### Service Not Starting on Server
* SSH into your server.
* Navigate to your deployment directory (e.g., `/srv/mitlist`).
* Check logs: `docker-compose -f docker-compose.prod.yml logs <service_name>`
* Ensure `.env.production` has the correct values.
* Check `docker ps` to see running containers.
### Frontend Not Accessible
* Verify the frontend service is running (`docker ps`).
* Check frontend logs: `docker-compose -f docker-compose.prod.yml logs frontend`.
* Ensure the port mapping in `docker-compose.prod.yml` for the frontend service (e.g., `80:3000`) is correct and not blocked by a firewall on your server.
## 📝 Changelog
### v1.1.0 (Gitea Actions Deployment)
- Removed nginx reverse proxy and related shell scripts.
- Frontend now served directly using `serve`.
- Added Gitea Actions workflows for CI (build/test) and CD (deploy to production).
- Updated deployment documentation to reflect Gitea Actions strategy.
- Simplified `docker-compose.prod.yml`.

368
docs/expense-system.md Normal file
View File

@ -0,0 +1,368 @@
# Expense System Documentation
## Overview
The expense system is a core feature that allows users to track shared expenses, split them among group members, and manage settlements. The system supports various split types and integrates with lists, groups, and items.
## Core Components
### 1. Expenses
An expense represents a shared cost that needs to be split among multiple users.
#### Key Properties
- `id`: Unique identifier
- `description`: Description of the expense
- `total_amount`: Total cost of the expense (Decimal)
- `currency`: Currency code (defaults to "USD")
- `expense_date`: When the expense occurred
- `split_type`: How the expense should be divided
- `list_id`: Optional reference to a shopping list
- `group_id`: Optional reference to a group
- `item_id`: Optional reference to a specific item
- `paid_by_user_id`: User who paid for the expense
- `created_by_user_id`: User who created the expense record
- `version`: For optimistic locking
- `overall_settlement_status`: Overall payment status
#### Status Types
```typescript
enum ExpenseOverallStatusEnum {
UNPAID = "unpaid",
PARTIALLY_PAID = "partially_paid",
PAID = "paid",
}
```
### 2. Expense Splits
Splits represent how an expense is divided among users.
#### Key Properties
- `id`: Unique identifier
- `expense_id`: Reference to parent expense
- `user_id`: User who owes this portion
- `owed_amount`: Amount owed by the user
- `share_percentage`: Percentage share (for percentage-based splits)
- `share_units`: Number of shares (for share-based splits)
- `status`: Current payment status
- `paid_at`: When the split was paid
- `settlement_activities`: List of payment activities
#### Status Types
```typescript
enum ExpenseSplitStatusEnum {
UNPAID = "unpaid",
PARTIALLY_PAID = "partially_paid",
PAID = "paid",
}
```
### 3. Settlement Activities
Settlement activities track individual payments made towards expense splits.
#### Key Properties
- `id`: Unique identifier
- `expense_split_id`: Reference to the split being paid
- `paid_by_user_id`: User making the payment
- `amount_paid`: Amount being paid
- `paid_at`: When the payment was made
- `created_by_user_id`: User who recorded the payment
## Split Types
The system supports multiple ways to split expenses:
### 1. Equal Split
- Divides the total amount equally among all participants
- Handles rounding differences by adding remainder to first split
- No additional data required
### 2. Exact Amounts
- Users specify exact amounts for each person
- Sum of amounts must equal total expense
- Requires `splits_in` data with exact amounts
### 3. Percentage Based
- Users specify percentage shares
- Percentages must sum to 100%
- Requires `splits_in` data with percentages
### 4. Share Based
- Users specify number of shares
- Amount divided proportionally to shares
- Requires `splits_in` data with share units
### 5. Item Based
- Splits based on items in a shopping list
- Each item's cost is assigned to its adder
- Requires `list_id` and optionally `item_id`
## Integration Points
### 1. Lists
- Expenses can be associated with shopping lists
- Item-based splits use list items to determine splits
- List's group context can determine split participants
### 2. Groups
- Expenses can be directly associated with groups
- Group membership determines who can be included in splits
- Group context is required if no list is specified
### 3. Items
- Expenses can be linked to specific items
- Item prices are used for item-based splits
- Items must belong to a list
### 4. Users
- Users can be payers, debtors, or payment recorders
- User relationships are tracked in splits and settlements
- User context is required for all financial operations
## Key Operations
### 1. Creating Expenses
1. Validate context (list/group)
2. Create expense record
3. Generate splits based on split type
4. Validate total amounts match
5. Save all records in transaction
### 2. Updating Expenses
- Limited to non-financial fields:
- Description
- Currency
- Expense date
- Uses optimistic locking via version field
- Cannot modify splits after creation
### 3. Recording Payments
1. Create settlement activity
2. Update split status
3. Recalculate expense overall status
4. All operations in single transaction
### 4. Deleting Expenses
- Requires version matching
- Cascades to splits and settlements
- All operations in single transaction
## Best Practices
1. **Data Integrity**
- Always use transactions for multi-step operations
- Validate totals match before saving
- Use optimistic locking for updates
2. **Error Handling**
- Handle database errors appropriately
- Validate user permissions
- Check for concurrent modifications
3. **Performance**
- Use appropriate indexes
- Load relationships efficiently
- Batch operations when possible
4. **Security**
- Validate user permissions
- Sanitize input data
- Use proper access controls
## Common Use Cases
1. **Group Dinner**
- Create expense with total amount
- Use equal split or exact amounts
- Record payments as they occur
2. **Shopping List**
- Create item-based expense
- System automatically splits based on items
- Track payments per person
3. **Rent Sharing**
- Create expense with total rent
- Use percentage or share-based split
- Record monthly payments
4. **Trip Expenses**
- Create multiple expenses
- Mix different split types
- Track overall balances
## Recurring Expenses
Recurring expenses are expenses that repeat at regular intervals. They are useful for regular payments like rent, utilities, or subscription services.
### Recurrence Types
1. **Daily**
- Repeats every X days
- Example: Daily parking fee
2. **Weekly**
- Repeats every X weeks on specific days
- Example: Weekly cleaning service
3. **Monthly**
- Repeats every X months on the same date
- Example: Monthly rent payment
4. **Yearly**
- Repeats every X years on the same date
- Example: Annual insurance premium
### Implementation Details
1. **Recurrence Pattern**
```typescript
interface RecurrencePattern {
type: "daily" | "weekly" | "monthly" | "yearly";
interval: number; // Every X days/weeks/months/years
daysOfWeek?: number[]; // For weekly recurrence (0-6, Sunday-Saturday)
endDate?: string; // Optional end date for the recurrence
maxOccurrences?: number; // Optional maximum number of occurrences
}
```
2. **Recurring Expense Properties**
- All standard expense properties
- `recurrence_pattern`: Defines how the expense repeats
- `next_occurrence`: When the next expense will be created
- `last_occurrence`: When the last expense was created
- `is_recurring`: Boolean flag to identify recurring expenses
3. **Generation Process**
- System automatically creates new expenses based on the pattern
- Each generated expense is a regular expense with its own splits
- Original recurring expense serves as a template
- Generated expenses can be modified individually
4. **Management Features**
- Pause/resume recurrence
- Modify future occurrences
- Skip specific occurrences
- End recurrence early
- View all generated expenses
### Best Practices for Recurring Expenses
1. **Data Management**
- Keep original recurring expense as template
- Generate new expenses in advance
- Clean up old generated expenses periodically
2. **User Experience**
- Clear indication of recurring expenses
- Easy way to modify future occurrences
- Option to handle exceptions
3. **Performance**
- Batch process expense generation
- Index recurring expense queries
- Cache frequently accessed patterns
### Example Use Cases
1. **Monthly Rent**
```json
{
"description": "Monthly Rent",
"total_amount": "2000.00",
"split_type": "PERCENTAGE",
"recurrence_pattern": {
"type": "monthly",
"interval": 1,
"endDate": "2024-12-31"
}
}
```
2. **Weekly Cleaning Service**
```json
{
"description": "Weekly Cleaning",
"total_amount": "150.00",
"split_type": "EQUAL",
"recurrence_pattern": {
"type": "weekly",
"interval": 1,
"daysOfWeek": [1] // Every Monday
}
}
```
## API Considerations
1. **Decimal Handling**
- Use string representation for decimals in API
- Convert to Decimal for calculations
- Round to 2 decimal places for money
2. **Date Handling**
- Use ISO format for dates
- Store in UTC
- Convert to local time for display
3. **Status Updates**
- Update split status on payment
- Recalculate overall status
- Notify relevant users
## Future Considerations
1. **Potential Enhancements**
- Recurring expenses
- Bulk operations
- Advanced reporting
- Currency conversion
2. **Scalability**
- Handle large groups
- Optimize for frequent updates
- Consider caching strategies
3. **Integration**
- Payment providers
- Accounting systems
- Export capabilities

177
docs/mitlist_doc.md Normal file
View File

@ -0,0 +1,177 @@
## Project Documentation: Shared Household Management PWA
**Version:** 1.1 (Tech Stack Update)
**Date:** 2025-04-22
### 1. Project Overview
**1.1. Concept:**
Develop a Progressive Web App (PWA) designed to streamline household coordination and shared responsibilities. The application enables users within defined groups (e.g., households, roommates, families) to collaboratively manage shopping lists, track and split expenses with historical accuracy, and manage recurring or one-off household chores.
**1.2. Goals:**
- Simplify the creation, management, and sharing of shopping lists.
- Provide an efficient way to add items via image capture and OCR (using Gemini 1.5 Flash).
- Enable transparent and traceable tracking and splitting of shared expenses related to shopping lists.
- Offer a clear system for managing and assigning recurring or single-instance household chores.
- Deliver a seamless, near-native user experience across devices through PWA technologies, including robust offline capabilities.
- Foster better communication and coordination within shared living environments.
**1.3. Target Audience:**
- Roommates sharing household expenses and chores.
- Families coordinating grocery shopping and household tasks.
- Couples managing shared finances and responsibilities.
- Groups organizing events or trips involving shared purchases.
### 2. Key Features (V1 Scope)
The Minimum Viable Product (V1) focuses on delivering the core functionalities with a high degree of polish and reliability:
- **User Authentication & Group Management (using `fastapi-users`):**
- Secure email/password signup, login, password reset, email verification (leveraging `fastapi-users` features).
- Ability to create user groups (e.g., "Home", "Trip").
- Invite members to groups via unique, shareable codes/links.
- Basic role distinction (Owner, Member) for group administration.
- Ability for users to view groups and leave groups.
- **Shared Shopping List Management:**
- CRUD operations for shopping lists (Create, Read, Update, Delete).
- Option to create personal lists or share lists with specific groups.
- Real-time (or near real-time via polling/basic WebSocket) updates for shared lists.
- CRUD operations for items within lists (name, quantity, notes).
- Ability to mark items as purchased.
- Attribution for who added/completed items in shared lists.
- **OCR Integration (Gemini 1.5 Flash):**
- Capture images (receipts, handwritten lists) via browser (`input capture` / `getUserMedia`).
- Backend processing using Google AI API (Gemini 1.5 Flash model) with tailored prompts to extract item names.
- User review and edit screen for confirming/correcting extracted items before adding them to the list.
- Clear progress indicators and error handling.
- **Cost Splitting (Traceable):**
- Ability to add prices to completed items on a list, recording who added the price and when.
- Functionality to trigger an expense calculation for a list based on items with prices.
- Creation of immutable `ExpenseRecord` entries detailing the total amount, participants, and calculation time/user.
- Generation of `ExpenseShare` entries detailing the amount owed per participant for each `ExpenseRecord`.
- Ability for participants to mark their specific `ExpenseShare` as paid, logged via a `SettlementActivity` record for full traceability.
- View displaying historical expense records and their settlement status for each list.
- V1 focuses on equal splitting among all group members associated with the list at the time of calculation.
- **Chore Management (Recurring & Assignable):**
- CRUD operations for chores within a group context.
- Ability to define chores as one-time or recurring (daily, weekly, monthly, custom intervals).
- System calculates `next_due_date` based on frequency.
- Manual assignment of chores (specific instances/due dates) to group members via `ChoreAssignments`.
- Ability for assigned users to mark their specific `ChoreAssignment` as complete.
- Automatic update of the parent chore's `last_completed_at` and recalculation of `next_due_date` upon completion of recurring chores.
- Dedicated view for users to see their pending assigned chores ("My Chores").
- **PWA Core Functionality:**
- Installable on user devices via `manifest.json`.
- Offline access to cached data (lists, items, chores, basic expense info) via Service Workers and IndexedDB.
- Background synchronization queue for actions performed offline (adding items, marking complete, adding prices, completing chores).
- Basic conflict resolution strategy (e.g., last-write-wins with user notification) for offline data sync.
### 3. User Experience (UX) Philosophy
- **User-Centered & Collaborative:** Focus on intuitive workflows for both individual task management and seamless group collaboration. Minimize friction in common tasks like adding items, splitting costs, and completing chores.
- **Native-like PWA Experience:** Leverage Service Workers, caching (IndexedDB), and `manifest.json` to provide fast loading, reliable offline functionality, and installability, mimicking a native app experience.
- **Clarity & Accessibility:** Prioritize clear information hierarchy, legible typography, sufficient contrast, and adherence to WCAG accessibility standards for usability by all users. Utilize **Valerie UI** components designed with accessibility in mind.
- **Informative Feedback:** Provide immediate visual feedback for user actions (loading states, confirmations, animations). Clearly communicate offline status, sync progress, OCR processing status, and data conflicts.
### 4. Architecture & Technology Stack
- **Frontend:**
- **Framework:** Vue.js (Vue 3 with Composition API, built with Vite).
- **Styling & UI Components:** **Valerie UI** (as the primary component library and design system).
- **State Management:** Pinia (official state management library for Vue).
- **PWA:** Vite PWA plugin (leveraging Workbox.js under the hood) for Service Worker generation, manifest management, and caching strategies. IndexedDB for offline data storage.
- **Backend:**
- **Framework:** FastAPI (Python, high-performance, async support, automatic docs).
- **Database:** PostgreSQL (reliable relational database with JSONB support).
- **ORM:** SQLAlchemy (version 2.0+ with native async support).
- **Migrations:** Alembic (for managing database schema changes).
- **Authentication & User Management:** **`fastapi-users`** (handles user models, password hashing, JWT/cookie authentication, and core auth endpoints like signup, login, password reset, email verification).
- **Cloud Services & APIs:**
- **OCR:** Google AI API (using `gemini-1.5-flash-latest` model).
- **Hosting (Backend):** Containerized deployment (Docker) on cloud platforms like Google Cloud Run, AWS Fargate, or DigitalOcean App Platform.
- **Hosting (Frontend):** Static hosting platforms like Vercel, Netlify, or Cloudflare Pages (optimized for Vite-built Vue apps).
- **DevOps & Monitoring:**
- **Version Control:** Git (hosted on GitHub, GitLab, etc.).
- **Containerization:** Docker & Docker Compose (for local development and deployment consistency).
- **CI/CD:** GitHub Actions (or similar) for automated testing and deployment pipelines (using Vite build commands for frontend).
- **Error Tracking:** Sentry (or similar) for real-time error monitoring.
- **Logging:** Standard Python logging configured within FastAPI.
### 5. Data Model Highlights
Key database tables supporting the application's features:
- `Users`: Stores user account information. The schema will align with `fastapi-users` requirements (e.g., `id`, `email`, `hashed_password`, `is_active`, `is_superuser`, `is_verified`), with potential custom fields added as needed.
- `Groups`: Defines shared groups (name, owner).
- `UserGroups`: Many-to-many relationship linking users to groups with roles (owner/member).
- `Lists`: Stores shopping list details (name, description, creator, associated group, completion status).
- `Items`: Stores individual shopping list items (name, quantity, price, completion status, list association, user attribution for adding/pricing).
- `ExpenseRecords`: Logs each instance of a cost split calculation for a list (total amount, participants, calculation time/user, overall settlement status).
- `ExpenseShares`: Details the amount owed by each participant for a specific `ExpenseRecord` (links to user and record, amount, paid status).
- `SettlementActivities`: Records every action taken to mark an `ExpenseShare` as paid (links to record, payer, affected user, timestamp).
- `Chores`: Defines chore templates (name, description, group association, recurrence rules, next due date).
- `ChoreAssignments`: Tracks specific instances of chores assigned to users (links to chore, user, due date, completion status).
### 6. Core User Flows (Summarized)
- **Onboarding:** Signup/Login (via `fastapi-users` flow) -> Optional guided tour -> Create/Join first group -> Dashboard.
- **List Creation & Sharing:** Create List -> Choose Personal or Share with Group -> List appears on dashboard (and shared members' dashboards).
- **Adding Items (Manual):** Open List -> Type item name -> Item added.
- **Adding Items (OCR):** Open List -> Tap "Add via Photo" -> Capture/Select Image -> Upload/Process (Gemini) -> Review/Edit extracted items -> Confirm -> Items added to list.
- **Shopping & Price Entry:** Open List -> Check off items -> Enter price for completed items -> Price saved.
- **Cost Splitting Cycle:** View List -> Click "Calculate Split" -> Backend creates traceable `ExpenseRecord` & `ExpenseShares` -> View Expense History -> Participants mark their shares paid (creating `SettlementActivity`).
- **Chore Cycle:** Create Chore (define recurrence) -> Chore appears in group list -> (Manual Assignment) Assign chore instance to user -> User views "My Chores" -> User marks assignment complete -> Backend updates status and recalculates next due date for recurring chores.
- **Offline Usage:** Open app offline -> View cached lists/chores -> Add/complete items/chores -> Changes queued -> Go online -> Background sync processes queue -> UI updates, conflicts notified.
### 7. Development Roadmap (Phase Summary)
1. **Phase 1: Planning & Design:** User stories, flows, sharing/sync models, tech stack, architecture, schema design.
2. **Phase 2: Core App Setup:** Project initialization (Git, **Vue.js with Vite**, FastAPI), DB connection (SQLAlchemy/Alembic), basic PWA config (**Vite PWA plugin**, manifest, SW), **Valerie UI integration**, **Pinia setup**, Docker setup, CI checks.
3. **Phase 3: User Auth & Group Management:** Backend: Integrate **`fastapi-users`**, configure its routers, adapt user model. Frontend: Implement auth pages using **Vue components**, **Pinia for auth state**, and calling `fastapi-users` endpoints. Implement Group Management features.
4. **Phase 4: Shared Shopping List CRUD:** Backend/Frontend for List/Item CRUD, permissions, basic real-time updates (polling), offline sync refinement for lists/items.
5. **Phase 5: OCR Integration (Gemini Flash):** Backend integration with Google AI SDK, image capture/upload UI, OCR processing endpoint, review/edit screen, integration with list items.
6. **Phase 6: Cost Splitting (Traceable):** Backend/Frontend for adding prices, calculating splits (creating historical records), viewing expense history, marking shares paid (with activity logging).
7. **Phase 7: Chore Splitting Module:** Backend/Frontend for Chore CRUD (including recurrence), manual assignment, completion tracking, "My Chores" view, recurrence handling logic.
8. **Phase 8: Testing, Refinement & Beta Launch:** Comprehensive E2E testing, usability testing, accessibility checks, performance tuning, deployment to beta environment, feedback collection.
9. **Phase 9: Final Release & Post-Launch Monitoring:** Address beta feedback, final deployment to production, setup monitoring (errors, performance, costs).
_(Estimated Total Duration: Approx. 17-19 Weeks for V1)_
### 8. Risk Management & Mitigation
- **Collaboration Complexity:** (Risk) Permissions and real-time sync can be complex. (Mitigation) Start simple, test permissions thoroughly, use clear data models.
- **OCR Accuracy/Cost (Gemini):** (Risk) OCR isn't perfect; API calls have costs/quotas. (Mitigation) Use capable model (Gemini Flash), mandatory user review step, clear error feedback, monitor API usage/costs, secure API keys.
- **Offline Sync Conflicts:** (Risk) Concurrent offline edits can clash. (Mitigation) Implement defined strategy (last-write-wins + notify), robust queue processing, thorough testing of conflict scenarios.
- **PWA Consistency:** (Risk) Behavior varies across browsers/OS (esp. iOS). (Mitigation) Rigorous cross-platform testing, use standard tools (Vite PWA plugin/Workbox), follow best practices.
- **Traceability Overhead:** (Risk) Storing detailed history increases DB size/complexity. (Mitigation) Design efficient queries, use appropriate indexing, plan for potential data archiving later.
- **User Adoption:** (Risk) Users might not consistently use groups/features. (Mitigation) Smooth onboarding, clear value proposition, reliable core features.
- **Valerie UI Maturity/Flexibility:** (Risk, if "Valerie UI" is niche or custom) Potential limitations in component availability or customization. (Mitigation) Thoroughly evaluate Valerie UI early, have fallback styling strategies if needed, or contribute to/extend the library.
### 9. Testing Strategy
- **Unit Tests:** Backend logic (calculations, permissions, recurrence), Frontend component logic (**Vue Test Utils** for Vue components, Pinia store testing).
- **Integration Tests:** Backend API endpoints interacting with DB and external APIs (Gemini - mocked).
- **End-to-End (E2E) Tests:** (Playwright/Cypress) Simulate full user flows across features.
- **PWA Testing:** Manual and automated checks for installability, offline functionality (caching, sync queue), cross-browser/OS compatibility.
- **Accessibility Testing:** Automated tools (axe-core) + manual checks (keyboard nav, screen readers), leveraging **Valerie UI's** accessibility features.
- **Usability Testing:** Regular sessions with target users throughout development.
- **Security Testing:** Basic checks (OWASP Top 10 awareness), dependency scanning, secure handling of secrets/tokens (rely on `fastapi-users` security practices).
- **Manual Testing:** Exploratory testing, edge case validation, testing diverse OCR inputs.
### 10. Future Enhancements (Post-V1)
- Advanced Cost Splitting (by item, percentage, unequal splits).
- Payment Integration (Stripe Connect for settling debts).
- Real-time Collaboration (WebSockets for instant updates).
- Push Notifications (reminders for chores, expenses, list updates).
- Advanced Chore Features (assignment algorithms, calendar view).
- Enhanced OCR (handling more formats, potential fine-tuning).
- User Profile Customization (avatars, etc., extending `fastapi-users` model).
- Analytics Dashboard (spending insights, chore completion stats).
- Recipe Integration / Pantry Inventory Tracking.
### 11. Conclusion
This project aims to deliver a modern, user-friendly PWA that effectively addresses common household coordination challenges. By combining collaborative list management, intelligent OCR, traceable expense splitting, and flexible chore tracking with a robust offline-first PWA architecture built on **Vue.js, Pinia, Valerie UI, and FastAPI with `fastapi-users`**, the application will provide significant value to roommates, families, and other shared living groups. The focus on a well-defined V1, traceable data, and a solid technical foundation sets the stage for future growth and feature expansion.

46
env.production.template Normal file
View File

@ -0,0 +1,46 @@
# Production Environment Variables Template
# Copy this file to .env.production and fill in the actual values
# NEVER commit the actual .env.production file to version control
# Database Configuration
POSTGRES_USER=mitlist_user
POSTGRES_PASSWORD=your_secure_database_password_here
POSTGRES_DB=mitlist_prod
DATABASE_URL=postgresql+asyncpg://mitlist_user:your_secure_database_password_here@db:5432/mitlist_prod
# Security Keys (Generate with: openssl rand -hex 32)
SECRET_KEY=your_secret_key_here_minimum_32_characters_long
SESSION_SECRET_KEY=your_session_secret_key_here_minimum_32_characters_long
# API Keys
GEMINI_API_KEY=your_gemini_api_key_here
# Redis Configuration
REDIS_PASSWORD=your_redis_password_here
# Sentry Configuration (Optional but recommended)
SENTRY_DSN=your_sentry_dsn_here
# CORS Configuration
CORS_ORIGINS=https://yourdomain.com,https://www.yourdomain.com
FRONTEND_URL=https://yourdomain.com
# Frontend Build Variables
VITE_API_URL=https://yourdomain.com/api
VITE_SENTRY_DSN=your_frontend_sentry_dsn_here
VITE_ROUTER_MODE=history
# Google OAuth Configuration - Replace with your actual credentials
GOOGLE_CLIENT_ID="YOUR_GOOGLE_CLIENT_ID_HERE"
GOOGLE_CLIENT_SECRET="YOUR_GOOGLE_CLIENT_SECRET_HERE"
GOOGLE_REDIRECT_URI=https://yourdomain.com/auth/google/callback
APPLE_CLIENT_ID=your_apple_client_id
APPLE_TEAM_ID=your_apple_team_id
APPLE_KEY_ID=your_apple_key_id
APPLE_PRIVATE_KEY=your_apple_private_key
APPLE_REDIRECT_URI=https://yourdomain.com/auth/apple/callback
# Production Settings
ENVIRONMENT=production
LOG_LEVEL=INFO

9
fe/.editorconfig Normal file
View File

@ -0,0 +1,9 @@
[*.{js,jsx,mjs,cjs,ts,tsx,mts,cts,vue,css,scss,sass,less,styl}]
charset = utf-8
indent_size = 2
indent_style = space
insert_final_newline = true
trim_trailing_whitespace = true
end_of_line = lf
max_line_length = 100

0
fe/.env Normal file
View File

Some files were not shown because too many files have changed in this diff Show More