import pytest from unittest.mock import AsyncMock, MagicMock, patch from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError from sqlalchemy.future import select from sqlalchemy import func as sql_func # For get_list_status from datetime import datetime, timezone from app.crud.list import ( create_list, get_lists_for_user, get_list_by_id, update_list, delete_list, check_list_permission, get_list_status ) from app.schemas.list import ListCreate, ListUpdate, ListStatus from app.models import List as ListModel, User as UserModel, Group as GroupModel, UserGroup as UserGroupModel, Item as ItemModel from app.core.exceptions import ( ListNotFoundError, ListPermissionError, ListCreatorRequiredError, DatabaseConnectionError, DatabaseIntegrityError, DatabaseQueryError, DatabaseTransactionError, ConflictError ) # Fixtures @pytest.fixture def mock_db_session(): session = AsyncMock() # Overall session mock # For session.begin() and session.begin_nested() # These are sync methods returning an async context manager. # The returned AsyncMock will act as the async context manager. mock_transaction_context = AsyncMock() session.begin = MagicMock(return_value=mock_transaction_context) session.begin_nested = MagicMock(return_value=mock_transaction_context) # Can use the same or a new one # Async methods on the session itself session.commit = AsyncMock() session.rollback = AsyncMock() session.refresh = AsyncMock() session.execute = AsyncMock() # Correct: execute is async session.get = AsyncMock() # Correct: get is async session.flush = AsyncMock() # Correct: flush is async # Sync methods on the session session.add = MagicMock() session.delete = MagicMock() session.in_transaction = MagicMock(return_value=False) return session @pytest.fixture def list_create_data(): return ListCreate(name="New Shopping List", description="Groceries for the week") @pytest.fixture def list_update_data(): return ListUpdate(name="Updated Shopping List", description="Weekend Groceries", version=1) @pytest.fixture def user_model(): return UserModel(id=1, name="Test User", email="test@example.com") @pytest.fixture def another_user_model(): return UserModel(id=2, name="Another User", email="another@example.com") @pytest.fixture def group_model(): return GroupModel(id=1, name="Test Group") @pytest.fixture def db_list_personal_model(user_model): return ListModel( id=1, name="Personal List", created_by_id=user_model.id, creator=user_model, version=1, updated_at=datetime.now(timezone.utc), items=[] ) @pytest.fixture def db_list_group_model(user_model, group_model): return ListModel( id=2, name="Group List", created_by_id=user_model.id, creator=user_model, group_id=group_model.id, group=group_model, version=1, updated_at=datetime.now(timezone.utc), items=[] ) # --- create_list Tests --- @pytest.mark.asyncio async def test_create_list_success(mock_db_session, list_create_data, user_model): async def mock_refresh(instance): instance.id = 100 instance.version = 1 instance.updated_at = datetime.now(timezone.utc) return None mock_db_session.refresh.side_effect = mock_refresh mock_result = AsyncMock() mock_result.scalar_one_or_none.return_value = ListModel( id=100, name=list_create_data.name, description=list_create_data.description, created_by_id=user_model.id, version=1, updated_at=datetime.now(timezone.utc) ) mock_db_session.execute.return_value = mock_result created_list = await create_list(mock_db_session, list_create_data, user_model.id) mock_db_session.add.assert_called_once() mock_db_session.flush.assert_called_once() assert created_list.name == list_create_data.name assert created_list.created_by_id == user_model.id # --- get_lists_for_user Tests --- @pytest.mark.asyncio async def test_get_lists_for_user_mix(mock_db_session, user_model, db_list_personal_model, db_list_group_model): # Mock for the object returned by .scalars() for group_ids query mock_group_ids_scalar_result = MagicMock() mock_group_ids_scalar_result.all.return_value = [db_list_group_model.group_id] # Mock for the object returned by await session.execute() for group_ids query mock_group_ids_execute_result = MagicMock() mock_group_ids_execute_result.scalars.return_value = mock_group_ids_scalar_result # Mock for the object returned by .scalars() for lists query mock_lists_scalar_result = MagicMock() mock_lists_scalar_result.all.return_value = [db_list_personal_model, db_list_group_model] # Mock for the object returned by await session.execute() for lists query mock_lists_execute_result = MagicMock() mock_lists_execute_result.scalars.return_value = mock_lists_scalar_result mock_db_session.execute.side_effect = [mock_group_ids_execute_result, mock_lists_execute_result] lists = await get_lists_for_user(mock_db_session, user_model.id) assert len(lists) == 2 assert db_list_personal_model in lists assert db_list_group_model in lists assert mock_db_session.execute.call_count == 2 # --- get_list_by_id Tests --- @pytest.mark.asyncio async def test_get_list_by_id_found_no_items(mock_db_session, db_list_personal_model): # Mock for the object returned by .scalars() mock_scalar_result = MagicMock() mock_scalar_result.first.return_value = db_list_personal_model # Mock for the object returned by await session.execute() mock_execute_result = MagicMock() mock_execute_result.scalars.return_value = mock_scalar_result mock_db_session.execute.return_value = mock_execute_result found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=False) assert found_list is not None assert found_list.id == db_list_personal_model.id @pytest.mark.asyncio async def test_get_list_by_id_found_with_items(mock_db_session, db_list_personal_model): db_list_personal_model.items = [ItemModel(id=1, name="Test Item")] # Mock for the object returned by .scalars() mock_scalar_result = MagicMock() mock_scalar_result.first.return_value = db_list_personal_model # Mock for the object returned by await session.execute() mock_execute_result = MagicMock() mock_execute_result.scalars.return_value = mock_scalar_result mock_db_session.execute.return_value = mock_execute_result found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=True) assert found_list is not None assert len(found_list.items) == 1 # --- update_list Tests --- @pytest.mark.asyncio async def test_update_list_success(mock_db_session, db_list_personal_model, list_update_data): list_update_data.version = db_list_personal_model.version mock_result = AsyncMock() mock_result.scalar_one_or_none.return_value = db_list_personal_model mock_db_session.execute.return_value = mock_result updated_list = await update_list(mock_db_session, db_list_personal_model, list_update_data) assert updated_list.name == list_update_data.name assert updated_list.version == db_list_personal_model.version + 1 mock_db_session.add.assert_called_once_with(db_list_personal_model) mock_db_session.flush.assert_called_once() @pytest.mark.asyncio async def test_update_list_conflict(mock_db_session, db_list_personal_model, list_update_data): list_update_data.version = db_list_personal_model.version + 1 # Simulate version mismatch # When update_list is called with a version mismatch, it should raise ConflictError with pytest.raises(ConflictError): await update_list(mock_db_session, db_list_personal_model, list_update_data) # Ensure rollback was called if a conflict occurred and was handled within update_list # This depends on how update_list implements error handling. # If update_list is expected to call session.rollback(), this assertion is valid. # If the caller of update_list is responsible for rollback, this might not be asserted here. # Based on the provided context, ConflictError is raised by update_list, # implying internal rollback or no changes persisted. # Let's assume for now the function itself handles rollback or prevents commit. # mock_db_session.rollback.assert_called_once() # This might be too specific depending on impl. # --- delete_list Tests --- @pytest.mark.asyncio async def test_delete_list_success(mock_db_session, db_list_personal_model): await delete_list(mock_db_session, db_list_personal_model) mock_db_session.delete.assert_called_once_with(db_list_personal_model) # mock_db_session.flush.assert_called_once() # delete usually implies a flush # --- check_list_permission Tests --- @pytest.mark.asyncio async def test_check_list_permission_creator_access_personal_list(mock_db_session, db_list_personal_model, user_model): # Mock for the object returned by .scalars() mock_scalar_result = MagicMock() mock_scalar_result.first.return_value = db_list_personal_model # Mock for the object returned by await session.execute() mock_execute_result = MagicMock() mock_execute_result.scalars.return_value = mock_scalar_result mock_db_session.execute.return_value = mock_execute_result ret_list = await check_list_permission(mock_db_session, db_list_personal_model.id, user_model.id) assert ret_list.id == db_list_personal_model.id @pytest.mark.asyncio async def test_check_list_permission_group_member_access_group_list(mock_db_session, db_list_group_model, another_user_model, group_model): # Mock for the object returned by .scalars() mock_scalar_result = MagicMock() mock_scalar_result.first.return_value = db_list_group_model # Mock for the object returned by await session.execute() mock_execute_result = MagicMock() mock_execute_result.scalars.return_value = mock_scalar_result mock_db_session.execute.return_value = mock_execute_result with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member: mock_is_member.return_value = True ret_list = await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id) assert ret_list.id == db_list_group_model.id mock_is_member.assert_called_once_with(mock_db_session, group_id=group_model.id, user_id=another_user_model.id) @pytest.mark.asyncio async def test_check_list_permission_non_member_no_access_group_list(mock_db_session, db_list_group_model, another_user_model): # Mock for the object returned by .scalars() mock_scalar_result = MagicMock() mock_scalar_result.first.return_value = db_list_group_model # Mock for the object returned by await session.execute() mock_execute_result = MagicMock() mock_execute_result.scalars.return_value = mock_scalar_result mock_db_session.execute.return_value = mock_execute_result with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member: mock_is_member.return_value = False with pytest.raises(ListPermissionError): await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id) @pytest.mark.asyncio async def test_check_list_permission_creator_required_fail(mock_db_session, db_list_group_model, another_user_model): # Simulate another_user_model is not the creator of db_list_group_model # db_list_group_model.created_by_id is user_model.id (1), another_user_model.id is 2 # Mock for the object returned by .scalars() mock_scalar_result = MagicMock() mock_scalar_result.first.return_value = db_list_group_model # List is found # Mock for the object returned by await session.execute() mock_execute_result = MagicMock() mock_execute_result.scalars.return_value = mock_scalar_result mock_db_session.execute.return_value = mock_execute_result # No need to mock is_user_member if require_creator is True and user is not creator with pytest.raises(ListCreatorRequiredError): await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id, require_creator=True) @pytest.mark.asyncio async def test_check_list_permission_list_not_found(mock_db_session, user_model): # Mock for the object returned by .scalars() mock_scalar_result = MagicMock() mock_scalar_result.first.return_value = None # Simulate list not found # Mock for the object returned by await session.execute() mock_execute_result = MagicMock() mock_execute_result.scalars.return_value = mock_scalar_result mock_db_session.execute.return_value = mock_execute_result with pytest.raises(ListNotFoundError): await check_list_permission(mock_db_session, 999, user_model.id) # --- get_list_status Tests --- @pytest.mark.asyncio async def test_get_list_status_success(mock_db_session, db_list_personal_model): # This test is more complex due to multiple potential execute calls or specific query structures # For simplicity, assuming the primary query for the list model uses the same pattern: # Mock for finding the list by ID (first execute call in get_list_status) mock_list_scalar = MagicMock() mock_list_scalar.first.return_value = db_list_personal_model mock_list_execute = MagicMock() mock_list_execute.scalars.return_value = mock_list_scalar # Mock for counting total items (second execute call) mock_total_items_scalar = MagicMock() mock_total_items_scalar.one.return_value = 5 mock_total_items_execute = MagicMock() mock_total_items_execute.scalars.return_value = mock_total_items_scalar # Mock for counting completed items (third execute call) mock_completed_items_scalar = MagicMock() mock_completed_items_scalar.one.return_value = 2 mock_completed_items_execute = MagicMock() mock_completed_items_execute.scalars.return_value = mock_completed_items_scalar mock_db_session.execute.side_effect = [ mock_list_execute, mock_total_items_execute, mock_completed_items_execute ] status = await get_list_status(mock_db_session, db_list_personal_model.id) assert status.list_id == db_list_personal_model.id assert status.total_items == 5 assert status.completed_items == 2 assert status.name == db_list_personal_model.name assert mock_db_session.execute.call_count == 3 @pytest.mark.asyncio async def test_get_list_status_list_not_found(mock_db_session): # Mock for the object returned by .scalars() mock_scalar_result = MagicMock() mock_scalar_result.first.return_value = None # List not found # Mock for the object returned by await session.execute() mock_execute_result = MagicMock() mock_execute_result.scalars.return_value = mock_scalar_result mock_db_session.execute.return_value = mock_execute_result with pytest.raises(ListNotFoundError): await get_list_status(mock_db_session, 999) # TODO: Add more specific DB error tests (Operational, SQLAlchemyError, IntegrityError) for each function. # TODO: Test check_list_permission with require_creator=True cases.