312 lines
13 KiB
Python
312 lines
13 KiB
Python
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy import func as sql_func # For get_list_status
|
|
from datetime import datetime, timezone
|
|
|
|
from app.crud.list import (
|
|
create_list,
|
|
get_lists_for_user,
|
|
get_list_by_id,
|
|
update_list,
|
|
delete_list,
|
|
check_list_permission,
|
|
get_list_status
|
|
)
|
|
from app.schemas.list import ListCreate, ListUpdate, ListStatus
|
|
from app.models import List as ListModel, User as UserModel, Group as GroupModel, UserGroup as UserGroupModel, Item as ItemModel
|
|
from app.core.exceptions import (
|
|
ListNotFoundError,
|
|
ListPermissionError,
|
|
ListCreatorRequiredError,
|
|
DatabaseConnectionError,
|
|
DatabaseIntegrityError,
|
|
DatabaseQueryError,
|
|
DatabaseTransactionError,
|
|
ConflictError
|
|
)
|
|
|
|
# Fixtures
|
|
@pytest.fixture
|
|
def mock_db_session():
|
|
session = AsyncMock() # Overall session mock
|
|
|
|
# For session.begin() and session.begin_nested()
|
|
# These are sync methods returning an async context manager.
|
|
# The returned AsyncMock will act as the async context manager.
|
|
mock_transaction_context = AsyncMock()
|
|
session.begin = MagicMock(return_value=mock_transaction_context)
|
|
session.begin_nested = MagicMock(return_value=mock_transaction_context) # Can use the same or a new one
|
|
|
|
# Async methods on the session itself
|
|
session.commit = AsyncMock()
|
|
session.rollback = AsyncMock()
|
|
session.refresh = AsyncMock()
|
|
session.execute = AsyncMock() # Correct: execute is async
|
|
session.get = AsyncMock() # Correct: get is async
|
|
session.flush = AsyncMock() # Correct: flush is async
|
|
|
|
# Sync methods on the session
|
|
session.add = MagicMock()
|
|
session.delete = MagicMock()
|
|
session.in_transaction = MagicMock(return_value=False)
|
|
return session
|
|
|
|
@pytest.fixture
|
|
def list_create_data():
|
|
return ListCreate(name="New Shopping List", description="Groceries for the week")
|
|
|
|
@pytest.fixture
|
|
def list_update_data():
|
|
return ListUpdate(name="Updated Shopping List", description="Weekend Groceries", version=1)
|
|
|
|
@pytest.fixture
|
|
def user_model():
|
|
return UserModel(id=1, name="Test User", email="test@example.com")
|
|
|
|
@pytest.fixture
|
|
def another_user_model():
|
|
return UserModel(id=2, name="Another User", email="another@example.com")
|
|
|
|
@pytest.fixture
|
|
def group_model():
|
|
return GroupModel(id=1, name="Test Group")
|
|
|
|
@pytest.fixture
|
|
def db_list_personal_model(user_model):
|
|
return ListModel(
|
|
id=1, name="Personal List", created_by_id=user_model.id, creator=user_model,
|
|
version=1, updated_at=datetime.now(timezone.utc), items=[]
|
|
)
|
|
|
|
@pytest.fixture
|
|
def db_list_group_model(user_model, group_model):
|
|
return ListModel(
|
|
id=2, name="Group List", created_by_id=user_model.id, creator=user_model,
|
|
group_id=group_model.id, group=group_model, version=1, updated_at=datetime.now(timezone.utc), items=[]
|
|
)
|
|
|
|
# --- create_list Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_create_list_success(mock_db_session, list_create_data, user_model):
|
|
async def mock_refresh(instance):
|
|
instance.id = 100
|
|
instance.version = 1
|
|
instance.updated_at = datetime.now(timezone.utc)
|
|
return None
|
|
|
|
mock_db_session.refresh.side_effect = mock_refresh
|
|
mock_result = AsyncMock()
|
|
mock_result.scalar_one_or_none.return_value = ListModel(
|
|
id=100,
|
|
name=list_create_data.name,
|
|
description=list_create_data.description,
|
|
created_by_id=user_model.id,
|
|
version=1,
|
|
updated_at=datetime.now(timezone.utc)
|
|
)
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
created_list = await create_list(mock_db_session, list_create_data, user_model.id)
|
|
mock_db_session.add.assert_called_once()
|
|
mock_db_session.flush.assert_called_once()
|
|
assert created_list.name == list_create_data.name
|
|
assert created_list.created_by_id == user_model.id
|
|
|
|
# --- get_lists_for_user Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_get_lists_for_user_mix(mock_db_session, user_model, db_list_personal_model, db_list_group_model):
|
|
# Mock for the object returned by .scalars() for group_ids query
|
|
mock_group_ids_scalar_result = MagicMock()
|
|
mock_group_ids_scalar_result.all.return_value = [db_list_group_model.group_id]
|
|
|
|
# Mock for the object returned by await session.execute() for group_ids query
|
|
mock_group_ids_execute_result = MagicMock()
|
|
mock_group_ids_execute_result.scalars.return_value = mock_group_ids_scalar_result
|
|
|
|
# Mock for the object returned by .scalars() for lists query
|
|
mock_lists_scalar_result = MagicMock()
|
|
mock_lists_scalar_result.all.return_value = [db_list_personal_model, db_list_group_model]
|
|
|
|
# Mock for the object returned by await session.execute() for lists query
|
|
mock_lists_execute_result = MagicMock()
|
|
mock_lists_execute_result.scalars.return_value = mock_lists_scalar_result
|
|
|
|
mock_db_session.execute.side_effect = [mock_group_ids_execute_result, mock_lists_execute_result]
|
|
|
|
lists = await get_lists_for_user(mock_db_session, user_model.id)
|
|
assert len(lists) == 2
|
|
assert db_list_personal_model in lists
|
|
assert db_list_group_model in lists
|
|
assert mock_db_session.execute.call_count == 2
|
|
|
|
# --- get_list_by_id Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_get_list_by_id_found_no_items(mock_db_session, db_list_personal_model):
|
|
# Mock for the object returned by .scalars()
|
|
mock_scalar_result = MagicMock()
|
|
mock_scalar_result.first.return_value = db_list_personal_model
|
|
|
|
# Mock for the object returned by await session.execute()
|
|
mock_execute_result = MagicMock()
|
|
mock_execute_result.scalars.return_value = mock_scalar_result
|
|
|
|
mock_db_session.execute.return_value = mock_execute_result
|
|
|
|
found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=False)
|
|
assert found_list is not None
|
|
assert found_list.id == db_list_personal_model.id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_list_by_id_found_with_items(mock_db_session, db_list_personal_model):
|
|
db_list_personal_model.items = [ItemModel(id=1, name="Test Item")]
|
|
# Mock for the object returned by .scalars()
|
|
mock_scalar_result = MagicMock()
|
|
mock_scalar_result.first.return_value = db_list_personal_model
|
|
|
|
# Mock for the object returned by await session.execute()
|
|
mock_execute_result = MagicMock()
|
|
mock_execute_result.scalars.return_value = mock_scalar_result
|
|
|
|
mock_db_session.execute.return_value = mock_execute_result
|
|
|
|
found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=True)
|
|
assert found_list is not None
|
|
assert len(found_list.items) == 1
|
|
|
|
# --- update_list Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_update_list_success(mock_db_session, db_list_personal_model, list_update_data):
|
|
list_update_data.version = db_list_personal_model.version
|
|
|
|
mock_result = AsyncMock()
|
|
mock_result.scalar_one_or_none.return_value = db_list_personal_model
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
updated_list = await update_list(mock_db_session, db_list_personal_model, list_update_data)
|
|
assert updated_list.name == list_update_data.name
|
|
assert updated_list.version == db_list_personal_model.version + 1
|
|
mock_db_session.add.assert_called_once_with(db_list_personal_model)
|
|
mock_db_session.flush.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_list_conflict(mock_db_session, db_list_personal_model, list_update_data):
|
|
list_update_data.version = db_list_personal_model.version + 1
|
|
with pytest.raises(ConflictError):
|
|
await update_list(mock_db_session, db_list_personal_model, list_update_data)
|
|
mock_db_session.rollback.assert_called_once()
|
|
|
|
# --- delete_list Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_delete_list_success(mock_db_session, db_list_personal_model):
|
|
await delete_list(mock_db_session, db_list_personal_model)
|
|
mock_db_session.delete.assert_called_once_with(db_list_personal_model)
|
|
|
|
# --- check_list_permission Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_check_list_permission_creator_access_personal_list(mock_db_session, db_list_personal_model, user_model):
|
|
# Mock for the object returned by .scalars()
|
|
mock_scalar_result = MagicMock()
|
|
mock_scalar_result.first.return_value = db_list_personal_model
|
|
|
|
# Mock for the object returned by await session.execute()
|
|
mock_execute_result = MagicMock()
|
|
mock_execute_result.scalars.return_value = mock_scalar_result
|
|
mock_db_session.execute.return_value = mock_execute_result
|
|
|
|
ret_list = await check_list_permission(mock_db_session, db_list_personal_model.id, user_model.id)
|
|
assert ret_list.id == db_list_personal_model.id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_list_permission_group_member_access_group_list(mock_db_session, db_list_group_model, another_user_model, group_model):
|
|
# Mock for the object returned by .scalars()
|
|
mock_scalar_result = MagicMock()
|
|
mock_scalar_result.first.return_value = db_list_group_model
|
|
|
|
# Mock for the object returned by await session.execute()
|
|
mock_execute_result = MagicMock()
|
|
mock_execute_result.scalars.return_value = mock_scalar_result
|
|
mock_db_session.execute.return_value = mock_execute_result
|
|
|
|
with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member:
|
|
mock_is_member.return_value = True
|
|
ret_list = await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id)
|
|
assert ret_list.id == db_list_group_model.id
|
|
mock_is_member.assert_called_once_with(mock_db_session, group_id=group_model.id, user_id=another_user_model.id)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_list_permission_non_member_no_access_group_list(mock_db_session, db_list_group_model, another_user_model):
|
|
# Mock for the object returned by .scalars()
|
|
mock_scalar_result = MagicMock()
|
|
mock_scalar_result.first.return_value = db_list_group_model
|
|
|
|
# Mock for the object returned by await session.execute()
|
|
mock_execute_result = MagicMock()
|
|
mock_execute_result.scalars.return_value = mock_scalar_result
|
|
mock_db_session.execute.return_value = mock_execute_result
|
|
|
|
with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member:
|
|
mock_is_member.return_value = False
|
|
with pytest.raises(ListPermissionError):
|
|
await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_list_permission_list_not_found(mock_db_session, user_model):
|
|
# Mock for the object returned by .scalars()
|
|
mock_scalar_result = MagicMock()
|
|
mock_scalar_result.first.return_value = None
|
|
|
|
# Mock for the object returned by await session.execute()
|
|
mock_execute_result = MagicMock()
|
|
mock_execute_result.scalars.return_value = mock_scalar_result
|
|
mock_db_session.execute.return_value = mock_execute_result
|
|
|
|
with pytest.raises(ListNotFoundError):
|
|
await check_list_permission(mock_db_session, 999, user_model.id)
|
|
|
|
# --- get_list_status Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_get_list_status_success(mock_db_session, db_list_personal_model):
|
|
# This test is more complex due to multiple potential execute calls or specific query structures
|
|
# For simplicity, assuming the primary query for the list model uses the same pattern:
|
|
mock_list_scalar_result = MagicMock()
|
|
mock_list_scalar_result.first.return_value = db_list_personal_model
|
|
mock_list_execute_result = MagicMock()
|
|
mock_list_execute_result.scalars.return_value = mock_list_scalar_result
|
|
|
|
# If get_list_status makes other db calls (e.g., for items, counts), they need similar mocking.
|
|
# For now, let's assume the first execute call is for the list itself.
|
|
# If the error persists as "'coroutine' object has no attribute 'latest_item_updated_at'",
|
|
# it means the `get_list_status` function is not awaiting something before accessing that attribute,
|
|
# or the mock for the object that *should* have `latest_item_updated_at` is incorrect.
|
|
|
|
# A simplified mock for a single execute call. You might need to adjust if get_list_status does more.
|
|
mock_db_session.execute.return_value = mock_list_execute_result
|
|
|
|
# Patching sql_func.max if it's directly used and causing issues with AsyncMock
|
|
with patch('app.crud.list.sql_func.max') as mock_sql_max:
|
|
# Example: if sql_func.max is part of a subquery or column expression
|
|
# this mock might not be hit directly if the execute call itself is fully mocked.
|
|
# This part is speculative without seeing the `get_list_status` implementation.
|
|
mock_sql_max.return_value = "mocked_max_value"
|
|
|
|
status = await get_list_status(mock_db_session, db_list_personal_model.id)
|
|
assert isinstance(status, ListStatus)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_list_status_list_not_found(mock_db_session):
|
|
# Mock for the object returned by .scalars()
|
|
mock_scalar_result = MagicMock()
|
|
mock_scalar_result.first.return_value = None
|
|
|
|
# Mock for the object returned by await session.execute()
|
|
mock_execute_result = MagicMock()
|
|
mock_execute_result.scalars.return_value = mock_scalar_result
|
|
mock_db_session.execute.return_value = mock_execute_result
|
|
|
|
with pytest.raises(ListNotFoundError):
|
|
await get_list_status(mock_db_session, 999)
|
|
|
|
# TODO: Add more specific DB error tests (Operational, SQLAlchemyError, IntegrityError) for each function.
|
|
# TODO: Test check_list_permission with require_creator=True cases. |