Compare commits
188 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
f8788ee42d | ||
![]() |
b0ec84b8ca | ||
![]() |
198222c3ff | ||
![]() |
7ef225daec | ||
![]() |
6e56e164df | ||
![]() |
550fac1c0c | ||
![]() |
944976b1cc | ||
![]() |
92c919785a | ||
![]() |
a1acee6e59 | ||
![]() |
331eaf7c35 | ||
![]() |
b9b2bfb469 | ||
![]() |
5f05cd9377 | ||
![]() |
ddaa20af3c | ||
![]() |
cef359238b | ||
![]() |
5fffd4d2f5 | ||
![]() |
397cf28673 | ||
![]() |
d6c7fde40c | ||
![]() |
77178cc67e | ||
![]() |
0aa88d0af7 | ||
![]() |
fc09848a33 | ||
![]() |
b9aace0c4e | ||
![]() |
d8db5721f4 | ||
![]() |
6e79fbfa04 | ||
![]() |
5c882996a9 | ||
![]() |
6306e70df7 | ||
![]() |
dbfbe7922e | ||
![]() |
57b913d135 | ||
![]() |
588abb1217 | ||
![]() |
d150dd28c9 | ||
![]() |
6b54566cef | ||
![]() |
d623c4b27c | ||
![]() |
fc49e830fc | ||
![]() |
af6324ddef | ||
![]() |
6924a016c8 | ||
![]() |
0fcc94ae8d | ||
![]() |
c0aa654e83 | ||
![]() |
ec361fe9ab | ||
![]() |
9d404d04d5 | ||
![]() |
92c70813fb | ||
![]() |
2d16116716 | ||
![]() |
3e328c2902 | ||
![]() |
effaef7d08 | ||
![]() |
12e2890a4a | ||
![]() |
f98bdb6b11 | ||
![]() |
5d50606fc2 | ||
![]() |
30af7ab692 | ||
![]() |
4effbf5c03 | ||
![]() |
5c9ba3f38c | ||
![]() |
8034824c97 | ||
![]() |
82205f6158 | ||
![]() |
2a2045c24a | ||
![]() |
c1ebd16e5a | ||
![]() |
554814ad63 | ||
![]() |
f2609f53ec | ||
![]() |
4fef642970 | ||
![]() |
dda39532d6 | ||
![]() |
6d5e950918 | ||
![]() |
e6c15210c1 | ||
![]() |
b07ab09f88 | ||
![]() |
5cb13862ef | ||
![]() |
843b3411e4 | ||
![]() |
7da93d1fe9 | ||
![]() |
02238974aa | ||
![]() |
ca1ac94b57 | ||
![]() |
e52ab871bc | ||
![]() |
c6c204f64a | ||
![]() |
a059768d8a | ||
![]() |
09c3160fbb | ||
![]() |
287155a783 | ||
![]() |
c50395ae86 | ||
![]() |
4540ad359e | ||
![]() |
3738819065 | ||
![]() |
c14b432082 | ||
![]() |
c204c25314 | ||
![]() |
02ab812ef0 | ||
![]() |
20daadc112 | ||
![]() |
5dcabd51f7 | ||
![]() |
8f1da5d440 | ||
![]() |
0f9d83a233 | ||
![]() |
cb5bfcf7b5 | ||
![]() |
e16c749019 | ||
![]() |
7223606fdc | ||
![]() |
f4eeb00acf | ||
![]() |
43e2d88ffe | ||
![]() |
32841ea727 | ||
![]() |
26e06ddeaa | ||
![]() |
f2df1c50dd | ||
![]() |
411c3c91b2 | ||
![]() |
5a2b311a4f | ||
![]() |
9b09b461bd | ||
![]() |
9f8de46d06 | ||
![]() |
b1a74edb6a | ||
![]() |
161292ff3b | ||
![]() |
55d08d36e0 | ||
![]() |
59f2f47949 | ||
![]() |
1e9957de91 | ||
![]() |
6ed7e32922 | ||
![]() |
cc1f910e4c | ||
![]() |
cd98b7b854 | ||
![]() |
392a2ae049 | ||
![]() |
a51b18e8f5 | ||
![]() |
99d6c5ffaa | ||
![]() |
dd29f27a5b | ||
![]() |
d05200b623 | ||
![]() |
ed76816a32 | ||
![]() |
8c5753ea77 | ||
![]() |
12f35b539a | ||
![]() |
e104d26583 | ||
![]() |
8ff31ecf91 | ||
![]() |
1c87170955 | ||
![]() |
74c73a9e8f | ||
![]() |
679169e4fb | ||
![]() |
a7fbc454a9 | ||
![]() |
813ed911f1 | ||
![]() |
272e5abe41 | ||
![]() |
fc16f169b1 | ||
![]() |
3811dc7ee5 | ||
![]() |
136c4df7ac | ||
![]() |
821a26e681 | ||
![]() |
ee6d96d9ec | ||
![]() |
8c52bbb307 | ||
![]() |
ce67570cfb | ||
![]() |
cb51186830 | ||
![]() |
84b046508a | ||
![]() |
a0d67f6c66 | ||
![]() |
81577ac7e8 | ||
![]() |
b0100a2e96 | ||
![]() |
5018ce02f7 | ||
![]() |
52fc33b472 | ||
![]() |
e7b072c2bd | ||
![]() |
f1152c5745 | ||
![]() |
8bb960b605 | ||
![]() |
0bf7a7cb49 | ||
![]() |
653788cfba | ||
![]() |
c0dcccd970 | ||
![]() |
0204fb6f3a | ||
![]() |
29ccab2f7e | ||
![]() |
ed222c840a | ||
![]() |
04b0ad7059 | ||
![]() |
16c9abb16a | ||
![]() |
185e89351e | ||
![]() |
17bebbfab8 | ||
![]() |
fc355077ab | ||
![]() |
eb19230b22 | ||
![]() |
c8cdbd571e | ||
![]() |
d6d19397d3 | ||
![]() |
323ce210ce | ||
![]() |
98b2f907de | ||
![]() |
e4175db4aa | ||
![]() |
2b7816cf33 | ||
![]() |
5abe7839f1 | ||
![]() |
c2aa62fa03 | ||
![]() |
f2ac73502c | ||
![]() |
9ff293b850 | ||
![]() |
7a88ea258a | ||
![]() |
515534dcce | ||
![]() |
3f0cfff9f1 | ||
![]() |
72b988b79b | ||
![]() |
1c08e57afd | ||
![]() |
29682b7e9c | ||
![]() |
18f759aa7c | ||
![]() |
9583aa4bab | ||
![]() |
cacfb2a5e8 | ||
![]() |
227a3d6186 | ||
![]() |
9230d1f626 | ||
![]() |
5a910a29e2 | ||
![]() |
db5f2d089e | ||
![]() |
7bbec7ad5f | ||
![]() |
f6a50e0d6a | ||
![]() |
4283fe8a19 | ||
![]() |
0dbee3bb4b | ||
![]() |
d99aef9d11 | ||
![]() |
8b6ddb91f8 | ||
![]() |
e484c9e9a8 | ||
![]() |
f52b47f6df | ||
![]() |
262505c898 | ||
![]() |
7836672f64 | ||
![]() |
fe252cfac8 | ||
![]() |
4f32670bda | ||
![]() |
ff25af26f5 | ||
![]() |
6198a29768 | ||
![]() |
c7fdb60130 | ||
![]() |
5186892df6 | ||
![]() |
7b2c5c9ebd | ||
![]() |
e3024ccd07 | ||
![]() |
bbb3c3b7df | ||
![]() |
423d345fdf | ||
![]() |
d2d484c327 |
57
.cursor/rules/fastapi-db-strategy.mdc
Normal file
57
.cursor/rules/fastapi-db-strategy.mdc
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
---
|
||||||
|
description: FastAPI Database Transactions
|
||||||
|
globs:
|
||||||
|
alwaysApply: false
|
||||||
|
---
|
||||||
|
## FastAPI Database Transaction Management: Technical Specification
|
||||||
|
|
||||||
|
**Objective:** Ensure atomic, consistent, isolated, and durable (ACID) database operations through a standardized transaction management strategy.
|
||||||
|
|
||||||
|
**1. API Endpoint Transaction Scope (Primary Strategy):**
|
||||||
|
|
||||||
|
* **Mechanism:** A FastAPI dependency `get_transactional_session` (from `app.database` or `app.core.dependencies`) wraps database-modifying API request handlers.
|
||||||
|
* **Behavior:**
|
||||||
|
* `async with AsyncSessionLocal() as session:` obtains a session.
|
||||||
|
* `async with session.begin():` starts a transaction.
|
||||||
|
* **Commit:** Automatic on successful completion of the `yield session` block (i.e., endpoint handler success).
|
||||||
|
* **Rollback:** Automatic on any exception raised from the `yield session` block.
|
||||||
|
* **Usage:** Endpoints performing CUD (Create, Update, Delete) operations **MUST** use `db: AsyncSession = Depends(get_transactional_session)`.
|
||||||
|
* **Read-Only Endpoints:** May use `get_async_session` (alias `get_db`) or `get_transactional_session` (results in an empty transaction).
|
||||||
|
|
||||||
|
**2. CRUD Layer Function Design:**
|
||||||
|
|
||||||
|
* **Transaction Participation:** CRUD functions (in `app/crud/`) operate on the session provided by the caller.
|
||||||
|
* **Composability Pattern:** Employ `async with db.begin_nested() if db.in_transaction() else db.begin():` to wrap database modification logic within the CRUD function.
|
||||||
|
* If an outer transaction exists (e.g., from `get_transactional_session`), `begin_nested()` creates a **savepoint**. The `async with` block commits/rolls back this savepoint.
|
||||||
|
* If no outer transaction exists (e.g., direct call from a script), `begin()` starts a **new transaction**. The `async with` block commits/rolls back this transaction.
|
||||||
|
* **NO Direct `db.commit()` / `db.rollback()`:** CRUD functions **MUST NOT** call these directly. The `async with begin_nested()/begin()` block and the outermost transaction manager are responsible.
|
||||||
|
* **`await db.flush()`:** Use only when necessary within the `async with` block to:
|
||||||
|
1. Obtain auto-generated IDs for subsequent operations in the *same* transaction.
|
||||||
|
2. Force database constraint checks mid-transaction.
|
||||||
|
* **Error Handling:** Raise specific custom exceptions (e.g., `ListNotFoundError`, `DatabaseIntegrityError`). These exceptions will trigger rollbacks in the managing transaction contexts.
|
||||||
|
|
||||||
|
**3. Non-API Operations (Background Tasks, Scripts):**
|
||||||
|
|
||||||
|
* **Explicit Management:** These contexts **MUST** manage their own session and transaction lifecycles.
|
||||||
|
* **Pattern:**
|
||||||
|
```python
|
||||||
|
async with AsyncSessionLocal() as session:
|
||||||
|
async with session.begin(): # Manages transaction for the task's scope
|
||||||
|
try:
|
||||||
|
# Call CRUD functions, which will participate via savepoints
|
||||||
|
await crud_operation_1(db=session, ...)
|
||||||
|
await crud_operation_2(db=session, ...)
|
||||||
|
# Commit is handled by session.begin() context manager on success
|
||||||
|
except Exception:
|
||||||
|
# Rollback is handled by session.begin() context manager on error
|
||||||
|
raise
|
||||||
|
```
|
||||||
|
|
||||||
|
**4. Key Principles Summary:**
|
||||||
|
|
||||||
|
* **API:** `get_transactional_session` for CUD.
|
||||||
|
* **CRUD:** Use `async with db.begin_nested() if db.in_transaction() else db.begin():`. No direct commit/rollback. Use `flush()` strategically.
|
||||||
|
* **Background Tasks:** Explicit `AsyncSessionLocal()` and `session.begin()` context managers.
|
||||||
|
|
||||||
|
|
||||||
|
This strategy ensures a clear separation of concerns, promotes composable CRUD operations, and centralizes final transaction control at the appropriate layer.
|
71
.gitea/workflows/build-test.yml
Normal file
71
.gitea/workflows/build-test.yml
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
name: Build and Test
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- develop
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- develop
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-and-test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
services:
|
||||||
|
postgres:
|
||||||
|
image: postgres:17-alpine
|
||||||
|
env:
|
||||||
|
POSTGRES_USER: testuser
|
||||||
|
POSTGRES_PASSWORD: testpassword
|
||||||
|
POSTGRES_DB: testdb
|
||||||
|
ports:
|
||||||
|
- 5432:5432
|
||||||
|
options: >-
|
||||||
|
--health-cmd pg_isready
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
--health-retries 5
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v3
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v3
|
||||||
|
with:
|
||||||
|
node-version: '24'
|
||||||
|
|
||||||
|
- name: Install backend dependencies
|
||||||
|
working-directory: ./be
|
||||||
|
run: |
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
- name: Install frontend dependencies
|
||||||
|
working-directory: ./fe
|
||||||
|
run: npm ci
|
||||||
|
|
||||||
|
- name: Run backend tests
|
||||||
|
working-directory: ./be
|
||||||
|
env:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://testuser:testpassword@localhost:5432/testdb
|
||||||
|
SECRET_KEY: testsecretkey
|
||||||
|
GEMINI_API_KEY: testgeminikey # Mock or skip tests requiring this if not available
|
||||||
|
SESSION_SECRET_KEY: testsessionsecret
|
||||||
|
run: pytest
|
||||||
|
|
||||||
|
- name: Build frontend
|
||||||
|
working-directory: ./fe
|
||||||
|
run: npm run build
|
||||||
|
|
||||||
|
# Add frontend test command if you have one e.g. npm test
|
||||||
|
# - name: Run frontend tests
|
||||||
|
# working-directory: ./fe
|
||||||
|
# run: npm test
|
214
.gitea/workflows/deploy-prod.yml
Normal file
214
.gitea/workflows/deploy-prod.yml
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
name: Deploy to Production, build images and push to Gitea Registry
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types: [closed]
|
||||||
|
branches:
|
||||||
|
- prod
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build_and_push:
|
||||||
|
if: github.event.pull_request.merged == true
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Debug context variables
|
||||||
|
run: |
|
||||||
|
echo "Actor: ${{ gitea.actor }}"
|
||||||
|
echo "Repository: ${{ gitea.repository_name }}"
|
||||||
|
echo "Repository owner: ${{ gitea.repository_owner }}"
|
||||||
|
echo "Event repository name: ${{ gitea.event.repository.name }}"
|
||||||
|
echo "Event repository full name: ${{ gitea.event.repository.full_name }}"
|
||||||
|
echo "Event repository owner login: ${{ gitea.event.repository.owner.login }}"
|
||||||
|
|
||||||
|
- name: Login to Gitea Registry
|
||||||
|
env:
|
||||||
|
GITEA_USERNAME: ${{ secrets.ME_USERNAME }}
|
||||||
|
GITEA_PASSWORD: ${{ secrets.ME_PASSWORD }}
|
||||||
|
run: |
|
||||||
|
echo $GITEA_PASSWORD | docker login git.vinylnostalgia.com -u $GITEA_USERNAME --password-stdin
|
||||||
|
|
||||||
|
- name: Set repository variables
|
||||||
|
id: vars
|
||||||
|
run: |
|
||||||
|
REPO_NAME="${{ gitea.repository_name }}"
|
||||||
|
ACTOR="${{ gitea.actor }}"
|
||||||
|
|
||||||
|
# Use fallback if variables are empty
|
||||||
|
if [ -z "$REPO_NAME" ]; then
|
||||||
|
REPO_NAME="${{ gitea.event.repository.name }}"
|
||||||
|
fi
|
||||||
|
if [ -z "$ACTOR" ]; then
|
||||||
|
ACTOR="${{ gitea.event.repository.owner.login }}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "actor=$ACTOR" >> $GITHUB_OUTPUT
|
||||||
|
echo "repo_name=$REPO_NAME" >> $GITHUB_OUTPUT
|
||||||
|
echo "Using ACTOR: $ACTOR"
|
||||||
|
echo "Using REPO_NAME: $REPO_NAME"
|
||||||
|
|
||||||
|
- name: Build backend image with optimizations
|
||||||
|
env:
|
||||||
|
GITEA_USERNAME: ${{ secrets.ME_USERNAME }}
|
||||||
|
GITEA_PASSWORD: ${{ secrets.ME_PASSWORD }}
|
||||||
|
run: |
|
||||||
|
REPO_NAME="${{ gitea.repository_name }}"
|
||||||
|
ACTOR="${{ gitea.actor }}"
|
||||||
|
|
||||||
|
# Use fallback if variables are empty
|
||||||
|
if [ -z "$REPO_NAME" ]; then
|
||||||
|
REPO_NAME="${{ gitea.event.repository.name }}"
|
||||||
|
fi
|
||||||
|
if [ -z "$ACTOR" ]; then
|
||||||
|
ACTOR="${{ gitea.event.repository.owner.login }}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Building backend image..."
|
||||||
|
echo "Using ACTOR: $ACTOR"
|
||||||
|
echo "Using REPO_NAME: $REPO_NAME"
|
||||||
|
|
||||||
|
# Build with BuildKit optimizations
|
||||||
|
DOCKER_BUILDKIT=1 docker build \
|
||||||
|
--compress \
|
||||||
|
--no-cache \
|
||||||
|
--squash \
|
||||||
|
--build-arg BUILDKIT_INLINE_CACHE=1 \
|
||||||
|
-t git.vinylnostalgia.com/$ACTOR/$REPO_NAME-backend:latest \
|
||||||
|
./be -f ./be/Dockerfile.prod
|
||||||
|
|
||||||
|
echo "Backend image built successfully"
|
||||||
|
|
||||||
|
- name: Push backend image with retry logic
|
||||||
|
env:
|
||||||
|
GITEA_USERNAME: ${{ secrets.ME_USERNAME }}
|
||||||
|
GITEA_PASSWORD: ${{ secrets.ME_PASSWORD }}
|
||||||
|
run: |
|
||||||
|
REPO_NAME="${{ gitea.repository_name }}"
|
||||||
|
ACTOR="${{ gitea.actor }}"
|
||||||
|
|
||||||
|
# Use fallback if variables are empty
|
||||||
|
if [ -z "$REPO_NAME" ]; then
|
||||||
|
REPO_NAME="${{ gitea.event.repository.name }}"
|
||||||
|
fi
|
||||||
|
if [ -z "$ACTOR" ]; then
|
||||||
|
ACTOR="${{ gitea.event.repository.owner.login }}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Push with retries and compression
|
||||||
|
max_retries=5
|
||||||
|
retry_count=0
|
||||||
|
base_wait=10
|
||||||
|
|
||||||
|
while [ $retry_count -lt $max_retries ]; do
|
||||||
|
echo "Pushing backend image (attempt $((retry_count + 1)) of $max_retries)..."
|
||||||
|
|
||||||
|
if docker push git.vinylnostalgia.com/$ACTOR/$REPO_NAME-backend:latest; then
|
||||||
|
echo "Backend image pushed successfully"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
|
||||||
|
retry_count=$((retry_count + 1))
|
||||||
|
if [ $retry_count -lt $max_retries ]; then
|
||||||
|
wait_time=$((base_wait * retry_count))
|
||||||
|
echo "Push failed, retrying in $wait_time seconds..."
|
||||||
|
sleep $wait_time
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ $retry_count -eq $max_retries ]; then
|
||||||
|
echo "Failed to push backend image after $max_retries attempts"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Build frontend image with optimizations
|
||||||
|
env:
|
||||||
|
GITEA_USERNAME: ${{ secrets.ME_USERNAME }}
|
||||||
|
GITEA_PASSWORD: ${{ secrets.ME_PASSWORD }}
|
||||||
|
run: |
|
||||||
|
REPO_NAME="${{ gitea.repository_name }}"
|
||||||
|
ACTOR="${{ gitea.actor }}"
|
||||||
|
|
||||||
|
# Use fallback if variables are empty
|
||||||
|
if [ -z "$REPO_NAME" ]; then
|
||||||
|
REPO_NAME="${{ gitea.event.repository.name }}"
|
||||||
|
fi
|
||||||
|
if [ -z "$ACTOR" ]; then
|
||||||
|
ACTOR="${{ gitea.event.repository.owner.login }}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Building frontend image..."
|
||||||
|
echo "Using ACTOR: $ACTOR"
|
||||||
|
echo "Using REPO_NAME: $REPO_NAME"
|
||||||
|
|
||||||
|
# Build with BuildKit optimizations
|
||||||
|
DOCKER_BUILDKIT=1 docker build \
|
||||||
|
--compress \
|
||||||
|
--no-cache \
|
||||||
|
--squash \
|
||||||
|
--build-arg BUILDKIT_INLINE_CACHE=1 \
|
||||||
|
-t git.vinylnostalgia.com/$ACTOR/$REPO_NAME-frontend:latest \
|
||||||
|
./fe -f ./fe/Dockerfile.prod
|
||||||
|
|
||||||
|
echo "Frontend image built successfully"
|
||||||
|
|
||||||
|
- name: Push frontend image with retry logic
|
||||||
|
env:
|
||||||
|
GITEA_USERNAME: ${{ secrets.ME_USERNAME }}
|
||||||
|
GITEA_PASSWORD: ${{ secrets.ME_PASSWORD }}
|
||||||
|
run: |
|
||||||
|
REPO_NAME="${{ gitea.repository_name }}"
|
||||||
|
ACTOR="${{ gitea.actor }}"
|
||||||
|
|
||||||
|
# Use fallback if variables are empty
|
||||||
|
if [ -z "$REPO_NAME" ]; then
|
||||||
|
REPO_NAME="${{ gitea.event.repository.name }}"
|
||||||
|
fi
|
||||||
|
if [ -z "$ACTOR" ]; then
|
||||||
|
ACTOR="${{ gitea.event.repository.owner.login }}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Push with retries and exponential backoff
|
||||||
|
max_retries=5
|
||||||
|
retry_count=0
|
||||||
|
base_wait=10
|
||||||
|
|
||||||
|
while [ $retry_count -lt $max_retries ]; do
|
||||||
|
echo "Pushing frontend image (attempt $((retry_count + 1)) of $max_retries)..."
|
||||||
|
|
||||||
|
if docker push git.vinylnostalgia.com/$ACTOR/$REPO_NAME-frontend:latest; then
|
||||||
|
echo "Frontend image pushed successfully"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
|
||||||
|
retry_count=$((retry_count + 1))
|
||||||
|
if [ $retry_count -lt $max_retries ]; then
|
||||||
|
wait_time=$((base_wait * retry_count))
|
||||||
|
echo "Push failed, retrying in $wait_time seconds..."
|
||||||
|
sleep $wait_time
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ $retry_count -eq $max_retries ]; then
|
||||||
|
echo "Failed to push frontend image after $max_retries attempts"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Cleanup Docker resources
|
||||||
|
if: always()
|
||||||
|
run: |
|
||||||
|
echo "Cleaning up Docker resources..."
|
||||||
|
docker system prune -af --volumes
|
||||||
|
docker logout git.vinylnostalgia.com
|
||||||
|
echo "Cleanup completed"
|
||||||
|
|
||||||
|
- name: Show final image sizes
|
||||||
|
if: always()
|
||||||
|
run: |
|
||||||
|
echo "Final image sizes:"
|
||||||
|
docker images --format "table {{.Repository}}\t{{.Tag}}\t{{.Size}}" | grep -E "(vinylnostalgia|REPOSITORY)"
|
@ -1,30 +1,55 @@
|
|||||||
# Git files
|
# Git
|
||||||
.git
|
.git
|
||||||
.gitignore
|
.gitignore
|
||||||
|
|
||||||
# Virtual environment
|
# Python
|
||||||
.venv
|
|
||||||
venv/
|
|
||||||
env/
|
|
||||||
ENV/
|
|
||||||
*.env # Ignore local .env files within the backend directory if any
|
|
||||||
|
|
||||||
# Python cache
|
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
env/
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
|
||||||
# IDE files
|
# Virtual Environment
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
|
||||||
|
# IDE
|
||||||
.idea/
|
.idea/
|
||||||
.vscode/
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
# Test artifacts
|
# Logs
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Local development
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
.env.*.local
|
||||||
|
|
||||||
|
# Docker
|
||||||
|
Dockerfile*
|
||||||
|
docker-compose*
|
||||||
|
.dockerignore
|
||||||
|
|
||||||
|
# Tests
|
||||||
|
tests/
|
||||||
|
test/
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
|
.coverage
|
||||||
htmlcov/
|
htmlcov/
|
||||||
.coverage*
|
|
||||||
|
|
||||||
# Other build/temp files
|
|
||||||
*.egg-info/
|
|
||||||
dist/
|
|
||||||
build/
|
|
||||||
*.db # e.g., sqlite temp dbs
|
|
@ -1,35 +1,75 @@
|
|||||||
# be/Dockerfile
|
# Multi-stage build for production - optimized for size
|
||||||
|
FROM python:3.11-slim AS builder
|
||||||
# Choose a suitable Python base image
|
|
||||||
FROM python:3.11-slim
|
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
ENV PYTHONDONTWRITEBYTECODE 1 # Prevent python from writing pyc files
|
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||||
ENV PYTHONUNBUFFERED 1 # Keep stdout/stderr unbuffered
|
PYTHONUNBUFFERED=1 \
|
||||||
|
PIP_NO_CACHE_DIR=1 \
|
||||||
|
PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||||
|
|
||||||
# Set the working directory in the container
|
# Install build dependencies
|
||||||
WORKDIR /app
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
gcc \
|
||||||
# Install system dependencies if needed (e.g., for psycopg2 build)
|
g++ \
|
||||||
# RUN apt-get update && apt-get install -y --no-install-recommends gcc build-essential libpq-dev && rm -rf /var/lib/apt/lists/*
|
libpq-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Install Python dependencies
|
# Install Python dependencies
|
||||||
# Upgrade pip first
|
COPY requirements.txt .
|
||||||
RUN pip install --no-cache-dir --upgrade pip
|
RUN pip install --user --no-cache-dir -r requirements.txt
|
||||||
# Copy only requirements first to leverage Docker cache
|
|
||||||
COPY requirements.txt requirements.txt
|
|
||||||
# Install dependencies
|
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
|
||||||
|
|
||||||
# Copy the rest of the application code into the working directory
|
# Production stage - minimal image
|
||||||
COPY . .
|
FROM python:3.11-slim AS production
|
||||||
# This includes your 'app/' directory, alembic.ini, etc.
|
|
||||||
|
|
||||||
# Expose the port the app runs on
|
# Set environment variables
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||||
|
PYTHONUNBUFFERED=1 \
|
||||||
|
PATH=/home/appuser/.local/bin:$PATH
|
||||||
|
|
||||||
|
# Install only runtime dependencies
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
libpq5 \
|
||||||
|
curl \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& apt-get clean
|
||||||
|
|
||||||
|
# Create non-root user
|
||||||
|
RUN groupadd -g 1001 appuser && \
|
||||||
|
useradd -u 1001 -g appuser -m appuser
|
||||||
|
|
||||||
|
# Copy Python packages from builder stage
|
||||||
|
COPY --from=builder /root/.local /home/appuser/.local
|
||||||
|
|
||||||
|
# Set working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy only necessary application files (be selective)
|
||||||
|
COPY --chown=appuser:appuser app/ ./app/
|
||||||
|
COPY --chown=appuser:appuser alembic/ ./alembic/
|
||||||
|
COPY --chown=appuser:appuser alembic.ini ./
|
||||||
|
COPY --chown=appuser:appuser *.py ./
|
||||||
|
COPY --chown=appuser:appuser requirements.txt ./
|
||||||
|
COPY --chown=appuser:appuser entrypoint.sh /app/entrypoint.sh
|
||||||
|
RUN chmod +x /app/entrypoint.sh
|
||||||
|
|
||||||
|
# Create logs directory
|
||||||
|
RUN mkdir -p /app/logs && chown -R appuser:appuser /app
|
||||||
|
|
||||||
|
# Switch to non-root user
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \
|
||||||
|
CMD curl -f http://localhost:8000/health || exit 1
|
||||||
|
|
||||||
|
# Expose port
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
# Command to run the application using uvicorn
|
# Production command
|
||||||
# The default command for production (can be overridden in docker-compose for development)
|
ENTRYPOINT ["/app/entrypoint.sh"]
|
||||||
# Note: Make sure 'app.main:app' correctly points to your FastAPI app instance
|
CMD ["uvicorn", "app.main:app", \
|
||||||
# relative to the WORKDIR (/app). If your main.py is directly in /app, this is correct.
|
"--host", "0.0.0.0", \
|
||||||
CMD ["uvicorn", "app.main:app", "--host", "localhost", "--port", "8000"]
|
"--port", "8000", \
|
||||||
|
"--workers", "4", \
|
||||||
|
"--access-log", \
|
||||||
|
"--log-level", "info"]
|
72
be/Dockerfile.prod
Normal file
72
be/Dockerfile.prod
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
# Multi-stage build for production - optimized for size
|
||||||
|
FROM python:3.11-slim AS builder
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||||
|
PYTHONUNBUFFERED=1 \
|
||||||
|
PIP_NO_CACHE_DIR=1 \
|
||||||
|
PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||||
|
|
||||||
|
# Install build dependencies
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
gcc \
|
||||||
|
g++ \
|
||||||
|
libpq-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install Python dependencies
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --user --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# Production stage - minimal image
|
||||||
|
FROM python:3.11-slim AS production
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||||
|
PYTHONUNBUFFERED=1 \
|
||||||
|
PATH=/home/appuser/.local/bin:$PATH
|
||||||
|
|
||||||
|
# Install only runtime dependencies
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
libpq5 \
|
||||||
|
curl \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& apt-get clean
|
||||||
|
|
||||||
|
# Create non-root user
|
||||||
|
RUN groupadd -g 1001 appuser && \
|
||||||
|
useradd -u 1001 -g appuser -m appuser
|
||||||
|
|
||||||
|
# Copy Python packages from builder stage
|
||||||
|
COPY --from=builder /root/.local /home/appuser/.local
|
||||||
|
|
||||||
|
# Set working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy only necessary application files (be selective)
|
||||||
|
COPY --chown=appuser:appuser app/ ./app/
|
||||||
|
COPY --chown=appuser:appuser alembic/ ./alembic/
|
||||||
|
COPY --chown=appuser:appuser alembic.ini ./
|
||||||
|
COPY --chown=appuser:appuser *.py ./
|
||||||
|
COPY --chown=appuser:appuser requirements.txt ./
|
||||||
|
|
||||||
|
# Create logs directory
|
||||||
|
RUN mkdir -p /app/logs && chown -R appuser:appuser /app
|
||||||
|
|
||||||
|
# Switch to non-root user
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \
|
||||||
|
CMD curl -f http://localhost:8000/health || exit 1
|
||||||
|
|
||||||
|
# Expose port
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Production command
|
||||||
|
CMD ["uvicorn", "app.main:app", \
|
||||||
|
"--host", "0.0.0.0", \
|
||||||
|
"--port", "8000", \
|
||||||
|
"--workers", "4", \
|
||||||
|
"--access-log", \
|
||||||
|
"--log-level", "info"]
|
@ -1,31 +1,28 @@
|
|||||||
from logging.config import fileConfig
|
from logging.config import fileConfig
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from sqlalchemy import engine_from_config
|
from sqlalchemy import engine_from_config
|
||||||
from sqlalchemy import pool
|
from sqlalchemy import pool
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
from alembic import context
|
from alembic import context
|
||||||
|
|
||||||
|
|
||||||
# Ensure the 'app' directory is in the Python path
|
# Ensure the 'app' directory is in the Python path
|
||||||
# Adjust the path if your project structure is different
|
# Adjust the path if your project structure is different
|
||||||
sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), '..')))
|
sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
# Import your app's Base and settings
|
# Import your app's Base and settings
|
||||||
from app.models import Base # Import Base from your models module
|
import app.models # Ensure all models are loaded and registered to app.database.Base
|
||||||
|
from app.database import Base as DatabaseBase # Explicitly get Base from database.py
|
||||||
from app.config import settings # Import settings to get DATABASE_URL
|
from app.config import settings # Import settings to get DATABASE_URL
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# Get alembic config
|
||||||
# access to the values within the .ini file in use.
|
|
||||||
config = context.config
|
config = context.config
|
||||||
|
|
||||||
# Set the sqlalchemy.url from your application settings
|
# Set the sqlalchemy.url from your application settings
|
||||||
# Use a synchronous version of the URL for Alembic's operations
|
# Ensure DATABASE_URL is available and use it directly
|
||||||
sync_db_url = settings.DATABASE_URL.replace("+asyncpg", "") if settings.DATABASE_URL else None
|
if not settings.DATABASE_URL:
|
||||||
if not sync_db_url:
|
|
||||||
raise ValueError("DATABASE_URL not found in settings for Alembic.")
|
raise ValueError("DATABASE_URL not found in settings for Alembic.")
|
||||||
config.set_main_option('sqlalchemy.url', sync_db_url)
|
config.set_main_option('sqlalchemy.url', settings.DATABASE_URL)
|
||||||
|
|
||||||
# Interpret the config file for Python logging.
|
# Interpret the config file for Python logging.
|
||||||
# This line sets up loggers basically.
|
# This line sets up loggers basically.
|
||||||
@ -36,26 +33,15 @@ if config.config_file_name is not None:
|
|||||||
# for 'autogenerate' support
|
# for 'autogenerate' support
|
||||||
# from myapp import mymodel
|
# from myapp import mymodel
|
||||||
# target_metadata = mymodel.Base.metadata
|
# target_metadata = mymodel.Base.metadata
|
||||||
target_metadata = Base.metadata
|
target_metadata = DatabaseBase.metadata # Use metadata from app.database.Base
|
||||||
|
|
||||||
# other values from the config, defined by the needs of env.py,
|
# other values from the config, defined by the needs of env.py,
|
||||||
# can be acquired:
|
# can be acquired:
|
||||||
# my_important_option = config.get_main_option("my_important_option")
|
# my_important_option = config.get_main_option("my_important_option")
|
||||||
# ... etc.
|
# ... etc.
|
||||||
|
|
||||||
|
|
||||||
def run_migrations_offline() -> None:
|
def run_migrations_offline() -> None:
|
||||||
"""Run migrations in 'offline' mode.
|
"""Run migrations in 'offline' mode."""
|
||||||
|
|
||||||
This configures the context with just a URL
|
|
||||||
and not an Engine, though an Engine is acceptable
|
|
||||||
here as well. By skipping the Engine creation
|
|
||||||
we don't even need a DBAPI to be available.
|
|
||||||
|
|
||||||
Calls to context.execute() here emit the given string to the
|
|
||||||
script output.
|
|
||||||
|
|
||||||
"""
|
|
||||||
url = config.get_main_option("sqlalchemy.url")
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
context.configure(
|
context.configure(
|
||||||
url=url,
|
url=url,
|
||||||
@ -67,30 +53,32 @@ def run_migrations_offline() -> None:
|
|||||||
with context.begin_transaction():
|
with context.begin_transaction():
|
||||||
context.run_migrations()
|
context.run_migrations()
|
||||||
|
|
||||||
|
async def run_migrations_online() -> None:
|
||||||
def run_migrations_online() -> None:
|
"""Run migrations in 'online' mode."""
|
||||||
"""Run migrations in 'online' mode.
|
connectable = create_async_engine(
|
||||||
|
settings.DATABASE_URL,
|
||||||
In this scenario we need to create an Engine
|
|
||||||
and associate a connection with the context.
|
|
||||||
|
|
||||||
"""
|
|
||||||
connectable = engine_from_config(
|
|
||||||
config.get_section(config.config_ini_section, {}),
|
|
||||||
prefix="sqlalchemy.",
|
|
||||||
poolclass=pool.NullPool,
|
poolclass=pool.NullPool,
|
||||||
)
|
)
|
||||||
|
|
||||||
with connectable.connect() as connection:
|
async with connectable.connect() as connection:
|
||||||
context.configure(
|
await connection.run_sync(_run_migrations)
|
||||||
connection=connection, target_metadata=target_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
with context.begin_transaction():
|
await connectable.dispose()
|
||||||
context.run_migrations()
|
|
||||||
|
|
||||||
|
def _run_migrations(connection):
|
||||||
|
context.configure(
|
||||||
|
connection=connection,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
compare_type=True,
|
||||||
|
compare_server_default=True
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
# This section only runs when executing alembic commands directly (not when imported)
|
||||||
if context.is_offline_mode():
|
if context.is_offline_mode():
|
||||||
run_migrations_offline()
|
run_migrations_offline()
|
||||||
else:
|
else:
|
||||||
run_migrations_online()
|
import asyncio
|
||||||
|
asyncio.run(run_migrations_online())
|
||||||
|
74
be/alembic/migrations.py
Normal file
74
be/alembic/migrations.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
"""
|
||||||
|
Async migrations handler for FastAPI application.
|
||||||
|
This file is separate from env.py to avoid Alembic context issues.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
from sqlalchemy import pool
|
||||||
|
from alembic.config import Config
|
||||||
|
from alembic.script import ScriptDirectory
|
||||||
|
from alembic.runtime.migration import MigrationContext
|
||||||
|
from alembic.operations import Operations
|
||||||
|
|
||||||
|
# Ensure the app directory is in the Python path
|
||||||
|
sys.path.insert(0, os.path.realpath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
from app.database import Base as DatabaseBase
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
def _get_migration_fn(script_directory, current_rev):
|
||||||
|
"""Create a migration function that knows how to upgrade from current revision."""
|
||||||
|
def migration_fn(rev, context):
|
||||||
|
# Get all upgrade steps from current revision to head
|
||||||
|
revisions = script_directory._upgrade_revs("head", current_rev)
|
||||||
|
for revision_step in revisions:
|
||||||
|
# Access the revision string from the Script object, which is within the RevisionStep object
|
||||||
|
script = script_directory.get_revision(revision_step.revision.revision)
|
||||||
|
script.module.upgrade(context)
|
||||||
|
return migration_fn
|
||||||
|
|
||||||
|
async def run_migrations():
|
||||||
|
"""Run database migrations asynchronously."""
|
||||||
|
# Get alembic configuration and script directory
|
||||||
|
alembic_cfg = Config(os.path.join(os.path.dirname(__file__), '..', 'alembic.ini'))
|
||||||
|
script_directory = ScriptDirectory.from_config(alembic_cfg)
|
||||||
|
|
||||||
|
# Create async engine
|
||||||
|
engine = create_async_engine(
|
||||||
|
settings.DATABASE_URL,
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with engine.connect() as connection:
|
||||||
|
def get_current_rev(conn):
|
||||||
|
migration_context = MigrationContext.configure(
|
||||||
|
conn,
|
||||||
|
opts={
|
||||||
|
'target_metadata': DatabaseBase.metadata,
|
||||||
|
'script': script_directory
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return migration_context.get_current_revision()
|
||||||
|
|
||||||
|
current_rev = await connection.run_sync(get_current_rev)
|
||||||
|
|
||||||
|
def upgrade_to_head(conn):
|
||||||
|
migration_context = MigrationContext.configure(
|
||||||
|
conn,
|
||||||
|
opts={
|
||||||
|
'target_metadata': DatabaseBase.metadata,
|
||||||
|
'script': script_directory,
|
||||||
|
'as_sql': False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the migration function
|
||||||
|
migration_context._migrations_fn = _get_migration_fn(script_directory, current_rev)
|
||||||
|
|
||||||
|
with migration_context.begin_transaction():
|
||||||
|
migration_context.run_migrations()
|
||||||
|
|
||||||
|
await connection.run_sync(upgrade_to_head)
|
||||||
|
|
||||||
|
await engine.dispose()
|
305
be/alembic/versions/0001_initial_schema.py
Normal file
305
be/alembic/versions/0001_initial_schema.py
Normal file
@ -0,0 +1,305 @@
|
|||||||
|
"""Initial schema setup
|
||||||
|
|
||||||
|
Revision ID: 0001_initial_schema
|
||||||
|
Revises:
|
||||||
|
Create Date: YYYY-MM-DD HH:MM:SS.ffffff
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '0001_initial_schema'
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
user_role_enum = postgresql.ENUM('owner', 'member', name='userroleenum', create_type=False)
|
||||||
|
split_type_enum = postgresql.ENUM('EQUAL', 'EXACT_AMOUNTS', 'PERCENTAGE', 'SHARES', 'ITEM_BASED', name='splittypeenum', create_type=False)
|
||||||
|
expense_split_status_enum = postgresql.ENUM('unpaid', 'partially_paid', 'paid', name='expensesplitstatusenum', create_type=False)
|
||||||
|
expense_overall_status_enum = postgresql.ENUM('unpaid', 'partially_paid', 'paid', name='expenseoverallstatusenum', create_type=False)
|
||||||
|
recurrence_type_enum = postgresql.ENUM('DAILY', 'WEEKLY', 'MONTHLY', 'YEARLY', name='recurrencetypeenum', create_type=False)
|
||||||
|
chore_frequency_enum = postgresql.ENUM('one_time', 'daily', 'weekly', 'monthly', 'custom', name='chorefrequencyenum', create_type=False)
|
||||||
|
chore_type_enum = postgresql.ENUM('personal', 'group', name='choretypeenum', create_type=False)
|
||||||
|
|
||||||
|
def upgrade(context=None) -> None: # Add context=None for compatibility, real arg passed by Alembic
|
||||||
|
# Create enums
|
||||||
|
user_role_enum.create(op.get_bind(), checkfirst=True)
|
||||||
|
split_type_enum.create(op.get_bind(), checkfirst=True)
|
||||||
|
expense_split_status_enum.create(op.get_bind(), checkfirst=True)
|
||||||
|
expense_overall_status_enum.create(op.get_bind(), checkfirst=True)
|
||||||
|
recurrence_type_enum.create(op.get_bind(), checkfirst=True)
|
||||||
|
chore_frequency_enum.create(op.get_bind(), checkfirst=True)
|
||||||
|
chore_type_enum.create(op.get_bind(), checkfirst=True)
|
||||||
|
|
||||||
|
op.create_table('users',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('email', sa.String(), nullable=False),
|
||||||
|
sa.Column('hashed_password', sa.String(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(), nullable=True),
|
||||||
|
sa.Column('is_active', sa.Boolean(), server_default=sa.text('true'), nullable=False),
|
||||||
|
sa.Column('is_superuser', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||||
|
sa.Column('is_verified', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||||
|
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_users_name'), 'users', ['name'], unique=False)
|
||||||
|
|
||||||
|
op.create_table('groups',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(), nullable=False),
|
||||||
|
sa.Column('created_by_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_groups_id'), 'groups', ['id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_groups_name'), 'groups', ['name'], unique=False)
|
||||||
|
|
||||||
|
op.create_table('user_groups',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('group_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('role', user_role_enum, nullable=False),
|
||||||
|
sa.Column('joined_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('user_id', 'group_id', name='uq_user_group')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_user_groups_id'), 'user_groups', ['id'], unique=False)
|
||||||
|
|
||||||
|
op.create_table('invites',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('code', sa.String(), nullable=False),
|
||||||
|
sa.Column('group_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_by_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column('is_active', sa.Boolean(), server_default=sa.text('true'), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_invites_code'), 'invites', ['code'], unique=False)
|
||||||
|
op.create_index('ix_invites_active_code', 'invites', ['code'], unique=True, postgresql_where=sa.text('is_active = true'))
|
||||||
|
op.create_index(op.f('ix_invites_id'), 'invites', ['id'], unique=False)
|
||||||
|
|
||||||
|
op.create_table('lists',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('created_by_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('group_id', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('is_complete', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('version', sa.Integer(), server_default='1', nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_lists_id'), 'lists', ['id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_lists_name'), 'lists', ['name'], unique=False)
|
||||||
|
|
||||||
|
op.create_table('recurrence_patterns',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('type', recurrence_type_enum, nullable=False),
|
||||||
|
sa.Column('interval', sa.Integer(), server_default='1', nullable=False),
|
||||||
|
sa.Column('days_of_week', sa.String(), nullable=True),
|
||||||
|
sa.Column('end_date', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column('max_occurrences', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_recurrence_patterns_id'), 'recurrence_patterns', ['id'], unique=False)
|
||||||
|
|
||||||
|
op.create_table('items',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('list_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(), nullable=False),
|
||||||
|
sa.Column('quantity', sa.String(), nullable=True),
|
||||||
|
sa.Column('is_complete', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||||
|
sa.Column('price', sa.Numeric(precision=10, scale=2), nullable=True),
|
||||||
|
sa.Column('added_by_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('completed_by_id', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('version', sa.Integer(), server_default='1', nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['added_by_id'], ['users.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['completed_by_id'], ['users.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['list_id'], ['lists.id'], ondelete='CASCADE'),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_items_id'), 'items', ['id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_items_name'), 'items', ['name'], unique=False)
|
||||||
|
|
||||||
|
op.create_table('chores',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('type', chore_type_enum, nullable=False),
|
||||||
|
sa.Column('group_id', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('name', sa.String(), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('created_by_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('frequency', chore_frequency_enum, nullable=False),
|
||||||
|
sa.Column('custom_interval_days', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('next_due_date', sa.Date(), nullable=False),
|
||||||
|
sa.Column('last_completed_at', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_chores_created_by_id'), 'chores', ['created_by_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_chores_group_id'), 'chores', ['group_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_chores_id'), 'chores', ['id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_chores_name'), 'chores', ['name'], unique=False)
|
||||||
|
|
||||||
|
op.create_table('expenses',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('description', sa.String(), nullable=False),
|
||||||
|
sa.Column('total_amount', sa.Numeric(precision=10, scale=2), nullable=False),
|
||||||
|
sa.Column('currency', sa.String(), server_default='USD', nullable=False),
|
||||||
|
sa.Column('expense_date', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('split_type', split_type_enum, nullable=False),
|
||||||
|
sa.Column('list_id', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('group_id', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('item_id', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('paid_by_user_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_by_user_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('version', sa.Integer(), server_default='1', nullable=False),
|
||||||
|
sa.Column('overall_settlement_status', expense_overall_status_enum, server_default='unpaid', nullable=False),
|
||||||
|
sa.Column('is_recurring', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||||
|
sa.Column('recurrence_pattern_id', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('next_occurrence', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column('parent_expense_id', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('last_occurrence', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.CheckConstraint('(item_id IS NOT NULL) OR (list_id IS NOT NULL) OR (group_id IS NOT NULL)', name='chk_expense_context'),
|
||||||
|
sa.ForeignKeyConstraint(['created_by_user_id'], ['users.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['item_id'], ['items.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['list_id'], ['lists.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['paid_by_user_id'], ['users.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['parent_expense_id'], ['expenses.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['recurrence_pattern_id'], ['recurrence_patterns.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_expenses_created_by_user_id'), 'expenses', ['created_by_user_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_expenses_group_id'), 'expenses', ['group_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_expenses_id'), 'expenses', ['id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_expenses_list_id'), 'expenses', ['list_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_expenses_paid_by_user_id'), 'expenses', ['paid_by_user_id'], unique=False)
|
||||||
|
op.create_index('ix_expenses_recurring_next_occurrence', 'expenses', ['is_recurring', 'next_occurrence'], unique=False, postgresql_where=sa.text('is_recurring = true'))
|
||||||
|
|
||||||
|
op.create_table('chore_assignments',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('chore_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('assigned_to_user_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('due_date', sa.Date(), nullable=False),
|
||||||
|
sa.Column('is_complete', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||||
|
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['assigned_to_user_id'], ['users.id'], ondelete='CASCADE'),
|
||||||
|
sa.ForeignKeyConstraint(['chore_id'], ['chores.id'], ondelete='CASCADE'),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_chore_assignments_assigned_to_user_id'), 'chore_assignments', ['assigned_to_user_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_chore_assignments_chore_id'), 'chore_assignments', ['chore_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_chore_assignments_id'), 'chore_assignments', ['id'], unique=False)
|
||||||
|
|
||||||
|
op.create_table('expense_splits',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('expense_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('owed_amount', sa.Numeric(precision=10, scale=2), nullable=False),
|
||||||
|
sa.Column('share_percentage', sa.Numeric(precision=5, scale=2), nullable=True),
|
||||||
|
sa.Column('share_units', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('status', expense_split_status_enum, server_default='unpaid', nullable=False),
|
||||||
|
sa.Column('paid_at', sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['expense_id'], ['expenses.id'], ondelete='CASCADE'),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('expense_id', 'user_id', name='uq_expense_user_split')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_expense_splits_id'), 'expense_splits', ['id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_expense_splits_user_id'), 'expense_splits', ['user_id'], unique=False)
|
||||||
|
|
||||||
|
op.create_table('settlements',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('group_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('paid_by_user_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('paid_to_user_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('amount', sa.Numeric(precision=10, scale=2), nullable=False),
|
||||||
|
sa.Column('settlement_date', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
|
sa.Column('created_by_user_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('version', sa.Integer(), server_default='1', nullable=False),
|
||||||
|
sa.CheckConstraint('paid_by_user_id != paid_to_user_id', name='chk_settlement_different_users'),
|
||||||
|
sa.ForeignKeyConstraint(['created_by_user_id'], ['users.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['paid_by_user_id'], ['users.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['paid_to_user_id'], ['users.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_settlements_created_by_user_id'), 'settlements', ['created_by_user_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_settlements_group_id'), 'settlements', ['group_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_settlements_id'), 'settlements', ['id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_settlements_paid_by_user_id'), 'settlements', ['paid_by_user_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_settlements_paid_to_user_id'), 'settlements', ['paid_to_user_id'], unique=False)
|
||||||
|
|
||||||
|
op.create_table('settlement_activities',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('expense_split_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('paid_by_user_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('paid_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('amount_paid', sa.Numeric(precision=10, scale=2), nullable=False),
|
||||||
|
sa.Column('created_by_user_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['created_by_user_id'], ['users.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['expense_split_id'], ['expense_splits.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['paid_by_user_id'], ['users.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_settlement_activities_created_by_user_id'), 'settlement_activities', ['created_by_user_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_settlement_activities_expense_split_id'), 'settlement_activities', ['expense_split_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_settlement_activities_paid_by_user_id'), 'settlement_activities', ['paid_by_user_id'], unique=False)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade(context=None) -> None: # Add context=None for compatibility, real arg passed by Alembic
|
||||||
|
op.drop_table('settlement_activities')
|
||||||
|
op.drop_table('settlements')
|
||||||
|
op.drop_table('expense_splits')
|
||||||
|
op.drop_table('chore_assignments')
|
||||||
|
op.drop_table('expenses')
|
||||||
|
op.drop_table('chores')
|
||||||
|
op.drop_table('items')
|
||||||
|
op.drop_table('recurrence_patterns')
|
||||||
|
op.drop_table('lists')
|
||||||
|
op.drop_table('invites')
|
||||||
|
op.drop_table('user_groups')
|
||||||
|
op.drop_table('groups')
|
||||||
|
op.drop_table('users')
|
||||||
|
|
||||||
|
# Drop enums in reverse order
|
||||||
|
chore_type_enum.drop(op.get_bind(), checkfirst=True)
|
||||||
|
chore_frequency_enum.drop(op.get_bind(), checkfirst=True)
|
||||||
|
recurrence_type_enum.drop(op.get_bind(), checkfirst=True)
|
||||||
|
expense_overall_status_enum.drop(op.get_bind(), checkfirst=True)
|
||||||
|
expense_split_status_enum.drop(op.get_bind(), checkfirst=True)
|
||||||
|
split_type_enum.drop(op.get_bind(), checkfirst=True)
|
||||||
|
user_role_enum.drop(op.get_bind(), checkfirst=True)
|
@ -1,32 +0,0 @@
|
|||||||
"""Add invite table and relationships
|
|
||||||
|
|
||||||
Revision ID: 563ee77c5214
|
|
||||||
Revises: 69b0c1432084
|
|
||||||
Create Date: 2025-03-30 18:51:19.926810
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = '563ee77c5214'
|
|
||||||
down_revision: Union[str, None] = '69b0c1432084'
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
@ -1,32 +0,0 @@
|
|||||||
"""Initial database setup
|
|
||||||
|
|
||||||
Revision ID: 643956b3f4de
|
|
||||||
Revises:
|
|
||||||
Create Date: 2025-03-29 20:49:01.018626
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = '643956b3f4de'
|
|
||||||
down_revision: Union[str, None] = None
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
@ -1,32 +0,0 @@
|
|||||||
"""Add invite table and relationships
|
|
||||||
|
|
||||||
Revision ID: 69b0c1432084
|
|
||||||
Revises: 6f80b82dbdf8
|
|
||||||
Create Date: 2025-03-30 18:50:48.072504
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = '69b0c1432084'
|
|
||||||
down_revision: Union[str, None] = '6f80b82dbdf8'
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
@ -1,32 +0,0 @@
|
|||||||
"""Add invite table and relationships
|
|
||||||
|
|
||||||
Revision ID: 6f80b82dbdf8
|
|
||||||
Revises: f42efe4f4bca
|
|
||||||
Create Date: 2025-03-30 18:49:26.968637
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = '6f80b82dbdf8'
|
|
||||||
down_revision: Union[str, None] = 'f42efe4f4bca'
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
@ -1,72 +0,0 @@
|
|||||||
"""Add User, Group, UserGroup models
|
|
||||||
|
|
||||||
Revision ID: 85a3c075e73a
|
|
||||||
Revises: c6cbef99588b
|
|
||||||
Create Date: 2025-03-30 12:46:07.322285
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = '85a3c075e73a'
|
|
||||||
down_revision: Union[str, None] = 'c6cbef99588b'
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.create_table('users',
|
|
||||||
sa.Column('id', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('email', sa.String(), nullable=False),
|
|
||||||
sa.Column('password_hash', sa.String(), nullable=False),
|
|
||||||
sa.Column('name', sa.String(), nullable=True),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
|
||||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
|
||||||
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False)
|
|
||||||
op.create_index(op.f('ix_users_name'), 'users', ['name'], unique=False)
|
|
||||||
op.create_table('groups',
|
|
||||||
sa.Column('id', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('name', sa.String(), nullable=False),
|
|
||||||
sa.Column('created_by_id', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
|
||||||
op.create_index(op.f('ix_groups_id'), 'groups', ['id'], unique=False)
|
|
||||||
op.create_index(op.f('ix_groups_name'), 'groups', ['name'], unique=False)
|
|
||||||
op.create_table('user_groups',
|
|
||||||
sa.Column('id', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('group_id', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('role', sa.Enum('owner', 'member', name='userroleenum'), nullable=False),
|
|
||||||
sa.Column('joined_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'),
|
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('id'),
|
|
||||||
sa.UniqueConstraint('user_id', 'group_id', name='uq_user_group')
|
|
||||||
)
|
|
||||||
op.create_index(op.f('ix_user_groups_id'), 'user_groups', ['id'], unique=False)
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_index(op.f('ix_user_groups_id'), table_name='user_groups')
|
|
||||||
op.drop_table('user_groups')
|
|
||||||
op.drop_index(op.f('ix_groups_name'), table_name='groups')
|
|
||||||
op.drop_index(op.f('ix_groups_id'), table_name='groups')
|
|
||||||
op.drop_table('groups')
|
|
||||||
op.drop_index(op.f('ix_users_name'), table_name='users')
|
|
||||||
op.drop_index(op.f('ix_users_id'), table_name='users')
|
|
||||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
|
||||||
op.drop_table('users')
|
|
||||||
# ### end Alembic commands ###
|
|
@ -0,0 +1,133 @@
|
|||||||
|
"""Add position to Item model for reordering
|
||||||
|
|
||||||
|
Revision ID: 91d00c100f5b
|
||||||
|
Revises: 0001_initial_schema
|
||||||
|
Create Date: 2025-06-07 14:59:48.761124
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '91d00c100f5b'
|
||||||
|
down_revision: Union[str, None] = '0001_initial_schema'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index('ix_apscheduler_jobs_next_run_time', table_name='apscheduler_jobs')
|
||||||
|
op.drop_table('apscheduler_jobs')
|
||||||
|
op.alter_column('chore_assignments', 'is_complete',
|
||||||
|
existing_type=sa.BOOLEAN(),
|
||||||
|
server_default=None,
|
||||||
|
existing_nullable=False)
|
||||||
|
op.alter_column('expenses', 'currency',
|
||||||
|
existing_type=sa.VARCHAR(),
|
||||||
|
server_default=None,
|
||||||
|
existing_nullable=False)
|
||||||
|
op.alter_column('expenses', 'is_recurring',
|
||||||
|
existing_type=sa.BOOLEAN(),
|
||||||
|
server_default=None,
|
||||||
|
existing_nullable=False)
|
||||||
|
op.drop_index('ix_expenses_recurring_next_occurrence', table_name='expenses', postgresql_where='(is_recurring = true)')
|
||||||
|
op.alter_column('invites', 'is_active',
|
||||||
|
existing_type=sa.BOOLEAN(),
|
||||||
|
server_default=None,
|
||||||
|
existing_nullable=False)
|
||||||
|
op.add_column('items', sa.Column('position', sa.Integer(), server_default='0', nullable=False))
|
||||||
|
op.alter_column('items', 'is_complete',
|
||||||
|
existing_type=sa.BOOLEAN(),
|
||||||
|
server_default=None,
|
||||||
|
existing_nullable=False)
|
||||||
|
op.create_index('ix_items_list_id_position', 'items', ['list_id', 'position'], unique=False)
|
||||||
|
op.alter_column('lists', 'is_complete',
|
||||||
|
existing_type=sa.BOOLEAN(),
|
||||||
|
server_default=None,
|
||||||
|
existing_nullable=False)
|
||||||
|
op.alter_column('recurrence_patterns', 'interval',
|
||||||
|
existing_type=sa.INTEGER(),
|
||||||
|
server_default=None,
|
||||||
|
existing_nullable=False)
|
||||||
|
op.create_index(op.f('ix_settlement_activities_id'), 'settlement_activities', ['id'], unique=False)
|
||||||
|
op.create_index('ix_settlement_activity_created_by_user_id', 'settlement_activities', ['created_by_user_id'], unique=False)
|
||||||
|
op.create_index('ix_settlement_activity_expense_split_id', 'settlement_activities', ['expense_split_id'], unique=False)
|
||||||
|
op.create_index('ix_settlement_activity_paid_by_user_id', 'settlement_activities', ['paid_by_user_id'], unique=False)
|
||||||
|
op.alter_column('users', 'is_active',
|
||||||
|
existing_type=sa.BOOLEAN(),
|
||||||
|
server_default=None,
|
||||||
|
existing_nullable=False)
|
||||||
|
op.alter_column('users', 'is_superuser',
|
||||||
|
existing_type=sa.BOOLEAN(),
|
||||||
|
server_default=None,
|
||||||
|
existing_nullable=False)
|
||||||
|
op.alter_column('users', 'is_verified',
|
||||||
|
existing_type=sa.BOOLEAN(),
|
||||||
|
server_default=None,
|
||||||
|
existing_nullable=False)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.alter_column('users', 'is_verified',
|
||||||
|
existing_type=sa.BOOLEAN(),
|
||||||
|
server_default=sa.text('false'),
|
||||||
|
existing_nullable=False)
|
||||||
|
op.alter_column('users', 'is_superuser',
|
||||||
|
existing_type=sa.BOOLEAN(),
|
||||||
|
server_default=sa.text('false'),
|
||||||
|
existing_nullable=False)
|
||||||
|
op.alter_column('users', 'is_active',
|
||||||
|
existing_type=sa.BOOLEAN(),
|
||||||
|
server_default=sa.text('true'),
|
||||||
|
existing_nullable=False)
|
||||||
|
op.drop_index('ix_settlement_activity_paid_by_user_id', table_name='settlement_activities')
|
||||||
|
op.drop_index('ix_settlement_activity_expense_split_id', table_name='settlement_activities')
|
||||||
|
op.drop_index('ix_settlement_activity_created_by_user_id', table_name='settlement_activities')
|
||||||
|
op.drop_index(op.f('ix_settlement_activities_id'), table_name='settlement_activities')
|
||||||
|
op.alter_column('recurrence_patterns', 'interval',
|
||||||
|
existing_type=sa.INTEGER(),
|
||||||
|
server_default=sa.text('1'),
|
||||||
|
existing_nullable=False)
|
||||||
|
op.alter_column('lists', 'is_complete',
|
||||||
|
existing_type=sa.BOOLEAN(),
|
||||||
|
server_default=sa.text('false'),
|
||||||
|
existing_nullable=False)
|
||||||
|
op.drop_index('ix_items_list_id_position', table_name='items')
|
||||||
|
op.alter_column('items', 'is_complete',
|
||||||
|
existing_type=sa.BOOLEAN(),
|
||||||
|
server_default=sa.text('false'),
|
||||||
|
existing_nullable=False)
|
||||||
|
op.drop_column('items', 'position')
|
||||||
|
op.alter_column('invites', 'is_active',
|
||||||
|
existing_type=sa.BOOLEAN(),
|
||||||
|
server_default=sa.text('true'),
|
||||||
|
existing_nullable=False)
|
||||||
|
op.create_index('ix_expenses_recurring_next_occurrence', 'expenses', ['is_recurring', 'next_occurrence'], unique=False, postgresql_where='(is_recurring = true)')
|
||||||
|
op.alter_column('expenses', 'is_recurring',
|
||||||
|
existing_type=sa.BOOLEAN(),
|
||||||
|
server_default=sa.text('false'),
|
||||||
|
existing_nullable=False)
|
||||||
|
op.alter_column('expenses', 'currency',
|
||||||
|
existing_type=sa.VARCHAR(),
|
||||||
|
server_default=sa.text("'USD'::character varying"),
|
||||||
|
existing_nullable=False)
|
||||||
|
op.alter_column('chore_assignments', 'is_complete',
|
||||||
|
existing_type=sa.BOOLEAN(),
|
||||||
|
server_default=sa.text('false'),
|
||||||
|
existing_nullable=False)
|
||||||
|
op.create_table('apscheduler_jobs',
|
||||||
|
sa.Column('id', sa.VARCHAR(length=191), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('next_run_time', sa.DOUBLE_PRECISION(precision=53), autoincrement=False, nullable=True),
|
||||||
|
sa.Column('job_state', postgresql.BYTEA(), autoincrement=False, nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='apscheduler_jobs_pkey')
|
||||||
|
)
|
||||||
|
op.create_index('ix_apscheduler_jobs_next_run_time', 'apscheduler_jobs', ['next_run_time'], unique=False)
|
||||||
|
# ### end Alembic commands ###
|
@ -1,32 +0,0 @@
|
|||||||
"""Initial database setup
|
|
||||||
|
|
||||||
Revision ID: c6cbef99588b
|
|
||||||
Revises: 643956b3f4de
|
|
||||||
Create Date: 2025-03-30 12:18:51.207858
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = 'c6cbef99588b'
|
|
||||||
down_revision: Union[str, None] = '643956b3f4de'
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
@ -1,73 +0,0 @@
|
|||||||
"""Add list and item tables
|
|
||||||
|
|
||||||
Revision ID: d25788f63e2c
|
|
||||||
Revises: d90ab7116920
|
|
||||||
Create Date: 2025-03-30 19:43:49.925240
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = 'd25788f63e2c'
|
|
||||||
down_revision: Union[str, None] = 'd90ab7116920'
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.create_table('lists',
|
|
||||||
sa.Column('id', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('name', sa.String(), nullable=False),
|
|
||||||
sa.Column('description', sa.Text(), nullable=True),
|
|
||||||
sa.Column('created_by_id', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('group_id', sa.Integer(), nullable=True),
|
|
||||||
sa.Column('is_complete', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ),
|
|
||||||
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
|
||||||
op.create_index(op.f('ix_lists_id'), 'lists', ['id'], unique=False)
|
|
||||||
op.create_index(op.f('ix_lists_name'), 'lists', ['name'], unique=False)
|
|
||||||
op.create_table('items',
|
|
||||||
sa.Column('id', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('list_id', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('name', sa.String(), nullable=False),
|
|
||||||
sa.Column('quantity', sa.String(), nullable=True),
|
|
||||||
sa.Column('is_complete', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('price', sa.Numeric(precision=10, scale=2), nullable=True),
|
|
||||||
sa.Column('added_by_id', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('completed_by_id', sa.Integer(), nullable=True),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['added_by_id'], ['users.id'], ),
|
|
||||||
sa.ForeignKeyConstraint(['completed_by_id'], ['users.id'], ),
|
|
||||||
sa.ForeignKeyConstraint(['list_id'], ['lists.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
|
||||||
op.create_index(op.f('ix_items_id'), 'items', ['id'], unique=False)
|
|
||||||
op.create_index(op.f('ix_items_name'), 'items', ['name'], unique=False)
|
|
||||||
op.drop_index('ix_invites_code', table_name='invites')
|
|
||||||
op.create_index(op.f('ix_invites_code'), 'invites', ['code'], unique=False)
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_index(op.f('ix_invites_code'), table_name='invites')
|
|
||||||
op.create_index('ix_invites_code', 'invites', ['code'], unique=True)
|
|
||||||
op.drop_index(op.f('ix_items_name'), table_name='items')
|
|
||||||
op.drop_index(op.f('ix_items_id'), table_name='items')
|
|
||||||
op.drop_table('items')
|
|
||||||
op.drop_index(op.f('ix_lists_name'), table_name='lists')
|
|
||||||
op.drop_index(op.f('ix_lists_id'), table_name='lists')
|
|
||||||
op.drop_table('lists')
|
|
||||||
# ### end Alembic commands ###
|
|
@ -1,32 +0,0 @@
|
|||||||
"""Add invite table and relationships
|
|
||||||
|
|
||||||
Revision ID: d90ab7116920
|
|
||||||
Revises: 563ee77c5214
|
|
||||||
Create Date: 2025-03-30 18:57:39.047729
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = 'd90ab7116920'
|
|
||||||
down_revision: Union[str, None] = '563ee77c5214'
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
@ -1,49 +0,0 @@
|
|||||||
"""Add invite table and relationships
|
|
||||||
|
|
||||||
Revision ID: f42efe4f4bca
|
|
||||||
Revises: 85a3c075e73a
|
|
||||||
Create Date: 2025-03-30 18:41:50.854172
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = 'f42efe4f4bca'
|
|
||||||
down_revision: Union[str, None] = '85a3c075e73a'
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.create_table('invites',
|
|
||||||
sa.Column('id', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('code', sa.String(), nullable=False),
|
|
||||||
sa.Column('group_id', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('created_by_id', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
|
||||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['created_by_id'], ['users.id'], ),
|
|
||||||
sa.ForeignKeyConstraint(['group_id'], ['groups.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
|
||||||
op.create_index('ix_invites_active_code', 'invites', ['code'], unique=True, postgresql_where=sa.text('is_active = true'))
|
|
||||||
op.create_index(op.f('ix_invites_code'), 'invites', ['code'], unique=True)
|
|
||||||
op.create_index(op.f('ix_invites_id'), 'invites', ['id'], unique=False)
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_index(op.f('ix_invites_id'), table_name='invites')
|
|
||||||
op.drop_index(op.f('ix_invites_code'), table_name='invites')
|
|
||||||
op.drop_index('ix_invites_active_code', table_name='invites', postgresql_where=sa.text('is_active = true'))
|
|
||||||
op.drop_table('invites')
|
|
||||||
# ### end Alembic commands ###
|
|
95
be/app/api/auth/oauth.py
Normal file
95
be/app/api/auth/oauth.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.database import get_transactional_session
|
||||||
|
from app.models import User
|
||||||
|
from app.auth import oauth, fastapi_users, auth_backend, get_jwt_strategy, get_refresh_jwt_strategy
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.get('/google/login')
|
||||||
|
async def google_login(request: Request):
|
||||||
|
return await oauth.google.authorize_redirect(request, settings.GOOGLE_REDIRECT_URI)
|
||||||
|
|
||||||
|
@router.get('/google/callback')
|
||||||
|
async def google_callback(request: Request, db: AsyncSession = Depends(get_transactional_session)):
|
||||||
|
token_data = await oauth.google.authorize_access_token(request)
|
||||||
|
user_info = await oauth.google.parse_id_token(request, token_data)
|
||||||
|
|
||||||
|
# Check if user exists
|
||||||
|
existing_user = (await db.execute(select(User).where(User.email == user_info['email']))).scalar_one_or_none()
|
||||||
|
|
||||||
|
user_to_login = existing_user
|
||||||
|
if not existing_user:
|
||||||
|
# Create new user
|
||||||
|
new_user = User(
|
||||||
|
email=user_info['email'],
|
||||||
|
name=user_info.get('name', user_info.get('email')),
|
||||||
|
is_verified=True, # Email is verified by Google
|
||||||
|
is_active=True
|
||||||
|
)
|
||||||
|
db.add(new_user)
|
||||||
|
await db.flush() # Use flush instead of commit since we're in a transaction
|
||||||
|
user_to_login = new_user
|
||||||
|
|
||||||
|
# Generate JWT tokens using the new backend
|
||||||
|
access_strategy = get_jwt_strategy()
|
||||||
|
refresh_strategy = get_refresh_jwt_strategy()
|
||||||
|
|
||||||
|
access_token = await access_strategy.write_token(user_to_login)
|
||||||
|
refresh_token = await refresh_strategy.write_token(user_to_login)
|
||||||
|
|
||||||
|
# Redirect to frontend with tokens
|
||||||
|
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}&refresh_token={refresh_token}"
|
||||||
|
|
||||||
|
return RedirectResponse(url=redirect_url)
|
||||||
|
|
||||||
|
@router.get('/apple/login')
|
||||||
|
async def apple_login(request: Request):
|
||||||
|
return await oauth.apple.authorize_redirect(request, settings.APPLE_REDIRECT_URI)
|
||||||
|
|
||||||
|
@router.get('/apple/callback')
|
||||||
|
async def apple_callback(request: Request, db: AsyncSession = Depends(get_transactional_session)):
|
||||||
|
token_data = await oauth.apple.authorize_access_token(request)
|
||||||
|
user_info = token_data.get('user', await oauth.apple.userinfo(token=token_data) if hasattr(oauth.apple, 'userinfo') else {})
|
||||||
|
if 'email' not in user_info and 'sub' in token_data:
|
||||||
|
parsed_id_token = await oauth.apple.parse_id_token(request, token_data) if hasattr(oauth.apple, 'parse_id_token') else {}
|
||||||
|
user_info = {**parsed_id_token, **user_info}
|
||||||
|
|
||||||
|
if 'email' not in user_info:
|
||||||
|
return RedirectResponse(url=f"{settings.FRONTEND_URL}/auth/callback?error=apple_email_missing")
|
||||||
|
|
||||||
|
# Check if user exists
|
||||||
|
existing_user = (await db.execute(select(User).where(User.email == user_info['email']))).scalar_one_or_none()
|
||||||
|
|
||||||
|
user_to_login = existing_user
|
||||||
|
if not existing_user:
|
||||||
|
# Create new user
|
||||||
|
name_info = user_info.get('name', {})
|
||||||
|
first_name = name_info.get('firstName', '')
|
||||||
|
last_name = name_info.get('lastName', '')
|
||||||
|
full_name = f"{first_name} {last_name}".strip() if first_name or last_name else user_info.get('email')
|
||||||
|
|
||||||
|
new_user = User(
|
||||||
|
email=user_info['email'],
|
||||||
|
name=full_name,
|
||||||
|
is_verified=True, # Email is verified by Apple
|
||||||
|
is_active=True
|
||||||
|
)
|
||||||
|
db.add(new_user)
|
||||||
|
await db.flush() # Use flush instead of commit since we're in a transaction
|
||||||
|
user_to_login = new_user
|
||||||
|
|
||||||
|
# Generate JWT tokens using the new backend
|
||||||
|
access_strategy = get_jwt_strategy()
|
||||||
|
refresh_strategy = get_refresh_jwt_strategy()
|
||||||
|
|
||||||
|
access_token = await access_strategy.write_token(user_to_login)
|
||||||
|
refresh_token = await refresh_strategy.write_token(user_to_login)
|
||||||
|
|
||||||
|
# Redirect to frontend with tokens
|
||||||
|
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}&refresh_token={refresh_token}"
|
||||||
|
|
||||||
|
return RedirectResponse(url=redirect_url)
|
@ -1,71 +0,0 @@
|
|||||||
# app/api/dependencies.py
|
|
||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, status
|
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from jose import JWTError
|
|
||||||
|
|
||||||
from app.database import get_db
|
|
||||||
from app.core.security import verify_access_token
|
|
||||||
from app.crud import user as crud_user
|
|
||||||
from app.models import User as UserModel # Import the SQLAlchemy model
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Define the OAuth2 scheme
|
|
||||||
# tokenUrl should point to your login endpoint relative to the base path
|
|
||||||
# It's used by Swagger UI for the "Authorize" button flow.
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") # Corrected path
|
|
||||||
|
|
||||||
async def get_current_user(
|
|
||||||
token: str = Depends(oauth2_scheme),
|
|
||||||
db: AsyncSession = Depends(get_db)
|
|
||||||
) -> UserModel:
|
|
||||||
"""
|
|
||||||
Dependency to get the current user based on the JWT token.
|
|
||||||
|
|
||||||
- Extracts token using OAuth2PasswordBearer.
|
|
||||||
- Verifies the token (signature, expiry).
|
|
||||||
- Fetches the user from the database based on the token's subject (email).
|
|
||||||
- Raises HTTPException 401 if any step fails.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The authenticated user's database model instance.
|
|
||||||
"""
|
|
||||||
credentials_exception = HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Could not validate credentials",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = verify_access_token(token)
|
|
||||||
if payload is None:
|
|
||||||
logger.warning("Token verification failed (invalid, expired, or malformed).")
|
|
||||||
raise credentials_exception
|
|
||||||
|
|
||||||
email: Optional[str] = payload.get("sub")
|
|
||||||
if email is None:
|
|
||||||
logger.error("Token payload missing 'sub' (subject/email).")
|
|
||||||
raise credentials_exception # Token is malformed
|
|
||||||
|
|
||||||
# Fetch user from database
|
|
||||||
user = await crud_user.get_user_by_email(db, email=email)
|
|
||||||
if user is None:
|
|
||||||
logger.warning(f"User corresponding to token subject not found: {email}")
|
|
||||||
# Could happen if user deleted after token issuance
|
|
||||||
raise credentials_exception # Treat as invalid credentials
|
|
||||||
|
|
||||||
logger.debug(f"Authenticated user retrieved: {user.email} (ID: {user.id})")
|
|
||||||
return user
|
|
||||||
|
|
||||||
# Optional: Dependency for getting the *active* current user
|
|
||||||
# You might add an `is_active` flag to your User model later
|
|
||||||
# async def get_current_active_user(
|
|
||||||
# current_user: UserModel = Depends(get_current_user)
|
|
||||||
# ) -> UserModel:
|
|
||||||
# if not current_user.is_active: # Assuming an is_active attribute
|
|
||||||
# logger.warning(f"Authentication attempt by inactive user: {current_user.email}")
|
|
||||||
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user")
|
|
||||||
# return current_user
|
|
@ -1,24 +1,25 @@
|
|||||||
# app/api/v1/api.py
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from app.api.v1.endpoints import health
|
from app.api.v1.endpoints import health
|
||||||
from app.api.v1.endpoints import auth
|
|
||||||
from app.api.v1.endpoints import users
|
|
||||||
from app.api.v1.endpoints import groups
|
from app.api.v1.endpoints import groups
|
||||||
from app.api.v1.endpoints import invites
|
from app.api.v1.endpoints import invites
|
||||||
from app.api.v1.endpoints import lists
|
from app.api.v1.endpoints import lists
|
||||||
from app.api.v1.endpoints import items
|
from app.api.v1.endpoints import items
|
||||||
from app.api.v1.endpoints import ocr
|
from app.api.v1.endpoints import ocr
|
||||||
|
from app.api.v1.endpoints import costs
|
||||||
|
from app.api.v1.endpoints import financials
|
||||||
|
from app.api.v1.endpoints import chores
|
||||||
|
|
||||||
api_router_v1 = APIRouter()
|
api_router_v1 = APIRouter()
|
||||||
|
|
||||||
api_router_v1.include_router(health.router) # Path /health defined inside
|
api_router_v1.include_router(health.router)
|
||||||
api_router_v1.include_router(auth.router, prefix="/auth", tags=["Authentication"])
|
|
||||||
api_router_v1.include_router(users.router, prefix="/users", tags=["Users"])
|
|
||||||
api_router_v1.include_router(groups.router, prefix="/groups", tags=["Groups"])
|
api_router_v1.include_router(groups.router, prefix="/groups", tags=["Groups"])
|
||||||
api_router_v1.include_router(invites.router, prefix="/invites", tags=["Invites"])
|
api_router_v1.include_router(invites.router, prefix="/invites", tags=["Invites"])
|
||||||
api_router_v1.include_router(lists.router, prefix="/lists", tags=["Lists"])
|
api_router_v1.include_router(lists.router, prefix="/lists", tags=["Lists"])
|
||||||
api_router_v1.include_router(items.router, tags=["Items"])
|
api_router_v1.include_router(items.router, tags=["Items"])
|
||||||
api_router_v1.include_router(ocr.router, prefix="/ocr", tags=["OCR"])
|
api_router_v1.include_router(ocr.router, prefix="/ocr", tags=["OCR"])
|
||||||
|
api_router_v1.include_router(costs.router, prefix="/costs", tags=["Costs"])
|
||||||
|
api_router_v1.include_router(financials.router, prefix="/financials", tags=["Financials"])
|
||||||
|
api_router_v1.include_router(chores.router, prefix="/chores", tags=["Chores"])
|
||||||
# Add other v1 endpoint routers here later
|
# Add other v1 endpoint routers here later
|
||||||
# e.g., api_router_v1.include_router(users.router, prefix="/users", tags=["Users"])
|
# e.g., api_router_v1.include_router(users.router, prefix="/users", tags=["Users"])
|
@ -1,91 +0,0 @@
|
|||||||
# app/api/v1/endpoints/auth.py
|
|
||||||
import logging
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.database import get_db
|
|
||||||
from app.schemas.user import UserCreate, UserPublic
|
|
||||||
from app.schemas.auth import Token
|
|
||||||
from app.crud import user as crud_user
|
|
||||||
from app.core.security import verify_password, create_access_token
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/signup",
|
|
||||||
response_model=UserPublic, # Return public user info, not the password hash
|
|
||||||
status_code=status.HTTP_201_CREATED, # Indicate resource creation
|
|
||||||
summary="Register New User",
|
|
||||||
description="Creates a new user account.",
|
|
||||||
tags=["Authentication"]
|
|
||||||
)
|
|
||||||
async def signup(
|
|
||||||
user_in: UserCreate,
|
|
||||||
db: AsyncSession = Depends(get_db)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Handles user registration.
|
|
||||||
- Validates input data.
|
|
||||||
- Checks if email already exists.
|
|
||||||
- Hashes the password.
|
|
||||||
- Stores the new user in the database.
|
|
||||||
"""
|
|
||||||
logger.info(f"Signup attempt for email: {user_in.email}")
|
|
||||||
existing_user = await crud_user.get_user_by_email(db, email=user_in.email)
|
|
||||||
if existing_user:
|
|
||||||
logger.warning(f"Signup failed: Email already registered - {user_in.email}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="Email already registered.",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
created_user = await crud_user.create_user(db=db, user_in=user_in)
|
|
||||||
logger.info(f"User created successfully: {created_user.email} (ID: {created_user.id})")
|
|
||||||
# Note: UserPublic schema automatically excludes the hashed password
|
|
||||||
return created_user
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during user creation for {user_in.email}: {e}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="An error occurred during user creation.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/login",
|
|
||||||
response_model=Token,
|
|
||||||
summary="User Login",
|
|
||||||
description="Authenticates a user and returns an access token.",
|
|
||||||
tags=["Authentication"]
|
|
||||||
)
|
|
||||||
async def login(
|
|
||||||
form_data: OAuth2PasswordRequestForm = Depends(), # Use standard form for username/password
|
|
||||||
db: AsyncSession = Depends(get_db)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Handles user login.
|
|
||||||
- Finds user by email (provided in 'username' field of form).
|
|
||||||
- Verifies the provided password against the stored hash.
|
|
||||||
- Generates and returns a JWT access token upon successful authentication.
|
|
||||||
"""
|
|
||||||
logger.info(f"Login attempt for user: {form_data.username}")
|
|
||||||
user = await crud_user.get_user_by_email(db, email=form_data.username)
|
|
||||||
|
|
||||||
# Check if user exists and password is correct
|
|
||||||
# Use the correct attribute name 'password_hash' from the User model
|
|
||||||
if not user or not verify_password(form_data.password, user.password_hash): # <-- CORRECTED LINE
|
|
||||||
logger.warning(f"Login failed: Invalid credentials for user {form_data.username}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Incorrect email or password",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"}, # Standard header for 401
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate JWT
|
|
||||||
access_token = create_access_token(subject=user.email) # Use email as subject
|
|
||||||
logger.info(f"Login successful, token generated for user: {user.email}")
|
|
||||||
return Token(access_token=access_token, token_type="bearer")
|
|
453
be/app/api/v1/endpoints/chores.py
Normal file
453
be/app/api/v1/endpoints/chores.py
Normal file
@ -0,0 +1,453 @@
|
|||||||
|
# app/api/v1/endpoints/chores.py
|
||||||
|
import logging
|
||||||
|
from typing import List as PyList, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Response
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.database import get_transactional_session, get_session
|
||||||
|
from app.auth import current_active_user
|
||||||
|
from app.models import User as UserModel, Chore as ChoreModel, ChoreTypeEnum
|
||||||
|
from app.schemas.chore import ChoreCreate, ChoreUpdate, ChorePublic, ChoreAssignmentCreate, ChoreAssignmentUpdate, ChoreAssignmentPublic
|
||||||
|
from app.crud import chore as crud_chore
|
||||||
|
from app.core.exceptions import ChoreNotFoundError, PermissionDeniedError, GroupNotFoundError, DatabaseIntegrityError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# Add this new endpoint before the personal chores section
|
||||||
|
@router.get(
|
||||||
|
"/all",
|
||||||
|
response_model=PyList[ChorePublic],
|
||||||
|
summary="List All Chores",
|
||||||
|
tags=["Chores"]
|
||||||
|
)
|
||||||
|
async def list_all_chores(
|
||||||
|
db: AsyncSession = Depends(get_session), # Use read-only session for GET
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Retrieves all chores (personal and group) for the current user in a single optimized request."""
|
||||||
|
logger.info(f"User {current_user.email} listing all their chores")
|
||||||
|
|
||||||
|
# Use the optimized function that reduces database queries
|
||||||
|
all_chores = await crud_chore.get_all_user_chores(db=db, user_id=current_user.id)
|
||||||
|
|
||||||
|
return all_chores
|
||||||
|
|
||||||
|
# --- Personal Chores Endpoints ---
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/personal",
|
||||||
|
response_model=ChorePublic,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
summary="Create Personal Chore",
|
||||||
|
tags=["Chores", "Personal Chores"]
|
||||||
|
)
|
||||||
|
async def create_personal_chore(
|
||||||
|
chore_in: ChoreCreate,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Creates a new personal chore for the current user."""
|
||||||
|
logger.info(f"User {current_user.email} creating personal chore: {chore_in.name}")
|
||||||
|
if chore_in.type != ChoreTypeEnum.personal:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Chore type must be personal.")
|
||||||
|
if chore_in.group_id is not None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="group_id must be null for personal chores.")
|
||||||
|
try:
|
||||||
|
return await crud_chore.create_chore(db=db, chore_in=chore_in, user_id=current_user.id)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(f"ValueError creating personal chore for user {current_user.email}: {str(e)}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
except DatabaseIntegrityError as e:
|
||||||
|
logger.error(f"DatabaseIntegrityError creating personal chore for {current_user.email}: {e.detail}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/personal",
|
||||||
|
response_model=PyList[ChorePublic],
|
||||||
|
summary="List Personal Chores",
|
||||||
|
tags=["Chores", "Personal Chores"]
|
||||||
|
)
|
||||||
|
async def list_personal_chores(
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Retrieves all personal chores for the current user."""
|
||||||
|
logger.info(f"User {current_user.email} listing their personal chores")
|
||||||
|
return await crud_chore.get_personal_chores(db=db, user_id=current_user.id)
|
||||||
|
|
||||||
|
@router.put(
|
||||||
|
"/personal/{chore_id}",
|
||||||
|
response_model=ChorePublic,
|
||||||
|
summary="Update Personal Chore",
|
||||||
|
tags=["Chores", "Personal Chores"]
|
||||||
|
)
|
||||||
|
async def update_personal_chore(
|
||||||
|
chore_id: int,
|
||||||
|
chore_in: ChoreUpdate,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Updates a personal chore for the current user."""
|
||||||
|
logger.info(f"User {current_user.email} updating personal chore ID: {chore_id}")
|
||||||
|
if chore_in.type is not None and chore_in.type != ChoreTypeEnum.personal:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot change chore type to group via this endpoint.")
|
||||||
|
if chore_in.group_id is not None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="group_id must be null for personal chores.")
|
||||||
|
try:
|
||||||
|
updated_chore = await crud_chore.update_chore(db=db, chore_id=chore_id, chore_in=chore_in, user_id=current_user.id, group_id=None)
|
||||||
|
if not updated_chore:
|
||||||
|
raise ChoreNotFoundError(chore_id=chore_id)
|
||||||
|
if updated_chore.type != ChoreTypeEnum.personal or updated_chore.created_by_id != current_user.id:
|
||||||
|
# This should ideally be caught by the CRUD layer permission checks
|
||||||
|
raise PermissionDeniedError(detail="Chore is not a personal chore of the current user or does not exist.")
|
||||||
|
return updated_chore
|
||||||
|
except ChoreNotFoundError as e:
|
||||||
|
logger.warning(f"Personal chore {e.chore_id} not found for user {current_user.email} during update.")
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
|
||||||
|
except PermissionDeniedError as e:
|
||||||
|
logger.warning(f"Permission denied for user {current_user.email} updating personal chore {chore_id}: {e.detail}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(f"ValueError updating personal chore {chore_id} for user {current_user.email}: {str(e)}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
except DatabaseIntegrityError as e:
|
||||||
|
logger.error(f"DatabaseIntegrityError updating personal chore {chore_id} for {current_user.email}: {e.detail}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/personal/{chore_id}",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
summary="Delete Personal Chore",
|
||||||
|
tags=["Chores", "Personal Chores"]
|
||||||
|
)
|
||||||
|
async def delete_personal_chore(
|
||||||
|
chore_id: int,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Deletes a personal chore for the current user."""
|
||||||
|
logger.info(f"User {current_user.email} deleting personal chore ID: {chore_id}")
|
||||||
|
try:
|
||||||
|
# First, verify it's a personal chore belonging to the user
|
||||||
|
chore_to_delete = await crud_chore.get_chore_by_id(db, chore_id)
|
||||||
|
if not chore_to_delete or chore_to_delete.type != ChoreTypeEnum.personal or chore_to_delete.created_by_id != current_user.id:
|
||||||
|
raise ChoreNotFoundError(chore_id=chore_id, detail="Personal chore not found or not owned by user.")
|
||||||
|
|
||||||
|
success = await crud_chore.delete_chore(db=db, chore_id=chore_id, user_id=current_user.id, group_id=None)
|
||||||
|
if not success:
|
||||||
|
# This case should be rare if the above check passes and DB is consistent
|
||||||
|
raise ChoreNotFoundError(chore_id=chore_id)
|
||||||
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
except ChoreNotFoundError as e:
|
||||||
|
logger.warning(f"Personal chore {e.chore_id} not found for user {current_user.email} during delete.")
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
|
||||||
|
except PermissionDeniedError as e: # Should be caught by the check above
|
||||||
|
logger.warning(f"Permission denied for user {current_user.email} deleting personal chore {chore_id}: {e.detail}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
|
||||||
|
except DatabaseIntegrityError as e:
|
||||||
|
logger.error(f"DatabaseIntegrityError deleting personal chore {chore_id} for {current_user.email}: {e.detail}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
|
||||||
|
|
||||||
|
# --- Group Chores Endpoints ---
|
||||||
|
# (These would be similar to what you might have had before, but now explicitly part of this router)
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/groups/{group_id}/chores",
|
||||||
|
response_model=ChorePublic,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
summary="Create Group Chore",
|
||||||
|
tags=["Chores", "Group Chores"]
|
||||||
|
)
|
||||||
|
async def create_group_chore(
|
||||||
|
group_id: int,
|
||||||
|
chore_in: ChoreCreate,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Creates a new chore within a specific group."""
|
||||||
|
logger.info(f"User {current_user.email} creating chore in group {group_id}: {chore_in.name}")
|
||||||
|
if chore_in.type != ChoreTypeEnum.group:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Chore type must be group.")
|
||||||
|
if chore_in.group_id != group_id and chore_in.group_id is not None: # Make sure chore_in.group_id matches path if provided
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Chore's group_id ({chore_in.group_id}) must match path group_id ({group_id}) or be omitted.")
|
||||||
|
|
||||||
|
# Ensure chore_in has the correct group_id and type for the CRUD operation
|
||||||
|
chore_payload = chore_in.model_copy(update={"group_id": group_id, "type": ChoreTypeEnum.group})
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await crud_chore.create_chore(db=db, chore_in=chore_payload, user_id=current_user.id, group_id=group_id)
|
||||||
|
except GroupNotFoundError as e:
|
||||||
|
logger.warning(f"Group {e.group_id} not found for chore creation by user {current_user.email}.")
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
|
||||||
|
except PermissionDeniedError as e:
|
||||||
|
logger.warning(f"Permission denied for user {current_user.email} in group {group_id} for chore creation: {e.detail}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(f"ValueError creating group chore for user {current_user.email} in group {group_id}: {str(e)}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
except DatabaseIntegrityError as e:
|
||||||
|
logger.error(f"DatabaseIntegrityError creating group chore for {current_user.email} in group {group_id}: {e.detail}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/groups/{group_id}/chores",
|
||||||
|
response_model=PyList[ChorePublic],
|
||||||
|
summary="List Group Chores",
|
||||||
|
tags=["Chores", "Group Chores"]
|
||||||
|
)
|
||||||
|
async def list_group_chores(
|
||||||
|
group_id: int,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Retrieves all chores for a specific group, if the user is a member."""
|
||||||
|
logger.info(f"User {current_user.email} listing chores for group {group_id}")
|
||||||
|
try:
|
||||||
|
return await crud_chore.get_chores_by_group_id(db=db, group_id=group_id, user_id=current_user.id)
|
||||||
|
except PermissionDeniedError as e:
|
||||||
|
logger.warning(f"Permission denied for user {current_user.email} accessing chores for group {group_id}: {e.detail}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
|
||||||
|
|
||||||
|
@router.put(
|
||||||
|
"/groups/{group_id}/chores/{chore_id}",
|
||||||
|
response_model=ChorePublic,
|
||||||
|
summary="Update Group Chore",
|
||||||
|
tags=["Chores", "Group Chores"]
|
||||||
|
)
|
||||||
|
async def update_group_chore(
|
||||||
|
group_id: int,
|
||||||
|
chore_id: int,
|
||||||
|
chore_in: ChoreUpdate,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Updates a chore's details within a specific group."""
|
||||||
|
logger.info(f"User {current_user.email} updating chore ID {chore_id} in group {group_id}")
|
||||||
|
if chore_in.type is not None and chore_in.type != ChoreTypeEnum.group:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot change chore type to personal via this endpoint.")
|
||||||
|
if chore_in.group_id is not None and chore_in.group_id != group_id:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Chore's group_id if provided must match path group_id ({group_id}).")
|
||||||
|
|
||||||
|
# Ensure chore_in has the correct type for the CRUD operation
|
||||||
|
chore_payload = chore_in.model_copy(update={"type": ChoreTypeEnum.group, "group_id": group_id} if chore_in.type is None else {"group_id": group_id})
|
||||||
|
|
||||||
|
try:
|
||||||
|
updated_chore = await crud_chore.update_chore(db=db, chore_id=chore_id, chore_in=chore_payload, user_id=current_user.id, group_id=group_id)
|
||||||
|
if not updated_chore:
|
||||||
|
raise ChoreNotFoundError(chore_id=chore_id, group_id=group_id)
|
||||||
|
return updated_chore
|
||||||
|
except ChoreNotFoundError as e:
|
||||||
|
logger.warning(f"Chore {e.chore_id} in group {e.group_id} not found for user {current_user.email} during update.")
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
|
||||||
|
except PermissionDeniedError as e:
|
||||||
|
logger.warning(f"Permission denied for user {current_user.email} updating chore {chore_id} in group {group_id}: {e.detail}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(f"ValueError updating group chore {chore_id} for user {current_user.email} in group {group_id}: {str(e)}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
except DatabaseIntegrityError as e:
|
||||||
|
logger.error(f"DatabaseIntegrityError updating group chore {chore_id} for {current_user.email}: {e.detail}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/groups/{group_id}/chores/{chore_id}",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
summary="Delete Group Chore",
|
||||||
|
tags=["Chores", "Group Chores"]
|
||||||
|
)
|
||||||
|
async def delete_group_chore(
|
||||||
|
group_id: int,
|
||||||
|
chore_id: int,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Deletes a chore from a group, ensuring user has permission."""
|
||||||
|
logger.info(f"User {current_user.email} deleting chore ID {chore_id} from group {group_id}")
|
||||||
|
try:
|
||||||
|
# Verify chore exists and belongs to the group before attempting deletion via CRUD
|
||||||
|
# This gives a more precise error if the chore exists but isn't in this group.
|
||||||
|
chore_to_delete = await crud_chore.get_chore_by_id_and_group(db, chore_id, group_id, current_user.id) # checks permission too
|
||||||
|
if not chore_to_delete : # get_chore_by_id_and_group will raise PermissionDeniedError if user not member
|
||||||
|
raise ChoreNotFoundError(chore_id=chore_id, group_id=group_id)
|
||||||
|
|
||||||
|
success = await crud_chore.delete_chore(db=db, chore_id=chore_id, user_id=current_user.id, group_id=group_id)
|
||||||
|
if not success:
|
||||||
|
# This case should be rare if the above check passes and DB is consistent
|
||||||
|
raise ChoreNotFoundError(chore_id=chore_id, group_id=group_id)
|
||||||
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
except ChoreNotFoundError as e:
|
||||||
|
logger.warning(f"Chore {e.chore_id} in group {e.group_id} not found for user {current_user.email} during delete.")
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
|
||||||
|
except PermissionDeniedError as e:
|
||||||
|
logger.warning(f"Permission denied for user {current_user.email} deleting chore {chore_id} in group {group_id}: {e.detail}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
|
||||||
|
except DatabaseIntegrityError as e:
|
||||||
|
logger.error(f"DatabaseIntegrityError deleting group chore {chore_id} for {current_user.email}: {e.detail}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
|
||||||
|
|
||||||
|
# === CHORE ASSIGNMENT ENDPOINTS ===
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/assignments",
|
||||||
|
response_model=ChoreAssignmentPublic,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
summary="Create Chore Assignment",
|
||||||
|
tags=["Chore Assignments"]
|
||||||
|
)
|
||||||
|
async def create_chore_assignment(
|
||||||
|
assignment_in: ChoreAssignmentCreate,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Creates a new chore assignment. User must have permission to manage the chore."""
|
||||||
|
logger.info(f"User {current_user.email} creating assignment for chore {assignment_in.chore_id}")
|
||||||
|
try:
|
||||||
|
return await crud_chore.create_chore_assignment(db=db, assignment_in=assignment_in, user_id=current_user.id)
|
||||||
|
except ChoreNotFoundError as e:
|
||||||
|
logger.warning(f"Chore {e.chore_id} not found for assignment creation by user {current_user.email}.")
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
|
||||||
|
except PermissionDeniedError as e:
|
||||||
|
logger.warning(f"Permission denied for user {current_user.email} creating assignment: {e.detail}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(f"ValueError creating assignment for user {current_user.email}: {str(e)}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
except DatabaseIntegrityError as e:
|
||||||
|
logger.error(f"DatabaseIntegrityError creating assignment for {current_user.email}: {e.detail}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/assignments/my",
|
||||||
|
response_model=PyList[ChoreAssignmentPublic],
|
||||||
|
summary="List My Chore Assignments",
|
||||||
|
tags=["Chore Assignments"]
|
||||||
|
)
|
||||||
|
async def list_my_assignments(
|
||||||
|
include_completed: bool = False,
|
||||||
|
db: AsyncSession = Depends(get_session), # Use read-only session for GET
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Retrieves all chore assignments for the current user."""
|
||||||
|
logger.info(f"User {current_user.email} listing their assignments (include_completed={include_completed})")
|
||||||
|
try:
|
||||||
|
return await crud_chore.get_user_assignments(db=db, user_id=current_user.id, include_completed=include_completed)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing assignments for user {current_user.email}: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to retrieve assignments")
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/chores/{chore_id}/assignments",
|
||||||
|
response_model=PyList[ChoreAssignmentPublic],
|
||||||
|
summary="List Chore Assignments",
|
||||||
|
tags=["Chore Assignments"]
|
||||||
|
)
|
||||||
|
async def list_chore_assignments(
|
||||||
|
chore_id: int,
|
||||||
|
db: AsyncSession = Depends(get_session), # Use read-only session for GET
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Retrieves all assignments for a specific chore."""
|
||||||
|
logger.info(f"User {current_user.email} listing assignments for chore {chore_id}")
|
||||||
|
try:
|
||||||
|
return await crud_chore.get_chore_assignments(db=db, chore_id=chore_id, user_id=current_user.id)
|
||||||
|
except ChoreNotFoundError as e:
|
||||||
|
logger.warning(f"Chore {e.chore_id} not found for assignment listing by user {current_user.email}.")
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
|
||||||
|
except PermissionDeniedError as e:
|
||||||
|
logger.warning(f"Permission denied for user {current_user.email} listing assignments for chore {chore_id}: {e.detail}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
|
||||||
|
|
||||||
|
@router.put(
|
||||||
|
"/assignments/{assignment_id}",
|
||||||
|
response_model=ChoreAssignmentPublic,
|
||||||
|
summary="Update Chore Assignment",
|
||||||
|
tags=["Chore Assignments"]
|
||||||
|
)
|
||||||
|
async def update_chore_assignment(
|
||||||
|
assignment_id: int,
|
||||||
|
assignment_in: ChoreAssignmentUpdate,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Updates a chore assignment. Only assignee can mark complete, managers can reschedule."""
|
||||||
|
logger.info(f"User {current_user.email} updating assignment {assignment_id}")
|
||||||
|
try:
|
||||||
|
updated_assignment = await crud_chore.update_chore_assignment(
|
||||||
|
db=db, assignment_id=assignment_id, assignment_in=assignment_in, user_id=current_user.id
|
||||||
|
)
|
||||||
|
if not updated_assignment:
|
||||||
|
raise ChoreNotFoundError(assignment_id=assignment_id)
|
||||||
|
return updated_assignment
|
||||||
|
except ChoreNotFoundError as e:
|
||||||
|
logger.warning(f"Assignment {assignment_id} not found for user {current_user.email} during update.")
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
|
||||||
|
except PermissionDeniedError as e:
|
||||||
|
logger.warning(f"Permission denied for user {current_user.email} updating assignment {assignment_id}: {e.detail}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(f"ValueError updating assignment {assignment_id} for user {current_user.email}: {str(e)}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
except DatabaseIntegrityError as e:
|
||||||
|
logger.error(f"DatabaseIntegrityError updating assignment {assignment_id} for {current_user.email}: {e.detail}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/assignments/{assignment_id}",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
summary="Delete Chore Assignment",
|
||||||
|
tags=["Chore Assignments"]
|
||||||
|
)
|
||||||
|
async def delete_chore_assignment(
|
||||||
|
assignment_id: int,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Deletes a chore assignment. User must have permission to manage the chore."""
|
||||||
|
logger.info(f"User {current_user.email} deleting assignment {assignment_id}")
|
||||||
|
try:
|
||||||
|
success = await crud_chore.delete_chore_assignment(db=db, assignment_id=assignment_id, user_id=current_user.id)
|
||||||
|
if not success:
|
||||||
|
raise ChoreNotFoundError(assignment_id=assignment_id)
|
||||||
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
except ChoreNotFoundError as e:
|
||||||
|
logger.warning(f"Assignment {assignment_id} not found for user {current_user.email} during delete.")
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
|
||||||
|
except PermissionDeniedError as e:
|
||||||
|
logger.warning(f"Permission denied for user {current_user.email} deleting assignment {assignment_id}: {e.detail}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
|
||||||
|
except DatabaseIntegrityError as e:
|
||||||
|
logger.error(f"DatabaseIntegrityError deleting assignment {assignment_id} for {current_user.email}: {e.detail}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/assignments/{assignment_id}/complete",
|
||||||
|
response_model=ChoreAssignmentPublic,
|
||||||
|
summary="Mark Assignment Complete",
|
||||||
|
tags=["Chore Assignments"]
|
||||||
|
)
|
||||||
|
async def complete_chore_assignment(
|
||||||
|
assignment_id: int,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Convenience endpoint to mark an assignment as complete."""
|
||||||
|
logger.info(f"User {current_user.email} marking assignment {assignment_id} as complete")
|
||||||
|
assignment_update = ChoreAssignmentUpdate(is_complete=True)
|
||||||
|
try:
|
||||||
|
updated_assignment = await crud_chore.update_chore_assignment(
|
||||||
|
db=db, assignment_id=assignment_id, assignment_in=assignment_update, user_id=current_user.id
|
||||||
|
)
|
||||||
|
if not updated_assignment:
|
||||||
|
raise ChoreNotFoundError(assignment_id=assignment_id)
|
||||||
|
return updated_assignment
|
||||||
|
except ChoreNotFoundError as e:
|
||||||
|
logger.warning(f"Assignment {assignment_id} not found for user {current_user.email} during completion.")
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=e.detail)
|
||||||
|
except PermissionDeniedError as e:
|
||||||
|
logger.warning(f"Permission denied for user {current_user.email} completing assignment {assignment_id}: {e.detail}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.detail)
|
||||||
|
except DatabaseIntegrityError as e:
|
||||||
|
logger.error(f"DatabaseIntegrityError completing assignment {assignment_id} for {current_user.email}: {e.detail}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail)
|
423
be/app/api/v1/endpoints/costs.py
Normal file
423
be/app/api/v1/endpoints/costs.py
Normal file
@ -0,0 +1,423 @@
|
|||||||
|
# app/api/v1/endpoints/costs.py
|
||||||
|
import logging
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import Session, selectinload
|
||||||
|
from decimal import Decimal, ROUND_HALF_UP, ROUND_DOWN
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from app.database import get_transactional_session
|
||||||
|
from app.auth import current_active_user
|
||||||
|
from app.models import (
|
||||||
|
User as UserModel,
|
||||||
|
Group as GroupModel,
|
||||||
|
List as ListModel,
|
||||||
|
Expense as ExpenseModel,
|
||||||
|
Item as ItemModel,
|
||||||
|
UserGroup as UserGroupModel,
|
||||||
|
SplitTypeEnum,
|
||||||
|
ExpenseSplit as ExpenseSplitModel,
|
||||||
|
Settlement as SettlementModel,
|
||||||
|
SettlementActivity as SettlementActivityModel # Added
|
||||||
|
)
|
||||||
|
from app.schemas.cost import ListCostSummary, GroupBalanceSummary, UserCostShare, UserBalanceDetail, SuggestedSettlement
|
||||||
|
from app.schemas.expense import ExpenseCreate
|
||||||
|
from app.crud import list as crud_list
|
||||||
|
from app.crud import expense as crud_expense
|
||||||
|
from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError, GroupNotFoundError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
def calculate_suggested_settlements(user_balances: List[UserBalanceDetail]) -> List[SuggestedSettlement]:
|
||||||
|
"""
|
||||||
|
Calculate suggested settlements to balance the finances within a group.
|
||||||
|
|
||||||
|
This function takes the current balances of all users and suggests optimal settlements
|
||||||
|
to minimize the number of transactions needed to settle all debts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_balances: List of UserBalanceDetail objects with their current balances
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SuggestedSettlement objects representing the suggested payments
|
||||||
|
"""
|
||||||
|
# Create list of users who owe money (negative balance) and who are owed money (positive balance)
|
||||||
|
debtors = [] # Users who owe money (negative balance)
|
||||||
|
creditors = [] # Users who are owed money (positive balance)
|
||||||
|
|
||||||
|
# Threshold to consider a balance as zero due to floating point precision
|
||||||
|
epsilon = Decimal('0.01')
|
||||||
|
|
||||||
|
# Sort users into debtors and creditors
|
||||||
|
for user in user_balances:
|
||||||
|
# Skip users with zero balance (or very close to zero)
|
||||||
|
if abs(user.net_balance) < epsilon:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if user.net_balance < Decimal('0'):
|
||||||
|
# User owes money
|
||||||
|
debtors.append({
|
||||||
|
'user_id': user.user_id,
|
||||||
|
'user_identifier': user.user_identifier,
|
||||||
|
'amount': -user.net_balance # Convert to positive amount
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# User is owed money
|
||||||
|
creditors.append({
|
||||||
|
'user_id': user.user_id,
|
||||||
|
'user_identifier': user.user_identifier,
|
||||||
|
'amount': user.net_balance
|
||||||
|
})
|
||||||
|
|
||||||
|
# Sort by amount (descending) to handle largest debts first
|
||||||
|
debtors.sort(key=lambda x: x['amount'], reverse=True)
|
||||||
|
creditors.sort(key=lambda x: x['amount'], reverse=True)
|
||||||
|
|
||||||
|
settlements = []
|
||||||
|
|
||||||
|
# Iterate through debtors and match them with creditors
|
||||||
|
while debtors and creditors:
|
||||||
|
debtor = debtors[0]
|
||||||
|
creditor = creditors[0]
|
||||||
|
|
||||||
|
# Determine the settlement amount (the smaller of the two amounts)
|
||||||
|
amount = min(debtor['amount'], creditor['amount']).quantize(Decimal('0.01'), rounding=ROUND_HALF_UP)
|
||||||
|
|
||||||
|
# Create settlement record
|
||||||
|
if amount > Decimal('0'):
|
||||||
|
settlements.append(
|
||||||
|
SuggestedSettlement(
|
||||||
|
from_user_id=debtor['user_id'],
|
||||||
|
from_user_identifier=debtor['user_identifier'],
|
||||||
|
to_user_id=creditor['user_id'],
|
||||||
|
to_user_identifier=creditor['user_identifier'],
|
||||||
|
amount=amount
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update balances
|
||||||
|
debtor['amount'] -= amount
|
||||||
|
creditor['amount'] -= amount
|
||||||
|
|
||||||
|
# Remove users who have settled their debts/credits
|
||||||
|
if debtor['amount'] < epsilon:
|
||||||
|
debtors.pop(0)
|
||||||
|
if creditor['amount'] < epsilon:
|
||||||
|
creditors.pop(0)
|
||||||
|
|
||||||
|
return settlements
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/lists/{list_id}/cost-summary",
|
||||||
|
response_model=ListCostSummary,
|
||||||
|
summary="Get Cost Summary for a List",
|
||||||
|
tags=["Costs"],
|
||||||
|
responses={
|
||||||
|
status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this list"},
|
||||||
|
status.HTTP_404_NOT_FOUND: {"description": "List or associated user not found"}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
async def get_list_cost_summary(
|
||||||
|
list_id: int,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieves a calculated cost summary for a specific list, detailing total costs,
|
||||||
|
equal shares per user, and individual user balances based on their contributions.
|
||||||
|
|
||||||
|
The user must have access to the list to view its cost summary.
|
||||||
|
Costs are split among group members if the list belongs to a group, or just for
|
||||||
|
the creator if it's a personal list. All users who added items with prices are
|
||||||
|
included in the calculation.
|
||||||
|
"""
|
||||||
|
logger.info(f"User {current_user.email} requesting cost summary for list {list_id}")
|
||||||
|
|
||||||
|
# 1. Verify user has access to the target list
|
||||||
|
try:
|
||||||
|
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
|
||||||
|
except ListPermissionError as e:
|
||||||
|
logger.warning(f"Permission denied for user {current_user.email} on list {list_id}: {str(e)}")
|
||||||
|
raise
|
||||||
|
except ListNotFoundError as e:
|
||||||
|
logger.warning(f"List {list_id} not found when checking permissions for cost summary: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 2. Get the list with its items and users
|
||||||
|
list_result = await db.execute(
|
||||||
|
select(ListModel)
|
||||||
|
.options(
|
||||||
|
selectinload(ListModel.items).options(selectinload(ItemModel.added_by_user)),
|
||||||
|
selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))),
|
||||||
|
selectinload(ListModel.creator)
|
||||||
|
)
|
||||||
|
.where(ListModel.id == list_id)
|
||||||
|
)
|
||||||
|
db_list = list_result.scalars().first()
|
||||||
|
if not db_list:
|
||||||
|
raise ListNotFoundError(list_id)
|
||||||
|
|
||||||
|
# 3. Get or create an expense for this list
|
||||||
|
expense_result = await db.execute(
|
||||||
|
select(ExpenseModel)
|
||||||
|
.where(ExpenseModel.list_id == list_id)
|
||||||
|
.options(selectinload(ExpenseModel.splits))
|
||||||
|
)
|
||||||
|
db_expense = expense_result.scalars().first()
|
||||||
|
|
||||||
|
if not db_expense:
|
||||||
|
# Create a new expense for this list
|
||||||
|
total_amount = sum(item.price for item in db_list.items if item.price is not None and item.price > Decimal("0"))
|
||||||
|
if total_amount == Decimal("0"):
|
||||||
|
return ListCostSummary(
|
||||||
|
list_id=db_list.id,
|
||||||
|
list_name=db_list.name,
|
||||||
|
total_list_cost=Decimal("0.00"),
|
||||||
|
num_participating_users=0,
|
||||||
|
equal_share_per_user=Decimal("0.00"),
|
||||||
|
user_balances=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create expense with ITEM_BASED split type
|
||||||
|
expense_in = ExpenseCreate(
|
||||||
|
description=f"Cost summary for list {db_list.name}",
|
||||||
|
total_amount=total_amount,
|
||||||
|
list_id=list_id,
|
||||||
|
split_type=SplitTypeEnum.ITEM_BASED,
|
||||||
|
paid_by_user_id=db_list.creator.id
|
||||||
|
)
|
||||||
|
db_expense = await crud_expense.create_expense(db=db, expense_in=expense_in, current_user_id=current_user.id)
|
||||||
|
|
||||||
|
# 4. Calculate cost summary from expense splits
|
||||||
|
participating_users = set()
|
||||||
|
user_items_added_value = {}
|
||||||
|
total_list_cost = Decimal("0.00")
|
||||||
|
|
||||||
|
# Get all users who added items
|
||||||
|
for item in db_list.items:
|
||||||
|
if item.price is not None and item.price > Decimal("0") and item.added_by_user:
|
||||||
|
participating_users.add(item.added_by_user)
|
||||||
|
user_items_added_value[item.added_by_user.id] = user_items_added_value.get(item.added_by_user.id, Decimal("0.00")) + item.price
|
||||||
|
total_list_cost += item.price
|
||||||
|
|
||||||
|
# Get all users from expense splits
|
||||||
|
for split in db_expense.splits:
|
||||||
|
if split.user:
|
||||||
|
participating_users.add(split.user)
|
||||||
|
|
||||||
|
num_participating_users = len(participating_users)
|
||||||
|
if num_participating_users == 0:
|
||||||
|
return ListCostSummary(
|
||||||
|
list_id=db_list.id,
|
||||||
|
list_name=db_list.name,
|
||||||
|
total_list_cost=Decimal("0.00"),
|
||||||
|
num_participating_users=0,
|
||||||
|
equal_share_per_user=Decimal("0.00"),
|
||||||
|
user_balances=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# This is the ideal equal share, returned in the summary
|
||||||
|
equal_share_per_user_for_response = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||||
|
|
||||||
|
# Sort users for deterministic remainder distribution
|
||||||
|
sorted_participating_users = sorted(list(participating_users), key=lambda u: u.id)
|
||||||
|
|
||||||
|
user_final_shares = {}
|
||||||
|
if num_participating_users > 0:
|
||||||
|
base_share_unrounded = total_list_cost / Decimal(num_participating_users)
|
||||||
|
|
||||||
|
# Calculate initial share for each user, rounding down
|
||||||
|
for user in sorted_participating_users:
|
||||||
|
user_final_shares[user.id] = base_share_unrounded.quantize(Decimal("0.01"), rounding=ROUND_DOWN)
|
||||||
|
|
||||||
|
# Calculate sum of rounded down shares
|
||||||
|
sum_of_rounded_shares = sum(user_final_shares.values())
|
||||||
|
|
||||||
|
# Calculate remaining pennies to be distributed
|
||||||
|
remaining_pennies = int(((total_list_cost - sum_of_rounded_shares) * Decimal("100")).to_integral_value(rounding=ROUND_HALF_UP))
|
||||||
|
|
||||||
|
# Distribute remaining pennies one by one to sorted users
|
||||||
|
for i in range(remaining_pennies):
|
||||||
|
user_to_adjust = sorted_participating_users[i % num_participating_users]
|
||||||
|
user_final_shares[user_to_adjust.id] += Decimal("0.01")
|
||||||
|
|
||||||
|
user_balances = []
|
||||||
|
for user in sorted_participating_users: # Iterate over sorted users
|
||||||
|
items_added = user_items_added_value.get(user.id, Decimal("0.00"))
|
||||||
|
# current_user_share is now the precisely calculated share for this user
|
||||||
|
current_user_share = user_final_shares.get(user.id, Decimal("0.00"))
|
||||||
|
|
||||||
|
balance = items_added - current_user_share
|
||||||
|
user_identifier = user.name if user.name else user.email
|
||||||
|
user_balances.append(
|
||||||
|
UserCostShare(
|
||||||
|
user_id=user.id,
|
||||||
|
user_identifier=user_identifier,
|
||||||
|
items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||||
|
amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||||
|
balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
user_balances.sort(key=lambda x: x.user_identifier)
|
||||||
|
return ListCostSummary(
|
||||||
|
list_id=db_list.id,
|
||||||
|
list_name=db_list.name,
|
||||||
|
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||||
|
num_participating_users=num_participating_users,
|
||||||
|
equal_share_per_user=equal_share_per_user_for_response, # Use the ideal share for the response field
|
||||||
|
user_balances=user_balances
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/groups/{group_id}/balance-summary",
|
||||||
|
response_model=GroupBalanceSummary,
|
||||||
|
summary="Get Detailed Balance Summary for a Group",
|
||||||
|
tags=["Costs", "Groups"],
|
||||||
|
responses={
|
||||||
|
status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this group"},
|
||||||
|
status.HTTP_404_NOT_FOUND: {"description": "Group not found"}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
async def get_group_balance_summary(
|
||||||
|
group_id: int,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieves a detailed financial balance summary for all users within a specific group.
|
||||||
|
It considers all expenses, their splits, and all settlements recorded for the group.
|
||||||
|
The user must be a member of the group to view its balance summary.
|
||||||
|
"""
|
||||||
|
logger.info(f"User {current_user.email} requesting balance summary for group {group_id}")
|
||||||
|
|
||||||
|
# 1. Verify user is a member of the target group
|
||||||
|
group_check = await db.execute(
|
||||||
|
select(GroupModel)
|
||||||
|
.options(selectinload(GroupModel.member_associations))
|
||||||
|
.where(GroupModel.id == group_id)
|
||||||
|
)
|
||||||
|
db_group_for_check = group_check.scalars().first()
|
||||||
|
|
||||||
|
if not db_group_for_check:
|
||||||
|
raise GroupNotFoundError(group_id)
|
||||||
|
|
||||||
|
user_is_member = any(assoc.user_id == current_user.id for assoc in db_group_for_check.member_associations)
|
||||||
|
if not user_is_member:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"User not a member of group {group_id}")
|
||||||
|
|
||||||
|
# 2. Get all expenses and settlements for the group
|
||||||
|
expenses_result = await db.execute(
|
||||||
|
select(ExpenseModel)
|
||||||
|
.where(ExpenseModel.group_id == group_id)
|
||||||
|
.options(selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user))
|
||||||
|
)
|
||||||
|
expenses = expenses_result.scalars().all()
|
||||||
|
|
||||||
|
settlements_result = await db.execute(
|
||||||
|
select(SettlementModel)
|
||||||
|
.where(SettlementModel.group_id == group_id)
|
||||||
|
.options(
|
||||||
|
selectinload(SettlementModel.paid_by_user),
|
||||||
|
selectinload(SettlementModel.paid_to_user)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
settlements = settlements_result.scalars().all()
|
||||||
|
|
||||||
|
# Fetch SettlementActivities related to the group's expenses
|
||||||
|
# This requires joining SettlementActivity -> ExpenseSplit -> Expense
|
||||||
|
settlement_activities_result = await db.execute(
|
||||||
|
select(SettlementActivityModel)
|
||||||
|
.join(ExpenseSplitModel, SettlementActivityModel.expense_split_id == ExpenseSplitModel.id)
|
||||||
|
.join(ExpenseModel, ExpenseSplitModel.expense_id == ExpenseModel.id)
|
||||||
|
.where(ExpenseModel.group_id == group_id)
|
||||||
|
.options(selectinload(SettlementActivityModel.payer)) # Optional: if you need payer details directly
|
||||||
|
)
|
||||||
|
settlement_activities = settlement_activities_result.scalars().all()
|
||||||
|
|
||||||
|
# 3. Calculate user balances
|
||||||
|
user_balances_data = {}
|
||||||
|
# Initialize UserBalanceDetail for each group member
|
||||||
|
for assoc in db_group_for_check.member_associations:
|
||||||
|
if assoc.user:
|
||||||
|
user_balances_data[assoc.user.id] = {
|
||||||
|
"user_id": assoc.user.id,
|
||||||
|
"user_identifier": assoc.user.name if assoc.user.name else assoc.user.email,
|
||||||
|
"total_paid_for_expenses": Decimal("0.00"),
|
||||||
|
"initial_total_share_of_expenses": Decimal("0.00"),
|
||||||
|
"total_amount_paid_via_settlement_activities": Decimal("0.00"),
|
||||||
|
"total_generic_settlements_paid": Decimal("0.00"),
|
||||||
|
"total_generic_settlements_received": Decimal("0.00"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process Expenses
|
||||||
|
for expense in expenses:
|
||||||
|
if expense.paid_by_user_id in user_balances_data:
|
||||||
|
user_balances_data[expense.paid_by_user_id]["total_paid_for_expenses"] += expense.total_amount
|
||||||
|
|
||||||
|
for split in expense.splits:
|
||||||
|
if split.user_id in user_balances_data:
|
||||||
|
user_balances_data[split.user_id]["initial_total_share_of_expenses"] += split.owed_amount
|
||||||
|
|
||||||
|
# Process Settlement Activities (SettlementActivityModel)
|
||||||
|
for activity in settlement_activities:
|
||||||
|
if activity.paid_by_user_id in user_balances_data:
|
||||||
|
user_balances_data[activity.paid_by_user_id]["total_amount_paid_via_settlement_activities"] += activity.amount_paid
|
||||||
|
|
||||||
|
# Process Generic Settlements (SettlementModel)
|
||||||
|
for settlement in settlements:
|
||||||
|
if settlement.paid_by_user_id in user_balances_data:
|
||||||
|
user_balances_data[settlement.paid_by_user_id]["total_generic_settlements_paid"] += settlement.amount
|
||||||
|
if settlement.paid_to_user_id in user_balances_data:
|
||||||
|
user_balances_data[settlement.paid_to_user_id]["total_generic_settlements_received"] += settlement.amount
|
||||||
|
|
||||||
|
# Calculate Final Balances
|
||||||
|
final_user_balances = []
|
||||||
|
for user_id, data in user_balances_data.items():
|
||||||
|
initial_total_share_of_expenses = data["initial_total_share_of_expenses"]
|
||||||
|
total_amount_paid_via_settlement_activities = data["total_amount_paid_via_settlement_activities"]
|
||||||
|
|
||||||
|
adjusted_total_share_of_expenses = initial_total_share_of_expenses - total_amount_paid_via_settlement_activities
|
||||||
|
|
||||||
|
total_paid_for_expenses = data["total_paid_for_expenses"]
|
||||||
|
total_generic_settlements_received = data["total_generic_settlements_received"]
|
||||||
|
total_generic_settlements_paid = data["total_generic_settlements_paid"]
|
||||||
|
|
||||||
|
net_balance = (
|
||||||
|
total_paid_for_expenses + total_generic_settlements_received
|
||||||
|
) - (adjusted_total_share_of_expenses + total_generic_settlements_paid)
|
||||||
|
|
||||||
|
# Quantize all final values for UserBalanceDetail schema
|
||||||
|
user_detail = UserBalanceDetail(
|
||||||
|
user_id=data["user_id"],
|
||||||
|
user_identifier=data["user_identifier"],
|
||||||
|
total_paid_for_expenses=total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||||
|
# Store adjusted_total_share_of_expenses in total_share_of_expenses
|
||||||
|
total_share_of_expenses=adjusted_total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||||
|
# Store total_generic_settlements_paid in total_settlements_paid
|
||||||
|
total_settlements_paid=total_generic_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||||
|
total_settlements_received=total_generic_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||||
|
net_balance=net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||||
|
)
|
||||||
|
final_user_balances.append(user_detail)
|
||||||
|
|
||||||
|
# Sort by user identifier
|
||||||
|
final_user_balances.sort(key=lambda x: x.user_identifier)
|
||||||
|
|
||||||
|
# Calculate suggested settlements
|
||||||
|
suggested_settlements = calculate_suggested_settlements(final_user_balances)
|
||||||
|
|
||||||
|
# Calculate overall totals for the group
|
||||||
|
overall_total_expenses = sum(expense.total_amount for expense in expenses)
|
||||||
|
overall_total_settlements = sum(settlement.amount for settlement in settlements)
|
||||||
|
|
||||||
|
return GroupBalanceSummary(
|
||||||
|
group_id=db_group_for_check.id,
|
||||||
|
group_name=db_group_for_check.name,
|
||||||
|
overall_total_expenses=overall_total_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||||
|
overall_total_settlements=overall_total_settlements.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||||
|
user_balances=final_user_balances,
|
||||||
|
suggested_settlements=suggested_settlements
|
||||||
|
)
|
658
be/app/api/v1/endpoints/financials.py
Normal file
658
be/app/api/v1/endpoints/financials.py
Normal file
@ -0,0 +1,658 @@
|
|||||||
|
# app/api/v1/endpoints/financials.py
|
||||||
|
import logging
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query, Response
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import joinedload
|
||||||
|
from typing import List as PyList, Optional, Sequence
|
||||||
|
|
||||||
|
from app.database import get_transactional_session
|
||||||
|
from app.auth import current_active_user
|
||||||
|
from app.models import (
|
||||||
|
User as UserModel,
|
||||||
|
Group as GroupModel,
|
||||||
|
List as ListModel,
|
||||||
|
UserGroup as UserGroupModel,
|
||||||
|
UserRoleEnum,
|
||||||
|
ExpenseSplit as ExpenseSplitModel
|
||||||
|
)
|
||||||
|
from app.schemas.expense import (
|
||||||
|
ExpenseCreate, ExpensePublic,
|
||||||
|
SettlementCreate, SettlementPublic,
|
||||||
|
ExpenseUpdate, SettlementUpdate
|
||||||
|
)
|
||||||
|
from app.schemas.settlement_activity import SettlementActivityCreate, SettlementActivityPublic # Added
|
||||||
|
from app.crud import expense as crud_expense
|
||||||
|
from app.crud import settlement as crud_settlement
|
||||||
|
from app.crud import settlement_activity as crud_settlement_activity # Added
|
||||||
|
from app.crud import group as crud_group
|
||||||
|
from app.crud import list as crud_list
|
||||||
|
from app.core.exceptions import (
|
||||||
|
ListNotFoundError, GroupNotFoundError, UserNotFoundError,
|
||||||
|
InvalidOperationError, GroupPermissionError, ListPermissionError,
|
||||||
|
ItemNotFoundError, GroupMembershipError
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# --- Helper for permissions ---
|
||||||
|
async def check_list_access_for_financials(db: AsyncSession, list_id: int, user_id: int, action: str = "access financial data for"):
|
||||||
|
try:
|
||||||
|
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=user_id, require_creator=False)
|
||||||
|
except ListPermissionError as e:
|
||||||
|
logger.warning(f"ListPermissionError in check_list_access_for_financials for list {list_id}, user {user_id}, action '{action}': {e.detail}")
|
||||||
|
raise ListPermissionError(list_id, action=action)
|
||||||
|
except ListNotFoundError:
|
||||||
|
raise
|
||||||
|
|
||||||
|
# --- Expense Endpoints ---
|
||||||
|
@router.post(
|
||||||
|
"/expenses",
|
||||||
|
response_model=ExpensePublic,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
summary="Create New Expense",
|
||||||
|
tags=["Expenses"]
|
||||||
|
)
|
||||||
|
async def create_new_expense(
|
||||||
|
expense_in: ExpenseCreate,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
logger.info(f"User {current_user.email} creating expense: {expense_in.description}")
|
||||||
|
effective_group_id = expense_in.group_id
|
||||||
|
is_group_context = False
|
||||||
|
|
||||||
|
if expense_in.list_id:
|
||||||
|
# Check basic access to list (implies membership if list is in group)
|
||||||
|
await check_list_access_for_financials(db, expense_in.list_id, current_user.id, action="create expenses for")
|
||||||
|
list_obj = await db.get(ListModel, expense_in.list_id)
|
||||||
|
if not list_obj:
|
||||||
|
raise ListNotFoundError(expense_in.list_id)
|
||||||
|
if list_obj.group_id:
|
||||||
|
if expense_in.group_id and list_obj.group_id != expense_in.group_id:
|
||||||
|
raise InvalidOperationError(f"List {list_obj.id} belongs to group {list_obj.group_id}, not group {expense_in.group_id} specified in expense.")
|
||||||
|
effective_group_id = list_obj.group_id
|
||||||
|
is_group_context = True # Expense is tied to a group via the list
|
||||||
|
elif expense_in.group_id:
|
||||||
|
raise InvalidOperationError(f"Personal list {list_obj.id} cannot have expense associated with group {expense_in.group_id}.")
|
||||||
|
# If list is personal, no group check needed yet, handled by payer check below.
|
||||||
|
|
||||||
|
elif effective_group_id: # Only group_id provided for expense
|
||||||
|
is_group_context = True
|
||||||
|
# Ensure user is at least a member to create expense in group context
|
||||||
|
await crud_group.check_group_membership(db, group_id=effective_group_id, user_id=current_user.id, action="create expenses for")
|
||||||
|
else:
|
||||||
|
# This case should ideally be caught by earlier checks if list_id was present but list was personal.
|
||||||
|
# If somehow reached, it means no list_id and no group_id.
|
||||||
|
raise InvalidOperationError("Expense must be linked to a list_id or group_id.")
|
||||||
|
|
||||||
|
# Finalize expense payload with correct group_id if derived
|
||||||
|
expense_in_final = expense_in.model_copy(update={"group_id": effective_group_id})
|
||||||
|
|
||||||
|
# --- Granular Permission Check for Payer ---
|
||||||
|
if expense_in_final.paid_by_user_id != current_user.id:
|
||||||
|
logger.warning(f"User {current_user.email} attempting to create expense paid by other user {expense_in_final.paid_by_user_id}")
|
||||||
|
# If creating expense paid by someone else, user MUST be owner IF in group context
|
||||||
|
if is_group_context and effective_group_id:
|
||||||
|
try:
|
||||||
|
await crud_group.check_user_role_in_group(db, group_id=effective_group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="create expense paid by another user")
|
||||||
|
except GroupPermissionError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Only group owners can create expenses paid by others. {str(e)}")
|
||||||
|
else:
|
||||||
|
# Cannot create expense paid by someone else for a personal list (no group context)
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Cannot create expense paid by another user for a personal list.")
|
||||||
|
# If paying for self, basic list/group membership check above is sufficient.
|
||||||
|
|
||||||
|
try:
|
||||||
|
created_expense = await crud_expense.create_expense(db=db, expense_in=expense_in_final, current_user_id=current_user.id)
|
||||||
|
logger.info(f"Expense '{created_expense.description}' (ID: {created_expense.id}) created successfully.")
|
||||||
|
return created_expense
|
||||||
|
except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError, GroupMembershipError) as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
except NotImplementedError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error creating expense: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
|
||||||
|
|
||||||
|
@router.get("/expenses/{expense_id}", response_model=ExpensePublic, summary="Get Expense by ID", tags=["Expenses"])
|
||||||
|
async def get_expense(
|
||||||
|
expense_id: int,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
logger.info(f"User {current_user.email} requesting expense ID {expense_id}")
|
||||||
|
expense = await crud_expense.get_expense_by_id(db, expense_id=expense_id)
|
||||||
|
if not expense:
|
||||||
|
raise ItemNotFoundError(item_id=expense_id)
|
||||||
|
|
||||||
|
if expense.list_id:
|
||||||
|
await check_list_access_for_financials(db, expense.list_id, current_user.id)
|
||||||
|
elif expense.group_id:
|
||||||
|
await crud_group.check_group_membership(db, group_id=expense.group_id, user_id=current_user.id)
|
||||||
|
elif expense.paid_by_user_id != current_user.id:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to view this expense")
|
||||||
|
return expense
|
||||||
|
|
||||||
|
@router.get("/expenses", response_model=PyList[ExpensePublic], summary="List Expenses", tags=["Expenses"])
|
||||||
|
async def list_expenses(
|
||||||
|
list_id: Optional[int] = Query(None, description="Filter by list ID"),
|
||||||
|
group_id: Optional[int] = Query(None, description="Filter by group ID"),
|
||||||
|
isRecurring: Optional[bool] = Query(None, description="Filter by recurring expenses"),
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(100, ge=1, le=200),
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
List expenses with optional filters.
|
||||||
|
If list_id is provided, returns expenses for that list (user must have list access).
|
||||||
|
If group_id is provided, returns expenses for that group (user must be group member).
|
||||||
|
If both are provided, returns expenses for the list (list_id takes precedence).
|
||||||
|
If neither is provided, returns all expenses the user has access to.
|
||||||
|
"""
|
||||||
|
logger.info(f"User {current_user.email} listing expenses with filters: list_id={list_id}, group_id={group_id}, isRecurring={isRecurring}")
|
||||||
|
|
||||||
|
if list_id:
|
||||||
|
# Use existing list expenses endpoint logic
|
||||||
|
await check_list_access_for_financials(db, list_id, current_user.id)
|
||||||
|
expenses = await crud_expense.get_expenses_for_list(db, list_id=list_id, skip=skip, limit=limit)
|
||||||
|
elif group_id:
|
||||||
|
# Use existing group expenses endpoint logic
|
||||||
|
await crud_group.check_group_membership(db, group_id=group_id, user_id=current_user.id, action="list expenses for")
|
||||||
|
expenses = await crud_expense.get_expenses_for_group(db, group_id=group_id, skip=skip, limit=limit)
|
||||||
|
else:
|
||||||
|
# Get all expenses the user has access to (user's personal expenses + group expenses + list expenses)
|
||||||
|
expenses = await crud_expense.get_user_accessible_expenses(db, user_id=current_user.id, skip=skip, limit=limit)
|
||||||
|
|
||||||
|
# Apply recurring filter if specified
|
||||||
|
if isRecurring is not None:
|
||||||
|
expenses = [expense for expense in expenses if bool(expense.recurrence_rule) == isRecurring]
|
||||||
|
|
||||||
|
return expenses
|
||||||
|
|
||||||
|
@router.get("/groups/{group_id}/expenses", response_model=PyList[ExpensePublic], summary="List Expenses for a Group", tags=["Expenses", "Groups"])
|
||||||
|
async def list_group_expenses(
|
||||||
|
group_id: int,
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(100, ge=1, le=200),
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
logger.info(f"User {current_user.email} listing expenses for group ID {group_id}")
|
||||||
|
await crud_group.check_group_membership(db, group_id=group_id, user_id=current_user.id, action="list expenses for")
|
||||||
|
expenses = await crud_expense.get_expenses_for_group(db, group_id=group_id, skip=skip, limit=limit)
|
||||||
|
return expenses
|
||||||
|
|
||||||
|
@router.put("/expenses/{expense_id}", response_model=ExpensePublic, summary="Update Expense", tags=["Expenses"])
|
||||||
|
async def update_expense_details(
|
||||||
|
expense_id: int,
|
||||||
|
expense_in: ExpenseUpdate,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Updates an existing expense (description, currency, expense_date only).
|
||||||
|
Requires the current version number for optimistic locking.
|
||||||
|
User must have permission to modify the expense (e.g., be the payer or group admin).
|
||||||
|
"""
|
||||||
|
logger.info(f"User {current_user.email} attempting to update expense ID {expense_id} (version {expense_in.version})")
|
||||||
|
expense_db = await crud_expense.get_expense_by_id(db, expense_id=expense_id)
|
||||||
|
if not expense_db:
|
||||||
|
raise ItemNotFoundError(item_id=expense_id)
|
||||||
|
|
||||||
|
# --- Granular Permission Check ---
|
||||||
|
can_modify = False
|
||||||
|
# 1. User paid for the expense
|
||||||
|
if expense_db.paid_by_user_id == current_user.id:
|
||||||
|
can_modify = True
|
||||||
|
# 2. OR User is owner of the group the expense belongs to
|
||||||
|
elif expense_db.group_id:
|
||||||
|
try:
|
||||||
|
await crud_group.check_user_role_in_group(db, group_id=expense_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="modify group expenses")
|
||||||
|
can_modify = True
|
||||||
|
logger.info(f"Allowing update for expense {expense_id} by group owner {current_user.email}")
|
||||||
|
except GroupMembershipError: # User not even a member
|
||||||
|
pass # Keep can_modify as False
|
||||||
|
except GroupPermissionError: # User is member but not owner
|
||||||
|
pass # Keep can_modify as False
|
||||||
|
except GroupNotFoundError: # Group doesn't exist (data integrity issue)
|
||||||
|
logger.error(f"Group {expense_db.group_id} not found for expense {expense_id} during update check.")
|
||||||
|
pass # Keep can_modify as False
|
||||||
|
# Note: If expense is only linked to a personal list (no group), only payer can modify.
|
||||||
|
|
||||||
|
if not can_modify:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot modify this expense (must be payer or group owner)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
updated_expense = await crud_expense.update_expense(db=db, expense_db=expense_db, expense_in=expense_in)
|
||||||
|
logger.info(f"Expense ID {expense_id} updated successfully to version {updated_expense.version}.")
|
||||||
|
return updated_expense
|
||||||
|
except InvalidOperationError as e:
|
||||||
|
# Check if it's a version conflict (409) or other validation error (400)
|
||||||
|
status_code = status.HTTP_400_BAD_REQUEST
|
||||||
|
if "version" in str(e).lower():
|
||||||
|
status_code = status.HTTP_409_CONFLICT
|
||||||
|
raise HTTPException(status_code=status_code, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error updating expense {expense_id}: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
|
||||||
|
|
||||||
|
@router.delete("/expenses/{expense_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete Expense", tags=["Expenses"])
|
||||||
|
async def delete_expense_record(
|
||||||
|
expense_id: int,
|
||||||
|
expected_version: Optional[int] = Query(None, description="Expected version for optimistic locking"),
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Deletes an expense and its associated splits.
|
||||||
|
Requires expected_version query parameter for optimistic locking.
|
||||||
|
User must have permission to delete the expense (e.g., be the payer or group admin).
|
||||||
|
"""
|
||||||
|
logger.info(f"User {current_user.email} attempting to delete expense ID {expense_id} (expected version {expected_version})")
|
||||||
|
expense_db = await crud_expense.get_expense_by_id(db, expense_id=expense_id)
|
||||||
|
if not expense_db:
|
||||||
|
# Return 204 even if not found, as the end state is achieved (item is gone)
|
||||||
|
logger.warning(f"Attempt to delete non-existent expense ID {expense_id}")
|
||||||
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
# Alternatively, raise NotFoundError(detail=f"Expense {expense_id} not found") -> 404
|
||||||
|
|
||||||
|
# --- Granular Permission Check ---
|
||||||
|
can_delete = False
|
||||||
|
# 1. User paid for the expense
|
||||||
|
if expense_db.paid_by_user_id == current_user.id:
|
||||||
|
can_delete = True
|
||||||
|
# 2. OR User is owner of the group the expense belongs to
|
||||||
|
elif expense_db.group_id:
|
||||||
|
try:
|
||||||
|
await crud_group.check_user_role_in_group(db, group_id=expense_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="delete group expenses")
|
||||||
|
can_delete = True
|
||||||
|
logger.info(f"Allowing delete for expense {expense_id} by group owner {current_user.email}")
|
||||||
|
except GroupMembershipError:
|
||||||
|
pass
|
||||||
|
except GroupPermissionError:
|
||||||
|
pass
|
||||||
|
except GroupNotFoundError:
|
||||||
|
logger.error(f"Group {expense_db.group_id} not found for expense {expense_id} during delete check.")
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not can_delete:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot delete this expense (must be payer or group owner)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await crud_expense.delete_expense(db=db, expense_db=expense_db, expected_version=expected_version)
|
||||||
|
logger.info(f"Expense ID {expense_id} deleted successfully.")
|
||||||
|
# No need to return content on 204
|
||||||
|
except InvalidOperationError as e:
|
||||||
|
# Check if it's a version conflict (409) or other validation error (400)
|
||||||
|
status_code = status.HTTP_400_BAD_REQUEST
|
||||||
|
if "version" in str(e).lower():
|
||||||
|
status_code = status.HTTP_409_CONFLICT
|
||||||
|
raise HTTPException(status_code=status_code, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error deleting expense {expense_id}: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
|
||||||
|
|
||||||
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
|
||||||
|
# --- Settlement Activity Endpoints (for ExpenseSplits) ---
|
||||||
|
@router.post(
|
||||||
|
"/expense_splits/{expense_split_id}/settle",
|
||||||
|
response_model=SettlementActivityPublic,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
summary="Record a Settlement Activity for an Expense Split",
|
||||||
|
tags=["Expenses", "Settlements"]
|
||||||
|
)
|
||||||
|
async def record_settlement_for_expense_split(
|
||||||
|
expense_split_id: int,
|
||||||
|
activity_in: SettlementActivityCreate,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
logger.info(f"User {current_user.email} attempting to record settlement for expense_split_id {expense_split_id} with amount {activity_in.amount_paid}")
|
||||||
|
|
||||||
|
if activity_in.expense_split_id != expense_split_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Expense split ID in path does not match expense split ID in request body."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch the ExpenseSplit and its parent Expense to check context (group/list)
|
||||||
|
stmt = (
|
||||||
|
select(ExpenseSplitModel)
|
||||||
|
.options(joinedload(ExpenseSplitModel.expense)) # Load parent expense
|
||||||
|
.where(ExpenseSplitModel.id == expense_split_id)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
expense_split = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not expense_split:
|
||||||
|
raise ItemNotFoundError(item_id=expense_split_id, detail_suffix="Expense split not found.")
|
||||||
|
|
||||||
|
parent_expense = expense_split.expense
|
||||||
|
if not parent_expense:
|
||||||
|
# Should not happen if data integrity is maintained
|
||||||
|
logger.error(f"Data integrity issue: ExpenseSplit {expense_split_id} has no parent Expense.")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Associated expense not found for this split.")
|
||||||
|
|
||||||
|
# --- Permission Checks ---
|
||||||
|
# The user performing the action (current_user) must be either:
|
||||||
|
# 1. The person who is making the payment (activity_in.paid_by_user_id).
|
||||||
|
# 2. An owner of the group, if the expense is tied to a group.
|
||||||
|
#
|
||||||
|
# Additionally, the payment (activity_in.paid_by_user_id) should ideally be made by the user who owes the split (expense_split.user_id).
|
||||||
|
# For simplicity, we'll first check if current_user is the one making the payment.
|
||||||
|
# More complex scenarios (e.g., a group owner settling on behalf of someone) are handled next.
|
||||||
|
|
||||||
|
can_record_settlement = False
|
||||||
|
if current_user.id == activity_in.paid_by_user_id:
|
||||||
|
# User is recording their own payment. This is allowed if they are the one who owes this split,
|
||||||
|
# or if they are paying for someone else and have group owner rights (covered below).
|
||||||
|
# We also need to ensure the person *being paid for* (activity_in.paid_by_user_id) is actually the one who owes this split.
|
||||||
|
if activity_in.paid_by_user_id != expense_split.user_id:
|
||||||
|
# Allow if current_user is group owner (checked next)
|
||||||
|
pass # Will be checked by group owner logic
|
||||||
|
else:
|
||||||
|
can_record_settlement = True # User is settling their own owed split
|
||||||
|
logger.info(f"User {current_user.email} is settling their own expense split {expense_split_id}.")
|
||||||
|
|
||||||
|
|
||||||
|
if not can_record_settlement and parent_expense.group_id:
|
||||||
|
try:
|
||||||
|
# Check if current_user is an owner of the group associated with the expense
|
||||||
|
await crud_group.check_user_role_in_group(
|
||||||
|
db,
|
||||||
|
group_id=parent_expense.group_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
required_role=UserRoleEnum.owner,
|
||||||
|
action="record settlement activities for group members"
|
||||||
|
)
|
||||||
|
can_record_settlement = True
|
||||||
|
logger.info(f"Group owner {current_user.email} is recording settlement for expense split {expense_split_id} in group {parent_expense.group_id}.")
|
||||||
|
except (GroupPermissionError, GroupMembershipError, GroupNotFoundError):
|
||||||
|
# If not group owner, and not settling own split, then permission denied.
|
||||||
|
pass # can_record_settlement remains False
|
||||||
|
|
||||||
|
if not can_record_settlement:
|
||||||
|
logger.warning(f"User {current_user.email} does not have permission to record settlement for expense split {expense_split_id}.")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="You do not have permission to record this settlement activity. Must be the payer or a group owner."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Final check: if someone is recording a payment for a split, the `paid_by_user_id` in the activity
|
||||||
|
# should match the `user_id` of the `ExpenseSplit` (the person who owes).
|
||||||
|
# The above permissions allow the current_user to *initiate* this, but the data itself must be consistent.
|
||||||
|
if activity_in.paid_by_user_id != expense_split.user_id:
|
||||||
|
logger.warning(f"Attempt to record settlement for expense split {expense_split_id} where activity payer ({activity_in.paid_by_user_id}) "
|
||||||
|
f"does not match split owner ({expense_split.user_id}). Only allowed if current_user is group owner and recording on behalf of split owner.")
|
||||||
|
# This scenario is tricky. If a group owner is settling for someone, they *might* set paid_by_user_id to the split owner.
|
||||||
|
# The current permission model allows the group owner to act. The crucial part is that the activity links to the correct split owner.
|
||||||
|
# If the intent is "current_user (owner) pays on behalf of expense_split.user_id", then activity_in.paid_by_user_id should be expense_split.user_id
|
||||||
|
# and current_user.id is the one performing the action (created_by_user_id in settlement_activity model).
|
||||||
|
# The CRUD `create_settlement_activity` will set `created_by_user_id` to `current_user.id`.
|
||||||
|
# The main point is that `activity_in.paid_by_user_id` should be the person whose debt is being cleared.
|
||||||
|
if current_user.id != expense_split.user_id and not (parent_expense.group_id and await crud_group.is_user_role_in_group(db, group_id=parent_expense.group_id, user_id=current_user.id, role=UserRoleEnum.owner)):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"The payer ID ({activity_in.paid_by_user_id}) in the settlement activity must match the user ID of the expense split owner ({expense_split.user_id}), unless you are a group owner acting on their behalf."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
created_activity = await crud_settlement_activity.create_settlement_activity(
|
||||||
|
db=db,
|
||||||
|
settlement_activity_in=activity_in,
|
||||||
|
current_user_id=current_user.id
|
||||||
|
)
|
||||||
|
logger.info(f"Settlement activity {created_activity.id} recorded for expense split {expense_split_id} by user {current_user.email}")
|
||||||
|
return created_activity
|
||||||
|
except UserNotFoundError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User referenced in settlement activity not found: {str(e)}")
|
||||||
|
except InvalidOperationError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error recording settlement activity for expense_split_id {expense_split_id}: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while recording settlement activity.")
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/expense_splits/{expense_split_id}/settlement_activities",
|
||||||
|
response_model=PyList[SettlementActivityPublic],
|
||||||
|
summary="List Settlement Activities for an Expense Split",
|
||||||
|
tags=["Expenses", "Settlements"]
|
||||||
|
)
|
||||||
|
async def list_settlement_activities_for_split(
|
||||||
|
expense_split_id: int,
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(100, ge=1, le=200),
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
logger.info(f"User {current_user.email} listing settlement activities for expense_split_id {expense_split_id}")
|
||||||
|
|
||||||
|
# Fetch the ExpenseSplit and its parent Expense to check context (group/list) for permissions
|
||||||
|
stmt = (
|
||||||
|
select(ExpenseSplitModel)
|
||||||
|
.options(joinedload(ExpenseSplitModel.expense)) # Load parent expense
|
||||||
|
.where(ExpenseSplitModel.id == expense_split_id)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
expense_split = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not expense_split:
|
||||||
|
raise ItemNotFoundError(item_id=expense_split_id, detail_suffix="Expense split not found.")
|
||||||
|
|
||||||
|
parent_expense = expense_split.expense
|
||||||
|
if not parent_expense:
|
||||||
|
logger.error(f"Data integrity issue: ExpenseSplit {expense_split_id} has no parent Expense.")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Associated expense not found for this split.")
|
||||||
|
|
||||||
|
# --- Permission Check (similar to viewing an expense) ---
|
||||||
|
# User must have access to the parent expense.
|
||||||
|
can_view_activities = False
|
||||||
|
if parent_expense.list_id:
|
||||||
|
try:
|
||||||
|
await check_list_access_for_financials(db, parent_expense.list_id, current_user.id, action="view settlement activities for list expense")
|
||||||
|
can_view_activities = True
|
||||||
|
except (ListPermissionError, ListNotFoundError):
|
||||||
|
pass # Keep can_view_activities False
|
||||||
|
elif parent_expense.group_id:
|
||||||
|
try:
|
||||||
|
await crud_group.check_group_membership(db, group_id=parent_expense.group_id, user_id=current_user.id, action="view settlement activities for group expense")
|
||||||
|
can_view_activities = True
|
||||||
|
except (GroupMembershipError, GroupNotFoundError):
|
||||||
|
pass # Keep can_view_activities False
|
||||||
|
elif parent_expense.paid_by_user_id == current_user.id or expense_split.user_id == current_user.id :
|
||||||
|
# If expense is not tied to list/group (e.g. item-based personal expense),
|
||||||
|
# allow if current user paid the expense OR is the one who owes this specific split.
|
||||||
|
can_view_activities = True
|
||||||
|
|
||||||
|
if not can_view_activities:
|
||||||
|
logger.warning(f"User {current_user.email} does not have permission to view settlement activities for expense split {expense_split_id}.")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="You do not have permission to view settlement activities for this expense split."
|
||||||
|
)
|
||||||
|
|
||||||
|
activities = await crud_settlement_activity.get_settlement_activities_for_split(
|
||||||
|
db=db, expense_split_id=expense_split_id, skip=skip, limit=limit
|
||||||
|
)
|
||||||
|
return activities
|
||||||
|
|
||||||
|
|
||||||
|
# --- Settlement Endpoints ---
|
||||||
|
@router.post(
|
||||||
|
"/settlements",
|
||||||
|
response_model=SettlementPublic,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
summary="Record New Settlement",
|
||||||
|
tags=["Settlements"]
|
||||||
|
)
|
||||||
|
async def create_new_settlement(
|
||||||
|
settlement_in: SettlementCreate,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
logger.info(f"User {current_user.email} recording settlement in group {settlement_in.group_id}")
|
||||||
|
await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=current_user.id, action="record settlements in")
|
||||||
|
try:
|
||||||
|
await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=settlement_in.paid_by_user_id, action="be a payer in this group's settlement")
|
||||||
|
await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=settlement_in.paid_to_user_id, action="be a payee in this group's settlement")
|
||||||
|
except GroupMembershipError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Payer or payee issue: {str(e)}")
|
||||||
|
except GroupNotFoundError as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||||
|
|
||||||
|
try:
|
||||||
|
created_settlement = await crud_settlement.create_settlement(db=db, settlement_in=settlement_in, current_user_id=current_user.id)
|
||||||
|
logger.info(f"Settlement ID {created_settlement.id} recorded successfully in group {settlement_in.group_id}.")
|
||||||
|
return created_settlement
|
||||||
|
except (UserNotFoundError, GroupNotFoundError, InvalidOperationError, GroupMembershipError) as e:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error recording settlement: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
|
||||||
|
|
||||||
|
@router.get("/settlements/{settlement_id}", response_model=SettlementPublic, summary="Get Settlement by ID", tags=["Settlements"])
|
||||||
|
async def get_settlement(
|
||||||
|
settlement_id: int,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
logger.info(f"User {current_user.email} requesting settlement ID {settlement_id}")
|
||||||
|
settlement = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id)
|
||||||
|
if not settlement:
|
||||||
|
raise ItemNotFoundError(item_id=settlement_id)
|
||||||
|
|
||||||
|
is_party_to_settlement = current_user.id in [settlement.paid_by_user_id, settlement.paid_to_user_id]
|
||||||
|
try:
|
||||||
|
await crud_group.check_group_membership(db, group_id=settlement.group_id, user_id=current_user.id)
|
||||||
|
except GroupMembershipError:
|
||||||
|
if not is_party_to_settlement:
|
||||||
|
raise GroupMembershipError(settlement.group_id, action="view this settlement's details")
|
||||||
|
logger.info(f"User {current_user.email} (party to settlement) viewing settlement {settlement_id} for group {settlement.group_id}.")
|
||||||
|
return settlement
|
||||||
|
|
||||||
|
@router.get("/groups/{group_id}/settlements", response_model=PyList[SettlementPublic], summary="List Settlements for a Group", tags=["Settlements", "Groups"])
|
||||||
|
async def list_group_settlements(
|
||||||
|
group_id: int,
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(100, ge=1, le=200),
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
logger.info(f"User {current_user.email} listing settlements for group ID {group_id}")
|
||||||
|
await crud_group.check_group_membership(db, group_id=group_id, user_id=current_user.id, action="list settlements for this group")
|
||||||
|
settlements = await crud_settlement.get_settlements_for_group(db, group_id=group_id, skip=skip, limit=limit)
|
||||||
|
return settlements
|
||||||
|
|
||||||
|
@router.put("/settlements/{settlement_id}", response_model=SettlementPublic, summary="Update Settlement", tags=["Settlements"])
|
||||||
|
async def update_settlement_details(
|
||||||
|
settlement_id: int,
|
||||||
|
settlement_in: SettlementUpdate,
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Updates an existing settlement (description, settlement_date only).
|
||||||
|
Requires the current version number for optimistic locking.
|
||||||
|
User must have permission (e.g., be involved party or group admin).
|
||||||
|
"""
|
||||||
|
logger.info(f"User {current_user.email} attempting to update settlement ID {settlement_id} (version {settlement_in.version})")
|
||||||
|
settlement_db = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id)
|
||||||
|
if not settlement_db:
|
||||||
|
raise ItemNotFoundError(item_id=settlement_id)
|
||||||
|
|
||||||
|
# --- Granular Permission Check ---
|
||||||
|
can_modify = False
|
||||||
|
# 1. User is involved party (payer or payee)
|
||||||
|
is_party = current_user.id in [settlement_db.paid_by_user_id, settlement_db.paid_to_user_id]
|
||||||
|
if is_party:
|
||||||
|
can_modify = True
|
||||||
|
# 2. OR User is owner of the group the settlement belongs to
|
||||||
|
# Note: Settlements always have a group_id based on current model
|
||||||
|
elif settlement_db.group_id:
|
||||||
|
try:
|
||||||
|
await crud_group.check_user_role_in_group(db, group_id=settlement_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="modify group settlements")
|
||||||
|
can_modify = True
|
||||||
|
logger.info(f"Allowing update for settlement {settlement_id} by group owner {current_user.email}")
|
||||||
|
except GroupMembershipError:
|
||||||
|
pass
|
||||||
|
except GroupPermissionError:
|
||||||
|
pass
|
||||||
|
except GroupNotFoundError:
|
||||||
|
logger.error(f"Group {settlement_db.group_id} not found for settlement {settlement_id} during update check.")
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not can_modify:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot modify this settlement (must be involved party or group owner)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
updated_settlement = await crud_settlement.update_settlement(db=db, settlement_db=settlement_db, settlement_in=settlement_in)
|
||||||
|
logger.info(f"Settlement ID {settlement_id} updated successfully to version {updated_settlement.version}.")
|
||||||
|
return updated_settlement
|
||||||
|
except InvalidOperationError as e:
|
||||||
|
status_code = status.HTTP_400_BAD_REQUEST
|
||||||
|
if "version" in str(e).lower():
|
||||||
|
status_code = status.HTTP_409_CONFLICT
|
||||||
|
raise HTTPException(status_code=status_code, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error updating settlement {settlement_id}: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
|
||||||
|
|
||||||
|
@router.delete("/settlements/{settlement_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete Settlement", tags=["Settlements"])
|
||||||
|
async def delete_settlement_record(
|
||||||
|
settlement_id: int,
|
||||||
|
expected_version: Optional[int] = Query(None, description="Expected version for optimistic locking"),
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Deletes a settlement.
|
||||||
|
Requires expected_version query parameter for optimistic locking.
|
||||||
|
User must have permission (e.g., be involved party or group admin).
|
||||||
|
"""
|
||||||
|
logger.info(f"User {current_user.email} attempting to delete settlement ID {settlement_id} (expected version {expected_version})")
|
||||||
|
settlement_db = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id)
|
||||||
|
if not settlement_db:
|
||||||
|
logger.warning(f"Attempt to delete non-existent settlement ID {settlement_id}")
|
||||||
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
|
||||||
|
# --- Granular Permission Check ---
|
||||||
|
can_delete = False
|
||||||
|
# 1. User is involved party (payer or payee)
|
||||||
|
is_party = current_user.id in [settlement_db.paid_by_user_id, settlement_db.paid_to_user_id]
|
||||||
|
if is_party:
|
||||||
|
can_delete = True
|
||||||
|
# 2. OR User is owner of the group the settlement belongs to
|
||||||
|
elif settlement_db.group_id:
|
||||||
|
try:
|
||||||
|
await crud_group.check_user_role_in_group(db, group_id=settlement_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="delete group settlements")
|
||||||
|
can_delete = True
|
||||||
|
logger.info(f"Allowing delete for settlement {settlement_id} by group owner {current_user.email}")
|
||||||
|
except GroupMembershipError:
|
||||||
|
pass
|
||||||
|
except GroupPermissionError:
|
||||||
|
pass
|
||||||
|
except GroupNotFoundError:
|
||||||
|
logger.error(f"Group {settlement_db.group_id} not found for settlement {settlement_id} during delete check.")
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not can_delete:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot delete this settlement (must be involved party or group owner)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await crud_settlement.delete_settlement(db=db, settlement_db=settlement_db, expected_version=expected_version)
|
||||||
|
logger.info(f"Settlement ID {settlement_id} deleted successfully.")
|
||||||
|
except InvalidOperationError as e:
|
||||||
|
status_code = status.HTTP_400_BAD_REQUEST
|
||||||
|
if "version" in str(e).lower():
|
||||||
|
status_code = status.HTTP_409_CONFLICT
|
||||||
|
raise HTTPException(status_code=status_code, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error deleting settlement {settlement_id}: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
|
||||||
|
|
||||||
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
@ -5,14 +5,24 @@ from typing import List
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_transactional_session, get_session
|
||||||
from app.api.dependencies import get_current_user
|
from app.auth import current_active_user
|
||||||
from app.models import User as UserModel, UserRoleEnum # Import model and enum
|
from app.models import User as UserModel, UserRoleEnum # Import model and enum
|
||||||
from app.schemas.group import GroupCreate, GroupPublic
|
from app.schemas.group import GroupCreate, GroupPublic
|
||||||
from app.schemas.invite import InviteCodePublic
|
from app.schemas.invite import InviteCodePublic
|
||||||
from app.schemas.message import Message # For simple responses
|
from app.schemas.message import Message # For simple responses
|
||||||
|
from app.schemas.list import ListPublic, ListDetail
|
||||||
from app.crud import group as crud_group
|
from app.crud import group as crud_group
|
||||||
from app.crud import invite as crud_invite
|
from app.crud import invite as crud_invite
|
||||||
|
from app.crud import list as crud_list
|
||||||
|
from app.core.exceptions import (
|
||||||
|
GroupNotFoundError,
|
||||||
|
GroupPermissionError,
|
||||||
|
GroupMembershipError,
|
||||||
|
GroupOperationError,
|
||||||
|
GroupValidationError,
|
||||||
|
InviteCreationError
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -26,8 +36,8 @@ router = APIRouter()
|
|||||||
)
|
)
|
||||||
async def create_group(
|
async def create_group(
|
||||||
group_in: GroupCreate,
|
group_in: GroupCreate,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
current_user: UserModel = Depends(get_current_user),
|
current_user: UserModel = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Creates a new group, adding the creator as the owner."""
|
"""Creates a new group, adding the creator as the owner."""
|
||||||
logger.info(f"User {current_user.email} creating group: {group_in.name}")
|
logger.info(f"User {current_user.email} creating group: {group_in.name}")
|
||||||
@ -44,8 +54,8 @@ async def create_group(
|
|||||||
tags=["Groups"]
|
tags=["Groups"]
|
||||||
)
|
)
|
||||||
async def read_user_groups(
|
async def read_user_groups(
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_session), # Use read-only session for GET
|
||||||
current_user: UserModel = Depends(get_current_user),
|
current_user: UserModel = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Retrieves all groups the current user is a member of."""
|
"""Retrieves all groups the current user is a member of."""
|
||||||
logger.info(f"Fetching groups for user: {current_user.email}")
|
logger.info(f"Fetching groups for user: {current_user.email}")
|
||||||
@ -61,8 +71,8 @@ async def read_user_groups(
|
|||||||
)
|
)
|
||||||
async def read_group(
|
async def read_group(
|
||||||
group_id: int,
|
group_id: int,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_session), # Use read-only session for GET
|
||||||
current_user: UserModel = Depends(get_current_user),
|
current_user: UserModel = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Retrieves details for a specific group, including members, if the user is part of it."""
|
"""Retrieves details for a specific group, including members, if the user is part of it."""
|
||||||
logger.info(f"User {current_user.email} requesting details for group ID: {group_id}")
|
logger.info(f"User {current_user.email} requesting details for group ID: {group_id}")
|
||||||
@ -70,18 +80,14 @@ async def read_group(
|
|||||||
is_member = await crud_group.is_user_member(db=db, group_id=group_id, user_id=current_user.id)
|
is_member = await crud_group.is_user_member(db=db, group_id=group_id, user_id=current_user.id)
|
||||||
if not is_member:
|
if not is_member:
|
||||||
logger.warning(f"Access denied: User {current_user.email} not member of group {group_id}")
|
logger.warning(f"Access denied: User {current_user.email} not member of group {group_id}")
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not a member of this group")
|
raise GroupMembershipError(group_id, "view group details")
|
||||||
|
|
||||||
group = await crud_group.get_group_by_id(db=db, group_id=group_id)
|
group = await crud_group.get_group_by_id(db=db, group_id=group_id)
|
||||||
if not group:
|
if not group:
|
||||||
logger.error(f"Group {group_id} requested by member {current_user.email} not found (data inconsistency?)")
|
logger.error(f"Group {group_id} requested by member {current_user.email} not found (data inconsistency?)")
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found")
|
raise GroupNotFoundError(group_id)
|
||||||
|
|
||||||
# Manually construct the members list with UserPublic schema if needed
|
return group
|
||||||
# Pydantic v2's from_attributes should handle this if relationships are loaded
|
|
||||||
# members_public = [UserPublic.model_validate(assoc.user) for assoc in group.member_associations]
|
|
||||||
# return GroupPublic.model_validate(group, update={"members": members_public})
|
|
||||||
return group # Rely on Pydantic conversion and eager loading
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@ -92,8 +98,8 @@ async def read_group(
|
|||||||
)
|
)
|
||||||
async def create_group_invite(
|
async def create_group_invite(
|
||||||
group_id: int,
|
group_id: int,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
current_user: UserModel = Depends(get_current_user),
|
current_user: UserModel = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Generates a new invite code for the group. Requires owner/admin role (MVP: owner only)."""
|
"""Generates a new invite code for the group. Requires owner/admin role (MVP: owner only)."""
|
||||||
logger.info(f"User {current_user.email} attempting to create invite for group {group_id}")
|
logger.info(f"User {current_user.email} attempting to create invite for group {group_id}")
|
||||||
@ -102,21 +108,59 @@ async def create_group_invite(
|
|||||||
# --- Permission Check (MVP: Owner only) ---
|
# --- Permission Check (MVP: Owner only) ---
|
||||||
if user_role != UserRoleEnum.owner:
|
if user_role != UserRoleEnum.owner:
|
||||||
logger.warning(f"Permission denied: User {current_user.email} (role: {user_role}) cannot create invite for group {group_id}")
|
logger.warning(f"Permission denied: User {current_user.email} (role: {user_role}) cannot create invite for group {group_id}")
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only group owners can create invites")
|
raise GroupPermissionError(group_id, "create invites")
|
||||||
|
|
||||||
# Check if group exists (implicitly done by role check, but good practice)
|
# Check if group exists (implicitly done by role check, but good practice)
|
||||||
group = await crud_group.get_group_by_id(db, group_id)
|
group = await crud_group.get_group_by_id(db, group_id)
|
||||||
if not group:
|
if not group:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Group not found")
|
raise GroupNotFoundError(group_id)
|
||||||
|
|
||||||
invite = await crud_invite.create_invite(db=db, group_id=group_id, creator_id=current_user.id)
|
invite = await crud_invite.create_invite(db=db, group_id=group_id, creator_id=current_user.id)
|
||||||
if not invite:
|
if not invite:
|
||||||
logger.error(f"Failed to generate unique invite code for group {group_id}")
|
logger.error(f"Failed to generate unique invite code for group {group_id}")
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Could not generate invite code")
|
# This case should ideally be covered by exceptions from create_invite now
|
||||||
|
raise InviteCreationError(group_id)
|
||||||
|
|
||||||
logger.info(f"Invite code created for group {group_id} by user {current_user.email}")
|
logger.info(f"User {current_user.email} created invite code for group {group_id}")
|
||||||
return invite
|
return invite
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{group_id}/invites",
|
||||||
|
response_model=InviteCodePublic, # Or Optional[InviteCodePublic] if it can be null
|
||||||
|
summary="Get Group Active Invite Code",
|
||||||
|
tags=["Groups", "Invites"]
|
||||||
|
)
|
||||||
|
async def get_group_active_invite(
|
||||||
|
group_id: int,
|
||||||
|
db: AsyncSession = Depends(get_session), # Use read-only session for GET
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Retrieves the active invite code for the group. Requires group membership (owner/admin to be stricter later if needed)."""
|
||||||
|
logger.info(f"User {current_user.email} attempting to get active invite for group {group_id}")
|
||||||
|
|
||||||
|
# Permission check: Ensure user is a member of the group to view invite code
|
||||||
|
# Using get_user_role_in_group which also checks membership indirectly
|
||||||
|
user_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=current_user.id)
|
||||||
|
if user_role is None: # Not a member
|
||||||
|
logger.warning(f"Permission denied: User {current_user.email} is not a member of group {group_id} and cannot view invite code.")
|
||||||
|
# More specific error or let GroupPermissionError handle if we want to be generic
|
||||||
|
raise GroupMembershipError(group_id, "view invite code for this group (not a member)")
|
||||||
|
|
||||||
|
# Fetch the active invite for the group
|
||||||
|
invite = await crud_invite.get_active_invite_for_group(db, group_id=group_id)
|
||||||
|
|
||||||
|
if not invite:
|
||||||
|
# This case means no active (non-expired, active=true) invite exists.
|
||||||
|
# The frontend can then prompt to generate one.
|
||||||
|
logger.info(f"No active invite code found for group {group_id} when requested by {current_user.email}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="No active invite code found for this group. Please generate one."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"User {current_user.email} retrieved active invite code for group {group_id}")
|
||||||
|
return invite # Pydantic will convert InviteModel to InviteCodePublic
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
"/{group_id}/leave",
|
"/{group_id}/leave",
|
||||||
response_model=Message,
|
response_model=Message,
|
||||||
@ -125,31 +169,32 @@ async def create_group_invite(
|
|||||||
)
|
)
|
||||||
async def leave_group(
|
async def leave_group(
|
||||||
group_id: int,
|
group_id: int,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
current_user: UserModel = Depends(get_current_user),
|
current_user: UserModel = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Removes the current user from the specified group."""
|
"""Removes the current user from the specified group. If the owner is the last member, the group will be deleted."""
|
||||||
logger.info(f"User {current_user.email} attempting to leave group {group_id}")
|
logger.info(f"User {current_user.email} attempting to leave group {group_id}")
|
||||||
user_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=current_user.id)
|
user_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=current_user.id)
|
||||||
|
|
||||||
if user_role is None:
|
if user_role is None:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="You are not a member of this group")
|
raise GroupMembershipError(group_id, "leave (you are not a member)")
|
||||||
|
|
||||||
# --- MVP: Prevent owner leaving if they are the last member/owner ---
|
# Check if owner is the last member
|
||||||
if user_role == UserRoleEnum.owner:
|
if user_role == UserRoleEnum.owner:
|
||||||
member_count = await crud_group.get_group_member_count(db, group_id)
|
member_count = await crud_group.get_group_member_count(db, group_id)
|
||||||
# More robust check: count owners. For now, just check member count.
|
|
||||||
if member_count <= 1:
|
if member_count <= 1:
|
||||||
logger.warning(f"Owner {current_user.email} attempted to leave group {group_id} as last member.")
|
# Delete the group since owner is the last member
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Owner cannot leave the group as the last member. Delete the group or transfer ownership.")
|
logger.info(f"Owner {current_user.email} is the last member. Deleting group {group_id}")
|
||||||
|
await crud_group.delete_group(db, group_id)
|
||||||
|
return Message(detail="Group deleted as you were the last member")
|
||||||
|
|
||||||
# Proceed with removal
|
# Proceed with removal for non-owner or if there are other members
|
||||||
deleted = await crud_group.remove_user_from_group(db, group_id=group_id, user_id=current_user.id)
|
deleted = await crud_group.remove_user_from_group(db, group_id=group_id, user_id=current_user.id)
|
||||||
|
|
||||||
if not deleted:
|
if not deleted:
|
||||||
# Should not happen if role check passed, but handle defensively
|
# Should not happen if role check passed, but handle defensively
|
||||||
logger.error(f"Failed to remove user {current_user.email} from group {group_id} despite being a member.")
|
logger.error(f"Failed to remove user {current_user.email} from group {group_id} despite being a member.")
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to leave group")
|
raise GroupOperationError("Failed to leave group")
|
||||||
|
|
||||||
logger.info(f"User {current_user.email} successfully left group {group_id}")
|
logger.info(f"User {current_user.email} successfully left group {group_id}")
|
||||||
return Message(detail="Successfully left the group")
|
return Message(detail="Successfully left the group")
|
||||||
@ -164,8 +209,8 @@ async def leave_group(
|
|||||||
async def remove_group_member(
|
async def remove_group_member(
|
||||||
group_id: int,
|
group_id: int,
|
||||||
user_id_to_remove: int,
|
user_id_to_remove: int,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
current_user: UserModel = Depends(get_current_user),
|
current_user: UserModel = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Removes a specified user from the group. Requires current user to be owner."""
|
"""Removes a specified user from the group. Requires current user to be owner."""
|
||||||
logger.info(f"Owner {current_user.email} attempting to remove user {user_id_to_remove} from group {group_id}")
|
logger.info(f"Owner {current_user.email} attempting to remove user {user_id_to_remove} from group {group_id}")
|
||||||
@ -174,23 +219,49 @@ async def remove_group_member(
|
|||||||
# --- Permission Check ---
|
# --- Permission Check ---
|
||||||
if owner_role != UserRoleEnum.owner:
|
if owner_role != UserRoleEnum.owner:
|
||||||
logger.warning(f"Permission denied: User {current_user.email} (role: {owner_role}) cannot remove members from group {group_id}")
|
logger.warning(f"Permission denied: User {current_user.email} (role: {owner_role}) cannot remove members from group {group_id}")
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only group owners can remove members")
|
raise GroupPermissionError(group_id, "remove members")
|
||||||
|
|
||||||
# Prevent owner removing themselves via this endpoint
|
# Prevent owner removing themselves via this endpoint
|
||||||
if current_user.id == user_id_to_remove:
|
if current_user.id == user_id_to_remove:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Owner cannot remove themselves using this endpoint. Use 'Leave Group' instead.")
|
raise GroupValidationError("Owner cannot remove themselves using this endpoint. Use 'Leave Group' instead.")
|
||||||
|
|
||||||
# Check if target user is actually in the group
|
# Check if target user is actually in the group
|
||||||
target_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=user_id_to_remove)
|
target_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=user_id_to_remove)
|
||||||
if target_role is None:
|
if target_role is None:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User to remove is not a member of this group")
|
raise GroupMembershipError(group_id, "remove this user (they are not a member)")
|
||||||
|
|
||||||
# Proceed with removal
|
# Proceed with removal
|
||||||
deleted = await crud_group.remove_user_from_group(db, group_id=group_id, user_id=user_id_to_remove)
|
deleted = await crud_group.remove_user_from_group(db, group_id=group_id, user_id=user_id_to_remove)
|
||||||
|
|
||||||
if not deleted:
|
if not deleted:
|
||||||
logger.error(f"Owner {current_user.email} failed to remove user {user_id_to_remove} from group {group_id}.")
|
logger.error(f"Owner {current_user.email} failed to remove user {user_id_to_remove} from group {group_id}.")
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to remove member")
|
raise GroupOperationError("Failed to remove member")
|
||||||
|
|
||||||
logger.info(f"Owner {current_user.email} successfully removed user {user_id_to_remove} from group {group_id}")
|
logger.info(f"Owner {current_user.email} successfully removed user {user_id_to_remove} from group {group_id}")
|
||||||
return Message(detail="Successfully removed member from the group")
|
return Message(detail="Successfully removed member from the group")
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{group_id}/lists",
|
||||||
|
response_model=List[ListDetail],
|
||||||
|
summary="Get Group Lists",
|
||||||
|
tags=["Groups", "Lists"]
|
||||||
|
)
|
||||||
|
async def read_group_lists(
|
||||||
|
group_id: int,
|
||||||
|
db: AsyncSession = Depends(get_session), # Use read-only session for GET
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Retrieves all lists belonging to a specific group, if the user is a member."""
|
||||||
|
logger.info(f"User {current_user.email} requesting lists for group ID: {group_id}")
|
||||||
|
|
||||||
|
# Check if user is a member first
|
||||||
|
is_member = await crud_group.is_user_member(db=db, group_id=group_id, user_id=current_user.id)
|
||||||
|
if not is_member:
|
||||||
|
logger.warning(f"Access denied: User {current_user.email} not member of group {group_id}")
|
||||||
|
raise GroupMembershipError(group_id, "view group lists")
|
||||||
|
|
||||||
|
# Get all lists for the user and filter by group_id
|
||||||
|
lists = await crud_list.get_lists_for_user(db=db, user_id=current_user.id)
|
||||||
|
group_lists = [list for list in lists if list.group_id == group_id]
|
||||||
|
|
||||||
|
return group_lists
|
@ -1,11 +1,12 @@
|
|||||||
# app/api/v1/endpoints/health.py
|
# app/api/v1/endpoints/health.py
|
||||||
import logging
|
import logging
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.sql import text
|
from sqlalchemy.sql import text
|
||||||
|
|
||||||
from app.database import get_db # Import the dependency function
|
from app.database import get_transactional_session
|
||||||
from app.schemas.health import HealthStatus # Import the response schema
|
from app.schemas.health import HealthStatus
|
||||||
|
from app.core.exceptions import DatabaseConnectionError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -15,9 +16,9 @@ router = APIRouter()
|
|||||||
response_model=HealthStatus,
|
response_model=HealthStatus,
|
||||||
summary="Perform a Health Check",
|
summary="Perform a Health Check",
|
||||||
description="Checks the operational status of the API and its connection to the database.",
|
description="Checks the operational status of the API and its connection to the database.",
|
||||||
tags=["Health"] # Group this endpoint in Swagger UI
|
tags=["Health"]
|
||||||
)
|
)
|
||||||
async def check_health(db: AsyncSession = Depends(get_db)):
|
async def check_health(db: AsyncSession = Depends(get_transactional_session)):
|
||||||
"""
|
"""
|
||||||
Health check endpoint. Verifies API reachability and database connection.
|
Health check endpoint. Verifies API reachability and database connection.
|
||||||
"""
|
"""
|
||||||
@ -30,16 +31,8 @@ async def check_health(db: AsyncSession = Depends(get_db)):
|
|||||||
else:
|
else:
|
||||||
# This case should ideally not happen with 'SELECT 1'
|
# This case should ideally not happen with 'SELECT 1'
|
||||||
logger.error("Health check failed: Database connection check returned unexpected result.")
|
logger.error("Health check failed: Database connection check returned unexpected result.")
|
||||||
# Raise 503 Service Unavailable
|
raise DatabaseConnectionError("Unexpected result from database connection check")
|
||||||
raise HTTPException(
|
|
||||||
status_code=503,
|
|
||||||
detail="Database connection error: Unexpected result"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Health check failed: Database connection error - {e}", exc_info=True) # Log stack trace
|
logger.error(f"Health check failed: Database connection error - {e}", exc_info=True)
|
||||||
# Raise 503 Service Unavailable
|
raise DatabaseConnectionError(str(e))
|
||||||
raise HTTPException(
|
|
||||||
status_code=503,
|
|
||||||
detail=f"Database connection error: {e}"
|
|
||||||
)
|
|
@ -3,57 +3,77 @@ import logging
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_transactional_session
|
||||||
from app.api.dependencies import get_current_user
|
from app.auth import current_active_user
|
||||||
from app.models import User as UserModel, UserRoleEnum
|
from app.models import User as UserModel, UserRoleEnum
|
||||||
from app.schemas.invite import InviteAccept
|
from app.schemas.invite import InviteAccept
|
||||||
from app.schemas.message import Message
|
from app.schemas.message import Message
|
||||||
|
from app.schemas.group import GroupPublic
|
||||||
from app.crud import invite as crud_invite
|
from app.crud import invite as crud_invite
|
||||||
from app.crud import group as crud_group
|
from app.crud import group as crud_group
|
||||||
|
from app.core.exceptions import (
|
||||||
|
InviteNotFoundError,
|
||||||
|
InviteExpiredError,
|
||||||
|
InviteAlreadyUsedError,
|
||||||
|
InviteCreationError,
|
||||||
|
GroupNotFoundError,
|
||||||
|
GroupMembershipError,
|
||||||
|
GroupOperationError
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/accept", # Route relative to prefix "/invites"
|
"/accept", # Route relative to prefix "/invites"
|
||||||
response_model=Message,
|
response_model=GroupPublic,
|
||||||
summary="Accept Group Invite",
|
summary="Accept Group Invite",
|
||||||
tags=["Invites"]
|
tags=["Invites"]
|
||||||
)
|
)
|
||||||
async def accept_invite(
|
async def accept_invite(
|
||||||
invite_in: InviteAccept,
|
invite_in: InviteAccept,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
current_user: UserModel = Depends(get_current_user),
|
current_user: UserModel = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Allows an authenticated user to accept an invite using its code."""
|
"""Accepts a group invite using the provided invite code."""
|
||||||
code = invite_in.code
|
logger.info(f"User {current_user.email} attempting to accept invite code: {invite_in.code}")
|
||||||
logger.info(f"User {current_user.email} attempting to accept invite code: {code}")
|
|
||||||
|
|
||||||
# Find the active, non-expired invite
|
# Get the invite - this function should only return valid, active invites
|
||||||
invite = await crud_invite.get_active_invite_by_code(db=db, code=code)
|
invite = await crud_invite.get_active_invite_by_code(db, code=invite_in.code)
|
||||||
if not invite:
|
if not invite:
|
||||||
logger.warning(f"Invite code '{code}' not found, expired, or already used.")
|
logger.warning(f"Invalid or inactive invite code attempted by user {current_user.email}: {invite_in.code}")
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Invite code is invalid or expired")
|
# We can use a more generic error or a specific one. InviteNotFound is reasonable.
|
||||||
|
raise InviteNotFoundError(invite_in.code)
|
||||||
|
|
||||||
group_id = invite.group_id
|
# Check if group still exists
|
||||||
|
group = await crud_group.get_group_by_id(db, group_id=invite.group_id)
|
||||||
|
if not group:
|
||||||
|
logger.error(f"Group {invite.group_id} not found for invite {invite_in.code}")
|
||||||
|
raise GroupNotFoundError(invite.group_id)
|
||||||
|
|
||||||
# Check if user is already in the group
|
# Check if user is already a member
|
||||||
is_member = await crud_group.is_user_member(db=db, group_id=group_id, user_id=current_user.id)
|
is_member = await crud_group.is_user_member(db, group_id=invite.group_id, user_id=current_user.id)
|
||||||
if is_member:
|
if is_member:
|
||||||
logger.info(f"User {current_user.email} is already a member of group {group_id}. Invite '{code}' still deactivated.")
|
logger.warning(f"User {current_user.email} already a member of group {invite.group_id}")
|
||||||
# Deactivate invite even if already member, to prevent reuse
|
raise GroupMembershipError(invite.group_id, "join (already a member)")
|
||||||
await crud_invite.deactivate_invite(db=db, invite=invite)
|
|
||||||
return Message(detail="You are already a member of this group.")
|
|
||||||
|
|
||||||
# Add user to the group as a member
|
# Add user to the group
|
||||||
added = await crud_group.add_user_to_group(db=db, group_id=group_id, user_id=current_user.id, role=UserRoleEnum.member)
|
added_to_group = await crud_group.add_user_to_group(db, group_id=invite.group_id, user_id=current_user.id)
|
||||||
if not added:
|
if not added_to_group:
|
||||||
# Should not happen if is_member check was correct, but handle defensively
|
logger.error(f"Failed to add user {current_user.email} to group {invite.group_id} during invite acceptance.")
|
||||||
logger.error(f"Failed to add user {current_user.email} to group {group_id} via invite '{code}' despite not being a member.")
|
# This could be a race condition or other issue, treat as an operational error.
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Could not join group.")
|
raise GroupOperationError("Failed to add user to group.")
|
||||||
|
|
||||||
# Deactivate the invite (single-use)
|
# Deactivate the invite so it cannot be used again
|
||||||
await crud_invite.deactivate_invite(db=db, invite=invite)
|
await crud_invite.deactivate_invite(db, invite=invite)
|
||||||
|
|
||||||
logger.info(f"User {current_user.email} successfully joined group {group_id} using invite '{code}'.")
|
logger.info(f"User {current_user.email} successfully joined group {invite.group_id} via invite {invite_in.code}")
|
||||||
return Message(detail="Successfully joined the group.")
|
|
||||||
|
# Re-fetch the group to get the updated member list
|
||||||
|
updated_group = await crud_group.get_group_by_id(db, group_id=invite.group_id)
|
||||||
|
if not updated_group:
|
||||||
|
# This should ideally not happen as we found it before
|
||||||
|
logger.error(f"Could not re-fetch group {invite.group_id} after user {current_user.email} joined.")
|
||||||
|
raise GroupNotFoundError(invite.group_id)
|
||||||
|
|
||||||
|
return updated_group
|
@ -1,12 +1,12 @@
|
|||||||
# app/api/v1/endpoints/items.py
|
# app/api/v1/endpoints/items.py
|
||||||
import logging
|
import logging
|
||||||
from typing import List as PyList
|
from typing import List as PyList, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Response
|
from fastapi import APIRouter, Depends, HTTPException, status, Response, Query
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_transactional_session
|
||||||
from app.api.dependencies import get_current_user
|
from app.auth import current_active_user
|
||||||
# --- Import Models Correctly ---
|
# --- Import Models Correctly ---
|
||||||
from app.models import User as UserModel
|
from app.models import User as UserModel
|
||||||
from app.models import Item as ItemModel # <-- IMPORT Item and alias it
|
from app.models import Item as ItemModel # <-- IMPORT Item and alias it
|
||||||
@ -14,6 +14,7 @@ from app.models import Item as ItemModel # <-- IMPORT Item and alias it
|
|||||||
from app.schemas.item import ItemCreate, ItemUpdate, ItemPublic
|
from app.schemas.item import ItemCreate, ItemUpdate, ItemPublic
|
||||||
from app.crud import item as crud_item
|
from app.crud import item as crud_item
|
||||||
from app.crud import list as crud_list
|
from app.crud import list as crud_list
|
||||||
|
from app.core.exceptions import ItemNotFoundError, ListPermissionError, ConflictError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -22,19 +23,21 @@ router = APIRouter()
|
|||||||
# Now ItemModel is defined before being used as a type hint
|
# Now ItemModel is defined before being used as a type hint
|
||||||
async def get_item_and_verify_access(
|
async def get_item_and_verify_access(
|
||||||
item_id: int,
|
item_id: int,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
current_user: UserModel = Depends(get_current_user)
|
current_user: UserModel = Depends(current_active_user)
|
||||||
) -> ItemModel: # Now this type hint is valid
|
) -> ItemModel:
|
||||||
|
"""Dependency to get an item and verify the user has access to its list."""
|
||||||
item_db = await crud_item.get_item_by_id(db, item_id=item_id)
|
item_db = await crud_item.get_item_by_id(db, item_id=item_id)
|
||||||
if not item_db:
|
if not item_db:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Item not found")
|
raise ItemNotFoundError(item_id)
|
||||||
|
|
||||||
# Check permission on the parent list
|
# Check permission on the parent list
|
||||||
list_db = await crud_list.check_list_permission(db=db, list_id=item_db.list_id, user_id=current_user.id)
|
try:
|
||||||
if not list_db:
|
await crud_list.check_list_permission(db=db, list_id=item_db.list_id, user_id=current_user.id)
|
||||||
# User doesn't have access to the list this item belongs to
|
except ListPermissionError as e:
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to access this item's list")
|
# Re-raise with a more specific message
|
||||||
return item_db # Return the fetched item if authorized
|
raise ListPermissionError(item_db.list_id, "access this item's list")
|
||||||
|
return item_db
|
||||||
|
|
||||||
|
|
||||||
# --- Endpoints ---
|
# --- Endpoints ---
|
||||||
@ -49,25 +52,23 @@ async def get_item_and_verify_access(
|
|||||||
async def create_list_item(
|
async def create_list_item(
|
||||||
list_id: int,
|
list_id: int,
|
||||||
item_in: ItemCreate,
|
item_in: ItemCreate,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
current_user: UserModel = Depends(get_current_user),
|
current_user: UserModel = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""Adds a new item to a specific list. User must have access to the list."""
|
"""Adds a new item to a specific list. User must have access to the list."""
|
||||||
logger.info(f"User {current_user.email} adding item to list {list_id}: {item_in.name}")
|
user_email = current_user.email # Access email attribute before async operations
|
||||||
|
logger.info(f"User {user_email} adding item to list {list_id}: {item_in.name}")
|
||||||
# Verify user has access to the target list
|
# Verify user has access to the target list
|
||||||
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
|
try:
|
||||||
if not list_db:
|
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
|
||||||
# Check if list exists at all for correct error code
|
except ListPermissionError as e:
|
||||||
exists = await crud_list.get_list_by_id(db, list_id)
|
# Re-raise with a more specific message
|
||||||
status_code = status.HTTP_404_NOT_FOUND if not exists else status.HTTP_403_FORBIDDEN
|
raise ListPermissionError(list_id, "add items to this list")
|
||||||
detail = "List not found" if not exists else "You do not have permission to add items to this list"
|
|
||||||
logger.warning(f"Add item failed for list {list_id} by user {current_user.email}: {detail}")
|
|
||||||
raise HTTPException(status_code=status_code, detail=detail)
|
|
||||||
|
|
||||||
created_item = await crud_item.create_item(
|
created_item = await crud_item.create_item(
|
||||||
db=db, item_in=item_in, list_id=list_id, user_id=current_user.id
|
db=db, item_in=item_in, list_id=list_id, user_id=current_user.id
|
||||||
)
|
)
|
||||||
logger.info(f"Item '{created_item.name}' (ID: {created_item.id}) added to list {list_id} by user {current_user.email}.")
|
logger.info(f"Item '{created_item.name}' (ID: {created_item.id}) added to list {list_id} by user {user_email}.")
|
||||||
return created_item
|
return created_item
|
||||||
|
|
||||||
|
|
||||||
@ -79,72 +80,102 @@ async def create_list_item(
|
|||||||
)
|
)
|
||||||
async def read_list_items(
|
async def read_list_items(
|
||||||
list_id: int,
|
list_id: int,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
current_user: UserModel = Depends(get_current_user),
|
current_user: UserModel = Depends(current_active_user),
|
||||||
# Add sorting/filtering params later if needed: sort_by: str = 'created_at', order: str = 'asc'
|
# Add sorting/filtering params later if needed: sort_by: str = 'created_at', order: str = 'asc'
|
||||||
):
|
):
|
||||||
"""Retrieves all items for a specific list if the user has access."""
|
"""Retrieves all items for a specific list if the user has access."""
|
||||||
logger.info(f"User {current_user.email} listing items for list {list_id}")
|
user_email = current_user.email # Access email attribute before async operations
|
||||||
|
logger.info(f"User {user_email} listing items for list {list_id}")
|
||||||
# Verify user has access to the list
|
# Verify user has access to the list
|
||||||
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
|
try:
|
||||||
if not list_db:
|
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
|
||||||
exists = await crud_list.get_list_by_id(db, list_id)
|
except ListPermissionError as e:
|
||||||
status_code = status.HTTP_404_NOT_FOUND if not exists else status.HTTP_403_FORBIDDEN
|
# Re-raise with a more specific message
|
||||||
detail = "List not found" if not exists else "You do not have permission to view items in this list"
|
raise ListPermissionError(list_id, "view items in this list")
|
||||||
logger.warning(f"List items failed for list {list_id} by user {current_user.email}: {detail}")
|
|
||||||
raise HTTPException(status_code=status_code, detail=detail)
|
|
||||||
|
|
||||||
items = await crud_item.get_items_by_list_id(db=db, list_id=list_id)
|
items = await crud_item.get_items_by_list_id(db=db, list_id=list_id)
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
@router.put(
|
@router.put(
|
||||||
"/items/{item_id}", # Operate directly on item ID
|
"/lists/{list_id}/items/{item_id}", # Nested under lists
|
||||||
response_model=ItemPublic,
|
response_model=ItemPublic,
|
||||||
summary="Update Item",
|
summary="Update Item",
|
||||||
tags=["Items"]
|
tags=["Items"],
|
||||||
|
responses={
|
||||||
|
status.HTTP_409_CONFLICT: {"description": "Conflict: Item has been modified by someone else"}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
async def update_item(
|
async def update_item(
|
||||||
item_id: int, # Item ID from path
|
list_id: int,
|
||||||
|
item_id: int,
|
||||||
item_in: ItemUpdate,
|
item_in: ItemUpdate,
|
||||||
item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access
|
item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
current_user: UserModel = Depends(get_current_user), # Need user ID for completed_by
|
current_user: UserModel = Depends(current_active_user), # Need user ID for completed_by
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Updates an item's details (name, quantity, is_complete, price).
|
Updates an item's details (name, quantity, is_complete, price).
|
||||||
User must have access to the list the item belongs to.
|
User must have access to the list the item belongs to.
|
||||||
|
The client MUST provide the current `version` of the item in the `item_in` payload.
|
||||||
|
If the version does not match, a 409 Conflict is returned.
|
||||||
Sets/unsets `completed_by_id` based on `is_complete` flag.
|
Sets/unsets `completed_by_id` based on `is_complete` flag.
|
||||||
"""
|
"""
|
||||||
logger.info(f"User {current_user.email} attempting to update item ID: {item_id}")
|
user_email = current_user.email # Access email attribute before async operations
|
||||||
|
logger.info(f"User {user_email} attempting to update item ID: {item_id} with version {item_in.version}")
|
||||||
# Permission check is handled by get_item_and_verify_access dependency
|
# Permission check is handled by get_item_and_verify_access dependency
|
||||||
|
|
||||||
updated_item = await crud_item.update_item(
|
try:
|
||||||
db=db, item_db=item_db, item_in=item_in, user_id=current_user.id
|
updated_item = await crud_item.update_item(
|
||||||
)
|
db=db, item_db=item_db, item_in=item_in, user_id=current_user.id
|
||||||
logger.info(f"Item {item_id} updated successfully by user {current_user.email}.")
|
)
|
||||||
return updated_item
|
logger.info(f"Item {item_id} updated successfully by user {user_email} to version {updated_item.version}.")
|
||||||
|
return updated_item
|
||||||
|
except ConflictError as e:
|
||||||
|
logger.warning(f"Conflict updating item {item_id} for user {user_email}: {str(e)}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating item {item_id} for user {user_email}: {str(e)}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the item.")
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
"/items/{item_id}", # Operate directly on item ID
|
"/lists/{list_id}/items/{item_id}", # Nested under lists
|
||||||
status_code=status.HTTP_204_NO_CONTENT,
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
summary="Delete Item",
|
summary="Delete Item",
|
||||||
tags=["Items"]
|
tags=["Items"],
|
||||||
|
responses={
|
||||||
|
status.HTTP_409_CONFLICT: {"description": "Conflict: Item has been modified, cannot delete specified version"}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
async def delete_item(
|
async def delete_item(
|
||||||
item_id: int, # Item ID from path
|
list_id: int,
|
||||||
|
item_id: int,
|
||||||
|
expected_version: Optional[int] = Query(None, description="The expected version of the item to delete for optimistic locking."),
|
||||||
item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access
|
item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
current_user: UserModel = Depends(get_current_user), # Log who deleted it
|
current_user: UserModel = Depends(current_active_user), # Log who deleted it
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Deletes an item. User must have access to the list the item belongs to.
|
Deletes an item. User must have access to the list the item belongs to.
|
||||||
(MVP: Any member with list access can delete items).
|
If `expected_version` is provided and does not match the item's current version,
|
||||||
|
a 409 Conflict is returned.
|
||||||
"""
|
"""
|
||||||
logger.info(f"User {current_user.email} attempting to delete item ID: {item_id}")
|
user_email = current_user.email # Access email attribute before async operations
|
||||||
|
logger.info(f"User {user_email} attempting to delete item ID: {item_id}, expected version: {expected_version}")
|
||||||
# Permission check is handled by get_item_and_verify_access dependency
|
# Permission check is handled by get_item_and_verify_access dependency
|
||||||
|
|
||||||
|
if expected_version is not None and item_db.version != expected_version:
|
||||||
|
logger.warning(
|
||||||
|
f"Conflict deleting item {item_id} for user {user_email}. "
|
||||||
|
f"Expected version {expected_version}, actual version {item_db.version}."
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=f"Item has been modified. Expected version {expected_version}, but current version is {item_db.version}. Please refresh."
|
||||||
|
)
|
||||||
|
|
||||||
await crud_item.delete_item(db=db, item_db=item_db)
|
await crud_item.delete_item(db=db, item_db=item_db)
|
||||||
logger.info(f"Item {item_id} deleted successfully by user {current_user.email}.")
|
logger.info(f"Item {item_id} (version {item_db.version}) deleted successfully by user {user_email}.")
|
||||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
@ -1,18 +1,27 @@
|
|||||||
# app/api/v1/endpoints/lists.py
|
# app/api/v1/endpoints/lists.py
|
||||||
import logging
|
import logging
|
||||||
from typing import List as PyList # Alias for Python List type hint
|
from typing import List as PyList, Optional # Alias for Python List type hint
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Response
|
from fastapi import APIRouter, Depends, HTTPException, status, Response, Query # Added Query
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_transactional_session
|
||||||
from app.api.dependencies import get_current_user
|
from app.auth import current_active_user
|
||||||
from app.models import User as UserModel
|
from app.models import User as UserModel
|
||||||
from app.schemas.list import ListCreate, ListUpdate, ListPublic, ListDetail
|
from app.schemas.list import ListCreate, ListUpdate, ListPublic, ListDetail
|
||||||
from app.schemas.message import Message # For simple responses
|
from app.schemas.message import Message # For simple responses
|
||||||
from app.crud import list as crud_list
|
from app.crud import list as crud_list
|
||||||
from app.crud import group as crud_group # Need for group membership check
|
from app.crud import group as crud_group # Need for group membership check
|
||||||
from app.schemas.list import ListStatus
|
from app.schemas.list import ListStatus, ListStatusWithId
|
||||||
|
from app.schemas.expense import ExpensePublic # Import ExpensePublic
|
||||||
|
from app.core.exceptions import (
|
||||||
|
GroupMembershipError,
|
||||||
|
ListNotFoundError,
|
||||||
|
ListPermissionError,
|
||||||
|
ListStatusNotFoundError,
|
||||||
|
ConflictError, # Added ConflictError
|
||||||
|
DatabaseIntegrityError # Added DatabaseIntegrityError
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -22,17 +31,24 @@ router = APIRouter()
|
|||||||
response_model=ListPublic, # Return basic list info on creation
|
response_model=ListPublic, # Return basic list info on creation
|
||||||
status_code=status.HTTP_201_CREATED,
|
status_code=status.HTTP_201_CREATED,
|
||||||
summary="Create New List",
|
summary="Create New List",
|
||||||
tags=["Lists"]
|
tags=["Lists"],
|
||||||
|
responses={
|
||||||
|
status.HTTP_409_CONFLICT: {
|
||||||
|
"description": "Conflict: A list with this name already exists in the specified group",
|
||||||
|
"model": ListPublic
|
||||||
|
}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
async def create_list(
|
async def create_list(
|
||||||
list_in: ListCreate,
|
list_in: ListCreate,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
current_user: UserModel = Depends(get_current_user),
|
current_user: UserModel = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates a new shopping list.
|
Creates a new shopping list.
|
||||||
- If `group_id` is provided, the user must be a member of that group.
|
- If `group_id` is provided, the user must be a member of that group.
|
||||||
- If `group_id` is null, it's a personal list.
|
- If `group_id` is null, it's a personal list.
|
||||||
|
- If a list with the same name already exists in the group, returns 409 with the existing list.
|
||||||
"""
|
"""
|
||||||
logger.info(f"User {current_user.email} creating list: {list_in.name}")
|
logger.info(f"User {current_user.email} creating list: {list_in.name}")
|
||||||
group_id = list_in.group_id
|
group_id = list_in.group_id
|
||||||
@ -42,25 +58,42 @@ async def create_list(
|
|||||||
is_member = await crud_group.is_user_member(db, group_id=group_id, user_id=current_user.id)
|
is_member = await crud_group.is_user_member(db, group_id=group_id, user_id=current_user.id)
|
||||||
if not is_member:
|
if not is_member:
|
||||||
logger.warning(f"User {current_user.email} attempted to create list in group {group_id} but is not a member.")
|
logger.warning(f"User {current_user.email} attempted to create list in group {group_id} but is not a member.")
|
||||||
raise HTTPException(
|
raise GroupMembershipError(group_id, "create lists")
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="You are not a member of the specified group",
|
|
||||||
)
|
|
||||||
|
|
||||||
created_list = await crud_list.create_list(db=db, list_in=list_in, creator_id=current_user.id)
|
try:
|
||||||
logger.info(f"List '{created_list.name}' (ID: {created_list.id}) created successfully for user {current_user.email}.")
|
created_list = await crud_list.create_list(db=db, list_in=list_in, creator_id=current_user.id)
|
||||||
return created_list
|
logger.info(f"List '{created_list.name}' (ID: {created_list.id}) created successfully for user {current_user.email}.")
|
||||||
|
return created_list
|
||||||
|
except DatabaseIntegrityError as e:
|
||||||
|
# Check if this is a unique constraint violation
|
||||||
|
if "unique constraint" in str(e).lower():
|
||||||
|
# Find the existing list with the same name in the group
|
||||||
|
existing_list = await crud_list.get_list_by_name_and_group(
|
||||||
|
db=db,
|
||||||
|
name=list_in.name,
|
||||||
|
group_id=group_id,
|
||||||
|
user_id=current_user.id
|
||||||
|
)
|
||||||
|
if existing_list:
|
||||||
|
logger.info(f"List '{list_in.name}' already exists in group {group_id}. Returning existing list.")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=f"A list named '{list_in.name}' already exists in this group.",
|
||||||
|
headers={"X-Existing-List": str(existing_list.id)}
|
||||||
|
)
|
||||||
|
# If it's not a unique constraint or we couldn't find the existing list, re-raise
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"", # Route relative to prefix "/lists"
|
"", # Route relative to prefix "/lists"
|
||||||
response_model=PyList[ListPublic], # Return a list of basic list info
|
response_model=PyList[ListDetail], # Return a list of detailed list info including items
|
||||||
summary="List Accessible Lists",
|
summary="List Accessible Lists",
|
||||||
tags=["Lists"]
|
tags=["Lists"]
|
||||||
)
|
)
|
||||||
async def read_lists(
|
async def read_lists(
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
current_user: UserModel = Depends(get_current_user),
|
current_user: UserModel = Depends(current_active_user),
|
||||||
# Add pagination parameters later if needed: skip: int = 0, limit: int = 100
|
# Add pagination parameters later if needed: skip: int = 0, limit: int = 100
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -73,6 +106,39 @@ async def read_lists(
|
|||||||
return lists
|
return lists
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/statuses",
|
||||||
|
response_model=PyList[ListStatusWithId],
|
||||||
|
summary="Get Status for Multiple Lists",
|
||||||
|
tags=["Lists"]
|
||||||
|
)
|
||||||
|
async def read_lists_statuses(
|
||||||
|
ids: PyList[int] = Query(...),
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieves the status for a list of lists.
|
||||||
|
- `updated_at`: The timestamp of the last update to the list itself.
|
||||||
|
- `item_count`: The total number of items in the list.
|
||||||
|
The user must have permission to view each list requested.
|
||||||
|
Lists that the user does not have permission for will be omitted from the response.
|
||||||
|
"""
|
||||||
|
logger.info(f"User {current_user.email} requesting statuses for list IDs: {ids}")
|
||||||
|
|
||||||
|
statuses = await crud_list.get_lists_statuses_by_ids(db=db, list_ids=ids, user_id=current_user.id)
|
||||||
|
|
||||||
|
# The CRUD function returns a list of Row objects, so we map them to the Pydantic model
|
||||||
|
return [
|
||||||
|
ListStatusWithId(
|
||||||
|
id=s.id,
|
||||||
|
updated_at=s.updated_at,
|
||||||
|
item_count=s.item_count,
|
||||||
|
latest_item_updated_at=s.latest_item_updated_at
|
||||||
|
) for s in statuses
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/{list_id}",
|
"/{list_id}",
|
||||||
response_model=ListDetail, # Return detailed list info including items
|
response_model=ListDetail, # Return detailed list info including items
|
||||||
@ -81,29 +147,16 @@ async def read_lists(
|
|||||||
)
|
)
|
||||||
async def read_list(
|
async def read_list(
|
||||||
list_id: int,
|
list_id: int,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
current_user: UserModel = Depends(get_current_user),
|
current_user: UserModel = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Retrieves details for a specific list, including its items,
|
Retrieves details for a specific list, including its items,
|
||||||
if the user has permission (creator or group member).
|
if the user has permission (creator or group member).
|
||||||
"""
|
"""
|
||||||
logger.info(f"User {current_user.email} requesting details for list ID: {list_id}")
|
logger.info(f"User {current_user.email} requesting details for list ID: {list_id}")
|
||||||
# Use the helper to fetch and check permission simultaneously
|
# The check_list_permission function will raise appropriate exceptions
|
||||||
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
|
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
|
||||||
|
|
||||||
if not list_db:
|
|
||||||
# check_list_permission returns None if list not found OR permission denied
|
|
||||||
# We need to check if the list exists at all to return 404 vs 403
|
|
||||||
exists = await crud_list.get_list_by_id(db, list_id)
|
|
||||||
if not exists:
|
|
||||||
logger.warning(f"List ID {list_id} not found for request by user {current_user.email}.")
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="List not found")
|
|
||||||
else:
|
|
||||||
logger.warning(f"Access denied: User {current_user.email} cannot access list {list_id}.")
|
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You do not have permission to access this list")
|
|
||||||
|
|
||||||
# list_db already has items loaded due to check_list_permission
|
|
||||||
return list_db
|
return list_db
|
||||||
|
|
||||||
|
|
||||||
@ -111,101 +164,123 @@ async def read_list(
|
|||||||
"/{list_id}",
|
"/{list_id}",
|
||||||
response_model=ListPublic, # Return updated basic info
|
response_model=ListPublic, # Return updated basic info
|
||||||
summary="Update List",
|
summary="Update List",
|
||||||
tags=["Lists"]
|
tags=["Lists"],
|
||||||
|
responses={ # Add 409 to responses
|
||||||
|
status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified by someone else"}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
async def update_list(
|
async def update_list(
|
||||||
list_id: int,
|
list_id: int,
|
||||||
list_in: ListUpdate,
|
list_in: ListUpdate,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
current_user: UserModel = Depends(get_current_user),
|
current_user: UserModel = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Updates a list's details (name, description, is_complete).
|
Updates a list's details (name, description, is_complete).
|
||||||
Requires user to be the creator or a member of the list's group.
|
Requires user to be the creator or a member of the list's group.
|
||||||
(MVP: Allows any member to update these fields).
|
The client MUST provide the current `version` of the list in the `list_in` payload.
|
||||||
|
If the version does not match, a 409 Conflict is returned.
|
||||||
"""
|
"""
|
||||||
logger.info(f"User {current_user.email} attempting to update list ID: {list_id}")
|
logger.info(f"User {current_user.email} attempting to update list ID: {list_id} with version {list_in.version}")
|
||||||
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
|
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
|
||||||
|
|
||||||
if not list_db:
|
try:
|
||||||
exists = await crud_list.get_list_by_id(db, list_id)
|
updated_list = await crud_list.update_list(db=db, list_db=list_db, list_in=list_in)
|
||||||
status_code = status.HTTP_404_NOT_FOUND if not exists else status.HTTP_403_FORBIDDEN
|
logger.info(f"List {list_id} updated successfully by user {current_user.email} to version {updated_list.version}.")
|
||||||
detail = "List not found" if not exists else "You do not have permission to update this list"
|
return updated_list
|
||||||
logger.warning(f"Update failed for list {list_id} by user {current_user.email}: {detail}")
|
except ConflictError as e: # Catch and re-raise as HTTPException for proper FastAPI response
|
||||||
raise HTTPException(status_code=status_code, detail=detail)
|
logger.warning(f"Conflict updating list {list_id} for user {current_user.email}: {str(e)}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
|
||||||
# Prevent changing group_id or creator via this endpoint for simplicity
|
except Exception as e: # Catch other potential errors from crud operation
|
||||||
# if list_in.group_id is not None or list_in.created_by_id is not None:
|
logger.error(f"Error updating list {list_id} for user {current_user.email}: {str(e)}")
|
||||||
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot change group or creator via this endpoint")
|
# Consider a more generic error, but for now, let's keep it specific if possible
|
||||||
|
# Re-raising might be better if crud layer already raises appropriate HTTPExceptions
|
||||||
updated_list = await crud_list.update_list(db=db, list_db=list_db, list_in=list_in)
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the list.")
|
||||||
logger.info(f"List {list_id} updated successfully by user {current_user.email}.")
|
|
||||||
return updated_list
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
"/{list_id}",
|
"/{list_id}",
|
||||||
status_code=status.HTTP_204_NO_CONTENT, # Standard for successful DELETE with no body
|
status_code=status.HTTP_204_NO_CONTENT, # Standard for successful DELETE with no body
|
||||||
summary="Delete List",
|
summary="Delete List",
|
||||||
tags=["Lists"]
|
tags=["Lists"],
|
||||||
|
responses={ # Add 409 to responses
|
||||||
|
status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified, cannot delete specified version"}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
async def delete_list(
|
async def delete_list(
|
||||||
list_id: int,
|
list_id: int,
|
||||||
db: AsyncSession = Depends(get_db),
|
expected_version: Optional[int] = Query(None, description="The expected version of the list to delete for optimistic locking."),
|
||||||
current_user: UserModel = Depends(get_current_user),
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Deletes a list. Requires user to be the creator of the list.
|
Deletes a list. Requires user to be the creator of the list.
|
||||||
(Alternatively, could allow group owner).
|
If `expected_version` is provided and does not match the list's current version,
|
||||||
|
a 409 Conflict is returned.
|
||||||
"""
|
"""
|
||||||
logger.info(f"User {current_user.email} attempting to delete list ID: {list_id}")
|
logger.info(f"User {current_user.email} attempting to delete list ID: {list_id}, expected version: {expected_version}")
|
||||||
# Use the helper, requiring creator permission
|
# Use the helper, requiring creator permission
|
||||||
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id, require_creator=True)
|
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id, require_creator=True)
|
||||||
|
|
||||||
if not list_db:
|
if expected_version is not None and list_db.version != expected_version:
|
||||||
exists = await crud_list.get_list_by_id(db, list_id)
|
logger.warning(
|
||||||
status_code = status.HTTP_404_NOT_FOUND if not exists else status.HTTP_403_FORBIDDEN
|
f"Conflict deleting list {list_id} for user {current_user.email}. "
|
||||||
detail = "List not found" if not exists else "Only the list creator can delete this list"
|
f"Expected version {expected_version}, actual version {list_db.version}."
|
||||||
logger.warning(f"Delete failed for list {list_id} by user {current_user.email}: {detail}")
|
)
|
||||||
raise HTTPException(status_code=status_code, detail=detail)
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=f"List has been modified. Expected version {expected_version}, but current version is {list_db.version}. Please refresh."
|
||||||
|
)
|
||||||
|
|
||||||
await crud_list.delete_list(db=db, list_db=list_db)
|
await crud_list.delete_list(db=db, list_db=list_db)
|
||||||
logger.info(f"List {list_id} deleted successfully by user {current_user.email}.")
|
logger.info(f"List {list_id} (version: {list_db.version}) deleted successfully by user {current_user.email}.")
|
||||||
# Return Response with 204 status explicitly if needed, otherwise FastAPI handles it
|
|
||||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/{list_id}/status",
|
"/{list_id}/status",
|
||||||
response_model=ListStatus,
|
response_model=ListStatus,
|
||||||
summary="Get List Status (for polling)",
|
summary="Get List Status",
|
||||||
tags=["Lists"]
|
tags=["Lists"]
|
||||||
)
|
)
|
||||||
async def read_list_status(
|
async def read_list_status(
|
||||||
list_id: int,
|
list_id: int,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
current_user: UserModel = Depends(get_current_user),
|
current_user: UserModel = Depends(current_active_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Retrieves the last update time for the list and its items, plus item count.
|
Retrieves the update timestamp and item count for a specific list
|
||||||
Used for polling to check if a full refresh is needed.
|
if the user has permission (creator or group member).
|
||||||
Requires user to have permission to view the list.
|
|
||||||
"""
|
"""
|
||||||
# Verify user has access to the list first
|
logger.info(f"User {current_user.email} requesting status for list ID: {list_id}")
|
||||||
list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
|
# The check_list_permission is not needed here as get_list_status handles not found
|
||||||
if not list_db:
|
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
|
||||||
# Check if list exists at all for correct error code
|
return await crud_list.get_list_status(db=db, list_id=list_id)
|
||||||
exists = await crud_list.get_list_by_id(db, list_id)
|
|
||||||
status_code = status.HTTP_404_NOT_FOUND if not exists else status.HTTP_403_FORBIDDEN
|
|
||||||
detail = "List not found" if not exists else "You do not have permission to access this list's status"
|
|
||||||
logger.warning(f"Status check failed for list {list_id} by user {current_user.email}: {detail}")
|
|
||||||
raise HTTPException(status_code=status_code, detail=detail)
|
|
||||||
|
|
||||||
# Fetch the status details
|
@router.get(
|
||||||
list_status = await crud_list.get_list_status(db=db, list_id=list_id)
|
"/{list_id}/expenses",
|
||||||
if not list_status:
|
response_model=PyList[ExpensePublic],
|
||||||
# Should not happen if check_list_permission passed, but handle defensively
|
summary="Get Expenses for List",
|
||||||
logger.error(f"Could not retrieve status for list {list_id} even though permission check passed.")
|
tags=["Lists", "Expenses"]
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="List status not found")
|
)
|
||||||
|
async def read_list_expenses(
|
||||||
|
list_id: int,
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(100, ge=1, le=200),
|
||||||
|
db: AsyncSession = Depends(get_transactional_session),
|
||||||
|
current_user: UserModel = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieves expenses associated with a specific list
|
||||||
|
if the user has permission (creator or group member).
|
||||||
|
"""
|
||||||
|
from app.crud import expense as crud_expense
|
||||||
|
|
||||||
return list_status
|
logger.info(f"User {current_user.email} requesting expenses for list ID: {list_id}")
|
||||||
|
|
||||||
|
# Check if user has permission to access this list
|
||||||
|
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
|
||||||
|
|
||||||
|
# Get expenses for this list
|
||||||
|
expenses = await crud_expense.get_expenses_for_list(db, list_id=list_id, skip=skip, limit=limit)
|
||||||
|
return expenses
|
@ -1,21 +1,27 @@
|
|||||||
# app/api/v1/endpoints/ocr.py
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File
|
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException, status
|
||||||
from google.api_core import exceptions as google_exceptions # Import Google API exceptions
|
from google.api_core import exceptions as google_exceptions
|
||||||
|
|
||||||
from app.api.dependencies import get_current_user
|
from app.auth import current_active_user
|
||||||
from app.models import User as UserModel
|
from app.models import User as UserModel
|
||||||
from app.schemas.ocr import OcrExtractResponse
|
from app.schemas.ocr import OcrExtractResponse
|
||||||
from app.core.gemini import extract_items_from_image_gemini, gemini_initialization_error # Import helper
|
from app.core.gemini import GeminiOCRService, gemini_initialization_error
|
||||||
|
from app.core.exceptions import (
|
||||||
|
OCRServiceUnavailableError,
|
||||||
|
OCRServiceConfigError,
|
||||||
|
OCRUnexpectedError,
|
||||||
|
OCRQuotaExceededError,
|
||||||
|
InvalidFileTypeError,
|
||||||
|
FileTooLargeError,
|
||||||
|
OCRProcessingError
|
||||||
|
)
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
ocr_service = GeminiOCRService()
|
||||||
# Allowed image MIME types
|
|
||||||
ALLOWED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/webp"]
|
|
||||||
MAX_FILE_SIZE_MB = 10 # Set a reasonable max file size
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/extract-items",
|
"/extract-items",
|
||||||
@ -24,8 +30,7 @@ MAX_FILE_SIZE_MB = 10 # Set a reasonable max file size
|
|||||||
tags=["OCR"]
|
tags=["OCR"]
|
||||||
)
|
)
|
||||||
async def ocr_extract_items(
|
async def ocr_extract_items(
|
||||||
current_user: UserModel = Depends(get_current_user),
|
current_user: UserModel = Depends(current_active_user),
|
||||||
# Use File(...) for better metadata handling than UploadFile directly as type hint
|
|
||||||
image_file: UploadFile = File(..., description="Image file (JPEG, PNG, WEBP) of the shopping list or receipt."),
|
image_file: UploadFile = File(..., description="Image file (JPEG, PNG, WEBP) of the shopping list or receipt."),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -34,75 +39,38 @@ async def ocr_extract_items(
|
|||||||
"""
|
"""
|
||||||
# Check if Gemini client initialized correctly
|
# Check if Gemini client initialized correctly
|
||||||
if gemini_initialization_error:
|
if gemini_initialization_error:
|
||||||
logger.error("OCR endpoint called but Gemini client failed to initialize.")
|
logger.error("OCR endpoint called but Gemini client failed to initialize.")
|
||||||
raise HTTPException(
|
raise OCRServiceUnavailableError(gemini_initialization_error)
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
||||||
detail=f"OCR service unavailable: {gemini_initialization_error}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"User {current_user.email} uploading image '{image_file.filename}' for OCR extraction.")
|
logger.info(f"User {current_user.email} uploading image '{image_file.filename}' for OCR extraction.")
|
||||||
|
|
||||||
# --- File Validation ---
|
# --- File Validation ---
|
||||||
if image_file.content_type not in ALLOWED_IMAGE_TYPES:
|
if image_file.content_type not in settings.ALLOWED_IMAGE_TYPES:
|
||||||
logger.warning(f"Invalid file type uploaded by {current_user.email}: {image_file.content_type}")
|
logger.warning(f"Invalid file type uploaded by {current_user.email}: {image_file.content_type}")
|
||||||
raise HTTPException(
|
raise InvalidFileTypeError()
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Invalid file type. Allowed types: {', '.join(ALLOWED_IMAGE_TYPES)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Simple size check (FastAPI/Starlette might handle larger limits via config)
|
# Simple size check
|
||||||
# Read content first to get size accurately
|
|
||||||
contents = await image_file.read()
|
contents = await image_file.read()
|
||||||
if len(contents) > MAX_FILE_SIZE_MB * 1024 * 1024:
|
if len(contents) > settings.MAX_FILE_SIZE_MB * 1024 * 1024:
|
||||||
logger.warning(f"File too large uploaded by {current_user.email}: {len(contents)} bytes")
|
logger.warning(f"File too large uploaded by {current_user.email}: {len(contents)} bytes")
|
||||||
raise HTTPException(
|
raise FileTooLargeError()
|
||||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
|
||||||
detail=f"File size exceeds limit of {MAX_FILE_SIZE_MB} MB.",
|
|
||||||
)
|
|
||||||
# --- End File Validation ---
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Call the Gemini helper function
|
# Use the ocr_service instance instead of the standalone function
|
||||||
extracted_items = await extract_items_from_image_gemini(image_bytes=contents)
|
extracted_items = await ocr_service.extract_items(image_data=contents)
|
||||||
|
|
||||||
logger.info(f"Successfully extracted {len(extracted_items)} items for user {current_user.email}.")
|
logger.info(f"Successfully extracted {len(extracted_items)} items for user {current_user.email}.")
|
||||||
return OcrExtractResponse(extracted_items=extracted_items)
|
return OcrExtractResponse(extracted_items=extracted_items)
|
||||||
|
|
||||||
except ValueError as e:
|
except OCRServiceUnavailableError:
|
||||||
# Handle errors from Gemini processing (blocked, empty response, etc.)
|
raise OCRServiceUnavailableError()
|
||||||
logger.warning(f"Gemini processing error for user {current_user.email}: {e}")
|
except OCRServiceConfigError:
|
||||||
raise HTTPException(
|
raise OCRServiceConfigError()
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, # Or 400 Bad Request?
|
except OCRQuotaExceededError:
|
||||||
detail=f"Could not extract items from image: {e}",
|
raise OCRQuotaExceededError()
|
||||||
)
|
|
||||||
except google_exceptions.ResourceExhausted as e:
|
|
||||||
# Specific handling for quota errors
|
|
||||||
logger.error(f"Gemini Quota Exceeded for user {current_user.email}: {e}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
||||||
detail="OCR service quota exceeded. Please try again later.",
|
|
||||||
)
|
|
||||||
except google_exceptions.GoogleAPIError as e:
|
|
||||||
# Handle other Google API errors (e.g., invalid key, permissions)
|
|
||||||
logger.error(f"Gemini API Error for user {current_user.email}: {e}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
||||||
detail=f"OCR service error: {e}",
|
|
||||||
)
|
|
||||||
except RuntimeError as e:
|
|
||||||
# Catch initialization errors from get_gemini_client()
|
|
||||||
logger.error(f"Gemini client runtime error during OCR request: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
||||||
detail=f"OCR service configuration error: {e}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Catch any other unexpected errors
|
raise OCRProcessingError(str(e))
|
||||||
logger.exception(f"Unexpected error during OCR extraction for user {current_user.email}: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="An unexpected error occurred during item extraction.",
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
# Ensure file handle is closed (UploadFile uses SpooledTemporaryFile)
|
# Ensure file handle is closed
|
||||||
await image_file.close()
|
await image_file.close()
|
@ -1,30 +0,0 @@
|
|||||||
# app/api/v1/endpoints/users.py
|
|
||||||
import logging
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
|
||||||
|
|
||||||
from app.api.dependencies import get_current_user # Import the dependency
|
|
||||||
from app.schemas.user import UserPublic # Import the response schema
|
|
||||||
from app.models import User as UserModel # Import the DB model for type hinting
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/me",
|
|
||||||
response_model=UserPublic, # Use the public schema to avoid exposing hash
|
|
||||||
summary="Get Current User",
|
|
||||||
description="Retrieves the details of the currently authenticated user.",
|
|
||||||
tags=["Users"]
|
|
||||||
)
|
|
||||||
async def read_users_me(
|
|
||||||
current_user: UserModel = Depends(get_current_user) # Apply the dependency
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Returns the data for the user associated with the current valid access token.
|
|
||||||
"""
|
|
||||||
logger.info(f"Fetching details for current user: {current_user.email}")
|
|
||||||
# The 'current_user' object is the SQLAlchemy model instance returned by the dependency.
|
|
||||||
# Pydantic's response_model will automatically convert it using UserPublic schema.
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
# Add other user-related endpoints here later (e.g., update user, list users (admin))
|
|
@ -3,7 +3,7 @@ import pytest
|
|||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
from app.schemas.user import UserPublic # For response validation
|
from app.schemas.user import UserPublic # For response validation
|
||||||
from app.core.security import create_access_token
|
# from app.core.security import create_access_token # Commented out as FastAPI-Users handles token creation
|
||||||
|
|
||||||
pytestmark = pytest.mark.asyncio
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
@ -51,15 +51,15 @@ async def test_read_users_me_invalid_token(client: AsyncClient):
|
|||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
assert response.json()["detail"] == "Could not validate credentials" # Detail from our dependency
|
assert response.json()["detail"] == "Could not validate credentials" # Detail from our dependency
|
||||||
|
|
||||||
async def test_read_users_me_expired_token(client: AsyncClient):
|
# async def test_read_users_me_expired_token(client: AsyncClient):
|
||||||
# Create a short-lived token manually (or adjust settings temporarily)
|
# # Create a short-lived token manually (or adjust settings temporarily)
|
||||||
email = "testexpired@example.com"
|
# email = "testexpired@example.com"
|
||||||
# Assume create_access_token allows timedelta override
|
# # Assume create_access_token allows timedelta override
|
||||||
expired_token = create_access_token(subject=email, expires_delta=timedelta(seconds=-10))
|
# # expired_token = create_access_token(subject=email, expires_delta=timedelta(seconds=-10))
|
||||||
headers = {"Authorization": f"Bearer {expired_token}"}
|
# # headers = {"Authorization": f"Bearer {expired_token}"}
|
||||||
|
|
||||||
response = await client.get("/api/v1/users/me", headers=headers)
|
# # response = await client.get("/api/v1/users/me", headers=headers)
|
||||||
assert response.status_code == 401
|
# # assert response.status_code == 401
|
||||||
assert response.json()["detail"] == "Could not validate credentials"
|
# # assert response.json()["detail"] == "Could not validate credentials" # Detail from our dependency
|
||||||
|
|
||||||
# Add test case for valid token but user deleted from DB if needed
|
# Add test case for valid token but user deleted from DB if needed
|
151
be/app/auth.py
Normal file
151
be/app/auth.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import Depends, Request
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
from fastapi_users import BaseUserManager, FastAPIUsers, IntegerIDMixin
|
||||||
|
from fastapi_users.authentication import (
|
||||||
|
AuthenticationBackend,
|
||||||
|
BearerTransport,
|
||||||
|
JWTStrategy,
|
||||||
|
)
|
||||||
|
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from authlib.integrations.starlette_client import OAuth
|
||||||
|
from starlette.config import Config
|
||||||
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
|
from starlette.responses import Response
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from .database import get_session
|
||||||
|
from .models import User
|
||||||
|
from .config import settings
|
||||||
|
|
||||||
|
# OAuth2 configuration
|
||||||
|
config = Config('.env')
|
||||||
|
oauth = OAuth(config)
|
||||||
|
|
||||||
|
# Google OAuth2 setup
|
||||||
|
oauth.register(
|
||||||
|
name='google',
|
||||||
|
server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
|
||||||
|
client_kwargs={
|
||||||
|
'scope': 'openid email profile',
|
||||||
|
'redirect_uri': settings.GOOGLE_REDIRECT_URI
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apple OAuth2 setup
|
||||||
|
oauth.register(
|
||||||
|
name='apple',
|
||||||
|
server_metadata_url='https://appleid.apple.com/.well-known/openid-configuration',
|
||||||
|
client_kwargs={
|
||||||
|
'scope': 'openid email name',
|
||||||
|
'redirect_uri': settings.APPLE_REDIRECT_URI
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Custom Bearer Response with Refresh Token
|
||||||
|
class BearerResponseWithRefresh(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
|
||||||
|
# Custom Bearer Transport that supports refresh tokens
|
||||||
|
class BearerTransportWithRefresh(BearerTransport):
|
||||||
|
async def get_login_response(self, token: str, refresh_token: str = None) -> Response:
|
||||||
|
if refresh_token:
|
||||||
|
bearer_response = BearerResponseWithRefresh(
|
||||||
|
access_token=token,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
token_type="bearer"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback to standard response if no refresh token
|
||||||
|
bearer_response = {
|
||||||
|
"access_token": token,
|
||||||
|
"token_type": "bearer"
|
||||||
|
}
|
||||||
|
return JSONResponse(bearer_response.dict() if hasattr(bearer_response, 'dict') else bearer_response)
|
||||||
|
|
||||||
|
# Custom Authentication Backend with Refresh Token Support
|
||||||
|
class AuthenticationBackendWithRefresh(AuthenticationBackend):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
transport: BearerTransportWithRefresh,
|
||||||
|
get_strategy,
|
||||||
|
get_refresh_strategy,
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.transport = transport
|
||||||
|
self.get_strategy = get_strategy
|
||||||
|
self.get_refresh_strategy = get_refresh_strategy
|
||||||
|
|
||||||
|
async def login(self, strategy, user) -> Response:
|
||||||
|
# Generate both access and refresh tokens
|
||||||
|
access_token = await strategy.write_token(user)
|
||||||
|
refresh_strategy = self.get_refresh_strategy()
|
||||||
|
refresh_token = await refresh_strategy.write_token(user)
|
||||||
|
|
||||||
|
return await self.transport.get_login_response(
|
||||||
|
token=access_token,
|
||||||
|
refresh_token=refresh_token
|
||||||
|
)
|
||||||
|
|
||||||
|
async def logout(self, strategy, user, token) -> Response:
|
||||||
|
return await self.transport.get_logout_response()
|
||||||
|
|
||||||
|
class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
|
||||||
|
reset_password_token_secret = settings.SECRET_KEY
|
||||||
|
verification_token_secret = settings.SECRET_KEY
|
||||||
|
|
||||||
|
async def on_after_register(self, user: User, request: Optional[Request] = None):
|
||||||
|
print(f"User {user.id} has registered.")
|
||||||
|
|
||||||
|
async def on_after_forgot_password(
|
||||||
|
self, user: User, token: str, request: Optional[Request] = None
|
||||||
|
):
|
||||||
|
print(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||||
|
|
||||||
|
async def on_after_request_verify(
|
||||||
|
self, user: User, token: str, request: Optional[Request] = None
|
||||||
|
):
|
||||||
|
print(f"Verification requested for user {user.id}. Verification token: {token}")
|
||||||
|
|
||||||
|
async def on_after_login(
|
||||||
|
self, user: User, request: Optional[Request] = None, response: Optional[Response] = None
|
||||||
|
):
|
||||||
|
print(f"User {user.id} has logged in.")
|
||||||
|
|
||||||
|
async def get_user_db(session: AsyncSession = Depends(get_session)):
|
||||||
|
yield SQLAlchemyUserDatabase(session, User)
|
||||||
|
|
||||||
|
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
|
||||||
|
yield UserManager(user_db)
|
||||||
|
|
||||||
|
# Updated transport with refresh token support
|
||||||
|
bearer_transport = BearerTransportWithRefresh(tokenUrl="auth/jwt/login")
|
||||||
|
|
||||||
|
def get_jwt_strategy() -> JWTStrategy:
|
||||||
|
return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60)
|
||||||
|
|
||||||
|
def get_refresh_jwt_strategy() -> JWTStrategy:
|
||||||
|
# Refresh tokens last longer - 7 days
|
||||||
|
return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=7 * 24 * 60 * 60)
|
||||||
|
|
||||||
|
# Updated auth backend with refresh token support
|
||||||
|
auth_backend = AuthenticationBackendWithRefresh(
|
||||||
|
name="jwt",
|
||||||
|
transport=bearer_transport,
|
||||||
|
get_strategy=get_jwt_strategy,
|
||||||
|
get_refresh_strategy=get_refresh_jwt_strategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
fastapi_users = FastAPIUsers[User, int](
|
||||||
|
get_user_manager,
|
||||||
|
[auth_backend],
|
||||||
|
)
|
||||||
|
|
||||||
|
current_active_user = fastapi_users.current_user(active=True)
|
||||||
|
current_superuser = fastapi_users.current_user(active=True, superuser=True)
|
197
be/app/config.py
197
be/app/config.py
@ -3,6 +3,8 @@ import os
|
|||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import logging
|
import logging
|
||||||
|
import secrets
|
||||||
|
from typing import List
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -10,39 +12,196 @@ logger = logging.getLogger(__name__)
|
|||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
DATABASE_URL: str | None = None
|
DATABASE_URL: str | None = None
|
||||||
GEMINI_API_KEY: str | None = None
|
GEMINI_API_KEY: str | None = None
|
||||||
|
SENTRY_DSN: str | None = None # Sentry DSN for error tracking
|
||||||
|
|
||||||
# --- JWT Settings ---
|
# --- Environment Settings ---
|
||||||
# Generate a strong secret key using: openssl rand -hex 32
|
ENVIRONMENT: str = "development" # development, staging, production
|
||||||
SECRET_KEY: str = "a_very_insecure_default_secret_key_replace_me" # !! MUST BE CHANGED IN PRODUCTION !!
|
|
||||||
ALGORITHM: str = "HS256"
|
# --- JWT Settings --- (SECRET_KEY is used by FastAPI-Users)
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # Default token lifetime: 30 minutes
|
SECRET_KEY: str # Must be set via environment variable
|
||||||
|
TOKEN_TYPE: str = "bearer" # Default token type for JWT authentication
|
||||||
|
# FastAPI-Users handles JWT algorithm internally
|
||||||
|
|
||||||
|
# --- OCR Settings ---
|
||||||
|
MAX_FILE_SIZE_MB: int = 10 # Maximum allowed file size for OCR processing
|
||||||
|
ALLOWED_IMAGE_TYPES: list[str] = ["image/jpeg", "image/png", "image/webp"] # Supported image formats
|
||||||
|
OCR_ITEM_EXTRACTION_PROMPT: str = """
|
||||||
|
Extract the shopping list items from this image.
|
||||||
|
List each distinct item on a new line.
|
||||||
|
Ignore prices, quantities, store names, discounts, taxes, totals, and other non-item text.
|
||||||
|
Focus only on the names of the products or items to be purchased.
|
||||||
|
Add 2 underscores before and after the item name, if it is struck through.
|
||||||
|
If the image does not appear to be a shopping list or receipt, state that clearly.
|
||||||
|
Example output for a grocery list:
|
||||||
|
Milk
|
||||||
|
Eggs
|
||||||
|
Bread
|
||||||
|
__Apples__
|
||||||
|
Organic Bananas
|
||||||
|
"""
|
||||||
|
# --- OCR Error Messages ---
|
||||||
|
OCR_SERVICE_UNAVAILABLE: str = "OCR service is currently unavailable. Please try again later."
|
||||||
|
OCR_SERVICE_CONFIG_ERROR: str = "OCR service configuration error. Please contact support."
|
||||||
|
OCR_UNEXPECTED_ERROR: str = "An unexpected error occurred during OCR processing."
|
||||||
|
OCR_QUOTA_EXCEEDED: str = "OCR service quota exceeded. Please try again later."
|
||||||
|
OCR_INVALID_FILE_TYPE: str = "Invalid file type. Supported types: {types}"
|
||||||
|
OCR_FILE_TOO_LARGE: str = "File too large. Maximum size: {size}MB"
|
||||||
|
OCR_PROCESSING_ERROR: str = "Error processing image: {detail}"
|
||||||
|
|
||||||
|
# --- Gemini AI Settings ---
|
||||||
|
GEMINI_MODEL_NAME: str = "gemini-2.0-flash" # The model to use for OCR
|
||||||
|
GEMINI_SAFETY_SETTINGS: dict = {
|
||||||
|
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
|
}
|
||||||
|
GEMINI_GENERATION_CONFIG: dict = {
|
||||||
|
"candidate_count": 1,
|
||||||
|
"max_output_tokens": 2048,
|
||||||
|
"temperature": 0.9,
|
||||||
|
"top_p": 1,
|
||||||
|
"top_k": 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- API Settings ---
|
||||||
|
API_PREFIX: str = "/api" # Base path for all API endpoints
|
||||||
|
API_OPENAPI_URL: str = "/api/openapi.json"
|
||||||
|
API_DOCS_URL: str = "/api/docs"
|
||||||
|
API_REDOC_URL: str = "/api/redoc"
|
||||||
|
|
||||||
|
# CORS Origins - environment dependent
|
||||||
|
CORS_ORIGINS: str = "http://localhost:5173,http://localhost:5174,http://localhost:8000,http://127.0.0.1:5173,http://127.0.0.1:5174,http://127.0.0.1:8000"
|
||||||
|
FRONTEND_URL: str = "http://localhost:5173" # URL for the frontend application
|
||||||
|
|
||||||
|
# --- API Metadata ---
|
||||||
|
API_TITLE: str = "Shared Lists API"
|
||||||
|
API_DESCRIPTION: str = "API for managing shared shopping lists, OCR, and cost splitting."
|
||||||
|
API_VERSION: str = "0.1.0"
|
||||||
|
ROOT_MESSAGE: str = "Welcome to the Shared Lists API! Docs available at /api/docs"
|
||||||
|
|
||||||
|
# --- Logging Settings ---
|
||||||
|
LOG_LEVEL: str = "WARNING"
|
||||||
|
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
|
||||||
|
# --- Health Check Settings ---
|
||||||
|
HEALTH_STATUS_OK: str = "ok"
|
||||||
|
HEALTH_STATUS_ERROR: str = "error"
|
||||||
|
|
||||||
|
# --- HTTP Status Messages ---
|
||||||
|
HTTP_400_DETAIL: str = "Bad Request"
|
||||||
|
HTTP_401_DETAIL: str = "Unauthorized"
|
||||||
|
HTTP_403_DETAIL: str = "Forbidden"
|
||||||
|
HTTP_404_DETAIL: str = "Not Found"
|
||||||
|
HTTP_422_DETAIL: str = "Unprocessable Entity"
|
||||||
|
HTTP_429_DETAIL: str = "Too Many Requests"
|
||||||
|
HTTP_500_DETAIL: str = "Internal Server Error"
|
||||||
|
HTTP_503_DETAIL: str = "Service Unavailable"
|
||||||
|
|
||||||
|
# --- Database Error Messages ---
|
||||||
|
DB_CONNECTION_ERROR: str = "Database connection error"
|
||||||
|
DB_INTEGRITY_ERROR: str = "Database integrity error"
|
||||||
|
DB_TRANSACTION_ERROR: str = "Database transaction error"
|
||||||
|
DB_QUERY_ERROR: str = "Database query error"
|
||||||
|
|
||||||
|
# --- Auth Error Messages ---
|
||||||
|
AUTH_INVALID_CREDENTIALS: str = "Invalid username or password"
|
||||||
|
AUTH_NOT_AUTHENTICATED: str = "Not authenticated"
|
||||||
|
AUTH_JWT_ERROR: str = "JWT token error: {error}"
|
||||||
|
AUTH_JWT_UNEXPECTED_ERROR: str = "Unexpected JWT error: {error}"
|
||||||
|
AUTH_HEADER_NAME: str = "WWW-Authenticate"
|
||||||
|
AUTH_HEADER_PREFIX: str = "Bearer"
|
||||||
|
|
||||||
|
# OAuth Settings
|
||||||
|
# IMPORTANT: For Google OAuth to work, you MUST set the following environment variables
|
||||||
|
# (e.g., in your .env file):
|
||||||
|
# GOOGLE_CLIENT_ID: Your Google Cloud project's OAuth 2.0 Client ID
|
||||||
|
# GOOGLE_CLIENT_SECRET: Your Google Cloud project's OAuth 2.0 Client Secret
|
||||||
|
# Ensure the GOOGLE_REDIRECT_URI below matches the one configured in your Google Cloud Console.
|
||||||
|
GOOGLE_CLIENT_ID: str = ""
|
||||||
|
GOOGLE_CLIENT_SECRET: str = ""
|
||||||
|
GOOGLE_REDIRECT_URI: str = "https://mitlistbe.mohamad.dev/api/v1/auth/google/callback"
|
||||||
|
|
||||||
|
APPLE_CLIENT_ID: str = ""
|
||||||
|
APPLE_TEAM_ID: str = ""
|
||||||
|
APPLE_KEY_ID: str = ""
|
||||||
|
APPLE_PRIVATE_KEY: str = ""
|
||||||
|
APPLE_REDIRECT_URI: str = "https://mitlistbe.mohamad.dev/api/v1/auth/apple/callback"
|
||||||
|
|
||||||
|
# Session Settings
|
||||||
|
SESSION_SECRET_KEY: str = "your-session-secret-key" # Change this in production
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES: int = 480 # 8 hours instead of 30 minutes
|
||||||
|
|
||||||
|
# Redis Settings
|
||||||
|
REDIS_URL: str = "redis://localhost:6379"
|
||||||
|
REDIS_PASSWORD: str = ""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
env_file_encoding = 'utf-8'
|
env_file_encoding = 'utf-8'
|
||||||
extra = "ignore"
|
extra = "ignore"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cors_origins_list(self) -> List[str]:
|
||||||
|
"""Convert CORS_ORIGINS string to list"""
|
||||||
|
return [origin.strip() for origin in self.CORS_ORIGINS.split(",")]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_production(self) -> bool:
|
||||||
|
"""Check if running in production environment"""
|
||||||
|
return self.ENVIRONMENT.lower() == "production"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_development(self) -> bool:
|
||||||
|
"""Check if running in development environment"""
|
||||||
|
return self.ENVIRONMENT.lower() == "development"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def docs_url(self) -> str | None:
|
||||||
|
"""Return docs URL only in development"""
|
||||||
|
return self.API_DOCS_URL if self.is_development else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def redoc_url(self) -> str | None:
|
||||||
|
"""Return redoc URL only in development"""
|
||||||
|
return self.API_REDOC_URL if self.is_development else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def openapi_url(self) -> str | None:
|
||||||
|
"""Return OpenAPI URL only in development"""
|
||||||
|
return self.API_OPENAPI_URL if self.is_development else None
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
# Validation for critical settings
|
# Validation for critical settings
|
||||||
if settings.DATABASE_URL is None:
|
if settings.DATABASE_URL is None:
|
||||||
print("Warning: DATABASE_URL environment variable not set.")
|
raise ValueError("DATABASE_URL environment variable must be set.")
|
||||||
# raise ValueError("DATABASE_URL environment variable not set.")
|
|
||||||
|
|
||||||
# CRITICAL: Check if the default secret key is being used
|
# Enforce secure secret key
|
||||||
if settings.SECRET_KEY == "a_very_insecure_default_secret_key_replace_me":
|
if not settings.SECRET_KEY:
|
||||||
print("*" * 80)
|
raise ValueError("SECRET_KEY environment variable must be set. Generate a secure key using: openssl rand -hex 32")
|
||||||
print("WARNING: Using default insecure SECRET_KEY. Please generate a strong key and set it in the environment variables!")
|
|
||||||
print("Use: openssl rand -hex 32")
|
# Validate secret key strength
|
||||||
print("*" * 80)
|
if len(settings.SECRET_KEY) < 32:
|
||||||
# Consider raising an error in a production environment check
|
raise ValueError("SECRET_KEY must be at least 32 characters long for security")
|
||||||
# if os.getenv("ENVIRONMENT") == "production":
|
|
||||||
# raise ValueError("Default SECRET_KEY is not allowed in production!")
|
# Production-specific validations
|
||||||
|
if settings.is_production:
|
||||||
|
if settings.SESSION_SECRET_KEY == "your-session-secret-key":
|
||||||
|
raise ValueError("SESSION_SECRET_KEY must be changed from default value in production")
|
||||||
|
|
||||||
|
if not settings.SENTRY_DSN:
|
||||||
|
logger.warning("SENTRY_DSN not set in production environment. Error tracking will be unavailable.")
|
||||||
|
|
||||||
if settings.GEMINI_API_KEY is None:
|
if settings.GEMINI_API_KEY is None:
|
||||||
print.error("CRITICAL: GEMINI_API_KEY environment variable not set. Gemini features will be unavailable.")
|
logger.error("CRITICAL: GEMINI_API_KEY environment variable not set. Gemini features will be unavailable.")
|
||||||
# You might raise an error here if Gemini is essential for startup
|
|
||||||
# raise ValueError("GEMINI_API_KEY must be set.")
|
|
||||||
else:
|
else:
|
||||||
# Optional: Log partial key for confirmation (avoid logging full key)
|
# Optional: Log partial key for confirmation (avoid logging full key)
|
||||||
logger.info(f"GEMINI_API_KEY loaded (starts with: {settings.GEMINI_API_KEY[:4]}...).")
|
logger.info(f"GEMINI_API_KEY loaded (starts with: {settings.GEMINI_API_KEY[:4]}...).")
|
||||||
|
|
||||||
|
# Log environment information
|
||||||
|
logger.info(f"Application starting in {settings.ENVIRONMENT} environment")
|
||||||
|
if settings.is_production:
|
||||||
|
logger.info("Production mode: API documentation disabled")
|
||||||
|
else:
|
||||||
|
logger.info(f"Development mode: API documentation available at {settings.API_DOCS_URL}")
|
102
be/app/core/api_config.py
Normal file
102
be/app/core/api_config.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
from typing import Dict, Any
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
# API Version
|
||||||
|
API_VERSION = "v1"
|
||||||
|
|
||||||
|
# API Prefix
|
||||||
|
API_PREFIX = f"/api/{API_VERSION}"
|
||||||
|
|
||||||
|
# API Endpoints
|
||||||
|
class APIEndpoints:
|
||||||
|
# Auth
|
||||||
|
AUTH = {
|
||||||
|
"LOGIN": "/auth/login",
|
||||||
|
"SIGNUP": "/auth/signup",
|
||||||
|
"REFRESH_TOKEN": "/auth/refresh-token",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Users
|
||||||
|
USERS = {
|
||||||
|
"PROFILE": "/users/profile",
|
||||||
|
"UPDATE_PROFILE": "/users/profile",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Lists
|
||||||
|
LISTS = {
|
||||||
|
"BASE": "/lists",
|
||||||
|
"BY_ID": "/lists/{id}",
|
||||||
|
"ITEMS": "/lists/{list_id}/items",
|
||||||
|
"ITEM": "/lists/{list_id}/items/{item_id}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Groups
|
||||||
|
GROUPS = {
|
||||||
|
"BASE": "/groups",
|
||||||
|
"BY_ID": "/groups/{id}",
|
||||||
|
"LISTS": "/groups/{group_id}/lists",
|
||||||
|
"MEMBERS": "/groups/{group_id}/members",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Invites
|
||||||
|
INVITES = {
|
||||||
|
"BASE": "/invites",
|
||||||
|
"BY_ID": "/invites/{id}",
|
||||||
|
"ACCEPT": "/invites/{id}/accept",
|
||||||
|
"DECLINE": "/invites/{id}/decline",
|
||||||
|
}
|
||||||
|
|
||||||
|
# OCR
|
||||||
|
OCR = {
|
||||||
|
"PROCESS": "/ocr/process",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Financials
|
||||||
|
FINANCIALS = {
|
||||||
|
"EXPENSES": "/financials/expenses",
|
||||||
|
"EXPENSE": "/financials/expenses/{id}",
|
||||||
|
"SETTLEMENTS": "/financials/settlements",
|
||||||
|
"SETTLEMENT": "/financials/settlements/{id}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Health
|
||||||
|
HEALTH = {
|
||||||
|
"CHECK": "/health",
|
||||||
|
}
|
||||||
|
|
||||||
|
# API Metadata
|
||||||
|
API_METADATA = {
|
||||||
|
"title": settings.API_TITLE,
|
||||||
|
"description": settings.API_DESCRIPTION,
|
||||||
|
"version": settings.API_VERSION,
|
||||||
|
"openapi_url": settings.API_OPENAPI_URL,
|
||||||
|
"docs_url": settings.API_DOCS_URL,
|
||||||
|
"redoc_url": settings.API_REDOC_URL,
|
||||||
|
}
|
||||||
|
|
||||||
|
# API Tags
|
||||||
|
API_TAGS = [
|
||||||
|
{"name": "Authentication", "description": "Authentication and authorization endpoints"},
|
||||||
|
{"name": "Users", "description": "User management endpoints"},
|
||||||
|
{"name": "Lists", "description": "Shopping list management endpoints"},
|
||||||
|
{"name": "Groups", "description": "Group management endpoints"},
|
||||||
|
{"name": "Invites", "description": "Group invitation management endpoints"},
|
||||||
|
{"name": "OCR", "description": "Optical Character Recognition endpoints"},
|
||||||
|
{"name": "Financials", "description": "Financial management endpoints"},
|
||||||
|
{"name": "Health", "description": "Health check endpoints"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Helper function to get full API URL
|
||||||
|
def get_api_url(endpoint: str, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Get the full API URL for an endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoint: The endpoint path
|
||||||
|
**kwargs: Path parameters to format the endpoint
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The full API URL
|
||||||
|
"""
|
||||||
|
formatted_endpoint = endpoint.format(**kwargs)
|
||||||
|
return f"{API_PREFIX}{formatted_endpoint}"
|
79
be/app/core/chore_utils.py
Normal file
79
be/app/core/chore_utils.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
from datetime import date, timedelta
|
||||||
|
from typing import Optional
|
||||||
|
from app.models import ChoreFrequencyEnum
|
||||||
|
|
||||||
|
def calculate_next_due_date(
|
||||||
|
current_due_date: date,
|
||||||
|
frequency: ChoreFrequencyEnum,
|
||||||
|
custom_interval_days: Optional[int] = None,
|
||||||
|
last_completed_date: Optional[date] = None
|
||||||
|
) -> date:
|
||||||
|
"""
|
||||||
|
Calculates the next due date for a chore.
|
||||||
|
Uses current_due_date as a base if last_completed_date is not provided.
|
||||||
|
Calculates from last_completed_date if provided.
|
||||||
|
"""
|
||||||
|
if frequency == ChoreFrequencyEnum.one_time:
|
||||||
|
if last_completed_date:
|
||||||
|
raise ValueError("Cannot calculate next due date for a completed one-time chore.")
|
||||||
|
return current_due_date
|
||||||
|
|
||||||
|
base_date = last_completed_date if last_completed_date else current_due_date
|
||||||
|
|
||||||
|
if hasattr(base_date, 'date') and callable(getattr(base_date, 'date')):
|
||||||
|
base_date = base_date.date() # type: ignore
|
||||||
|
|
||||||
|
next_due: date
|
||||||
|
|
||||||
|
if frequency == ChoreFrequencyEnum.daily:
|
||||||
|
next_due = base_date + timedelta(days=1)
|
||||||
|
elif frequency == ChoreFrequencyEnum.weekly:
|
||||||
|
next_due = base_date + timedelta(weeks=1)
|
||||||
|
elif frequency == ChoreFrequencyEnum.monthly:
|
||||||
|
month = base_date.month + 1
|
||||||
|
year = base_date.year + (month - 1) // 12
|
||||||
|
month = (month - 1) % 12 + 1
|
||||||
|
|
||||||
|
day_of_target_month_last = (date(year, month % 12 + 1, 1) - timedelta(days=1)).day if month < 12 else 31
|
||||||
|
day = min(base_date.day, day_of_target_month_last)
|
||||||
|
|
||||||
|
next_due = date(year, month, day)
|
||||||
|
elif frequency == ChoreFrequencyEnum.custom:
|
||||||
|
if not custom_interval_days or custom_interval_days <= 0:
|
||||||
|
raise ValueError("Custom frequency requires a positive custom_interval_days.")
|
||||||
|
next_due = base_date + timedelta(days=custom_interval_days)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown or unsupported chore frequency: {frequency}")
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
reference_future_date = max(today, base_date)
|
||||||
|
|
||||||
|
# This loop ensures the next_due date is always in the future relative to the reference_future_date.
|
||||||
|
while next_due <= reference_future_date:
|
||||||
|
current_base_for_recalc = next_due
|
||||||
|
|
||||||
|
if frequency == ChoreFrequencyEnum.daily:
|
||||||
|
next_due = current_base_for_recalc + timedelta(days=1)
|
||||||
|
elif frequency == ChoreFrequencyEnum.weekly:
|
||||||
|
next_due = current_base_for_recalc + timedelta(weeks=1)
|
||||||
|
elif frequency == ChoreFrequencyEnum.monthly:
|
||||||
|
month = current_base_for_recalc.month + 1
|
||||||
|
year = current_base_for_recalc.year + (month - 1) // 12
|
||||||
|
month = (month - 1) % 12 + 1
|
||||||
|
day_of_target_month_last = (date(year, month % 12 + 1, 1) - timedelta(days=1)).day if month < 12 else 31
|
||||||
|
day = min(current_base_for_recalc.day, day_of_target_month_last)
|
||||||
|
next_due = date(year, month, day)
|
||||||
|
elif frequency == ChoreFrequencyEnum.custom:
|
||||||
|
if not custom_interval_days or custom_interval_days <= 0: # Should have been validated
|
||||||
|
raise ValueError("Custom frequency requires positive interval during recalc.")
|
||||||
|
next_due = current_base_for_recalc + timedelta(days=custom_interval_days)
|
||||||
|
else: # Should not be reached
|
||||||
|
break
|
||||||
|
|
||||||
|
# Safety break: if date hasn't changed, interval is zero or logic error.
|
||||||
|
if next_due == current_base_for_recalc:
|
||||||
|
# Log error ideally, then advance by one day to prevent infinite loop.
|
||||||
|
next_due += timedelta(days=1)
|
||||||
|
break
|
||||||
|
|
||||||
|
return next_due
|
357
be/app/core/exceptions.py
Normal file
357
be/app/core/exceptions.py
Normal file
@ -0,0 +1,357 @@
|
|||||||
|
from fastapi import HTTPException, status
|
||||||
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||||
|
from app.config import settings
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
class ListNotFoundError(HTTPException):
|
||||||
|
"""Raised when a list is not found."""
|
||||||
|
def __init__(self, list_id: int):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"List {list_id} not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
class ListPermissionError(HTTPException):
|
||||||
|
"""Raised when a user doesn't have permission to access a list."""
|
||||||
|
def __init__(self, list_id: int, action: str = "access"):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"You do not have permission to {action} list {list_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
class ListCreatorRequiredError(HTTPException):
|
||||||
|
"""Raised when an action requires the list creator but the user is not the creator."""
|
||||||
|
def __init__(self, list_id: int, action: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"Only the list creator can {action} list {list_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
class GroupNotFoundError(HTTPException):
|
||||||
|
"""Raised when a group is not found."""
|
||||||
|
def __init__(self, group_id: int):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Group {group_id} not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
class GroupPermissionError(HTTPException):
|
||||||
|
"""Raised when a user doesn't have permission to perform an action in a group."""
|
||||||
|
def __init__(self, group_id: int, action: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"You do not have permission to {action} in group {group_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
class GroupMembershipError(HTTPException):
|
||||||
|
"""Raised when a user attempts to perform an action that requires group membership."""
|
||||||
|
def __init__(self, group_id: int, action: str = "access"):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"You must be a member of group {group_id} to {action}"
|
||||||
|
)
|
||||||
|
|
||||||
|
class GroupOperationError(HTTPException):
|
||||||
|
"""Raised when a group operation fails."""
|
||||||
|
def __init__(self, detail: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=detail
|
||||||
|
)
|
||||||
|
|
||||||
|
class GroupValidationError(HTTPException):
|
||||||
|
"""Raised when a group operation is invalid."""
|
||||||
|
def __init__(self, detail: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=detail
|
||||||
|
)
|
||||||
|
|
||||||
|
class ItemNotFoundError(HTTPException):
|
||||||
|
"""Raised when an item is not found."""
|
||||||
|
def __init__(self, item_id: int):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Item {item_id} not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
class UserNotFoundError(HTTPException):
|
||||||
|
"""Raised when a user is not found."""
|
||||||
|
def __init__(self, user_id: Optional[int] = None, identifier: Optional[str] = None):
|
||||||
|
detail_msg = "User not found."
|
||||||
|
if user_id:
|
||||||
|
detail_msg = f"User with ID {user_id} not found."
|
||||||
|
elif identifier:
|
||||||
|
detail_msg = f"User with identifier '{identifier}' not found."
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=detail_msg
|
||||||
|
)
|
||||||
|
|
||||||
|
class InvalidOperationError(HTTPException):
|
||||||
|
"""Raised when an operation is invalid or disallowed by business logic."""
|
||||||
|
def __init__(self, detail: str, status_code: int = status.HTTP_400_BAD_REQUEST):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status_code,
|
||||||
|
detail=detail
|
||||||
|
)
|
||||||
|
|
||||||
|
class DatabaseConnectionError(HTTPException):
|
||||||
|
"""Raised when there is an error connecting to the database."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail=settings.DB_CONNECTION_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
class DatabaseIntegrityError(HTTPException):
|
||||||
|
"""Raised when a database integrity constraint is violated."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=settings.DB_INTEGRITY_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
class DatabaseTransactionError(HTTPException):
|
||||||
|
"""Raised when a database transaction fails."""
|
||||||
|
def __init__(self, detail: str = settings.DB_TRANSACTION_ERROR):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=detail
|
||||||
|
)
|
||||||
|
|
||||||
|
class DatabaseQueryError(HTTPException):
|
||||||
|
"""Raised when a database query fails."""
|
||||||
|
def __init__(self, detail: str = settings.DB_QUERY_ERROR):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=detail
|
||||||
|
)
|
||||||
|
|
||||||
|
class ExpenseOperationError(HTTPException):
|
||||||
|
"""Raised when an expense operation fails."""
|
||||||
|
def __init__(self, detail: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=detail
|
||||||
|
)
|
||||||
|
|
||||||
|
class OCRServiceUnavailableError(HTTPException):
|
||||||
|
"""Raised when the OCR service is unavailable."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail=settings.OCR_SERVICE_UNAVAILABLE
|
||||||
|
)
|
||||||
|
|
||||||
|
class OCRServiceConfigError(HTTPException):
|
||||||
|
"""Raised when there is an error in the OCR service configuration."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=settings.OCR_SERVICE_CONFIG_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
class OCRUnexpectedError(HTTPException):
|
||||||
|
"""Raised when there is an unexpected error in the OCR service."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=settings.OCR_UNEXPECTED_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
class OCRQuotaExceededError(HTTPException):
|
||||||
|
"""Raised when the OCR service quota is exceeded."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
|
detail=settings.OCR_QUOTA_EXCEEDED
|
||||||
|
)
|
||||||
|
|
||||||
|
class InvalidFileTypeError(HTTPException):
|
||||||
|
"""Raised when an invalid file type is uploaded for OCR."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=settings.OCR_INVALID_FILE_TYPE.format(types=", ".join(settings.ALLOWED_IMAGE_TYPES))
|
||||||
|
)
|
||||||
|
|
||||||
|
class FileTooLargeError(HTTPException):
|
||||||
|
"""Raised when an uploaded file exceeds the size limit."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=settings.OCR_FILE_TOO_LARGE.format(size=settings.MAX_FILE_SIZE_MB)
|
||||||
|
)
|
||||||
|
|
||||||
|
class OCRProcessingError(HTTPException):
|
||||||
|
"""Raised when there is an error processing the image with OCR."""
|
||||||
|
def __init__(self, detail: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=settings.OCR_PROCESSING_ERROR.format(detail=detail)
|
||||||
|
)
|
||||||
|
|
||||||
|
class EmailAlreadyRegisteredError(HTTPException):
|
||||||
|
"""Raised when attempting to register with an email that is already in use."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Email already registered."
|
||||||
|
)
|
||||||
|
|
||||||
|
class UserCreationError(HTTPException):
|
||||||
|
"""Raised when there is an error creating a new user."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="An error occurred during user creation."
|
||||||
|
)
|
||||||
|
|
||||||
|
class InviteNotFoundError(HTTPException):
|
||||||
|
"""Raised when an invite is not found."""
|
||||||
|
def __init__(self, invite_code: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Invite code {invite_code} not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
class InviteExpiredError(HTTPException):
|
||||||
|
"""Raised when an invite has expired."""
|
||||||
|
def __init__(self, invite_code: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_410_GONE,
|
||||||
|
detail=f"Invite code {invite_code} has expired"
|
||||||
|
)
|
||||||
|
|
||||||
|
class InviteAlreadyUsedError(HTTPException):
|
||||||
|
"""Raised when an invite has already been used."""
|
||||||
|
def __init__(self, invite_code: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_410_GONE,
|
||||||
|
detail=f"Invite code {invite_code} has already been used"
|
||||||
|
)
|
||||||
|
|
||||||
|
class InviteCreationError(HTTPException):
|
||||||
|
"""Raised when an invite cannot be created."""
|
||||||
|
def __init__(self, group_id: int):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to create invite for group {group_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
class ListStatusNotFoundError(HTTPException):
|
||||||
|
"""Raised when a list's status cannot be retrieved."""
|
||||||
|
def __init__(self, list_id: int):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Status for list {list_id} not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
class InviteOperationError(HTTPException):
|
||||||
|
"""Raised when an invite operation fails."""
|
||||||
|
def __init__(self, detail: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=detail
|
||||||
|
)
|
||||||
|
|
||||||
|
class SettlementOperationError(HTTPException):
|
||||||
|
"""Raised when a settlement operation fails."""
|
||||||
|
def __init__(self, detail: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=detail
|
||||||
|
)
|
||||||
|
|
||||||
|
class ConflictError(HTTPException):
|
||||||
|
"""Raised when an optimistic lock version conflict occurs."""
|
||||||
|
def __init__(self, detail: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=detail
|
||||||
|
)
|
||||||
|
|
||||||
|
class InvalidCredentialsError(HTTPException):
|
||||||
|
"""Raised when login credentials are invalid."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=settings.AUTH_INVALID_CREDENTIALS,
|
||||||
|
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_credentials\""}
|
||||||
|
)
|
||||||
|
|
||||||
|
class NotAuthenticatedError(HTTPException):
|
||||||
|
"""Raised when the user is not authenticated."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=settings.AUTH_NOT_AUTHENTICATED,
|
||||||
|
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"not_authenticated\""}
|
||||||
|
)
|
||||||
|
|
||||||
|
class JWTError(HTTPException):
|
||||||
|
"""Raised when there is an error with the JWT token."""
|
||||||
|
def __init__(self, error: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=settings.AUTH_JWT_ERROR.format(error=error),
|
||||||
|
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
|
||||||
|
)
|
||||||
|
|
||||||
|
class JWTUnexpectedError(HTTPException):
|
||||||
|
"""Raised when there is an unexpected error with the JWT token."""
|
||||||
|
def __init__(self, error: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=settings.AUTH_JWT_UNEXPECTED_ERROR.format(error=error),
|
||||||
|
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
|
||||||
|
)
|
||||||
|
|
||||||
|
class ListOperationError(HTTPException):
|
||||||
|
"""Raised when a list operation fails."""
|
||||||
|
def __init__(self, detail: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=detail
|
||||||
|
)
|
||||||
|
|
||||||
|
class ItemOperationError(HTTPException):
|
||||||
|
"""Raised when an item operation fails."""
|
||||||
|
def __init__(self, detail: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=detail
|
||||||
|
)
|
||||||
|
|
||||||
|
class UserOperationError(HTTPException):
|
||||||
|
"""Raised when a user operation fails."""
|
||||||
|
def __init__(self, detail: str):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=detail
|
||||||
|
)
|
||||||
|
|
||||||
|
class ChoreNotFoundError(HTTPException):
|
||||||
|
"""Raised when a chore is not found."""
|
||||||
|
def __init__(self, chore_id: int, group_id: Optional[int] = None, detail: Optional[str] = None):
|
||||||
|
if detail:
|
||||||
|
error_detail = detail
|
||||||
|
elif group_id is not None:
|
||||||
|
error_detail = f"Chore {chore_id} not found in group {group_id}"
|
||||||
|
else:
|
||||||
|
error_detail = f"Chore {chore_id} not found"
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=error_detail
|
||||||
|
)
|
||||||
|
|
||||||
|
class PermissionDeniedError(HTTPException):
|
||||||
|
"""Raised when a user is denied permission for an action."""
|
||||||
|
def __init__(self, detail: str = "Permission denied."):
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=detail
|
||||||
|
)
|
||||||
|
|
||||||
|
# Financials & Cost Splitting specific errors
|
@ -4,8 +4,14 @@ from typing import List
|
|||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
from google.generativeai.types import HarmCategory, HarmBlockThreshold # For safety settings
|
from google.generativeai.types import HarmCategory, HarmBlockThreshold # For safety settings
|
||||||
from google.api_core import exceptions as google_exceptions
|
from google.api_core import exceptions as google_exceptions
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
from app.core.exceptions import (
|
||||||
|
OCRServiceUnavailableError,
|
||||||
|
OCRServiceConfigError,
|
||||||
|
OCRUnexpectedError,
|
||||||
|
OCRQuotaExceededError,
|
||||||
|
OCRProcessingError
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -19,26 +25,12 @@ try:
|
|||||||
genai.configure(api_key=settings.GEMINI_API_KEY)
|
genai.configure(api_key=settings.GEMINI_API_KEY)
|
||||||
# Initialize the specific model we want to use
|
# Initialize the specific model we want to use
|
||||||
gemini_flash_client = genai.GenerativeModel(
|
gemini_flash_client = genai.GenerativeModel(
|
||||||
model_name="gemini-2.0-flash",
|
model_name=settings.GEMINI_MODEL_NAME,
|
||||||
# Optional: Add default safety settings
|
generation_config=genai.types.GenerationConfig(
|
||||||
# Adjust these based on your expected content and risk tolerance
|
**settings.GEMINI_GENERATION_CONFIG
|
||||||
safety_settings={
|
)
|
||||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
|
||||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
|
||||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
|
||||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
|
||||||
},
|
|
||||||
# Optional: Add default generation config (can be overridden per request)
|
|
||||||
# generation_config=genai.types.GenerationConfig(
|
|
||||||
# # candidate_count=1, # Usually default is 1
|
|
||||||
# # stop_sequences=["\n"],
|
|
||||||
# # max_output_tokens=2048,
|
|
||||||
# # temperature=0.9, # Controls randomness (0=deterministic, >1=more random)
|
|
||||||
# # top_p=1,
|
|
||||||
# # top_k=1
|
|
||||||
# )
|
|
||||||
)
|
)
|
||||||
logger.info("Gemini AI client initialized successfully for model 'gemini-1.5-flash-latest'.")
|
logger.info(f"Gemini AI client initialized successfully for model '{settings.GEMINI_MODEL_NAME}'.")
|
||||||
else:
|
else:
|
||||||
# Store error if API key is missing
|
# Store error if API key is missing
|
||||||
gemini_initialization_error = "GEMINI_API_KEY not configured. Gemini client not initialized."
|
gemini_initialization_error = "GEMINI_API_KEY not configured. Gemini client not initialized."
|
||||||
@ -58,10 +50,10 @@ def get_gemini_client():
|
|||||||
Raises an exception if initialization failed.
|
Raises an exception if initialization failed.
|
||||||
"""
|
"""
|
||||||
if gemini_initialization_error:
|
if gemini_initialization_error:
|
||||||
raise RuntimeError(f"Gemini client could not be initialized: {gemini_initialization_error}")
|
raise OCRServiceConfigError()
|
||||||
if gemini_flash_client is None:
|
if gemini_flash_client is None:
|
||||||
# This case should ideally be covered by the check above, but as a safeguard:
|
# This case should ideally be covered by the check above, but as a safeguard:
|
||||||
raise RuntimeError("Gemini client is not available (unknown initialization issue).")
|
raise OCRServiceConfigError()
|
||||||
return gemini_flash_client
|
return gemini_flash_client
|
||||||
|
|
||||||
# Define the prompt as a constant
|
# Define the prompt as a constant
|
||||||
@ -79,37 +71,41 @@ Apples
|
|||||||
Organic Bananas
|
Organic Bananas
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def extract_items_from_image_gemini(image_bytes: bytes) -> List[str]:
|
async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "image/jpeg") -> List[str]:
|
||||||
"""
|
"""
|
||||||
Uses Gemini Flash to extract shopping list items from image bytes.
|
Uses Gemini Flash to extract shopping list items from image bytes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_bytes: The image content as bytes.
|
image_bytes: The image content as bytes.
|
||||||
|
mime_type: The MIME type of the image (e.g., "image/jpeg", "image/png", "image/webp").
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of extracted item strings.
|
A list of extracted item strings.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If the Gemini client is not initialized.
|
OCRServiceConfigError: If the Gemini client is not initialized.
|
||||||
google_exceptions.GoogleAPIError: For API call errors (quota, invalid key etc.).
|
OCRQuotaExceededError: If API quota is exceeded.
|
||||||
ValueError: If the response is blocked or contains no usable text.
|
OCRServiceUnavailableError: For general API call errors.
|
||||||
|
OCRProcessingError: If the response is blocked or contains no usable text.
|
||||||
|
OCRUnexpectedError: For any other unexpected errors.
|
||||||
"""
|
"""
|
||||||
client = get_gemini_client() # Raises RuntimeError if not initialized
|
|
||||||
|
|
||||||
# Prepare image part for multimodal input
|
|
||||||
image_part = {
|
|
||||||
"mime_type": "image/jpeg", # Or image/png, image/webp etc. Adjust if needed or detect mime type
|
|
||||||
"data": image_bytes
|
|
||||||
}
|
|
||||||
|
|
||||||
# Prepare the full prompt content
|
|
||||||
prompt_parts = [
|
|
||||||
OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first
|
|
||||||
image_part # Then the image
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.info("Sending image to Gemini for item extraction...")
|
|
||||||
try:
|
try:
|
||||||
|
client = get_gemini_client() # Raises OCRServiceConfigError if not initialized
|
||||||
|
|
||||||
|
# Prepare image part for multimodal input
|
||||||
|
image_part = {
|
||||||
|
"mime_type": mime_type,
|
||||||
|
"data": image_bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
# Prepare the full prompt content
|
||||||
|
prompt_parts = [
|
||||||
|
settings.OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first
|
||||||
|
image_part # Then the image
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info("Sending image to Gemini for item extraction...")
|
||||||
|
|
||||||
# Make the API call
|
# Make the API call
|
||||||
# Use generate_content_async for async FastAPI
|
# Use generate_content_async for async FastAPI
|
||||||
response = await client.generate_content_async(prompt_parts)
|
response = await client.generate_content_async(prompt_parts)
|
||||||
@ -122,9 +118,9 @@ async def extract_items_from_image_gemini(image_bytes: bytes) -> List[str]:
|
|||||||
finish_reason = response.candidates[0].finish_reason if response.candidates else 'UNKNOWN'
|
finish_reason = response.candidates[0].finish_reason if response.candidates else 'UNKNOWN'
|
||||||
safety_ratings = response.candidates[0].safety_ratings if response.candidates else 'N/A'
|
safety_ratings = response.candidates[0].safety_ratings if response.candidates else 'N/A'
|
||||||
if finish_reason == 'SAFETY':
|
if finish_reason == 'SAFETY':
|
||||||
raise ValueError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
|
raise OCRProcessingError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Gemini response was empty or incomplete. Finish Reason: {finish_reason}")
|
raise OCRUnexpectedError()
|
||||||
|
|
||||||
# Extract text - assumes the first part of the first candidate is the text response
|
# Extract text - assumes the first part of the first candidate is the text response
|
||||||
raw_text = response.text # response.text is a shortcut for response.candidates[0].content.parts[0].text
|
raw_text = response.text # response.text is a shortcut for response.candidates[0].content.parts[0].text
|
||||||
@ -145,10 +141,89 @@ async def extract_items_from_image_gemini(image_bytes: bytes) -> List[str]:
|
|||||||
|
|
||||||
except google_exceptions.GoogleAPIError as e:
|
except google_exceptions.GoogleAPIError as e:
|
||||||
logger.error(f"Gemini API Error: {e}", exc_info=True)
|
logger.error(f"Gemini API Error: {e}", exc_info=True)
|
||||||
# Re-raise specific Google API errors for endpoint to handle (e.g., quota)
|
if "quota" in str(e).lower():
|
||||||
raise e
|
raise OCRQuotaExceededError()
|
||||||
|
raise OCRServiceUnavailableError()
|
||||||
|
except (OCRServiceConfigError, OCRQuotaExceededError, OCRServiceUnavailableError, OCRProcessingError, OCRUnexpectedError):
|
||||||
|
# Re-raise specific OCR exceptions
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Catch other unexpected errors during generation or processing
|
# Catch other unexpected errors during generation or processing
|
||||||
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
|
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
|
||||||
# Wrap in a generic ValueError or re-raise
|
# Wrap in a custom exception
|
||||||
raise ValueError(f"Failed to process image with Gemini: {e}") from e
|
raise OCRUnexpectedError()
|
||||||
|
|
||||||
|
class GeminiOCRService:
|
||||||
|
def __init__(self):
|
||||||
|
try:
|
||||||
|
genai.configure(api_key=settings.GEMINI_API_KEY)
|
||||||
|
self.model = genai.GenerativeModel(
|
||||||
|
model_name=settings.GEMINI_MODEL_NAME,
|
||||||
|
generation_config=genai.types.GenerationConfig(
|
||||||
|
**settings.GEMINI_GENERATION_CONFIG
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize Gemini client: {e}")
|
||||||
|
raise OCRServiceConfigError()
|
||||||
|
|
||||||
|
async def extract_items(self, image_data: bytes, mime_type: str = "image/jpeg") -> List[str]:
|
||||||
|
"""
|
||||||
|
Extract shopping list items from an image using Gemini Vision.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_data: The image content as bytes.
|
||||||
|
mime_type: The MIME type of the image (e.g., "image/jpeg", "image/png", "image/webp").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of extracted item strings.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OCRServiceConfigError: If the Gemini client is not initialized.
|
||||||
|
OCRQuotaExceededError: If API quota is exceeded.
|
||||||
|
OCRServiceUnavailableError: For general API call errors.
|
||||||
|
OCRProcessingError: If the response is blocked or contains no usable text.
|
||||||
|
OCRUnexpectedError: For any other unexpected errors.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Create image part
|
||||||
|
image_parts = [{"mime_type": mime_type, "data": image_data}]
|
||||||
|
|
||||||
|
# Generate content
|
||||||
|
response = await self.model.generate_content_async(
|
||||||
|
contents=[settings.OCR_ITEM_EXTRACTION_PROMPT, *image_parts]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process response
|
||||||
|
if not response.text:
|
||||||
|
logger.warning("Gemini response is empty")
|
||||||
|
raise OCRUnexpectedError()
|
||||||
|
|
||||||
|
# Check for safety blocks
|
||||||
|
if hasattr(response, 'candidates') and response.candidates and hasattr(response.candidates[0], 'finish_reason'):
|
||||||
|
finish_reason = response.candidates[0].finish_reason
|
||||||
|
if finish_reason == 'SAFETY':
|
||||||
|
safety_ratings = response.candidates[0].safety_ratings if hasattr(response.candidates[0], 'safety_ratings') else 'N/A'
|
||||||
|
raise OCRProcessingError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
|
||||||
|
|
||||||
|
# Split response into lines and clean up
|
||||||
|
items = []
|
||||||
|
for line in response.text.splitlines():
|
||||||
|
cleaned_line = line.strip()
|
||||||
|
if cleaned_line and len(cleaned_line) > 1 and not cleaned_line.startswith("Example"):
|
||||||
|
items.append(cleaned_line)
|
||||||
|
|
||||||
|
logger.info(f"Extracted {len(items)} potential items.")
|
||||||
|
return items
|
||||||
|
|
||||||
|
except google_exceptions.GoogleAPIError as e:
|
||||||
|
logger.error(f"Error during OCR extraction: {e}")
|
||||||
|
if "quota" in str(e).lower():
|
||||||
|
raise OCRQuotaExceededError()
|
||||||
|
raise OCRServiceUnavailableError()
|
||||||
|
except (OCRServiceConfigError, OCRQuotaExceededError, OCRServiceUnavailableError, OCRProcessingError, OCRUnexpectedError):
|
||||||
|
# Re-raise specific OCR exceptions
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
|
||||||
|
raise OCRUnexpectedError()
|
73
be/app/core/scheduler.py
Normal file
73
be/app/core/scheduler.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
|
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
|
||||||
|
from apscheduler.executors.pool import ThreadPoolExecutor
|
||||||
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||||
|
from app.config import settings
|
||||||
|
from app.jobs.recurring_expenses import generate_recurring_expenses
|
||||||
|
from app.db.session import async_session
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Convert async database URL to sync URL for APScheduler
|
||||||
|
# Replace postgresql+asyncpg:// with postgresql://
|
||||||
|
sync_db_url = settings.DATABASE_URL.replace('postgresql+asyncpg://', 'postgresql://')
|
||||||
|
|
||||||
|
# Configure the scheduler
|
||||||
|
jobstores = {
|
||||||
|
'default': SQLAlchemyJobStore(url=sync_db_url)
|
||||||
|
}
|
||||||
|
|
||||||
|
executors = {
|
||||||
|
'default': ThreadPoolExecutor(20)
|
||||||
|
}
|
||||||
|
|
||||||
|
job_defaults = {
|
||||||
|
'coalesce': False,
|
||||||
|
'max_instances': 1
|
||||||
|
}
|
||||||
|
|
||||||
|
scheduler = AsyncIOScheduler(
|
||||||
|
jobstores=jobstores,
|
||||||
|
executors=executors,
|
||||||
|
job_defaults=job_defaults,
|
||||||
|
timezone='UTC'
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_recurring_expenses_job():
|
||||||
|
"""Wrapper function to run the recurring expenses job with a database session."""
|
||||||
|
try:
|
||||||
|
async with async_session() as session:
|
||||||
|
await generate_recurring_expenses(session)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error running recurring expenses job: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def init_scheduler():
|
||||||
|
"""Initialize and start the scheduler."""
|
||||||
|
try:
|
||||||
|
# Add the recurring expenses job
|
||||||
|
scheduler.add_job(
|
||||||
|
run_recurring_expenses_job,
|
||||||
|
trigger=CronTrigger(hour=0, minute=0), # Run at midnight UTC
|
||||||
|
id='generate_recurring_expenses',
|
||||||
|
name='Generate Recurring Expenses',
|
||||||
|
replace_existing=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start the scheduler
|
||||||
|
scheduler.start()
|
||||||
|
logger.info("Scheduler started successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error initializing scheduler: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def shutdown_scheduler():
|
||||||
|
"""Shutdown the scheduler gracefully."""
|
||||||
|
try:
|
||||||
|
scheduler.shutdown()
|
||||||
|
logger.info("Scheduler shut down successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error shutting down scheduler: {str(e)}")
|
||||||
|
raise
|
@ -8,6 +8,9 @@ from passlib.context import CryptContext
|
|||||||
from app.config import settings # Import settings from config
|
from app.config import settings # Import settings from config
|
||||||
|
|
||||||
# --- Password Hashing ---
|
# --- Password Hashing ---
|
||||||
|
# These functions are used for password hashing and verification
|
||||||
|
# They complement FastAPI-Users but provide direct access to the underlying password functionality
|
||||||
|
# when needed outside of the FastAPI-Users authentication flow.
|
||||||
|
|
||||||
# Configure passlib context
|
# Configure passlib context
|
||||||
# Using bcrypt as the default hashing scheme
|
# Using bcrypt as the default hashing scheme
|
||||||
@ -17,6 +20,8 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|||||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Verifies a plain text password against a hashed password.
|
Verifies a plain text password against a hashed password.
|
||||||
|
This is used by FastAPI-Users internally, but also exposed here for custom authentication flows
|
||||||
|
if needed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
plain_password: The password attempt.
|
plain_password: The password attempt.
|
||||||
@ -34,6 +39,8 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|||||||
def hash_password(password: str) -> str:
|
def hash_password(password: str) -> str:
|
||||||
"""
|
"""
|
||||||
Hashes a plain text password using the configured context (bcrypt).
|
Hashes a plain text password using the configured context (bcrypt).
|
||||||
|
This is used by FastAPI-Users internally, but also exposed here for
|
||||||
|
custom user creation or password reset flows if needed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
password: The plain text password to hash.
|
password: The plain text password to hash.
|
||||||
@ -45,66 +52,22 @@ def hash_password(password: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
# --- JSON Web Tokens (JWT) ---
|
# --- JSON Web Tokens (JWT) ---
|
||||||
|
# FastAPI-Users now handles all JWT token creation and validation.
|
||||||
|
# The code below is commented out because FastAPI-Users provides these features.
|
||||||
|
# It's kept for reference in case a custom implementation is needed later.
|
||||||
|
|
||||||
def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
# Example of a potential future implementation:
|
||||||
"""
|
|
||||||
Creates a JWT access token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
subject: The subject of the token (e.g., user ID or email).
|
|
||||||
expires_delta: Optional timedelta object for token expiry. If None,
|
|
||||||
uses ACCESS_TOKEN_EXPIRE_MINUTES from settings.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The encoded JWT access token string.
|
|
||||||
"""
|
|
||||||
if expires_delta:
|
|
||||||
expire = datetime.now(timezone.utc) + expires_delta
|
|
||||||
else:
|
|
||||||
expire = datetime.now(timezone.utc) + timedelta(
|
|
||||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
|
||||||
)
|
|
||||||
|
|
||||||
# Data to encode in the token payload
|
|
||||||
to_encode = {"exp": expire, "sub": str(subject)}
|
|
||||||
|
|
||||||
encoded_jwt = jwt.encode(
|
|
||||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
|
||||||
)
|
|
||||||
return encoded_jwt
|
|
||||||
|
|
||||||
def verify_access_token(token: str) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
Verifies a JWT access token and returns its payload if valid.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token: The JWT token string to verify.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The decoded token payload (dict) if the token is valid and not expired,
|
|
||||||
otherwise None.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Decode the token. This also automatically verifies:
|
|
||||||
# - Signature (using SECRET_KEY and ALGORITHM)
|
|
||||||
# - Expiration ('exp' claim)
|
|
||||||
payload = jwt.decode(
|
|
||||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
|
||||||
)
|
|
||||||
return payload
|
|
||||||
except JWTError as e:
|
|
||||||
# Handles InvalidSignatureError, ExpiredSignatureError, etc.
|
|
||||||
print(f"JWT Error: {e}") # Log the error for debugging
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
# Handle other potential unexpected errors during decoding
|
|
||||||
print(f"Unexpected error decoding JWT: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# You might add a function here later to extract the 'sub' (subject/user id)
|
|
||||||
# specifically, often used in dependency injection for authentication.
|
|
||||||
# def get_subject_from_token(token: str) -> Optional[str]:
|
# def get_subject_from_token(token: str) -> Optional[str]:
|
||||||
# payload = verify_access_token(token)
|
# """
|
||||||
# if payload:
|
# Extract the subject (user ID) from a JWT token.
|
||||||
|
# This would be used if we need to validate tokens outside of FastAPI-Users flow.
|
||||||
|
# For now, use fastapi_users.current_user dependency instead.
|
||||||
|
# """
|
||||||
|
# # This would need to use FastAPI-Users' token verification if ever implemented
|
||||||
|
# # For example, by decoding the token using the strategy from the auth backend
|
||||||
|
# try:
|
||||||
|
# payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||||
# return payload.get("sub")
|
# return payload.get("sub")
|
||||||
|
# except JWTError:
|
||||||
|
# return None
|
||||||
# return None
|
# return None
|
@ -1,83 +0,0 @@
|
|||||||
# be/tests/core/test_gemini.py
|
|
||||||
import pytest
|
|
||||||
import os
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
|
|
||||||
# Store original key if exists, then clear it for testing missing key scenario
|
|
||||||
original_api_key = os.environ.get("GEMINI_API_KEY")
|
|
||||||
if "GEMINI_API_KEY" in os.environ:
|
|
||||||
del os.environ["GEMINI_API_KEY"]
|
|
||||||
|
|
||||||
# --- Test Module Import ---
|
|
||||||
# This forces the module-level initialization code in gemini.py to run
|
|
||||||
# We need to reload modules because settings might have been cached
|
|
||||||
from importlib import reload
|
|
||||||
from app.config import settings as app_settings
|
|
||||||
from app.core import gemini as gemini_core
|
|
||||||
|
|
||||||
# Reload settings first to ensure GEMINI_API_KEY is None initially
|
|
||||||
reload(app_settings)
|
|
||||||
# Reload gemini core to trigger initialization logic with potentially missing key
|
|
||||||
reload(gemini_core)
|
|
||||||
|
|
||||||
|
|
||||||
def test_gemini_initialization_without_key():
|
|
||||||
"""Verify behavior when GEMINI_API_KEY is not set."""
|
|
||||||
# Reload modules again to ensure clean state for this specific test
|
|
||||||
if "GEMINI_API_KEY" in os.environ:
|
|
||||||
del os.environ["GEMINI_API_KEY"]
|
|
||||||
reload(app_settings)
|
|
||||||
reload(gemini_core)
|
|
||||||
|
|
||||||
assert gemini_core.gemini_flash_client is None
|
|
||||||
assert gemini_core.gemini_initialization_error is not None
|
|
||||||
assert "GEMINI_API_KEY not configured" in gemini_core.gemini_initialization_error
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="GEMINI_API_KEY not configured"):
|
|
||||||
gemini_core.get_gemini_client()
|
|
||||||
|
|
||||||
@patch('google.generativeai.configure')
|
|
||||||
@patch('google.generativeai.GenerativeModel')
|
|
||||||
def test_gemini_initialization_with_key(mock_generative_model: MagicMock, mock_configure: MagicMock):
|
|
||||||
"""Verify initialization logic is called when key is present (using mocks)."""
|
|
||||||
# Set a dummy key in the environment for this test
|
|
||||||
test_key = "TEST_API_KEY_123"
|
|
||||||
os.environ["GEMINI_API_KEY"] = test_key
|
|
||||||
|
|
||||||
# Reload settings and gemini module to pick up the new key
|
|
||||||
reload(app_settings)
|
|
||||||
reload(gemini_core)
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
mock_configure.assert_called_once_with(api_key=test_key)
|
|
||||||
mock_generative_model.assert_called_once_with(
|
|
||||||
model_name="gemini-1.5-flash-latest",
|
|
||||||
safety_settings=pytest.ANY, # Check safety settings were passed (ANY allows flexibility)
|
|
||||||
# generation_config=pytest.ANY # Check if you added default generation config
|
|
||||||
)
|
|
||||||
assert gemini_core.gemini_flash_client is not None
|
|
||||||
assert gemini_core.gemini_initialization_error is None
|
|
||||||
|
|
||||||
# Test get_gemini_client() success path
|
|
||||||
client = gemini_core.get_gemini_client()
|
|
||||||
assert client is not None # Should return the mocked client instance
|
|
||||||
|
|
||||||
# Clean up environment variable after test
|
|
||||||
if original_api_key:
|
|
||||||
os.environ["GEMINI_API_KEY"] = original_api_key
|
|
||||||
else:
|
|
||||||
if "GEMINI_API_KEY" in os.environ:
|
|
||||||
del os.environ["GEMINI_API_KEY"]
|
|
||||||
# Reload modules one last time to restore state for other tests
|
|
||||||
reload(app_settings)
|
|
||||||
reload(gemini_core)
|
|
||||||
|
|
||||||
# Restore original key after all tests in the module run (if needed)
|
|
||||||
def teardown_module(module):
|
|
||||||
if original_api_key:
|
|
||||||
os.environ["GEMINI_API_KEY"] = original_api_key
|
|
||||||
else:
|
|
||||||
if "GEMINI_API_KEY" in os.environ:
|
|
||||||
del os.environ["GEMINI_API_KEY"]
|
|
||||||
reload(app_settings)
|
|
||||||
reload(gemini_core)
|
|
@ -1,86 +0,0 @@
|
|||||||
# Example: be/tests/core/test_security.py
|
|
||||||
import pytest
|
|
||||||
from datetime import timedelta
|
|
||||||
from jose import jwt, JWTError
|
|
||||||
import time
|
|
||||||
|
|
||||||
from app.core.security import (
|
|
||||||
hash_password,
|
|
||||||
verify_password,
|
|
||||||
create_access_token,
|
|
||||||
verify_access_token,
|
|
||||||
)
|
|
||||||
from app.config import settings # Import settings for testing JWT config
|
|
||||||
|
|
||||||
# --- Password Hashing Tests ---
|
|
||||||
|
|
||||||
def test_hash_password_returns_string():
|
|
||||||
password = "testpassword"
|
|
||||||
hashed = hash_password(password)
|
|
||||||
assert isinstance(hashed, str)
|
|
||||||
assert password != hashed # Ensure it's not plain text
|
|
||||||
|
|
||||||
def test_verify_password_correct():
|
|
||||||
password = "correct_password"
|
|
||||||
hashed = hash_password(password)
|
|
||||||
assert verify_password(password, hashed) is True
|
|
||||||
|
|
||||||
def test_verify_password_incorrect():
|
|
||||||
hashed = hash_password("correct_password")
|
|
||||||
assert verify_password("wrong_password", hashed) is False
|
|
||||||
|
|
||||||
def test_verify_password_invalid_hash_format():
|
|
||||||
# Passlib's verify handles many format errors gracefully
|
|
||||||
assert verify_password("any_password", "invalid_hash_string") is False
|
|
||||||
|
|
||||||
|
|
||||||
# --- JWT Tests ---
|
|
||||||
|
|
||||||
def test_create_access_token():
|
|
||||||
subject = "testuser@example.com"
|
|
||||||
token = create_access_token(subject=subject)
|
|
||||||
assert isinstance(token, str)
|
|
||||||
|
|
||||||
# Decode manually for basic check (verification done in verify_access_token tests)
|
|
||||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
|
||||||
assert payload["sub"] == subject
|
|
||||||
assert "exp" in payload
|
|
||||||
assert isinstance(payload["exp"], int)
|
|
||||||
|
|
||||||
def test_verify_access_token_valid():
|
|
||||||
subject = "test_subject_valid"
|
|
||||||
token = create_access_token(subject=subject)
|
|
||||||
payload = verify_access_token(token)
|
|
||||||
assert payload is not None
|
|
||||||
assert payload["sub"] == subject
|
|
||||||
|
|
||||||
def test_verify_access_token_invalid_signature():
|
|
||||||
subject = "test_subject_invalid_sig"
|
|
||||||
token = create_access_token(subject=subject)
|
|
||||||
# Attempt to verify with a wrong key
|
|
||||||
wrong_key = settings.SECRET_KEY + "wrong"
|
|
||||||
with pytest.raises(JWTError): # Decoding with wrong key should raise JWTError internally
|
|
||||||
jwt.decode(token, wrong_key, algorithms=[settings.ALGORITHM])
|
|
||||||
# Our verify function should catch this and return None
|
|
||||||
assert verify_access_token(token + "tamper") is None # Tampering token often invalidates sig
|
|
||||||
# Note: Testing verify_access_token directly returning None for wrong key is tricky
|
|
||||||
# as the error happens *during* jwt.decode. We rely on it catching JWTError.
|
|
||||||
|
|
||||||
def test_verify_access_token_expired():
|
|
||||||
# Create a token that expires almost immediately
|
|
||||||
subject = "test_subject_expired"
|
|
||||||
expires_delta = timedelta(seconds=-1) # Expired 1 second ago
|
|
||||||
token = create_access_token(subject=subject, expires_delta=expires_delta)
|
|
||||||
|
|
||||||
# Wait briefly just in case of timing issues, though negative delta should guarantee expiry
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
# Decoding expired token raises ExpiredSignatureError internally
|
|
||||||
with pytest.raises(JWTError): # Specifically ExpiredSignatureError, but JWTError catches it
|
|
||||||
jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
|
||||||
|
|
||||||
# Our verify function should catch this and return None
|
|
||||||
assert verify_access_token(token) is None
|
|
||||||
|
|
||||||
def test_verify_access_token_malformed():
|
|
||||||
assert verify_access_token("this.is.not.a.valid.token") is None
|
|
505
be/app/crud/chore.py
Normal file
505
be/app/crud/chore.py
Normal file
@ -0,0 +1,505 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
from sqlalchemy import union_all
|
||||||
|
from typing import List, Optional
|
||||||
|
import logging
|
||||||
|
from datetime import date, datetime
|
||||||
|
|
||||||
|
from app.models import Chore, Group, User, ChoreAssignment, ChoreFrequencyEnum, ChoreTypeEnum, UserGroup
|
||||||
|
from app.schemas.chore import ChoreCreate, ChoreUpdate, ChoreAssignmentCreate, ChoreAssignmentUpdate
|
||||||
|
from app.core.chore_utils import calculate_next_due_date
|
||||||
|
from app.crud.group import get_group_by_id, is_user_member
|
||||||
|
from app.core.exceptions import ChoreNotFoundError, GroupNotFoundError, PermissionDeniedError, DatabaseIntegrityError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def get_all_user_chores(db: AsyncSession, user_id: int) -> List[Chore]:
|
||||||
|
"""Gets all chores (personal and group) for a user in optimized queries."""
|
||||||
|
|
||||||
|
# Get personal chores query
|
||||||
|
personal_chores_query = (
|
||||||
|
select(Chore)
|
||||||
|
.where(
|
||||||
|
Chore.created_by_id == user_id,
|
||||||
|
Chore.type == ChoreTypeEnum.personal
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get user's group IDs first
|
||||||
|
user_groups_result = await db.execute(
|
||||||
|
select(UserGroup.group_id).where(UserGroup.user_id == user_id)
|
||||||
|
)
|
||||||
|
user_group_ids = user_groups_result.scalars().all()
|
||||||
|
|
||||||
|
all_chores = []
|
||||||
|
|
||||||
|
# Execute personal chores query
|
||||||
|
personal_result = await db.execute(
|
||||||
|
personal_chores_query
|
||||||
|
.options(
|
||||||
|
selectinload(Chore.creator),
|
||||||
|
selectinload(Chore.assignments).selectinload(ChoreAssignment.assigned_user)
|
||||||
|
)
|
||||||
|
.order_by(Chore.next_due_date, Chore.name)
|
||||||
|
)
|
||||||
|
all_chores.extend(personal_result.scalars().all())
|
||||||
|
|
||||||
|
# If user has groups, get all group chores in one query
|
||||||
|
if user_group_ids:
|
||||||
|
group_chores_result = await db.execute(
|
||||||
|
select(Chore)
|
||||||
|
.where(
|
||||||
|
Chore.group_id.in_(user_group_ids),
|
||||||
|
Chore.type == ChoreTypeEnum.group
|
||||||
|
)
|
||||||
|
.options(
|
||||||
|
selectinload(Chore.creator),
|
||||||
|
selectinload(Chore.group),
|
||||||
|
selectinload(Chore.assignments).selectinload(ChoreAssignment.assigned_user)
|
||||||
|
)
|
||||||
|
.order_by(Chore.next_due_date, Chore.name)
|
||||||
|
)
|
||||||
|
all_chores.extend(group_chores_result.scalars().all())
|
||||||
|
|
||||||
|
return all_chores
|
||||||
|
|
||||||
|
async def create_chore(
|
||||||
|
db: AsyncSession,
|
||||||
|
chore_in: ChoreCreate,
|
||||||
|
user_id: int,
|
||||||
|
group_id: Optional[int] = None
|
||||||
|
) -> Chore:
|
||||||
|
"""Creates a new chore, either personal or within a specific group."""
|
||||||
|
# Use the transaction pattern from the FastAPI strategy
|
||||||
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
|
if chore_in.type == ChoreTypeEnum.group:
|
||||||
|
if not group_id:
|
||||||
|
raise ValueError("group_id is required for group chores")
|
||||||
|
# Validate group existence and user membership
|
||||||
|
group = await get_group_by_id(db, group_id)
|
||||||
|
if not group:
|
||||||
|
raise GroupNotFoundError(group_id)
|
||||||
|
if not await is_user_member(db, group_id, user_id):
|
||||||
|
raise PermissionDeniedError(detail=f"User {user_id} not a member of group {group_id}")
|
||||||
|
else: # personal chore
|
||||||
|
if group_id:
|
||||||
|
raise ValueError("group_id must be None for personal chores")
|
||||||
|
|
||||||
|
db_chore = Chore(
|
||||||
|
**chore_in.model_dump(exclude_unset=True, exclude={'group_id'}),
|
||||||
|
group_id=group_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Specific check for custom frequency
|
||||||
|
if chore_in.frequency == ChoreFrequencyEnum.custom and chore_in.custom_interval_days is None:
|
||||||
|
raise ValueError("custom_interval_days must be set for custom frequency chores.")
|
||||||
|
|
||||||
|
db.add(db_chore)
|
||||||
|
await db.flush() # Get the ID for the chore
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load relationships for the response with eager loading
|
||||||
|
result = await db.execute(
|
||||||
|
select(Chore)
|
||||||
|
.where(Chore.id == db_chore.id)
|
||||||
|
.options(
|
||||||
|
selectinload(Chore.creator),
|
||||||
|
selectinload(Chore.group),
|
||||||
|
selectinload(Chore.assignments)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating chore: {e}", exc_info=True)
|
||||||
|
raise DatabaseIntegrityError(f"Could not create chore. Error: {str(e)}")
|
||||||
|
|
||||||
|
async def get_chore_by_id(db: AsyncSession, chore_id: int) -> Optional[Chore]:
|
||||||
|
"""Gets a chore by ID."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Chore)
|
||||||
|
.where(Chore.id == chore_id)
|
||||||
|
.options(selectinload(Chore.creator), selectinload(Chore.group), selectinload(Chore.assignments))
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_chore_by_id_and_group(
|
||||||
|
db: AsyncSession,
|
||||||
|
chore_id: int,
|
||||||
|
group_id: int,
|
||||||
|
user_id: int
|
||||||
|
) -> Optional[Chore]:
|
||||||
|
"""Gets a specific group chore by ID, ensuring it belongs to the group and user is a member."""
|
||||||
|
if not await is_user_member(db, group_id, user_id):
|
||||||
|
raise PermissionDeniedError(detail=f"User {user_id} not a member of group {group_id}")
|
||||||
|
|
||||||
|
chore = await get_chore_by_id(db, chore_id)
|
||||||
|
if chore and chore.group_id == group_id and chore.type == ChoreTypeEnum.group:
|
||||||
|
return chore
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_personal_chores(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: int
|
||||||
|
) -> List[Chore]:
|
||||||
|
"""Gets all personal chores for a user with optimized eager loading."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Chore)
|
||||||
|
.where(
|
||||||
|
Chore.created_by_id == user_id,
|
||||||
|
Chore.type == ChoreTypeEnum.personal
|
||||||
|
)
|
||||||
|
.options(
|
||||||
|
selectinload(Chore.creator),
|
||||||
|
selectinload(Chore.assignments).selectinload(ChoreAssignment.assigned_user)
|
||||||
|
)
|
||||||
|
.order_by(Chore.next_due_date, Chore.name)
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_chores_by_group_id(
|
||||||
|
db: AsyncSession,
|
||||||
|
group_id: int,
|
||||||
|
user_id: int
|
||||||
|
) -> List[Chore]:
|
||||||
|
"""Gets all chores for a specific group with optimized eager loading, if the user is a member."""
|
||||||
|
if not await is_user_member(db, group_id, user_id):
|
||||||
|
raise PermissionDeniedError(detail=f"User {user_id} not a member of group {group_id}")
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Chore)
|
||||||
|
.where(
|
||||||
|
Chore.group_id == group_id,
|
||||||
|
Chore.type == ChoreTypeEnum.group
|
||||||
|
)
|
||||||
|
.options(
|
||||||
|
selectinload(Chore.creator),
|
||||||
|
selectinload(Chore.assignments).selectinload(ChoreAssignment.assigned_user)
|
||||||
|
)
|
||||||
|
.order_by(Chore.next_due_date, Chore.name)
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def update_chore(
|
||||||
|
db: AsyncSession,
|
||||||
|
chore_id: int,
|
||||||
|
chore_in: ChoreUpdate,
|
||||||
|
user_id: int,
|
||||||
|
group_id: Optional[int] = None
|
||||||
|
) -> Optional[Chore]:
|
||||||
|
"""Updates a chore's details using proper transaction management."""
|
||||||
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
|
db_chore = await get_chore_by_id(db, chore_id)
|
||||||
|
if not db_chore:
|
||||||
|
raise ChoreNotFoundError(chore_id, group_id)
|
||||||
|
|
||||||
|
# Check permissions
|
||||||
|
if db_chore.type == ChoreTypeEnum.group:
|
||||||
|
if not group_id:
|
||||||
|
raise ValueError("group_id is required for group chores")
|
||||||
|
if not await is_user_member(db, group_id, user_id):
|
||||||
|
raise PermissionDeniedError(detail=f"User {user_id} not a member of group {group_id}")
|
||||||
|
if db_chore.group_id != group_id:
|
||||||
|
raise ChoreNotFoundError(chore_id, group_id)
|
||||||
|
else: # personal chore
|
||||||
|
if group_id:
|
||||||
|
raise ValueError("group_id must be None for personal chores")
|
||||||
|
if db_chore.created_by_id != user_id:
|
||||||
|
raise PermissionDeniedError(detail="Only the creator can update personal chores")
|
||||||
|
|
||||||
|
update_data = chore_in.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
|
# Handle type change
|
||||||
|
if 'type' in update_data:
|
||||||
|
new_type = update_data['type']
|
||||||
|
if new_type == ChoreTypeEnum.group and not group_id:
|
||||||
|
raise ValueError("group_id is required for group chores")
|
||||||
|
if new_type == ChoreTypeEnum.personal and group_id:
|
||||||
|
raise ValueError("group_id must be None for personal chores")
|
||||||
|
|
||||||
|
# Recalculate next_due_date if needed
|
||||||
|
recalculate = False
|
||||||
|
if 'frequency' in update_data and update_data['frequency'] != db_chore.frequency:
|
||||||
|
recalculate = True
|
||||||
|
if 'custom_interval_days' in update_data and update_data['custom_interval_days'] != db_chore.custom_interval_days:
|
||||||
|
recalculate = True
|
||||||
|
|
||||||
|
current_next_due_date_for_calc = db_chore.next_due_date
|
||||||
|
if 'next_due_date' in update_data:
|
||||||
|
current_next_due_date_for_calc = update_data['next_due_date']
|
||||||
|
if not ('frequency' in update_data or 'custom_interval_days' in update_data):
|
||||||
|
recalculate = False
|
||||||
|
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(db_chore, field, value)
|
||||||
|
|
||||||
|
if recalculate:
|
||||||
|
db_chore.next_due_date = calculate_next_due_date(
|
||||||
|
current_due_date=current_next_due_date_for_calc,
|
||||||
|
frequency=db_chore.frequency,
|
||||||
|
custom_interval_days=db_chore.custom_interval_days,
|
||||||
|
last_completed_date=db_chore.last_completed_at
|
||||||
|
)
|
||||||
|
|
||||||
|
if db_chore.frequency == ChoreFrequencyEnum.custom and db_chore.custom_interval_days is None:
|
||||||
|
raise ValueError("custom_interval_days must be set for custom frequency chores.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await db.flush() # Flush changes within the transaction
|
||||||
|
result = await db.execute(
|
||||||
|
select(Chore)
|
||||||
|
.where(Chore.id == db_chore.id)
|
||||||
|
.options(
|
||||||
|
selectinload(Chore.creator),
|
||||||
|
selectinload(Chore.group),
|
||||||
|
selectinload(Chore.assignments).selectinload(ChoreAssignment.assigned_user)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating chore {chore_id}: {e}", exc_info=True)
|
||||||
|
raise DatabaseIntegrityError(f"Could not update chore {chore_id}. Error: {str(e)}")
|
||||||
|
|
||||||
|
async def delete_chore(
|
||||||
|
db: AsyncSession,
|
||||||
|
chore_id: int,
|
||||||
|
user_id: int,
|
||||||
|
group_id: Optional[int] = None
|
||||||
|
) -> bool:
|
||||||
|
"""Deletes a chore and its assignments using proper transaction management, ensuring user has permission."""
|
||||||
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
|
db_chore = await get_chore_by_id(db, chore_id)
|
||||||
|
if not db_chore:
|
||||||
|
raise ChoreNotFoundError(chore_id, group_id)
|
||||||
|
|
||||||
|
# Check permissions
|
||||||
|
if db_chore.type == ChoreTypeEnum.group:
|
||||||
|
if not group_id:
|
||||||
|
raise ValueError("group_id is required for group chores")
|
||||||
|
if not await is_user_member(db, group_id, user_id):
|
||||||
|
raise PermissionDeniedError(detail=f"User {user_id} not a member of group {group_id}")
|
||||||
|
if db_chore.group_id != group_id:
|
||||||
|
raise ChoreNotFoundError(chore_id, group_id)
|
||||||
|
else: # personal chore
|
||||||
|
if group_id:
|
||||||
|
raise ValueError("group_id must be None for personal chores")
|
||||||
|
if db_chore.created_by_id != user_id:
|
||||||
|
raise PermissionDeniedError(detail="Only the creator can delete personal chores")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await db.delete(db_chore)
|
||||||
|
await db.flush() # Ensure deletion is processed within the transaction
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting chore {chore_id}: {e}", exc_info=True)
|
||||||
|
raise DatabaseIntegrityError(f"Could not delete chore {chore_id}. Error: {str(e)}")
|
||||||
|
|
||||||
|
# === CHORE ASSIGNMENT CRUD FUNCTIONS ===
|
||||||
|
|
||||||
|
async def create_chore_assignment(
|
||||||
|
db: AsyncSession,
|
||||||
|
assignment_in: ChoreAssignmentCreate,
|
||||||
|
user_id: int
|
||||||
|
) -> ChoreAssignment:
|
||||||
|
"""Creates a new chore assignment. User must be able to manage the chore."""
|
||||||
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
|
# Get the chore and validate permissions
|
||||||
|
chore = await get_chore_by_id(db, assignment_in.chore_id)
|
||||||
|
if not chore:
|
||||||
|
raise ChoreNotFoundError(chore_id=assignment_in.chore_id)
|
||||||
|
|
||||||
|
# Check permissions to assign this chore
|
||||||
|
if chore.type == ChoreTypeEnum.personal:
|
||||||
|
if chore.created_by_id != user_id:
|
||||||
|
raise PermissionDeniedError(detail="Only the creator can assign personal chores")
|
||||||
|
else: # group chore
|
||||||
|
if not await is_user_member(db, chore.group_id, user_id):
|
||||||
|
raise PermissionDeniedError(detail=f"User {user_id} not a member of group {chore.group_id}")
|
||||||
|
# For group chores, check if assignee is also a group member
|
||||||
|
if not await is_user_member(db, chore.group_id, assignment_in.assigned_to_user_id):
|
||||||
|
raise PermissionDeniedError(detail=f"Cannot assign chore to user {assignment_in.assigned_to_user_id} who is not a group member")
|
||||||
|
|
||||||
|
db_assignment = ChoreAssignment(**assignment_in.model_dump(exclude_unset=True))
|
||||||
|
db.add(db_assignment)
|
||||||
|
await db.flush() # Get the ID for the assignment
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load relationships for the response
|
||||||
|
result = await db.execute(
|
||||||
|
select(ChoreAssignment)
|
||||||
|
.where(ChoreAssignment.id == db_assignment.id)
|
||||||
|
.options(
|
||||||
|
selectinload(ChoreAssignment.chore).selectinload(Chore.creator),
|
||||||
|
selectinload(ChoreAssignment.assigned_user)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating chore assignment: {e}", exc_info=True)
|
||||||
|
raise DatabaseIntegrityError(f"Could not create chore assignment. Error: {str(e)}")
|
||||||
|
|
||||||
|
async def get_chore_assignment_by_id(db: AsyncSession, assignment_id: int) -> Optional[ChoreAssignment]:
|
||||||
|
"""Gets a chore assignment by ID."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(ChoreAssignment)
|
||||||
|
.where(ChoreAssignment.id == assignment_id)
|
||||||
|
.options(
|
||||||
|
selectinload(ChoreAssignment.chore).selectinload(Chore.creator),
|
||||||
|
selectinload(ChoreAssignment.assigned_user)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_user_assignments(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: int,
|
||||||
|
include_completed: bool = False
|
||||||
|
) -> List[ChoreAssignment]:
|
||||||
|
"""Gets all chore assignments for a user."""
|
||||||
|
query = select(ChoreAssignment).where(ChoreAssignment.assigned_to_user_id == user_id)
|
||||||
|
|
||||||
|
if not include_completed:
|
||||||
|
query = query.where(ChoreAssignment.is_complete == False)
|
||||||
|
|
||||||
|
query = query.options(
|
||||||
|
selectinload(ChoreAssignment.chore).selectinload(Chore.creator),
|
||||||
|
selectinload(ChoreAssignment.assigned_user)
|
||||||
|
).order_by(ChoreAssignment.due_date, ChoreAssignment.id)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_chore_assignments(
|
||||||
|
db: AsyncSession,
|
||||||
|
chore_id: int,
|
||||||
|
user_id: int
|
||||||
|
) -> List[ChoreAssignment]:
|
||||||
|
"""Gets all assignments for a specific chore. User must have permission to view the chore."""
|
||||||
|
chore = await get_chore_by_id(db, chore_id)
|
||||||
|
if not chore:
|
||||||
|
raise ChoreNotFoundError(chore_id=chore_id)
|
||||||
|
|
||||||
|
# Check permissions
|
||||||
|
if chore.type == ChoreTypeEnum.personal:
|
||||||
|
if chore.created_by_id != user_id:
|
||||||
|
raise PermissionDeniedError(detail="Can only view assignments for own personal chores")
|
||||||
|
else: # group chore
|
||||||
|
if not await is_user_member(db, chore.group_id, user_id):
|
||||||
|
raise PermissionDeniedError(detail=f"User {user_id} not a member of group {chore.group_id}")
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(ChoreAssignment)
|
||||||
|
.where(ChoreAssignment.chore_id == chore_id)
|
||||||
|
.options(
|
||||||
|
selectinload(ChoreAssignment.chore).selectinload(Chore.creator),
|
||||||
|
selectinload(ChoreAssignment.assigned_user)
|
||||||
|
)
|
||||||
|
.order_by(ChoreAssignment.due_date, ChoreAssignment.id)
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def update_chore_assignment(
|
||||||
|
db: AsyncSession,
|
||||||
|
assignment_id: int,
|
||||||
|
assignment_in: ChoreAssignmentUpdate,
|
||||||
|
user_id: int
|
||||||
|
) -> Optional[ChoreAssignment]:
|
||||||
|
"""Updates a chore assignment. Only the assignee can mark it complete."""
|
||||||
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
|
db_assignment = await get_chore_assignment_by_id(db, assignment_id)
|
||||||
|
if not db_assignment:
|
||||||
|
raise ChoreNotFoundError(assignment_id=assignment_id)
|
||||||
|
|
||||||
|
# Load the chore for permission checking
|
||||||
|
chore = await get_chore_by_id(db, db_assignment.chore_id)
|
||||||
|
if not chore:
|
||||||
|
raise ChoreNotFoundError(chore_id=db_assignment.chore_id)
|
||||||
|
|
||||||
|
# Check permissions - only assignee can complete, but chore managers can reschedule
|
||||||
|
can_manage = False
|
||||||
|
if chore.type == ChoreTypeEnum.personal:
|
||||||
|
can_manage = chore.created_by_id == user_id
|
||||||
|
else: # group chore
|
||||||
|
can_manage = await is_user_member(db, chore.group_id, user_id)
|
||||||
|
|
||||||
|
can_complete = db_assignment.assigned_to_user_id == user_id
|
||||||
|
|
||||||
|
update_data = assignment_in.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
|
# Check specific permissions for different updates
|
||||||
|
if 'is_complete' in update_data and not can_complete:
|
||||||
|
raise PermissionDeniedError(detail="Only the assignee can mark assignments as complete")
|
||||||
|
|
||||||
|
if 'due_date' in update_data and not can_manage:
|
||||||
|
raise PermissionDeniedError(detail="Only chore managers can reschedule assignments")
|
||||||
|
|
||||||
|
# Handle completion logic
|
||||||
|
if 'is_complete' in update_data and update_data['is_complete']:
|
||||||
|
if not db_assignment.is_complete: # Only if not already complete
|
||||||
|
update_data['completed_at'] = datetime.utcnow()
|
||||||
|
|
||||||
|
# Update parent chore's last_completed_at and recalculate next_due_date
|
||||||
|
chore.last_completed_at = update_data['completed_at']
|
||||||
|
chore.next_due_date = calculate_next_due_date(
|
||||||
|
current_due_date=chore.next_due_date,
|
||||||
|
frequency=chore.frequency,
|
||||||
|
custom_interval_days=chore.custom_interval_days,
|
||||||
|
last_completed_date=chore.last_completed_at
|
||||||
|
)
|
||||||
|
elif 'is_complete' in update_data and not update_data['is_complete']:
|
||||||
|
# If marking as incomplete, clear completed_at
|
||||||
|
update_data['completed_at'] = None
|
||||||
|
|
||||||
|
# Apply updates
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(db_assignment, field, value)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await db.flush() # Flush changes within the transaction
|
||||||
|
|
||||||
|
# Load relationships for the response
|
||||||
|
result = await db.execute(
|
||||||
|
select(ChoreAssignment)
|
||||||
|
.where(ChoreAssignment.id == db_assignment.id)
|
||||||
|
.options(
|
||||||
|
selectinload(ChoreAssignment.chore).selectinload(Chore.creator),
|
||||||
|
selectinload(ChoreAssignment.assigned_user)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating chore assignment {assignment_id}: {e}", exc_info=True)
|
||||||
|
raise DatabaseIntegrityError(f"Could not update chore assignment {assignment_id}. Error: {str(e)}")
|
||||||
|
|
||||||
|
async def delete_chore_assignment(
|
||||||
|
db: AsyncSession,
|
||||||
|
assignment_id: int,
|
||||||
|
user_id: int
|
||||||
|
) -> bool:
|
||||||
|
"""Deletes a chore assignment. User must have permission to manage the chore."""
|
||||||
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
|
db_assignment = await get_chore_assignment_by_id(db, assignment_id)
|
||||||
|
if not db_assignment:
|
||||||
|
raise ChoreNotFoundError(assignment_id=assignment_id)
|
||||||
|
|
||||||
|
# Load the chore for permission checking
|
||||||
|
chore = await get_chore_by_id(db, db_assignment.chore_id)
|
||||||
|
if not chore:
|
||||||
|
raise ChoreNotFoundError(chore_id=db_assignment.chore_id)
|
||||||
|
|
||||||
|
# Check permissions
|
||||||
|
if chore.type == ChoreTypeEnum.personal:
|
||||||
|
if chore.created_by_id != user_id:
|
||||||
|
raise PermissionDeniedError(detail="Only the creator can delete personal chore assignments")
|
||||||
|
else: # group chore
|
||||||
|
if not await is_user_member(db, chore.group_id, user_id):
|
||||||
|
raise PermissionDeniedError(detail=f"User {user_id} not a member of group {chore.group_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await db.delete(db_assignment)
|
||||||
|
await db.flush() # Ensure deletion is processed within the transaction
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting chore assignment {assignment_id}: {e}", exc_info=True)
|
||||||
|
raise DatabaseIntegrityError(f"Could not delete chore assignment {assignment_id}. Error: {str(e)}")
|
699
be/app/crud/expense.py
Normal file
699
be/app/crud/expense.py
Normal file
@ -0,0 +1,699 @@
|
|||||||
|
# app/crud/expense.py
|
||||||
|
import logging # Add logging import
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
from sqlalchemy.orm import selectinload, joinedload
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError # Added import
|
||||||
|
from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation
|
||||||
|
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict, Any
|
||||||
|
from datetime import datetime, timezone # Added timezone
|
||||||
|
|
||||||
|
from app.models import (
|
||||||
|
Expense as ExpenseModel,
|
||||||
|
ExpenseSplit as ExpenseSplitModel,
|
||||||
|
User as UserModel,
|
||||||
|
List as ListModel,
|
||||||
|
Group as GroupModel,
|
||||||
|
UserGroup as UserGroupModel,
|
||||||
|
SplitTypeEnum,
|
||||||
|
Item as ItemModel,
|
||||||
|
ExpenseOverallStatusEnum, # Added
|
||||||
|
ExpenseSplitStatusEnum, # Added
|
||||||
|
)
|
||||||
|
from app.schemas.expense import ExpenseCreate, ExpenseSplitCreate, ExpenseUpdate # Removed unused ExpenseUpdate
|
||||||
|
from app.core.exceptions import (
|
||||||
|
# Using existing specific exceptions where possible
|
||||||
|
ListNotFoundError,
|
||||||
|
GroupNotFoundError,
|
||||||
|
UserNotFoundError,
|
||||||
|
InvalidOperationError, # Import the new exception
|
||||||
|
DatabaseConnectionError, # Added
|
||||||
|
DatabaseIntegrityError, # Added
|
||||||
|
DatabaseQueryError, # Added
|
||||||
|
DatabaseTransactionError,# Added
|
||||||
|
ExpenseOperationError # Added specific exception
|
||||||
|
)
|
||||||
|
from app.models import RecurrencePattern
|
||||||
|
|
||||||
|
# Placeholder for InvalidOperationError if not defined in app.core.exceptions
|
||||||
|
# This should be a proper HTTPException subclass if used in API layer
|
||||||
|
# class CrudInvalidOperationError(ValueError): # For internal CRUD validation logic
|
||||||
|
# pass
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # Initialize logger
|
||||||
|
|
||||||
|
def _round_money(amount: Decimal) -> Decimal:
|
||||||
|
"""Rounds a Decimal to two decimal places using ROUND_HALF_UP."""
|
||||||
|
return amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||||
|
|
||||||
|
async def get_users_for_splitting(db: AsyncSession, expense_group_id: Optional[int], expense_list_id: Optional[int], expense_paid_by_user_id: int) -> PyList[UserModel]:
|
||||||
|
"""
|
||||||
|
Determines the list of users an expense should be split amongst.
|
||||||
|
Priority: Group members (if group_id), then List's group members or creator (if list_id).
|
||||||
|
Fallback to only the payer if no other context yields users.
|
||||||
|
"""
|
||||||
|
users_to_split_with: PyList[UserModel] = []
|
||||||
|
processed_user_ids = set()
|
||||||
|
|
||||||
|
async def _add_user(user: Optional[UserModel]):
|
||||||
|
if user and user.id not in processed_user_ids:
|
||||||
|
users_to_split_with.append(user)
|
||||||
|
processed_user_ids.add(user.id)
|
||||||
|
|
||||||
|
if expense_group_id:
|
||||||
|
group_result = await db.execute(
|
||||||
|
select(GroupModel).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user)))
|
||||||
|
.where(GroupModel.id == expense_group_id)
|
||||||
|
)
|
||||||
|
group = group_result.scalars().first()
|
||||||
|
if not group:
|
||||||
|
raise GroupNotFoundError(expense_group_id)
|
||||||
|
for assoc in group.member_associations:
|
||||||
|
await _add_user(assoc.user)
|
||||||
|
|
||||||
|
elif expense_list_id: # Only if group_id was not primary context
|
||||||
|
list_result = await db.execute(
|
||||||
|
select(ListModel)
|
||||||
|
.options(
|
||||||
|
selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))),
|
||||||
|
selectinload(ListModel.creator)
|
||||||
|
)
|
||||||
|
.where(ListModel.id == expense_list_id)
|
||||||
|
)
|
||||||
|
db_list = list_result.scalars().first()
|
||||||
|
if not db_list:
|
||||||
|
raise ListNotFoundError(expense_list_id)
|
||||||
|
|
||||||
|
if db_list.group:
|
||||||
|
for assoc in db_list.group.member_associations:
|
||||||
|
await _add_user(assoc.user)
|
||||||
|
elif db_list.creator:
|
||||||
|
await _add_user(db_list.creator)
|
||||||
|
|
||||||
|
if not users_to_split_with:
|
||||||
|
payer_user = await db.get(UserModel, expense_paid_by_user_id)
|
||||||
|
if not payer_user:
|
||||||
|
# This should have been caught earlier if paid_by_user_id was validated before calling this helper
|
||||||
|
raise UserNotFoundError(user_id=expense_paid_by_user_id)
|
||||||
|
await _add_user(payer_user)
|
||||||
|
|
||||||
|
if not users_to_split_with:
|
||||||
|
# This should ideally not be reached if payer is always a fallback
|
||||||
|
raise InvalidOperationError("Could not determine any users for splitting the expense.")
|
||||||
|
|
||||||
|
return users_to_split_with
|
||||||
|
|
||||||
|
|
||||||
|
async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_user_id: int) -> ExpenseModel:
|
||||||
|
"""Creates a new expense and its associated splits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
expense_in: Expense creation data
|
||||||
|
current_user_id: ID of the user creating the expense
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created expense with splits
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
UserNotFoundError: If payer or split users don't exist
|
||||||
|
ListNotFoundError: If specified list doesn't exist
|
||||||
|
GroupNotFoundError: If specified group doesn't exist
|
||||||
|
InvalidOperationError: For various validation failures
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||||
|
# 1. Validate payer
|
||||||
|
payer = await db.get(UserModel, expense_in.paid_by_user_id)
|
||||||
|
if not payer:
|
||||||
|
raise UserNotFoundError(user_id=expense_in.paid_by_user_id, identifier="Payer")
|
||||||
|
|
||||||
|
# 2. Context Resolution and Validation (now part of the transaction)
|
||||||
|
if not expense_in.list_id and not expense_in.group_id and not expense_in.item_id:
|
||||||
|
raise InvalidOperationError("Expense must be associated with a list, a group, or an item.")
|
||||||
|
|
||||||
|
final_group_id = await _resolve_expense_context(db, expense_in)
|
||||||
|
# Further validation for item_id if provided
|
||||||
|
db_item_instance = None
|
||||||
|
if expense_in.item_id:
|
||||||
|
db_item_instance = await db.get(ItemModel, expense_in.item_id)
|
||||||
|
if not db_item_instance:
|
||||||
|
raise InvalidOperationError(f"Item with ID {expense_in.item_id} not found.")
|
||||||
|
# Potentially link item's list/group if not already set on expense_in
|
||||||
|
if db_item_instance.list_id and not expense_in.list_id:
|
||||||
|
expense_in.list_id = db_item_instance.list_id
|
||||||
|
# Re-resolve context if list_id was derived from item
|
||||||
|
final_group_id = await _resolve_expense_context(db, expense_in)
|
||||||
|
|
||||||
|
# Create recurrence pattern if this is a recurring expense
|
||||||
|
recurrence_pattern = None
|
||||||
|
if expense_in.is_recurring and expense_in.recurrence_pattern:
|
||||||
|
recurrence_pattern = RecurrencePattern(
|
||||||
|
type=expense_in.recurrence_pattern.type,
|
||||||
|
interval=expense_in.recurrence_pattern.interval,
|
||||||
|
days_of_week=expense_in.recurrence_pattern.days_of_week,
|
||||||
|
end_date=expense_in.recurrence_pattern.end_date,
|
||||||
|
max_occurrences=expense_in.recurrence_pattern.max_occurrences,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
db.add(recurrence_pattern)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
# 3. Create the ExpenseModel instance
|
||||||
|
db_expense = ExpenseModel(
|
||||||
|
description=expense_in.description,
|
||||||
|
total_amount=_round_money(expense_in.total_amount),
|
||||||
|
currency=expense_in.currency or "USD",
|
||||||
|
expense_date=expense_in.expense_date or datetime.now(timezone.utc),
|
||||||
|
split_type=expense_in.split_type,
|
||||||
|
list_id=expense_in.list_id,
|
||||||
|
group_id=final_group_id, # Use resolved group_id
|
||||||
|
item_id=expense_in.item_id,
|
||||||
|
paid_by_user_id=expense_in.paid_by_user_id,
|
||||||
|
created_by_user_id=current_user_id,
|
||||||
|
overall_settlement_status=ExpenseOverallStatusEnum.unpaid,
|
||||||
|
is_recurring=expense_in.is_recurring,
|
||||||
|
recurrence_pattern=recurrence_pattern,
|
||||||
|
next_occurrence=expense_in.expense_date if expense_in.is_recurring else None
|
||||||
|
)
|
||||||
|
db.add(db_expense)
|
||||||
|
await db.flush() # Get expense ID
|
||||||
|
|
||||||
|
# 4. Generate splits (passing current_user_id through kwargs if needed by specific split types)
|
||||||
|
splits_to_create = await _generate_expense_splits(
|
||||||
|
db=db,
|
||||||
|
expense_model=db_expense,
|
||||||
|
expense_in=expense_in,
|
||||||
|
current_user_id=current_user_id # Pass for item-based splits needing creator info
|
||||||
|
)
|
||||||
|
|
||||||
|
for split_model in splits_to_create:
|
||||||
|
split_model.expense_id = db_expense.id # Set FK after db_expense has ID
|
||||||
|
db.add_all(splits_to_create)
|
||||||
|
await db.flush() # Persist splits
|
||||||
|
|
||||||
|
# 5. Re-fetch the expense with all necessary relationships for the response
|
||||||
|
stmt = (
|
||||||
|
select(ExpenseModel)
|
||||||
|
.where(ExpenseModel.id == db_expense.id)
|
||||||
|
.options(
|
||||||
|
selectinload(ExpenseModel.paid_by_user),
|
||||||
|
selectinload(ExpenseModel.created_by_user), # If you have this relationship
|
||||||
|
selectinload(ExpenseModel.list),
|
||||||
|
selectinload(ExpenseModel.group),
|
||||||
|
selectinload(ExpenseModel.item),
|
||||||
|
selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user),
|
||||||
|
selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.settlement_activities)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
loaded_expense = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if loaded_expense is None:
|
||||||
|
# The context manager will handle rollback if an exception is raised.
|
||||||
|
# await transaction.rollback() # Should be handled by context manager
|
||||||
|
raise ExpenseOperationError("Failed to load expense after creation.")
|
||||||
|
|
||||||
|
# await transaction.commit() # Explicit commit removed, context manager handles it.
|
||||||
|
return loaded_expense
|
||||||
|
|
||||||
|
except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
|
||||||
|
# These are business logic validation errors, re-raise them.
|
||||||
|
# If a transaction was started, the context manager handles rollback.
|
||||||
|
raise
|
||||||
|
except IntegrityError as e:
|
||||||
|
# Context manager handles rollback.
|
||||||
|
logger.error(f"Database integrity error during expense creation: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseIntegrityError(f"Failed to save expense due to database integrity issue: {str(e)}")
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error during expense creation: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"Database connection error during expense creation: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
# Context manager handles rollback.
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during expense creation: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"Failed to save expense due to a database transaction error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]:
|
||||||
|
"""Resolves and validates the expense's context (list and group).
|
||||||
|
|
||||||
|
Returns the final group_id for the expense after validation.
|
||||||
|
"""
|
||||||
|
final_group_id = expense_in.group_id
|
||||||
|
|
||||||
|
# If list_id is provided, validate it and potentially derive group_id
|
||||||
|
if expense_in.list_id:
|
||||||
|
list_obj = await db.get(ListModel, expense_in.list_id)
|
||||||
|
if not list_obj:
|
||||||
|
raise ListNotFoundError(expense_in.list_id)
|
||||||
|
|
||||||
|
# If list belongs to a group, verify consistency or inherit group_id
|
||||||
|
if list_obj.group_id:
|
||||||
|
if expense_in.group_id and list_obj.group_id != expense_in.group_id:
|
||||||
|
raise InvalidOperationError(
|
||||||
|
f"List {expense_in.list_id} belongs to group {list_obj.group_id}, "
|
||||||
|
f"but expense was specified for group {expense_in.group_id}."
|
||||||
|
)
|
||||||
|
final_group_id = list_obj.group_id # Prioritize list's group
|
||||||
|
|
||||||
|
# If only group_id is provided (no list_id), validate group_id
|
||||||
|
elif final_group_id:
|
||||||
|
group_obj = await db.get(GroupModel, final_group_id)
|
||||||
|
if not group_obj:
|
||||||
|
raise GroupNotFoundError(final_group_id)
|
||||||
|
|
||||||
|
return final_group_id
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_expense_splits(
|
||||||
|
db: AsyncSession,
|
||||||
|
expense_model: ExpenseModel,
|
||||||
|
expense_in: ExpenseCreate,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> PyList[ExpenseSplitModel]:
|
||||||
|
"""Generates appropriate expense splits based on split type."""
|
||||||
|
|
||||||
|
splits_to_create: PyList[ExpenseSplitModel] = []
|
||||||
|
|
||||||
|
# Pass db to split creation helpers if they need to fetch more data (e.g., item details for item-based)
|
||||||
|
common_args = {"db": db, "expense_model": expense_model, "expense_in": expense_in, "round_money_func": _round_money, "kwargs": kwargs}
|
||||||
|
|
||||||
|
# Create splits based on the split type
|
||||||
|
if expense_in.split_type == SplitTypeEnum.EQUAL:
|
||||||
|
splits_to_create = await _create_equal_splits(**common_args)
|
||||||
|
|
||||||
|
elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS:
|
||||||
|
splits_to_create = await _create_exact_amount_splits(**common_args)
|
||||||
|
|
||||||
|
elif expense_in.split_type == SplitTypeEnum.PERCENTAGE:
|
||||||
|
splits_to_create = await _create_percentage_splits(**common_args)
|
||||||
|
|
||||||
|
elif expense_in.split_type == SplitTypeEnum.SHARES:
|
||||||
|
splits_to_create = await _create_shares_splits(**common_args)
|
||||||
|
|
||||||
|
elif expense_in.split_type == SplitTypeEnum.ITEM_BASED:
|
||||||
|
splits_to_create = await _create_item_based_splits(**common_args)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}")
|
||||||
|
|
||||||
|
if not splits_to_create:
|
||||||
|
raise InvalidOperationError("No expense splits were generated.")
|
||||||
|
|
||||||
|
return splits_to_create
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_equal_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
||||||
|
"""Creates equal splits among users."""
|
||||||
|
|
||||||
|
users_for_splitting = await get_users_for_splitting(
|
||||||
|
db, expense_model.group_id, expense_model.list_id, expense_model.paid_by_user_id
|
||||||
|
)
|
||||||
|
if not users_for_splitting:
|
||||||
|
raise InvalidOperationError("No users found for EQUAL split.")
|
||||||
|
|
||||||
|
num_users = len(users_for_splitting)
|
||||||
|
amount_per_user = round_money_func(expense_model.total_amount / Decimal(num_users))
|
||||||
|
remainder = expense_model.total_amount - (amount_per_user * num_users)
|
||||||
|
|
||||||
|
splits = []
|
||||||
|
for i, user in enumerate(users_for_splitting):
|
||||||
|
split_amount = amount_per_user
|
||||||
|
if i == 0 and remainder != Decimal('0'):
|
||||||
|
split_amount = round_money_func(amount_per_user + remainder)
|
||||||
|
|
||||||
|
splits.append(ExpenseSplitModel(
|
||||||
|
user_id=user.id,
|
||||||
|
owed_amount=split_amount,
|
||||||
|
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
|
||||||
|
))
|
||||||
|
|
||||||
|
return splits
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_exact_amount_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
||||||
|
"""Creates splits with exact amounts."""
|
||||||
|
|
||||||
|
if not expense_in.splits_in:
|
||||||
|
raise InvalidOperationError("Splits data is required for EXACT_AMOUNTS split type.")
|
||||||
|
|
||||||
|
# Validate all users in splits exist
|
||||||
|
await _validate_users_in_splits(db, expense_in.splits_in)
|
||||||
|
|
||||||
|
current_total = Decimal("0.00")
|
||||||
|
splits = []
|
||||||
|
|
||||||
|
for split_in in expense_in.splits_in:
|
||||||
|
if split_in.owed_amount <= Decimal('0'):
|
||||||
|
raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.")
|
||||||
|
|
||||||
|
rounded_amount = round_money_func(split_in.owed_amount)
|
||||||
|
current_total += rounded_amount
|
||||||
|
|
||||||
|
splits.append(ExpenseSplitModel(
|
||||||
|
user_id=split_in.user_id,
|
||||||
|
owed_amount=rounded_amount,
|
||||||
|
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
|
||||||
|
))
|
||||||
|
|
||||||
|
if round_money_func(current_total) != expense_model.total_amount:
|
||||||
|
raise InvalidOperationError(
|
||||||
|
f"Sum of exact split amounts ({current_total}) != expense total ({expense_model.total_amount})."
|
||||||
|
)
|
||||||
|
|
||||||
|
return splits
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_percentage_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
||||||
|
"""Creates splits based on percentages."""
|
||||||
|
|
||||||
|
if not expense_in.splits_in:
|
||||||
|
raise InvalidOperationError("Splits data is required for PERCENTAGE split type.")
|
||||||
|
|
||||||
|
# Validate all users in splits exist
|
||||||
|
await _validate_users_in_splits(db, expense_in.splits_in)
|
||||||
|
|
||||||
|
total_percentage = Decimal("0.00")
|
||||||
|
current_total = Decimal("0.00")
|
||||||
|
splits = []
|
||||||
|
|
||||||
|
for split_in in expense_in.splits_in:
|
||||||
|
if not (split_in.share_percentage and Decimal("0") < split_in.share_percentage <= Decimal("100")):
|
||||||
|
raise InvalidOperationError(
|
||||||
|
f"Invalid percentage {split_in.share_percentage} for user {split_in.user_id}."
|
||||||
|
)
|
||||||
|
|
||||||
|
total_percentage += split_in.share_percentage
|
||||||
|
owed_amount = round_money_func(expense_model.total_amount * (split_in.share_percentage / Decimal("100")))
|
||||||
|
current_total += owed_amount
|
||||||
|
|
||||||
|
splits.append(ExpenseSplitModel(
|
||||||
|
user_id=split_in.user_id,
|
||||||
|
owed_amount=owed_amount,
|
||||||
|
share_percentage=split_in.share_percentage,
|
||||||
|
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
|
||||||
|
))
|
||||||
|
|
||||||
|
if round_money_func(total_percentage) != Decimal("100.00"):
|
||||||
|
raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.")
|
||||||
|
|
||||||
|
# Adjust for rounding differences
|
||||||
|
if current_total != expense_model.total_amount and splits:
|
||||||
|
diff = expense_model.total_amount - current_total
|
||||||
|
splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
|
||||||
|
|
||||||
|
return splits
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_shares_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
||||||
|
"""Creates splits based on shares."""
|
||||||
|
|
||||||
|
if not expense_in.splits_in:
|
||||||
|
raise InvalidOperationError("Splits data is required for SHARES split type.")
|
||||||
|
|
||||||
|
# Validate all users in splits exist
|
||||||
|
await _validate_users_in_splits(db, expense_in.splits_in)
|
||||||
|
|
||||||
|
# Calculate total shares
|
||||||
|
total_shares = sum(s.share_units for s in expense_in.splits_in if s.share_units and s.share_units > 0)
|
||||||
|
if total_shares == 0:
|
||||||
|
raise InvalidOperationError("Total shares cannot be zero for SHARES split.")
|
||||||
|
|
||||||
|
splits = []
|
||||||
|
current_total = Decimal("0.00")
|
||||||
|
|
||||||
|
for split_in in expense_in.splits_in:
|
||||||
|
if not (split_in.share_units and split_in.share_units > 0):
|
||||||
|
raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.")
|
||||||
|
|
||||||
|
share_ratio = Decimal(split_in.share_units) / Decimal(total_shares)
|
||||||
|
owed_amount = round_money_func(expense_model.total_amount * share_ratio)
|
||||||
|
current_total += owed_amount
|
||||||
|
|
||||||
|
splits.append(ExpenseSplitModel(
|
||||||
|
user_id=split_in.user_id,
|
||||||
|
owed_amount=owed_amount,
|
||||||
|
share_units=split_in.share_units,
|
||||||
|
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
|
||||||
|
))
|
||||||
|
|
||||||
|
# Adjust for rounding differences
|
||||||
|
if current_total != expense_model.total_amount and splits:
|
||||||
|
diff = expense_model.total_amount - current_total
|
||||||
|
splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
|
||||||
|
|
||||||
|
return splits
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_item_based_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
||||||
|
"""Creates splits based on items in a shopping list."""
|
||||||
|
|
||||||
|
if not expense_model.list_id:
|
||||||
|
raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.")
|
||||||
|
|
||||||
|
if expense_in.splits_in:
|
||||||
|
logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.")
|
||||||
|
|
||||||
|
# Build query to fetch relevant items
|
||||||
|
items_query = select(ItemModel).where(ItemModel.list_id == expense_model.list_id)
|
||||||
|
if expense_model.item_id:
|
||||||
|
items_query = items_query.where(ItemModel.id == expense_model.item_id)
|
||||||
|
else:
|
||||||
|
items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0")))
|
||||||
|
|
||||||
|
# Load items with their adders
|
||||||
|
items_result = await db.execute(items_query.options(selectinload(ItemModel.added_by_user)))
|
||||||
|
relevant_items = items_result.scalars().all()
|
||||||
|
|
||||||
|
if not relevant_items:
|
||||||
|
error_msg = (
|
||||||
|
f"Specified item ID {expense_model.item_id} not found in list {expense_model.list_id}."
|
||||||
|
if expense_model.item_id else
|
||||||
|
f"List {expense_model.list_id} has no priced items to base the expense on."
|
||||||
|
)
|
||||||
|
raise InvalidOperationError(error_msg)
|
||||||
|
|
||||||
|
# Aggregate owed amounts by user
|
||||||
|
calculated_total = Decimal("0.00")
|
||||||
|
user_owed_amounts = defaultdict(Decimal)
|
||||||
|
processed_items = 0
|
||||||
|
|
||||||
|
for item in relevant_items:
|
||||||
|
if item.price is None or item.price <= Decimal("0"):
|
||||||
|
if expense_model.item_id:
|
||||||
|
raise InvalidOperationError(
|
||||||
|
f"Item ID {expense_model.item_id} must have a positive price for ITEM_BASED expense."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not item.added_by_user:
|
||||||
|
logger.error(f"Item ID {item.id} is missing added_by_user relationship.")
|
||||||
|
raise InvalidOperationError(f"Data integrity issue: Item {item.id} is missing adder information.")
|
||||||
|
|
||||||
|
calculated_total += item.price
|
||||||
|
user_owed_amounts[item.added_by_user.id] += item.price
|
||||||
|
processed_items += 1
|
||||||
|
|
||||||
|
if processed_items == 0:
|
||||||
|
raise InvalidOperationError(
|
||||||
|
f"No items with positive prices found in list {expense_model.list_id} to create ITEM_BASED expense."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate total matches calculated total
|
||||||
|
if round_money_func(calculated_total) != expense_model.total_amount:
|
||||||
|
raise InvalidOperationError(
|
||||||
|
f"Expense total amount ({expense_model.total_amount}) does not match the "
|
||||||
|
f"calculated total from item prices ({calculated_total})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create splits based on aggregated amounts
|
||||||
|
splits = []
|
||||||
|
for user_id, owed_amount in user_owed_amounts.items():
|
||||||
|
splits.append(ExpenseSplitModel(
|
||||||
|
user_id=user_id,
|
||||||
|
owed_amount=round_money_func(owed_amount),
|
||||||
|
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
|
||||||
|
))
|
||||||
|
|
||||||
|
return splits
|
||||||
|
|
||||||
|
|
||||||
|
async def _validate_users_in_splits(db: AsyncSession, splits_in: PyList[ExpenseSplitCreate]) -> None:
|
||||||
|
"""Validates that all users in the splits exist."""
|
||||||
|
|
||||||
|
user_ids_in_split = [s.user_id for s in splits_in]
|
||||||
|
user_results = await db.execute(select(UserModel.id).where(UserModel.id.in_(user_ids_in_split)))
|
||||||
|
found_user_ids = {row[0] for row in user_results}
|
||||||
|
|
||||||
|
if len(found_user_ids) != len(user_ids_in_split):
|
||||||
|
missing_user_ids = set(user_ids_in_split) - found_user_ids
|
||||||
|
raise UserNotFoundError(identifier=f"users in split data: {list(missing_user_ids)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_expense_by_id(db: AsyncSession, expense_id: int) -> Optional[ExpenseModel]:
|
||||||
|
result = await db.execute(
|
||||||
|
select(ExpenseModel)
|
||||||
|
.options(
|
||||||
|
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)),
|
||||||
|
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.settlement_activities)),
|
||||||
|
selectinload(ExpenseModel.paid_by_user),
|
||||||
|
selectinload(ExpenseModel.list),
|
||||||
|
selectinload(ExpenseModel.group),
|
||||||
|
selectinload(ExpenseModel.item)
|
||||||
|
)
|
||||||
|
.where(ExpenseModel.id == expense_id)
|
||||||
|
)
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
async def get_expenses_for_list(db: AsyncSession, list_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
|
||||||
|
result = await db.execute(
|
||||||
|
select(ExpenseModel)
|
||||||
|
.where(ExpenseModel.list_id == list_id)
|
||||||
|
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
|
||||||
|
.offset(skip).limit(limit)
|
||||||
|
.options(
|
||||||
|
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)),
|
||||||
|
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.settlement_activities))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_expenses_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
|
||||||
|
result = await db.execute(
|
||||||
|
select(ExpenseModel)
|
||||||
|
.where(ExpenseModel.group_id == group_id)
|
||||||
|
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
|
||||||
|
.offset(skip).limit(limit)
|
||||||
|
.options(
|
||||||
|
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)),
|
||||||
|
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.settlement_activities))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_user_accessible_expenses(db: AsyncSession, user_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
|
||||||
|
"""
|
||||||
|
Get all expenses that a user has access to:
|
||||||
|
- Expenses they paid for
|
||||||
|
- Expenses in groups they are members of
|
||||||
|
- Expenses for lists they have access to
|
||||||
|
"""
|
||||||
|
from app.models import UserGroup as UserGroupModel, List as ListModel # Import here to avoid circular imports
|
||||||
|
|
||||||
|
# Build the query for accessible expenses
|
||||||
|
# 1. Expenses paid by the user
|
||||||
|
paid_by_condition = ExpenseModel.paid_by_user_id == user_id
|
||||||
|
|
||||||
|
# 2. Expenses in groups where user is a member
|
||||||
|
group_member_subquery = select(UserGroupModel.group_id).where(UserGroupModel.user_id == user_id)
|
||||||
|
group_expenses_condition = ExpenseModel.group_id.in_(group_member_subquery)
|
||||||
|
|
||||||
|
# 3. Expenses for lists where user is creator or has access (simplified to creator for now)
|
||||||
|
user_lists_subquery = select(ListModel.id).where(ListModel.created_by_id == user_id)
|
||||||
|
list_expenses_condition = ExpenseModel.list_id.in_(user_lists_subquery)
|
||||||
|
|
||||||
|
# Combine all conditions with OR
|
||||||
|
combined_condition = paid_by_condition | group_expenses_condition | list_expenses_condition
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(ExpenseModel)
|
||||||
|
.where(combined_condition)
|
||||||
|
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
|
||||||
|
.offset(skip).limit(limit)
|
||||||
|
.options(
|
||||||
|
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)),
|
||||||
|
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.settlement_activities)),
|
||||||
|
selectinload(ExpenseModel.paid_by_user),
|
||||||
|
selectinload(ExpenseModel.list),
|
||||||
|
selectinload(ExpenseModel.group)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in: ExpenseUpdate) -> ExpenseModel:
|
||||||
|
"""
|
||||||
|
Updates an existing expense.
|
||||||
|
Only allows updates to description, currency, and expense_date to avoid split complexities.
|
||||||
|
Requires version matching for optimistic locking.
|
||||||
|
"""
|
||||||
|
if expense_db.version != expense_in.version:
|
||||||
|
raise InvalidOperationError(
|
||||||
|
f"Expense '{expense_db.description}' (ID: {expense_db.id}) has been modified. "
|
||||||
|
f"Your version is {expense_in.version}, current version is {expense_db.version}. Please refresh.",
|
||||||
|
# status_code=status.HTTP_409_CONFLICT # This would be for the API layer to set
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = expense_in.model_dump(exclude_unset=True, exclude={"version"}) # Exclude version itself from data
|
||||||
|
|
||||||
|
# Fields that are safe to update without affecting splits or core logic
|
||||||
|
allowed_to_update = {"description", "currency", "expense_date"}
|
||||||
|
|
||||||
|
updated_something = False
|
||||||
|
for field, value in update_data.items():
|
||||||
|
if field in allowed_to_update:
|
||||||
|
setattr(expense_db, field, value)
|
||||||
|
updated_something = True
|
||||||
|
else:
|
||||||
|
# If any other field is present in the update payload, it's an invalid operation for this simple update
|
||||||
|
raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed.")
|
||||||
|
|
||||||
|
if not updated_something and not expense_in.model_fields_set.intersection(allowed_to_update):
|
||||||
|
# No actual updatable fields were provided in the payload, even if others (like version) were.
|
||||||
|
# This could be a non-issue, or an indication of a misuse of the endpoint.
|
||||||
|
# For now, if only version was sent, we still increment if it matched.
|
||||||
|
pass # Or raise InvalidOperationError("No updatable fields provided.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||||
|
expense_db.version += 1
|
||||||
|
expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp
|
||||||
|
# db.add(expense_db) # Not strictly necessary as expense_db is already tracked by the session
|
||||||
|
|
||||||
|
await db.flush() # Persist changes to the DB and run constraints
|
||||||
|
await db.refresh(expense_db) # Refresh the object from the DB
|
||||||
|
return expense_db
|
||||||
|
except InvalidOperationError: # Re-raise validation errors to be handled by the caller
|
||||||
|
raise
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
|
||||||
|
# The transaction context manager (begin_nested/begin) handles rollback.
|
||||||
|
raise DatabaseIntegrityError(f"Failed to update expense ID {expense_db.id} due to database integrity issue.") from e
|
||||||
|
except SQLAlchemyError as e: # Catch other SQLAlchemy errors
|
||||||
|
logger.error(f"Database transaction error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
|
||||||
|
# The transaction context manager (begin_nested/begin) handles rollback.
|
||||||
|
raise DatabaseTransactionError(f"Failed to update expense ID {expense_db.id} due to a database transaction error.") from e
|
||||||
|
# No generic Exception catch here, let other unexpected errors propagate if not SQLAlchemy related.
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None:
|
||||||
|
"""
|
||||||
|
Deletes an expense. Requires version matching if expected_version is provided.
|
||||||
|
Associated ExpenseSplits are cascade deleted by the database foreign key constraint.
|
||||||
|
"""
|
||||||
|
if expected_version is not None and expense_db.version != expected_version:
|
||||||
|
raise InvalidOperationError(
|
||||||
|
f"Expense '{expense_db.description}' (ID: {expense_db.id}) cannot be deleted. "
|
||||||
|
f"Your expected version {expected_version} does not match current version {expense_db.version}. Please refresh.",
|
||||||
|
# status_code=status.HTTP_409_CONFLICT
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||||
|
await db.delete(expense_db)
|
||||||
|
await db.flush() # Ensure the delete operation is sent to the database
|
||||||
|
except InvalidOperationError: # Re-raise validation errors
|
||||||
|
raise
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
|
||||||
|
# The transaction context manager (begin_nested/begin) handles rollback.
|
||||||
|
raise DatabaseIntegrityError(f"Failed to delete expense ID {expense_db.id} due to database integrity issue.") from e
|
||||||
|
except SQLAlchemyError as e: # Catch other SQLAlchemy errors
|
||||||
|
logger.error(f"Database transaction error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
|
||||||
|
# The transaction context manager (begin_nested/begin) handles rollback.
|
||||||
|
raise DatabaseTransactionError(f"Failed to delete expense ID {expense_db.id} due to a database transaction error.") from e
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Note: The InvalidOperationError is a simple ValueError placeholder.
|
||||||
|
# For API endpoints, these should be translated to appropriate HTTPExceptions.
|
||||||
|
# Ensure app.core.exceptions has proper HTTP error classes if needed.
|
@ -2,122 +2,298 @@
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from sqlalchemy.orm import selectinload # For eager loading members
|
from sqlalchemy.orm import selectinload # For eager loading members
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
from sqlalchemy import delete, func
|
||||||
|
import logging # Add logging import
|
||||||
|
|
||||||
from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel
|
from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel
|
||||||
from app.schemas.group import GroupCreate
|
from app.schemas.group import GroupCreate
|
||||||
from app.models import UserRoleEnum # Import enum
|
from app.models import UserRoleEnum # Import enum
|
||||||
|
from app.core.exceptions import (
|
||||||
|
GroupOperationError,
|
||||||
|
GroupNotFoundError,
|
||||||
|
DatabaseConnectionError,
|
||||||
|
DatabaseIntegrityError,
|
||||||
|
DatabaseQueryError,
|
||||||
|
DatabaseTransactionError,
|
||||||
|
GroupMembershipError,
|
||||||
|
GroupPermissionError # Import GroupPermissionError
|
||||||
|
)
|
||||||
|
|
||||||
# --- Keep existing functions: get_user_by_email, create_user ---
|
logger = logging.getLogger(__name__) # Initialize logger
|
||||||
# (These are actually user CRUD, should ideally be in user.py, but keep for now if working)
|
|
||||||
from app.core.security import hash_password
|
|
||||||
from app.schemas.user import UserCreate # Assuming create_user uses this
|
|
||||||
|
|
||||||
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]:
|
|
||||||
result = await db.execute(select(UserModel).filter(UserModel.email == email))
|
|
||||||
return result.scalars().first()
|
|
||||||
|
|
||||||
async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel:
|
|
||||||
_hashed_password = hash_password(user_in.password)
|
|
||||||
db_user = UserModel(
|
|
||||||
email=user_in.email,
|
|
||||||
password_hash=_hashed_password, # Use correct keyword argument
|
|
||||||
name=user_in.name
|
|
||||||
)
|
|
||||||
db.add(db_user)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(db_user)
|
|
||||||
return db_user
|
|
||||||
# --- End User CRUD ---
|
|
||||||
|
|
||||||
|
|
||||||
# --- Group CRUD ---
|
# --- Group CRUD ---
|
||||||
async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) -> GroupModel:
|
async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) -> GroupModel:
|
||||||
"""Creates a group and adds the creator as the owner."""
|
"""Creates a group and adds the creator as the owner."""
|
||||||
db_group = GroupModel(name=group_in.name, created_by_id=creator_id)
|
try:
|
||||||
db.add(db_group)
|
# Use the composability pattern for transactions as per fastapi-db-strategy.
|
||||||
await db.flush() # Flush to get the db_group.id for the UserGroup entry
|
# This creates a savepoint if already in a transaction (e.g., from get_transactional_session)
|
||||||
|
# or starts a new transaction if called outside of one (e.g., from a script).
|
||||||
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
|
db_group = GroupModel(name=group_in.name, created_by_id=creator_id)
|
||||||
|
db.add(db_group)
|
||||||
|
await db.flush() # Assigns ID to db_group
|
||||||
|
|
||||||
# Add creator as owner
|
db_user_group = UserGroupModel(
|
||||||
db_user_group = UserGroupModel(
|
user_id=creator_id,
|
||||||
user_id=creator_id,
|
group_id=db_group.id,
|
||||||
group_id=db_group.id,
|
role=UserRoleEnum.owner
|
||||||
role=UserRoleEnum.owner # Use the Enum member
|
)
|
||||||
)
|
db.add(db_user_group)
|
||||||
db.add(db_user_group)
|
await db.flush() # Commits user_group, links to group
|
||||||
|
|
||||||
await db.commit()
|
# After creation and linking, explicitly load the group with its member associations and users
|
||||||
await db.refresh(db_group)
|
stmt = (
|
||||||
return db_group
|
select(GroupModel)
|
||||||
|
.where(GroupModel.id == db_group.id)
|
||||||
|
.options(
|
||||||
|
selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
loaded_group = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if loaded_group is None:
|
||||||
|
# This should not happen if we just created it, but as a safeguard
|
||||||
|
raise GroupOperationError("Failed to load group after creation.")
|
||||||
|
|
||||||
|
return loaded_group
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during group creation: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseIntegrityError(f"Failed to create group due to integrity issue: {str(e)}")
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error during group creation: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"Database connection error during group creation: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during group creation: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"Database transaction error during group creation: {str(e)}")
|
||||||
|
|
||||||
async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]:
|
async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]:
|
||||||
"""Gets all groups a user is a member of."""
|
"""Gets all groups a user is a member of with optimized eager loading."""
|
||||||
result = await db.execute(
|
try:
|
||||||
select(GroupModel)
|
result = await db.execute(
|
||||||
.join(UserGroupModel)
|
select(GroupModel)
|
||||||
.where(UserGroupModel.user_id == user_id)
|
.join(UserGroupModel)
|
||||||
.options(selectinload(GroupModel.member_associations)) # Optional: preload associations if needed often
|
.where(UserGroupModel.user_id == user_id)
|
||||||
)
|
.options(
|
||||||
return result.scalars().all()
|
selectinload(GroupModel.member_associations).options(
|
||||||
|
selectinload(UserGroupModel.user)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
except OperationalError as e:
|
||||||
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
raise DatabaseQueryError(f"Failed to query user groups: {str(e)}")
|
||||||
|
|
||||||
async def get_group_by_id(db: AsyncSession, group_id: int) -> Optional[GroupModel]:
|
async def get_group_by_id(db: AsyncSession, group_id: int) -> Optional[GroupModel]:
|
||||||
"""Gets a single group by its ID, optionally loading members."""
|
"""Gets a single group by its ID, optionally loading members."""
|
||||||
# Use selectinload to eager load members and their user details
|
try:
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(GroupModel)
|
select(GroupModel)
|
||||||
.where(GroupModel.id == group_id)
|
.where(GroupModel.id == group_id)
|
||||||
.options(
|
.options(
|
||||||
selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user)
|
selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
return result.scalars().first()
|
||||||
return result.scalars().first()
|
except OperationalError as e:
|
||||||
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
raise DatabaseQueryError(f"Failed to query group: {str(e)}")
|
||||||
|
|
||||||
async def is_user_member(db: AsyncSession, group_id: int, user_id: int) -> bool:
|
async def is_user_member(db: AsyncSession, group_id: int, user_id: int) -> bool:
|
||||||
"""Checks if a user is a member of a specific group."""
|
"""Checks if a user is a member of a specific group."""
|
||||||
result = await db.execute(
|
try:
|
||||||
select(UserGroupModel.id) # Select just one column for existence check
|
result = await db.execute(
|
||||||
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
|
select(UserGroupModel.id)
|
||||||
.limit(1)
|
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
|
||||||
)
|
.limit(1)
|
||||||
return result.scalar_one_or_none() is not None
|
)
|
||||||
|
return result.scalar_one_or_none() is not None
|
||||||
|
except OperationalError as e:
|
||||||
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
raise DatabaseQueryError(f"Failed to check group membership: {str(e)}")
|
||||||
|
|
||||||
async def get_user_role_in_group(db: AsyncSession, group_id: int, user_id: int) -> Optional[UserRoleEnum]:
|
async def get_user_role_in_group(db: AsyncSession, group_id: int, user_id: int) -> Optional[UserRoleEnum]:
|
||||||
"""Gets the role of a user in a specific group."""
|
"""Gets the role of a user in a specific group."""
|
||||||
result = await db.execute(
|
try:
|
||||||
select(UserGroupModel.role)
|
result = await db.execute(
|
||||||
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
|
select(UserGroupModel.role)
|
||||||
)
|
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
|
||||||
role = result.scalar_one_or_none()
|
)
|
||||||
return role # Will be None if not a member, or the UserRoleEnum value
|
return result.scalar_one_or_none()
|
||||||
|
except OperationalError as e:
|
||||||
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
raise DatabaseQueryError(f"Failed to query user role: {str(e)}")
|
||||||
|
|
||||||
async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role: UserRoleEnum = UserRoleEnum.member) -> Optional[UserGroupModel]:
|
async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role: UserRoleEnum = UserRoleEnum.member) -> Optional[UserGroupModel]:
|
||||||
"""Adds a user to a group if they aren't already a member."""
|
"""Adds a user to a group if they aren't already a member."""
|
||||||
# Check if already exists
|
try:
|
||||||
existing = await db.execute(
|
# Check if user is already a member before starting a transaction
|
||||||
select(UserGroupModel).where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
|
existing_stmt = select(UserGroupModel.id).where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
|
||||||
)
|
existing_result = await db.execute(existing_stmt)
|
||||||
if existing.scalar_one_or_none():
|
if existing_result.scalar_one_or_none():
|
||||||
return None # Indicate user already in group
|
return None
|
||||||
|
|
||||||
db_user_group = UserGroupModel(user_id=user_id, group_id=group_id, role=role)
|
# Use a single transaction
|
||||||
db.add(db_user_group)
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
await db.commit()
|
db_user_group = UserGroupModel(user_id=user_id, group_id=group_id, role=role)
|
||||||
await db.refresh(db_user_group)
|
db.add(db_user_group)
|
||||||
return db_user_group
|
await db.flush() # Assigns ID to db_user_group
|
||||||
|
|
||||||
|
# Eagerly load the 'user' and 'group' relationships for the response
|
||||||
|
stmt = (
|
||||||
|
select(UserGroupModel)
|
||||||
|
.where(UserGroupModel.id == db_user_group.id)
|
||||||
|
.options(
|
||||||
|
selectinload(UserGroupModel.user),
|
||||||
|
selectinload(UserGroupModel.group)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
loaded_user_group = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if loaded_user_group is None:
|
||||||
|
raise GroupOperationError(f"Failed to load user group association after adding user {user_id} to group {group_id}.")
|
||||||
|
|
||||||
|
return loaded_user_group
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error while adding user to group: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseIntegrityError(f"Failed to add user to group: {str(e)}")
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error while adding user to group: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error while adding user to group: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"Failed to add user to group: {str(e)}")
|
||||||
|
|
||||||
async def remove_user_from_group(db: AsyncSession, group_id: int, user_id: int) -> bool:
|
async def remove_user_from_group(db: AsyncSession, group_id: int, user_id: int) -> bool:
|
||||||
"""Removes a user from a group."""
|
"""Removes a user from a group."""
|
||||||
result = await db.execute(
|
try:
|
||||||
delete(UserGroupModel)
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
|
result = await db.execute(
|
||||||
.returning(UserGroupModel.id) # Optional: check if a row was actually deleted
|
delete(UserGroupModel)
|
||||||
)
|
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
|
||||||
await db.commit()
|
.returning(UserGroupModel.id)
|
||||||
return result.scalar_one_or_none() is not None # True if deletion happened
|
)
|
||||||
|
return result.scalar_one_or_none() is not None
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error while removing user from group: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error while removing user from group: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"Failed to remove user from group: {str(e)}")
|
||||||
|
|
||||||
async def get_group_member_count(db: AsyncSession, group_id: int) -> int:
|
async def get_group_member_count(db: AsyncSession, group_id: int) -> int:
|
||||||
"""Counts the number of members in a group."""
|
"""Counts the number of members in a group."""
|
||||||
result = await db.execute(
|
try:
|
||||||
select(func.count(UserGroupModel.id)).where(UserGroupModel.group_id == group_id)
|
result = await db.execute(
|
||||||
)
|
select(func.count(UserGroupModel.id)).where(UserGroupModel.group_id == group_id)
|
||||||
return result.scalar_one()
|
)
|
||||||
|
return result.scalar_one()
|
||||||
|
except OperationalError as e:
|
||||||
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
raise DatabaseQueryError(f"Failed to count group members: {str(e)}")
|
||||||
|
|
||||||
|
async def check_group_membership(
|
||||||
|
db: AsyncSession,
|
||||||
|
group_id: int,
|
||||||
|
user_id: int,
|
||||||
|
action: str = "access this group"
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Checks if a user is a member of a group. Raises exceptions if not found or not a member.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GroupNotFoundError: If the group_id does not exist.
|
||||||
|
GroupMembershipError: If the user_id is not a member of the group.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check group existence first
|
||||||
|
group_exists = await db.get(GroupModel, group_id)
|
||||||
|
if not group_exists:
|
||||||
|
raise GroupNotFoundError(group_id)
|
||||||
|
|
||||||
|
# Check membership
|
||||||
|
membership = await db.execute(
|
||||||
|
select(UserGroupModel.id)
|
||||||
|
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
if membership.scalar_one_or_none() is None:
|
||||||
|
raise GroupMembershipError(group_id, action=action)
|
||||||
|
# If we reach here, the user is a member
|
||||||
|
return None
|
||||||
|
except GroupNotFoundError: # Re-raise specific errors
|
||||||
|
raise
|
||||||
|
except GroupMembershipError:
|
||||||
|
raise
|
||||||
|
except OperationalError as e:
|
||||||
|
raise DatabaseConnectionError(f"Failed to connect to database while checking membership: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
raise DatabaseQueryError(f"Failed to check group membership: {str(e)}")
|
||||||
|
|
||||||
|
async def check_user_role_in_group(
|
||||||
|
db: AsyncSession,
|
||||||
|
group_id: int,
|
||||||
|
user_id: int,
|
||||||
|
required_role: UserRoleEnum,
|
||||||
|
action: str = "perform this action"
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Checks if a user is a member of a group and has the required role (or higher).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GroupNotFoundError: If the group_id does not exist.
|
||||||
|
GroupMembershipError: If the user_id is not a member of the group.
|
||||||
|
GroupPermissionError: If the user does not have the required role.
|
||||||
|
"""
|
||||||
|
# First, ensure user is a member (this also checks group existence)
|
||||||
|
await check_group_membership(db, group_id, user_id, action=f"be checked for permissions to {action}")
|
||||||
|
|
||||||
|
# Get the user's actual role
|
||||||
|
actual_role = await get_user_role_in_group(db, group_id, user_id)
|
||||||
|
|
||||||
|
# Define role hierarchy (assuming owner > member)
|
||||||
|
role_hierarchy = {UserRoleEnum.owner: 2, UserRoleEnum.member: 1}
|
||||||
|
|
||||||
|
if not actual_role or role_hierarchy.get(actual_role, 0) < role_hierarchy.get(required_role, 0):
|
||||||
|
raise GroupPermissionError(
|
||||||
|
group_id=group_id,
|
||||||
|
action=f"{action} (requires at least '{required_role.value}' role)"
|
||||||
|
)
|
||||||
|
# If role is sufficient, return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def delete_group(db: AsyncSession, group_id: int) -> None:
|
||||||
|
"""
|
||||||
|
Deletes a group and all its associated data (members, invites, lists, etc.).
|
||||||
|
The cascade delete in the models will handle the deletion of related records.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
GroupNotFoundError: If the group doesn't exist.
|
||||||
|
DatabaseError: If there's an error during deletion.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get the group first to ensure it exists
|
||||||
|
group = await get_group_by_id(db, group_id)
|
||||||
|
if not group:
|
||||||
|
raise GroupNotFoundError(group_id)
|
||||||
|
|
||||||
|
# Delete the group - cascading delete will handle related records
|
||||||
|
await db.delete(group)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
logger.info(f"Group {group_id} deleted successfully")
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error while deleting group {group_id}: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error while deleting group {group_id}: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"Failed to delete group: {str(e)}")
|
@ -1,69 +1,199 @@
|
|||||||
# app/crud/invite.py
|
# app/crud/invite.py
|
||||||
|
import logging # Add logging import
|
||||||
import secrets
|
import secrets
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
|
||||||
from sqlalchemy import delete # Import delete statement
|
from sqlalchemy import delete # Import delete statement
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from app.models import Invite as InviteModel
|
from app.models import Invite as InviteModel, Group as GroupModel, User as UserModel # Import related models for selectinload
|
||||||
|
from app.core.exceptions import (
|
||||||
|
DatabaseConnectionError,
|
||||||
|
DatabaseIntegrityError,
|
||||||
|
DatabaseQueryError,
|
||||||
|
DatabaseTransactionError,
|
||||||
|
InviteOperationError # Add new specific exception
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # Initialize logger
|
||||||
|
|
||||||
# Invite codes should be reasonably unique, but handle potential collision
|
# Invite codes should be reasonably unique, but handle potential collision
|
||||||
MAX_CODE_GENERATION_ATTEMPTS = 5
|
MAX_CODE_GENERATION_ATTEMPTS = 5
|
||||||
|
|
||||||
async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 7) -> Optional[InviteModel]:
|
async def deactivate_all_active_invites_for_group(db: AsyncSession, group_id: int):
|
||||||
"""Creates a new invite code for a group."""
|
"""Deactivates all currently active invite codes for a specific group."""
|
||||||
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
|
try:
|
||||||
code = None
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
attempts = 0
|
stmt = (
|
||||||
|
select(InviteModel)
|
||||||
|
.where(InviteModel.group_id == group_id, InviteModel.is_active == True)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
active_invites = result.scalars().all()
|
||||||
|
|
||||||
# Generate a unique code, retrying if a collision occurs (highly unlikely but safe)
|
if not active_invites:
|
||||||
while attempts < MAX_CODE_GENERATION_ATTEMPTS:
|
return # No active invites to deactivate
|
||||||
attempts += 1
|
|
||||||
potential_code = secrets.token_urlsafe(16)
|
for invite in active_invites:
|
||||||
# Check if an *active* invite with this code already exists
|
invite.is_active = False
|
||||||
existing = await db.execute(
|
db.add(invite)
|
||||||
select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
|
await db.flush() # Flush changes within this transaction block
|
||||||
|
|
||||||
|
# await db.flush() # Removed: Rely on caller to flush/commit
|
||||||
|
# No explicit commit here, assuming it's part of a larger transaction or caller handles commit.
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error deactivating invites for group {group_id}: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"DB connection error deactivating invites for group {group_id}: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error deactivating invites for group {group_id}: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"DB transaction error deactivating invites for group {group_id}: {str(e)}")
|
||||||
|
|
||||||
|
async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 365 * 100) -> Optional[InviteModel]: # Default to 100 years
|
||||||
|
"""Creates a new invite code for a group, deactivating any existing active ones for that group first."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
|
# Deactivate existing active invites for this group
|
||||||
|
await deactivate_all_active_invites_for_group(db, group_id)
|
||||||
|
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
|
||||||
|
|
||||||
|
potential_code = None
|
||||||
|
for attempt in range(MAX_CODE_GENERATION_ATTEMPTS):
|
||||||
|
potential_code = secrets.token_urlsafe(16)
|
||||||
|
existing_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
|
||||||
|
existing_result = await db.execute(existing_check_stmt)
|
||||||
|
if existing_result.scalar_one_or_none() is None:
|
||||||
|
break
|
||||||
|
if attempt == MAX_CODE_GENERATION_ATTEMPTS - 1:
|
||||||
|
raise InviteOperationError("Failed to generate a unique invite code after several attempts.")
|
||||||
|
|
||||||
|
final_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
|
||||||
|
final_check_result = await db.execute(final_check_stmt)
|
||||||
|
if final_check_result.scalar_one_or_none() is not None:
|
||||||
|
raise InviteOperationError("Invite code collision detected just before creation attempt.")
|
||||||
|
|
||||||
|
db_invite = InviteModel(
|
||||||
|
code=potential_code,
|
||||||
|
group_id=group_id,
|
||||||
|
created_by_id=creator_id,
|
||||||
|
expires_at=expires_at,
|
||||||
|
is_active=True
|
||||||
|
)
|
||||||
|
db.add(db_invite)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(InviteModel)
|
||||||
|
.where(InviteModel.id == db_invite.id)
|
||||||
|
.options(
|
||||||
|
selectinload(InviteModel.group),
|
||||||
|
selectinload(InviteModel.creator)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
loaded_invite = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if loaded_invite is None:
|
||||||
|
raise InviteOperationError("Failed to load invite after creation and flush.")
|
||||||
|
|
||||||
|
return loaded_invite
|
||||||
|
except InviteOperationError: # Already specific, re-raise
|
||||||
|
raise
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during invite creation for group {group_id}: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseIntegrityError(f"Failed to create invite due to DB integrity issue: {str(e)}")
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error during invite creation for group {group_id}: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"DB connection error during invite creation: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during invite creation for group {group_id}: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"DB transaction error during invite creation: {str(e)}")
|
||||||
|
|
||||||
|
async def get_active_invite_for_group(db: AsyncSession, group_id: int) -> Optional[InviteModel]:
|
||||||
|
"""Gets the currently active and non-expired invite for a specific group."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
try:
|
||||||
|
stmt = (
|
||||||
|
select(InviteModel).where(
|
||||||
|
InviteModel.group_id == group_id,
|
||||||
|
InviteModel.is_active == True,
|
||||||
|
InviteModel.expires_at > now # Still respect expiry, even if very long
|
||||||
|
)
|
||||||
|
.order_by(InviteModel.created_at.desc()) # Get the most recent one if multiple (should not happen)
|
||||||
|
.limit(1)
|
||||||
|
.options(
|
||||||
|
selectinload(InviteModel.group), # Eager load group
|
||||||
|
selectinload(InviteModel.creator) # Eager load creator
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if existing.scalar_one_or_none() is None:
|
result = await db.execute(stmt)
|
||||||
code = potential_code
|
return result.scalars().first()
|
||||||
break
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error fetching active invite for group {group_id}: {str(e)}", exc_info=True)
|
||||||
if code is None:
|
raise DatabaseConnectionError(f"DB connection error fetching active invite for group {group_id}: {str(e)}")
|
||||||
# Failed to generate a unique code after several attempts
|
except SQLAlchemyError as e:
|
||||||
return None
|
logger.error(f"DB query error fetching active invite for group {group_id}: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseQueryError(f"DB query error fetching active invite for group {group_id}: {str(e)}")
|
||||||
db_invite = InviteModel(
|
|
||||||
code=code,
|
|
||||||
group_id=group_id,
|
|
||||||
created_by_id=creator_id,
|
|
||||||
expires_at=expires_at,
|
|
||||||
is_active=True
|
|
||||||
)
|
|
||||||
db.add(db_invite)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(db_invite)
|
|
||||||
return db_invite
|
|
||||||
|
|
||||||
async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]:
|
async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]:
|
||||||
"""Gets an active and non-expired invite by its code."""
|
"""Gets an active and non-expired invite by its code."""
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
result = await db.execute(
|
try:
|
||||||
select(InviteModel).where(
|
stmt = (
|
||||||
InviteModel.code == code,
|
select(InviteModel).where(
|
||||||
InviteModel.is_active == True,
|
InviteModel.code == code,
|
||||||
InviteModel.expires_at > now
|
InviteModel.is_active == True,
|
||||||
|
InviteModel.expires_at > now
|
||||||
|
)
|
||||||
|
.options(
|
||||||
|
selectinload(InviteModel.group),
|
||||||
|
selectinload(InviteModel.creator)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
result = await db.execute(stmt)
|
||||||
return result.scalars().first()
|
return result.scalars().first()
|
||||||
|
except OperationalError as e:
|
||||||
|
raise DatabaseConnectionError(f"DB connection error fetching invite: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
raise DatabaseQueryError(f"DB query error fetching invite: {str(e)}")
|
||||||
|
|
||||||
async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteModel:
|
async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteModel:
|
||||||
"""Marks an invite as inactive (used)."""
|
"""Marks an invite as inactive (used) and reloads with relationships."""
|
||||||
invite.is_active = False
|
try:
|
||||||
db.add(invite) # Add to session to track change
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||||
await db.commit()
|
invite.is_active = False
|
||||||
await db.refresh(invite)
|
db.add(invite) # Add to session to track change
|
||||||
return invite
|
await db.flush() # Persist is_active change
|
||||||
|
|
||||||
|
# Re-fetch with relationships
|
||||||
|
stmt = (
|
||||||
|
select(InviteModel)
|
||||||
|
.where(InviteModel.id == invite.id)
|
||||||
|
.options(
|
||||||
|
selectinload(InviteModel.group),
|
||||||
|
selectinload(InviteModel.creator)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
updated_invite = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if updated_invite is None: # Should not happen as invite is passed in
|
||||||
|
raise InviteOperationError("Failed to load invite after deactivation.")
|
||||||
|
|
||||||
|
return updated_invite
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error deactivating invite: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"DB connection error deactivating invite: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error deactivating invite: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"DB transaction error deactivating invite: {str(e)}")
|
||||||
|
|
||||||
|
# Ensure InviteOperationError is defined in app.core.exceptions
|
||||||
|
# Example: class InviteOperationError(AppException): pass
|
||||||
|
|
||||||
# Optional: Function to periodically delete old, inactive invites
|
# Optional: Function to periodically delete old, inactive invites
|
||||||
# async def cleanup_old_invites(db: AsyncSession, older_than_days: int = 30): ...
|
# async def cleanup_old_invites(db: AsyncSession, older_than_days: int = 30): ...
|
@ -1,67 +1,209 @@
|
|||||||
# app/crud/item.py
|
# app/crud/item.py
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
|
||||||
from sqlalchemy import delete as sql_delete, update as sql_update # Use aliases
|
from sqlalchemy import delete as sql_delete, update as sql_update # Use aliases
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
|
||||||
from typing import Optional, List as PyList
|
from typing import Optional, List as PyList
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
import logging # Add logging import
|
||||||
|
from sqlalchemy import func
|
||||||
|
|
||||||
from app.models import Item as ItemModel
|
from app.models import Item as ItemModel, User as UserModel # Import UserModel for type hints if needed for selectinload
|
||||||
from app.schemas.item import ItemCreate, ItemUpdate
|
from app.schemas.item import ItemCreate, ItemUpdate
|
||||||
|
from app.core.exceptions import (
|
||||||
|
ItemNotFoundError,
|
||||||
|
DatabaseConnectionError,
|
||||||
|
DatabaseIntegrityError,
|
||||||
|
DatabaseQueryError,
|
||||||
|
DatabaseTransactionError,
|
||||||
|
ConflictError,
|
||||||
|
ItemOperationError # Add if specific item operation errors are needed
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # Initialize logger
|
||||||
|
|
||||||
async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel:
|
async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel:
|
||||||
"""Creates a new item record for a specific list."""
|
"""Creates a new item record for a specific list, setting its position."""
|
||||||
db_item = ItemModel(
|
try:
|
||||||
name=item_in.name,
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||||
quantity=item_in.quantity,
|
# Get the current max position in the list
|
||||||
list_id=list_id,
|
max_pos_stmt = select(func.max(ItemModel.position)).where(ItemModel.list_id == list_id)
|
||||||
added_by_id=user_id,
|
max_pos_result = await db.execute(max_pos_stmt)
|
||||||
is_complete=False # Default on creation
|
max_pos = max_pos_result.scalar_one_or_none() or 0
|
||||||
)
|
|
||||||
db.add(db_item)
|
db_item = ItemModel(
|
||||||
await db.commit()
|
name=item_in.name,
|
||||||
await db.refresh(db_item)
|
quantity=item_in.quantity,
|
||||||
return db_item
|
list_id=list_id,
|
||||||
|
added_by_id=user_id,
|
||||||
|
is_complete=False,
|
||||||
|
position=max_pos + 1 # Set the new position
|
||||||
|
)
|
||||||
|
db.add(db_item)
|
||||||
|
await db.flush() # Assigns ID
|
||||||
|
|
||||||
|
# Re-fetch with relationships
|
||||||
|
stmt = (
|
||||||
|
select(ItemModel)
|
||||||
|
.where(ItemModel.id == db_item.id)
|
||||||
|
.options(
|
||||||
|
selectinload(ItemModel.added_by_user),
|
||||||
|
selectinload(ItemModel.completed_by_user) # Will be None but loads relationship
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
loaded_item = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if loaded_item is None:
|
||||||
|
# await transaction.rollback() # Redundant, context manager handles rollback on exception
|
||||||
|
raise ItemOperationError("Failed to load item after creation.") # Define ItemOperationError
|
||||||
|
|
||||||
|
return loaded_item
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during item creation: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseIntegrityError(f"Failed to create item: {str(e)}")
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error during item creation: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during item creation: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"Failed to create item: {str(e)}")
|
||||||
|
# Removed generic Exception block as SQLAlchemyError should cover DB issues,
|
||||||
|
# and context manager handles rollback.
|
||||||
|
|
||||||
async def get_items_by_list_id(db: AsyncSession, list_id: int) -> PyList[ItemModel]:
|
async def get_items_by_list_id(db: AsyncSession, list_id: int) -> PyList[ItemModel]:
|
||||||
"""Gets all items belonging to a specific list, ordered by creation time."""
|
"""Gets all items belonging to a specific list, ordered by creation time."""
|
||||||
result = await db.execute(
|
try:
|
||||||
select(ItemModel)
|
stmt = (
|
||||||
.where(ItemModel.list_id == list_id)
|
select(ItemModel)
|
||||||
.order_by(ItemModel.created_at.asc()) # Or desc() if preferred
|
.where(ItemModel.list_id == list_id)
|
||||||
)
|
.options(
|
||||||
return result.scalars().all()
|
selectinload(ItemModel.added_by_user),
|
||||||
|
selectinload(ItemModel.completed_by_user)
|
||||||
|
)
|
||||||
|
.order_by(ItemModel.created_at.asc())
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
except OperationalError as e:
|
||||||
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
raise DatabaseQueryError(f"Failed to query items: {str(e)}")
|
||||||
|
|
||||||
async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]:
|
async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]:
|
||||||
"""Gets a single item by its ID."""
|
"""Gets a single item by its ID."""
|
||||||
result = await db.execute(select(ItemModel).where(ItemModel.id == item_id))
|
try:
|
||||||
return result.scalars().first()
|
stmt = (
|
||||||
|
select(ItemModel)
|
||||||
|
.where(ItemModel.id == item_id)
|
||||||
|
.options(
|
||||||
|
selectinload(ItemModel.added_by_user),
|
||||||
|
selectinload(ItemModel.completed_by_user),
|
||||||
|
selectinload(ItemModel.list) # Often useful to get the parent list
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().first()
|
||||||
|
except OperationalError as e:
|
||||||
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
raise DatabaseQueryError(f"Failed to query item: {str(e)}")
|
||||||
|
|
||||||
async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel:
|
async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel:
|
||||||
"""Updates an existing item record."""
|
"""Updates an existing item record, checking for version conflicts and handling reordering."""
|
||||||
update_data = item_in.model_dump(exclude_unset=True) # Get only provided fields
|
try:
|
||||||
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||||
|
if item_db.version != item_in.version:
|
||||||
|
raise ConflictError(
|
||||||
|
f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. "
|
||||||
|
f"Your version is {item_in.version}, current version is {item_db.version}. Please refresh."
|
||||||
|
)
|
||||||
|
|
||||||
# Special handling for is_complete
|
update_data = item_in.model_dump(exclude_unset=True, exclude={'version'})
|
||||||
if 'is_complete' in update_data:
|
|
||||||
if update_data['is_complete'] is True:
|
|
||||||
# Mark as complete: set completed_by_id if not already set
|
|
||||||
if item_db.completed_by_id is None:
|
|
||||||
update_data['completed_by_id'] = user_id
|
|
||||||
else:
|
|
||||||
# Mark as incomplete: clear completed_by_id
|
|
||||||
update_data['completed_by_id'] = None
|
|
||||||
# Ensure updated_at is refreshed (handled by onupdate in model, but explicit is fine too)
|
|
||||||
# update_data['updated_at'] = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
for key, value in update_data.items():
|
# --- Handle Reordering ---
|
||||||
setattr(item_db, key, value)
|
if 'position' in update_data:
|
||||||
|
new_position = update_data.pop('position') # Remove from update_data to handle separately
|
||||||
|
|
||||||
db.add(item_db) # Add to session to track changes
|
# We need the full list to reorder, making sure it's loaded and ordered
|
||||||
await db.commit()
|
list_id = item_db.list_id
|
||||||
await db.refresh(item_db)
|
stmt = select(ItemModel).where(ItemModel.list_id == list_id).order_by(ItemModel.position.asc(), ItemModel.created_at.asc())
|
||||||
return item_db
|
result = await db.execute(stmt)
|
||||||
|
items_in_list = result.scalars().all()
|
||||||
|
|
||||||
|
# Find the item to move
|
||||||
|
item_to_move = next((it for it in items_in_list if it.id == item_db.id), None)
|
||||||
|
if item_to_move:
|
||||||
|
items_in_list.remove(item_to_move)
|
||||||
|
# Insert at the new position (adjust for 1-based index from frontend)
|
||||||
|
# Clamp position to be within bounds
|
||||||
|
insert_pos = max(0, min(new_position - 1, len(items_in_list)))
|
||||||
|
items_in_list.insert(insert_pos, item_to_move)
|
||||||
|
|
||||||
|
# Re-assign positions
|
||||||
|
for i, item in enumerate(items_in_list):
|
||||||
|
item.position = i + 1
|
||||||
|
|
||||||
|
# --- End Handle Reordering ---
|
||||||
|
|
||||||
|
if 'is_complete' in update_data:
|
||||||
|
if update_data['is_complete'] is True:
|
||||||
|
if item_db.completed_by_id is None:
|
||||||
|
update_data['completed_by_id'] = user_id
|
||||||
|
else:
|
||||||
|
update_data['completed_by_id'] = None
|
||||||
|
|
||||||
|
for key, value in update_data.items():
|
||||||
|
setattr(item_db, key, value)
|
||||||
|
|
||||||
|
item_db.version += 1
|
||||||
|
db.add(item_db) # Mark as dirty
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
# Re-fetch with relationships
|
||||||
|
stmt = (
|
||||||
|
select(ItemModel)
|
||||||
|
.where(ItemModel.id == item_db.id)
|
||||||
|
.options(
|
||||||
|
selectinload(ItemModel.added_by_user),
|
||||||
|
selectinload(ItemModel.completed_by_user),
|
||||||
|
selectinload(ItemModel.list)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
updated_item = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if updated_item is None: # Should not happen
|
||||||
|
# Rollback will be handled by context manager on raise
|
||||||
|
raise ItemOperationError("Failed to load item after update.")
|
||||||
|
|
||||||
|
return updated_item
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during item update: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}")
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error while updating item: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}")
|
||||||
|
except ConflictError: # Re-raise ConflictError, rollback handled by context manager
|
||||||
|
raise
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during item update: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"Failed to update item: {str(e)}")
|
||||||
|
|
||||||
async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
|
async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
|
||||||
"""Deletes an item record."""
|
"""Deletes an item record. Version check should be done by the caller (API endpoint)."""
|
||||||
await db.delete(item_db)
|
try:
|
||||||
await db.commit()
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||||
return None # Or return True/False
|
await db.delete(item_db)
|
||||||
|
# await transaction.commit() # Removed
|
||||||
|
# No return needed for None
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error while deleting item: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error while deleting item: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"Failed to delete item: {str(e)}")
|
||||||
|
|
||||||
|
# Ensure ItemOperationError is defined in app.core.exceptions if used
|
||||||
|
# Example: class ItemOperationError(AppException): pass
|
@ -2,150 +2,351 @@
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from sqlalchemy.orm import selectinload, joinedload
|
from sqlalchemy.orm import selectinload, joinedload
|
||||||
from sqlalchemy import or_, and_, delete as sql_delete # Use alias for delete
|
from sqlalchemy import or_, and_, delete as sql_delete, func as sql_func, desc
|
||||||
from typing import Optional, List as PyList # Use alias for List
|
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
|
||||||
from sqlalchemy import func as sql_func, desc # Import func and desc
|
from typing import Optional, List as PyList
|
||||||
|
import logging # Add logging import
|
||||||
|
|
||||||
from app.schemas.list import ListStatus # Import the new schema
|
from app.schemas.list import ListStatus
|
||||||
from app.models import List as ListModel, UserGroup as UserGroupModel, Item as ItemModel
|
from app.models import List as ListModel, UserGroup as UserGroupModel, Item as ItemModel
|
||||||
from app.schemas.list import ListCreate, ListUpdate
|
from app.schemas.list import ListCreate, ListUpdate
|
||||||
|
from app.core.exceptions import (
|
||||||
|
ListNotFoundError,
|
||||||
|
ListPermissionError,
|
||||||
|
ListCreatorRequiredError,
|
||||||
|
DatabaseConnectionError,
|
||||||
|
DatabaseIntegrityError,
|
||||||
|
DatabaseQueryError,
|
||||||
|
DatabaseTransactionError,
|
||||||
|
ConflictError,
|
||||||
|
ListOperationError
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # Initialize logger
|
||||||
|
|
||||||
async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
|
async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
|
||||||
"""Creates a new list record."""
|
"""Creates a new list record."""
|
||||||
db_list = ListModel(
|
try:
|
||||||
name=list_in.name,
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||||
description=list_in.description,
|
db_list = ListModel(
|
||||||
group_id=list_in.group_id,
|
name=list_in.name,
|
||||||
created_by_id=creator_id,
|
description=list_in.description,
|
||||||
is_complete=False # Default on creation
|
group_id=list_in.group_id,
|
||||||
)
|
created_by_id=creator_id,
|
||||||
db.add(db_list)
|
is_complete=False
|
||||||
await db.commit()
|
)
|
||||||
await db.refresh(db_list)
|
db.add(db_list)
|
||||||
return db_list
|
await db.flush() # Assigns ID
|
||||||
|
|
||||||
|
# Re-fetch with relationships for the response
|
||||||
|
stmt = (
|
||||||
|
select(ListModel)
|
||||||
|
.where(ListModel.id == db_list.id)
|
||||||
|
.options(
|
||||||
|
selectinload(ListModel.creator),
|
||||||
|
selectinload(ListModel.group)
|
||||||
|
# selectinload(ListModel.items) # Optionally add if items are always needed in response
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
loaded_list = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if loaded_list is None:
|
||||||
|
raise ListOperationError("Failed to load list after creation.")
|
||||||
|
|
||||||
|
return loaded_list
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during list creation: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseIntegrityError(f"Failed to create list: {str(e)}")
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error during list creation: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during list creation: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"Failed to create list: {str(e)}")
|
||||||
|
|
||||||
async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel]:
|
async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel]:
|
||||||
"""
|
"""Gets all lists accessible by a user."""
|
||||||
Gets all lists accessible by a user:
|
try:
|
||||||
- Personal lists created by the user (group_id is NULL).
|
group_ids_result = await db.execute(
|
||||||
- Lists belonging to groups the user is a member of.
|
select(UserGroupModel.group_id).where(UserGroupModel.user_id == user_id)
|
||||||
"""
|
|
||||||
# Get IDs of groups the user is a member of
|
|
||||||
group_ids_result = await db.execute(
|
|
||||||
select(UserGroupModel.group_id).where(UserGroupModel.user_id == user_id)
|
|
||||||
)
|
|
||||||
user_group_ids = group_ids_result.scalars().all()
|
|
||||||
|
|
||||||
# Query for lists
|
|
||||||
query = select(ListModel).where(
|
|
||||||
or_(
|
|
||||||
# Personal lists
|
|
||||||
and_(ListModel.created_by_id == user_id, ListModel.group_id == None),
|
|
||||||
# Group lists where user is a member
|
|
||||||
ListModel.group_id.in_(user_group_ids)
|
|
||||||
)
|
)
|
||||||
).order_by(ListModel.updated_at.desc()) # Order by most recently updated
|
user_group_ids = group_ids_result.scalars().all()
|
||||||
|
|
||||||
result = await db.execute(query)
|
conditions = [
|
||||||
return result.scalars().all()
|
and_(ListModel.created_by_id == user_id, ListModel.group_id.is_(None))
|
||||||
|
]
|
||||||
|
if user_group_ids:
|
||||||
|
conditions.append(ListModel.group_id.in_(user_group_ids))
|
||||||
|
|
||||||
|
query = (
|
||||||
|
select(ListModel)
|
||||||
|
.where(or_(*conditions))
|
||||||
|
.options(
|
||||||
|
selectinload(ListModel.creator),
|
||||||
|
selectinload(ListModel.group),
|
||||||
|
selectinload(ListModel.items).options(
|
||||||
|
joinedload(ItemModel.added_by_user),
|
||||||
|
joinedload(ItemModel.completed_by_user)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.order_by(ListModel.updated_at.desc())
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
return result.scalars().all()
|
||||||
|
except OperationalError as e:
|
||||||
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
raise DatabaseQueryError(f"Failed to query user lists: {str(e)}")
|
||||||
|
|
||||||
async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = False) -> Optional[ListModel]:
|
async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = False) -> Optional[ListModel]:
|
||||||
"""Gets a single list by ID, optionally loading its items."""
|
"""Gets a single list by ID, optionally loading its items."""
|
||||||
query = select(ListModel).where(ListModel.id == list_id)
|
try:
|
||||||
if load_items:
|
query = (
|
||||||
# Eager load items and their creators/completers if needed
|
select(ListModel)
|
||||||
query = query.options(
|
.where(ListModel.id == list_id)
|
||||||
selectinload(ListModel.items)
|
|
||||||
.options(
|
.options(
|
||||||
joinedload(ItemModel.added_by_user), # Use joinedload for simple FKs
|
selectinload(ListModel.creator),
|
||||||
joinedload(ItemModel.completed_by_user)
|
selectinload(ListModel.group)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
result = await db.execute(query)
|
if load_items:
|
||||||
return result.scalars().first()
|
query = query.options(
|
||||||
|
selectinload(ListModel.items).options(
|
||||||
|
joinedload(ItemModel.added_by_user),
|
||||||
|
joinedload(ItemModel.completed_by_user)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(query)
|
||||||
|
return result.scalars().first()
|
||||||
|
except OperationalError as e:
|
||||||
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
raise DatabaseQueryError(f"Failed to query list: {str(e)}")
|
||||||
|
|
||||||
async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel:
|
async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel:
|
||||||
"""Updates an existing list record."""
|
"""Updates an existing list record, checking for version conflicts."""
|
||||||
update_data = list_in.model_dump(exclude_unset=True) # Get only provided fields
|
try:
|
||||||
for key, value in update_data.items():
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||||
setattr(list_db, key, value)
|
if list_db.version != list_in.version: # list_db here is the one passed in, pre-loaded by API layer
|
||||||
db.add(list_db) # Add to session to track changes
|
raise ConflictError(
|
||||||
await db.commit()
|
f"List '{list_db.name}' (ID: {list_db.id}) has been modified. "
|
||||||
await db.refresh(list_db)
|
f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh."
|
||||||
return list_db
|
)
|
||||||
|
|
||||||
|
update_data = list_in.model_dump(exclude_unset=True, exclude={'version'})
|
||||||
|
|
||||||
|
for key, value in update_data.items():
|
||||||
|
setattr(list_db, key, value)
|
||||||
|
|
||||||
|
list_db.version += 1
|
||||||
|
|
||||||
|
db.add(list_db) # Add the already attached list_db to mark it dirty for the session
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
# Re-fetch with relationships for the response
|
||||||
|
stmt = (
|
||||||
|
select(ListModel)
|
||||||
|
.where(ListModel.id == list_db.id)
|
||||||
|
.options(
|
||||||
|
selectinload(ListModel.creator),
|
||||||
|
selectinload(ListModel.group)
|
||||||
|
# selectinload(ListModel.items) # Optionally add if items are always needed in response
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
updated_list = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if updated_list is None: # Should not happen
|
||||||
|
raise ListOperationError("Failed to load list after update.")
|
||||||
|
|
||||||
|
return updated_list
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during list update: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}")
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error while updating list: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}")
|
||||||
|
except ConflictError:
|
||||||
|
raise
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during list update: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"Failed to update list: {str(e)}")
|
||||||
|
|
||||||
async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
|
async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
|
||||||
"""Deletes a list record."""
|
"""Deletes a list record. Version check should be done by the caller (API endpoint)."""
|
||||||
# Items should be deleted automatically due to cascade="all, delete-orphan"
|
try:
|
||||||
# on List.items relationship and ondelete="CASCADE" on Item.list_id FK
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Standardize transaction
|
||||||
await db.delete(list_db)
|
await db.delete(list_db)
|
||||||
await db.commit()
|
except OperationalError as e:
|
||||||
return None # Or return True/False if needed
|
logger.error(f"Database connection error while deleting list: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error while deleting list: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"Failed to delete list: {str(e)}")
|
||||||
|
|
||||||
# --- Helper for Permission Checks ---
|
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel:
|
||||||
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> Optional[ListModel]:
|
"""Fetches a list and verifies user permission."""
|
||||||
|
try:
|
||||||
|
list_db = await get_list_by_id(db, list_id=list_id, load_items=True)
|
||||||
|
if not list_db:
|
||||||
|
raise ListNotFoundError(list_id)
|
||||||
|
|
||||||
|
is_creator = list_db.created_by_id == user_id
|
||||||
|
|
||||||
|
if require_creator:
|
||||||
|
if not is_creator:
|
||||||
|
raise ListCreatorRequiredError(list_id, "access")
|
||||||
|
return list_db
|
||||||
|
|
||||||
|
if is_creator:
|
||||||
|
return list_db
|
||||||
|
|
||||||
|
if list_db.group_id:
|
||||||
|
from app.crud.group import is_user_member
|
||||||
|
is_member = await is_user_member(db, group_id=list_db.group_id, user_id=user_id)
|
||||||
|
if not is_member:
|
||||||
|
raise ListPermissionError(list_id)
|
||||||
|
return list_db
|
||||||
|
else:
|
||||||
|
raise ListPermissionError(list_id)
|
||||||
|
except OperationalError as e:
|
||||||
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
raise DatabaseQueryError(f"Failed to check list permissions: {str(e)}")
|
||||||
|
|
||||||
|
async def get_list_status(db: AsyncSession, list_id: int) -> ListStatus:
|
||||||
|
"""Gets the update timestamps and item count for a list."""
|
||||||
|
try:
|
||||||
|
query = (
|
||||||
|
select(
|
||||||
|
ListModel.updated_at,
|
||||||
|
sql_func.count(ItemModel.id).label("item_count"),
|
||||||
|
sql_func.max(ItemModel.updated_at).label("latest_item_updated_at")
|
||||||
|
)
|
||||||
|
.select_from(ListModel)
|
||||||
|
.outerjoin(ItemModel, ItemModel.list_id == ListModel.id)
|
||||||
|
.where(ListModel.id == list_id)
|
||||||
|
.group_by(ListModel.id)
|
||||||
|
)
|
||||||
|
result = await db.execute(query)
|
||||||
|
status = result.first()
|
||||||
|
|
||||||
|
if status is None:
|
||||||
|
raise ListNotFoundError(list_id)
|
||||||
|
|
||||||
|
return ListStatus(
|
||||||
|
updated_at=status.updated_at,
|
||||||
|
item_count=status.item_count,
|
||||||
|
latest_item_updated_at=status.latest_item_updated_at
|
||||||
|
)
|
||||||
|
except OperationalError as e:
|
||||||
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
raise DatabaseQueryError(f"Failed to get list status: {str(e)}")
|
||||||
|
|
||||||
|
async def get_list_by_name_and_group(
|
||||||
|
db: AsyncSession,
|
||||||
|
name: str,
|
||||||
|
group_id: Optional[int],
|
||||||
|
user_id: int # user_id is for permission check, not direct list attribute
|
||||||
|
) -> Optional[ListModel]:
|
||||||
"""
|
"""
|
||||||
Fetches a list and verifies user permission.
|
Gets a list by name and group, ensuring the user has permission to access it.
|
||||||
|
Used for conflict resolution when creating lists.
|
||||||
Args:
|
|
||||||
db: Database session.
|
|
||||||
list_id: The ID of the list to check.
|
|
||||||
user_id: The ID of the user requesting access.
|
|
||||||
require_creator: If True, only allows the creator access.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The ListModel if found and permission granted, otherwise None.
|
|
||||||
(Raising exceptions might be better handled in the endpoint).
|
|
||||||
"""
|
"""
|
||||||
list_db = await get_list_by_id(db, list_id=list_id, load_items=True) # Load items for detail/update/delete context
|
try:
|
||||||
if not list_db:
|
# Base query for the list itself
|
||||||
return None # List not found
|
base_query = select(ListModel).where(ListModel.name == name)
|
||||||
|
|
||||||
# Check if user is the creator
|
if group_id is not None:
|
||||||
is_creator = list_db.created_by_id == user_id
|
base_query = base_query.where(ListModel.group_id == group_id)
|
||||||
|
else:
|
||||||
|
base_query = base_query.where(ListModel.group_id.is_(None))
|
||||||
|
|
||||||
if require_creator:
|
# Add eager loading for common relationships
|
||||||
return list_db if is_creator else None
|
base_query = base_query.options(
|
||||||
|
selectinload(ListModel.creator),
|
||||||
|
selectinload(ListModel.group)
|
||||||
|
)
|
||||||
|
|
||||||
# If not requiring creator, check membership if it's a group list
|
list_result = await db.execute(base_query)
|
||||||
if is_creator:
|
target_list = list_result.scalar_one_or_none()
|
||||||
return list_db # Creator always has access
|
|
||||||
|
|
||||||
if list_db.group_id:
|
if not target_list:
|
||||||
# Check if user is member of the list's group
|
return None
|
||||||
from app.crud.group import is_user_member # Avoid circular import at top level
|
|
||||||
is_member = await is_user_member(db, group_id=list_db.group_id, user_id=user_id)
|
# Permission check
|
||||||
return list_db if is_member else None
|
is_creator = target_list.created_by_id == user_id
|
||||||
else:
|
|
||||||
# Personal list, not the creator -> no access
|
if is_creator:
|
||||||
|
return target_list
|
||||||
|
|
||||||
|
if target_list.group_id:
|
||||||
|
from app.crud.group import is_user_member # Assuming this is a quick check not needing its own transaction
|
||||||
|
is_member_of_group = await is_user_member(db, group_id=target_list.group_id, user_id=user_id)
|
||||||
|
if is_member_of_group:
|
||||||
|
return target_list
|
||||||
|
|
||||||
|
# If not creator and (not a group list or not a member of the group list)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_list_status(db: AsyncSession, list_id: int) -> Optional[ListStatus]:
|
except OperationalError as e:
|
||||||
"""
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||||
Gets the update timestamps and item count for a list.
|
except SQLAlchemyError as e:
|
||||||
Returns None if the list itself doesn't exist.
|
raise DatabaseQueryError(f"Failed to query list by name and group: {str(e)}")
|
||||||
"""
|
|
||||||
# Fetch list updated_at time
|
|
||||||
list_query = select(ListModel.updated_at).where(ListModel.id == list_id)
|
|
||||||
list_result = await db.execute(list_query)
|
|
||||||
list_updated_at = list_result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if list_updated_at is None:
|
async def get_lists_statuses_by_ids(db: AsyncSession, list_ids: PyList[int], user_id: int) -> PyList[ListModel]:
|
||||||
return None # List not found
|
"""
|
||||||
|
Gets status for a list of lists if the user has permission.
|
||||||
|
Status includes list updated_at and a count of its items.
|
||||||
|
"""
|
||||||
|
if not list_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
# Fetch the latest item update time and count for that list
|
try:
|
||||||
item_status_query = (
|
# First, get the groups the user is a member of
|
||||||
select(
|
group_ids_result = await db.execute(
|
||||||
sql_func.max(ItemModel.updated_at).label("latest_item_updated_at"),
|
select(UserGroupModel.group_id).where(UserGroupModel.user_id == user_id)
|
||||||
sql_func.count(ItemModel.id).label("item_count")
|
|
||||||
)
|
)
|
||||||
.where(ItemModel.list_id == list_id)
|
user_group_ids = group_ids_result.scalars().all()
|
||||||
)
|
|
||||||
item_result = await db.execute(item_status_query)
|
|
||||||
item_status = item_result.first() # Use first() as aggregate always returns one row
|
|
||||||
|
|
||||||
return ListStatus(
|
# Build the permission logic
|
||||||
list_updated_at=list_updated_at,
|
permission_filter = or_(
|
||||||
latest_item_updated_at=item_status.latest_item_updated_at if item_status else None,
|
# User is the creator of the list
|
||||||
item_count=item_status.item_count if item_status else 0
|
and_(ListModel.created_by_id == user_id, ListModel.group_id.is_(None)),
|
||||||
)
|
# List belongs to a group the user is a member of
|
||||||
|
ListModel.group_id.in_(user_group_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Main query to get list data and item counts
|
||||||
|
query = (
|
||||||
|
select(
|
||||||
|
ListModel.id,
|
||||||
|
ListModel.updated_at,
|
||||||
|
sql_func.count(ItemModel.id).label("item_count"),
|
||||||
|
sql_func.max(ItemModel.updated_at).label("latest_item_updated_at")
|
||||||
|
)
|
||||||
|
.outerjoin(ItemModel, ListModel.id == ItemModel.list_id)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
ListModel.id.in_(list_ids),
|
||||||
|
permission_filter
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.group_by(ListModel.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
|
||||||
|
# The result will be rows of (id, updated_at, item_count).
|
||||||
|
# We need to verify that all requested list_ids that the user *should* have access to are present.
|
||||||
|
# The filter in the query already handles permissions.
|
||||||
|
|
||||||
|
return result.all() # Returns a list of Row objects with id, updated_at, item_count
|
||||||
|
|
||||||
|
except OperationalError as e:
|
||||||
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
raise DatabaseQueryError(f"Failed to get lists statuses: {str(e)}")
|
281
be/app/crud/settlement.py
Normal file
281
be/app/crud/settlement.py
Normal file
@ -0,0 +1,281 @@
|
|||||||
|
# app/crud/settlement.py
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
from sqlalchemy.orm import selectinload, joinedload
|
||||||
|
from sqlalchemy import or_
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
|
||||||
|
from decimal import Decimal, ROUND_HALF_UP
|
||||||
|
from typing import List as PyList, Optional, Sequence
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import logging # Add logging import
|
||||||
|
|
||||||
|
from app.models import (
|
||||||
|
Settlement as SettlementModel,
|
||||||
|
User as UserModel,
|
||||||
|
Group as GroupModel,
|
||||||
|
UserGroup as UserGroupModel
|
||||||
|
)
|
||||||
|
from app.schemas.expense import SettlementCreate, SettlementUpdate
|
||||||
|
from app.core.exceptions import (
|
||||||
|
UserNotFoundError,
|
||||||
|
GroupNotFoundError,
|
||||||
|
InvalidOperationError,
|
||||||
|
DatabaseConnectionError,
|
||||||
|
DatabaseIntegrityError,
|
||||||
|
DatabaseQueryError,
|
||||||
|
DatabaseTransactionError,
|
||||||
|
SettlementOperationError,
|
||||||
|
ConflictError
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # Initialize logger
|
||||||
|
|
||||||
|
async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel:
|
||||||
|
"""Creates a new settlement record."""
|
||||||
|
try:
|
||||||
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
|
payer = await db.get(UserModel, settlement_in.paid_by_user_id)
|
||||||
|
if not payer:
|
||||||
|
raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer")
|
||||||
|
|
||||||
|
payee = await db.get(UserModel, settlement_in.paid_to_user_id)
|
||||||
|
if not payee:
|
||||||
|
raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee")
|
||||||
|
|
||||||
|
if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id:
|
||||||
|
raise InvalidOperationError("Payer and Payee cannot be the same user.")
|
||||||
|
|
||||||
|
group = await db.get(GroupModel, settlement_in.group_id)
|
||||||
|
if not group:
|
||||||
|
raise GroupNotFoundError(settlement_in.group_id)
|
||||||
|
|
||||||
|
# Permission check example (can be in API layer too)
|
||||||
|
# if current_user_id not in [payer.id, payee.id]:
|
||||||
|
# is_member_stmt = select(UserGroupModel.id).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id).limit(1)
|
||||||
|
# is_member_result = await db.execute(is_member_stmt)
|
||||||
|
# if not is_member_result.scalar_one_or_none():
|
||||||
|
# raise InvalidOperationError("Settlement recorder must be part of the group or one of the parties.")
|
||||||
|
|
||||||
|
db_settlement = SettlementModel(
|
||||||
|
group_id=settlement_in.group_id,
|
||||||
|
paid_by_user_id=settlement_in.paid_by_user_id,
|
||||||
|
paid_to_user_id=settlement_in.paid_to_user_id,
|
||||||
|
amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||||
|
settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc),
|
||||||
|
description=settlement_in.description,
|
||||||
|
created_by_user_id=current_user_id
|
||||||
|
)
|
||||||
|
db.add(db_settlement)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
# Re-fetch with relationships
|
||||||
|
stmt = (
|
||||||
|
select(SettlementModel)
|
||||||
|
.where(SettlementModel.id == db_settlement.id)
|
||||||
|
.options(
|
||||||
|
selectinload(SettlementModel.payer),
|
||||||
|
selectinload(SettlementModel.payee),
|
||||||
|
selectinload(SettlementModel.group),
|
||||||
|
selectinload(SettlementModel.created_by_user)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
loaded_settlement = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if loaded_settlement is None:
|
||||||
|
raise SettlementOperationError("Failed to load settlement after creation.")
|
||||||
|
|
||||||
|
return loaded_settlement
|
||||||
|
except (UserNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
|
||||||
|
# These are validation errors, re-raise them.
|
||||||
|
# If a transaction was started, context manager handles rollback.
|
||||||
|
raise
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during settlement creation: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseIntegrityError(f"Failed to save settlement due to DB integrity: {str(e)}")
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error during settlement creation: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"DB connection error during settlement creation: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during settlement creation: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"DB transaction error during settlement creation: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional[SettlementModel]:
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(SettlementModel)
|
||||||
|
.options(
|
||||||
|
selectinload(SettlementModel.payer),
|
||||||
|
selectinload(SettlementModel.payee),
|
||||||
|
selectinload(SettlementModel.group),
|
||||||
|
selectinload(SettlementModel.created_by_user)
|
||||||
|
)
|
||||||
|
.where(SettlementModel.id == settlement_id)
|
||||||
|
)
|
||||||
|
return result.scalars().first()
|
||||||
|
except OperationalError as e:
|
||||||
|
# Optional: logger.warning or info if needed for read operations
|
||||||
|
raise DatabaseConnectionError(f"DB connection error fetching settlement: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
# Optional: logger.warning or info if needed for read operations
|
||||||
|
raise DatabaseQueryError(f"DB query error fetching settlement: {str(e)}")
|
||||||
|
|
||||||
|
async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]:
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(SettlementModel)
|
||||||
|
.where(SettlementModel.group_id == group_id)
|
||||||
|
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
|
||||||
|
.offset(skip).limit(limit)
|
||||||
|
.options(
|
||||||
|
selectinload(SettlementModel.payer),
|
||||||
|
selectinload(SettlementModel.payee),
|
||||||
|
selectinload(SettlementModel.group),
|
||||||
|
selectinload(SettlementModel.created_by_user)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
except OperationalError as e:
|
||||||
|
raise DatabaseConnectionError(f"DB connection error fetching group settlements: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
raise DatabaseQueryError(f"DB query error fetching group settlements: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_settlements_involving_user(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: int,
|
||||||
|
group_id: Optional[int] = None,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100
|
||||||
|
) -> Sequence[SettlementModel]:
|
||||||
|
try:
|
||||||
|
query = (
|
||||||
|
select(SettlementModel)
|
||||||
|
.where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id))
|
||||||
|
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
|
||||||
|
.offset(skip).limit(limit)
|
||||||
|
.options(
|
||||||
|
selectinload(SettlementModel.payer),
|
||||||
|
selectinload(SettlementModel.payee),
|
||||||
|
selectinload(SettlementModel.group),
|
||||||
|
selectinload(SettlementModel.created_by_user)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if group_id:
|
||||||
|
query = query.where(SettlementModel.group_id == group_id)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
return result.scalars().all()
|
||||||
|
except OperationalError as e:
|
||||||
|
raise DatabaseConnectionError(f"DB connection error fetching user settlements: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
raise DatabaseQueryError(f"DB query error fetching user settlements: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, settlement_in: SettlementUpdate) -> SettlementModel:
|
||||||
|
"""
|
||||||
|
Updates an existing settlement.
|
||||||
|
Only allows updates to description and settlement_date.
|
||||||
|
Requires version matching for optimistic locking.
|
||||||
|
Assumes SettlementUpdate schema includes a version field.
|
||||||
|
Assumes SettlementModel has version and updated_at fields.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
|
# Ensure the settlement_db passed is managed by the current session if not already.
|
||||||
|
# This is usually true if fetched by an endpoint dependency using the same session.
|
||||||
|
# If not, `db.add(settlement_db)` might be needed before modification if it's detached.
|
||||||
|
|
||||||
|
if not hasattr(settlement_db, 'version') or not hasattr(settlement_in, 'version'):
|
||||||
|
raise InvalidOperationError("Version field is missing in model or input for optimistic locking.")
|
||||||
|
|
||||||
|
if settlement_db.version != settlement_in.version:
|
||||||
|
raise ConflictError( # Make sure ConflictError is defined in exceptions
|
||||||
|
f"Settlement (ID: {settlement_db.id}) has been modified. "
|
||||||
|
f"Your version {settlement_in.version} does not match current version {settlement_db.version}. Please refresh."
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"})
|
||||||
|
allowed_to_update = {"description", "settlement_date"}
|
||||||
|
updated_something = False
|
||||||
|
|
||||||
|
for field, value in update_data.items():
|
||||||
|
if field in allowed_to_update:
|
||||||
|
setattr(settlement_db, field, value)
|
||||||
|
updated_something = True
|
||||||
|
# Silently ignore fields not allowed to update or raise error:
|
||||||
|
# else:
|
||||||
|
# raise InvalidOperationError(f"Field '{field}' cannot be updated.")
|
||||||
|
|
||||||
|
if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update):
|
||||||
|
# No updatable fields were actually provided, or they didn't change
|
||||||
|
# Still, we might want to return the re-loaded settlement if version matched.
|
||||||
|
pass
|
||||||
|
|
||||||
|
settlement_db.version += 1
|
||||||
|
settlement_db.updated_at = datetime.now(timezone.utc) # Ensure model has this field
|
||||||
|
|
||||||
|
db.add(settlement_db) # Mark as dirty
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
# Re-fetch with relationships
|
||||||
|
stmt = (
|
||||||
|
select(SettlementModel)
|
||||||
|
.where(SettlementModel.id == settlement_db.id)
|
||||||
|
.options(
|
||||||
|
selectinload(SettlementModel.payer),
|
||||||
|
selectinload(SettlementModel.payee),
|
||||||
|
selectinload(SettlementModel.group),
|
||||||
|
selectinload(SettlementModel.created_by_user)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
updated_settlement = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if updated_settlement is None: # Should not happen
|
||||||
|
raise SettlementOperationError("Failed to load settlement after update.")
|
||||||
|
|
||||||
|
return updated_settlement
|
||||||
|
except ConflictError as e: # ConflictError should be defined in exceptions
|
||||||
|
raise
|
||||||
|
except InvalidOperationError as e:
|
||||||
|
raise
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during settlement update: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseIntegrityError(f"Failed to update settlement due to DB integrity: {str(e)}")
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error during settlement update: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"DB connection error during settlement update: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during settlement update: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"DB transaction error during settlement update: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, expected_version: Optional[int] = None) -> None:
|
||||||
|
"""
|
||||||
|
Deletes a settlement. Requires version matching if expected_version is provided.
|
||||||
|
Assumes SettlementModel has a version field.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
|
if expected_version is not None:
|
||||||
|
if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version:
|
||||||
|
raise ConflictError( # Make sure ConflictError is defined
|
||||||
|
f"Settlement (ID: {settlement_db.id}) cannot be deleted. "
|
||||||
|
f"Expected version {expected_version} does not match current version {settlement_db.version}. Please refresh."
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.delete(settlement_db)
|
||||||
|
except ConflictError as e: # ConflictError should be defined
|
||||||
|
raise
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error during settlement deletion: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"DB connection error during settlement deletion: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during settlement deletion: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"DB transaction error during settlement deletion: {str(e)}")
|
||||||
|
|
||||||
|
# Ensure SettlementOperationError and ConflictError are defined in app.core.exceptions
|
||||||
|
# Example: class SettlementOperationError(AppException): pass
|
||||||
|
# Example: class ConflictError(AppException): status_code = 409
|
211
be/app/crud/settlement_activity.py
Normal file
211
be/app/crud/settlement_activity.py
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from decimal import Decimal
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import select, func, update, delete
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import selectinload, joinedload
|
||||||
|
|
||||||
|
from app.models import (
|
||||||
|
SettlementActivity,
|
||||||
|
ExpenseSplit,
|
||||||
|
Expense,
|
||||||
|
User,
|
||||||
|
ExpenseSplitStatusEnum,
|
||||||
|
ExpenseOverallStatusEnum,
|
||||||
|
)
|
||||||
|
# Placeholder for Pydantic schema - actual schema definition is a later step
|
||||||
|
# from app.schemas.settlement_activity import SettlementActivityCreate # Assuming this path
|
||||||
|
from pydantic import BaseModel # Using pydantic BaseModel directly for the placeholder
|
||||||
|
|
||||||
|
|
||||||
|
class SettlementActivityCreatePlaceholder(BaseModel):
|
||||||
|
expense_split_id: int
|
||||||
|
paid_by_user_id: int
|
||||||
|
amount_paid: Decimal
|
||||||
|
paid_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True # Pydantic V1 style orm_mode
|
||||||
|
# from_attributes = True # Pydantic V2 style
|
||||||
|
|
||||||
|
|
||||||
|
async def update_expense_split_status(db: AsyncSession, expense_split_id: int) -> Optional[ExpenseSplit]:
|
||||||
|
"""
|
||||||
|
Updates the status of an ExpenseSplit based on its settlement activities.
|
||||||
|
Also updates the overall status of the parent Expense.
|
||||||
|
"""
|
||||||
|
# Fetch the ExpenseSplit with its related settlement_activities and the parent expense
|
||||||
|
result = await db.execute(
|
||||||
|
select(ExpenseSplit)
|
||||||
|
.options(
|
||||||
|
selectinload(ExpenseSplit.settlement_activities),
|
||||||
|
joinedload(ExpenseSplit.expense) # To get expense_id easily
|
||||||
|
)
|
||||||
|
.where(ExpenseSplit.id == expense_split_id)
|
||||||
|
)
|
||||||
|
expense_split = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not expense_split:
|
||||||
|
# Or raise an exception, depending on desired error handling
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Calculate total_paid from all settlement_activities for that split
|
||||||
|
total_paid = sum(activity.amount_paid for activity in expense_split.settlement_activities)
|
||||||
|
total_paid = Decimal(total_paid).quantize(Decimal("0.01")) # Ensure two decimal places
|
||||||
|
|
||||||
|
# Compare total_paid with ExpenseSplit.owed_amount
|
||||||
|
if total_paid >= expense_split.owed_amount:
|
||||||
|
expense_split.status = ExpenseSplitStatusEnum.paid
|
||||||
|
# Set paid_at to the latest relevant SettlementActivity or current time
|
||||||
|
# For simplicity, let's find the latest paid_at from activities, or use now()
|
||||||
|
latest_paid_at = None
|
||||||
|
if expense_split.settlement_activities:
|
||||||
|
latest_paid_at = max(act.paid_at for act in expense_split.settlement_activities if act.paid_at)
|
||||||
|
|
||||||
|
expense_split.paid_at = latest_paid_at if latest_paid_at else datetime.now(timezone.utc)
|
||||||
|
elif total_paid > 0:
|
||||||
|
expense_split.status = ExpenseSplitStatusEnum.partially_paid
|
||||||
|
expense_split.paid_at = None # Clear paid_at if not fully paid
|
||||||
|
else: # total_paid == 0
|
||||||
|
expense_split.status = ExpenseSplitStatusEnum.unpaid
|
||||||
|
expense_split.paid_at = None # Clear paid_at
|
||||||
|
|
||||||
|
await db.flush()
|
||||||
|
await db.refresh(expense_split, attribute_names=['status', 'paid_at', 'expense']) # Refresh to get updated data and related expense
|
||||||
|
|
||||||
|
return expense_split
|
||||||
|
|
||||||
|
|
||||||
|
async def update_expense_overall_status(db: AsyncSession, expense_id: int) -> Optional[Expense]:
|
||||||
|
"""
|
||||||
|
Updates the overall_status of an Expense based on the status of its splits.
|
||||||
|
"""
|
||||||
|
# Fetch the Expense with its related splits
|
||||||
|
result = await db.execute(
|
||||||
|
select(Expense).options(selectinload(Expense.splits)).where(Expense.id == expense_id)
|
||||||
|
)
|
||||||
|
expense = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not expense:
|
||||||
|
# Or raise an exception
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not expense.splits: # No splits, should not happen for a valid expense but handle defensively
|
||||||
|
expense.overall_settlement_status = ExpenseOverallStatusEnum.unpaid # Or some other default/error state
|
||||||
|
await db.flush()
|
||||||
|
await db.refresh(expense)
|
||||||
|
return expense
|
||||||
|
|
||||||
|
num_splits = len(expense.splits)
|
||||||
|
num_paid_splits = 0
|
||||||
|
num_partially_paid_splits = 0
|
||||||
|
num_unpaid_splits = 0
|
||||||
|
|
||||||
|
for split in expense.splits:
|
||||||
|
if split.status == ExpenseSplitStatusEnum.paid:
|
||||||
|
num_paid_splits += 1
|
||||||
|
elif split.status == ExpenseSplitStatusEnum.partially_paid:
|
||||||
|
num_partially_paid_splits += 1
|
||||||
|
else: # unpaid
|
||||||
|
num_unpaid_splits += 1
|
||||||
|
|
||||||
|
if num_paid_splits == num_splits:
|
||||||
|
expense.overall_settlement_status = ExpenseOverallStatusEnum.paid
|
||||||
|
elif num_unpaid_splits == num_splits:
|
||||||
|
expense.overall_settlement_status = ExpenseOverallStatusEnum.unpaid
|
||||||
|
else: # Mix of paid, partially_paid, or unpaid but not all unpaid/paid
|
||||||
|
expense.overall_settlement_status = ExpenseOverallStatusEnum.partially_paid
|
||||||
|
|
||||||
|
await db.flush()
|
||||||
|
await db.refresh(expense, attribute_names=['overall_settlement_status'])
|
||||||
|
return expense
|
||||||
|
|
||||||
|
|
||||||
|
async def create_settlement_activity(
|
||||||
|
db: AsyncSession,
|
||||||
|
settlement_activity_in: SettlementActivityCreatePlaceholder,
|
||||||
|
current_user_id: int
|
||||||
|
) -> Optional[SettlementActivity]:
|
||||||
|
"""
|
||||||
|
Creates a new settlement activity, then updates the parent expense split and expense statuses.
|
||||||
|
"""
|
||||||
|
# Validate ExpenseSplit
|
||||||
|
split_result = await db.execute(select(ExpenseSplit).where(ExpenseSplit.id == settlement_activity_in.expense_split_id))
|
||||||
|
expense_split = split_result.scalar_one_or_none()
|
||||||
|
if not expense_split:
|
||||||
|
# Consider raising an HTTPException in an API layer
|
||||||
|
return None # ExpenseSplit not found
|
||||||
|
|
||||||
|
# Validate User (paid_by_user_id)
|
||||||
|
user_result = await db.execute(select(User).where(User.id == settlement_activity_in.paid_by_user_id))
|
||||||
|
paid_by_user = user_result.scalar_one_or_none()
|
||||||
|
if not paid_by_user:
|
||||||
|
return None # User not found
|
||||||
|
|
||||||
|
# Create SettlementActivity instance
|
||||||
|
db_settlement_activity = SettlementActivity(
|
||||||
|
expense_split_id=settlement_activity_in.expense_split_id,
|
||||||
|
paid_by_user_id=settlement_activity_in.paid_by_user_id,
|
||||||
|
amount_paid=settlement_activity_in.amount_paid,
|
||||||
|
paid_at=settlement_activity_in.paid_at if settlement_activity_in.paid_at else datetime.now(timezone.utc),
|
||||||
|
created_by_user_id=current_user_id # The user recording the activity
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(db_settlement_activity)
|
||||||
|
await db.flush() # Flush to get the ID for db_settlement_activity
|
||||||
|
|
||||||
|
# Update statuses
|
||||||
|
updated_split = await update_expense_split_status(db, expense_split_id=db_settlement_activity.expense_split_id)
|
||||||
|
if updated_split and updated_split.expense_id:
|
||||||
|
await update_expense_overall_status(db, expense_id=updated_split.expense_id)
|
||||||
|
else:
|
||||||
|
# This case implies update_expense_split_status returned None or expense_id was missing.
|
||||||
|
# This could be a problem, consider logging or raising an error.
|
||||||
|
# For now, the transaction would roll back if an exception is raised.
|
||||||
|
# If not raising, the overall status update might be skipped.
|
||||||
|
pass # Or handle error
|
||||||
|
|
||||||
|
await db.refresh(db_settlement_activity, attribute_names=['split', 'payer', 'creator']) # Refresh to load relationships
|
||||||
|
|
||||||
|
return db_settlement_activity
|
||||||
|
|
||||||
|
|
||||||
|
async def get_settlement_activity_by_id(
|
||||||
|
db: AsyncSession, settlement_activity_id: int
|
||||||
|
) -> Optional[SettlementActivity]:
|
||||||
|
"""
|
||||||
|
Fetches a single SettlementActivity by its ID, loading relationships.
|
||||||
|
"""
|
||||||
|
result = await db.execute(
|
||||||
|
select(SettlementActivity)
|
||||||
|
.options(
|
||||||
|
selectinload(SettlementActivity.split).selectinload(ExpenseSplit.expense), # Load split and its parent expense
|
||||||
|
selectinload(SettlementActivity.payer), # Load the user who paid
|
||||||
|
selectinload(SettlementActivity.creator) # Load the user who created the record
|
||||||
|
)
|
||||||
|
.where(SettlementActivity.id == settlement_activity_id)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_settlement_activities_for_split(
|
||||||
|
db: AsyncSession, expense_split_id: int, skip: int = 0, limit: int = 100
|
||||||
|
) -> List[SettlementActivity]:
|
||||||
|
"""
|
||||||
|
Fetches a list of SettlementActivity records associated with a given expense_split_id.
|
||||||
|
"""
|
||||||
|
result = await db.execute(
|
||||||
|
select(SettlementActivity)
|
||||||
|
.where(SettlementActivity.expense_split_id == expense_split_id)
|
||||||
|
.options(
|
||||||
|
selectinload(SettlementActivity.payer), # Load the user who paid
|
||||||
|
selectinload(SettlementActivity.creator) # Load the user who created the record
|
||||||
|
)
|
||||||
|
.order_by(SettlementActivity.paid_at.desc(), SettlementActivity.created_at.desc())
|
||||||
|
.offset(skip)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
# Further CRUD operations like update/delete can be added later if needed.
|
@ -1,28 +1,90 @@
|
|||||||
# app/crud/user.py
|
# app/crud/user.py
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import logging # Add logging import
|
||||||
|
|
||||||
from app.models import User as UserModel # Alias to avoid name clash
|
from app.models import User as UserModel, UserGroup as UserGroupModel, Group as GroupModel # Import related models for selectinload
|
||||||
from app.schemas.user import UserCreate
|
from app.schemas.user import UserCreate
|
||||||
from app.core.security import hash_password
|
from app.core.security import hash_password
|
||||||
|
from app.core.exceptions import (
|
||||||
|
UserCreationError,
|
||||||
|
EmailAlreadyRegisteredError,
|
||||||
|
DatabaseConnectionError,
|
||||||
|
DatabaseIntegrityError,
|
||||||
|
DatabaseQueryError,
|
||||||
|
DatabaseTransactionError,
|
||||||
|
UserOperationError # Add if specific user operation errors are needed
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # Initialize logger
|
||||||
|
|
||||||
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]:
|
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]:
|
||||||
"""Fetches a user from the database by email."""
|
"""Fetches a user from the database by email, with common relationships."""
|
||||||
result = await db.execute(select(UserModel).filter(UserModel.email == email))
|
try:
|
||||||
return result.scalars().first()
|
# db.begin() is not strictly necessary for a single read, but ensures atomicity if multiple reads were added.
|
||||||
|
# For a single select, it can be omitted if preferred, session handles connection.
|
||||||
|
stmt = (
|
||||||
|
select(UserModel)
|
||||||
|
.filter(UserModel.email == email)
|
||||||
|
.options(
|
||||||
|
selectinload(UserModel.group_associations).selectinload(UserGroupModel.group), # Groups user is member of
|
||||||
|
selectinload(UserModel.created_groups) # Groups user created
|
||||||
|
# Add other relationships as needed by UserPublic schema
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().first()
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error while fetching user by email '{email}': {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error while fetching user by email '{email}': {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseQueryError(f"Failed to query user: {str(e)}")
|
||||||
|
|
||||||
async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel:
|
async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel:
|
||||||
"""Creates a new user record in the database."""
|
"""Creates a new user record in the database with common relationships loaded."""
|
||||||
_hashed_password = hash_password(user_in.password) # Keep local var name if you like
|
try:
|
||||||
# Create SQLAlchemy model instance - explicitly map fields
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||||
db_user = UserModel(
|
_hashed_password = hash_password(user_in.password)
|
||||||
email=user_in.email,
|
db_user = UserModel(
|
||||||
# Use the correct keyword argument matching the model column name
|
email=user_in.email,
|
||||||
password_hash=_hashed_password,
|
hashed_password=_hashed_password, # Field name in model is hashed_password
|
||||||
name=user_in.name
|
name=user_in.name
|
||||||
)
|
)
|
||||||
db.add(db_user)
|
db.add(db_user)
|
||||||
await db.commit()
|
await db.flush() # Flush to get DB-generated values like ID
|
||||||
await db.refresh(db_user) # Refresh to get DB-generated values like ID, created_at
|
|
||||||
return db_user
|
# Re-fetch with relationships
|
||||||
|
stmt = (
|
||||||
|
select(UserModel)
|
||||||
|
.where(UserModel.id == db_user.id)
|
||||||
|
.options(
|
||||||
|
selectinload(UserModel.group_associations).selectinload(UserGroupModel.group),
|
||||||
|
selectinload(UserModel.created_groups)
|
||||||
|
# Add other relationships as needed by UserPublic schema
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
loaded_user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if loaded_user is None:
|
||||||
|
raise UserOperationError("Failed to load user after creation.") # Define UserOperationError
|
||||||
|
|
||||||
|
return loaded_user
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during user creation for email '{user_in.email}': {str(e)}", exc_info=True)
|
||||||
|
if "unique constraint" in str(e).lower() and ("users_email_key" in str(e).lower() or "ix_users_email" in str(e).lower()):
|
||||||
|
raise EmailAlreadyRegisteredError(email=user_in.email)
|
||||||
|
raise DatabaseIntegrityError(f"Failed to create user due to integrity issue: {str(e)}")
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error during user creation for email '{user_in.email}': {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"Database connection error during user creation: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during user creation for email '{user_in.email}': {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"Failed to create user due to other DB error: {str(e)}")
|
||||||
|
|
||||||
|
# Ensure UserOperationError is defined in app.core.exceptions if used
|
||||||
|
# Example: class UserOperationError(AppException): pass
|
@ -11,9 +11,10 @@ if not settings.DATABASE_URL:
|
|||||||
# pool_recycle=3600 helps prevent stale connections on some DBs
|
# pool_recycle=3600 helps prevent stale connections on some DBs
|
||||||
engine = create_async_engine(
|
engine = create_async_engine(
|
||||||
settings.DATABASE_URL,
|
settings.DATABASE_URL,
|
||||||
echo=True, # Log SQL queries (useful for debugging)
|
echo=False, # Disable SQL query logging for production (use DEBUG log level to enable)
|
||||||
future=True, # Use SQLAlchemy 2.0 style features
|
future=True, # Use SQLAlchemy 2.0 style features
|
||||||
pool_recycle=3600 # Optional: recycle connections after 1 hour
|
pool_recycle=3600, # Optional: recycle connections after 1 hour
|
||||||
|
pool_pre_ping=True # Add this line to ensure connections are live
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a configured "Session" class
|
# Create a configured "Session" class
|
||||||
@ -30,18 +31,27 @@ AsyncSessionLocal = sessionmaker(
|
|||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
# Dependency to get DB session in path operations
|
# Dependency to get DB session in path operations
|
||||||
async def get_db() -> AsyncSession: # type: ignore
|
async def get_session() -> AsyncSession: # type: ignore
|
||||||
"""
|
"""
|
||||||
Dependency function that yields an AsyncSession.
|
Dependency function that yields an AsyncSession for read-only operations.
|
||||||
Ensures the session is closed after the request.
|
Ensures the session is closed after the request.
|
||||||
"""
|
"""
|
||||||
async with AsyncSessionLocal() as session:
|
async with AsyncSessionLocal() as session:
|
||||||
try:
|
yield session
|
||||||
|
# The 'async with' block handles session.close() automatically.
|
||||||
|
|
||||||
|
async def get_transactional_session() -> AsyncSession: # type: ignore
|
||||||
|
"""
|
||||||
|
Dependency function that yields an AsyncSession and manages a transaction.
|
||||||
|
Commits the transaction if the request handler succeeds, otherwise rollbacks.
|
||||||
|
Ensures the session is closed after the request.
|
||||||
|
|
||||||
|
This follows the FastAPI-DB strategy for endpoint-level transaction management.
|
||||||
|
"""
|
||||||
|
async with AsyncSessionLocal() as session:
|
||||||
|
async with session.begin():
|
||||||
yield session
|
yield session
|
||||||
# Optionally commit if your endpoints modify data directly
|
# Transaction is automatically committed on success or rolled back on exception
|
||||||
# await session.commit() # Usually commit happens within endpoint logic
|
|
||||||
except Exception:
|
# Alias for backward compatibility
|
||||||
await session.rollback()
|
get_db = get_session
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
await session.close() # Not strictly necessary with async context manager, but explicit
|
|
3
be/app/db/__init__.py
Normal file
3
be/app/db/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from app.db.session import async_session
|
||||||
|
|
||||||
|
__all__ = ["async_session"]
|
4
be/app/db/session.py
Normal file
4
be/app/db/session.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from app.database import AsyncSessionLocal
|
||||||
|
|
||||||
|
# Export the async session factory
|
||||||
|
async_session = AsyncSessionLocal
|
119
be/app/jobs/recurring_expenses.py
Normal file
119
be/app/jobs/recurring_expenses.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
from datetime import datetime, timedelta
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import select, and_
|
||||||
|
from app.models import Expense, RecurrencePattern
|
||||||
|
from app.crud.expense import create_expense
|
||||||
|
from app.schemas.expense import ExpenseCreate
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def generate_recurring_expenses(db: AsyncSession) -> None:
|
||||||
|
"""
|
||||||
|
Background job to generate recurring expenses.
|
||||||
|
Should be run daily to check for and create new recurring expenses.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get all active recurring expenses that need to be generated
|
||||||
|
now = datetime.utcnow()
|
||||||
|
query = select(Expense).join(RecurrencePattern).where(
|
||||||
|
and_(
|
||||||
|
Expense.is_recurring == True,
|
||||||
|
Expense.next_occurrence <= now,
|
||||||
|
# Check if we haven't reached max occurrences
|
||||||
|
(
|
||||||
|
(RecurrencePattern.max_occurrences == None) |
|
||||||
|
(RecurrencePattern.max_occurrences > 0)
|
||||||
|
),
|
||||||
|
# Check if we haven't reached end date
|
||||||
|
(
|
||||||
|
(RecurrencePattern.end_date == None) |
|
||||||
|
(RecurrencePattern.end_date > now)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
recurring_expenses = result.scalars().all()
|
||||||
|
|
||||||
|
for expense in recurring_expenses:
|
||||||
|
try:
|
||||||
|
await _generate_next_occurrence(db, expense)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating next occurrence for expense {expense.id}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in generate_recurring_expenses job: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _generate_next_occurrence(db: AsyncSession, expense: Expense) -> None:
|
||||||
|
"""Generate the next occurrence of a recurring expense."""
|
||||||
|
pattern = expense.recurrence_pattern
|
||||||
|
if not pattern:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Calculate next occurrence date
|
||||||
|
next_date = _calculate_next_occurrence(expense.next_occurrence, pattern)
|
||||||
|
if not next_date:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create new expense based on template
|
||||||
|
new_expense = ExpenseCreate(
|
||||||
|
description=expense.description,
|
||||||
|
total_amount=expense.total_amount,
|
||||||
|
currency=expense.currency,
|
||||||
|
expense_date=next_date,
|
||||||
|
split_type=expense.split_type,
|
||||||
|
list_id=expense.list_id,
|
||||||
|
group_id=expense.group_id,
|
||||||
|
item_id=expense.item_id,
|
||||||
|
paid_by_user_id=expense.paid_by_user_id,
|
||||||
|
is_recurring=False, # Generated expenses are not recurring
|
||||||
|
splits_in=None # Will be generated based on split_type
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the new expense
|
||||||
|
created_expense = await create_expense(db, new_expense, expense.created_by_user_id)
|
||||||
|
|
||||||
|
# Update the original expense
|
||||||
|
expense.last_occurrence = next_date
|
||||||
|
expense.next_occurrence = _calculate_next_occurrence(next_date, pattern)
|
||||||
|
|
||||||
|
if pattern.max_occurrences:
|
||||||
|
pattern.max_occurrences -= 1
|
||||||
|
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
def _calculate_next_occurrence(current_date: datetime, pattern: RecurrencePattern) -> Optional[datetime]:
|
||||||
|
"""Calculate the next occurrence date based on the pattern."""
|
||||||
|
if not current_date:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if pattern.type == 'daily':
|
||||||
|
return current_date + timedelta(days=pattern.interval)
|
||||||
|
|
||||||
|
elif pattern.type == 'weekly':
|
||||||
|
if not pattern.days_of_week:
|
||||||
|
return current_date + timedelta(weeks=pattern.interval)
|
||||||
|
|
||||||
|
# Find next day of week
|
||||||
|
current_weekday = current_date.weekday()
|
||||||
|
next_weekday = min((d for d in pattern.days_of_week if d > current_weekday),
|
||||||
|
default=min(pattern.days_of_week))
|
||||||
|
days_ahead = next_weekday - current_weekday
|
||||||
|
if days_ahead <= 0:
|
||||||
|
days_ahead += 7
|
||||||
|
return current_date + timedelta(days=days_ahead)
|
||||||
|
|
||||||
|
elif pattern.type == 'monthly':
|
||||||
|
# Add months to current date
|
||||||
|
year = current_date.year + (current_date.month + pattern.interval - 1) // 12
|
||||||
|
month = (current_date.month + pattern.interval - 1) % 12 + 1
|
||||||
|
return current_date.replace(year=year, month=month)
|
||||||
|
|
||||||
|
elif pattern.type == 'yearly':
|
||||||
|
return current_date.replace(year=current_date.year + pattern.interval)
|
||||||
|
|
||||||
|
return None
|
266
be/app/main.py
266
be/app/main.py
@ -1,55 +1,204 @@
|
|||||||
# app/main.py
|
# app/main.py
|
||||||
import logging
|
import logging
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, HTTPException, Depends, status, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
|
import sentry_sdk
|
||||||
|
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||||
|
from fastapi_users.authentication import JWTStrategy
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from jose import jwt, JWTError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||||
|
from alembic.config import Config
|
||||||
|
from alembic import command
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
from app.api.api_router import api_router # Import the main combined router
|
from app.api.api_router import api_router
|
||||||
# Import database and models if needed for startup/shutdown events later
|
from app.config import settings
|
||||||
# from . import database, models
|
from app.core.api_config import API_METADATA, API_TAGS
|
||||||
|
from app.auth import fastapi_users, auth_backend, get_refresh_jwt_strategy, get_jwt_strategy
|
||||||
|
from app.models import User
|
||||||
|
from app.api.auth.oauth import router as oauth_router
|
||||||
|
from app.schemas.user import UserPublic, UserCreate, UserUpdate
|
||||||
|
from app.core.scheduler import init_scheduler, shutdown_scheduler
|
||||||
|
from app.database import get_session
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
# Response model for refresh endpoint
|
||||||
|
class RefreshResponse(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
|
||||||
|
# Initialize Sentry only if DSN is provided
|
||||||
|
if settings.SENTRY_DSN:
|
||||||
|
sentry_sdk.init(
|
||||||
|
dsn=settings.SENTRY_DSN,
|
||||||
|
integrations=[
|
||||||
|
FastApiIntegration(),
|
||||||
|
],
|
||||||
|
# Adjust traces_sample_rate for production
|
||||||
|
traces_sample_rate=0.1 if settings.is_production else 1.0,
|
||||||
|
environment=settings.ENVIRONMENT,
|
||||||
|
# Enable PII data only in development
|
||||||
|
send_default_pii=not settings.is_production
|
||||||
|
)
|
||||||
|
|
||||||
# --- Logging Setup ---
|
# --- Logging Setup ---
|
||||||
# Configure logging (can be more sophisticated later, e.g., using logging.yaml)
|
logging.basicConfig(
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
level=getattr(logging, settings.LOG_LEVEL),
|
||||||
|
format=settings.LOG_FORMAT
|
||||||
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# --- FastAPI App Instance ---
|
# --- FastAPI App Instance ---
|
||||||
|
# Create API metadata with environment-dependent settings
|
||||||
|
api_metadata = {
|
||||||
|
**API_METADATA,
|
||||||
|
"docs_url": settings.docs_url,
|
||||||
|
"redoc_url": settings.redoc_url,
|
||||||
|
"openapi_url": settings.openapi_url,
|
||||||
|
}
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Shared Lists API",
|
**api_metadata,
|
||||||
description="API for managing shared shopping lists, OCR, and cost splitting.",
|
openapi_tags=API_TAGS
|
||||||
version="0.1.0",
|
)
|
||||||
openapi_url="/api/openapi.json", # Place OpenAPI spec under /api
|
|
||||||
docs_url="/api/docs", # Place Swagger UI under /api
|
# Add session middleware for OAuth
|
||||||
redoc_url="/api/redoc" # Place ReDoc under /api
|
app.add_middleware(
|
||||||
|
SessionMiddleware,
|
||||||
|
secret_key=settings.SESSION_SECRET_KEY
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- CORS Middleware ---
|
# --- CORS Middleware ---
|
||||||
# Define allowed origins. Be specific in production!
|
|
||||||
# Use ["*"] for wide open access during early development if needed,
|
|
||||||
# but restrict it as soon as possible.
|
|
||||||
# SvelteKit default dev port is 5173
|
|
||||||
origins = [
|
|
||||||
"http://localhost:5174",
|
|
||||||
"http://localhost:8000", # Allow requests from the API itself (e.g., Swagger UI)
|
|
||||||
# Add your deployed frontend URL here later
|
|
||||||
# "https://your-frontend-domain.com",
|
|
||||||
]
|
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=origins, # List of origins that are allowed to make requests
|
allow_origins=settings.cors_origins_list,
|
||||||
allow_credentials=True, # Allow cookies to be included in requests
|
allow_credentials=True,
|
||||||
allow_methods=["*"], # Allow all methods (GET, POST, PUT, DELETE, etc.)
|
allow_methods=["*"],
|
||||||
allow_headers=["*"], # Allow all headers
|
allow_headers=["*"],
|
||||||
|
expose_headers=["*"]
|
||||||
)
|
)
|
||||||
# --- End CORS Middleware ---
|
# --- End CORS Middleware ---
|
||||||
|
|
||||||
|
# Refresh token endpoint
|
||||||
|
@app.post("/auth/jwt/refresh", response_model=RefreshResponse, tags=["auth"])
|
||||||
|
async def refresh_jwt_token(
|
||||||
|
request: Request,
|
||||||
|
refresh_strategy: JWTStrategy = Depends(get_refresh_jwt_strategy),
|
||||||
|
access_strategy: JWTStrategy = Depends(get_jwt_strategy),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Refresh access token using a valid refresh token.
|
||||||
|
Send refresh token in Authorization header: Bearer <refresh_token>
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get refresh token from Authorization header
|
||||||
|
authorization = request.headers.get("Authorization")
|
||||||
|
if not authorization or not authorization.startswith("Bearer "):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Refresh token missing or invalid format",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_token = authorization.split(" ")[1]
|
||||||
|
|
||||||
|
# Validate refresh token and get user data
|
||||||
|
try:
|
||||||
|
# Decode the refresh token to get the user identifier
|
||||||
|
payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=["HS256"])
|
||||||
|
user_id = payload.get("sub")
|
||||||
|
if user_id is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid refresh token",
|
||||||
|
)
|
||||||
|
except JWTError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid refresh token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get user from database
|
||||||
|
async with get_session() as session:
|
||||||
|
result = await session.execute(select(User).where(User.id == int(user_id)))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not user or not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="User not found or inactive",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate new tokens
|
||||||
|
new_access_token = await access_strategy.write_token(user)
|
||||||
|
new_refresh_token = await refresh_strategy.write_token(user)
|
||||||
|
|
||||||
|
return RefreshResponse(
|
||||||
|
access_token=new_access_token,
|
||||||
|
refresh_token=new_refresh_token,
|
||||||
|
token_type="bearer"
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error refreshing token: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid refresh token"
|
||||||
|
)
|
||||||
|
|
||||||
# --- Include API Routers ---
|
# --- Include API Routers ---
|
||||||
# All API endpoints will be prefixed with /api
|
# Include OAuth routes first (no auth required)
|
||||||
app.include_router(api_router, prefix="/api")
|
app.include_router(oauth_router, prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
# Include FastAPI-Users routes
|
||||||
|
app.include_router(
|
||||||
|
fastapi_users.get_auth_router(auth_backend),
|
||||||
|
prefix="/auth/jwt",
|
||||||
|
tags=["auth"],
|
||||||
|
)
|
||||||
|
app.include_router(
|
||||||
|
fastapi_users.get_register_router(UserPublic, UserCreate),
|
||||||
|
prefix="/auth",
|
||||||
|
tags=["auth"],
|
||||||
|
)
|
||||||
|
app.include_router(
|
||||||
|
fastapi_users.get_reset_password_router(),
|
||||||
|
prefix="/auth",
|
||||||
|
tags=["auth"],
|
||||||
|
)
|
||||||
|
app.include_router(
|
||||||
|
fastapi_users.get_verify_router(UserPublic),
|
||||||
|
prefix="/auth",
|
||||||
|
tags=["auth"],
|
||||||
|
)
|
||||||
|
app.include_router(
|
||||||
|
fastapi_users.get_users_router(UserPublic, UserUpdate),
|
||||||
|
prefix="/users",
|
||||||
|
tags=["users"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Include your API router
|
||||||
|
app.include_router(api_router, prefix=settings.API_PREFIX)
|
||||||
# --- End Include API Routers ---
|
# --- End Include API Routers ---
|
||||||
|
|
||||||
|
# Health check endpoint
|
||||||
|
@app.get("/health", tags=["Health"])
|
||||||
|
async def health_check():
|
||||||
|
"""
|
||||||
|
Health check endpoint for load balancers and monitoring.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"status": settings.HEALTH_STATUS_OK,
|
||||||
|
"environment": settings.ENVIRONMENT,
|
||||||
|
"version": settings.API_VERSION
|
||||||
|
}
|
||||||
|
|
||||||
# --- Root Endpoint (Optional - outside the main API structure) ---
|
# --- Root Endpoint (Optional - outside the main API structure) ---
|
||||||
@app.get("/", tags=["Root"])
|
@app.get("/", tags=["Root"])
|
||||||
@ -59,26 +208,53 @@ async def read_root():
|
|||||||
Useful for basic reachability checks.
|
Useful for basic reachability checks.
|
||||||
"""
|
"""
|
||||||
logger.info("Root endpoint '/' accessed.")
|
logger.info("Root endpoint '/' accessed.")
|
||||||
# You could redirect to the docs or return a simple message
|
return {
|
||||||
# from fastapi.responses import RedirectResponse
|
"message": settings.ROOT_MESSAGE,
|
||||||
# return RedirectResponse(url="/api/docs")
|
"environment": settings.ENVIRONMENT,
|
||||||
return {"message": "Welcome to the Shared Lists API! Docs available at /api/docs"}
|
"version": settings.API_VERSION
|
||||||
|
}
|
||||||
# --- End Root Endpoint ---
|
# --- End Root Endpoint ---
|
||||||
|
|
||||||
|
async def run_migrations():
|
||||||
|
"""Run database migrations."""
|
||||||
|
try:
|
||||||
|
logger.info("Running database migrations...")
|
||||||
|
# Get the absolute path to the alembic directory
|
||||||
|
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
alembic_path = os.path.join(base_path, 'alembic')
|
||||||
|
|
||||||
# --- Application Startup/Shutdown Events (Optional) ---
|
# Add alembic directory to Python path
|
||||||
# @app.on_event("startup")
|
if alembic_path not in sys.path:
|
||||||
# async def startup_event():
|
sys.path.insert(0, alembic_path)
|
||||||
# logger.info("Application startup: Connecting to database...")
|
|
||||||
# # You might perform initial checks or warm-up here
|
|
||||||
# # await database.engine.connect() # Example check (get_db handles sessions per request)
|
|
||||||
# logger.info("Application startup complete.")
|
|
||||||
|
|
||||||
# @app.on_event("shutdown")
|
# Import and run migrations
|
||||||
# async def shutdown_event():
|
from migrations import run_migrations as run_db_migrations
|
||||||
# logger.info("Application shutdown: Disconnecting from database...")
|
await run_db_migrations()
|
||||||
# # await database.engine.dispose() # Close connection pool
|
|
||||||
# logger.info("Application shutdown complete.")
|
logger.info("Database migrations completed successfully.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error running migrations: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
"""Initialize services on startup."""
|
||||||
|
logger.info(f"Application startup in {settings.ENVIRONMENT} environment...")
|
||||||
|
|
||||||
|
# Run database migrations
|
||||||
|
# await run_migrations()
|
||||||
|
|
||||||
|
# Initialize scheduler
|
||||||
|
init_scheduler()
|
||||||
|
logger.info("Application startup complete.")
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown_event():
|
||||||
|
"""Cleanup services on shutdown."""
|
||||||
|
logger.info("Application shutdown: Disconnecting from database...")
|
||||||
|
# await database.engine.dispose() # Close connection pool
|
||||||
|
shutdown_scheduler()
|
||||||
|
logger.info("Application shutdown complete.")
|
||||||
# --- End Events ---
|
# --- End Events ---
|
||||||
|
|
||||||
|
|
||||||
|
295
be/app/models.py
295
be/app/models.py
@ -19,9 +19,11 @@ from sqlalchemy import (
|
|||||||
func,
|
func,
|
||||||
text as sa_text,
|
text as sa_text,
|
||||||
Text, # <-- Add Text for description
|
Text, # <-- Add Text for description
|
||||||
Numeric # <-- Add Numeric for price
|
Numeric, # <-- Add Numeric for price
|
||||||
|
CheckConstraint,
|
||||||
|
Date # Added Date for Chore model
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship, backref
|
||||||
|
|
||||||
from .database import Base
|
from .database import Base
|
||||||
|
|
||||||
@ -30,14 +32,56 @@ class UserRoleEnum(enum.Enum):
|
|||||||
owner = "owner"
|
owner = "owner"
|
||||||
member = "member"
|
member = "member"
|
||||||
|
|
||||||
|
class SplitTypeEnum(enum.Enum):
|
||||||
|
EQUAL = "EQUAL" # Split equally among all involved users
|
||||||
|
EXACT_AMOUNTS = "EXACT_AMOUNTS" # Specific amounts for each user (defined in ExpenseSplit)
|
||||||
|
PERCENTAGE = "PERCENTAGE" # Percentage for each user (defined in ExpenseSplit)
|
||||||
|
SHARES = "SHARES" # Proportional to shares/units (defined in ExpenseSplit)
|
||||||
|
ITEM_BASED = "ITEM_BASED" # If an expense is derived directly from item prices and who added them
|
||||||
|
# Consider renaming to a more generic term like 'DERIVED' or 'ENTITY_DRIVEN'
|
||||||
|
# if expenses might be derived from other entities in the future.
|
||||||
|
# Add more types as needed, e.g., UNPAID (for tracking debts not part of a formal expense)
|
||||||
|
|
||||||
|
class ExpenseSplitStatusEnum(enum.Enum):
|
||||||
|
unpaid = "unpaid"
|
||||||
|
partially_paid = "partially_paid"
|
||||||
|
paid = "paid"
|
||||||
|
|
||||||
|
class ExpenseOverallStatusEnum(enum.Enum):
|
||||||
|
unpaid = "unpaid"
|
||||||
|
partially_paid = "partially_paid"
|
||||||
|
paid = "paid"
|
||||||
|
|
||||||
|
class RecurrenceTypeEnum(enum.Enum):
|
||||||
|
DAILY = "DAILY"
|
||||||
|
WEEKLY = "WEEKLY"
|
||||||
|
MONTHLY = "MONTHLY"
|
||||||
|
YEARLY = "YEARLY"
|
||||||
|
# Add more types as needed
|
||||||
|
|
||||||
|
# Define ChoreFrequencyEnum
|
||||||
|
class ChoreFrequencyEnum(enum.Enum):
|
||||||
|
one_time = "one_time"
|
||||||
|
daily = "daily"
|
||||||
|
weekly = "weekly"
|
||||||
|
monthly = "monthly"
|
||||||
|
custom = "custom"
|
||||||
|
|
||||||
|
class ChoreTypeEnum(enum.Enum):
|
||||||
|
personal = "personal"
|
||||||
|
group = "group"
|
||||||
|
|
||||||
# --- User Model ---
|
# --- User Model ---
|
||||||
class User(Base):
|
class User(Base):
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
email = Column(String, unique=True, index=True, nullable=False)
|
email = Column(String, unique=True, index=True, nullable=False)
|
||||||
password_hash = Column(String, nullable=False)
|
hashed_password = Column(String, nullable=False)
|
||||||
name = Column(String, index=True, nullable=True)
|
name = Column(String, index=True, nullable=True)
|
||||||
|
is_active = Column(Boolean, default=True, nullable=False)
|
||||||
|
is_superuser = Column(Boolean, default=False, nullable=False)
|
||||||
|
is_verified = Column(Boolean, default=False, nullable=False)
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
|
||||||
# --- Relationships ---
|
# --- Relationships ---
|
||||||
@ -51,6 +95,20 @@ class User(Base):
|
|||||||
completed_items = relationship("Item", foreign_keys="Item.completed_by_id", back_populates="completed_by_user") # Link Item.completed_by_id -> User
|
completed_items = relationship("Item", foreign_keys="Item.completed_by_id", back_populates="completed_by_user") # Link Item.completed_by_id -> User
|
||||||
# --- End NEW Relationships ---
|
# --- End NEW Relationships ---
|
||||||
|
|
||||||
|
# --- Relationships for Cost Splitting ---
|
||||||
|
expenses_paid = relationship("Expense", foreign_keys="Expense.paid_by_user_id", back_populates="paid_by_user", cascade="all, delete-orphan")
|
||||||
|
expenses_created = relationship("Expense", foreign_keys="Expense.created_by_user_id", back_populates="created_by_user", cascade="all, delete-orphan")
|
||||||
|
expense_splits = relationship("ExpenseSplit", foreign_keys="ExpenseSplit.user_id", back_populates="user", cascade="all, delete-orphan")
|
||||||
|
settlements_made = relationship("Settlement", foreign_keys="Settlement.paid_by_user_id", back_populates="payer", cascade="all, delete-orphan")
|
||||||
|
settlements_received = relationship("Settlement", foreign_keys="Settlement.paid_to_user_id", back_populates="payee", cascade="all, delete-orphan")
|
||||||
|
settlements_created = relationship("Settlement", foreign_keys="Settlement.created_by_user_id", back_populates="created_by_user", cascade="all, delete-orphan")
|
||||||
|
# --- End Relationships for Cost Splitting ---
|
||||||
|
|
||||||
|
# --- Relationships for Chores ---
|
||||||
|
created_chores = relationship("Chore", foreign_keys="[Chore.created_by_id]", back_populates="creator")
|
||||||
|
assigned_chores = relationship("ChoreAssignment", back_populates="assigned_user", cascade="all, delete-orphan")
|
||||||
|
# --- End Relationships for Chores ---
|
||||||
|
|
||||||
|
|
||||||
# --- Group Model ---
|
# --- Group Model ---
|
||||||
class Group(Base):
|
class Group(Base):
|
||||||
@ -70,6 +128,15 @@ class Group(Base):
|
|||||||
lists = relationship("List", back_populates="group", cascade="all, delete-orphan") # Link List.group_id -> Group
|
lists = relationship("List", back_populates="group", cascade="all, delete-orphan") # Link List.group_id -> Group
|
||||||
# --- End NEW Relationship ---
|
# --- End NEW Relationship ---
|
||||||
|
|
||||||
|
# --- Relationships for Cost Splitting ---
|
||||||
|
expenses = relationship("Expense", foreign_keys="Expense.group_id", back_populates="group", cascade="all, delete-orphan")
|
||||||
|
settlements = relationship("Settlement", foreign_keys="Settlement.group_id", back_populates="group", cascade="all, delete-orphan")
|
||||||
|
# --- End Relationships for Cost Splitting ---
|
||||||
|
|
||||||
|
# --- Relationship for Chores ---
|
||||||
|
chores = relationship("Chore", back_populates="group", cascade="all, delete-orphan")
|
||||||
|
# --- End Relationship for Chores ---
|
||||||
|
|
||||||
|
|
||||||
# --- UserGroup Association Model ---
|
# --- UserGroup Association Model ---
|
||||||
class UserGroup(Base):
|
class UserGroup(Base):
|
||||||
@ -117,16 +184,29 @@ class List(Base):
|
|||||||
is_complete = Column(Boolean, default=False, nullable=False)
|
is_complete = Column(Boolean, default=False, nullable=False)
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||||
|
version = Column(Integer, nullable=False, default=1, server_default='1')
|
||||||
|
|
||||||
# --- Relationships ---
|
# --- Relationships ---
|
||||||
creator = relationship("User", back_populates="created_lists") # Link to User.created_lists
|
creator = relationship("User", back_populates="created_lists") # Link to User.created_lists
|
||||||
group = relationship("Group", back_populates="lists") # Link to Group.lists
|
group = relationship("Group", back_populates="lists") # Link to Group.lists
|
||||||
items = relationship("Item", back_populates="list", cascade="all, delete-orphan", order_by="Item.created_at") # Link to Item.list, cascade deletes
|
items = relationship(
|
||||||
|
"Item",
|
||||||
|
back_populates="list",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
order_by="Item.position.asc(), Item.created_at.asc()" # Default order by position, then creation
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Relationships for Cost Splitting ---
|
||||||
|
expenses = relationship("Expense", foreign_keys="Expense.list_id", back_populates="list", cascade="all, delete-orphan")
|
||||||
|
# --- End Relationships for Cost Splitting ---
|
||||||
|
|
||||||
|
|
||||||
# === NEW: Item Model ===
|
# === NEW: Item Model ===
|
||||||
class Item(Base):
|
class Item(Base):
|
||||||
__tablename__ = "items"
|
__tablename__ = "items"
|
||||||
|
__table_args__ = (
|
||||||
|
Index('ix_items_list_id_position', 'list_id', 'position'),
|
||||||
|
)
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
list_id = Column(Integer, ForeignKey("lists.id", ondelete="CASCADE"), nullable=False) # Belongs to which list
|
list_id = Column(Integer, ForeignKey("lists.id", ondelete="CASCADE"), nullable=False) # Belongs to which list
|
||||||
@ -134,12 +214,219 @@ class Item(Base):
|
|||||||
quantity = Column(String, nullable=True) # Flexible quantity (e.g., "1", "2 lbs", "a bunch")
|
quantity = Column(String, nullable=True) # Flexible quantity (e.g., "1", "2 lbs", "a bunch")
|
||||||
is_complete = Column(Boolean, default=False, nullable=False)
|
is_complete = Column(Boolean, default=False, nullable=False)
|
||||||
price = Column(Numeric(10, 2), nullable=True) # For cost splitting later (e.g., 12345678.99)
|
price = Column(Numeric(10, 2), nullable=True) # For cost splitting later (e.g., 12345678.99)
|
||||||
|
position = Column(Integer, nullable=False, server_default='0') # For ordering
|
||||||
added_by_id = Column(Integer, ForeignKey("users.id"), nullable=False) # Who added this item
|
added_by_id = Column(Integer, ForeignKey("users.id"), nullable=False) # Who added this item
|
||||||
completed_by_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Who marked it complete
|
completed_by_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Who marked it complete
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||||
|
version = Column(Integer, nullable=False, default=1, server_default='1')
|
||||||
|
|
||||||
# --- Relationships ---
|
# --- Relationships ---
|
||||||
list = relationship("List", back_populates="items") # Link to List.items
|
list = relationship("List", back_populates="items") # Link to List.items
|
||||||
added_by_user = relationship("User", foreign_keys=[added_by_id], back_populates="added_items") # Link to User.added_items
|
added_by_user = relationship("User", foreign_keys=[added_by_id], back_populates="added_items") # Link to User.added_items
|
||||||
completed_by_user = relationship("User", foreign_keys=[completed_by_id], back_populates="completed_items") # Link to User.completed_items
|
completed_by_user = relationship("User", foreign_keys=[completed_by_id], back_populates="completed_items") # Link to User.completed_items
|
||||||
|
|
||||||
|
# --- Relationships for Cost Splitting ---
|
||||||
|
# If an item directly results in an expense, or an expense can be tied to an item.
|
||||||
|
expenses = relationship("Expense", back_populates="item") # An item might have multiple associated expenses
|
||||||
|
# --- End Relationships for Cost Splitting ---
|
||||||
|
|
||||||
|
|
||||||
|
# === NEW Models for Advanced Cost Splitting ===
|
||||||
|
|
||||||
|
class Expense(Base):
|
||||||
|
__tablename__ = "expenses"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
description = Column(String, nullable=False)
|
||||||
|
total_amount = Column(Numeric(10, 2), nullable=False)
|
||||||
|
currency = Column(String, nullable=False, default="USD")
|
||||||
|
expense_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
split_type = Column(SAEnum(SplitTypeEnum, name="splittypeenum", create_type=True), nullable=False)
|
||||||
|
|
||||||
|
# Foreign Keys
|
||||||
|
list_id = Column(Integer, ForeignKey("lists.id"), nullable=True, index=True)
|
||||||
|
group_id = Column(Integer, ForeignKey("groups.id"), nullable=True, index=True)
|
||||||
|
item_id = Column(Integer, ForeignKey("items.id"), nullable=True)
|
||||||
|
paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||||
|
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||||
|
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||||
|
version = Column(Integer, nullable=False, default=1, server_default='1')
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
paid_by_user = relationship("User", foreign_keys=[paid_by_user_id], back_populates="expenses_paid")
|
||||||
|
created_by_user = relationship("User", foreign_keys=[created_by_user_id], back_populates="expenses_created")
|
||||||
|
list = relationship("List", foreign_keys=[list_id], back_populates="expenses")
|
||||||
|
group = relationship("Group", foreign_keys=[group_id], back_populates="expenses")
|
||||||
|
item = relationship("Item", foreign_keys=[item_id], back_populates="expenses")
|
||||||
|
splits = relationship("ExpenseSplit", back_populates="expense", cascade="all, delete-orphan")
|
||||||
|
parent_expense = relationship("Expense", remote_side=[id], back_populates="child_expenses")
|
||||||
|
child_expenses = relationship("Expense", back_populates="parent_expense")
|
||||||
|
overall_settlement_status = Column(SAEnum(ExpenseOverallStatusEnum, name="expenseoverallstatusenum", create_type=True), nullable=False, server_default=ExpenseOverallStatusEnum.unpaid.value, default=ExpenseOverallStatusEnum.unpaid)
|
||||||
|
# --- Recurrence fields ---
|
||||||
|
is_recurring = Column(Boolean, default=False, nullable=False)
|
||||||
|
recurrence_pattern_id = Column(Integer, ForeignKey("recurrence_patterns.id"), nullable=True)
|
||||||
|
recurrence_pattern = relationship("RecurrencePattern", back_populates="expenses", uselist=False) # One-to-one
|
||||||
|
next_occurrence = Column(DateTime(timezone=True), nullable=True) # For recurring expenses
|
||||||
|
parent_expense_id = Column(Integer, ForeignKey("expenses.id"), nullable=True)
|
||||||
|
last_occurrence = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
# Ensure at least one context is provided
|
||||||
|
CheckConstraint('(item_id IS NOT NULL) OR (list_id IS NOT NULL) OR (group_id IS NOT NULL)', name='chk_expense_context'),
|
||||||
|
)
|
||||||
|
|
||||||
|
class ExpenseSplit(Base):
|
||||||
|
__tablename__ = "expense_splits"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint('expense_id', 'user_id', name='uq_expense_user_split'),
|
||||||
|
Index('ix_expense_splits_user_id', 'user_id'), # For looking up user's splits
|
||||||
|
)
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
expense_id = Column(Integer, ForeignKey("expenses.id", ondelete="CASCADE"), nullable=False)
|
||||||
|
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||||
|
|
||||||
|
owed_amount = Column(Numeric(10, 2), nullable=False)
|
||||||
|
share_percentage = Column(Numeric(5, 2), nullable=True)
|
||||||
|
share_units = Column(Integer, nullable=True)
|
||||||
|
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
expense = relationship("Expense", back_populates="splits")
|
||||||
|
user = relationship("User", foreign_keys=[user_id], back_populates="expense_splits")
|
||||||
|
settlement_activities = relationship("SettlementActivity", back_populates="split", cascade="all, delete-orphan")
|
||||||
|
|
||||||
|
# New fields for tracking payment status
|
||||||
|
status = Column(SAEnum(ExpenseSplitStatusEnum, name="expensesplitstatusenum", create_type=True), nullable=False, server_default=ExpenseSplitStatusEnum.unpaid.value, default=ExpenseSplitStatusEnum.unpaid)
|
||||||
|
paid_at = Column(DateTime(timezone=True), nullable=True) # Timestamp when the split was fully paid
|
||||||
|
|
||||||
|
class Settlement(Base):
|
||||||
|
__tablename__ = "settlements"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
group_id = Column(Integer, ForeignKey("groups.id"), nullable=False, index=True)
|
||||||
|
paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||||
|
paid_to_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||||
|
amount = Column(Numeric(10, 2), nullable=False)
|
||||||
|
settlement_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
description = Column(Text, nullable=True)
|
||||||
|
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||||
|
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||||
|
version = Column(Integer, nullable=False, default=1, server_default='1')
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
group = relationship("Group", foreign_keys=[group_id], back_populates="settlements")
|
||||||
|
payer = relationship("User", foreign_keys=[paid_by_user_id], back_populates="settlements_made")
|
||||||
|
payee = relationship("User", foreign_keys=[paid_to_user_id], back_populates="settlements_received")
|
||||||
|
created_by_user = relationship("User", foreign_keys=[created_by_user_id], back_populates="settlements_created")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
# Ensure payer and payee are different users
|
||||||
|
CheckConstraint('paid_by_user_id != paid_to_user_id', name='chk_settlement_different_users'),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Potential future: PaymentMethod model, etc.
|
||||||
|
|
||||||
|
class SettlementActivity(Base):
|
||||||
|
__tablename__ = "settlement_activities"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
expense_split_id = Column(Integer, ForeignKey("expense_splits.id"), nullable=False, index=True)
|
||||||
|
paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) # User who made this part of the payment
|
||||||
|
paid_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
amount_paid = Column(Numeric(10, 2), nullable=False)
|
||||||
|
created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) # User who recorded this activity
|
||||||
|
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||||
|
|
||||||
|
# --- Relationships ---
|
||||||
|
split = relationship("ExpenseSplit", back_populates="settlement_activities")
|
||||||
|
payer = relationship("User", foreign_keys=[paid_by_user_id], backref="made_settlement_activities")
|
||||||
|
creator = relationship("User", foreign_keys=[created_by_user_id], backref="created_settlement_activities")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index('ix_settlement_activity_expense_split_id', 'expense_split_id'),
|
||||||
|
Index('ix_settlement_activity_paid_by_user_id', 'paid_by_user_id'),
|
||||||
|
Index('ix_settlement_activity_created_by_user_id', 'created_by_user_id'),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Chore Model ---
|
||||||
|
class Chore(Base):
|
||||||
|
__tablename__ = "chores"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
type = Column(SAEnum(ChoreTypeEnum, name="choretypeenum", create_type=True), nullable=False)
|
||||||
|
group_id = Column(Integer, ForeignKey("groups.id", ondelete="CASCADE"), nullable=True, index=True)
|
||||||
|
name = Column(String, nullable=False, index=True)
|
||||||
|
description = Column(Text, nullable=True)
|
||||||
|
created_by_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||||
|
|
||||||
|
frequency = Column(SAEnum(ChoreFrequencyEnum, name="chorefrequencyenum", create_type=True), nullable=False)
|
||||||
|
custom_interval_days = Column(Integer, nullable=True) # Only if frequency is 'custom'
|
||||||
|
|
||||||
|
next_due_date = Column(Date, nullable=False) # Changed to Date
|
||||||
|
last_completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||||
|
|
||||||
|
# --- Relationships ---
|
||||||
|
group = relationship("Group", back_populates="chores")
|
||||||
|
creator = relationship("User", back_populates="created_chores")
|
||||||
|
assignments = relationship("ChoreAssignment", back_populates="chore", cascade="all, delete-orphan")
|
||||||
|
|
||||||
|
|
||||||
|
# --- ChoreAssignment Model ---
|
||||||
|
class ChoreAssignment(Base):
|
||||||
|
__tablename__ = "chore_assignments"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
chore_id = Column(Integer, ForeignKey("chores.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||||
|
assigned_to_user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||||
|
|
||||||
|
due_date = Column(Date, nullable=False) # Specific due date for this instance, changed to Date
|
||||||
|
is_complete = Column(Boolean, default=False, nullable=False)
|
||||||
|
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||||
|
|
||||||
|
# --- Relationships ---
|
||||||
|
chore = relationship("Chore", back_populates="assignments")
|
||||||
|
assigned_user = relationship("User", back_populates="assigned_chores")
|
||||||
|
|
||||||
|
|
||||||
|
# === NEW: RecurrencePattern Model ===
|
||||||
|
class RecurrencePattern(Base):
|
||||||
|
__tablename__ = "recurrence_patterns"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
type = Column(SAEnum(RecurrenceTypeEnum, name="recurrencetypeenum", create_type=True), nullable=False)
|
||||||
|
interval = Column(Integer, default=1, nullable=False) # e.g., every 1 day, every 2 weeks
|
||||||
|
days_of_week = Column(String, nullable=True) # For weekly recurrences, e.g., "MON,TUE,FRI"
|
||||||
|
# day_of_month = Column(Integer, nullable=True) # For monthly on a specific day
|
||||||
|
# week_of_month = Column(Integer, nullable=True) # For monthly on a specific week (e.g., 2nd week)
|
||||||
|
# month_of_year = Column(Integer, nullable=True) # For yearly recurrences
|
||||||
|
end_date = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
max_occurrences = Column(Integer, nullable=True)
|
||||||
|
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||||
|
|
||||||
|
# Relationship back to Expenses that use this pattern (could be one-to-many if patterns are shared)
|
||||||
|
# However, the current CRUD implies one RecurrencePattern per Expense if recurring.
|
||||||
|
# If a pattern can be shared, this would be a one-to-many (RecurrencePattern to many Expenses).
|
||||||
|
# For now, assuming one-to-one as implied by current Expense.recurrence_pattern relationship setup.
|
||||||
|
expenses = relationship("Expense", back_populates="recurrence_pattern")
|
||||||
|
|
||||||
|
|
||||||
|
# === END: RecurrencePattern Model ===
|
||||||
|
0
be/app/models/expense.py
Normal file
0
be/app/models/expense.py
Normal file
@ -1,9 +1,11 @@
|
|||||||
# app/schemas/auth.py
|
# app/schemas/auth.py
|
||||||
from pydantic import BaseModel, EmailStr
|
from pydantic import BaseModel, EmailStr
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
class Token(BaseModel):
|
class Token(BaseModel):
|
||||||
access_token: str
|
access_token: str
|
||||||
token_type: str = "bearer" # Default token type
|
refresh_token: str # Added refresh token
|
||||||
|
token_type: str = settings.TOKEN_TYPE # Use configured token type
|
||||||
|
|
||||||
# Optional: If you preferred not to use OAuth2PasswordRequestForm
|
# Optional: If you preferred not to use OAuth2PasswordRequestForm
|
||||||
# class UserLogin(BaseModel):
|
# class UserLogin(BaseModel):
|
||||||
|
111
be/app/schemas/chore.py
Normal file
111
be/app/schemas/chore.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
from datetime import date, datetime
|
||||||
|
from typing import Optional, List
|
||||||
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||||||
|
|
||||||
|
# Assuming ChoreFrequencyEnum is imported from models
|
||||||
|
# Adjust the import path if necessary based on your project structure.
|
||||||
|
# e.g., from app.models import ChoreFrequencyEnum
|
||||||
|
from ..models import ChoreFrequencyEnum, ChoreTypeEnum, User as UserModel # For UserPublic relation
|
||||||
|
from .user import UserPublic # For embedding user information
|
||||||
|
|
||||||
|
# Chore Schemas
|
||||||
|
class ChoreBase(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
frequency: ChoreFrequencyEnum
|
||||||
|
custom_interval_days: Optional[int] = None
|
||||||
|
next_due_date: date # For creation, this will be the initial due date
|
||||||
|
type: ChoreTypeEnum
|
||||||
|
|
||||||
|
@field_validator('custom_interval_days', mode='before')
|
||||||
|
@classmethod
|
||||||
|
def check_custom_interval_days(cls, value, values):
|
||||||
|
# Pydantic v2 uses `values.data` to get all fields
|
||||||
|
# For older Pydantic, it might just be `values`
|
||||||
|
# This is a simplified check; actual access might differ slightly
|
||||||
|
# based on Pydantic version context within the validator.
|
||||||
|
# The goal is to ensure custom_interval_days is present if frequency is 'custom'.
|
||||||
|
# This validator might be more complex in a real Pydantic v2 setup.
|
||||||
|
|
||||||
|
# A more direct way if 'frequency' is already parsed into values.data:
|
||||||
|
# freq = values.data.get('frequency')
|
||||||
|
# For this example, we'll assume 'frequency' might not be in 'values.data' yet
|
||||||
|
# if 'custom_interval_days' is validated 'before' 'frequency'.
|
||||||
|
# A truly robust validator might need to be on the whole model or run 'after'.
|
||||||
|
# For now, this is a placeholder for the logic.
|
||||||
|
# Consider if this validation is better handled at the service/CRUD layer for complex cases.
|
||||||
|
return value
|
||||||
|
|
||||||
|
class ChoreCreate(ChoreBase):
|
||||||
|
group_id: Optional[int] = None
|
||||||
|
|
||||||
|
@field_validator('group_id')
|
||||||
|
@classmethod
|
||||||
|
def validate_group_id(cls, v, values):
|
||||||
|
if values.data.get('type') == ChoreTypeEnum.group and v is None:
|
||||||
|
raise ValueError("group_id is required for group chores")
|
||||||
|
if values.data.get('type') == ChoreTypeEnum.personal and v is not None:
|
||||||
|
raise ValueError("group_id must be None for personal chores")
|
||||||
|
return v
|
||||||
|
|
||||||
|
class ChoreUpdate(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
frequency: Optional[ChoreFrequencyEnum] = None
|
||||||
|
custom_interval_days: Optional[int] = None
|
||||||
|
next_due_date: Optional[date] = None # Allow updating next_due_date directly if needed
|
||||||
|
type: Optional[ChoreTypeEnum] = None
|
||||||
|
group_id: Optional[int] = None
|
||||||
|
# last_completed_at should generally not be updated directly by user
|
||||||
|
|
||||||
|
@field_validator('group_id')
|
||||||
|
@classmethod
|
||||||
|
def validate_group_id(cls, v, values):
|
||||||
|
if values.data.get('type') == ChoreTypeEnum.group and v is None:
|
||||||
|
raise ValueError("group_id is required for group chores")
|
||||||
|
if values.data.get('type') == ChoreTypeEnum.personal and v is not None:
|
||||||
|
raise ValueError("group_id must be None for personal chores")
|
||||||
|
return v
|
||||||
|
|
||||||
|
class ChorePublic(ChoreBase):
|
||||||
|
id: int
|
||||||
|
group_id: Optional[int] = None
|
||||||
|
created_by_id: int
|
||||||
|
last_completed_at: Optional[datetime] = None
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
creator: Optional[UserPublic] = None # Embed creator UserPublic schema
|
||||||
|
# group: Optional[GroupPublic] = None # Embed GroupPublic schema if needed
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
# Chore Assignment Schemas
|
||||||
|
class ChoreAssignmentBase(BaseModel):
|
||||||
|
chore_id: int
|
||||||
|
assigned_to_user_id: int
|
||||||
|
due_date: date
|
||||||
|
|
||||||
|
class ChoreAssignmentCreate(ChoreAssignmentBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ChoreAssignmentUpdate(BaseModel):
|
||||||
|
# Only completion status and perhaps due_date can be updated for an assignment
|
||||||
|
is_complete: Optional[bool] = None
|
||||||
|
due_date: Optional[date] = None # If rescheduling an existing assignment is allowed
|
||||||
|
|
||||||
|
class ChoreAssignmentPublic(ChoreAssignmentBase):
|
||||||
|
id: int
|
||||||
|
is_complete: bool
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
# Embed ChorePublic and UserPublic for richer responses
|
||||||
|
chore: Optional[ChorePublic] = None
|
||||||
|
assigned_user: Optional[UserPublic] = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
# To handle potential circular imports if ChorePublic needs GroupPublic and GroupPublic needs ChorePublic
|
||||||
|
# We can update forward refs after all models are defined.
|
||||||
|
# ChorePublic.model_rebuild() # If using Pydantic v2 and forward refs were used with strings
|
||||||
|
# ChoreAssignmentPublic.model_rebuild()
|
55
be/app/schemas/cost.py
Normal file
55
be/app/schemas/cost.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from typing import List, Optional
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
class UserCostShare(BaseModel):
|
||||||
|
user_id: int
|
||||||
|
user_identifier: str # Name or email
|
||||||
|
items_added_value: Decimal = Decimal("0.00") # Total value of items this user added
|
||||||
|
amount_due: Decimal # The user's share of the total cost (for equal split, this is total_cost / num_users)
|
||||||
|
balance: Decimal # items_added_value - amount_due
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
class ListCostSummary(BaseModel):
|
||||||
|
list_id: int
|
||||||
|
list_name: str
|
||||||
|
total_list_cost: Decimal
|
||||||
|
num_participating_users: int
|
||||||
|
equal_share_per_user: Decimal
|
||||||
|
user_balances: List[UserCostShare]
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
class UserBalanceDetail(BaseModel):
|
||||||
|
user_id: int
|
||||||
|
user_identifier: str # Name or email
|
||||||
|
total_paid_for_expenses: Decimal = Decimal("0.00")
|
||||||
|
total_share_of_expenses: Decimal = Decimal("0.00")
|
||||||
|
total_settlements_paid: Decimal = Decimal("0.00")
|
||||||
|
total_settlements_received: Decimal = Decimal("0.00")
|
||||||
|
net_balance: Decimal = Decimal("0.00") # (paid_for_expenses + settlements_received) - (share_of_expenses + settlements_paid)
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
class SuggestedSettlement(BaseModel):
|
||||||
|
from_user_id: int
|
||||||
|
from_user_identifier: str # Name or email of payer
|
||||||
|
to_user_id: int
|
||||||
|
to_user_identifier: str # Name or email of payee
|
||||||
|
amount: Decimal
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
class GroupBalanceSummary(BaseModel):
|
||||||
|
group_id: int
|
||||||
|
group_name: str
|
||||||
|
overall_total_expenses: Decimal = Decimal("0.00")
|
||||||
|
overall_total_settlements: Decimal = Decimal("0.00")
|
||||||
|
user_balances: List[UserBalanceDetail]
|
||||||
|
# Optional: Could add a list of suggested settlements to zero out balances
|
||||||
|
suggested_settlements: Optional[List[SuggestedSettlement]] = None
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
# class SuggestedSettlement(BaseModel):
|
||||||
|
# from_user_id: int
|
||||||
|
# to_user_id: int
|
||||||
|
# amount: Decimal
|
180
be/app/schemas/expense.py
Normal file
180
be/app/schemas/expense.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
# app/schemas/expense.py
|
||||||
|
from pydantic import BaseModel, ConfigDict, validator, Field
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from decimal import Decimal
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Assuming SplitTypeEnum is accessible here, e.g., from app.models or app.core.enums
|
||||||
|
# For now, let's redefine it or import it if models.py is parsable by Pydantic directly
|
||||||
|
# If it's from app.models, you might need to make app.models.SplitTypeEnum Pydantic-compatible or map it.
|
||||||
|
# For simplicity during schema definition, I'll redefine a string enum here.
|
||||||
|
# In a real setup, ensure this aligns with the SQLAlchemy enum in models.py.
|
||||||
|
from app.models import SplitTypeEnum, ExpenseSplitStatusEnum, ExpenseOverallStatusEnum # Try importing directly
|
||||||
|
from app.schemas.user import UserPublic # For user details in responses
|
||||||
|
from app.schemas.settlement_activity import SettlementActivityPublic # For settlement activities
|
||||||
|
|
||||||
|
# --- ExpenseSplit Schemas ---
|
||||||
|
class ExpenseSplitBase(BaseModel):
|
||||||
|
user_id: int
|
||||||
|
owed_amount: Decimal
|
||||||
|
share_percentage: Optional[Decimal] = None
|
||||||
|
share_units: Optional[int] = None
|
||||||
|
|
||||||
|
class ExpenseSplitCreate(ExpenseSplitBase):
|
||||||
|
pass # All fields from base are needed for creation
|
||||||
|
|
||||||
|
class ExpenseSplitPublic(ExpenseSplitBase):
|
||||||
|
id: int
|
||||||
|
expense_id: int
|
||||||
|
user: Optional[UserPublic] = None # If we want to nest user details
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
status: ExpenseSplitStatusEnum # New field
|
||||||
|
paid_at: Optional[datetime] = None # New field
|
||||||
|
settlement_activities: List[SettlementActivityPublic] = [] # New field
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
# --- Expense Schemas ---
|
||||||
|
class RecurrencePatternBase(BaseModel):
|
||||||
|
type: str = Field(..., description="Type of recurrence: daily, weekly, monthly, yearly")
|
||||||
|
interval: int = Field(..., description="Interval of recurrence (e.g., every X days/weeks/months/years)")
|
||||||
|
days_of_week: Optional[List[int]] = Field(None, description="Days of week for weekly recurrence (0-6, Sunday-Saturday)")
|
||||||
|
end_date: Optional[datetime] = Field(None, description="Optional end date for the recurrence")
|
||||||
|
max_occurrences: Optional[int] = Field(None, description="Optional maximum number of occurrences")
|
||||||
|
|
||||||
|
class RecurrencePatternCreate(RecurrencePatternBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class RecurrencePatternUpdate(RecurrencePatternBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class RecurrencePatternInDB(RecurrencePatternBase):
|
||||||
|
id: int
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
class ExpenseBase(BaseModel):
|
||||||
|
description: str
|
||||||
|
total_amount: Decimal
|
||||||
|
currency: Optional[str] = "USD"
|
||||||
|
expense_date: Optional[datetime] = None
|
||||||
|
split_type: SplitTypeEnum
|
||||||
|
list_id: Optional[int] = None
|
||||||
|
group_id: Optional[int] = None # Should be present if list_id is not, and vice-versa
|
||||||
|
item_id: Optional[int] = None
|
||||||
|
paid_by_user_id: int
|
||||||
|
is_recurring: bool = Field(False, description="Whether this is a recurring expense")
|
||||||
|
recurrence_pattern: Optional[RecurrencePatternCreate] = Field(None, description="Recurrence pattern for recurring expenses")
|
||||||
|
|
||||||
|
class ExpenseCreate(ExpenseBase):
|
||||||
|
# For EQUAL split, splits are generated. For others, they might be provided.
|
||||||
|
# This logic will be in the CRUD: if split_type is EXACT_AMOUNTS, PERCENTAGE, SHARES,
|
||||||
|
# then 'splits_in' should be provided.
|
||||||
|
splits_in: Optional[List[ExpenseSplitCreate]] = None
|
||||||
|
|
||||||
|
@validator('total_amount')
|
||||||
|
def total_amount_must_be_positive(cls, v):
|
||||||
|
if v <= Decimal('0'):
|
||||||
|
raise ValueError('Total amount must be positive')
|
||||||
|
return v
|
||||||
|
|
||||||
|
# Basic validation: if list_id is None, group_id must be provided.
|
||||||
|
# More complex cross-field validation might be needed.
|
||||||
|
@validator('group_id', always=True)
|
||||||
|
def check_list_or_group_id(cls, v, values):
|
||||||
|
if values.get('list_id') is None and v is None:
|
||||||
|
raise ValueError('Either list_id or group_id must be provided for an expense')
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator('recurrence_pattern')
|
||||||
|
def validate_recurrence_pattern(cls, v, values):
|
||||||
|
if values.get('is_recurring') and not v:
|
||||||
|
raise ValueError('Recurrence pattern is required for recurring expenses')
|
||||||
|
if not values.get('is_recurring') and v:
|
||||||
|
raise ValueError('Recurrence pattern should not be provided for non-recurring expenses')
|
||||||
|
return v
|
||||||
|
|
||||||
|
class ExpenseUpdate(BaseModel):
|
||||||
|
description: Optional[str] = None
|
||||||
|
total_amount: Optional[Decimal] = None
|
||||||
|
currency: Optional[str] = None
|
||||||
|
expense_date: Optional[datetime] = None
|
||||||
|
split_type: Optional[SplitTypeEnum] = None
|
||||||
|
list_id: Optional[int] = None
|
||||||
|
group_id: Optional[int] = None
|
||||||
|
item_id: Optional[int] = None
|
||||||
|
# paid_by_user_id is usually not updatable directly to maintain integrity.
|
||||||
|
# Updating splits would be a more complex operation, potentially a separate endpoint or careful logic.
|
||||||
|
version: int # For optimistic locking
|
||||||
|
is_recurring: Optional[bool] = None
|
||||||
|
recurrence_pattern: Optional[RecurrencePatternUpdate] = None
|
||||||
|
next_occurrence: Optional[datetime] = None
|
||||||
|
|
||||||
|
class ExpensePublic(ExpenseBase):
|
||||||
|
id: int
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
version: int
|
||||||
|
created_by_user_id: int
|
||||||
|
splits: List[ExpenseSplitPublic] = []
|
||||||
|
paid_by_user: Optional[UserPublic] = None # If nesting user details
|
||||||
|
overall_settlement_status: ExpenseOverallStatusEnum # New field
|
||||||
|
# list: Optional[ListPublic] # If nesting list details
|
||||||
|
# group: Optional[GroupPublic] # If nesting group details
|
||||||
|
# item: Optional[ItemPublic] # If nesting item details
|
||||||
|
is_recurring: bool
|
||||||
|
next_occurrence: Optional[datetime]
|
||||||
|
last_occurrence: Optional[datetime]
|
||||||
|
recurrence_pattern: Optional[RecurrencePatternInDB]
|
||||||
|
parent_expense_id: Optional[int]
|
||||||
|
generated_expenses: List['ExpensePublic'] = []
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
# --- Settlement Schemas ---
|
||||||
|
class SettlementBase(BaseModel):
|
||||||
|
group_id: int
|
||||||
|
paid_by_user_id: int
|
||||||
|
paid_to_user_id: int
|
||||||
|
amount: Decimal
|
||||||
|
settlement_date: Optional[datetime] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
class SettlementCreate(SettlementBase):
|
||||||
|
@validator('amount')
|
||||||
|
def amount_must_be_positive(cls, v):
|
||||||
|
if v <= Decimal('0'):
|
||||||
|
raise ValueError('Settlement amount must be positive')
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator('paid_to_user_id')
|
||||||
|
def payer_and_payee_must_be_different(cls, v, values):
|
||||||
|
if 'paid_by_user_id' in values and v == values['paid_by_user_id']:
|
||||||
|
raise ValueError('Payer and payee cannot be the same user')
|
||||||
|
return v
|
||||||
|
|
||||||
|
class SettlementUpdate(BaseModel):
|
||||||
|
amount: Optional[Decimal] = None
|
||||||
|
settlement_date: Optional[datetime] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
# group_id, paid_by_user_id, paid_to_user_id are typically not updatable.
|
||||||
|
version: int # For optimistic locking
|
||||||
|
|
||||||
|
class SettlementPublic(SettlementBase):
|
||||||
|
id: int
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
version: int
|
||||||
|
created_by_user_id: int
|
||||||
|
# payer: Optional[UserPublic] # If we want to include payer details
|
||||||
|
# payee: Optional[UserPublic] # If we want to include payee details
|
||||||
|
# group: Optional[GroupPublic] # If we want to include group details
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
# Placeholder for nested schemas (e.g., UserPublic) if needed
|
||||||
|
# from app.schemas.user import UserPublic
|
||||||
|
# from app.schemas.list import ListPublic
|
||||||
|
# from app.schemas.group import GroupPublic
|
||||||
|
# from app.schemas.item import ItemPublic
|
@ -1,5 +1,5 @@
|
|||||||
# app/schemas/group.py
|
# app/schemas/group.py
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, computed_field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
@ -15,7 +15,25 @@ class GroupPublic(BaseModel):
|
|||||||
name: str
|
name: str
|
||||||
created_by_id: int
|
created_by_id: int
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
members: Optional[List[UserPublic]] = None # Include members only in detailed view
|
member_associations: Optional[List["UserGroupPublic"]] = None
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def members(self) -> Optional[List[UserPublic]]:
|
||||||
|
if not self.member_associations:
|
||||||
|
return None
|
||||||
|
return [assoc.user for assoc in self.member_associations if assoc.user]
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
# Properties for UserGroup association
|
||||||
|
class UserGroupPublic(BaseModel):
|
||||||
|
id: int
|
||||||
|
user_id: int
|
||||||
|
group_id: int
|
||||||
|
role: str
|
||||||
|
joined_at: datetime
|
||||||
|
user: Optional[UserPublic] = None
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
# app/schemas/health.py
|
# app/schemas/health.py
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
class HealthStatus(BaseModel):
|
class HealthStatus(BaseModel):
|
||||||
"""
|
"""
|
||||||
Response model for the health check endpoint.
|
Response model for the health check endpoint.
|
||||||
"""
|
"""
|
||||||
status: str = "ok" # Provide a default value
|
status: str = settings.HEALTH_STATUS_OK # Use configured default value
|
||||||
database: str
|
database: str
|
@ -16,6 +16,7 @@ class ItemPublic(BaseModel):
|
|||||||
completed_by_id: Optional[int] = None
|
completed_by_id: Optional[int] = None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
version: int
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
# Properties to receive via API on creation
|
# Properties to receive via API on creation
|
||||||
@ -31,4 +32,6 @@ class ItemUpdate(BaseModel):
|
|||||||
quantity: Optional[str] = None
|
quantity: Optional[str] = None
|
||||||
is_complete: Optional[bool] = None
|
is_complete: Optional[bool] = None
|
||||||
price: Optional[Decimal] = None # Price added here for update
|
price: Optional[Decimal] = None # Price added here for update
|
||||||
|
position: Optional[int] = None # For reordering
|
||||||
|
version: int
|
||||||
# completed_by_id will be set internally if is_complete is true
|
# completed_by_id will be set internally if is_complete is true
|
@ -16,6 +16,7 @@ class ListUpdate(BaseModel):
|
|||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
is_complete: Optional[bool] = None
|
is_complete: Optional[bool] = None
|
||||||
|
version: int # Client must provide the version for updates
|
||||||
# Potentially add group_id update later if needed
|
# Potentially add group_id update later if needed
|
||||||
|
|
||||||
# Base properties returned by API (common fields)
|
# Base properties returned by API (common fields)
|
||||||
@ -28,6 +29,7 @@ class ListBase(BaseModel):
|
|||||||
is_complete: bool
|
is_complete: bool
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
version: int # Include version in responses
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
@ -40,6 +42,9 @@ class ListDetail(ListBase):
|
|||||||
items: List[ItemPublic] = [] # Include list of items
|
items: List[ItemPublic] = [] # Include list of items
|
||||||
|
|
||||||
class ListStatus(BaseModel):
|
class ListStatus(BaseModel):
|
||||||
list_updated_at: datetime
|
updated_at: datetime
|
||||||
latest_item_updated_at: Optional[datetime] = None # Can be null if list has no items
|
|
||||||
item_count: int
|
item_count: int
|
||||||
|
latest_item_updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
class ListStatusWithId(ListStatus):
|
||||||
|
id: int
|
43
be/app/schemas/settlement_activity.py
Normal file
43
be/app/schemas/settlement_activity.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||||||
|
from typing import Optional, List
|
||||||
|
from decimal import Decimal
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from app.schemas.user import UserPublic # Assuming UserPublic is defined here
|
||||||
|
|
||||||
|
class SettlementActivityBase(BaseModel):
|
||||||
|
expense_split_id: int
|
||||||
|
paid_by_user_id: int
|
||||||
|
amount_paid: Decimal
|
||||||
|
paid_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
class SettlementActivityCreate(SettlementActivityBase):
|
||||||
|
@field_validator('amount_paid')
|
||||||
|
@classmethod
|
||||||
|
def amount_must_be_positive(cls, v: Decimal) -> Decimal:
|
||||||
|
if v <= Decimal("0"):
|
||||||
|
raise ValueError("Amount paid must be a positive value.")
|
||||||
|
return v
|
||||||
|
|
||||||
|
class SettlementActivityPublic(SettlementActivityBase):
|
||||||
|
id: int
|
||||||
|
created_by_user_id: int # User who recorded this activity
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
payer: Optional[UserPublic] = None # User who made this part of the payment
|
||||||
|
creator: Optional[UserPublic] = None # User who recorded this activity
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
# Schema for updating a settlement activity (if needed in the future)
|
||||||
|
# class SettlementActivityUpdate(BaseModel):
|
||||||
|
# amount_paid: Optional[Decimal] = None
|
||||||
|
# paid_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
# @field_validator('amount_paid')
|
||||||
|
# @classmethod
|
||||||
|
# def amount_must_be_positive_if_provided(cls, v: Optional[Decimal]) -> Optional[Decimal]:
|
||||||
|
# if v is not None and v <= Decimal("0"):
|
||||||
|
# raise ValueError("Amount paid must be a positive value.")
|
||||||
|
# return v
|
@ -12,14 +12,27 @@ class UserBase(BaseModel):
|
|||||||
class UserCreate(UserBase):
|
class UserCreate(UserBase):
|
||||||
password: str
|
password: str
|
||||||
|
|
||||||
# Properties to receive via API on update (optional, add later if needed)
|
def create_update_dict(self):
|
||||||
# class UserUpdate(UserBase):
|
return {
|
||||||
# password: Optional[str] = None
|
"email": self.email,
|
||||||
|
"name": self.name,
|
||||||
|
"password": self.password,
|
||||||
|
"is_active": True,
|
||||||
|
"is_superuser": False,
|
||||||
|
"is_verified": False
|
||||||
|
}
|
||||||
|
|
||||||
|
# Properties to receive via API on update
|
||||||
|
class UserUpdate(UserBase):
|
||||||
|
password: Optional[str] = None
|
||||||
|
is_active: Optional[bool] = None
|
||||||
|
is_superuser: Optional[bool] = None
|
||||||
|
is_verified: Optional[bool] = None
|
||||||
|
|
||||||
# Properties stored in DB
|
# Properties stored in DB
|
||||||
class UserInDBBase(UserBase):
|
class UserInDBBase(UserBase):
|
||||||
id: int
|
id: int
|
||||||
hashed_password: str
|
password_hash: str
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
model_config = ConfigDict(from_attributes=True) # Use orm_mode in Pydantic v1
|
model_config = ConfigDict(from_attributes=True) # Use orm_mode in Pydantic v1
|
||||||
|
|
||||||
|
10
be/entrypoint.sh
Executable file
10
be/entrypoint.sh
Executable file
@ -0,0 +1,10 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Run database migrations
|
||||||
|
echo "Running database migrations..."
|
||||||
|
alembic upgrade head
|
||||||
|
|
||||||
|
# Execute the command passed as arguments to this script
|
||||||
|
echo "Starting application..."
|
||||||
|
exec "$@"
|
5
be/pytest.ini
Normal file
5
be/pytest.ini
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
[pytest]
|
||||||
|
pythonpath = .
|
||||||
|
testpaths = tests
|
||||||
|
python_files = test_*.py
|
||||||
|
asyncio_mode = auto
|
@ -10,3 +10,18 @@ passlib[bcrypt]>=1.7.4
|
|||||||
python-jose[cryptography]>=3.3.0
|
python-jose[cryptography]>=3.3.0
|
||||||
pydantic[email]
|
pydantic[email]
|
||||||
google-generativeai>=0.5.0
|
google-generativeai>=0.5.0
|
||||||
|
sentry-sdk[fastapi]>=1.39.0
|
||||||
|
python-multipart>=0.0.6 # Required for form data handling
|
||||||
|
fastapi-users[sqlalchemy]>=12.1.2
|
||||||
|
email-validator>=2.0.0
|
||||||
|
fastapi-users[oauth]>=12.1.2
|
||||||
|
authlib>=1.3.0
|
||||||
|
itsdangerous>=2.1.2
|
||||||
|
pytest>=7.4.0
|
||||||
|
pytest-asyncio>=0.21.0
|
||||||
|
pytest-cov>=4.1.0
|
||||||
|
httpx>=0.24.0 # For async HTTP testing
|
||||||
|
aiosqlite>=0.19.0 # For async SQLite support in tests
|
||||||
|
|
||||||
|
# Scheduler
|
||||||
|
APScheduler==3.10.4
|
649
be/tests/api/v1/endpoints/test_financials.py
Normal file
649
be/tests/api/v1/endpoints/test_financials.py
Normal file
@ -0,0 +1,649 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi import status
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from typing import Callable, Dict, Any
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from app.models import User as UserModel, Group as GroupModel, List as ListModel
|
||||||
|
from app.schemas.expense import ExpenseCreate, ExpensePublic, ExpenseUpdate
|
||||||
|
# from app.config import settings # Comment out the original import
|
||||||
|
|
||||||
|
# Helper to create a URL for an endpoint
|
||||||
|
# API_V1_STR = settings.API_V1_STR # Comment out the original assignment
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mock_settings_financials():
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.API_V1_STR = "/api/v1"
|
||||||
|
return mock_settings
|
||||||
|
|
||||||
|
# Patch the settings in the test module
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_settings_financials(mock_settings_financials):
|
||||||
|
with patch("app.config.settings", mock_settings_financials):
|
||||||
|
yield
|
||||||
|
|
||||||
|
def expense_url(endpoint: str = "") -> str:
|
||||||
|
# Use the mocked API_V1_STR via the patched settings object
|
||||||
|
from app.config import settings # Import settings here to use the patched version
|
||||||
|
return f"{settings.API_V1_STR}/financials/expenses{endpoint}"
|
||||||
|
|
||||||
|
def settlement_url(endpoint: str = "") -> str:
|
||||||
|
# Use the mocked API_V1_STR via the patched settings object
|
||||||
|
from app.config import settings # Import settings here to use the patched version
|
||||||
|
return f"{settings.API_V1_STR}/financials/settlements{endpoint}"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_new_expense_success_list_context(
|
||||||
|
client: AsyncClient,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_user: UserModel,
|
||||||
|
test_list_user_is_member: ListModel,
|
||||||
|
) -> None:
|
||||||
|
expense_data = ExpenseCreate(
|
||||||
|
description="Test Expense for List",
|
||||||
|
amount=100.00,
|
||||||
|
currency="USD",
|
||||||
|
paid_by_user_id=test_user.id,
|
||||||
|
list_id=test_list_user_is_member.id,
|
||||||
|
group_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
expense_url(),
|
||||||
|
headers=normal_user_token_headers,
|
||||||
|
json=expense_data.model_dump(exclude_unset=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
content = response.json()
|
||||||
|
assert content["description"] == expense_data.description
|
||||||
|
assert content["amount"] == expense_data.amount
|
||||||
|
assert content["currency"] == expense_data.currency
|
||||||
|
assert content["paid_by_user_id"] == test_user.id
|
||||||
|
assert content["list_id"] == test_list_user_is_member.id
|
||||||
|
if test_list_user_is_member.group_id:
|
||||||
|
assert content["group_id"] == test_list_user_is_member.group_id
|
||||||
|
else:
|
||||||
|
assert content["group_id"] is None
|
||||||
|
assert "id" in content
|
||||||
|
assert "created_at" in content
|
||||||
|
assert "updated_at" in content
|
||||||
|
assert "version" in content
|
||||||
|
assert content["version"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_new_expense_success_group_context(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_user: UserModel,
|
||||||
|
test_group_user_is_member: GroupModel,
|
||||||
|
) -> None:
|
||||||
|
expense_data = ExpenseCreate(
|
||||||
|
description="Test Expense for Group",
|
||||||
|
amount=50.00,
|
||||||
|
currency="EUR",
|
||||||
|
paid_by_user_id=test_user.id,
|
||||||
|
group_id=test_group_user_is_member.id,
|
||||||
|
list_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
expense_url(),
|
||||||
|
headers=normal_user_token_headers,
|
||||||
|
json=expense_data.model_dump(exclude_unset=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
content = response.json()
|
||||||
|
assert content["description"] == expense_data.description
|
||||||
|
assert content["paid_by_user_id"] == test_user.id
|
||||||
|
assert content["group_id"] == test_group_user_is_member.id
|
||||||
|
assert content["list_id"] is None
|
||||||
|
assert content["version"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_new_expense_fail_no_list_or_group(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_user: UserModel,
|
||||||
|
) -> None:
|
||||||
|
expense_data = ExpenseCreate(
|
||||||
|
description="Test Invalid Expense",
|
||||||
|
amount=10.00,
|
||||||
|
currency="USD",
|
||||||
|
paid_by_user_id=test_user.id,
|
||||||
|
list_id=None,
|
||||||
|
group_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
expense_url(),
|
||||||
|
headers=normal_user_token_headers,
|
||||||
|
json=expense_data.model_dump(exclude_unset=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
content = response.json()
|
||||||
|
assert "Expense must be linked to a list_id or group_id" in content["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_new_expense_fail_paid_by_other_not_owner(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_user: UserModel,
|
||||||
|
test_group_user_is_member: GroupModel,
|
||||||
|
another_user_in_group: UserModel,
|
||||||
|
) -> None:
|
||||||
|
expense_data = ExpenseCreate(
|
||||||
|
description="Expense paid by other",
|
||||||
|
amount=75.00,
|
||||||
|
currency="GBP",
|
||||||
|
paid_by_user_id=another_user_in_group.id,
|
||||||
|
group_id=test_group_user_is_member.id,
|
||||||
|
list_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
expense_url(),
|
||||||
|
headers=normal_user_token_headers,
|
||||||
|
json=expense_data.model_dump(exclude_unset=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
content = response.json()
|
||||||
|
assert "Only group owners can create expenses paid by others" in content["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_expense_success(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_user: UserModel,
|
||||||
|
created_expense: ExpensePublic,
|
||||||
|
) -> None:
|
||||||
|
response = await client.get(
|
||||||
|
expense_url(f"/{created_expense.id}"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
content = response.json()
|
||||||
|
assert content["id"] == created_expense.id
|
||||||
|
assert content["description"] == created_expense.description
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_expense_not_found(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
) -> None:
|
||||||
|
response = await client.get(
|
||||||
|
expense_url("/999"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
content = response.json()
|
||||||
|
assert "Expense not found" in content["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_expense_forbidden_personal_expense_other_user(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
personal_expense_of_another_user: ExpensePublic,
|
||||||
|
) -> None:
|
||||||
|
response = await client.get(
|
||||||
|
expense_url(f"/{personal_expense_of_another_user.id}"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
content = response.json()
|
||||||
|
assert "You do not have permission to access this expense" in content["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_expense_forbidden_not_member_of_list_or_group(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_user: UserModel,
|
||||||
|
another_user: UserModel,
|
||||||
|
expense_in_inaccessible_list_or_group: ExpensePublic,
|
||||||
|
) -> None:
|
||||||
|
response = await client.get(
|
||||||
|
expense_url(f"/{expense_in_inaccessible_list_or_group.id}"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
content = response.json()
|
||||||
|
assert "You do not have permission to access this expense" in content["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_expense_success_in_list_user_has_access(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_user: UserModel,
|
||||||
|
expense_in_accessible_list: ExpensePublic,
|
||||||
|
test_list_user_is_member: ListModel,
|
||||||
|
) -> None:
|
||||||
|
response = await client.get(
|
||||||
|
expense_url(f"/{expense_in_accessible_list.id}"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
content = response.json()
|
||||||
|
assert content["id"] == expense_in_accessible_list.id
|
||||||
|
assert content["list_id"] == test_list_user_is_member.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_expense_success_in_group_user_has_access(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_user: UserModel,
|
||||||
|
expense_in_accessible_group: ExpensePublic,
|
||||||
|
test_group_user_is_member: GroupModel,
|
||||||
|
) -> None:
|
||||||
|
response = await client.get(
|
||||||
|
expense_url(f"/{expense_in_accessible_group.id}"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
content = response.json()
|
||||||
|
assert content["id"] == expense_in_accessible_group.id
|
||||||
|
assert content["group_id"] == test_group_user_is_member.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_list_expenses_success(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_user: UserModel,
|
||||||
|
test_list_user_is_member: ListModel,
|
||||||
|
) -> None:
|
||||||
|
response = await client.get(
|
||||||
|
expense_url(f"?list_id={test_list_user_is_member.id}"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
content = response.json()
|
||||||
|
assert isinstance(content, list)
|
||||||
|
for expense in content:
|
||||||
|
assert expense["list_id"] == test_list_user_is_member.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_list_expenses_list_not_found(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
) -> None:
|
||||||
|
response = await client.get(
|
||||||
|
expense_url("?list_id=999"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
content = response.json()
|
||||||
|
assert "List not found" in content["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_list_expenses_no_access(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_list_user_not_member: ListModel,
|
||||||
|
) -> None:
|
||||||
|
response = await client.get(
|
||||||
|
expense_url(f"?list_id={test_list_user_not_member.id}"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
content = response.json()
|
||||||
|
assert "You do not have permission to access this list" in content["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_list_expenses_empty(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_list_user_is_member_no_expenses: ListModel,
|
||||||
|
) -> None:
|
||||||
|
response = await client.get(
|
||||||
|
expense_url(f"?list_id={test_list_user_is_member_no_expenses.id}"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
content = response.json()
|
||||||
|
assert isinstance(content, list)
|
||||||
|
assert len(content) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_list_expenses_pagination(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_user: UserModel,
|
||||||
|
test_list_with_multiple_expenses: ListModel,
|
||||||
|
created_expenses_for_list: list[ExpensePublic],
|
||||||
|
) -> None:
|
||||||
|
# Test first page
|
||||||
|
response = await client.get(
|
||||||
|
expense_url(f"?list_id={test_list_with_multiple_expenses.id}&skip=0&limit=2"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
content = response.json()
|
||||||
|
assert isinstance(content, list)
|
||||||
|
assert len(content) == 2
|
||||||
|
assert content[0]["id"] == created_expenses_for_list[0].id
|
||||||
|
assert content[1]["id"] == created_expenses_for_list[1].id
|
||||||
|
|
||||||
|
# Test second page
|
||||||
|
response = await client.get(
|
||||||
|
expense_url(f"?list_id={test_list_with_multiple_expenses.id}&skip=2&limit=2"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
content = response.json()
|
||||||
|
assert isinstance(content, list)
|
||||||
|
assert len(content) == 2
|
||||||
|
assert content[0]["id"] == created_expenses_for_list[2].id
|
||||||
|
assert content[1]["id"] == created_expenses_for_list[3].id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_group_expenses_success(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_user: UserModel,
|
||||||
|
test_group_user_is_member: GroupModel,
|
||||||
|
) -> None:
|
||||||
|
response = await client.get(
|
||||||
|
expense_url(f"?group_id={test_group_user_is_member.id}"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
content = response.json()
|
||||||
|
assert isinstance(content, list)
|
||||||
|
for expense in content:
|
||||||
|
assert expense["group_id"] == test_group_user_is_member.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_group_expenses_group_not_found(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
) -> None:
|
||||||
|
response = await client.get(
|
||||||
|
expense_url("?group_id=999"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
content = response.json()
|
||||||
|
assert "Group not found" in content["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_group_expenses_no_access(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_group_user_not_member: GroupModel,
|
||||||
|
) -> None:
|
||||||
|
response = await client.get(
|
||||||
|
expense_url(f"?group_id={test_group_user_not_member.id}"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
content = response.json()
|
||||||
|
assert "You do not have permission to access this group" in content["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_group_expenses_empty(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_group_user_is_member_no_expenses: GroupModel,
|
||||||
|
) -> None:
|
||||||
|
response = await client.get(
|
||||||
|
expense_url(f"?group_id={test_group_user_is_member_no_expenses.id}"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
content = response.json()
|
||||||
|
assert isinstance(content, list)
|
||||||
|
assert len(content) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_group_expenses_pagination(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_user: UserModel,
|
||||||
|
test_group_with_multiple_expenses: GroupModel,
|
||||||
|
created_expenses_for_group: list[ExpensePublic],
|
||||||
|
) -> None:
|
||||||
|
# Test first page
|
||||||
|
response = await client.get(
|
||||||
|
expense_url(f"?group_id={test_group_with_multiple_expenses.id}&skip=0&limit=2"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
content = response.json()
|
||||||
|
assert isinstance(content, list)
|
||||||
|
assert len(content) == 2
|
||||||
|
assert content[0]["id"] == created_expenses_for_group[0].id
|
||||||
|
assert content[1]["id"] == created_expenses_for_group[1].id
|
||||||
|
|
||||||
|
# Test second page
|
||||||
|
response = await client.get(
|
||||||
|
expense_url(f"?group_id={test_group_with_multiple_expenses.id}&skip=2&limit=2"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
content = response.json()
|
||||||
|
assert isinstance(content, list)
|
||||||
|
assert len(content) == 2
|
||||||
|
assert content[0]["id"] == created_expenses_for_group[2].id
|
||||||
|
assert content[1]["id"] == created_expenses_for_group[3].id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_expense_success_payer_updates_details(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_user: UserModel,
|
||||||
|
expense_paid_by_test_user: ExpensePublic,
|
||||||
|
) -> None:
|
||||||
|
update_data = ExpenseUpdate(
|
||||||
|
description="Updated expense description",
|
||||||
|
version=expense_paid_by_test_user.version,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.put(
|
||||||
|
expense_url(f"/{expense_paid_by_test_user.id}"),
|
||||||
|
headers=normal_user_token_headers,
|
||||||
|
json=update_data.model_dump(exclude_unset=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
content = response.json()
|
||||||
|
assert content["description"] == update_data.description
|
||||||
|
assert content["version"] == expense_paid_by_test_user.version + 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_expense_success_group_owner_updates_others_expense(
|
||||||
|
client: AsyncClient,
|
||||||
|
group_owner_token_headers: Dict[str, str],
|
||||||
|
group_owner: UserModel,
|
||||||
|
expense_paid_by_another_in_group_where_test_user_is_owner: ExpensePublic,
|
||||||
|
another_user_in_group: UserModel,
|
||||||
|
) -> None:
|
||||||
|
update_data = ExpenseUpdate(
|
||||||
|
description="Updated by group owner",
|
||||||
|
version=expense_paid_by_another_in_group_where_test_user_is_owner.version,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.put(
|
||||||
|
expense_url(f"/{expense_paid_by_another_in_group_where_test_user_is_owner.id}"),
|
||||||
|
headers=group_owner_token_headers,
|
||||||
|
json=update_data.model_dump(exclude_unset=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
content = response.json()
|
||||||
|
assert content["description"] == update_data.description
|
||||||
|
assert content["version"] == expense_paid_by_another_in_group_where_test_user_is_owner.version + 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_expense_fail_not_payer_nor_group_owner(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_user: UserModel,
|
||||||
|
expense_paid_by_another_in_group_where_test_user_is_member: ExpensePublic,
|
||||||
|
another_user_in_group: UserModel,
|
||||||
|
) -> None:
|
||||||
|
update_data = ExpenseUpdate(
|
||||||
|
description="Attempted update by non-owner",
|
||||||
|
version=expense_paid_by_another_in_group_where_test_user_is_member.version,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.put(
|
||||||
|
expense_url(f"/{expense_paid_by_another_in_group_where_test_user_is_member.id}"),
|
||||||
|
headers=normal_user_token_headers,
|
||||||
|
json=update_data.model_dump(exclude_unset=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
content = response.json()
|
||||||
|
assert "You do not have permission to update this expense" in content["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_expense_fail_not_found(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
) -> None:
|
||||||
|
update_data = ExpenseUpdate(
|
||||||
|
description="Update attempt on non-existent expense",
|
||||||
|
version=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.put(
|
||||||
|
expense_url("/999"),
|
||||||
|
headers=normal_user_token_headers,
|
||||||
|
json=update_data.model_dump(exclude_unset=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
content = response.json()
|
||||||
|
assert "Expense not found" in content["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_expense_fail_change_paid_by_user_not_owner(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_user: UserModel,
|
||||||
|
expense_paid_by_test_user_in_group: ExpensePublic,
|
||||||
|
another_user_in_same_group: UserModel,
|
||||||
|
) -> None:
|
||||||
|
update_data = ExpenseUpdate(
|
||||||
|
paid_by_user_id=another_user_in_same_group.id,
|
||||||
|
version=expense_paid_by_test_user_in_group.version,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.put(
|
||||||
|
expense_url(f"/{expense_paid_by_test_user_in_group.id}"),
|
||||||
|
headers=normal_user_token_headers,
|
||||||
|
json=update_data.model_dump(exclude_unset=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
content = response.json()
|
||||||
|
assert "Only group owners can change the payer of an expense" in content["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_expense_success_owner_changes_paid_by_user(
|
||||||
|
client: AsyncClient,
|
||||||
|
group_owner_token_headers: Dict[str, str],
|
||||||
|
group_owner: UserModel,
|
||||||
|
expense_in_group_owner_group: ExpensePublic,
|
||||||
|
another_user_in_same_group: UserModel,
|
||||||
|
) -> None:
|
||||||
|
update_data = ExpenseUpdate(
|
||||||
|
paid_by_user_id=another_user_in_same_group.id,
|
||||||
|
version=expense_in_group_owner_group.version,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.put(
|
||||||
|
expense_url(f"/{expense_in_group_owner_group.id}"),
|
||||||
|
headers=group_owner_token_headers,
|
||||||
|
json=update_data.model_dump(exclude_unset=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
content = response.json()
|
||||||
|
assert content["paid_by_user_id"] == another_user_in_same_group.id
|
||||||
|
assert content["version"] == expense_in_group_owner_group.version + 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_expense_success_payer(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_user: UserModel,
|
||||||
|
expense_paid_by_test_user: ExpensePublic,
|
||||||
|
) -> None:
|
||||||
|
response = await client.delete(
|
||||||
|
expense_url(f"/{expense_paid_by_test_user.id}"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_expense_success_group_owner(
|
||||||
|
client: AsyncClient,
|
||||||
|
group_owner_token_headers: Dict[str, str],
|
||||||
|
group_owner: UserModel,
|
||||||
|
expense_paid_by_another_in_group_where_test_user_is_owner: ExpensePublic,
|
||||||
|
) -> None:
|
||||||
|
response = await client.delete(
|
||||||
|
expense_url(f"/{expense_paid_by_another_in_group_where_test_user_is_owner.id}"),
|
||||||
|
headers=group_owner_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_expense_fail_not_payer_nor_group_owner(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
test_user: UserModel,
|
||||||
|
expense_paid_by_another_in_group_where_test_user_is_member: ExpensePublic,
|
||||||
|
) -> None:
|
||||||
|
response = await client.delete(
|
||||||
|
expense_url(f"/{expense_paid_by_another_in_group_where_test_user_is_member.id}"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
content = response.json()
|
||||||
|
assert "You do not have permission to delete this expense" in content["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_expense_fail_not_found(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
) -> None:
|
||||||
|
response = await client.delete(
|
||||||
|
expense_url("/999"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
content = response.json()
|
||||||
|
assert "Expense not found" in content["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_expense_idempotency(
|
||||||
|
client: AsyncClient,
|
||||||
|
normal_user_token_headers: Dict[str, str],
|
||||||
|
expense_paid_by_test_user: ExpensePublic,
|
||||||
|
) -> None:
|
||||||
|
# First delete
|
||||||
|
response = await client.delete(
|
||||||
|
expense_url(f"/{expense_paid_by_test_user.id}"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
# Second delete should also succeed
|
||||||
|
response = await client.delete(
|
||||||
|
expense_url(f"/{expense_paid_by_test_user.id}"),
|
||||||
|
headers=normal_user_token_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
# GET /settlements/{settlement_id}
|
||||||
|
# POST /settlements
|
||||||
|
# GET /groups/{group_id}/settlements
|
||||||
|
# PUT /settlements/{settlement_id}
|
||||||
|
# DELETE /settlements/{settlement_id}
|
2282
be/tests/api/v1/test_costs.py
Normal file
2282
be/tests/api/v1/test_costs.py
Normal file
File diff suppressed because it is too large
Load Diff
411
be/tests/api/v1/test_financials.py
Normal file
411
be/tests/api/v1/test_financials.py
Normal file
@ -0,0 +1,411 @@
|
|||||||
|
import pytest
|
||||||
|
import httpx
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from decimal import Decimal
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from app.models import (
|
||||||
|
User,
|
||||||
|
Group,
|
||||||
|
Expense,
|
||||||
|
ExpenseSplit,
|
||||||
|
SettlementActivity,
|
||||||
|
UserRoleEnum,
|
||||||
|
SplitTypeEnum,
|
||||||
|
ExpenseOverallStatusEnum,
|
||||||
|
ExpenseSplitStatusEnum
|
||||||
|
)
|
||||||
|
from app.schemas.settlement_activity import SettlementActivityPublic, SettlementActivityCreate
|
||||||
|
from app.schemas.expense import ExpensePublic, ExpenseSplitPublic
|
||||||
|
from app.core.config import settings # For API prefix
|
||||||
|
|
||||||
|
# Assume db_session, event_loop, client are provided by conftest.py or similar setup
|
||||||
|
# For this example, I'll define basic user/auth fixtures if not assumed from conftest
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def test_user1_api(db_session, client: httpx.AsyncClient) -> Dict[str, Any]:
|
||||||
|
user = User(email="api.user1@example.com", name="API User 1", hashed_password="password1")
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
|
||||||
|
# Simulate token login - in a real setup, you'd call a login endpoint
|
||||||
|
# For now, just returning user and headers directly for mock authentication
|
||||||
|
# This would typically be handled by a dependency override in tests
|
||||||
|
# For simplicity, we'll assume current_active_user dependency correctly resolves to this user
|
||||||
|
# when these headers are used (or mock the dependency).
|
||||||
|
return {"user": user, "headers": {"Authorization": f"Bearer token-for-{user.id}"}}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def test_user2_api(db_session, client: httpx.AsyncClient) -> Dict[str, Any]:
|
||||||
|
user = User(email="api.user2@example.com", name="API User 2", hashed_password="password2")
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
return {"user": user, "headers": {"Authorization": f"Bearer token-for-{user.id}"}}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def test_group_user1_owner_api(db_session, test_user1_api: Dict[str, Any]) -> Group:
|
||||||
|
user1 = test_user1_api["user"]
|
||||||
|
group = Group(name="API Test Group", created_by_id=user1.id)
|
||||||
|
db_session.add(group)
|
||||||
|
await db_session.flush() # Get group.id
|
||||||
|
|
||||||
|
# Add user1 as owner
|
||||||
|
from app.models import UserGroup
|
||||||
|
user_group_assoc = UserGroup(user_id=user1.id, group_id=group.id, role=UserRoleEnum.owner)
|
||||||
|
db_session.add(user_group_assoc)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(group)
|
||||||
|
return group
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def test_expense_in_group_api(db_session, test_user1_api: Dict[str, Any], test_group_user1_owner_api: Group) -> Expense:
|
||||||
|
user1 = test_user1_api["user"]
|
||||||
|
expense = Expense(
|
||||||
|
description="Group API Expense",
|
||||||
|
total_amount=Decimal("50.00"),
|
||||||
|
currency="USD",
|
||||||
|
group_id=test_group_user1_owner_api.id,
|
||||||
|
paid_by_user_id=user1.id,
|
||||||
|
created_by_user_id=user1.id,
|
||||||
|
split_type=SplitTypeEnum.EQUAL,
|
||||||
|
overall_settlement_status=ExpenseOverallStatusEnum.unpaid
|
||||||
|
)
|
||||||
|
db_session.add(expense)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(expense)
|
||||||
|
return expense
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def test_expense_split_for_user2_api(db_session, test_expense_in_group_api: Expense, test_user1_api: Dict[str, Any], test_user2_api: Dict[str, Any]) -> ExpenseSplit:
|
||||||
|
user1 = test_user1_api["user"]
|
||||||
|
user2 = test_user2_api["user"]
|
||||||
|
|
||||||
|
# Split for User 1 (payer)
|
||||||
|
split1 = ExpenseSplit(
|
||||||
|
expense_id=test_expense_in_group_api.id,
|
||||||
|
user_id=user1.id,
|
||||||
|
owed_amount=Decimal("25.00"),
|
||||||
|
status=ExpenseSplitStatusEnum.unpaid
|
||||||
|
)
|
||||||
|
# Split for User 2 (owes)
|
||||||
|
split2 = ExpenseSplit(
|
||||||
|
expense_id=test_expense_in_group_api.id,
|
||||||
|
user_id=user2.id,
|
||||||
|
owed_amount=Decimal("25.00"),
|
||||||
|
status=ExpenseSplitStatusEnum.unpaid
|
||||||
|
)
|
||||||
|
db_session.add_all([split1, split2])
|
||||||
|
|
||||||
|
# Add user2 to the group as a member for permission checks
|
||||||
|
from app.models import UserGroup
|
||||||
|
user_group_assoc = UserGroup(user_id=user2.id, group_id=test_expense_in_group_api.group_id, role=UserRoleEnum.member)
|
||||||
|
db_session.add(user_group_assoc)
|
||||||
|
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(split1)
|
||||||
|
await db_session.refresh(split2)
|
||||||
|
return split2 # Return the split that user2 owes
|
||||||
|
|
||||||
|
|
||||||
|
# --- Tests for POST /expense_splits/{expense_split_id}/settle ---
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_settle_expense_split_by_self_success(
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
test_user2_api: Dict[str, Any], # User2 will settle their own split
|
||||||
|
test_expense_split_for_user2_api: ExpenseSplit,
|
||||||
|
db_session: AsyncSession # To verify db changes
|
||||||
|
):
|
||||||
|
user2 = test_user2_api["user"]
|
||||||
|
user2_headers = test_user2_api["headers"]
|
||||||
|
split_to_settle = test_expense_split_for_user2_api
|
||||||
|
|
||||||
|
payload = SettlementActivityCreate(
|
||||||
|
expense_split_id=split_to_settle.id,
|
||||||
|
paid_by_user_id=user2.id, # User2 is paying
|
||||||
|
amount_paid=split_to_settle.owed_amount # Full payment
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"{settings.API_V1_STR}/expense_splits/{split_to_settle.id}/settle",
|
||||||
|
json=payload.model_dump(mode='json'), # Pydantic v2
|
||||||
|
headers=user2_headers
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 201
|
||||||
|
activity_data = response.json()
|
||||||
|
assert activity_data["amount_paid"] == str(split_to_settle.owed_amount) # Compare as string due to JSON
|
||||||
|
assert activity_data["paid_by_user_id"] == user2.id
|
||||||
|
assert activity_data["expense_split_id"] == split_to_settle.id
|
||||||
|
assert "id" in activity_data
|
||||||
|
|
||||||
|
# Verify DB state
|
||||||
|
await db_session.refresh(split_to_settle)
|
||||||
|
assert split_to_settle.status == ExpenseSplitStatusEnum.paid
|
||||||
|
assert split_to_settle.paid_at is not None
|
||||||
|
|
||||||
|
# Verify parent expense status (this requires other splits to be paid too)
|
||||||
|
# For a focused test, we might need to ensure the other split (user1's share) is also paid.
|
||||||
|
# Or, accept 'partially_paid' if only this one is paid.
|
||||||
|
parent_expense_id = split_to_settle.expense_id
|
||||||
|
parent_expense = await db_session.get(Expense, parent_expense_id)
|
||||||
|
await db_session.refresh(parent_expense, attribute_names=['splits']) # Load splits to check status
|
||||||
|
|
||||||
|
all_splits_paid = all(s.status == ExpenseSplitStatusEnum.paid for s in parent_expense.splits)
|
||||||
|
if all_splits_paid:
|
||||||
|
assert parent_expense.overall_settlement_status == ExpenseOverallStatusEnum.paid
|
||||||
|
else:
|
||||||
|
assert parent_expense.overall_settlement_status == ExpenseOverallStatusEnum.partially_paid
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_settle_expense_split_by_group_owner_success(
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
test_user1_api: Dict[str, Any], # User1 is group owner
|
||||||
|
test_user2_api: Dict[str, Any], # User2 owes the split
|
||||||
|
test_expense_split_for_user2_api: ExpenseSplit,
|
||||||
|
db_session: AsyncSession
|
||||||
|
):
|
||||||
|
user1_headers = test_user1_api["headers"]
|
||||||
|
user_who_owes = test_user2_api["user"]
|
||||||
|
split_to_settle = test_expense_split_for_user2_api
|
||||||
|
|
||||||
|
payload = SettlementActivityCreate(
|
||||||
|
expense_split_id=split_to_settle.id,
|
||||||
|
paid_by_user_id=user_who_owes.id, # User1 (owner) records that User2 has paid
|
||||||
|
amount_paid=split_to_settle.owed_amount
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"{settings.API_V1_STR}/expense_splits/{split_to_settle.id}/settle",
|
||||||
|
json=payload.model_dump(mode='json'),
|
||||||
|
headers=user1_headers # Authenticated as group owner
|
||||||
|
)
|
||||||
|
assert response.status_code == 201
|
||||||
|
activity_data = response.json()
|
||||||
|
assert activity_data["paid_by_user_id"] == user_who_owes.id
|
||||||
|
assert activity_data["created_by_user_id"] == test_user1_api["user"].id # Activity created by owner
|
||||||
|
|
||||||
|
await db_session.refresh(split_to_settle)
|
||||||
|
assert split_to_settle.status == ExpenseSplitStatusEnum.paid
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_settle_expense_split_path_body_id_mismatch(
|
||||||
|
client: httpx.AsyncClient, test_user2_api: Dict[str, Any], test_expense_split_for_user2_api: ExpenseSplit
|
||||||
|
):
|
||||||
|
user2_headers = test_user2_api["headers"]
|
||||||
|
split_to_settle = test_expense_split_for_user2_api
|
||||||
|
payload = SettlementActivityCreate(
|
||||||
|
expense_split_id=split_to_settle.id + 1, # Mismatch
|
||||||
|
paid_by_user_id=test_user2_api["user"].id,
|
||||||
|
amount_paid=split_to_settle.owed_amount
|
||||||
|
)
|
||||||
|
response = await client.post(
|
||||||
|
f"{settings.API_V1_STR}/expense_splits/{split_to_settle.id}/settle",
|
||||||
|
json=payload.model_dump(mode='json'), headers=user2_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == 400 # As per API endpoint logic
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_settle_expense_split_not_found(
|
||||||
|
client: httpx.AsyncClient, test_user2_api: Dict[str, Any]
|
||||||
|
):
|
||||||
|
user2_headers = test_user2_api["headers"]
|
||||||
|
payload = SettlementActivityCreate(expense_split_id=9999, paid_by_user_id=test_user2_api["user"].id, amount_paid=Decimal("10.00"))
|
||||||
|
response = await client.post(
|
||||||
|
f"{settings.API_V1_STR}/expense_splits/9999/settle",
|
||||||
|
json=payload.model_dump(mode='json'), headers=user2_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == 404 # ItemNotFoundError
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_settle_expense_split_insufficient_permissions(
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
test_user1_api: Dict[str, Any], # User1 is not group owner for this setup, nor involved in split
|
||||||
|
test_user2_api: Dict[str, Any],
|
||||||
|
test_expense_split_for_user2_api: ExpenseSplit, # User2 owes this
|
||||||
|
db_session: AsyncSession
|
||||||
|
):
|
||||||
|
# Create a new user (user3) who is not involved and not an owner
|
||||||
|
user3 = User(email="api.user3@example.com", name="API User 3", hashed_password="password3")
|
||||||
|
db_session.add(user3)
|
||||||
|
await db_session.commit()
|
||||||
|
user3_headers = {"Authorization": f"Bearer token-for-{user3.id}"}
|
||||||
|
|
||||||
|
|
||||||
|
split_owner = test_user2_api["user"] # User2 owns the split
|
||||||
|
split_to_settle = test_expense_split_for_user2_api
|
||||||
|
|
||||||
|
payload = SettlementActivityCreate(
|
||||||
|
expense_split_id=split_to_settle.id,
|
||||||
|
paid_by_user_id=split_owner.id, # User2 is paying
|
||||||
|
amount_paid=split_to_settle.owed_amount
|
||||||
|
)
|
||||||
|
# User3 (neither payer nor group owner) tries to record User2's payment
|
||||||
|
response = await client.post(
|
||||||
|
f"{settings.API_V1_STR}/expense_splits/{split_to_settle.id}/settle",
|
||||||
|
json=payload.model_dump(mode='json'),
|
||||||
|
headers=user3_headers # Authenticated as User3
|
||||||
|
)
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
# --- Tests for GET /expense_splits/{expense_split_id}/settlement_activities ---
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_settlement_activities_success(
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
test_user1_api: Dict[str, Any], # Group owner / expense creator
|
||||||
|
test_user2_api: Dict[str, Any], # User who owes and pays
|
||||||
|
test_expense_split_for_user2_api: ExpenseSplit,
|
||||||
|
db_session: AsyncSession
|
||||||
|
):
|
||||||
|
user1_headers = test_user1_api["headers"]
|
||||||
|
user2 = test_user2_api["user"]
|
||||||
|
split = test_expense_split_for_user2_api
|
||||||
|
|
||||||
|
# Create a settlement activity first
|
||||||
|
activity_payload = SettlementActivityCreate(expense_split_id=split.id, paid_by_user_id=user2.id, amount_paid=Decimal("10.00"))
|
||||||
|
await client.post(
|
||||||
|
f"{settings.API_V1_STR}/expense_splits/{split.id}/settle",
|
||||||
|
json=activity_payload.model_dump(mode='json'), headers=test_user2_api["headers"] # User2 settles
|
||||||
|
)
|
||||||
|
|
||||||
|
# User1 (group owner) fetches activities
|
||||||
|
response = await client.get(
|
||||||
|
f"{settings.API_V1_STR}/expense_splits/{split.id}/settlement_activities",
|
||||||
|
headers=user1_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
activities_data = response.json()
|
||||||
|
assert isinstance(activities_data, list)
|
||||||
|
assert len(activities_data) == 1
|
||||||
|
assert activities_data[0]["amount_paid"] == "10.00"
|
||||||
|
assert activities_data[0]["paid_by_user_id"] == user2.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_settlement_activities_split_not_found(
|
||||||
|
client: httpx.AsyncClient, test_user1_api: Dict[str, Any]
|
||||||
|
):
|
||||||
|
user1_headers = test_user1_api["headers"]
|
||||||
|
response = await client.get(
|
||||||
|
f"{settings.API_V1_STR}/expense_splits/9999/settlement_activities",
|
||||||
|
headers=user1_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_settlement_activities_no_permission(
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
test_expense_split_for_user2_api: ExpenseSplit, # Belongs to group of user1/user2
|
||||||
|
db_session: AsyncSession
|
||||||
|
):
|
||||||
|
# Create a new user (user3) who is not in the group
|
||||||
|
user3 = User(email="api.user3.other@example.com", name="API User 3 Other", hashed_password="password3")
|
||||||
|
db_session.add(user3)
|
||||||
|
await db_session.commit()
|
||||||
|
user3_headers = {"Authorization": f"Bearer token-for-{user3.id}"}
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
f"{settings.API_V1_STR}/expense_splits/{test_expense_split_for_user2_api.id}/settlement_activities",
|
||||||
|
headers=user3_headers # Authenticated as User3
|
||||||
|
)
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test existing expense endpoints for new fields ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_expense_by_id_includes_new_fields(
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
test_user1_api: Dict[str, Any], # User in group
|
||||||
|
test_expense_in_group_api: Expense,
|
||||||
|
test_expense_split_for_user2_api: ExpenseSplit # one of the splits
|
||||||
|
):
|
||||||
|
user1_headers = test_user1_api["headers"]
|
||||||
|
expense_id = test_expense_in_group_api.id
|
||||||
|
|
||||||
|
response = await client.get(f"{settings.API_V1_STR}/expenses/{expense_id}", headers=user1_headers)
|
||||||
|
assert response.status_code == 200
|
||||||
|
expense_data = response.json()
|
||||||
|
|
||||||
|
assert "overall_settlement_status" in expense_data
|
||||||
|
assert expense_data["overall_settlement_status"] == ExpenseOverallStatusEnum.unpaid.value # Initial state
|
||||||
|
|
||||||
|
assert "splits" in expense_data
|
||||||
|
assert len(expense_data["splits"]) > 0
|
||||||
|
|
||||||
|
found_split = False
|
||||||
|
for split_json in expense_data["splits"]:
|
||||||
|
if split_json["id"] == test_expense_split_for_user2_api.id:
|
||||||
|
found_split = True
|
||||||
|
assert "status" in split_json
|
||||||
|
assert split_json["status"] == ExpenseSplitStatusEnum.unpaid.value # Initial state
|
||||||
|
assert "paid_at" in split_json # Should be null initially
|
||||||
|
assert split_json["paid_at"] is None
|
||||||
|
assert "settlement_activities" in split_json
|
||||||
|
assert isinstance(split_json["settlement_activities"], list)
|
||||||
|
assert len(split_json["settlement_activities"]) == 0 # No activities yet
|
||||||
|
break
|
||||||
|
assert found_split, "The specific test split was not found in the expense data."
|
||||||
|
|
||||||
|
|
||||||
|
# Placeholder for conftest.py content if needed for local execution understanding
|
||||||
|
"""
|
||||||
|
# conftest.py (example structure)
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from app.main import app # Your FastAPI app
|
||||||
|
from app.database import Base, get_transactional_session # Your DB setup
|
||||||
|
|
||||||
|
TEST_DATABASE_URL = "sqlite+aiosqlite:///./test.db"
|
||||||
|
|
||||||
|
engine = create_async_engine(TEST_DATABASE_URL, echo=True)
|
||||||
|
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, class_=AsyncSession)
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def event_loop():
|
||||||
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
async def setup_db():
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
yield
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session() -> AsyncSession:
|
||||||
|
async with TestingSessionLocal() as session:
|
||||||
|
# Transaction is handled by get_transactional_session override or test logic
|
||||||
|
yield session
|
||||||
|
# Rollback changes after test if not using transactional tests per case
|
||||||
|
# await session.rollback() # Or rely on test isolation method
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client(db_session) -> AsyncClient: # Depends on db_session to ensure DB is ready
|
||||||
|
async def override_get_transactional_session():
|
||||||
|
# Provide the test session, potentially managing transactions per test
|
||||||
|
# This is a simplified version; real setup might involve nested transactions
|
||||||
|
# or ensuring each test runs in its own transaction that's rolled back.
|
||||||
|
try:
|
||||||
|
yield db_session
|
||||||
|
# await db_session.commit() # Or commit if test is meant to persist then rollback globally
|
||||||
|
except Exception:
|
||||||
|
# await db_session.rollback()
|
||||||
|
raise
|
||||||
|
# finally:
|
||||||
|
# await db_session.rollback() # Ensure rollback after each test using this fixture
|
||||||
|
|
||||||
|
app.dependency_overrides[get_transactional_session] = override_get_transactional_session
|
||||||
|
async with AsyncClient(app=app, base_url="http://test") as c:
|
||||||
|
yield c
|
||||||
|
del app.dependency_overrides[get_transactional_session] # Clean up
|
||||||
|
"""
|
56
be/tests/conftest.py
Normal file
56
be/tests/conftest.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.main import app
|
||||||
|
from app.models import Base
|
||||||
|
from app.database import get_db
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
# Create test database engine
|
||||||
|
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||||
|
engine = create_async_engine(
|
||||||
|
TEST_DATABASE_URL,
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
TestingSessionLocal = sessionmaker(
|
||||||
|
engine, class_=AsyncSession, expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def event_loop():
|
||||||
|
"""Create an instance of the default event loop for each test case."""
|
||||||
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
async def test_db():
|
||||||
|
"""Create test database and tables."""
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
yield
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session(test_db) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Create a fresh database session for each test."""
|
||||||
|
async with TestingSessionLocal() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client(db_session) -> AsyncGenerator[TestClient, None]:
|
||||||
|
"""Create a test client with the test database session."""
|
||||||
|
async def override_get_db():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
with TestClient(app) as test_client:
|
||||||
|
yield test_client
|
||||||
|
app.dependency_overrides.clear()
|
1
be/tests/core/__init__.py
Normal file
1
be/tests/core/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
208
be/tests/core/test_exceptions.py
Normal file
208
be/tests/core/test_exceptions.py
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi import status
|
||||||
|
from app.core.exceptions import (
|
||||||
|
ListNotFoundError,
|
||||||
|
ListPermissionError,
|
||||||
|
ListCreatorRequiredError,
|
||||||
|
GroupNotFoundError,
|
||||||
|
GroupPermissionError,
|
||||||
|
GroupMembershipError,
|
||||||
|
GroupOperationError,
|
||||||
|
GroupValidationError,
|
||||||
|
ItemNotFoundError,
|
||||||
|
UserNotFoundError,
|
||||||
|
InvalidOperationError,
|
||||||
|
DatabaseConnectionError,
|
||||||
|
DatabaseIntegrityError,
|
||||||
|
DatabaseTransactionError,
|
||||||
|
DatabaseQueryError,
|
||||||
|
OCRServiceUnavailableError,
|
||||||
|
OCRServiceConfigError,
|
||||||
|
OCRUnexpectedError,
|
||||||
|
OCRQuotaExceededError,
|
||||||
|
InvalidFileTypeError,
|
||||||
|
FileTooLargeError,
|
||||||
|
OCRProcessingError,
|
||||||
|
EmailAlreadyRegisteredError,
|
||||||
|
UserCreationError,
|
||||||
|
InviteNotFoundError,
|
||||||
|
InviteExpiredError,
|
||||||
|
InviteAlreadyUsedError,
|
||||||
|
InviteCreationError,
|
||||||
|
ListStatusNotFoundError,
|
||||||
|
ConflictError,
|
||||||
|
InvalidCredentialsError,
|
||||||
|
NotAuthenticatedError,
|
||||||
|
JWTError,
|
||||||
|
JWTUnexpectedError
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_list_not_found_error():
|
||||||
|
list_id = 123
|
||||||
|
with pytest.raises(ListNotFoundError) as excinfo:
|
||||||
|
raise ListNotFoundError(list_id=list_id)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert excinfo.value.detail == f"List {list_id} not found"
|
||||||
|
|
||||||
|
def test_list_permission_error():
|
||||||
|
list_id = 456
|
||||||
|
action = "delete"
|
||||||
|
with pytest.raises(ListPermissionError) as excinfo:
|
||||||
|
raise ListPermissionError(list_id=list_id, action=action)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert excinfo.value.detail == f"You do not have permission to {action} list {list_id}"
|
||||||
|
|
||||||
|
def test_list_permission_error_default_action():
|
||||||
|
list_id = 789
|
||||||
|
with pytest.raises(ListPermissionError) as excinfo:
|
||||||
|
raise ListPermissionError(list_id=list_id)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert excinfo.value.detail == f"You do not have permission to access list {list_id}"
|
||||||
|
|
||||||
|
def test_list_creator_required_error():
|
||||||
|
list_id = 101
|
||||||
|
action = "update"
|
||||||
|
with pytest.raises(ListCreatorRequiredError) as excinfo:
|
||||||
|
raise ListCreatorRequiredError(list_id=list_id, action=action)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert excinfo.value.detail == f"Only the list creator can {action} list {list_id}"
|
||||||
|
|
||||||
|
def test_group_not_found_error():
|
||||||
|
group_id = 202
|
||||||
|
with pytest.raises(GroupNotFoundError) as excinfo:
|
||||||
|
raise GroupNotFoundError(group_id=group_id)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert excinfo.value.detail == f"Group {group_id} not found"
|
||||||
|
|
||||||
|
def test_group_permission_error():
|
||||||
|
group_id = 303
|
||||||
|
action = "invite"
|
||||||
|
with pytest.raises(GroupPermissionError) as excinfo:
|
||||||
|
raise GroupPermissionError(group_id=group_id, action=action)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert excinfo.value.detail == f"You do not have permission to {action} in group {group_id}"
|
||||||
|
|
||||||
|
def test_group_membership_error():
|
||||||
|
group_id = 404
|
||||||
|
action = "post"
|
||||||
|
with pytest.raises(GroupMembershipError) as excinfo:
|
||||||
|
raise GroupMembershipError(group_id=group_id, action=action)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert excinfo.value.detail == f"You must be a member of group {group_id} to {action}"
|
||||||
|
|
||||||
|
def test_group_membership_error_default_action():
|
||||||
|
group_id = 505
|
||||||
|
with pytest.raises(GroupMembershipError) as excinfo:
|
||||||
|
raise GroupMembershipError(group_id=group_id)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
assert excinfo.value.detail == f"You must be a member of group {group_id} to access"
|
||||||
|
|
||||||
|
def test_group_operation_error():
|
||||||
|
detail_msg = "Failed to perform group operation."
|
||||||
|
with pytest.raises(GroupOperationError) as excinfo:
|
||||||
|
raise GroupOperationError(detail=detail_msg)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
assert excinfo.value.detail == detail_msg
|
||||||
|
|
||||||
|
def test_group_validation_error():
|
||||||
|
detail_msg = "Invalid group data."
|
||||||
|
with pytest.raises(GroupValidationError) as excinfo:
|
||||||
|
raise GroupValidationError(detail=detail_msg)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert excinfo.value.detail == detail_msg
|
||||||
|
|
||||||
|
def test_item_not_found_error():
|
||||||
|
item_id = 606
|
||||||
|
with pytest.raises(ItemNotFoundError) as excinfo:
|
||||||
|
raise ItemNotFoundError(item_id=item_id)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert excinfo.value.detail == f"Item {item_id} not found"
|
||||||
|
|
||||||
|
def test_user_not_found_error_no_identifier():
|
||||||
|
with pytest.raises(UserNotFoundError) as excinfo:
|
||||||
|
raise UserNotFoundError()
|
||||||
|
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert excinfo.value.detail == "User not found."
|
||||||
|
|
||||||
|
def test_user_not_found_error_with_id():
|
||||||
|
user_id = 707
|
||||||
|
with pytest.raises(UserNotFoundError) as excinfo:
|
||||||
|
raise UserNotFoundError(user_id=user_id)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert excinfo.value.detail == f"User with ID {user_id} not found."
|
||||||
|
|
||||||
|
def test_user_not_found_error_with_identifier_string():
|
||||||
|
identifier = "test_user"
|
||||||
|
with pytest.raises(UserNotFoundError) as excinfo:
|
||||||
|
raise UserNotFoundError(identifier=identifier)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert excinfo.value.detail == f"User with identifier '{identifier}' not found."
|
||||||
|
|
||||||
|
def test_invalid_operation_error():
|
||||||
|
detail_msg = "This operation is not allowed."
|
||||||
|
with pytest.raises(InvalidOperationError) as excinfo:
|
||||||
|
raise InvalidOperationError(detail=detail_msg)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert excinfo.value.detail == detail_msg
|
||||||
|
|
||||||
|
def test_invalid_operation_error_custom_status():
|
||||||
|
detail_msg = "This operation is forbidden."
|
||||||
|
custom_status = status.HTTP_403_FORBIDDEN
|
||||||
|
with pytest.raises(InvalidOperationError) as excinfo:
|
||||||
|
raise InvalidOperationError(detail=detail_msg, status_code=custom_status)
|
||||||
|
assert excinfo.value.status_code == custom_status
|
||||||
|
assert excinfo.value.detail == detail_msg
|
||||||
|
|
||||||
|
def test_email_already_registered_error():
|
||||||
|
with pytest.raises(EmailAlreadyRegisteredError) as excinfo:
|
||||||
|
raise EmailAlreadyRegisteredError()
|
||||||
|
assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert excinfo.value.detail == "Email already registered."
|
||||||
|
|
||||||
|
def test_user_creation_error():
|
||||||
|
with pytest.raises(UserCreationError) as excinfo:
|
||||||
|
raise UserCreationError()
|
||||||
|
assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
assert excinfo.value.detail == "An error occurred during user creation."
|
||||||
|
|
||||||
|
def test_invite_not_found_error():
|
||||||
|
invite_code = "TESTCODE123"
|
||||||
|
with pytest.raises(InviteNotFoundError) as excinfo:
|
||||||
|
raise InviteNotFoundError(invite_code=invite_code)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert excinfo.value.detail == f"Invite code {invite_code} not found"
|
||||||
|
|
||||||
|
def test_invite_expired_error():
|
||||||
|
invite_code = "EXPIREDCODE"
|
||||||
|
with pytest.raises(InviteExpiredError) as excinfo:
|
||||||
|
raise InviteExpiredError(invite_code=invite_code)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_410_GONE
|
||||||
|
assert excinfo.value.detail == f"Invite code {invite_code} has expired"
|
||||||
|
|
||||||
|
def test_invite_already_used_error():
|
||||||
|
invite_code = "USEDCODE"
|
||||||
|
with pytest.raises(InviteAlreadyUsedError) as excinfo:
|
||||||
|
raise InviteAlreadyUsedError(invite_code=invite_code)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_410_GONE
|
||||||
|
assert excinfo.value.detail == f"Invite code {invite_code} has already been used"
|
||||||
|
|
||||||
|
def test_invite_creation_error():
|
||||||
|
group_id = 909
|
||||||
|
with pytest.raises(InviteCreationError) as excinfo:
|
||||||
|
raise InviteCreationError(group_id=group_id)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
assert excinfo.value.detail == f"Failed to create invite for group {group_id}"
|
||||||
|
|
||||||
|
def test_list_status_not_found_error():
|
||||||
|
list_id = 808
|
||||||
|
with pytest.raises(ListStatusNotFoundError) as excinfo:
|
||||||
|
raise ListStatusNotFoundError(list_id=list_id)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert excinfo.value.detail == f"Status for list {list_id} not found"
|
||||||
|
|
||||||
|
def test_conflict_error():
|
||||||
|
detail_msg = "Resource version mismatch."
|
||||||
|
with pytest.raises(ConflictError) as excinfo:
|
||||||
|
raise ConflictError(detail=detail_msg)
|
||||||
|
assert excinfo.value.status_code == status.HTTP_409_CONFLICT
|
||||||
|
assert excinfo.value.detail == detail_msg
|
300
be/tests/core/test_gemini.py
Normal file
300
be/tests/core/test_gemini.py
Normal file
@ -0,0 +1,300 @@
|
|||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import patch, MagicMock, AsyncMock
|
||||||
|
|
||||||
|
import google.generativeai as genai
|
||||||
|
from google.api_core import exceptions as google_exceptions
|
||||||
|
|
||||||
|
# Modules to test
|
||||||
|
from app.core import gemini
|
||||||
|
from app.core.exceptions import (
|
||||||
|
OCRServiceUnavailableError,
|
||||||
|
OCRServiceConfigError,
|
||||||
|
OCRUnexpectedError,
|
||||||
|
OCRQuotaExceededError
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default Mock Settings
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_gemini_settings():
|
||||||
|
settings_mock = MagicMock()
|
||||||
|
settings_mock.GEMINI_API_KEY = "test_api_key"
|
||||||
|
settings_mock.GEMINI_MODEL_NAME = "gemini-pro-vision"
|
||||||
|
settings_mock.GEMINI_SAFETY_SETTINGS = {
|
||||||
|
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
|
||||||
|
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
|
||||||
|
}
|
||||||
|
settings_mock.GEMINI_GENERATION_CONFIG = {"temperature": 0.7}
|
||||||
|
settings_mock.OCR_ITEM_EXTRACTION_PROMPT = "Extract items:"
|
||||||
|
return settings_mock
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_generative_model_instance():
|
||||||
|
model_instance = AsyncMock(spec=genai.GenerativeModel)
|
||||||
|
model_instance.generate_content_async = AsyncMock()
|
||||||
|
return model_instance
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patch_google_ai_client(mock_generative_model_instance):
|
||||||
|
with patch('google.generativeai.GenerativeModel', return_value=mock_generative_model_instance) as mock_generative_model, \
|
||||||
|
patch('google.generativeai.configure') as mock_configure:
|
||||||
|
yield mock_configure, mock_generative_model, mock_generative_model_instance
|
||||||
|
|
||||||
|
|
||||||
|
# --- Test Gemini Client Initialization (Global Client) ---
|
||||||
|
|
||||||
|
# Parametrize to test different scenarios for the global client init
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"api_key_present, configure_raises, model_init_raises, expected_error_message_part",
|
||||||
|
[
|
||||||
|
(True, None, None, None), # Success
|
||||||
|
(False, None, None, "GEMINI_API_KEY not configured"), # API key missing
|
||||||
|
(True, Exception("Config error"), None, "Failed to initialize Gemini AI client: Config error"), # genai.configure error
|
||||||
|
(True, None, Exception("Model init error"), "Failed to initialize Gemini AI client: Model init error"), # GenerativeModel error
|
||||||
|
]
|
||||||
|
)
|
||||||
|
@patch('app.core.gemini.genai') # Patch genai within the gemini module
|
||||||
|
def test_global_gemini_client_initialization(
|
||||||
|
mock_genai_module,
|
||||||
|
mock_gemini_settings,
|
||||||
|
api_key_present,
|
||||||
|
configure_raises,
|
||||||
|
model_init_raises,
|
||||||
|
expected_error_message_part
|
||||||
|
):
|
||||||
|
"""Tests the global gemini_flash_client initialization logic in app.core.gemini."""
|
||||||
|
# We need to reload the module to re-trigger its top-level initialization code.
|
||||||
|
# This is a bit tricky. A common pattern is to put init logic in a function.
|
||||||
|
# For now, we'll try to simulate it by controlling mocks before module access.
|
||||||
|
|
||||||
|
with patch('app.core.gemini.settings', mock_gemini_settings):
|
||||||
|
if not api_key_present:
|
||||||
|
mock_gemini_settings.GEMINI_API_KEY = None
|
||||||
|
|
||||||
|
mock_genai_module.configure = MagicMock()
|
||||||
|
mock_genai_module.GenerativeModel = MagicMock()
|
||||||
|
mock_genai_module.types = genai.types # Keep original types
|
||||||
|
mock_genai_module.HarmCategory = genai.HarmCategory
|
||||||
|
mock_genai_module.HarmBlockThreshold = genai.HarmBlockThreshold
|
||||||
|
|
||||||
|
if configure_raises:
|
||||||
|
mock_genai_module.configure.side_effect = configure_raises
|
||||||
|
if model_init_raises:
|
||||||
|
mock_genai_module.GenerativeModel.side_effect = model_init_raises
|
||||||
|
|
||||||
|
# Python modules are singletons. To re-run top-level code, we need to unload and reload.
|
||||||
|
# This is generally discouraged. It's better to have an explicit init function.
|
||||||
|
# For this test, we'll check the state variables set by the module's import-time code.
|
||||||
|
import importlib
|
||||||
|
importlib.reload(gemini) # This re-runs the try-except block at the top of gemini.py
|
||||||
|
|
||||||
|
if expected_error_message_part:
|
||||||
|
assert gemini.gemini_initialization_error is not None
|
||||||
|
assert expected_error_message_part in gemini.gemini_initialization_error
|
||||||
|
assert gemini.gemini_flash_client is None
|
||||||
|
else:
|
||||||
|
assert gemini.gemini_initialization_error is None
|
||||||
|
assert gemini.gemini_flash_client is not None
|
||||||
|
mock_genai_module.configure.assert_called_once_with(api_key="test_api_key")
|
||||||
|
mock_genai_module.GenerativeModel.assert_called_once()
|
||||||
|
# Could add more assertions about safety_settings and generation_config here
|
||||||
|
|
||||||
|
# Clean up after reload for other tests
|
||||||
|
importlib.reload(gemini)
|
||||||
|
|
||||||
|
# --- Test get_gemini_client ---
|
||||||
|
# Assuming the global client tests above set the stage for these
|
||||||
|
|
||||||
|
@patch('app.core.gemini.gemini_flash_client', new_callable=MagicMock)
|
||||||
|
@patch('app.core.gemini.gemini_initialization_error', None)
|
||||||
|
def test_get_gemini_client_success(mock_client_var, mock_error_var):
|
||||||
|
mock_client_var.return_value = MagicMock(spec=genai.GenerativeModel) # Simulate an initialized client
|
||||||
|
gemini.gemini_flash_client = mock_client_var # Assign the mock
|
||||||
|
gemini.gemini_initialization_error = None
|
||||||
|
client = gemini.get_gemini_client()
|
||||||
|
assert client is not None
|
||||||
|
|
||||||
|
@patch('app.core.gemini.gemini_flash_client', None)
|
||||||
|
@patch('app.core.gemini.gemini_initialization_error', "Test init error")
|
||||||
|
def test_get_gemini_client_init_error(mock_client_var, mock_error_var):
|
||||||
|
gemini.gemini_flash_client = None
|
||||||
|
gemini.gemini_initialization_error = "Test init error"
|
||||||
|
with pytest.raises(RuntimeError, match="Gemini client could not be initialized: Test init error"):
|
||||||
|
gemini.get_gemini_client()
|
||||||
|
|
||||||
|
@patch('app.core.gemini.gemini_flash_client', None)
|
||||||
|
@patch('app.core.gemini.gemini_initialization_error', None) # No init error, but client is None
|
||||||
|
def test_get_gemini_client_none_client_unknown_issue(mock_client_var, mock_error_var):
|
||||||
|
gemini.gemini_flash_client = None
|
||||||
|
gemini.gemini_initialization_error = None
|
||||||
|
with pytest.raises(RuntimeError, match="Gemini client is not available \(unknown initialization issue\)."):
|
||||||
|
gemini.get_gemini_client()
|
||||||
|
|
||||||
|
|
||||||
|
# --- Tests for extract_items_from_image_gemini --- (Simplified for brevity, needs more cases)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_from_image_gemini_success(
|
||||||
|
mock_gemini_settings,
|
||||||
|
mock_generative_model_instance,
|
||||||
|
patch_google_ai_client
|
||||||
|
):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.text = "Item 1\nItem 2\n Item 3 \n\nAnother Item"
|
||||||
|
mock_candidate = MagicMock()
|
||||||
|
mock_candidate.content.parts = [MagicMock(text=mock_response.text)]
|
||||||
|
mock_candidate.finish_reason = 'STOP'
|
||||||
|
mock_candidate.safety_ratings = []
|
||||||
|
mock_response.candidates = [mock_candidate]
|
||||||
|
|
||||||
|
mock_generative_model_instance.generate_content_async.return_value = mock_response
|
||||||
|
|
||||||
|
with patch('app.core.gemini.settings', mock_gemini_settings), \
|
||||||
|
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
|
||||||
|
patch('app.core.gemini.gemini_initialization_error', None):
|
||||||
|
|
||||||
|
image_bytes = b"dummy_image_bytes"
|
||||||
|
mime_type = "image/png"
|
||||||
|
|
||||||
|
items = await gemini.extract_items_from_image_gemini(image_bytes, mime_type)
|
||||||
|
|
||||||
|
mock_generative_model_instance.generate_content_async.assert_called_once_with([
|
||||||
|
mock_gemini_settings.OCR_ITEM_EXTRACTION_PROMPT,
|
||||||
|
{"mime_type": mime_type, "data": image_bytes}
|
||||||
|
])
|
||||||
|
assert items == ["Item 1", "Item 2", "Item 3", "Another Item"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_from_image_gemini_client_not_init(mock_gemini_settings):
|
||||||
|
with patch('app.core.gemini.settings', mock_gemini_settings), \
|
||||||
|
patch('app.core.gemini.gemini_flash_client', None), \
|
||||||
|
patch('app.core.gemini.gemini_initialization_error', "Initialization failed explicitly"):
|
||||||
|
|
||||||
|
image_bytes = b"dummy_image_bytes"
|
||||||
|
with pytest.raises(RuntimeError, match="Gemini client could not be initialized: Initialization failed explicitly"):
|
||||||
|
await gemini.extract_items_from_image_gemini(image_bytes)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_from_image_gemini_api_quota_error(
|
||||||
|
mock_gemini_settings,
|
||||||
|
mock_generative_model_instance
|
||||||
|
):
|
||||||
|
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ResourceExhausted("Quota exceeded")
|
||||||
|
|
||||||
|
with patch('app.core.gemini.settings', mock_gemini_settings), \
|
||||||
|
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
|
||||||
|
patch('app.core.gemini.gemini_initialization_error', None):
|
||||||
|
|
||||||
|
image_bytes = b"dummy_image_bytes"
|
||||||
|
with pytest.raises(google_exceptions.ResourceExhausted, match="Quota exceeded"):
|
||||||
|
await gemini.extract_items_from_image_gemini(image_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Tests for GeminiOCRService --- (Example tests, more needed)
|
||||||
|
|
||||||
|
@patch('app.core.gemini.genai.configure')
|
||||||
|
@patch('app.core.gemini.genai.GenerativeModel')
|
||||||
|
def test_gemini_ocr_service_init_success(MockGenerativeModel, MockConfigure, mock_gemini_settings, mock_generative_model_instance):
|
||||||
|
MockGenerativeModel.return_value = mock_generative_model_instance
|
||||||
|
with patch('app.core.gemini.settings', mock_gemini_settings):
|
||||||
|
service = gemini.GeminiOCRService()
|
||||||
|
MockConfigure.assert_called_once_with(api_key=mock_gemini_settings.GEMINI_API_KEY)
|
||||||
|
MockGenerativeModel.assert_called_once_with(mock_gemini_settings.GEMINI_MODEL_NAME)
|
||||||
|
assert service.model == mock_generative_model_instance
|
||||||
|
# Could add assertions for safety_settings and generation_config if they are set directly on model
|
||||||
|
|
||||||
|
@patch('app.core.gemini.genai.configure')
|
||||||
|
@patch('app.core.gemini.genai.GenerativeModel', side_effect=Exception("Init model failed"))
|
||||||
|
def test_gemini_ocr_service_init_failure(MockGenerativeModel, MockConfigure, mock_gemini_settings):
|
||||||
|
with patch('app.core.gemini.settings', mock_gemini_settings):
|
||||||
|
with pytest.raises(OCRServiceConfigError):
|
||||||
|
gemini.GeminiOCRService()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gemini_ocr_service_extract_items_success(
|
||||||
|
mock_gemini_settings,
|
||||||
|
mock_generative_model_instance
|
||||||
|
):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.text = "Item 1\nItem 2\n Item 3 \n\nAnother Item"
|
||||||
|
mock_candidate = MagicMock()
|
||||||
|
mock_candidate.content.parts = [MagicMock(text=mock_response.text)]
|
||||||
|
mock_candidate.finish_reason = 'STOP'
|
||||||
|
mock_candidate.safety_ratings = []
|
||||||
|
mock_response.candidates = [mock_candidate]
|
||||||
|
|
||||||
|
mock_generative_model_instance.generate_content_async.return_value = mock_response
|
||||||
|
|
||||||
|
with patch('app.core.gemini.settings', mock_gemini_settings), \
|
||||||
|
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
|
||||||
|
patch('app.core.gemini.gemini_initialization_error', None):
|
||||||
|
|
||||||
|
service = gemini.GeminiOCRService()
|
||||||
|
image_bytes = b"dummy_image_bytes"
|
||||||
|
mime_type = "image/png"
|
||||||
|
|
||||||
|
items = await service.extract_items(image_bytes, mime_type)
|
||||||
|
|
||||||
|
mock_generative_model_instance.generate_content_async.assert_called_once_with([
|
||||||
|
mock_gemini_settings.OCR_ITEM_EXTRACTION_PROMPT,
|
||||||
|
{"mime_type": mime_type, "data": image_bytes}
|
||||||
|
])
|
||||||
|
assert items == ["Item 1", "Item 2", "Item 3", "Another Item"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gemini_ocr_service_extract_items_quota_error(
|
||||||
|
mock_gemini_settings,
|
||||||
|
mock_generative_model_instance
|
||||||
|
):
|
||||||
|
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ResourceExhausted("Quota exceeded")
|
||||||
|
|
||||||
|
with patch('app.core.gemini.settings', mock_gemini_settings), \
|
||||||
|
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
|
||||||
|
patch('app.core.gemini.gemini_initialization_error', None):
|
||||||
|
|
||||||
|
service = gemini.GeminiOCRService()
|
||||||
|
image_bytes = b"dummy_image_bytes"
|
||||||
|
|
||||||
|
with pytest.raises(OCRQuotaExceededError):
|
||||||
|
await service.extract_items(image_bytes)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gemini_ocr_service_extract_items_api_unavailable(
|
||||||
|
mock_gemini_settings,
|
||||||
|
mock_generative_model_instance
|
||||||
|
):
|
||||||
|
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ServiceUnavailable("Service unavailable")
|
||||||
|
|
||||||
|
with patch('app.core.gemini.settings', mock_gemini_settings), \
|
||||||
|
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
|
||||||
|
patch('app.core.gemini.gemini_initialization_error', None):
|
||||||
|
|
||||||
|
service = gemini.GeminiOCRService()
|
||||||
|
image_bytes = b"dummy_image_bytes"
|
||||||
|
|
||||||
|
with pytest.raises(OCRServiceUnavailableError):
|
||||||
|
await service.extract_items(image_bytes)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gemini_ocr_service_extract_items_no_text_response(
|
||||||
|
mock_gemini_settings,
|
||||||
|
mock_generative_model_instance
|
||||||
|
):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.text = ""
|
||||||
|
mock_candidate = MagicMock()
|
||||||
|
mock_candidate.content.parts = [MagicMock(text=mock_response.text)]
|
||||||
|
mock_candidate.finish_reason = 'STOP'
|
||||||
|
mock_candidate.safety_ratings = []
|
||||||
|
mock_response.candidates = [mock_candidate]
|
||||||
|
|
||||||
|
mock_generative_model_instance.generate_content_async.return_value = mock_response
|
||||||
|
|
||||||
|
with patch('app.core.gemini.settings', mock_gemini_settings), \
|
||||||
|
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
|
||||||
|
patch('app.core.gemini.gemini_initialization_error', None):
|
||||||
|
|
||||||
|
service = gemini.GeminiOCRService()
|
||||||
|
image_bytes = b"dummy_image_bytes"
|
||||||
|
|
||||||
|
items = await service.extract_items(image_bytes)
|
||||||
|
assert items == []
|
216
be/tests/core/test_security.py
Normal file
216
be/tests/core/test_security.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from jose import jwt, JWTError
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
|
from app.core.security import (
|
||||||
|
verify_password,
|
||||||
|
hash_password,
|
||||||
|
# create_access_token,
|
||||||
|
# create_refresh_token,
|
||||||
|
# verify_access_token,
|
||||||
|
# verify_refresh_token,
|
||||||
|
pwd_context, # Import for direct testing if needed, or to check its config
|
||||||
|
)
|
||||||
|
# Assuming app.config.settings will be mocked
|
||||||
|
# from app.config import settings
|
||||||
|
|
||||||
|
# --- Tests for Password Hashing ---
|
||||||
|
|
||||||
|
def test_hash_password():
|
||||||
|
password = "securepassword123"
|
||||||
|
hashed = hash_password(password)
|
||||||
|
assert isinstance(hashed, str)
|
||||||
|
assert hashed != password
|
||||||
|
# Check that the default scheme (bcrypt) is used by verifying the hash prefix
|
||||||
|
# bcrypt hashes typically start with $2b$ or $2a$ or $2y$
|
||||||
|
assert hashed.startswith("$2b$") or hashed.startswith("$2a$") or hashed.startswith("$2y$")
|
||||||
|
|
||||||
|
def test_verify_password_correct():
|
||||||
|
password = "testpassword"
|
||||||
|
hashed_password = pwd_context.hash(password) # Use the same context for consistency
|
||||||
|
assert verify_password(password, hashed_password) is True
|
||||||
|
|
||||||
|
def test_verify_password_incorrect():
|
||||||
|
password = "testpassword"
|
||||||
|
wrong_password = "wrongpassword"
|
||||||
|
hashed_password = pwd_context.hash(password)
|
||||||
|
assert verify_password(wrong_password, hashed_password) is False
|
||||||
|
|
||||||
|
def test_verify_password_invalid_hash_format():
|
||||||
|
password = "testpassword"
|
||||||
|
invalid_hash = "notarealhash"
|
||||||
|
assert verify_password(password, invalid_hash) is False
|
||||||
|
|
||||||
|
# --- Tests for JWT Creation ---
|
||||||
|
# Mock settings for JWT tests
|
||||||
|
# @pytest.fixture(scope="module")
|
||||||
|
# def mock_jwt_settings():
|
||||||
|
# mock_settings = MagicMock()
|
||||||
|
# mock_settings.SECRET_KEY = "testsecretkey"
|
||||||
|
# mock_settings.ALGORITHM = "HS256"
|
||||||
|
# mock_settings.ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||||
|
# mock_settings.REFRESH_TOKEN_EXPIRE_MINUTES = 10080 # 7 days
|
||||||
|
# return mock_settings
|
||||||
|
|
||||||
|
# @patch('app.core.security.settings')
|
||||||
|
# def test_create_access_token_default_expiry(mock_settings_global, mock_jwt_settings):
|
||||||
|
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||||
|
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||||
|
# mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||||
|
|
||||||
|
# subject = "user@example.com"
|
||||||
|
# token = create_access_token(subject)
|
||||||
|
# assert isinstance(token, str)
|
||||||
|
|
||||||
|
# decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
|
||||||
|
# assert decoded_payload["sub"] == subject
|
||||||
|
# assert decoded_payload["type"] == "access"
|
||||||
|
# assert "exp" in decoded_payload
|
||||||
|
# # Check if expiry is roughly correct (within a small delta)
|
||||||
|
# expected_expiry = datetime.now(timezone.utc) + timedelta(minutes=mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
# assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
|
||||||
|
|
||||||
|
# @patch('app.core.security.settings')
|
||||||
|
# def test_create_access_token_custom_expiry(mock_settings_global, mock_jwt_settings):
|
||||||
|
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||||
|
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||||
|
# # ACCESS_TOKEN_EXPIRE_MINUTES is not used here due to custom delta
|
||||||
|
|
||||||
|
# subject = 123 # Subject can be int
|
||||||
|
# custom_delta = timedelta(hours=1)
|
||||||
|
# token = create_access_token(subject, expires_delta=custom_delta)
|
||||||
|
# assert isinstance(token, str)
|
||||||
|
|
||||||
|
# decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
|
||||||
|
# assert decoded_payload["sub"] == str(subject)
|
||||||
|
# assert decoded_payload["type"] == "access"
|
||||||
|
# expected_expiry = datetime.now(timezone.utc) + custom_delta
|
||||||
|
# assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
|
||||||
|
|
||||||
|
# @patch('app.core.security.settings')
|
||||||
|
# def test_create_refresh_token_default_expiry(mock_settings_global, mock_jwt_settings):
|
||||||
|
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||||
|
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||||
|
# mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
||||||
|
|
||||||
|
# subject = "refresh_subject"
|
||||||
|
# token = create_refresh_token(subject)
|
||||||
|
# assert isinstance(token, str)
|
||||||
|
|
||||||
|
# decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
|
||||||
|
# assert decoded_payload["sub"] == subject
|
||||||
|
# assert decoded_payload["type"] == "refresh"
|
||||||
|
# expected_expiry = datetime.now(timezone.utc) + timedelta(minutes=mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES)
|
||||||
|
# assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
|
||||||
|
|
||||||
|
# --- Tests for JWT Verification --- (More tests to be added here)
|
||||||
|
|
||||||
|
# @patch('app.core.security.settings')
|
||||||
|
# def test_verify_access_token_valid(mock_settings_global, mock_jwt_settings):
|
||||||
|
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||||
|
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||||
|
# mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||||
|
|
||||||
|
# subject = "test_user_valid_access"
|
||||||
|
# token = create_access_token(subject)
|
||||||
|
# payload = verify_access_token(token)
|
||||||
|
# assert payload is not None
|
||||||
|
# assert payload["sub"] == subject
|
||||||
|
# assert payload["type"] == "access"
|
||||||
|
|
||||||
|
# @patch('app.core.security.settings')
|
||||||
|
# def test_verify_access_token_invalid_signature(mock_settings_global, mock_jwt_settings):
|
||||||
|
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||||
|
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||||
|
# mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||||
|
|
||||||
|
# subject = "test_user_invalid_sig"
|
||||||
|
# # Create token with correct key
|
||||||
|
# token = create_access_token(subject)
|
||||||
|
|
||||||
|
# # Try to verify with wrong key
|
||||||
|
# mock_settings_global.SECRET_KEY = "wrongsecretkey"
|
||||||
|
# payload = verify_access_token(token)
|
||||||
|
# assert payload is None
|
||||||
|
|
||||||
|
# @patch('app.core.security.settings')
|
||||||
|
# @patch('app.core.security.datetime') # Mock datetime to control token expiry
|
||||||
|
# def test_verify_access_token_expired(mock_datetime, mock_settings_global, mock_jwt_settings):
|
||||||
|
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||||
|
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||||
|
# mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = 1 # Expire in 1 minute
|
||||||
|
|
||||||
|
# # Set current time for token creation
|
||||||
|
# now = datetime.now(timezone.utc)
|
||||||
|
# mock_datetime.now.return_value = now
|
||||||
|
# mock_datetime.fromtimestamp = datetime.fromtimestamp # Ensure original fromtimestamp is used by jwt.decode
|
||||||
|
# mock_datetime.timedelta = timedelta # Ensure original timedelta is used
|
||||||
|
|
||||||
|
# subject = "test_user_expired"
|
||||||
|
# token = create_access_token(subject)
|
||||||
|
|
||||||
|
# # Advance time beyond expiry for verification
|
||||||
|
# mock_datetime.now.return_value = now + timedelta(minutes=5)
|
||||||
|
# payload = verify_access_token(token)
|
||||||
|
# assert payload is None
|
||||||
|
|
||||||
|
# @patch('app.core.security.settings')
|
||||||
|
# def test_verify_access_token_wrong_type(mock_settings_global, mock_jwt_settings):
|
||||||
|
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||||
|
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||||
|
# mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES # For refresh token creation
|
||||||
|
|
||||||
|
# subject = "test_user_wrong_type"
|
||||||
|
# # Create a refresh token
|
||||||
|
# refresh_token = create_refresh_token(subject)
|
||||||
|
|
||||||
|
# # Try to verify it as an access token
|
||||||
|
# payload = verify_access_token(refresh_token)
|
||||||
|
# assert payload is None
|
||||||
|
|
||||||
|
# @patch('app.core.security.settings')
|
||||||
|
# def test_verify_refresh_token_valid(mock_settings_global, mock_jwt_settings):
|
||||||
|
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||||
|
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||||
|
# mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
||||||
|
|
||||||
|
# subject = "test_user_valid_refresh"
|
||||||
|
# token = create_refresh_token(subject)
|
||||||
|
# payload = verify_refresh_token(token)
|
||||||
|
# assert payload is not None
|
||||||
|
# assert payload["sub"] == subject
|
||||||
|
# assert payload["type"] == "refresh"
|
||||||
|
|
||||||
|
# @patch('app.core.security.settings')
|
||||||
|
# @patch('app.core.security.datetime')
|
||||||
|
# def test_verify_refresh_token_expired(mock_datetime, mock_settings_global, mock_jwt_settings):
|
||||||
|
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||||
|
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||||
|
# mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = 1 # Expire in 1 minute
|
||||||
|
|
||||||
|
# now = datetime.now(timezone.utc)
|
||||||
|
# mock_datetime.now.return_value = now
|
||||||
|
# mock_datetime.fromtimestamp = datetime.fromtimestamp
|
||||||
|
# mock_datetime.timedelta = timedelta
|
||||||
|
|
||||||
|
# subject = "test_user_expired_refresh"
|
||||||
|
# token = create_refresh_token(subject)
|
||||||
|
|
||||||
|
# mock_datetime.now.return_value = now + timedelta(minutes=5)
|
||||||
|
# payload = verify_refresh_token(token)
|
||||||
|
# assert payload is None
|
||||||
|
|
||||||
|
# @patch('app.core.security.settings')
|
||||||
|
# def test_verify_refresh_token_wrong_type(mock_settings_global, mock_jwt_settings):
|
||||||
|
# mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||||
|
# mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||||
|
# mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||||
|
|
||||||
|
# subject = "test_user_wrong_type_refresh"
|
||||||
|
# access_token = create_access_token(subject)
|
||||||
|
|
||||||
|
# payload = verify_refresh_token(access_token)
|
||||||
|
# assert payload is None
|
1
be/tests/crud/__init__.py
Normal file
1
be/tests/crud/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
369
be/tests/crud/test_expense.py
Normal file
369
be/tests/crud/test_expense.py
Normal file
@ -0,0 +1,369 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
|
from decimal import Decimal, ROUND_HALF_UP
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import List as PyList, Optional
|
||||||
|
|
||||||
|
from app.crud.expense import (
|
||||||
|
create_expense,
|
||||||
|
get_expense_by_id,
|
||||||
|
get_expenses_for_list,
|
||||||
|
get_expenses_for_group,
|
||||||
|
update_expense,
|
||||||
|
delete_expense,
|
||||||
|
get_users_for_splitting
|
||||||
|
)
|
||||||
|
from app.schemas.expense import ExpenseCreate, ExpenseUpdate, ExpenseSplitCreate, ExpenseRead
|
||||||
|
from app.models import (
|
||||||
|
Expense as ExpenseModel,
|
||||||
|
ExpenseSplit as ExpenseSplitModel,
|
||||||
|
User as UserModel,
|
||||||
|
List as ListModel,
|
||||||
|
Group as GroupModel,
|
||||||
|
UserGroup as UserGroupModel,
|
||||||
|
Item as ItemModel,
|
||||||
|
SplitTypeEnum,
|
||||||
|
ExpenseOverallStatusEnum, # Added
|
||||||
|
ExpenseSplitStatusEnum # Added
|
||||||
|
)
|
||||||
|
from app.core.exceptions import (
|
||||||
|
ListNotFoundError,
|
||||||
|
GroupNotFoundError,
|
||||||
|
UserNotFoundError,
|
||||||
|
InvalidOperationError,
|
||||||
|
ExpenseNotFoundError,
|
||||||
|
DatabaseTransactionError,
|
||||||
|
ConflictError
|
||||||
|
)
|
||||||
|
|
||||||
|
# General Fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_db_session():
|
||||||
|
session = AsyncMock()
|
||||||
|
session.begin_nested = AsyncMock() # For nested transactions within functions
|
||||||
|
session.commit = AsyncMock()
|
||||||
|
session.rollback = AsyncMock()
|
||||||
|
session.refresh = AsyncMock()
|
||||||
|
session.add = MagicMock()
|
||||||
|
session.delete = MagicMock()
|
||||||
|
session.execute = AsyncMock()
|
||||||
|
session.get = AsyncMock()
|
||||||
|
session.flush = AsyncMock()
|
||||||
|
session.in_transaction = MagicMock(return_value=False)
|
||||||
|
# Mock session.begin() to return an async context manager
|
||||||
|
mock_transaction_context = AsyncMock()
|
||||||
|
session.begin = MagicMock(return_value=mock_transaction_context)
|
||||||
|
return session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def basic_user_model():
|
||||||
|
return UserModel(id=1, name="Test User", email="test@example.com", version=1)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def another_user_model():
|
||||||
|
return UserModel(id=2, name="Another User", email="another@example.com", version=1)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def basic_group_model(basic_user_model, another_user_model):
|
||||||
|
group = GroupModel(id=1, name="Test Group", version=1)
|
||||||
|
# Simulate member_associations for get_users_for_splitting if needed directly
|
||||||
|
# group.member_associations = [UserGroupModel(user_id=1, group_id=1, user=basic_user_model), UserGroupModel(user_id=2, group_id=1, user=another_user_model)]
|
||||||
|
return group
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def basic_list_model(basic_group_model, basic_user_model):
|
||||||
|
return ListModel(id=1, name="Test List", group_id=basic_group_model.id, group=basic_group_model, created_by_id=basic_user_model.id, creator=basic_user_model, version=1)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expense_create_data_equal_split_list_ctx(basic_list_model, basic_user_model):
|
||||||
|
return ExpenseCreate(
|
||||||
|
description="Grocery run",
|
||||||
|
total_amount=Decimal("30.00"),
|
||||||
|
currency="USD",
|
||||||
|
expense_date=datetime.now(timezone.utc).date(),
|
||||||
|
split_type=SplitTypeEnum.EQUAL,
|
||||||
|
list_id=basic_list_model.id,
|
||||||
|
group_id=None, # Derived from list
|
||||||
|
item_id=None,
|
||||||
|
paid_by_user_id=basic_user_model.id,
|
||||||
|
splits_in=None
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expense_create_data_equal_split_group_ctx(basic_group_model, basic_user_model):
|
||||||
|
return ExpenseCreate(
|
||||||
|
description="Movies",
|
||||||
|
total_amount=Decimal("50.00"),
|
||||||
|
currency="USD",
|
||||||
|
expense_date=datetime.now(timezone.utc).date(),
|
||||||
|
split_type=SplitTypeEnum.EQUAL,
|
||||||
|
list_id=None,
|
||||||
|
group_id=basic_group_model.id,
|
||||||
|
item_id=None,
|
||||||
|
paid_by_user_id=basic_user_model.id,
|
||||||
|
splits_in=None
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expense_create_data_exact_split(basic_group_model, basic_user_model, another_user_model):
|
||||||
|
return ExpenseCreate(
|
||||||
|
description="Dinner",
|
||||||
|
total_amount=Decimal("100.00"),
|
||||||
|
expense_date=datetime.now(timezone.utc).date(),
|
||||||
|
currency="USD",
|
||||||
|
split_type=SplitTypeEnum.EXACT_AMOUNTS,
|
||||||
|
group_id=basic_group_model.id,
|
||||||
|
paid_by_user_id=basic_user_model.id,
|
||||||
|
splits_in=[
|
||||||
|
ExpenseSplitCreate(user_id=basic_user_model.id, owed_amount=Decimal("60.00")),
|
||||||
|
ExpenseSplitCreate(user_id=another_user_model.id, owed_amount=Decimal("40.00")),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expense_update_data():
|
||||||
|
return ExpenseUpdate(
|
||||||
|
description="Updated Dinner",
|
||||||
|
total_amount=Decimal("120.00"),
|
||||||
|
version=1 # Ensure version is provided for updates
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model):
|
||||||
|
expense = ExpenseModel(
|
||||||
|
id=1,
|
||||||
|
description=expense_create_data_equal_split_group_ctx.description,
|
||||||
|
total_amount=expense_create_data_equal_split_group_ctx.total_amount,
|
||||||
|
currency=expense_create_data_equal_split_group_ctx.currency,
|
||||||
|
expense_date=expense_create_data_equal_split_group_ctx.expense_date,
|
||||||
|
split_type=expense_create_data_equal_split_group_ctx.split_type,
|
||||||
|
list_id=expense_create_data_equal_split_group_ctx.list_id,
|
||||||
|
group_id=expense_create_data_equal_split_group_ctx.group_id,
|
||||||
|
group=basic_group_model, # Link to group fixture
|
||||||
|
item_id=expense_create_data_equal_split_group_ctx.item_id,
|
||||||
|
paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id,
|
||||||
|
created_by_user_id=basic_user_model.id,
|
||||||
|
paid_by=basic_user_model,
|
||||||
|
created_by_user=basic_user_model,
|
||||||
|
version=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
# Simulate splits for an existing expense
|
||||||
|
expense.splits = [
|
||||||
|
ExpenseSplitModel(id=1, expense_id=1, user_id=basic_user_model.id, owed_amount=Decimal("25.00"), version=1),
|
||||||
|
ExpenseSplitModel(id=2, expense_id=1, user_id=2, owed_amount=Decimal("25.00"), version=1) # Assuming another_user_model has id 2
|
||||||
|
]
|
||||||
|
return expense
|
||||||
|
|
||||||
|
# Tests for get_users_for_splitting
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_users_for_splitting_group_context(mock_db_session, basic_group_model, basic_user_model, another_user_model):
|
||||||
|
user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id, group_id=basic_group_model.id)
|
||||||
|
user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id, group_id=basic_group_model.id)
|
||||||
|
basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2]
|
||||||
|
|
||||||
|
mock_db_session.get.return_value = basic_group_model # Mock get for group
|
||||||
|
|
||||||
|
users = await get_users_for_splitting(mock_db_session, expense_group_id=basic_group_model.id, expense_list_id=None, expense_paid_by_user_id=basic_user_model.id)
|
||||||
|
assert len(users) == 2
|
||||||
|
assert basic_user_model in users
|
||||||
|
assert another_user_model in users
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_users_for_splitting_list_context(mock_db_session, basic_list_model, basic_group_model, basic_user_model, another_user_model):
|
||||||
|
user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id, group_id=basic_group_model.id)
|
||||||
|
user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id, group_id=basic_group_model.id)
|
||||||
|
basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2]
|
||||||
|
basic_list_model.group = basic_group_model # Ensure list is associated with the group
|
||||||
|
|
||||||
|
mock_db_session.get.return_value = basic_list_model # Mock get for list
|
||||||
|
|
||||||
|
users = await get_users_for_splitting(mock_db_session, expense_group_id=None, expense_list_id=basic_list_model.id, expense_paid_by_user_id=basic_user_model.id)
|
||||||
|
assert len(users) == 2
|
||||||
|
assert basic_user_model in users
|
||||||
|
assert another_user_model in users
|
||||||
|
|
||||||
|
# --- create_expense Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_expense_equal_split_group_success(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model, another_user_model):
|
||||||
|
# Setup mocks
|
||||||
|
mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # paid_by_user, then group
|
||||||
|
|
||||||
|
# Mock get_users_for_splitting directly
|
||||||
|
with patch('app.crud.expense.get_users_for_splitting', new_callable=AsyncMock) as mock_get_users:
|
||||||
|
mock_get_users.return_value = [basic_user_model, another_user_model]
|
||||||
|
|
||||||
|
async def mock_refresh(instance, attribute_names=None, with_for_update=None):
|
||||||
|
if isinstance(instance, ExpenseModel):
|
||||||
|
instance.id = 1 # Simulate ID assignment after flush
|
||||||
|
instance.version = 1
|
||||||
|
instance.created_at = datetime.now(timezone.utc)
|
||||||
|
instance.updated_at = datetime.now(timezone.utc)
|
||||||
|
# Simulate splits being added to the session and linked by refresh
|
||||||
|
instance.splits = [
|
||||||
|
ExpenseSplitModel(expense_id=instance.id, user_id=basic_user_model.id, owed_amount=Decimal("25.00"), version=1),
|
||||||
|
ExpenseSplitModel(expense_id=instance.id, user_id=another_user_model.id, owed_amount=Decimal("25.00"), version=1)
|
||||||
|
]
|
||||||
|
return None
|
||||||
|
mock_db_session.refresh.side_effect = mock_refresh
|
||||||
|
|
||||||
|
created_expense = await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, current_user_id=basic_user_model.id)
|
||||||
|
|
||||||
|
mock_db_session.add.assert_called()
|
||||||
|
mock_db_session.flush.assert_called_once()
|
||||||
|
mock_db_session.refresh.assert_called_once()
|
||||||
|
assert created_expense is not None
|
||||||
|
assert created_expense.total_amount == expense_create_data_equal_split_group_ctx.total_amount
|
||||||
|
assert created_expense.split_type == SplitTypeEnum.EQUAL
|
||||||
|
assert len(created_expense.splits) == 2
|
||||||
|
|
||||||
|
expected_amount_per_user = (expense_create_data_equal_split_group_ctx.total_amount / 2).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||||
|
for split in created_expense.splits:
|
||||||
|
assert split.owed_amount == expected_amount_per_user
|
||||||
|
assert split.status == ExpenseSplitStatusEnum.unpaid # Verify initial split status
|
||||||
|
|
||||||
|
assert created_expense.overall_settlement_status == ExpenseOverallStatusEnum.unpaid # Verify initial expense status
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_expense_exact_split_success(mock_db_session, expense_create_data_exact_split, basic_user_model, basic_group_model, another_user_model):
|
||||||
|
mock_db_session.get.side_effect = [basic_user_model, basic_group_model, basic_user_model, another_user_model] # Payer, Group, User1 in split, User2 in split
|
||||||
|
|
||||||
|
async def mock_refresh(instance, attribute_names=None, with_for_update=None):
|
||||||
|
if isinstance(instance, ExpenseModel):
|
||||||
|
instance.id = 2
|
||||||
|
instance.version = 1
|
||||||
|
instance.splits = [
|
||||||
|
ExpenseSplitModel(expense_id=instance.id, user_id=basic_user_model.id, owed_amount=Decimal("60.00")),
|
||||||
|
ExpenseSplitModel(expense_id=instance.id, user_id=another_user_model.id, owed_amount=Decimal("40.00"))
|
||||||
|
]
|
||||||
|
return None
|
||||||
|
mock_db_session.refresh.side_effect = mock_refresh
|
||||||
|
|
||||||
|
created_expense = await create_expense(mock_db_session, expense_create_data_exact_split, current_user_id=basic_user_model.id)
|
||||||
|
|
||||||
|
mock_db_session.add.assert_called()
|
||||||
|
mock_db_session.flush.assert_called_once()
|
||||||
|
assert created_expense is not None
|
||||||
|
assert created_expense.split_type == SplitTypeEnum.EXACT_AMOUNTS
|
||||||
|
assert len(created_expense.splits) == 2
|
||||||
|
assert created_expense.splits[0].owed_amount == Decimal("60.00")
|
||||||
|
assert created_expense.splits[1].owed_amount == Decimal("40.00")
|
||||||
|
for split in created_expense.splits:
|
||||||
|
assert split.status == ExpenseSplitStatusEnum.unpaid # Verify initial split status
|
||||||
|
|
||||||
|
assert created_expense.overall_settlement_status == ExpenseOverallStatusEnum.unpaid # Verify initial expense status
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_expense_payer_not_found(mock_db_session, expense_create_data_equal_split_group_ctx):
|
||||||
|
mock_db_session.get.side_effect = [None] # Payer not found, group lookup won't happen
|
||||||
|
with pytest.raises(UserNotFoundError):
|
||||||
|
await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, 999) # current_user_id is for creator, paid_by_user_id is in schema
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_expense_no_list_or_group(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model):
|
||||||
|
mock_db_session.get.return_value = basic_user_model # Payer found
|
||||||
|
expense_data = expense_create_data_equal_split_group_ctx.model_copy()
|
||||||
|
expense_data.list_id = None
|
||||||
|
expense_data.group_id = None
|
||||||
|
with pytest.raises(InvalidOperationError, match="Expense must be associated with a list or a group"):
|
||||||
|
await create_expense(mock_db_session, expense_data, basic_user_model.id)
|
||||||
|
|
||||||
|
# --- get_expense_by_id Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_expense_by_id_found(mock_db_session, db_expense_model):
|
||||||
|
mock_db_session.get.return_value = db_expense_model
|
||||||
|
expense = await get_expense_by_id(mock_db_session, db_expense_model.id)
|
||||||
|
assert expense is not None
|
||||||
|
assert expense.id == db_expense_model.id
|
||||||
|
mock_db_session.get.assert_called_once_with(ExpenseModel, db_expense_model.id, options=[
|
||||||
|
MagicMock(), MagicMock(), MagicMock()
|
||||||
|
]) # Adjust based on actual options used in get_expense_by_id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_expense_by_id_not_found(mock_db_session):
|
||||||
|
mock_db_session.get.return_value = None
|
||||||
|
expense = await get_expense_by_id(mock_db_session, 999)
|
||||||
|
assert expense is None
|
||||||
|
|
||||||
|
# --- get_expenses_for_list Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_expenses_for_list_success(mock_db_session, db_expense_model, basic_list_model):
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalars.return_value.all.return_value = [db_expense_model]
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
expenses = await get_expenses_for_list(mock_db_session, basic_list_model.id)
|
||||||
|
assert len(expenses) == 1
|
||||||
|
assert expenses[0].id == db_expense_model.id
|
||||||
|
mock_db_session.execute.assert_called_once()
|
||||||
|
|
||||||
|
# --- get_expenses_for_group Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_expenses_for_group_success(mock_db_session, db_expense_model, basic_group_model):
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalars.return_value.all.return_value = [db_expense_model]
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
expenses = await get_expenses_for_group(mock_db_session, basic_group_model.id)
|
||||||
|
assert len(expenses) == 1
|
||||||
|
assert expenses[0].id == db_expense_model.id
|
||||||
|
mock_db_session.execute.assert_called_once()
|
||||||
|
|
||||||
|
# --- update_expense Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_expense_success(mock_db_session, db_expense_model, expense_update_data, basic_user_model):
|
||||||
|
expense_update_data.version = db_expense_model.version # Match version
|
||||||
|
|
||||||
|
# Simulate that the db_expense_model is returned by session.get
|
||||||
|
mock_db_session.get.return_value = db_expense_model
|
||||||
|
|
||||||
|
updated_expense = await update_expense(mock_db_session, db_expense_model.id, expense_update_data, basic_user_model.id)
|
||||||
|
|
||||||
|
mock_db_session.add.assert_called_once_with(db_expense_model)
|
||||||
|
mock_db_session.flush.assert_called_once()
|
||||||
|
mock_db_session.refresh.assert_called_once_with(db_expense_model)
|
||||||
|
assert updated_expense.description == expense_update_data.description
|
||||||
|
assert updated_expense.total_amount == expense_update_data.total_amount
|
||||||
|
assert updated_expense.version == db_expense_model.version # Version incremented by the function
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_expense_not_found(mock_db_session, expense_update_data, basic_user_model):
|
||||||
|
mock_db_session.get.return_value = None # Expense not found
|
||||||
|
with pytest.raises(ExpenseNotFoundError):
|
||||||
|
await update_expense(mock_db_session, 999, expense_update_data, basic_user_model.id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_expense_version_conflict(mock_db_session, db_expense_model, expense_update_data, basic_user_model):
|
||||||
|
expense_update_data.version = db_expense_model.version + 1 # Create version mismatch
|
||||||
|
mock_db_session.get.return_value = db_expense_model
|
||||||
|
with pytest.raises(ConflictError):
|
||||||
|
await update_expense(mock_db_session, db_expense_model.id, expense_update_data, basic_user_model.id)
|
||||||
|
mock_db_session.rollback.assert_called_once()
|
||||||
|
|
||||||
|
# --- delete_expense Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_expense_success(mock_db_session, db_expense_model, basic_user_model):
|
||||||
|
mock_db_session.get.return_value = db_expense_model # Simulate expense found
|
||||||
|
|
||||||
|
await delete_expense(mock_db_session, db_expense_model.id, basic_user_model.id)
|
||||||
|
|
||||||
|
mock_db_session.delete.assert_called_once_with(db_expense_model)
|
||||||
|
# Assuming delete_expense uses session.begin() and commits
|
||||||
|
mock_db_session.begin().commit.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_expense_not_found(mock_db_session, basic_user_model):
|
||||||
|
mock_db_session.get.return_value = None # Expense not found
|
||||||
|
with pytest.raises(ExpenseNotFoundError):
|
||||||
|
await delete_expense(mock_db_session, 999, basic_user_model.id)
|
||||||
|
mock_db_session.rollback.assert_not_called() # Rollback might be called by begin() context manager exit
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_expense_db_error(mock_db_session, db_expense_model, basic_user_model):
|
||||||
|
mock_db_session.get.return_value = db_expense_model
|
||||||
|
mock_db_session.delete.side_effect = OperationalError("mock op error", "params", "orig")
|
||||||
|
with pytest.raises(DatabaseTransactionError):
|
||||||
|
await delete_expense(mock_db_session, db_expense_model.id, basic_user_model.id)
|
||||||
|
mock_db_session.begin().rollback.assert_called_once() # Rollback from the transaction context
|
354
be/tests/crud/test_group.py
Normal file
354
be/tests/crud/test_group.py
Normal file
@ -0,0 +1,354 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
from sqlalchemy import delete, func
|
||||||
|
|
||||||
|
from app.crud.group import (
|
||||||
|
create_group,
|
||||||
|
get_user_groups,
|
||||||
|
get_group_by_id,
|
||||||
|
is_user_member,
|
||||||
|
get_user_role_in_group,
|
||||||
|
add_user_to_group,
|
||||||
|
remove_user_from_group,
|
||||||
|
get_group_member_count,
|
||||||
|
check_group_membership,
|
||||||
|
check_user_role_in_group,
|
||||||
|
update_group_member_role # Assuming this will be added
|
||||||
|
)
|
||||||
|
from app.schemas.group import GroupCreate, GroupUpdate # Added GroupUpdate
|
||||||
|
from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel, UserRoleEnum
|
||||||
|
from app.core.exceptions import (
|
||||||
|
GroupOperationError,
|
||||||
|
GroupNotFoundError,
|
||||||
|
DatabaseConnectionError,
|
||||||
|
DatabaseIntegrityError,
|
||||||
|
DatabaseQueryError,
|
||||||
|
DatabaseTransactionError,
|
||||||
|
GroupMembershipError,
|
||||||
|
GroupPermissionError,
|
||||||
|
UserNotFoundError, # For adding user to group
|
||||||
|
ConflictError # For updates
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_db_session():
|
||||||
|
session = AsyncMock()
|
||||||
|
mock_transaction_context = AsyncMock()
|
||||||
|
session.begin = MagicMock(return_value=mock_transaction_context)
|
||||||
|
session.commit = AsyncMock()
|
||||||
|
session.rollback = AsyncMock()
|
||||||
|
session.refresh = AsyncMock()
|
||||||
|
session.add = MagicMock()
|
||||||
|
session.delete = MagicMock()
|
||||||
|
session.execute = AsyncMock()
|
||||||
|
session.get = AsyncMock()
|
||||||
|
session.flush = AsyncMock()
|
||||||
|
return session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def group_create_data():
|
||||||
|
return GroupCreate(name="Test Group")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def group_update_data():
|
||||||
|
return GroupUpdate(name="Updated Test Group", version=1)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def creator_user_model():
|
||||||
|
return UserModel(id=1, name="Creator User", email="creator@example.com", version=1)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def member_user_model():
|
||||||
|
return UserModel(id=2, name="Member User", email="member@example.com", version=1)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def non_member_user_model():
|
||||||
|
return UserModel(id=3, name="Non Member User", email="nonmember@example.com", version=1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_group_model(creator_user_model):
|
||||||
|
return GroupModel(id=1, name="Test Group", created_by_id=creator_user_model.id, creator=creator_user_model, version=1)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_user_group_owner_assoc(db_group_model, creator_user_model):
|
||||||
|
return UserGroupModel(id=1, user_id=creator_user_model.id, group_id=db_group_model.id, role=UserRoleEnum.owner, user=creator_user_model, group=db_group_model, version=1)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_user_group_member_assoc(db_group_model, member_user_model):
|
||||||
|
return UserGroupModel(id=2, user_id=member_user_model.id, group_id=db_group_model.id, role=UserRoleEnum.member, user=member_user_model, group=db_group_model, version=1)
|
||||||
|
|
||||||
|
# --- create_group Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_group_success(mock_db_session, group_create_data, creator_user_model):
|
||||||
|
async def mock_refresh(instance, attribute_names=None, with_for_update=None):
|
||||||
|
if isinstance(instance, GroupModel):
|
||||||
|
instance.id = 1 # Simulate ID assignment by DB
|
||||||
|
instance.version = 1
|
||||||
|
# Simulate the UserGroup association being added and refreshed if done via relationship back_populates
|
||||||
|
instance.members = [UserGroupModel(user_id=creator_user_model.id, group_id=instance.id, role=UserRoleEnum.owner, version=1)]
|
||||||
|
elif isinstance(instance, UserGroupModel):
|
||||||
|
instance.id = 1 # Simulate ID for UserGroupModel
|
||||||
|
instance.version = 1
|
||||||
|
return None
|
||||||
|
mock_db_session.refresh.side_effect = mock_refresh
|
||||||
|
|
||||||
|
# Mock the user get for the creator
|
||||||
|
mock_db_session.get.return_value = creator_user_model
|
||||||
|
|
||||||
|
created_group = await create_group(mock_db_session, group_create_data, creator_user_model.id)
|
||||||
|
|
||||||
|
assert mock_db_session.add.call_count == 2 # Group and UserGroup
|
||||||
|
mock_db_session.flush.assert_called()
|
||||||
|
assert mock_db_session.refresh.call_count >= 1 # Called for group, maybe for UserGroup too
|
||||||
|
assert created_group is not None
|
||||||
|
assert created_group.name == group_create_data.name
|
||||||
|
assert created_group.created_by_id == creator_user_model.id
|
||||||
|
assert len(created_group.members) == 1
|
||||||
|
assert created_group.members[0].role == UserRoleEnum.owner
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_group_integrity_error(mock_db_session, group_create_data, creator_user_model):
|
||||||
|
mock_db_session.get.return_value = creator_user_model # Creator user found
|
||||||
|
mock_db_session.flush.side_effect = IntegrityError("mock integrity error", "params", "orig")
|
||||||
|
with pytest.raises(DatabaseIntegrityError):
|
||||||
|
await create_group(mock_db_session, group_create_data, creator_user_model.id)
|
||||||
|
mock_db_session.rollback.assert_called_once()
|
||||||
|
|
||||||
|
# --- get_user_groups Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_groups_success(mock_db_session, db_group_model, creator_user_model):
|
||||||
|
# Mock the execute call that fetches groups for a user
|
||||||
|
mock_result_groups = AsyncMock()
|
||||||
|
mock_result_groups.scalars.return_value.all.return_value = [db_group_model]
|
||||||
|
mock_db_session.execute.return_value = mock_result_groups
|
||||||
|
|
||||||
|
groups = await get_user_groups(mock_db_session, creator_user_model.id)
|
||||||
|
assert len(groups) == 1
|
||||||
|
assert groups[0].name == db_group_model.name
|
||||||
|
mock_db_session.execute.assert_called_once()
|
||||||
|
|
||||||
|
# --- get_group_by_id Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_group_by_id_found(mock_db_session, db_group_model):
|
||||||
|
mock_db_session.get.return_value = db_group_model
|
||||||
|
group = await get_group_by_id(mock_db_session, db_group_model.id)
|
||||||
|
assert group is not None
|
||||||
|
assert group.id == db_group_model.id
|
||||||
|
mock_db_session.get.assert_called_once_with(GroupModel, db_group_model.id, options=ANY) # options for eager loading
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_group_by_id_not_found(mock_db_session):
|
||||||
|
mock_db_session.get.return_value = None
|
||||||
|
group = await get_group_by_id(mock_db_session, 999)
|
||||||
|
assert group is None
|
||||||
|
|
||||||
|
# --- is_user_member Tests ---
|
||||||
|
from unittest.mock import ANY # For checking options in get
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_user_member_true(mock_db_session, db_group_model, creator_user_model, db_user_group_owner_assoc):
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = db_user_group_owner_assoc.id
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
is_member = await is_user_member(mock_db_session, db_group_model.id, creator_user_model.id)
|
||||||
|
assert is_member is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_user_member_false(mock_db_session, db_group_model, non_member_user_model):
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = None
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
is_member = await is_user_member(mock_db_session, db_group_model.id, non_member_user_model.id)
|
||||||
|
assert is_member is False
|
||||||
|
|
||||||
|
# --- get_user_role_in_group Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_role_in_group_owner(mock_db_session, db_group_model, creator_user_model):
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = UserRoleEnum.owner
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
role = await get_user_role_in_group(mock_db_session, db_group_model.id, creator_user_model.id)
|
||||||
|
assert role == UserRoleEnum.owner
|
||||||
|
|
||||||
|
# --- add_user_to_group Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_user_to_group_new_member(mock_db_session, db_group_model, member_user_model, non_member_user_model):
|
||||||
|
# Mock is_user_member to return False initially
|
||||||
|
with patch('app.crud.group.is_user_member', new_callable=AsyncMock) as mock_is_member:
|
||||||
|
mock_is_member.return_value = False
|
||||||
|
# Mock get for the user to be added
|
||||||
|
mock_db_session.get.return_value = non_member_user_model
|
||||||
|
|
||||||
|
async def mock_refresh_user_group(instance, attribute_names=None, with_for_update=None):
|
||||||
|
instance.id = 100
|
||||||
|
instance.version = 1
|
||||||
|
return None
|
||||||
|
mock_db_session.refresh.side_effect = mock_refresh_user_group
|
||||||
|
|
||||||
|
user_group_assoc = await add_user_to_group(mock_db_session, db_group_model, non_member_user_model.id, UserRoleEnum.member)
|
||||||
|
|
||||||
|
mock_db_session.add.assert_called_once()
|
||||||
|
mock_db_session.flush.assert_called_once()
|
||||||
|
mock_db_session.refresh.assert_called_once()
|
||||||
|
assert user_group_assoc is not None
|
||||||
|
assert user_group_assoc.user_id == non_member_user_model.id
|
||||||
|
assert user_group_assoc.group_id == db_group_model.id
|
||||||
|
assert user_group_assoc.role == UserRoleEnum.member
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_user_to_group_already_member(mock_db_session, db_group_model, creator_user_model):
|
||||||
|
with patch('app.crud.group.is_user_member', new_callable=AsyncMock) as mock_is_member:
|
||||||
|
mock_is_member.return_value = True # User is already a member
|
||||||
|
# No need to mock session.get for the user if is_user_member is true first
|
||||||
|
|
||||||
|
user_group_assoc = await add_user_to_group(mock_db_session, db_group_model, creator_user_model.id)
|
||||||
|
assert user_group_assoc is None # Should return None if user already member
|
||||||
|
mock_db_session.add.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_user_to_group_user_not_found(mock_db_session, db_group_model):
|
||||||
|
with patch('app.crud.group.is_user_member', new_callable=AsyncMock) as mock_is_member:
|
||||||
|
mock_is_member.return_value = False # User not member initially
|
||||||
|
mock_db_session.get.return_value = None # User to be added not found
|
||||||
|
|
||||||
|
with pytest.raises(UserNotFoundError):
|
||||||
|
await add_user_to_group(mock_db_session, db_group_model, 999, UserRoleEnum.member)
|
||||||
|
mock_db_session.add.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# --- remove_user_from_group Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_user_from_group_success(mock_db_session, db_group_model, member_user_model, db_user_group_member_assoc):
|
||||||
|
# Mock get_user_role_in_group to confirm user is not owner
|
||||||
|
with patch('app.crud.group.get_user_role_in_group', new_callable=AsyncMock) as mock_get_role:
|
||||||
|
mock_get_role.return_value = UserRoleEnum.member
|
||||||
|
|
||||||
|
# Mock the execute call for the delete statement
|
||||||
|
mock_delete_result = AsyncMock()
|
||||||
|
mock_delete_result.rowcount = 1 # Simulate one row was affected/deleted
|
||||||
|
mock_db_session.execute.return_value = mock_delete_result
|
||||||
|
|
||||||
|
removed = await remove_user_from_group(mock_db_session, db_group_model, member_user_model.id)
|
||||||
|
assert removed is True
|
||||||
|
mock_db_session.execute.assert_called_once()
|
||||||
|
# Check that the delete statement was indeed called, e.g., by checking the structure of the query passed to execute
|
||||||
|
# This is a bit more involved if you want to match the exact SQLAlchemy delete object.
|
||||||
|
# For now, assert_called_once() confirms it was called.
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_user_from_group_owner_last_member(mock_db_session, db_group_model, creator_user_model):
|
||||||
|
with patch('app.crud.group.get_user_role_in_group', new_callable=AsyncMock) as mock_get_role, \
|
||||||
|
patch('app.crud.group.get_group_member_count', new_callable=AsyncMock) as mock_member_count:
|
||||||
|
|
||||||
|
mock_get_role.return_value = UserRoleEnum.owner
|
||||||
|
mock_member_count.return_value = 1 # This user is the last member
|
||||||
|
|
||||||
|
with pytest.raises(GroupOperationError, match="Cannot remove the sole owner of a group. Delete the group instead."):
|
||||||
|
await remove_user_from_group(mock_db_session, db_group_model, creator_user_model.id)
|
||||||
|
mock_db_session.execute.assert_not_called() # Delete should not be called
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_user_from_group_not_member(mock_db_session, db_group_model, non_member_user_model):
|
||||||
|
# Mock get_user_role_in_group to return None, indicating not a member or role not found (effectively not a member for removal purposes)
|
||||||
|
with patch('app.crud.group.get_user_role_in_group', new_callable=AsyncMock) as mock_get_role:
|
||||||
|
mock_get_role.return_value = None
|
||||||
|
|
||||||
|
# For this specific test, we might not even need to mock `execute` if `get_user_role_in_group` returning None
|
||||||
|
# already causes the function to exit or raise an error handled by `GroupMembershipError`.
|
||||||
|
# However, if the function proceeds to attempt a delete that affects 0 rows, then `rowcount = 0` is the correct mock.
|
||||||
|
mock_delete_result = AsyncMock()
|
||||||
|
mock_delete_result.rowcount = 0
|
||||||
|
mock_db_session.execute.return_value = mock_delete_result
|
||||||
|
|
||||||
|
with pytest.raises(GroupMembershipError, match="User is not a member of the group or cannot be removed."):
|
||||||
|
await remove_user_from_group(mock_db_session, db_group_model, non_member_user_model.id)
|
||||||
|
|
||||||
|
# Depending on the implementation: execute might be called or not.
|
||||||
|
# If there's a check before executing delete, it might not be called.
|
||||||
|
# If it tries to delete and finds nothing, it would be called.
|
||||||
|
# For now, let's assume it could be called. If your function logic prevents it, adjust this.
|
||||||
|
# mock_db_session.execute.assert_called_once() <--- This might fail if not called
|
||||||
|
|
||||||
|
|
||||||
|
# --- get_group_member_count Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_group_member_count_success(mock_db_session, db_group_model):
|
||||||
|
mock_result_count = AsyncMock()
|
||||||
|
mock_result_count.scalar_one.return_value = 5 # Example count
|
||||||
|
mock_db_session.execute.return_value = mock_result_count
|
||||||
|
|
||||||
|
count = await get_group_member_count(mock_db_session, db_group_model.id)
|
||||||
|
assert count == 5
|
||||||
|
|
||||||
|
# --- check_group_membership Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_group_membership_is_member(mock_db_session, db_group_model, creator_user_model):
|
||||||
|
# Mock get_group_by_id
|
||||||
|
with patch('app.crud.group.get_group_by_id', new_callable=AsyncMock) as mock_get_group, \
|
||||||
|
patch('app.crud.group.is_user_member', new_callable=AsyncMock) as mock_is_member:
|
||||||
|
|
||||||
|
mock_get_group.return_value = db_group_model
|
||||||
|
mock_is_member.return_value = True
|
||||||
|
|
||||||
|
group = await check_group_membership(mock_db_session, db_group_model.id, creator_user_model.id)
|
||||||
|
assert group is db_group_model
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_group_membership_group_not_found(mock_db_session, creator_user_model):
|
||||||
|
with patch('app.crud.group.get_group_by_id', new_callable=AsyncMock) as mock_get_group:
|
||||||
|
mock_get_group.return_value = None
|
||||||
|
with pytest.raises(GroupNotFoundError):
|
||||||
|
await check_group_membership(mock_db_session, 999, creator_user_model.id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_group_membership_not_member(mock_db_session, db_group_model, non_member_user_model):
|
||||||
|
with patch('app.crud.group.get_group_by_id', new_callable=AsyncMock) as mock_get_group, \
|
||||||
|
patch('app.crud.group.is_user_member', new_callable=AsyncMock) as mock_is_member:
|
||||||
|
|
||||||
|
mock_get_group.return_value = db_group_model
|
||||||
|
mock_is_member.return_value = False
|
||||||
|
|
||||||
|
with pytest.raises(GroupMembershipError, match="User is not a member of the specified group"):
|
||||||
|
await check_group_membership(mock_db_session, db_group_model.id, non_member_user_model.id)
|
||||||
|
|
||||||
|
# --- check_user_role_in_group (standalone check, not just membership) ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_user_role_in_group_sufficient_role(mock_db_session, db_group_model, creator_user_model):
|
||||||
|
# This test assumes check_group_membership is called internally first, or similar logic applies
|
||||||
|
with patch('app.crud.group.check_group_membership', new_callable=AsyncMock) as mock_check_membership, \
|
||||||
|
patch('app.crud.group.get_user_role_in_group', new_callable=AsyncMock) as mock_get_role:
|
||||||
|
|
||||||
|
mock_check_membership.return_value = db_group_model # Group exists and user is member
|
||||||
|
mock_get_role.return_value = UserRoleEnum.owner
|
||||||
|
|
||||||
|
# Check if owner has owner role (should pass)
|
||||||
|
await check_user_role_in_group(mock_db_session, db_group_model.id, creator_user_model.id, UserRoleEnum.owner)
|
||||||
|
# Check if owner has member role (should pass, as owner is implicitly a member with higher privileges)
|
||||||
|
await check_user_role_in_group(mock_db_session, db_group_model.id, creator_user_model.id, UserRoleEnum.member)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_user_role_in_group_insufficient_role(mock_db_session, db_group_model, member_user_model):
|
||||||
|
with patch('app.crud.group.check_group_membership', new_callable=AsyncMock) as mock_check_membership, \
|
||||||
|
patch('app.crud.group.get_user_role_in_group', new_callable=AsyncMock) as mock_get_role:
|
||||||
|
|
||||||
|
mock_check_membership.return_value = db_group_model
|
||||||
|
mock_get_role.return_value = UserRoleEnum.member
|
||||||
|
|
||||||
|
with pytest.raises(GroupPermissionError, match="User does not have the required role in the group."):
|
||||||
|
await check_user_role_in_group(mock_db_session, db_group_model.id, member_user_model.id, UserRoleEnum.owner)
|
||||||
|
|
||||||
|
# Future test ideas, to be moved to a proper test planning tool or issue tracker.
|
||||||
|
# Consider these during major refactors or when expanding test coverage.
|
||||||
|
|
||||||
|
# Example of a DB operational error test (can be adapted for other functions)
|
||||||
|
# @pytest.mark.asyncio
|
||||||
|
# async def test_create_group_operational_error(mock_db_session, group_create_data, creator_user_model):
|
||||||
|
# mock_db_session.get.return_value = creator_user_model
|
||||||
|
# mock_db_session.flush.side_effect = OperationalError("mock operational error", "params", "orig")
|
||||||
|
# with pytest.raises(DatabaseConnectionError): # Assuming OperationalError maps to this
|
||||||
|
# await create_group(mock_db_session, group_create_data, creator_user_model.id)
|
||||||
|
# mock_db_session.rollback.assert_called_once()
|
174
be/tests/crud/test_invite.py
Normal file
174
be/tests/crud/test_invite.py
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError # Assuming these might be raised
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
from app.crud.invite import (
|
||||||
|
create_invite,
|
||||||
|
get_active_invite_by_code,
|
||||||
|
deactivate_invite,
|
||||||
|
MAX_CODE_GENERATION_ATTEMPTS
|
||||||
|
)
|
||||||
|
from app.models import Invite as InviteModel, User as UserModel, Group as GroupModel # For context
|
||||||
|
# No specific schemas for invite CRUD usually, but models are used.
|
||||||
|
|
||||||
|
# Fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_db_session():
|
||||||
|
session = AsyncMock()
|
||||||
|
session.commit = AsyncMock()
|
||||||
|
session.rollback = AsyncMock()
|
||||||
|
session.refresh = AsyncMock()
|
||||||
|
session.add = MagicMock()
|
||||||
|
session.execute = AsyncMock()
|
||||||
|
return session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def group_model():
|
||||||
|
return GroupModel(id=1, name="Test Group")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_model(): # Creator
|
||||||
|
return UserModel(id=1, name="Creator User")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_invite_model(group_model, user_model):
|
||||||
|
return InviteModel(
|
||||||
|
id=1,
|
||||||
|
code="test_invite_code_123",
|
||||||
|
group_id=group_model.id,
|
||||||
|
created_by_id=user_model.id,
|
||||||
|
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||||
|
is_active=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- create_invite Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('app.crud.invite.secrets.token_urlsafe') # Patch secrets.token_urlsafe
|
||||||
|
async def test_create_invite_success_first_attempt(mock_token_urlsafe, mock_db_session, group_model, user_model):
|
||||||
|
generated_code = "unique_code_123"
|
||||||
|
mock_token_urlsafe.return_value = generated_code
|
||||||
|
|
||||||
|
# Mock DB execute for checking existing code (first attempt, no existing code)
|
||||||
|
mock_existing_check_result = AsyncMock()
|
||||||
|
mock_existing_check_result.scalar_one_or_none.return_value = None
|
||||||
|
mock_db_session.execute.return_value = mock_existing_check_result
|
||||||
|
|
||||||
|
invite = await create_invite(mock_db_session, group_model.id, user_model.id, expires_in_days=5)
|
||||||
|
|
||||||
|
mock_token_urlsafe.assert_called_once_with(16)
|
||||||
|
mock_db_session.execute.assert_called_once() # For the uniqueness check
|
||||||
|
mock_db_session.add.assert_called_once()
|
||||||
|
mock_db_session.commit.assert_called_once()
|
||||||
|
mock_db_session.refresh.assert_called_once_with(invite)
|
||||||
|
|
||||||
|
assert invite is not None
|
||||||
|
assert invite.code == generated_code
|
||||||
|
assert invite.group_id == group_model.id
|
||||||
|
assert invite.created_by_id == user_model.id
|
||||||
|
assert invite.is_active is True
|
||||||
|
assert invite.expires_at > datetime.now(timezone.utc) + timedelta(days=4) # Check expiry is roughly correct
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('app.crud.invite.secrets.token_urlsafe')
|
||||||
|
async def test_create_invite_success_after_collision(mock_token_urlsafe, mock_db_session, group_model, user_model):
|
||||||
|
colliding_code = "colliding_code"
|
||||||
|
unique_code = "finally_unique_code"
|
||||||
|
mock_token_urlsafe.side_effect = [colliding_code, unique_code] # First call collides, second is unique
|
||||||
|
|
||||||
|
# Mock DB execute for checking existing code
|
||||||
|
mock_collision_check_result = AsyncMock()
|
||||||
|
mock_collision_check_result.scalar_one_or_none.return_value = 1 # Simulate collision (ID found)
|
||||||
|
|
||||||
|
mock_no_collision_check_result = AsyncMock()
|
||||||
|
mock_no_collision_check_result.scalar_one_or_none.return_value = None # No collision
|
||||||
|
|
||||||
|
mock_db_session.execute.side_effect = [mock_collision_check_result, mock_no_collision_check_result]
|
||||||
|
|
||||||
|
invite = await create_invite(mock_db_session, group_model.id, user_model.id)
|
||||||
|
|
||||||
|
assert mock_token_urlsafe.call_count == 2
|
||||||
|
assert mock_db_session.execute.call_count == 2
|
||||||
|
assert invite is not None
|
||||||
|
assert invite.code == unique_code
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('app.crud.invite.secrets.token_urlsafe')
|
||||||
|
async def test_create_invite_fails_after_max_attempts(mock_token_urlsafe, mock_db_session, group_model, user_model):
|
||||||
|
mock_token_urlsafe.return_value = "always_colliding_code"
|
||||||
|
|
||||||
|
mock_collision_check_result = AsyncMock()
|
||||||
|
mock_collision_check_result.scalar_one_or_none.return_value = 1 # Always collide
|
||||||
|
mock_db_session.execute.return_value = mock_collision_check_result
|
||||||
|
|
||||||
|
invite = await create_invite(mock_db_session, group_model.id, user_model.id)
|
||||||
|
|
||||||
|
assert invite is None
|
||||||
|
assert mock_token_urlsafe.call_count == MAX_CODE_GENERATION_ATTEMPTS
|
||||||
|
assert mock_db_session.execute.call_count == MAX_CODE_GENERATION_ATTEMPTS
|
||||||
|
mock_db_session.add.assert_not_called()
|
||||||
|
|
||||||
|
# --- get_active_invite_by_code Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_active_invite_by_code_found_active(mock_db_session, db_invite_model):
|
||||||
|
db_invite_model.is_active = True
|
||||||
|
db_invite_model.expires_at = datetime.now(timezone.utc) + timedelta(days=1)
|
||||||
|
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalars.return_value.first.return_value = db_invite_model
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code)
|
||||||
|
assert invite is not None
|
||||||
|
assert invite.code == db_invite_model.code
|
||||||
|
mock_db_session.execute.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_active_invite_by_code_not_found(mock_db_session):
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalars.return_value.first.return_value = None
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
invite = await get_active_invite_by_code(mock_db_session, "non_existent_code")
|
||||||
|
assert invite is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_active_invite_by_code_inactive(mock_db_session, db_invite_model):
|
||||||
|
db_invite_model.is_active = False # Inactive
|
||||||
|
db_invite_model.expires_at = datetime.now(timezone.utc) + timedelta(days=1)
|
||||||
|
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalars.return_value.first.return_value = None # Should not be found by query
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code)
|
||||||
|
assert invite is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_active_invite_by_code_expired(mock_db_session, db_invite_model):
|
||||||
|
db_invite_model.is_active = True
|
||||||
|
db_invite_model.expires_at = datetime.now(timezone.utc) - timedelta(days=1) # Expired
|
||||||
|
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalars.return_value.first.return_value = None # Should not be found by query
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code)
|
||||||
|
assert invite is None
|
||||||
|
|
||||||
|
# --- deactivate_invite Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deactivate_invite_success(mock_db_session, db_invite_model):
|
||||||
|
db_invite_model.is_active = True # Ensure it starts active
|
||||||
|
|
||||||
|
deactivated_invite = await deactivate_invite(mock_db_session, db_invite_model)
|
||||||
|
|
||||||
|
mock_db_session.add.assert_called_once_with(db_invite_model)
|
||||||
|
mock_db_session.commit.assert_called_once()
|
||||||
|
mock_db_session.refresh.assert_called_once_with(db_invite_model)
|
||||||
|
assert deactivated_invite.is_active is False
|
||||||
|
|
||||||
|
# It might be useful to test DB error cases (OperationalError, etc.) for each function
|
||||||
|
# if they have specific try-except blocks, but invite.py seems to rely on caller/framework for some of that.
|
||||||
|
# create_invite has its own DB interaction within the loop, so that's covered.
|
||||||
|
# get_active_invite_by_code and deactivate_invite are simpler DB ops.
|
186
be/tests/crud/test_item.py
Normal file
186
be/tests/crud/test_item.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from app.crud.item import (
|
||||||
|
create_item,
|
||||||
|
get_items_by_list_id,
|
||||||
|
get_item_by_id,
|
||||||
|
update_item,
|
||||||
|
delete_item
|
||||||
|
)
|
||||||
|
from app.schemas.item import ItemCreate, ItemUpdate
|
||||||
|
from app.models import Item as ItemModel, User as UserModel, List as ListModel
|
||||||
|
from app.core.exceptions import (
|
||||||
|
ItemNotFoundError, # Not directly raised by CRUD but good for API layer tests
|
||||||
|
DatabaseConnectionError,
|
||||||
|
DatabaseIntegrityError,
|
||||||
|
DatabaseQueryError,
|
||||||
|
DatabaseTransactionError,
|
||||||
|
ConflictError
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_db_session():
|
||||||
|
session = AsyncMock()
|
||||||
|
session.begin = AsyncMock()
|
||||||
|
session.commit = AsyncMock()
|
||||||
|
session.rollback = AsyncMock()
|
||||||
|
session.refresh = AsyncMock()
|
||||||
|
session.add = MagicMock()
|
||||||
|
session.delete = MagicMock()
|
||||||
|
session.execute = AsyncMock()
|
||||||
|
session.get = AsyncMock() # Though not directly used in item.py, good for consistency
|
||||||
|
session.flush = AsyncMock()
|
||||||
|
return session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def item_create_data():
|
||||||
|
return ItemCreate(name="Test Item", quantity="1 pack")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def item_update_data():
|
||||||
|
return ItemUpdate(name="Updated Test Item", quantity="2 packs", version=1, is_complete=False)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_model():
|
||||||
|
return UserModel(id=1, name="Test User", email="test@example.com")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def list_model():
|
||||||
|
return ListModel(id=1, name="Test List")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_item_model(list_model, user_model):
|
||||||
|
return ItemModel(
|
||||||
|
id=1,
|
||||||
|
name="Existing Item",
|
||||||
|
quantity="1 unit",
|
||||||
|
list_id=list_model.id,
|
||||||
|
added_by_id=user_model.id,
|
||||||
|
is_complete=False,
|
||||||
|
version=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- create_item Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_item_success(mock_db_session, item_create_data, list_model, user_model):
|
||||||
|
async def mock_refresh(instance):
|
||||||
|
instance.id = 10 # Simulate ID assignment
|
||||||
|
instance.version = 1 # Simulate version init
|
||||||
|
instance.created_at = datetime.now(timezone.utc)
|
||||||
|
instance.updated_at = datetime.now(timezone.utc)
|
||||||
|
return None
|
||||||
|
mock_db_session.refresh = AsyncMock(side_effect=mock_refresh)
|
||||||
|
|
||||||
|
created_item = await create_item(mock_db_session, item_create_data, list_model.id, user_model.id)
|
||||||
|
|
||||||
|
mock_db_session.add.assert_called_once()
|
||||||
|
mock_db_session.flush.assert_called_once()
|
||||||
|
mock_db_session.refresh.assert_called_once_with(created_item)
|
||||||
|
assert created_item is not None
|
||||||
|
assert created_item.name == item_create_data.name
|
||||||
|
assert created_item.list_id == list_model.id
|
||||||
|
assert created_item.added_by_id == user_model.id
|
||||||
|
assert created_item.is_complete is False
|
||||||
|
assert created_item.version == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_item_integrity_error(mock_db_session, item_create_data, list_model, user_model):
|
||||||
|
mock_db_session.flush.side_effect = IntegrityError("mock integrity error", "params", "orig")
|
||||||
|
with pytest.raises(DatabaseIntegrityError):
|
||||||
|
await create_item(mock_db_session, item_create_data, list_model.id, user_model.id)
|
||||||
|
mock_db_session.rollback.assert_called_once()
|
||||||
|
|
||||||
|
# --- get_items_by_list_id Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_items_by_list_id_success(mock_db_session, db_item_model, list_model):
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalars.return_value.all.return_value = [db_item_model]
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
items = await get_items_by_list_id(mock_db_session, list_model.id)
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0].id == db_item_model.id
|
||||||
|
mock_db_session.execute.assert_called_once()
|
||||||
|
|
||||||
|
# --- get_item_by_id Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_item_by_id_found(mock_db_session, db_item_model):
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalars.return_value.first.return_value = db_item_model
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
item = await get_item_by_id(mock_db_session, db_item_model.id)
|
||||||
|
assert item is not None
|
||||||
|
assert item.id == db_item_model.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_item_by_id_not_found(mock_db_session):
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalars.return_value.first.return_value = None
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
item = await get_item_by_id(mock_db_session, 999)
|
||||||
|
assert item is None
|
||||||
|
|
||||||
|
# --- update_item Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_item_success(mock_db_session, db_item_model, item_update_data, user_model):
|
||||||
|
item_update_data.version = db_item_model.version # Match versions for successful update
|
||||||
|
item_update_data.name = "Newly Updated Name"
|
||||||
|
item_update_data.is_complete = True # Test completion logic
|
||||||
|
|
||||||
|
updated_item = await update_item(mock_db_session, db_item_model, item_update_data, user_model.id)
|
||||||
|
|
||||||
|
mock_db_session.add.assert_called_once_with(db_item_model) # add is used for existing objects too
|
||||||
|
mock_db_session.flush.assert_called_once()
|
||||||
|
mock_db_session.refresh.assert_called_once_with(db_item_model)
|
||||||
|
assert updated_item.name == "Newly Updated Name"
|
||||||
|
assert updated_item.version == db_item_model.version + 1 # Check version increment logic in function
|
||||||
|
assert updated_item.is_complete is True
|
||||||
|
assert updated_item.completed_by_id == user_model.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_item_version_conflict(mock_db_session, db_item_model, item_update_data, user_model):
|
||||||
|
item_update_data.version = db_item_model.version + 1 # Create a version mismatch
|
||||||
|
with pytest.raises(ConflictError):
|
||||||
|
await update_item(mock_db_session, db_item_model, item_update_data, user_model.id)
|
||||||
|
mock_db_session.rollback.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_item_set_incomplete(mock_db_session, db_item_model, item_update_data, user_model):
|
||||||
|
db_item_model.is_complete = True # Start as complete
|
||||||
|
db_item_model.completed_by_id = user_model.id
|
||||||
|
db_item_model.version = 1
|
||||||
|
|
||||||
|
item_update_data.version = 1
|
||||||
|
item_update_data.is_complete = False
|
||||||
|
item_update_data.name = db_item_model.name # No name change for this test
|
||||||
|
item_update_data.quantity = db_item_model.quantity
|
||||||
|
|
||||||
|
updated_item = await update_item(mock_db_session, db_item_model, item_update_data, user_model.id)
|
||||||
|
assert updated_item.is_complete is False
|
||||||
|
assert updated_item.completed_by_id is None
|
||||||
|
assert updated_item.version == 2
|
||||||
|
|
||||||
|
# --- delete_item Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_item_success(mock_db_session, db_item_model):
|
||||||
|
result = await delete_item(mock_db_session, db_item_model)
|
||||||
|
assert result is None
|
||||||
|
mock_db_session.delete.assert_called_once_with(db_item_model)
|
||||||
|
# Assuming delete_item commits the session or is called within a transaction that commits.
|
||||||
|
# If delete_item itself doesn't commit, this might need to be adjusted based on calling context.
|
||||||
|
# mock_db_session.commit.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_item_db_error(mock_db_session, db_item_model):
|
||||||
|
mock_db_session.delete.side_effect = OperationalError("mock op error", "params", "orig")
|
||||||
|
with pytest.raises(DatabaseTransactionError): # Changed to DatabaseTransactionError based on crud logic
|
||||||
|
await delete_item(mock_db_session, db_item_model)
|
||||||
|
mock_db_session.rollback.assert_called_once()
|
||||||
|
|
||||||
|
# TODO: Add more specific DB error tests (Operational, SQLAlchemyError) for each function.
|
351
be/tests/crud/test_list.py
Normal file
351
be/tests/crud/test_list.py
Normal file
@ -0,0 +1,351 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
from sqlalchemy import func as sql_func # For get_list_status
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from app.crud.list import (
|
||||||
|
create_list,
|
||||||
|
get_lists_for_user,
|
||||||
|
get_list_by_id,
|
||||||
|
update_list,
|
||||||
|
delete_list,
|
||||||
|
check_list_permission,
|
||||||
|
get_list_status
|
||||||
|
)
|
||||||
|
from app.schemas.list import ListCreate, ListUpdate, ListStatus
|
||||||
|
from app.models import List as ListModel, User as UserModel, Group as GroupModel, UserGroup as UserGroupModel, Item as ItemModel
|
||||||
|
from app.core.exceptions import (
|
||||||
|
ListNotFoundError,
|
||||||
|
ListPermissionError,
|
||||||
|
ListCreatorRequiredError,
|
||||||
|
DatabaseConnectionError,
|
||||||
|
DatabaseIntegrityError,
|
||||||
|
DatabaseQueryError,
|
||||||
|
DatabaseTransactionError,
|
||||||
|
ConflictError
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_db_session():
|
||||||
|
session = AsyncMock() # Overall session mock
|
||||||
|
|
||||||
|
# For session.begin() and session.begin_nested()
|
||||||
|
# These are sync methods returning an async context manager.
|
||||||
|
# The returned AsyncMock will act as the async context manager.
|
||||||
|
mock_transaction_context = AsyncMock()
|
||||||
|
session.begin = MagicMock(return_value=mock_transaction_context)
|
||||||
|
session.begin_nested = MagicMock(return_value=mock_transaction_context) # Can use the same or a new one
|
||||||
|
|
||||||
|
# Async methods on the session itself
|
||||||
|
session.commit = AsyncMock()
|
||||||
|
session.rollback = AsyncMock()
|
||||||
|
session.refresh = AsyncMock()
|
||||||
|
session.execute = AsyncMock() # Correct: execute is async
|
||||||
|
session.get = AsyncMock() # Correct: get is async
|
||||||
|
session.flush = AsyncMock() # Correct: flush is async
|
||||||
|
|
||||||
|
# Sync methods on the session
|
||||||
|
session.add = MagicMock()
|
||||||
|
session.delete = MagicMock()
|
||||||
|
session.in_transaction = MagicMock(return_value=False)
|
||||||
|
return session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def list_create_data():
|
||||||
|
return ListCreate(name="New Shopping List", description="Groceries for the week")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def list_update_data():
|
||||||
|
return ListUpdate(name="Updated Shopping List", description="Weekend Groceries", version=1)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_model():
|
||||||
|
return UserModel(id=1, name="Test User", email="test@example.com")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def another_user_model():
|
||||||
|
return UserModel(id=2, name="Another User", email="another@example.com")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def group_model():
|
||||||
|
return GroupModel(id=1, name="Test Group")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_list_personal_model(user_model):
|
||||||
|
return ListModel(
|
||||||
|
id=1, name="Personal List", created_by_id=user_model.id, creator=user_model,
|
||||||
|
version=1, updated_at=datetime.now(timezone.utc), items=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_list_group_model(user_model, group_model):
|
||||||
|
return ListModel(
|
||||||
|
id=2, name="Group List", created_by_id=user_model.id, creator=user_model,
|
||||||
|
group_id=group_model.id, group=group_model, version=1, updated_at=datetime.now(timezone.utc), items=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- create_list Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_list_success(mock_db_session, list_create_data, user_model):
|
||||||
|
async def mock_refresh(instance):
|
||||||
|
instance.id = 100
|
||||||
|
instance.version = 1
|
||||||
|
instance.updated_at = datetime.now(timezone.utc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
mock_db_session.refresh.side_effect = mock_refresh
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = ListModel(
|
||||||
|
id=100,
|
||||||
|
name=list_create_data.name,
|
||||||
|
description=list_create_data.description,
|
||||||
|
created_by_id=user_model.id,
|
||||||
|
version=1,
|
||||||
|
updated_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
created_list = await create_list(mock_db_session, list_create_data, user_model.id)
|
||||||
|
mock_db_session.add.assert_called_once()
|
||||||
|
mock_db_session.flush.assert_called_once()
|
||||||
|
assert created_list.name == list_create_data.name
|
||||||
|
assert created_list.created_by_id == user_model.id
|
||||||
|
|
||||||
|
# --- get_lists_for_user Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_lists_for_user_mix(mock_db_session, user_model, db_list_personal_model, db_list_group_model):
|
||||||
|
# Mock for the object returned by .scalars() for group_ids query
|
||||||
|
mock_group_ids_scalar_result = MagicMock()
|
||||||
|
mock_group_ids_scalar_result.all.return_value = [db_list_group_model.group_id]
|
||||||
|
|
||||||
|
# Mock for the object returned by await session.execute() for group_ids query
|
||||||
|
mock_group_ids_execute_result = MagicMock()
|
||||||
|
mock_group_ids_execute_result.scalars.return_value = mock_group_ids_scalar_result
|
||||||
|
|
||||||
|
# Mock for the object returned by .scalars() for lists query
|
||||||
|
mock_lists_scalar_result = MagicMock()
|
||||||
|
mock_lists_scalar_result.all.return_value = [db_list_personal_model, db_list_group_model]
|
||||||
|
|
||||||
|
# Mock for the object returned by await session.execute() for lists query
|
||||||
|
mock_lists_execute_result = MagicMock()
|
||||||
|
mock_lists_execute_result.scalars.return_value = mock_lists_scalar_result
|
||||||
|
|
||||||
|
mock_db_session.execute.side_effect = [mock_group_ids_execute_result, mock_lists_execute_result]
|
||||||
|
|
||||||
|
lists = await get_lists_for_user(mock_db_session, user_model.id)
|
||||||
|
assert len(lists) == 2
|
||||||
|
assert db_list_personal_model in lists
|
||||||
|
assert db_list_group_model in lists
|
||||||
|
assert mock_db_session.execute.call_count == 2
|
||||||
|
|
||||||
|
# --- get_list_by_id Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_list_by_id_found_no_items(mock_db_session, db_list_personal_model):
|
||||||
|
# Mock for the object returned by .scalars()
|
||||||
|
mock_scalar_result = MagicMock()
|
||||||
|
mock_scalar_result.first.return_value = db_list_personal_model
|
||||||
|
|
||||||
|
# Mock for the object returned by await session.execute()
|
||||||
|
mock_execute_result = MagicMock()
|
||||||
|
mock_execute_result.scalars.return_value = mock_scalar_result
|
||||||
|
|
||||||
|
mock_db_session.execute.return_value = mock_execute_result
|
||||||
|
|
||||||
|
found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=False)
|
||||||
|
assert found_list is not None
|
||||||
|
assert found_list.id == db_list_personal_model.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_list_by_id_found_with_items(mock_db_session, db_list_personal_model):
|
||||||
|
db_list_personal_model.items = [ItemModel(id=1, name="Test Item")]
|
||||||
|
# Mock for the object returned by .scalars()
|
||||||
|
mock_scalar_result = MagicMock()
|
||||||
|
mock_scalar_result.first.return_value = db_list_personal_model
|
||||||
|
|
||||||
|
# Mock for the object returned by await session.execute()
|
||||||
|
mock_execute_result = MagicMock()
|
||||||
|
mock_execute_result.scalars.return_value = mock_scalar_result
|
||||||
|
|
||||||
|
mock_db_session.execute.return_value = mock_execute_result
|
||||||
|
|
||||||
|
found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=True)
|
||||||
|
assert found_list is not None
|
||||||
|
assert len(found_list.items) == 1
|
||||||
|
|
||||||
|
# --- update_list Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_list_success(mock_db_session, db_list_personal_model, list_update_data):
|
||||||
|
list_update_data.version = db_list_personal_model.version
|
||||||
|
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = db_list_personal_model
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
updated_list = await update_list(mock_db_session, db_list_personal_model, list_update_data)
|
||||||
|
assert updated_list.name == list_update_data.name
|
||||||
|
assert updated_list.version == db_list_personal_model.version + 1
|
||||||
|
mock_db_session.add.assert_called_once_with(db_list_personal_model)
|
||||||
|
mock_db_session.flush.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_list_conflict(mock_db_session, db_list_personal_model, list_update_data):
|
||||||
|
list_update_data.version = db_list_personal_model.version + 1 # Simulate version mismatch
|
||||||
|
|
||||||
|
# When update_list is called with a version mismatch, it should raise ConflictError
|
||||||
|
with pytest.raises(ConflictError):
|
||||||
|
await update_list(mock_db_session, db_list_personal_model, list_update_data)
|
||||||
|
|
||||||
|
# Ensure rollback was called if a conflict occurred and was handled within update_list
|
||||||
|
# This depends on how update_list implements error handling.
|
||||||
|
# If update_list is expected to call session.rollback(), this assertion is valid.
|
||||||
|
# If the caller of update_list is responsible for rollback, this might not be asserted here.
|
||||||
|
# Based on the provided context, ConflictError is raised by update_list,
|
||||||
|
# implying internal rollback or no changes persisted.
|
||||||
|
# Let's assume for now the function itself handles rollback or prevents commit.
|
||||||
|
# mock_db_session.rollback.assert_called_once() # This might be too specific depending on impl.
|
||||||
|
|
||||||
|
# --- delete_list Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_list_success(mock_db_session, db_list_personal_model):
|
||||||
|
await delete_list(mock_db_session, db_list_personal_model)
|
||||||
|
mock_db_session.delete.assert_called_once_with(db_list_personal_model)
|
||||||
|
# mock_db_session.flush.assert_called_once() # delete usually implies a flush
|
||||||
|
|
||||||
|
# --- check_list_permission Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_list_permission_creator_access_personal_list(mock_db_session, db_list_personal_model, user_model):
|
||||||
|
# Mock for the object returned by .scalars()
|
||||||
|
mock_scalar_result = MagicMock()
|
||||||
|
mock_scalar_result.first.return_value = db_list_personal_model
|
||||||
|
|
||||||
|
# Mock for the object returned by await session.execute()
|
||||||
|
mock_execute_result = MagicMock()
|
||||||
|
mock_execute_result.scalars.return_value = mock_scalar_result
|
||||||
|
mock_db_session.execute.return_value = mock_execute_result
|
||||||
|
|
||||||
|
ret_list = await check_list_permission(mock_db_session, db_list_personal_model.id, user_model.id)
|
||||||
|
assert ret_list.id == db_list_personal_model.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_list_permission_group_member_access_group_list(mock_db_session, db_list_group_model, another_user_model, group_model):
|
||||||
|
# Mock for the object returned by .scalars()
|
||||||
|
mock_scalar_result = MagicMock()
|
||||||
|
mock_scalar_result.first.return_value = db_list_group_model
|
||||||
|
|
||||||
|
# Mock for the object returned by await session.execute()
|
||||||
|
mock_execute_result = MagicMock()
|
||||||
|
mock_execute_result.scalars.return_value = mock_scalar_result
|
||||||
|
mock_db_session.execute.return_value = mock_execute_result
|
||||||
|
|
||||||
|
with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member:
|
||||||
|
mock_is_member.return_value = True
|
||||||
|
ret_list = await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id)
|
||||||
|
assert ret_list.id == db_list_group_model.id
|
||||||
|
mock_is_member.assert_called_once_with(mock_db_session, group_id=group_model.id, user_id=another_user_model.id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_list_permission_non_member_no_access_group_list(mock_db_session, db_list_group_model, another_user_model):
|
||||||
|
# Mock for the object returned by .scalars()
|
||||||
|
mock_scalar_result = MagicMock()
|
||||||
|
mock_scalar_result.first.return_value = db_list_group_model
|
||||||
|
|
||||||
|
# Mock for the object returned by await session.execute()
|
||||||
|
mock_execute_result = MagicMock()
|
||||||
|
mock_execute_result.scalars.return_value = mock_scalar_result
|
||||||
|
mock_db_session.execute.return_value = mock_execute_result
|
||||||
|
|
||||||
|
with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member:
|
||||||
|
mock_is_member.return_value = False
|
||||||
|
with pytest.raises(ListPermissionError):
|
||||||
|
await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_list_permission_creator_required_fail(mock_db_session, db_list_group_model, another_user_model):
|
||||||
|
# Simulate another_user_model is not the creator of db_list_group_model
|
||||||
|
# db_list_group_model.created_by_id is user_model.id (1), another_user_model.id is 2
|
||||||
|
|
||||||
|
# Mock for the object returned by .scalars()
|
||||||
|
mock_scalar_result = MagicMock()
|
||||||
|
mock_scalar_result.first.return_value = db_list_group_model # List is found
|
||||||
|
|
||||||
|
# Mock for the object returned by await session.execute()
|
||||||
|
mock_execute_result = MagicMock()
|
||||||
|
mock_execute_result.scalars.return_value = mock_scalar_result
|
||||||
|
mock_db_session.execute.return_value = mock_execute_result
|
||||||
|
|
||||||
|
# No need to mock is_user_member if require_creator is True and user is not creator
|
||||||
|
|
||||||
|
with pytest.raises(ListCreatorRequiredError):
|
||||||
|
await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id, require_creator=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_list_permission_list_not_found(mock_db_session, user_model):
|
||||||
|
# Mock for the object returned by .scalars()
|
||||||
|
mock_scalar_result = MagicMock()
|
||||||
|
mock_scalar_result.first.return_value = None # Simulate list not found
|
||||||
|
|
||||||
|
# Mock for the object returned by await session.execute()
|
||||||
|
mock_execute_result = MagicMock()
|
||||||
|
mock_execute_result.scalars.return_value = mock_scalar_result
|
||||||
|
mock_db_session.execute.return_value = mock_execute_result
|
||||||
|
|
||||||
|
with pytest.raises(ListNotFoundError):
|
||||||
|
await check_list_permission(mock_db_session, 999, user_model.id)
|
||||||
|
|
||||||
|
# --- get_list_status Tests ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_list_status_success(mock_db_session, db_list_personal_model):
|
||||||
|
# This test is more complex due to multiple potential execute calls or specific query structures
|
||||||
|
# For simplicity, assuming the primary query for the list model uses the same pattern:
|
||||||
|
|
||||||
|
# Mock for finding the list by ID (first execute call in get_list_status)
|
||||||
|
mock_list_scalar = MagicMock()
|
||||||
|
mock_list_scalar.first.return_value = db_list_personal_model
|
||||||
|
mock_list_execute = MagicMock()
|
||||||
|
mock_list_execute.scalars.return_value = mock_list_scalar
|
||||||
|
|
||||||
|
# Mock for counting total items (second execute call)
|
||||||
|
mock_total_items_scalar = MagicMock()
|
||||||
|
mock_total_items_scalar.one.return_value = 5
|
||||||
|
mock_total_items_execute = MagicMock()
|
||||||
|
mock_total_items_execute.scalars.return_value = mock_total_items_scalar
|
||||||
|
|
||||||
|
# Mock for counting completed items (third execute call)
|
||||||
|
mock_completed_items_scalar = MagicMock()
|
||||||
|
mock_completed_items_scalar.one.return_value = 2
|
||||||
|
mock_completed_items_execute = MagicMock()
|
||||||
|
mock_completed_items_execute.scalars.return_value = mock_completed_items_scalar
|
||||||
|
|
||||||
|
mock_db_session.execute.side_effect = [
|
||||||
|
mock_list_execute,
|
||||||
|
mock_total_items_execute,
|
||||||
|
mock_completed_items_execute
|
||||||
|
]
|
||||||
|
|
||||||
|
status = await get_list_status(mock_db_session, db_list_personal_model.id)
|
||||||
|
assert status.list_id == db_list_personal_model.id
|
||||||
|
assert status.total_items == 5
|
||||||
|
assert status.completed_items == 2
|
||||||
|
assert status.name == db_list_personal_model.name
|
||||||
|
assert mock_db_session.execute.call_count == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_list_status_list_not_found(mock_db_session):
|
||||||
|
# Mock for the object returned by .scalars()
|
||||||
|
mock_scalar_result = MagicMock()
|
||||||
|
mock_scalar_result.first.return_value = None # List not found
|
||||||
|
|
||||||
|
# Mock for the object returned by await session.execute()
|
||||||
|
mock_execute_result = MagicMock()
|
||||||
|
mock_execute_result.scalars.return_value = mock_scalar_result
|
||||||
|
mock_db_session.execute.return_value = mock_execute_result
|
||||||
|
|
||||||
|
with pytest.raises(ListNotFoundError):
|
||||||
|
await get_list_status(mock_db_session, 999)
|
||||||
|
|
||||||
|
# TODO: Add more specific DB error tests (Operational, SQLAlchemyError, IntegrityError) for each function.
|
||||||
|
# TODO: Test check_list_permission with require_creator=True cases.
|
277
be/tests/crud/test_settlement.py
Normal file
277
be/tests/crud/test_settlement.py
Normal file
@ -0,0 +1,277 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
from decimal import Decimal, ROUND_HALF_UP
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import List as PyList
|
||||||
|
|
||||||
|
from app.crud.settlement import (
|
||||||
|
create_settlement,
|
||||||
|
get_settlement_by_id,
|
||||||
|
get_settlements_for_group,
|
||||||
|
get_settlements_involving_user,
|
||||||
|
update_settlement,
|
||||||
|
delete_settlement
|
||||||
|
)
|
||||||
|
from app.schemas.expense import SettlementCreate, SettlementUpdate
|
||||||
|
from app.models import Settlement as SettlementModel, User as UserModel, Group as GroupModel
|
||||||
|
from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError, ConflictError
|
||||||
|
|
||||||
|
# Fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_db_session():
|
||||||
|
session = AsyncMock()
|
||||||
|
session.begin = AsyncMock()
|
||||||
|
session.begin_nested = AsyncMock()
|
||||||
|
session.commit = AsyncMock()
|
||||||
|
session.rollback = AsyncMock()
|
||||||
|
session.refresh = AsyncMock()
|
||||||
|
session.add = MagicMock()
|
||||||
|
session.delete = MagicMock()
|
||||||
|
session.execute = AsyncMock()
|
||||||
|
session.get = AsyncMock()
|
||||||
|
session.flush = AsyncMock()
|
||||||
|
session.in_transaction = MagicMock(return_value=False)
|
||||||
|
return session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def settlement_create_data():
|
||||||
|
return SettlementCreate(
|
||||||
|
group_id=1,
|
||||||
|
paid_by_user_id=1,
|
||||||
|
paid_to_user_id=2,
|
||||||
|
amount=Decimal("10.50"),
|
||||||
|
settlement_date=datetime.now(timezone.utc),
|
||||||
|
description="Test settlement"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def settlement_update_data():
|
||||||
|
return SettlementUpdate(
|
||||||
|
description="Updated settlement description",
|
||||||
|
settlement_date=datetime.now(timezone.utc),
|
||||||
|
version=1 # Assuming version is required for update
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_settlement_model():
|
||||||
|
return SettlementModel(
|
||||||
|
id=1,
|
||||||
|
group_id=1,
|
||||||
|
paid_by_user_id=1,
|
||||||
|
paid_to_user_id=2,
|
||||||
|
amount=Decimal("10.50"),
|
||||||
|
settlement_date=datetime.now(timezone.utc),
|
||||||
|
description="Original settlement",
|
||||||
|
created_by_user_id=1,
|
||||||
|
version=1, # Initial version
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
payer=UserModel(id=1, name="Payer User"),
|
||||||
|
payee=UserModel(id=2, name="Payee User"),
|
||||||
|
group=GroupModel(id=1, name="Test Group"),
|
||||||
|
created_by_user=UserModel(id=1, name="Payer User") # Same as payer for simplicity
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def payer_user_model():
|
||||||
|
return UserModel(id=1, name="Payer User", email="payer@example.com")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def payee_user_model():
|
||||||
|
return UserModel(id=2, name="Payee User", email="payee@example.com")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def group_model():
|
||||||
|
return GroupModel(id=1, name="Test Group")
|
||||||
|
|
||||||
|
# Tests for create_settlement
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_settlement_success(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model):
|
||||||
|
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model]
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = SettlementModel(
|
||||||
|
id=1,
|
||||||
|
group_id=settlement_create_data.group_id,
|
||||||
|
paid_by_user_id=settlement_create_data.paid_by_user_id,
|
||||||
|
paid_to_user_id=settlement_create_data.paid_to_user_id,
|
||||||
|
amount=settlement_create_data.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||||
|
settlement_date=settlement_create_data.settlement_date,
|
||||||
|
description=settlement_create_data.description,
|
||||||
|
created_by_user_id=1,
|
||||||
|
version=1,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
updated_at=datetime.now(timezone.utc)
|
||||||
|
)
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
created_settlement = await create_settlement(mock_db_session, settlement_create_data, current_user_id=1)
|
||||||
|
mock_db_session.add.assert_called_once()
|
||||||
|
mock_db_session.flush.assert_called_once()
|
||||||
|
assert created_settlement is not None
|
||||||
|
assert created_settlement.amount == settlement_create_data.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||||
|
assert created_settlement.paid_by_user_id == settlement_create_data.paid_by_user_id
|
||||||
|
assert created_settlement.paid_to_user_id == settlement_create_data.paid_to_user_id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_settlement_payer_not_found(mock_db_session, settlement_create_data):
|
||||||
|
mock_db_session.get.side_effect = [None, payee_user_model, group_model]
|
||||||
|
with pytest.raises(UserNotFoundError) as excinfo:
|
||||||
|
await create_settlement(mock_db_session, settlement_create_data, 1)
|
||||||
|
assert "Payer" in str(excinfo.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_settlement_payee_not_found(mock_db_session, settlement_create_data, payer_user_model):
|
||||||
|
mock_db_session.get.side_effect = [payer_user_model, None, group_model]
|
||||||
|
with pytest.raises(UserNotFoundError) as excinfo:
|
||||||
|
await create_settlement(mock_db_session, settlement_create_data, 1)
|
||||||
|
assert "Payee" in str(excinfo.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_settlement_group_not_found(mock_db_session, settlement_create_data, payer_user_model, payee_user_model):
|
||||||
|
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, None]
|
||||||
|
with pytest.raises(GroupNotFoundError):
|
||||||
|
await create_settlement(mock_db_session, settlement_create_data, 1)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_settlement_payer_equals_payee(mock_db_session, settlement_create_data, payer_user_model, group_model):
|
||||||
|
settlement_create_data.paid_to_user_id = settlement_create_data.paid_by_user_id
|
||||||
|
mock_db_session.get.side_effect = [payer_user_model, payer_user_model, group_model]
|
||||||
|
with pytest.raises(InvalidOperationError) as excinfo:
|
||||||
|
await create_settlement(mock_db_session, settlement_create_data, 1)
|
||||||
|
assert "Payer and Payee cannot be the same user" in str(excinfo.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_settlement_commit_failure(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model):
|
||||||
|
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model]
|
||||||
|
mock_db_session.commit.side_effect = Exception("DB commit failed")
|
||||||
|
with pytest.raises(InvalidOperationError) as excinfo:
|
||||||
|
await create_settlement(mock_db_session, settlement_create_data, 1)
|
||||||
|
assert "Failed to save settlement" in str(excinfo.value)
|
||||||
|
mock_db_session.rollback.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for get_settlement_by_id
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_settlement_by_id_found(mock_db_session, db_settlement_model):
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalars.return_value.first.return_value = db_settlement_model
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
settlement = await get_settlement_by_id(mock_db_session, 1)
|
||||||
|
assert settlement is not None
|
||||||
|
assert settlement.id == 1
|
||||||
|
mock_db_session.execute.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_settlement_by_id_not_found(mock_db_session):
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalars.return_value.first.return_value = None
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
settlement = await get_settlement_by_id(mock_db_session, 999)
|
||||||
|
assert settlement is None
|
||||||
|
|
||||||
|
# Tests for get_settlements_for_group
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_settlements_for_group_success(mock_db_session, db_settlement_model):
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalars.return_value.all.return_value = [db_settlement_model]
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
settlements = await get_settlements_for_group(mock_db_session, group_id=1)
|
||||||
|
assert len(settlements) == 1
|
||||||
|
assert settlements[0].group_id == 1
|
||||||
|
mock_db_session.execute.assert_called_once()
|
||||||
|
|
||||||
|
# Tests for get_settlements_involving_user
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_settlements_involving_user_success(mock_db_session, db_settlement_model):
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalars.return_value.all.return_value = [db_settlement_model]
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
settlements = await get_settlements_involving_user(mock_db_session, user_id=1)
|
||||||
|
assert len(settlements) == 1
|
||||||
|
assert settlements[0].paid_by_user_id == 1 or settlements[0].paid_to_user_id == 1
|
||||||
|
mock_db_session.execute.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_settlements_involving_user_with_group_filter(mock_db_session, db_settlement_model):
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalars.return_value.all.return_value = [db_settlement_model]
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
settlements = await get_settlements_involving_user(mock_db_session, user_id=1, group_id=1)
|
||||||
|
assert len(settlements) == 1
|
||||||
|
mock_db_session.execute.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for update_settlement
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_settlement_success(mock_db_session, db_settlement_model, settlement_update_data):
|
||||||
|
settlement_update_data.version = db_settlement_model.version
|
||||||
|
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = db_settlement_model
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
updated_settlement = await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
|
||||||
|
mock_db_session.add.assert_called_once_with(db_settlement_model)
|
||||||
|
mock_db_session.flush.assert_called_once()
|
||||||
|
assert updated_settlement.description == settlement_update_data.description
|
||||||
|
assert updated_settlement.settlement_date == settlement_update_data.settlement_date
|
||||||
|
assert updated_settlement.version == db_settlement_model.version + 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_settlement_version_mismatch(mock_db_session, db_settlement_model, settlement_update_data):
|
||||||
|
settlement_update_data.version = db_settlement_model.version + 1
|
||||||
|
with pytest.raises(ConflictError):
|
||||||
|
await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
|
||||||
|
mock_db_session.rollback.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_settlement_disallowed_field(mock_db_session, db_settlement_model):
|
||||||
|
# Create an update payload with a field not allowed to be updated, e.g., 'amount'
|
||||||
|
invalid_update_data = SettlementUpdate(amount=Decimal("100.00"), version=db_settlement_model.version)
|
||||||
|
with pytest.raises(InvalidOperationError) as excinfo:
|
||||||
|
await update_settlement(mock_db_session, db_settlement_model, invalid_update_data)
|
||||||
|
assert "Field 'amount' cannot be updated" in str(excinfo.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_settlement_commit_failure(mock_db_session, db_settlement_model, settlement_update_data):
|
||||||
|
settlement_update_data.version = db_settlement_model.version
|
||||||
|
mock_db_session.commit.side_effect = Exception("DB commit failed")
|
||||||
|
with pytest.raises(InvalidOperationError) as excinfo:
|
||||||
|
await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
|
||||||
|
assert "Failed to update settlement" in str(excinfo.value)
|
||||||
|
mock_db_session.rollback.assert_called_once()
|
||||||
|
|
||||||
|
# Tests for delete_settlement
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_settlement_success(mock_db_session, db_settlement_model):
|
||||||
|
await delete_settlement(mock_db_session, db_settlement_model)
|
||||||
|
mock_db_session.delete.assert_called_once_with(db_settlement_model)
|
||||||
|
mock_db_session.commit.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_settlement_success_with_version_check(mock_db_session, db_settlement_model):
|
||||||
|
await delete_settlement(mock_db_session, db_settlement_model, expected_version=db_settlement_model.version)
|
||||||
|
mock_db_session.delete.assert_called_once_with(db_settlement_model)
|
||||||
|
mock_db_session.commit.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_settlement_version_mismatch(mock_db_session, db_settlement_model):
|
||||||
|
db_settlement_model.version = 2
|
||||||
|
with pytest.raises(ConflictError):
|
||||||
|
await delete_settlement(mock_db_session, db_settlement_model, expected_version=1)
|
||||||
|
mock_db_session.rollback.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_settlement_commit_failure(mock_db_session, db_settlement_model):
|
||||||
|
mock_db_session.commit.side_effect = Exception("DB commit failed")
|
||||||
|
with pytest.raises(InvalidOperationError) as excinfo:
|
||||||
|
await delete_settlement(mock_db_session, db_settlement_model)
|
||||||
|
assert "Failed to delete settlement" in str(excinfo.value)
|
||||||
|
mock_db_session.rollback.assert_called_once()
|
369
be/tests/crud/test_settlement_activity.py
Normal file
369
be/tests/crud/test_settlement_activity.py
Normal file
@ -0,0 +1,369 @@
|
|||||||
|
import pytest
|
||||||
|
from decimal import Decimal
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import AsyncGenerator, List
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models import (
|
||||||
|
User,
|
||||||
|
Group,
|
||||||
|
Expense,
|
||||||
|
ExpenseSplit,
|
||||||
|
SettlementActivity,
|
||||||
|
ExpenseSplitStatusEnum,
|
||||||
|
ExpenseOverallStatusEnum,
|
||||||
|
SplitTypeEnum,
|
||||||
|
UserRoleEnum
|
||||||
|
)
|
||||||
|
from app.crud.settlement_activity import (
|
||||||
|
create_settlement_activity,
|
||||||
|
get_settlement_activity_by_id,
|
||||||
|
get_settlement_activities_for_split,
|
||||||
|
update_expense_split_status, # For direct testing if needed
|
||||||
|
update_expense_overall_status # For direct testing if needed
|
||||||
|
)
|
||||||
|
from app.schemas.settlement_activity import SettlementActivityCreate as SettlementActivityCreateSchema
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def test_user1(db_session: AsyncSession) -> User:
|
||||||
|
user = User(email="user1@example.com", name="Test User 1", hashed_password="password1")
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def test_user2(db_session: AsyncSession) -> User:
|
||||||
|
user = User(email="user2@example.com", name="Test User 2", hashed_password="password2")
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def test_group(db_session: AsyncSession, test_user1: User) -> Group:
|
||||||
|
group = Group(name="Test Group", created_by_id=test_user1.id)
|
||||||
|
db_session.add(group)
|
||||||
|
await db_session.commit()
|
||||||
|
# Add user1 as owner and user2 as member (can be done in specific tests if needed)
|
||||||
|
await db_session.refresh(group)
|
||||||
|
return group
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def test_expense(db_session: AsyncSession, test_user1: User, test_group: Group) -> Expense:
|
||||||
|
expense = Expense(
|
||||||
|
description="Test Expense for Settlement",
|
||||||
|
total_amount=Decimal("20.00"),
|
||||||
|
currency="USD",
|
||||||
|
expense_date=datetime.now(timezone.utc),
|
||||||
|
split_type=SplitTypeEnum.EQUAL,
|
||||||
|
group_id=test_group.id,
|
||||||
|
paid_by_user_id=test_user1.id,
|
||||||
|
created_by_user_id=test_user1.id,
|
||||||
|
overall_settlement_status=ExpenseOverallStatusEnum.unpaid # Initial status
|
||||||
|
)
|
||||||
|
db_session.add(expense)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(expense)
|
||||||
|
return expense
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def test_expense_split_user2_owes(db_session: AsyncSession, test_expense: Expense, test_user2: User) -> ExpenseSplit:
|
||||||
|
# User2 owes 10.00 to User1 (who paid the expense)
|
||||||
|
split = ExpenseSplit(
|
||||||
|
expense_id=test_expense.id,
|
||||||
|
user_id=test_user2.id,
|
||||||
|
owed_amount=Decimal("10.00"),
|
||||||
|
status=ExpenseSplitStatusEnum.unpaid # Initial status
|
||||||
|
)
|
||||||
|
db_session.add(split)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(split)
|
||||||
|
return split
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def test_expense_split_user1_owes_self_for_completeness(db_session: AsyncSession, test_expense: Expense, test_user1: User) -> ExpenseSplit:
|
||||||
|
# User1's own share (owes 10.00 to self, effectively settled)
|
||||||
|
# This is often how splits are represented, even for the payer
|
||||||
|
split = ExpenseSplit(
|
||||||
|
expense_id=test_expense.id,
|
||||||
|
user_id=test_user1.id,
|
||||||
|
owed_amount=Decimal("10.00"), # User1's share of the 20.00 expense
|
||||||
|
status=ExpenseSplitStatusEnum.unpaid # Initial status, though payer's own share might be considered paid by some logic
|
||||||
|
)
|
||||||
|
db_session.add(split)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(split)
|
||||||
|
return split
|
||||||
|
|
||||||
|
|
||||||
|
# --- Tests for create_settlement_activity ---
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_settlement_activity_full_payment(
|
||||||
|
db_session: AsyncSession,
|
||||||
|
test_user1: User, # Creator of activity, Payer of expense
|
||||||
|
test_user2: User, # Payer of this settlement activity (settling their debt)
|
||||||
|
test_expense: Expense,
|
||||||
|
test_expense_split_user2_owes: ExpenseSplit,
|
||||||
|
test_expense_split_user1_owes_self_for_completeness: ExpenseSplit # User1's own share
|
||||||
|
):
|
||||||
|
# Scenario: User2 fully pays their 10.00 share.
|
||||||
|
# User1's share is also part of the expense. Let's assume it's 'paid' by default or handled separately.
|
||||||
|
# For this test, we focus on User2's split.
|
||||||
|
# To make overall expense paid, User1's split also needs to be considered paid.
|
||||||
|
# We can manually update User1's split status to paid for this test case.
|
||||||
|
test_expense_split_user1_owes_self_for_completeness.status = ExpenseSplitStatusEnum.paid
|
||||||
|
test_expense_split_user1_owes_self_for_completeness.paid_at = datetime.now(timezone.utc)
|
||||||
|
db_session.add(test_expense_split_user1_owes_self_for_completeness)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(test_expense_split_user1_owes_self_for_completeness)
|
||||||
|
await db_session.refresh(test_expense) # Refresh expense to reflect split status change
|
||||||
|
|
||||||
|
|
||||||
|
activity_data = SettlementActivityCreateSchema(
|
||||||
|
expense_split_id=test_expense_split_user2_owes.id,
|
||||||
|
paid_by_user_id=test_user2.id, # User2 is paying their share
|
||||||
|
amount_paid=Decimal("10.00")
|
||||||
|
)
|
||||||
|
|
||||||
|
created_activity = await create_settlement_activity(
|
||||||
|
db=db_session,
|
||||||
|
settlement_activity_in=activity_data,
|
||||||
|
current_user_id=test_user2.id # User2 is recording their own payment
|
||||||
|
)
|
||||||
|
|
||||||
|
assert created_activity is not None
|
||||||
|
assert created_activity.expense_split_id == test_expense_split_user2_owes.id
|
||||||
|
assert created_activity.paid_by_user_id == test_user2.id
|
||||||
|
assert created_activity.amount_paid == Decimal("10.00")
|
||||||
|
assert created_activity.created_by_user_id == test_user2.id
|
||||||
|
|
||||||
|
await db_session.refresh(test_expense_split_user2_owes)
|
||||||
|
await db_session.refresh(test_expense) # Refresh to get updated overall_status
|
||||||
|
|
||||||
|
assert test_expense_split_user2_owes.status == ExpenseSplitStatusEnum.paid
|
||||||
|
assert test_expense_split_user2_owes.paid_at is not None
|
||||||
|
|
||||||
|
# Check parent expense status
|
||||||
|
# This depends on all splits being paid for the expense to be fully paid.
|
||||||
|
assert test_expense.overall_settlement_status == ExpenseOverallStatusEnum.paid
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_settlement_activity_partial_payment(
|
||||||
|
db_session: AsyncSession,
|
||||||
|
test_user1: User, # Creator of activity
|
||||||
|
test_user2: User, # Payer of this settlement activity
|
||||||
|
test_expense: Expense,
|
||||||
|
test_expense_split_user2_owes: ExpenseSplit
|
||||||
|
):
|
||||||
|
activity_data = SettlementActivityCreateSchema(
|
||||||
|
expense_split_id=test_expense_split_user2_owes.id,
|
||||||
|
paid_by_user_id=test_user2.id,
|
||||||
|
amount_paid=Decimal("5.00")
|
||||||
|
)
|
||||||
|
|
||||||
|
created_activity = await create_settlement_activity(
|
||||||
|
db=db_session,
|
||||||
|
settlement_activity_in=activity_data,
|
||||||
|
current_user_id=test_user2.id # User2 records their payment
|
||||||
|
)
|
||||||
|
|
||||||
|
assert created_activity is not None
|
||||||
|
assert created_activity.amount_paid == Decimal("5.00")
|
||||||
|
|
||||||
|
await db_session.refresh(test_expense_split_user2_owes)
|
||||||
|
await db_session.refresh(test_expense)
|
||||||
|
|
||||||
|
assert test_expense_split_user2_owes.status == ExpenseSplitStatusEnum.partially_paid
|
||||||
|
assert test_expense_split_user2_owes.paid_at is None
|
||||||
|
assert test_expense.overall_settlement_status == ExpenseOverallStatusEnum.partially_paid # Assuming other splits are unpaid or partially paid
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_settlement_activity_multiple_payments_to_full(
|
||||||
|
db_session: AsyncSession,
|
||||||
|
test_user1: User,
|
||||||
|
test_user2: User,
|
||||||
|
test_expense: Expense,
|
||||||
|
test_expense_split_user2_owes: ExpenseSplit,
|
||||||
|
test_expense_split_user1_owes_self_for_completeness: ExpenseSplit # User1's own share
|
||||||
|
):
|
||||||
|
# Assume user1's share is already 'paid' for overall expense status testing
|
||||||
|
test_expense_split_user1_owes_self_for_completeness.status = ExpenseSplitStatusEnum.paid
|
||||||
|
test_expense_split_user1_owes_self_for_completeness.paid_at = datetime.now(timezone.utc)
|
||||||
|
db_session.add(test_expense_split_user1_owes_self_for_completeness)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# First partial payment
|
||||||
|
activity_data1 = SettlementActivityCreateSchema(
|
||||||
|
expense_split_id=test_expense_split_user2_owes.id,
|
||||||
|
paid_by_user_id=test_user2.id,
|
||||||
|
amount_paid=Decimal("3.00")
|
||||||
|
)
|
||||||
|
await create_settlement_activity(db=db_session, settlement_activity_in=activity_data1, current_user_id=test_user2.id)
|
||||||
|
|
||||||
|
await db_session.refresh(test_expense_split_user2_owes)
|
||||||
|
await db_session.refresh(test_expense)
|
||||||
|
assert test_expense_split_user2_owes.status == ExpenseSplitStatusEnum.partially_paid
|
||||||
|
assert test_expense.overall_settlement_status == ExpenseOverallStatusEnum.partially_paid
|
||||||
|
|
||||||
|
# Second payment completing the amount
|
||||||
|
activity_data2 = SettlementActivityCreateSchema(
|
||||||
|
expense_split_id=test_expense_split_user2_owes.id,
|
||||||
|
paid_by_user_id=test_user2.id,
|
||||||
|
amount_paid=Decimal("7.00") # 3.00 + 7.00 = 10.00
|
||||||
|
)
|
||||||
|
await create_settlement_activity(db=db_session, settlement_activity_in=activity_data2, current_user_id=test_user2.id)
|
||||||
|
|
||||||
|
await db_session.refresh(test_expense_split_user2_owes)
|
||||||
|
await db_session.refresh(test_expense)
|
||||||
|
assert test_expense_split_user2_owes.status == ExpenseSplitStatusEnum.paid
|
||||||
|
assert test_expense_split_user2_owes.paid_at is not None
|
||||||
|
assert test_expense.overall_settlement_status == ExpenseOverallStatusEnum.paid
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_settlement_activity_invalid_split_id(
|
||||||
|
db_session: AsyncSession, test_user1: User
|
||||||
|
):
|
||||||
|
activity_data = SettlementActivityCreateSchema(
|
||||||
|
expense_split_id=99999, # Non-existent
|
||||||
|
paid_by_user_id=test_user1.id,
|
||||||
|
amount_paid=Decimal("10.00")
|
||||||
|
)
|
||||||
|
# The CRUD function returns None for not found related objects
|
||||||
|
result = await create_settlement_activity(db=db_session, settlement_activity_in=activity_data, current_user_id=test_user1.id)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_settlement_activity_invalid_paid_by_user_id(
|
||||||
|
db_session: AsyncSession, test_user1: User, test_expense_split_user2_owes: ExpenseSplit
|
||||||
|
):
|
||||||
|
activity_data = SettlementActivityCreateSchema(
|
||||||
|
expense_split_id=test_expense_split_user2_owes.id,
|
||||||
|
paid_by_user_id=99999, # Non-existent
|
||||||
|
amount_paid=Decimal("10.00")
|
||||||
|
)
|
||||||
|
result = await create_settlement_activity(db=db_session, settlement_activity_in=activity_data, current_user_id=test_user1.id)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# --- Tests for get_settlement_activity_by_id ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_settlement_activity_by_id_found(
|
||||||
|
db_session: AsyncSession, test_user2: User, test_expense_split_user2_owes: ExpenseSplit
|
||||||
|
):
|
||||||
|
activity_data = SettlementActivityCreateSchema(
|
||||||
|
expense_split_id=test_expense_split_user2_owes.id,
|
||||||
|
paid_by_user_id=test_user2.id,
|
||||||
|
amount_paid=Decimal("5.00")
|
||||||
|
)
|
||||||
|
created = await create_settlement_activity(db=db_session, settlement_activity_in=activity_data, current_user_id=test_user2.id)
|
||||||
|
assert created is not None
|
||||||
|
|
||||||
|
fetched = await get_settlement_activity_by_id(db=db_session, settlement_activity_id=created.id)
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.id == created.id
|
||||||
|
assert fetched.amount_paid == Decimal("5.00")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_settlement_activity_by_id_not_found(db_session: AsyncSession):
|
||||||
|
fetched = await get_settlement_activity_by_id(db=db_session, settlement_activity_id=99999)
|
||||||
|
assert fetched is None
|
||||||
|
|
||||||
|
|
||||||
|
# --- Tests for get_settlement_activities_for_split ---
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_settlement_activities_for_split_multiple_found(
|
||||||
|
db_session: AsyncSession, test_user2: User, test_expense_split_user2_owes: ExpenseSplit
|
||||||
|
):
|
||||||
|
act1_data = SettlementActivityCreateSchema(expense_split_id=test_expense_split_user2_owes.id, paid_by_user_id=test_user2.id, amount_paid=Decimal("2.00"))
|
||||||
|
act2_data = SettlementActivityCreateSchema(expense_split_id=test_expense_split_user2_owes.id, paid_by_user_id=test_user2.id, amount_paid=Decimal("3.00"))
|
||||||
|
|
||||||
|
await create_settlement_activity(db=db_session, settlement_activity_in=act1_data, current_user_id=test_user2.id)
|
||||||
|
await create_settlement_activity(db=db_session, settlement_activity_in=act2_data, current_user_id=test_user2.id)
|
||||||
|
|
||||||
|
activities: List[SettlementActivity] = await get_settlement_activities_for_split(db=db_session, expense_split_id=test_expense_split_user2_owes.id)
|
||||||
|
assert len(activities) == 2
|
||||||
|
amounts = sorted([act.amount_paid for act in activities])
|
||||||
|
assert amounts == [Decimal("2.00"), Decimal("3.00")]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_settlement_activities_for_split_none_found(
|
||||||
|
db_session: AsyncSession, test_expense_split_user2_owes: ExpenseSplit # A split with no activities
|
||||||
|
):
|
||||||
|
activities: List[SettlementActivity] = await get_settlement_activities_for_split(db=db_session, expense_split_id=test_expense_split_user2_owes.id)
|
||||||
|
assert len(activities) == 0
|
||||||
|
|
||||||
|
# Note: Direct tests for helper functions update_expense_split_status and update_expense_overall_status
|
||||||
|
# could be added if complex logic within them isn't fully covered by create_settlement_activity tests.
|
||||||
|
# However, their effects are validated through the main CRUD function here.
|
||||||
|
# For example, to test update_expense_split_status directly:
|
||||||
|
# 1. Create an ExpenseSplit.
|
||||||
|
# 2. Create one or more SettlementActivity instances directly in the DB session for that split.
|
||||||
|
# 3. Call await update_expense_split_status(db_session, expense_split_id=split.id).
|
||||||
|
# 4. Assert the split.status and split.paid_at are as expected.
|
||||||
|
# Similar for update_expense_overall_status by setting up multiple splits.
|
||||||
|
# For now, relying on indirect testing via create_settlement_activity.
|
||||||
|
|
||||||
|
# More tests can be added for edge cases, such as:
|
||||||
|
# - Overpayment (current logic in update_expense_split_status treats >= owed_amount as 'paid').
|
||||||
|
# - Different users creating the activity vs. paying for it (permission aspects, though that's more for API tests).
|
||||||
|
# - Interactions with different expense split types if that affects status updates.
|
||||||
|
# - Ensuring `overall_settlement_status` correctly reflects if one split is paid, another is unpaid, etc.
|
||||||
|
# (e.g. test_expense_split_user1_owes_self_for_completeness is set to unpaid initially).
|
||||||
|
# A test case where one split becomes 'paid' but another remains 'unpaid' should result in 'partially_paid' for the expense.
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_settlement_activity_overall_status_becomes_partially_paid(
|
||||||
|
db_session: AsyncSession,
|
||||||
|
test_user1: User,
|
||||||
|
test_user2: User,
|
||||||
|
test_expense: Expense, # Overall status is initially unpaid
|
||||||
|
test_expense_split_user2_owes: ExpenseSplit, # User2's split, initially unpaid
|
||||||
|
test_expense_split_user1_owes_self_for_completeness: ExpenseSplit # User1's split, also initially unpaid
|
||||||
|
):
|
||||||
|
# Sanity check: both splits and expense are unpaid initially
|
||||||
|
assert test_expense_split_user2_owes.status == ExpenseSplitStatusEnum.unpaid
|
||||||
|
assert test_expense_split_user1_owes_self_for_completeness.status == ExpenseSplitStatusEnum.unpaid
|
||||||
|
assert test_expense.overall_settlement_status == ExpenseOverallStatusEnum.unpaid
|
||||||
|
|
||||||
|
# User2 fully pays their 10.00 share.
|
||||||
|
activity_data = SettlementActivityCreateSchema(
|
||||||
|
expense_split_id=test_expense_split_user2_owes.id,
|
||||||
|
paid_by_user_id=test_user2.id, # User2 is paying their share
|
||||||
|
amount_paid=Decimal("10.00")
|
||||||
|
)
|
||||||
|
|
||||||
|
await create_settlement_activity(
|
||||||
|
db=db_session,
|
||||||
|
settlement_activity_in=activity_data,
|
||||||
|
current_user_id=test_user2.id # User2 is recording their own payment
|
||||||
|
)
|
||||||
|
|
||||||
|
await db_session.refresh(test_expense_split_user2_owes)
|
||||||
|
await db_session.refresh(test_expense_split_user1_owes_self_for_completeness) # Ensure its status is current
|
||||||
|
await db_session.refresh(test_expense)
|
||||||
|
|
||||||
|
assert test_expense_split_user2_owes.status == ExpenseSplitStatusEnum.paid
|
||||||
|
assert test_expense_split_user1_owes_self_for_completeness.status == ExpenseSplitStatusEnum.unpaid # User1's split is still unpaid
|
||||||
|
|
||||||
|
# Since one split is paid and the other is unpaid, the overall expense status should be partially_paid
|
||||||
|
assert test_expense.overall_settlement_status == ExpenseOverallStatusEnum.partially_paid
|
||||||
|
|
||||||
|
# Example of a placeholder for db_session fixture if not provided by conftest.py
|
||||||
|
# @pytest.fixture
|
||||||
|
# async def db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
# # This needs to be implemented based on your test database setup
|
||||||
|
# # e.g., using a test-specific database and creating a new session per test
|
||||||
|
# # from app.database import SessionLocal # Assuming SessionLocal is your session factory
|
||||||
|
# # async with SessionLocal() as session:
|
||||||
|
# # async with session.begin(): # Start a transaction
|
||||||
|
# # yield session
|
||||||
|
# # # Transaction will be rolled back here after the test
|
||||||
|
# pass # Replace with actual implementation if needed
|
128
be/tests/crud/test_user.py
Normal file
128
be/tests/crud/test_user.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
|
|
||||||
|
from app.crud.user import get_user_by_email, create_user
|
||||||
|
from app.schemas.user import UserCreate
|
||||||
|
from app.models import User as UserModel
|
||||||
|
from app.core.exceptions import (
|
||||||
|
UserCreationError,
|
||||||
|
EmailAlreadyRegisteredError,
|
||||||
|
DatabaseConnectionError,
|
||||||
|
DatabaseIntegrityError,
|
||||||
|
DatabaseQueryError,
|
||||||
|
DatabaseTransactionError
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_db_session():
|
||||||
|
session = AsyncMock()
|
||||||
|
session.begin = AsyncMock()
|
||||||
|
session.begin_nested = AsyncMock()
|
||||||
|
session.commit = AsyncMock()
|
||||||
|
session.rollback = AsyncMock()
|
||||||
|
session.refresh = AsyncMock()
|
||||||
|
session.add = MagicMock()
|
||||||
|
session.delete = MagicMock()
|
||||||
|
session.execute = AsyncMock()
|
||||||
|
session.get = AsyncMock()
|
||||||
|
session.flush = AsyncMock()
|
||||||
|
session.in_transaction = MagicMock(return_value=False)
|
||||||
|
return session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_create_data():
|
||||||
|
return UserCreate(email="test@example.com", password="password123", name="Test User")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def existing_user_data():
|
||||||
|
return UserModel(id=1, email="exists@example.com", password_hash="hashed_password", name="Existing User")
|
||||||
|
|
||||||
|
# Tests for get_user_by_email
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_by_email_found(mock_db_session, existing_user_data):
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalars.return_value.first.return_value = existing_user_data
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
user = await get_user_by_email(mock_db_session, "exists@example.com")
|
||||||
|
assert user is not None
|
||||||
|
assert user.email == "exists@example.com"
|
||||||
|
mock_db_session.execute.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_by_email_not_found(mock_db_session):
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalars.return_value.first.return_value = None
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
user = await get_user_by_email(mock_db_session, "nonexistent@example.com")
|
||||||
|
assert user is None
|
||||||
|
mock_db_session.execute.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_by_email_db_connection_error(mock_db_session):
|
||||||
|
mock_db_session.execute.side_effect = OperationalError("mock_op_error", "params", "orig")
|
||||||
|
with pytest.raises(DatabaseConnectionError):
|
||||||
|
await get_user_by_email(mock_db_session, "test@example.com")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_by_email_db_query_error(mock_db_session):
|
||||||
|
# Simulate a generic SQLAlchemyError that is not OperationalError
|
||||||
|
mock_db_session.execute.side_effect = IntegrityError("mock_sql_error", "params", "orig") # Using IntegrityError as an example of SQLAlchemyError
|
||||||
|
with pytest.raises(DatabaseQueryError):
|
||||||
|
await get_user_by_email(mock_db_session, "test@example.com")
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for create_user
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_user_success(mock_db_session, user_create_data):
|
||||||
|
mock_result = AsyncMock()
|
||||||
|
mock_result.scalar_one_or_none.return_value = UserModel(
|
||||||
|
id=1,
|
||||||
|
email=user_create_data.email,
|
||||||
|
name=user_create_data.name,
|
||||||
|
password_hash="hashed_password" # This would be set by the actual hash_password function
|
||||||
|
)
|
||||||
|
mock_db_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
created_user = await create_user(mock_db_session, user_create_data)
|
||||||
|
mock_db_session.add.assert_called_once()
|
||||||
|
mock_db_session.flush.assert_called_once()
|
||||||
|
assert created_user is not None
|
||||||
|
assert created_user.email == user_create_data.email
|
||||||
|
assert created_user.name == user_create_data.name
|
||||||
|
assert created_user.id == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_user_email_already_registered(mock_db_session, user_create_data):
|
||||||
|
mock_db_session.flush.side_effect = IntegrityError("mock error (unique constraint)", "params", "orig")
|
||||||
|
with pytest.raises(EmailAlreadyRegisteredError):
|
||||||
|
await create_user(mock_db_session, user_create_data)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_user_db_integrity_error_not_unique(mock_db_session, user_create_data):
|
||||||
|
# Simulate an IntegrityError that is not related to a unique constraint
|
||||||
|
mock_db_session.flush.side_effect = IntegrityError("mock error (not unique constraint)", "params", "orig")
|
||||||
|
with pytest.raises(DatabaseIntegrityError):
|
||||||
|
await create_user(mock_db_session, user_create_data)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_user_db_connection_error(mock_db_session, user_create_data):
|
||||||
|
mock_db_session.begin.side_effect = OperationalError("mock_op_error", "params", "orig")
|
||||||
|
with pytest.raises(DatabaseConnectionError):
|
||||||
|
await create_user(mock_db_session, user_create_data)
|
||||||
|
# also test OperationalError on flush
|
||||||
|
mock_db_session.begin.side_effect = None # reset side effect
|
||||||
|
mock_db_session.flush.side_effect = OperationalError("mock_op_error", "params", "orig")
|
||||||
|
with pytest.raises(DatabaseConnectionError):
|
||||||
|
await create_user(mock_db_session, user_create_data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_user_db_transaction_error(mock_db_session, user_create_data):
|
||||||
|
# Simulate a generic SQLAlchemyError on flush that is not IntegrityError or OperationalError
|
||||||
|
mock_db_session.flush.side_effect = UserCreationError("Simulated non-specific SQLAlchemyError") # Or any other SQLAlchemyError
|
||||||
|
with pytest.raises(DatabaseTransactionError):
|
||||||
|
await create_user(mock_db_session, user_create_data)
|
28
docker-compose.prod.yml
Normal file
28
docker-compose.prod.yml
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
services:
|
||||||
|
backend:
|
||||||
|
container_name: fastapi_backend_prod
|
||||||
|
build:
|
||||||
|
context: ./be
|
||||||
|
dockerfile: Dockerfile.prod
|
||||||
|
target: production
|
||||||
|
environment:
|
||||||
|
- DATABASE_URL=${DATABASE_URL}
|
||||||
|
- GEMINI_API_KEY=${GEMINI_API_KEY}
|
||||||
|
- SECRET_KEY=${SECRET_KEY}
|
||||||
|
- SESSION_SECRET_KEY=${SESSION_SECRET_KEY}
|
||||||
|
- SENTRY_DSN=${SENTRY_DSN}
|
||||||
|
- LOG_LEVEL=INFO
|
||||||
|
- ENVIRONMENT=production
|
||||||
|
- CORS_ORIGINS=${CORS_ORIGINS}
|
||||||
|
- FRONTEND_URL=${FRONTEND_URL}
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
frontend:
|
||||||
|
container_name: frontend_prod
|
||||||
|
build:
|
||||||
|
context: ./fe
|
||||||
|
dockerfile: Dockerfile.prod
|
||||||
|
target: production
|
||||||
|
environment:
|
||||||
|
- VITE_API_URL=https://mitlistbe.mohamad.dev
|
||||||
|
restart: unless-stopped
|
@ -1,24 +1,21 @@
|
|||||||
# docker-compose.yml (in project root)
|
|
||||||
version: '3.8'
|
|
||||||
|
|
||||||
services:
|
services:
|
||||||
db:
|
db:
|
||||||
image: postgres:15 # Use a specific PostgreSQL version
|
image: postgres:17 # Use a specific PostgreSQL version
|
||||||
container_name: postgres_db
|
container_name: postgres_db
|
||||||
environment:
|
environment:
|
||||||
POSTGRES_USER: dev_user # Define DB user
|
POSTGRES_USER: xxx # Define DB user
|
||||||
POSTGRES_PASSWORD: dev_password # Define DB password
|
POSTGRES_PASSWORD: xxx # Define DB password
|
||||||
POSTGRES_DB: dev_db # Define Database name
|
POSTGRES_DB: xxx # Define Database name
|
||||||
volumes:
|
volumes:
|
||||||
- postgres_data:/var/lib/postgresql/data # Persist data using a named volume
|
- postgres_data:/var/lib/postgresql/data # Persist data using a named volume
|
||||||
ports:
|
ports:
|
||||||
- "5432:5432" # Expose PostgreSQL port to host (optional, for direct access)
|
- "5432:5432" # Expose PostgreSQL port to host (optional, for direct access)
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD-SHELL", "pg_isready -U $${POSTGRES_USER} -d $${POSTGRES_DB}"]
|
test: ["CMD-SHELL", "pg_isready -U $${POSTGRES_USER} -d $${POSTGRES_DB}"]
|
||||||
interval: 10s
|
interval: 10s
|
||||||
timeout: 5s
|
timeout: 5s
|
||||||
retries: 5
|
retries: 5
|
||||||
start_period: 10s
|
start_period: 10s
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
backend:
|
backend:
|
||||||
@ -36,30 +33,38 @@ services:
|
|||||||
# Pass the database URL to the backend container
|
# Pass the database URL to the backend container
|
||||||
# Uses the service name 'db' as the host, and credentials defined above
|
# Uses the service name 'db' as the host, and credentials defined above
|
||||||
# IMPORTANT: Use the correct async driver prefix if your app needs it!
|
# IMPORTANT: Use the correct async driver prefix if your app needs it!
|
||||||
- DATABASE_URL=postgresql+asyncpg://dev_user:dev_password@db:5432/dev_db
|
- DATABASE_URL=postgresql+asyncpg://mitlist_owner:npg_p0SkmyJ6BPWO@ep-small-sound-a9ketcef-pooler.gwc.azure.neon.tech/testnewmig
|
||||||
|
- GEMINI_API_KEY=xxx
|
||||||
|
- SECRET_KEY=xxx
|
||||||
# Add other environment variables needed by the backend here
|
# Add other environment variables needed by the backend here
|
||||||
# - SOME_OTHER_VAR=some_value
|
# - SOME_OTHER_VAR=some_value
|
||||||
depends_on:
|
depends_on:
|
||||||
db: # Wait for the db service to be healthy before starting backend
|
db:
|
||||||
|
# Wait for the db service to be healthy before starting backend
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
command: ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] # Override CMD for development reload
|
command: [
|
||||||
|
"uvicorn",
|
||||||
|
"app.main:app",
|
||||||
|
"--host",
|
||||||
|
"0.0.0.0",
|
||||||
|
"--port",
|
||||||
|
"8000",
|
||||||
|
"--reload",
|
||||||
|
] # Override CMD for development reload
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
pgadmin: # Optional service for database administration
|
frontend:
|
||||||
image: dpage/pgadmin4:latest
|
container_name: vite_frontend
|
||||||
container_name: pgadmin4_server
|
build:
|
||||||
environment:
|
context: ./fe
|
||||||
PGADMIN_DEFAULT_EMAIL: admin@example.com # Change as needed
|
dockerfile: Dockerfile
|
||||||
PGADMIN_DEFAULT_PASSWORD: admin_password # Change to a secure password
|
|
||||||
PGADMIN_CONFIG_SERVER_MODE: 'False' # Run in Desktop mode for easier local dev server setup
|
|
||||||
volumes:
|
|
||||||
- pgadmin_data:/var/lib/pgadmin # Persist pgAdmin configuration
|
|
||||||
ports:
|
ports:
|
||||||
- "5050:80" # Map container port 80 to host port 5050
|
- "80:80"
|
||||||
depends_on:
|
depends_on:
|
||||||
- db # Depends on the database service
|
- backend
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
volumes: # Define named volumes for data persistence
|
volumes:
|
||||||
|
# Define named volumes for data persistence
|
||||||
postgres_data:
|
postgres_data:
|
||||||
pgadmin_data:
|
pgadmin_data:
|
196
docs/PRODUCTION.md
Normal file
196
docs/PRODUCTION.md
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
# Production Deployment Guide (Gitea Actions)
|
||||||
|
|
||||||
|
This guide covers deploying the mitlist application to a production environment using Docker Compose and Gitea Actions for CI/CD.
|
||||||
|
|
||||||
|
## 🚀 Quick Start
|
||||||
|
|
||||||
|
1. **Clone the repository** (if not already done):
|
||||||
|
```bash
|
||||||
|
git clone <your-repo>
|
||||||
|
cd mitlist
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Configure Gitea Secrets**:
|
||||||
|
In your Gitea repository settings, go to "Secrets" and add the following secrets. These will be used by the `deploy-prod.yml` workflow.
|
||||||
|
|
||||||
|
* `DOCKER_USERNAME`: Your Docker Hub username (or username for your container registry).
|
||||||
|
* `DOCKER_PASSWORD`: Your Docker Hub password (or token for your container registry).
|
||||||
|
* `SERVER_HOST`: IP address or hostname of your production server.
|
||||||
|
* `SERVER_USERNAME`: Username for SSH access to your production server.
|
||||||
|
* `SSH_PRIVATE_KEY`: Your private SSH key for accessing the production server.
|
||||||
|
* `SERVER_PORT`: (Optional) SSH port for your server (defaults to 22).
|
||||||
|
* `POSTGRES_USER`: Production database username.
|
||||||
|
* `POSTGRES_PASSWORD`: Production database password.
|
||||||
|
* `POSTGRES_DB`: Production database name.
|
||||||
|
* `DATABASE_URL`: Production database connection string.
|
||||||
|
* `SECRET_KEY`: FastAPI application secret key.
|
||||||
|
* `SESSION_SECRET_KEY`: FastAPI session secret key.
|
||||||
|
* `GEMINI_API_KEY`: API key for Gemini.
|
||||||
|
* `REDIS_PASSWORD`: Password for Redis.
|
||||||
|
* `SENTRY_DSN`: (Optional) Sentry DSN for backend error tracking.
|
||||||
|
* `CORS_ORIGINS`: Comma-separated list of allowed CORS origins for production (e.g., `https://yourdomain.com`).
|
||||||
|
* `FRONTEND_URL`: The public URL of your frontend (e.g., `https://yourdomain.com`).
|
||||||
|
* `VITE_API_URL`: The public API URL for the frontend (e.g., `https://yourdomain.com/api`).
|
||||||
|
* `VITE_SENTRY_DSN`: (Optional) Sentry DSN for frontend error tracking.
|
||||||
|
|
||||||
|
3. **Prepare your Production Server**:
|
||||||
|
* Install Docker and Docker Compose (see Prerequisites section below).
|
||||||
|
* Ensure your server can be accessed via SSH using the key you added to Gitea secrets.
|
||||||
|
* Create the deployment directory on your server (e.g., `/srv/mitlist`).
|
||||||
|
* Copy the `docker-compose.prod.yml` file to this directory on your server.
|
||||||
|
|
||||||
|
4. **Push to `main` branch**:
|
||||||
|
Once the Gitea workflows (`.gitea/workflows/build-test.yml` and `.gitea/workflows/deploy-prod.yml`) are in your repository, pushing to the `main` branch will automatically trigger the deployment workflow.
|
||||||
|
|
||||||
|
## 📋 Prerequisites (Server Setup)
|
||||||
|
|
||||||
|
### System Requirements
|
||||||
|
- **OS**: Ubuntu 20.04+ / CentOS 8+ / Debian 11+
|
||||||
|
- **RAM**: Minimum 2GB, Recommended 4GB+
|
||||||
|
- **Storage**: Minimum 20GB free space
|
||||||
|
- **CPU**: 2+ cores recommended
|
||||||
|
|
||||||
|
### Software Dependencies (on Production Server)
|
||||||
|
- Docker 20.10+
|
||||||
|
- Docker Compose 2.0+
|
||||||
|
|
||||||
|
### Installation Commands
|
||||||
|
|
||||||
|
**Ubuntu/Debian:**
|
||||||
|
```bash
|
||||||
|
# Update system
|
||||||
|
sudo apt update && sudo apt upgrade -y
|
||||||
|
|
||||||
|
# Install Docker
|
||||||
|
curl -fsSL https://get.docker.com -o get-docker.sh
|
||||||
|
sudo sh get-docker.sh
|
||||||
|
sudo usermod -aG docker $USER # Add your deployment user to docker group
|
||||||
|
|
||||||
|
# Install Docker Compose
|
||||||
|
sudo apt install docker-compose-plugin
|
||||||
|
|
||||||
|
# Reboot or log out/in to apply group changes
|
||||||
|
# sudo reboot
|
||||||
|
```
|
||||||
|
|
||||||
|
**CentOS/RHEL:**
|
||||||
|
```bash
|
||||||
|
# Update system
|
||||||
|
sudo yum update -y
|
||||||
|
|
||||||
|
# Install Docker
|
||||||
|
sudo yum install -y yum-utils
|
||||||
|
sudo yum-config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo
|
||||||
|
sudo yum install docker-ce docker-ce-cli containerd.io docker-compose-plugin
|
||||||
|
sudo systemctl start docker
|
||||||
|
sudo systemctl enable docker
|
||||||
|
sudo usermod -aG docker $USER # Add your deployment user to docker group
|
||||||
|
|
||||||
|
# Reboot or log out/in to apply group changes
|
||||||
|
# sudo reboot
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔧 Configuration Overview
|
||||||
|
|
||||||
|
* **`docker-compose.prod.yml`**: Defines the production services (database, backend, frontend, redis). This file needs to be on your production server in the deployment directory.
|
||||||
|
* **`.gitea/workflows/build-test.yml`**: Gitea workflow that builds and runs tests on every push to `main` or `develop`, and on pull requests to these branches.
|
||||||
|
* **`.gitea/workflows/deploy-prod.yml`**: Gitea workflow that triggers on pushes to the `main` branch. It builds and pushes Docker images to your container registry and then SSHes into your production server to update environment variables and restart services using `docker-compose`.
|
||||||
|
* **`env.production.template`**: A template file showing the environment variables needed. These are now set directly in the Gitea deployment workflow via secrets.
|
||||||
|
|
||||||
|
## 🚀 Deployment Process (via Gitea Actions)
|
||||||
|
|
||||||
|
1. **Code Push**: Developer pushes code to the `main` branch.
|
||||||
|
2. **Build & Test Workflow**: (Optional, if you keep `build-test.yml` active on `main` as well) The `build-test.yml` workflow runs, ensuring code quality.
|
||||||
|
3. **Deploy Workflow Trigger**: The `deploy-prod.yml` workflow is triggered.
|
||||||
|
4. **Checkout Code**: The workflow checks out the latest code.
|
||||||
|
5. **Login to Registry**: Logs into your specified Docker container registry.
|
||||||
|
6. **Build & Push Images**: Builds the production Docker images for the backend and frontend and pushes them to the registry.
|
||||||
|
7. **SSH to Server**: Connects to your production server via SSH.
|
||||||
|
8. **Set Environment Variables**: Creates/updates the `.env.production` file on the server using the Gitea secrets.
|
||||||
|
9. **Pull New Images**: Runs `docker-compose pull` to fetch the newly pushed images.
|
||||||
|
10. **Restart Services**: Runs `docker-compose up -d` to restart the services with the new images and configuration.
|
||||||
|
11. **Prune Images**: Cleans up old, unused Docker images on the server.
|
||||||
|
|
||||||
|
## 🏗️ Simplified Architecture
|
||||||
|
|
||||||
|
With the removal of nginx as a reverse proxy, the architecture is simpler:
|
||||||
|
|
||||||
|
```
|
||||||
|
[ User / Internet ]
|
||||||
|
|
|
||||||
|
v
|
||||||
|
[ Frontend Service (Port 80) ] <-- Serves Vue.js app (e.g., via `serve`)
|
||||||
|
|
|
||||||
|
v (API Calls)
|
||||||
|
[ Backend Service (Internal Port 8000) ] <-- FastAPI
|
||||||
|
| |
|
||||||
|
v v
|
||||||
|
[ PostgreSQL ] [ Redis ]
|
||||||
|
(Database) (Cache)
|
||||||
|
```
|
||||||
|
|
||||||
|
* The **Frontend** service now directly exposes port 80 (or another port you configure) to the internet.
|
||||||
|
* The **Backend** service is still internal and accessed by the frontend via its Docker network name (`backend:8000`).
|
||||||
|
|
||||||
|
**Note on SSL/HTTPS**: Since nginx is removed, SSL termination is not handled by this setup. You would typically handle SSL at a higher level, for example:
|
||||||
|
* Using a cloud provider's load balancer with SSL termination.
|
||||||
|
* Placing another reverse proxy (like Caddy, Traefik, or a dedicated nginx instance) in front of your Docker setup on the server, configured for SSL.
|
||||||
|
* Using services like Cloudflare that can provide SSL for your domain.
|
||||||
|
|
||||||
|
## 📊 Monitoring & Logging
|
||||||
|
|
||||||
|
### Health Checks
|
||||||
|
* **Backend**: `http://<your_server_ip_or_domain>/api/health` (assuming your backend health endpoint is accessible if you map its port in `docker-compose.prod.yml` or if the frontend proxies it).
|
||||||
|
* **Frontend**: The `serve` package used by the frontend doesn't have a dedicated health check endpoint by default. You can check if the main page loads.
|
||||||
|
|
||||||
|
### Log Access
|
||||||
|
```bash
|
||||||
|
# On your production server, in the deployment directory
|
||||||
|
docker-compose -f docker-compose.prod.yml logs -f
|
||||||
|
|
||||||
|
# Specific service logs
|
||||||
|
docker-compose -f docker-compose.prod.yml logs -f backend
|
||||||
|
docker-compose -f docker-compose.prod.yml logs -f frontend
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔄 Maintenance
|
||||||
|
|
||||||
|
### Database Backups
|
||||||
|
Manual backups can still be performed on the server:
|
||||||
|
```bash
|
||||||
|
# Ensure your .env.production file is sourced or vars are available
|
||||||
|
docker exec postgres_db_prod pg_dump -U $POSTGRES_USER $POSTGRES_DB > backup-$(date +%Y%m%d).sql
|
||||||
|
```
|
||||||
|
Consider automating this with a cron job on your server.
|
||||||
|
|
||||||
|
### Updates
|
||||||
|
Updates are now handled by pushing to the `main` branch, which triggers the Gitea deployment workflow.
|
||||||
|
|
||||||
|
## 🐛 Troubleshooting
|
||||||
|
|
||||||
|
### Gitea Workflow Failures
|
||||||
|
* Check the Gitea Actions logs for the specific workflow run to identify errors.
|
||||||
|
* Ensure all secrets are correctly configured in Gitea.
|
||||||
|
* Verify Docker Hub/registry credentials.
|
||||||
|
* Check SSH connectivity to your server from the Gitea runner (if using self-hosted runners, ensure network access).
|
||||||
|
|
||||||
|
### Service Not Starting on Server
|
||||||
|
* SSH into your server.
|
||||||
|
* Navigate to your deployment directory (e.g., `/srv/mitlist`).
|
||||||
|
* Check logs: `docker-compose -f docker-compose.prod.yml logs <service_name>`
|
||||||
|
* Ensure `.env.production` has the correct values.
|
||||||
|
* Check `docker ps` to see running containers.
|
||||||
|
|
||||||
|
### Frontend Not Accessible
|
||||||
|
* Verify the frontend service is running (`docker ps`).
|
||||||
|
* Check frontend logs: `docker-compose -f docker-compose.prod.yml logs frontend`.
|
||||||
|
* Ensure the port mapping in `docker-compose.prod.yml` for the frontend service (e.g., `80:3000`) is correct and not blocked by a firewall on your server.
|
||||||
|
|
||||||
|
## 📝 Changelog
|
||||||
|
|
||||||
|
### v1.1.0 (Gitea Actions Deployment)
|
||||||
|
- Removed nginx reverse proxy and related shell scripts.
|
||||||
|
- Frontend now served directly using `serve`.
|
||||||
|
- Added Gitea Actions workflows for CI (build/test) and CD (deploy to production).
|
||||||
|
- Updated deployment documentation to reflect Gitea Actions strategy.
|
||||||
|
- Simplified `docker-compose.prod.yml`.
|
368
docs/expense-system.md
Normal file
368
docs/expense-system.md
Normal file
@ -0,0 +1,368 @@
|
|||||||
|
# Expense System Documentation
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The expense system is a core feature that allows users to track shared expenses, split them among group members, and manage settlements. The system supports various split types and integrates with lists, groups, and items.
|
||||||
|
|
||||||
|
## Core Components
|
||||||
|
|
||||||
|
### 1. Expenses
|
||||||
|
|
||||||
|
An expense represents a shared cost that needs to be split among multiple users.
|
||||||
|
|
||||||
|
#### Key Properties
|
||||||
|
|
||||||
|
- `id`: Unique identifier
|
||||||
|
- `description`: Description of the expense
|
||||||
|
- `total_amount`: Total cost of the expense (Decimal)
|
||||||
|
- `currency`: Currency code (defaults to "USD")
|
||||||
|
- `expense_date`: When the expense occurred
|
||||||
|
- `split_type`: How the expense should be divided
|
||||||
|
- `list_id`: Optional reference to a shopping list
|
||||||
|
- `group_id`: Optional reference to a group
|
||||||
|
- `item_id`: Optional reference to a specific item
|
||||||
|
- `paid_by_user_id`: User who paid for the expense
|
||||||
|
- `created_by_user_id`: User who created the expense record
|
||||||
|
- `version`: For optimistic locking
|
||||||
|
- `overall_settlement_status`: Overall payment status
|
||||||
|
|
||||||
|
#### Status Types
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
enum ExpenseOverallStatusEnum {
|
||||||
|
UNPAID = "unpaid",
|
||||||
|
PARTIALLY_PAID = "partially_paid",
|
||||||
|
PAID = "paid",
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Expense Splits
|
||||||
|
|
||||||
|
Splits represent how an expense is divided among users.
|
||||||
|
|
||||||
|
#### Key Properties
|
||||||
|
|
||||||
|
- `id`: Unique identifier
|
||||||
|
- `expense_id`: Reference to parent expense
|
||||||
|
- `user_id`: User who owes this portion
|
||||||
|
- `owed_amount`: Amount owed by the user
|
||||||
|
- `share_percentage`: Percentage share (for percentage-based splits)
|
||||||
|
- `share_units`: Number of shares (for share-based splits)
|
||||||
|
- `status`: Current payment status
|
||||||
|
- `paid_at`: When the split was paid
|
||||||
|
- `settlement_activities`: List of payment activities
|
||||||
|
|
||||||
|
#### Status Types
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
enum ExpenseSplitStatusEnum {
|
||||||
|
UNPAID = "unpaid",
|
||||||
|
PARTIALLY_PAID = "partially_paid",
|
||||||
|
PAID = "paid",
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Settlement Activities
|
||||||
|
|
||||||
|
Settlement activities track individual payments made towards expense splits.
|
||||||
|
|
||||||
|
#### Key Properties
|
||||||
|
|
||||||
|
- `id`: Unique identifier
|
||||||
|
- `expense_split_id`: Reference to the split being paid
|
||||||
|
- `paid_by_user_id`: User making the payment
|
||||||
|
- `amount_paid`: Amount being paid
|
||||||
|
- `paid_at`: When the payment was made
|
||||||
|
- `created_by_user_id`: User who recorded the payment
|
||||||
|
|
||||||
|
## Split Types
|
||||||
|
|
||||||
|
The system supports multiple ways to split expenses:
|
||||||
|
|
||||||
|
### 1. Equal Split
|
||||||
|
|
||||||
|
- Divides the total amount equally among all participants
|
||||||
|
- Handles rounding differences by adding remainder to first split
|
||||||
|
- No additional data required
|
||||||
|
|
||||||
|
### 2. Exact Amounts
|
||||||
|
|
||||||
|
- Users specify exact amounts for each person
|
||||||
|
- Sum of amounts must equal total expense
|
||||||
|
- Requires `splits_in` data with exact amounts
|
||||||
|
|
||||||
|
### 3. Percentage Based
|
||||||
|
|
||||||
|
- Users specify percentage shares
|
||||||
|
- Percentages must sum to 100%
|
||||||
|
- Requires `splits_in` data with percentages
|
||||||
|
|
||||||
|
### 4. Share Based
|
||||||
|
|
||||||
|
- Users specify number of shares
|
||||||
|
- Amount divided proportionally to shares
|
||||||
|
- Requires `splits_in` data with share units
|
||||||
|
|
||||||
|
### 5. Item Based
|
||||||
|
|
||||||
|
- Splits based on items in a shopping list
|
||||||
|
- Each item's cost is assigned to its adder
|
||||||
|
- Requires `list_id` and optionally `item_id`
|
||||||
|
|
||||||
|
## Integration Points
|
||||||
|
|
||||||
|
### 1. Lists
|
||||||
|
|
||||||
|
- Expenses can be associated with shopping lists
|
||||||
|
- Item-based splits use list items to determine splits
|
||||||
|
- List's group context can determine split participants
|
||||||
|
|
||||||
|
### 2. Groups
|
||||||
|
|
||||||
|
- Expenses can be directly associated with groups
|
||||||
|
- Group membership determines who can be included in splits
|
||||||
|
- Group context is required if no list is specified
|
||||||
|
|
||||||
|
### 3. Items
|
||||||
|
|
||||||
|
- Expenses can be linked to specific items
|
||||||
|
- Item prices are used for item-based splits
|
||||||
|
- Items must belong to a list
|
||||||
|
|
||||||
|
### 4. Users
|
||||||
|
|
||||||
|
- Users can be payers, debtors, or payment recorders
|
||||||
|
- User relationships are tracked in splits and settlements
|
||||||
|
- User context is required for all financial operations
|
||||||
|
|
||||||
|
## Key Operations
|
||||||
|
|
||||||
|
### 1. Creating Expenses
|
||||||
|
|
||||||
|
1. Validate context (list/group)
|
||||||
|
2. Create expense record
|
||||||
|
3. Generate splits based on split type
|
||||||
|
4. Validate total amounts match
|
||||||
|
5. Save all records in transaction
|
||||||
|
|
||||||
|
### 2. Updating Expenses
|
||||||
|
|
||||||
|
- Limited to non-financial fields:
|
||||||
|
- Description
|
||||||
|
- Currency
|
||||||
|
- Expense date
|
||||||
|
- Uses optimistic locking via version field
|
||||||
|
- Cannot modify splits after creation
|
||||||
|
|
||||||
|
### 3. Recording Payments
|
||||||
|
|
||||||
|
1. Create settlement activity
|
||||||
|
2. Update split status
|
||||||
|
3. Recalculate expense overall status
|
||||||
|
4. All operations in single transaction
|
||||||
|
|
||||||
|
### 4. Deleting Expenses
|
||||||
|
|
||||||
|
- Requires version matching
|
||||||
|
- Cascades to splits and settlements
|
||||||
|
- All operations in single transaction
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Data Integrity**
|
||||||
|
|
||||||
|
- Always use transactions for multi-step operations
|
||||||
|
- Validate totals match before saving
|
||||||
|
- Use optimistic locking for updates
|
||||||
|
|
||||||
|
2. **Error Handling**
|
||||||
|
|
||||||
|
- Handle database errors appropriately
|
||||||
|
- Validate user permissions
|
||||||
|
- Check for concurrent modifications
|
||||||
|
|
||||||
|
3. **Performance**
|
||||||
|
|
||||||
|
- Use appropriate indexes
|
||||||
|
- Load relationships efficiently
|
||||||
|
- Batch operations when possible
|
||||||
|
|
||||||
|
4. **Security**
|
||||||
|
- Validate user permissions
|
||||||
|
- Sanitize input data
|
||||||
|
- Use proper access controls
|
||||||
|
|
||||||
|
## Common Use Cases
|
||||||
|
|
||||||
|
1. **Group Dinner**
|
||||||
|
|
||||||
|
- Create expense with total amount
|
||||||
|
- Use equal split or exact amounts
|
||||||
|
- Record payments as they occur
|
||||||
|
|
||||||
|
2. **Shopping List**
|
||||||
|
|
||||||
|
- Create item-based expense
|
||||||
|
- System automatically splits based on items
|
||||||
|
- Track payments per person
|
||||||
|
|
||||||
|
3. **Rent Sharing**
|
||||||
|
|
||||||
|
- Create expense with total rent
|
||||||
|
- Use percentage or share-based split
|
||||||
|
- Record monthly payments
|
||||||
|
|
||||||
|
4. **Trip Expenses**
|
||||||
|
- Create multiple expenses
|
||||||
|
- Mix different split types
|
||||||
|
- Track overall balances
|
||||||
|
|
||||||
|
## Recurring Expenses
|
||||||
|
|
||||||
|
Recurring expenses are expenses that repeat at regular intervals. They are useful for regular payments like rent, utilities, or subscription services.
|
||||||
|
|
||||||
|
### Recurrence Types
|
||||||
|
|
||||||
|
1. **Daily**
|
||||||
|
|
||||||
|
- Repeats every X days
|
||||||
|
- Example: Daily parking fee
|
||||||
|
|
||||||
|
2. **Weekly**
|
||||||
|
|
||||||
|
- Repeats every X weeks on specific days
|
||||||
|
- Example: Weekly cleaning service
|
||||||
|
|
||||||
|
3. **Monthly**
|
||||||
|
|
||||||
|
- Repeats every X months on the same date
|
||||||
|
- Example: Monthly rent payment
|
||||||
|
|
||||||
|
4. **Yearly**
|
||||||
|
- Repeats every X years on the same date
|
||||||
|
- Example: Annual insurance premium
|
||||||
|
|
||||||
|
### Implementation Details
|
||||||
|
|
||||||
|
1. **Recurrence Pattern**
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
interface RecurrencePattern {
|
||||||
|
type: "daily" | "weekly" | "monthly" | "yearly";
|
||||||
|
interval: number; // Every X days/weeks/months/years
|
||||||
|
daysOfWeek?: number[]; // For weekly recurrence (0-6, Sunday-Saturday)
|
||||||
|
endDate?: string; // Optional end date for the recurrence
|
||||||
|
maxOccurrences?: number; // Optional maximum number of occurrences
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Recurring Expense Properties**
|
||||||
|
|
||||||
|
- All standard expense properties
|
||||||
|
- `recurrence_pattern`: Defines how the expense repeats
|
||||||
|
- `next_occurrence`: When the next expense will be created
|
||||||
|
- `last_occurrence`: When the last expense was created
|
||||||
|
- `is_recurring`: Boolean flag to identify recurring expenses
|
||||||
|
|
||||||
|
3. **Generation Process**
|
||||||
|
|
||||||
|
- System automatically creates new expenses based on the pattern
|
||||||
|
- Each generated expense is a regular expense with its own splits
|
||||||
|
- Original recurring expense serves as a template
|
||||||
|
- Generated expenses can be modified individually
|
||||||
|
|
||||||
|
4. **Management Features**
|
||||||
|
- Pause/resume recurrence
|
||||||
|
- Modify future occurrences
|
||||||
|
- Skip specific occurrences
|
||||||
|
- End recurrence early
|
||||||
|
- View all generated expenses
|
||||||
|
|
||||||
|
### Best Practices for Recurring Expenses
|
||||||
|
|
||||||
|
1. **Data Management**
|
||||||
|
|
||||||
|
- Keep original recurring expense as template
|
||||||
|
- Generate new expenses in advance
|
||||||
|
- Clean up old generated expenses periodically
|
||||||
|
|
||||||
|
2. **User Experience**
|
||||||
|
|
||||||
|
- Clear indication of recurring expenses
|
||||||
|
- Easy way to modify future occurrences
|
||||||
|
- Option to handle exceptions
|
||||||
|
|
||||||
|
3. **Performance**
|
||||||
|
- Batch process expense generation
|
||||||
|
- Index recurring expense queries
|
||||||
|
- Cache frequently accessed patterns
|
||||||
|
|
||||||
|
### Example Use Cases
|
||||||
|
|
||||||
|
1. **Monthly Rent**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"description": "Monthly Rent",
|
||||||
|
"total_amount": "2000.00",
|
||||||
|
"split_type": "PERCENTAGE",
|
||||||
|
"recurrence_pattern": {
|
||||||
|
"type": "monthly",
|
||||||
|
"interval": 1,
|
||||||
|
"endDate": "2024-12-31"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Weekly Cleaning Service**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"description": "Weekly Cleaning",
|
||||||
|
"total_amount": "150.00",
|
||||||
|
"split_type": "EQUAL",
|
||||||
|
"recurrence_pattern": {
|
||||||
|
"type": "weekly",
|
||||||
|
"interval": 1,
|
||||||
|
"daysOfWeek": [1] // Every Monday
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Considerations
|
||||||
|
|
||||||
|
1. **Decimal Handling**
|
||||||
|
|
||||||
|
- Use string representation for decimals in API
|
||||||
|
- Convert to Decimal for calculations
|
||||||
|
- Round to 2 decimal places for money
|
||||||
|
|
||||||
|
2. **Date Handling**
|
||||||
|
|
||||||
|
- Use ISO format for dates
|
||||||
|
- Store in UTC
|
||||||
|
- Convert to local time for display
|
||||||
|
|
||||||
|
3. **Status Updates**
|
||||||
|
- Update split status on payment
|
||||||
|
- Recalculate overall status
|
||||||
|
- Notify relevant users
|
||||||
|
|
||||||
|
## Future Considerations
|
||||||
|
|
||||||
|
1. **Potential Enhancements**
|
||||||
|
|
||||||
|
- Recurring expenses
|
||||||
|
- Bulk operations
|
||||||
|
- Advanced reporting
|
||||||
|
- Currency conversion
|
||||||
|
|
||||||
|
2. **Scalability**
|
||||||
|
|
||||||
|
- Handle large groups
|
||||||
|
- Optimize for frequent updates
|
||||||
|
- Consider caching strategies
|
||||||
|
|
||||||
|
3. **Integration**
|
||||||
|
- Payment providers
|
||||||
|
- Accounting systems
|
||||||
|
- Export capabilities
|
177
docs/mitlist_doc.md
Normal file
177
docs/mitlist_doc.md
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
## Project Documentation: Shared Household Management PWA
|
||||||
|
|
||||||
|
**Version:** 1.1 (Tech Stack Update)
|
||||||
|
**Date:** 2025-04-22
|
||||||
|
|
||||||
|
### 1. Project Overview
|
||||||
|
|
||||||
|
**1.1. Concept:**
|
||||||
|
Develop a Progressive Web App (PWA) designed to streamline household coordination and shared responsibilities. The application enables users within defined groups (e.g., households, roommates, families) to collaboratively manage shopping lists, track and split expenses with historical accuracy, and manage recurring or one-off household chores.
|
||||||
|
|
||||||
|
**1.2. Goals:**
|
||||||
|
|
||||||
|
- Simplify the creation, management, and sharing of shopping lists.
|
||||||
|
- Provide an efficient way to add items via image capture and OCR (using Gemini 1.5 Flash).
|
||||||
|
- Enable transparent and traceable tracking and splitting of shared expenses related to shopping lists.
|
||||||
|
- Offer a clear system for managing and assigning recurring or single-instance household chores.
|
||||||
|
- Deliver a seamless, near-native user experience across devices through PWA technologies, including robust offline capabilities.
|
||||||
|
- Foster better communication and coordination within shared living environments.
|
||||||
|
|
||||||
|
**1.3. Target Audience:**
|
||||||
|
|
||||||
|
- Roommates sharing household expenses and chores.
|
||||||
|
- Families coordinating grocery shopping and household tasks.
|
||||||
|
- Couples managing shared finances and responsibilities.
|
||||||
|
- Groups organizing events or trips involving shared purchases.
|
||||||
|
|
||||||
|
### 2. Key Features (V1 Scope)
|
||||||
|
|
||||||
|
The Minimum Viable Product (V1) focuses on delivering the core functionalities with a high degree of polish and reliability:
|
||||||
|
|
||||||
|
- **User Authentication & Group Management (using `fastapi-users`):**
|
||||||
|
- Secure email/password signup, login, password reset, email verification (leveraging `fastapi-users` features).
|
||||||
|
- Ability to create user groups (e.g., "Home", "Trip").
|
||||||
|
- Invite members to groups via unique, shareable codes/links.
|
||||||
|
- Basic role distinction (Owner, Member) for group administration.
|
||||||
|
- Ability for users to view groups and leave groups.
|
||||||
|
- **Shared Shopping List Management:**
|
||||||
|
- CRUD operations for shopping lists (Create, Read, Update, Delete).
|
||||||
|
- Option to create personal lists or share lists with specific groups.
|
||||||
|
- Real-time (or near real-time via polling/basic WebSocket) updates for shared lists.
|
||||||
|
- CRUD operations for items within lists (name, quantity, notes).
|
||||||
|
- Ability to mark items as purchased.
|
||||||
|
- Attribution for who added/completed items in shared lists.
|
||||||
|
- **OCR Integration (Gemini 1.5 Flash):**
|
||||||
|
- Capture images (receipts, handwritten lists) via browser (`input capture` / `getUserMedia`).
|
||||||
|
- Backend processing using Google AI API (Gemini 1.5 Flash model) with tailored prompts to extract item names.
|
||||||
|
- User review and edit screen for confirming/correcting extracted items before adding them to the list.
|
||||||
|
- Clear progress indicators and error handling.
|
||||||
|
- **Cost Splitting (Traceable):**
|
||||||
|
- Ability to add prices to completed items on a list, recording who added the price and when.
|
||||||
|
- Functionality to trigger an expense calculation for a list based on items with prices.
|
||||||
|
- Creation of immutable `ExpenseRecord` entries detailing the total amount, participants, and calculation time/user.
|
||||||
|
- Generation of `ExpenseShare` entries detailing the amount owed per participant for each `ExpenseRecord`.
|
||||||
|
- Ability for participants to mark their specific `ExpenseShare` as paid, logged via a `SettlementActivity` record for full traceability.
|
||||||
|
- View displaying historical expense records and their settlement status for each list.
|
||||||
|
- V1 focuses on equal splitting among all group members associated with the list at the time of calculation.
|
||||||
|
- **Chore Management (Recurring & Assignable):**
|
||||||
|
- CRUD operations for chores within a group context.
|
||||||
|
- Ability to define chores as one-time or recurring (daily, weekly, monthly, custom intervals).
|
||||||
|
- System calculates `next_due_date` based on frequency.
|
||||||
|
- Manual assignment of chores (specific instances/due dates) to group members via `ChoreAssignments`.
|
||||||
|
- Ability for assigned users to mark their specific `ChoreAssignment` as complete.
|
||||||
|
- Automatic update of the parent chore's `last_completed_at` and recalculation of `next_due_date` upon completion of recurring chores.
|
||||||
|
- Dedicated view for users to see their pending assigned chores ("My Chores").
|
||||||
|
- **PWA Core Functionality:**
|
||||||
|
- Installable on user devices via `manifest.json`.
|
||||||
|
- Offline access to cached data (lists, items, chores, basic expense info) via Service Workers and IndexedDB.
|
||||||
|
- Background synchronization queue for actions performed offline (adding items, marking complete, adding prices, completing chores).
|
||||||
|
- Basic conflict resolution strategy (e.g., last-write-wins with user notification) for offline data sync.
|
||||||
|
|
||||||
|
### 3. User Experience (UX) Philosophy
|
||||||
|
|
||||||
|
- **User-Centered & Collaborative:** Focus on intuitive workflows for both individual task management and seamless group collaboration. Minimize friction in common tasks like adding items, splitting costs, and completing chores.
|
||||||
|
- **Native-like PWA Experience:** Leverage Service Workers, caching (IndexedDB), and `manifest.json` to provide fast loading, reliable offline functionality, and installability, mimicking a native app experience.
|
||||||
|
- **Clarity & Accessibility:** Prioritize clear information hierarchy, legible typography, sufficient contrast, and adherence to WCAG accessibility standards for usability by all users. Utilize **Valerie UI** components designed with accessibility in mind.
|
||||||
|
- **Informative Feedback:** Provide immediate visual feedback for user actions (loading states, confirmations, animations). Clearly communicate offline status, sync progress, OCR processing status, and data conflicts.
|
||||||
|
|
||||||
|
### 4. Architecture & Technology Stack
|
||||||
|
|
||||||
|
- **Frontend:**
|
||||||
|
- **Framework:** Vue.js (Vue 3 with Composition API, built with Vite).
|
||||||
|
- **Styling & UI Components:** **Valerie UI** (as the primary component library and design system).
|
||||||
|
- **State Management:** Pinia (official state management library for Vue).
|
||||||
|
- **PWA:** Vite PWA plugin (leveraging Workbox.js under the hood) for Service Worker generation, manifest management, and caching strategies. IndexedDB for offline data storage.
|
||||||
|
- **Backend:**
|
||||||
|
- **Framework:** FastAPI (Python, high-performance, async support, automatic docs).
|
||||||
|
- **Database:** PostgreSQL (reliable relational database with JSONB support).
|
||||||
|
- **ORM:** SQLAlchemy (version 2.0+ with native async support).
|
||||||
|
- **Migrations:** Alembic (for managing database schema changes).
|
||||||
|
- **Authentication & User Management:** **`fastapi-users`** (handles user models, password hashing, JWT/cookie authentication, and core auth endpoints like signup, login, password reset, email verification).
|
||||||
|
- **Cloud Services & APIs:**
|
||||||
|
- **OCR:** Google AI API (using `gemini-1.5-flash-latest` model).
|
||||||
|
- **Hosting (Backend):** Containerized deployment (Docker) on cloud platforms like Google Cloud Run, AWS Fargate, or DigitalOcean App Platform.
|
||||||
|
- **Hosting (Frontend):** Static hosting platforms like Vercel, Netlify, or Cloudflare Pages (optimized for Vite-built Vue apps).
|
||||||
|
- **DevOps & Monitoring:**
|
||||||
|
- **Version Control:** Git (hosted on GitHub, GitLab, etc.).
|
||||||
|
- **Containerization:** Docker & Docker Compose (for local development and deployment consistency).
|
||||||
|
- **CI/CD:** GitHub Actions (or similar) for automated testing and deployment pipelines (using Vite build commands for frontend).
|
||||||
|
- **Error Tracking:** Sentry (or similar) for real-time error monitoring.
|
||||||
|
- **Logging:** Standard Python logging configured within FastAPI.
|
||||||
|
|
||||||
|
### 5. Data Model Highlights
|
||||||
|
|
||||||
|
Key database tables supporting the application's features:
|
||||||
|
|
||||||
|
- `Users`: Stores user account information. The schema will align with `fastapi-users` requirements (e.g., `id`, `email`, `hashed_password`, `is_active`, `is_superuser`, `is_verified`), with potential custom fields added as needed.
|
||||||
|
- `Groups`: Defines shared groups (name, owner).
|
||||||
|
- `UserGroups`: Many-to-many relationship linking users to groups with roles (owner/member).
|
||||||
|
- `Lists`: Stores shopping list details (name, description, creator, associated group, completion status).
|
||||||
|
- `Items`: Stores individual shopping list items (name, quantity, price, completion status, list association, user attribution for adding/pricing).
|
||||||
|
- `ExpenseRecords`: Logs each instance of a cost split calculation for a list (total amount, participants, calculation time/user, overall settlement status).
|
||||||
|
- `ExpenseShares`: Details the amount owed by each participant for a specific `ExpenseRecord` (links to user and record, amount, paid status).
|
||||||
|
- `SettlementActivities`: Records every action taken to mark an `ExpenseShare` as paid (links to record, payer, affected user, timestamp).
|
||||||
|
- `Chores`: Defines chore templates (name, description, group association, recurrence rules, next due date).
|
||||||
|
- `ChoreAssignments`: Tracks specific instances of chores assigned to users (links to chore, user, due date, completion status).
|
||||||
|
|
||||||
|
### 6. Core User Flows (Summarized)
|
||||||
|
|
||||||
|
- **Onboarding:** Signup/Login (via `fastapi-users` flow) -> Optional guided tour -> Create/Join first group -> Dashboard.
|
||||||
|
- **List Creation & Sharing:** Create List -> Choose Personal or Share with Group -> List appears on dashboard (and shared members' dashboards).
|
||||||
|
- **Adding Items (Manual):** Open List -> Type item name -> Item added.
|
||||||
|
- **Adding Items (OCR):** Open List -> Tap "Add via Photo" -> Capture/Select Image -> Upload/Process (Gemini) -> Review/Edit extracted items -> Confirm -> Items added to list.
|
||||||
|
- **Shopping & Price Entry:** Open List -> Check off items -> Enter price for completed items -> Price saved.
|
||||||
|
- **Cost Splitting Cycle:** View List -> Click "Calculate Split" -> Backend creates traceable `ExpenseRecord` & `ExpenseShares` -> View Expense History -> Participants mark their shares paid (creating `SettlementActivity`).
|
||||||
|
- **Chore Cycle:** Create Chore (define recurrence) -> Chore appears in group list -> (Manual Assignment) Assign chore instance to user -> User views "My Chores" -> User marks assignment complete -> Backend updates status and recalculates next due date for recurring chores.
|
||||||
|
- **Offline Usage:** Open app offline -> View cached lists/chores -> Add/complete items/chores -> Changes queued -> Go online -> Background sync processes queue -> UI updates, conflicts notified.
|
||||||
|
|
||||||
|
### 7. Development Roadmap (Phase Summary)
|
||||||
|
|
||||||
|
1. **Phase 1: Planning & Design:** User stories, flows, sharing/sync models, tech stack, architecture, schema design.
|
||||||
|
2. **Phase 2: Core App Setup:** Project initialization (Git, **Vue.js with Vite**, FastAPI), DB connection (SQLAlchemy/Alembic), basic PWA config (**Vite PWA plugin**, manifest, SW), **Valerie UI integration**, **Pinia setup**, Docker setup, CI checks.
|
||||||
|
3. **Phase 3: User Auth & Group Management:** Backend: Integrate **`fastapi-users`**, configure its routers, adapt user model. Frontend: Implement auth pages using **Vue components**, **Pinia for auth state**, and calling `fastapi-users` endpoints. Implement Group Management features.
|
||||||
|
4. **Phase 4: Shared Shopping List CRUD:** Backend/Frontend for List/Item CRUD, permissions, basic real-time updates (polling), offline sync refinement for lists/items.
|
||||||
|
5. **Phase 5: OCR Integration (Gemini Flash):** Backend integration with Google AI SDK, image capture/upload UI, OCR processing endpoint, review/edit screen, integration with list items.
|
||||||
|
6. **Phase 6: Cost Splitting (Traceable):** Backend/Frontend for adding prices, calculating splits (creating historical records), viewing expense history, marking shares paid (with activity logging).
|
||||||
|
7. **Phase 7: Chore Splitting Module:** Backend/Frontend for Chore CRUD (including recurrence), manual assignment, completion tracking, "My Chores" view, recurrence handling logic.
|
||||||
|
8. **Phase 8: Testing, Refinement & Beta Launch:** Comprehensive E2E testing, usability testing, accessibility checks, performance tuning, deployment to beta environment, feedback collection.
|
||||||
|
9. **Phase 9: Final Release & Post-Launch Monitoring:** Address beta feedback, final deployment to production, setup monitoring (errors, performance, costs).
|
||||||
|
|
||||||
|
_(Estimated Total Duration: Approx. 17-19 Weeks for V1)_
|
||||||
|
|
||||||
|
### 8. Risk Management & Mitigation
|
||||||
|
|
||||||
|
- **Collaboration Complexity:** (Risk) Permissions and real-time sync can be complex. (Mitigation) Start simple, test permissions thoroughly, use clear data models.
|
||||||
|
- **OCR Accuracy/Cost (Gemini):** (Risk) OCR isn't perfect; API calls have costs/quotas. (Mitigation) Use capable model (Gemini Flash), mandatory user review step, clear error feedback, monitor API usage/costs, secure API keys.
|
||||||
|
- **Offline Sync Conflicts:** (Risk) Concurrent offline edits can clash. (Mitigation) Implement defined strategy (last-write-wins + notify), robust queue processing, thorough testing of conflict scenarios.
|
||||||
|
- **PWA Consistency:** (Risk) Behavior varies across browsers/OS (esp. iOS). (Mitigation) Rigorous cross-platform testing, use standard tools (Vite PWA plugin/Workbox), follow best practices.
|
||||||
|
- **Traceability Overhead:** (Risk) Storing detailed history increases DB size/complexity. (Mitigation) Design efficient queries, use appropriate indexing, plan for potential data archiving later.
|
||||||
|
- **User Adoption:** (Risk) Users might not consistently use groups/features. (Mitigation) Smooth onboarding, clear value proposition, reliable core features.
|
||||||
|
- **Valerie UI Maturity/Flexibility:** (Risk, if "Valerie UI" is niche or custom) Potential limitations in component availability or customization. (Mitigation) Thoroughly evaluate Valerie UI early, have fallback styling strategies if needed, or contribute to/extend the library.
|
||||||
|
|
||||||
|
### 9. Testing Strategy
|
||||||
|
|
||||||
|
- **Unit Tests:** Backend logic (calculations, permissions, recurrence), Frontend component logic (**Vue Test Utils** for Vue components, Pinia store testing).
|
||||||
|
- **Integration Tests:** Backend API endpoints interacting with DB and external APIs (Gemini - mocked).
|
||||||
|
- **End-to-End (E2E) Tests:** (Playwright/Cypress) Simulate full user flows across features.
|
||||||
|
- **PWA Testing:** Manual and automated checks for installability, offline functionality (caching, sync queue), cross-browser/OS compatibility.
|
||||||
|
- **Accessibility Testing:** Automated tools (axe-core) + manual checks (keyboard nav, screen readers), leveraging **Valerie UI's** accessibility features.
|
||||||
|
- **Usability Testing:** Regular sessions with target users throughout development.
|
||||||
|
- **Security Testing:** Basic checks (OWASP Top 10 awareness), dependency scanning, secure handling of secrets/tokens (rely on `fastapi-users` security practices).
|
||||||
|
- **Manual Testing:** Exploratory testing, edge case validation, testing diverse OCR inputs.
|
||||||
|
|
||||||
|
### 10. Future Enhancements (Post-V1)
|
||||||
|
|
||||||
|
- Advanced Cost Splitting (by item, percentage, unequal splits).
|
||||||
|
- Payment Integration (Stripe Connect for settling debts).
|
||||||
|
- Real-time Collaboration (WebSockets for instant updates).
|
||||||
|
- Push Notifications (reminders for chores, expenses, list updates).
|
||||||
|
- Advanced Chore Features (assignment algorithms, calendar view).
|
||||||
|
- Enhanced OCR (handling more formats, potential fine-tuning).
|
||||||
|
- User Profile Customization (avatars, etc., extending `fastapi-users` model).
|
||||||
|
- Analytics Dashboard (spending insights, chore completion stats).
|
||||||
|
- Recipe Integration / Pantry Inventory Tracking.
|
||||||
|
|
||||||
|
### 11. Conclusion
|
||||||
|
|
||||||
|
This project aims to deliver a modern, user-friendly PWA that effectively addresses common household coordination challenges. By combining collaborative list management, intelligent OCR, traceable expense splitting, and flexible chore tracking with a robust offline-first PWA architecture built on **Vue.js, Pinia, Valerie UI, and FastAPI with `fastapi-users`**, the application will provide significant value to roommates, families, and other shared living groups. The focus on a well-defined V1, traceable data, and a solid technical foundation sets the stage for future growth and feature expansion.
|
46
env.production.template
Normal file
46
env.production.template
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# Production Environment Variables Template
|
||||||
|
# Copy this file to .env.production and fill in the actual values
|
||||||
|
# NEVER commit the actual .env.production file to version control
|
||||||
|
|
||||||
|
# Database Configuration
|
||||||
|
POSTGRES_USER=mitlist_user
|
||||||
|
POSTGRES_PASSWORD=your_secure_database_password_here
|
||||||
|
POSTGRES_DB=mitlist_prod
|
||||||
|
DATABASE_URL=postgresql+asyncpg://mitlist_user:your_secure_database_password_here@db:5432/mitlist_prod
|
||||||
|
|
||||||
|
# Security Keys (Generate with: openssl rand -hex 32)
|
||||||
|
SECRET_KEY=your_secret_key_here_minimum_32_characters_long
|
||||||
|
SESSION_SECRET_KEY=your_session_secret_key_here_minimum_32_characters_long
|
||||||
|
|
||||||
|
# API Keys
|
||||||
|
GEMINI_API_KEY=your_gemini_api_key_here
|
||||||
|
|
||||||
|
# Redis Configuration
|
||||||
|
REDIS_PASSWORD=your_redis_password_here
|
||||||
|
|
||||||
|
# Sentry Configuration (Optional but recommended)
|
||||||
|
SENTRY_DSN=your_sentry_dsn_here
|
||||||
|
|
||||||
|
# CORS Configuration
|
||||||
|
CORS_ORIGINS=https://yourdomain.com,https://www.yourdomain.com
|
||||||
|
FRONTEND_URL=https://yourdomain.com
|
||||||
|
|
||||||
|
# Frontend Build Variables
|
||||||
|
VITE_API_URL=https://yourdomain.com/api
|
||||||
|
VITE_SENTRY_DSN=your_frontend_sentry_dsn_here
|
||||||
|
VITE_ROUTER_MODE=history
|
||||||
|
|
||||||
|
# Google OAuth Configuration - Replace with your actual credentials
|
||||||
|
GOOGLE_CLIENT_ID="YOUR_GOOGLE_CLIENT_ID_HERE"
|
||||||
|
GOOGLE_CLIENT_SECRET="YOUR_GOOGLE_CLIENT_SECRET_HERE"
|
||||||
|
GOOGLE_REDIRECT_URI=https://yourdomain.com/auth/google/callback
|
||||||
|
|
||||||
|
APPLE_CLIENT_ID=your_apple_client_id
|
||||||
|
APPLE_TEAM_ID=your_apple_team_id
|
||||||
|
APPLE_KEY_ID=your_apple_key_id
|
||||||
|
APPLE_PRIVATE_KEY=your_apple_private_key
|
||||||
|
APPLE_REDIRECT_URI=https://yourdomain.com/auth/apple/callback
|
||||||
|
|
||||||
|
# Production Settings
|
||||||
|
ENVIRONMENT=production
|
||||||
|
LOG_LEVEL=INFO
|
9
fe/.editorconfig
Normal file
9
fe/.editorconfig
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
[*.{js,jsx,mjs,cjs,ts,tsx,mts,cts,vue,css,scss,sass,less,styl}]
|
||||||
|
charset = utf-8
|
||||||
|
indent_size = 2
|
||||||
|
indent_style = space
|
||||||
|
insert_final_newline = true
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
|
||||||
|
end_of_line = lf
|
||||||
|
max_line_length = 100
|
1
fe/.gitattributes
vendored
Normal file
1
fe/.gitattributes
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
* text=auto eol=lf
|
53
fe/.gitignore
vendored
53
fe/.gitignore
vendored
@ -1,23 +1,38 @@
|
|||||||
node_modules
|
# Logs
|
||||||
|
logs
|
||||||
|
*.log
|
||||||
|
npm-debug.log*
|
||||||
|
yarn-debug.log*
|
||||||
|
yarn-error.log*
|
||||||
|
pnpm-debug.log*
|
||||||
|
lerna-debug.log*
|
||||||
|
|
||||||
# Output
|
**/node_modules/
|
||||||
.output
|
|
||||||
.vercel
|
|
||||||
.netlify
|
|
||||||
.wrangler
|
|
||||||
/.svelte-kit
|
|
||||||
/build
|
|
||||||
|
|
||||||
# OS
|
|
||||||
.DS_Store
|
.DS_Store
|
||||||
Thumbs.db
|
dist
|
||||||
|
dist-ssr
|
||||||
|
coverage
|
||||||
|
*.local
|
||||||
|
|
||||||
|
/cypress/videos/
|
||||||
|
/cypress/screenshots/
|
||||||
|
|
||||||
|
# Editor directories and files
|
||||||
|
.vscode/*
|
||||||
|
!.vscode/extensions.json
|
||||||
|
.idea
|
||||||
|
*.suo
|
||||||
|
*.ntvs*
|
||||||
|
*.njsproj
|
||||||
|
*.sln
|
||||||
|
*.sw?
|
||||||
|
|
||||||
|
*.tsbuildinfo
|
||||||
|
*.sw.js
|
||||||
|
|
||||||
|
test-results/
|
||||||
|
playwright-report/
|
||||||
|
|
||||||
# Env
|
|
||||||
.env
|
.env
|
||||||
.env.*
|
*storybook.log
|
||||||
!.env.example
|
storybook-static
|
||||||
!.env.test
|
|
||||||
|
|
||||||
# Vite
|
|
||||||
vite.config.js.timestamp-*
|
|
||||||
vite.config.ts.timestamp-*
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user