184 lines
7.2 KiB
Python
184 lines
7.2 KiB
Python
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError
|
|
from datetime import datetime, timezone
|
|
|
|
from app.crud.item import (
|
|
create_item,
|
|
get_items_by_list_id,
|
|
get_item_by_id,
|
|
update_item,
|
|
delete_item
|
|
)
|
|
from app.schemas.item import ItemCreate, ItemUpdate
|
|
from app.models import Item as ItemModel, User as UserModel, List as ListModel
|
|
from app.core.exceptions import (
|
|
ItemNotFoundError, # Not directly raised by CRUD but good for API layer tests
|
|
DatabaseConnectionError,
|
|
DatabaseIntegrityError,
|
|
DatabaseQueryError,
|
|
DatabaseTransactionError,
|
|
ConflictError
|
|
)
|
|
|
|
# Fixtures
|
|
@pytest.fixture
|
|
def mock_db_session():
|
|
session = AsyncMock()
|
|
session.begin = AsyncMock()
|
|
session.commit = AsyncMock()
|
|
session.rollback = AsyncMock()
|
|
session.refresh = AsyncMock()
|
|
session.add = MagicMock()
|
|
session.delete = MagicMock()
|
|
session.execute = AsyncMock()
|
|
session.get = AsyncMock() # Though not directly used in item.py, good for consistency
|
|
session.flush = AsyncMock()
|
|
return session
|
|
|
|
@pytest.fixture
|
|
def item_create_data():
|
|
return ItemCreate(name="Test Item", quantity="1 pack")
|
|
|
|
@pytest.fixture
|
|
def item_update_data():
|
|
return ItemUpdate(name="Updated Test Item", quantity="2 packs", version=1, is_complete=False)
|
|
|
|
@pytest.fixture
|
|
def user_model():
|
|
return UserModel(id=1, name="Test User", email="test@example.com")
|
|
|
|
@pytest.fixture
|
|
def list_model():
|
|
return ListModel(id=1, name="Test List")
|
|
|
|
@pytest.fixture
|
|
def db_item_model(list_model, user_model):
|
|
return ItemModel(
|
|
id=1,
|
|
name="Existing Item",
|
|
quantity="1 unit",
|
|
list_id=list_model.id,
|
|
added_by_id=user_model.id,
|
|
is_complete=False,
|
|
version=1,
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc)
|
|
)
|
|
|
|
# --- create_item Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_create_item_success(mock_db_session, item_create_data, list_model, user_model):
|
|
async def mock_refresh(instance):
|
|
instance.id = 10 # Simulate ID assignment
|
|
instance.version = 1 # Simulate version init
|
|
instance.created_at = datetime.now(timezone.utc)
|
|
instance.updated_at = datetime.now(timezone.utc)
|
|
return None
|
|
mock_db_session.refresh = AsyncMock(side_effect=mock_refresh)
|
|
|
|
created_item = await create_item(mock_db_session, item_create_data, list_model.id, user_model.id)
|
|
|
|
mock_db_session.add.assert_called_once()
|
|
mock_db_session.flush.assert_called_once()
|
|
mock_db_session.refresh.assert_called_once_with(created_item)
|
|
assert created_item is not None
|
|
assert created_item.name == item_create_data.name
|
|
assert created_item.list_id == list_model.id
|
|
assert created_item.added_by_id == user_model.id
|
|
assert created_item.is_complete is False
|
|
assert created_item.version == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_item_integrity_error(mock_db_session, item_create_data, list_model, user_model):
|
|
mock_db_session.flush.side_effect = IntegrityError("mock integrity error", "params", "orig")
|
|
with pytest.raises(DatabaseIntegrityError):
|
|
await create_item(mock_db_session, item_create_data, list_model.id, user_model.id)
|
|
mock_db_session.rollback.assert_called_once()
|
|
|
|
# --- get_items_by_list_id Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_get_items_by_list_id_success(mock_db_session, db_item_model, list_model):
|
|
mock_result = AsyncMock()
|
|
mock_result.scalars.return_value.all.return_value = [db_item_model]
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
items = await get_items_by_list_id(mock_db_session, list_model.id)
|
|
assert len(items) == 1
|
|
assert items[0].id == db_item_model.id
|
|
mock_db_session.execute.assert_called_once()
|
|
|
|
# --- get_item_by_id Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_get_item_by_id_found(mock_db_session, db_item_model):
|
|
mock_result = AsyncMock()
|
|
mock_result.scalars.return_value.first.return_value = db_item_model
|
|
mock_db_session.execute.return_value = mock_result
|
|
item = await get_item_by_id(mock_db_session, db_item_model.id)
|
|
assert item is not None
|
|
assert item.id == db_item_model.id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_item_by_id_not_found(mock_db_session):
|
|
mock_result = AsyncMock()
|
|
mock_result.scalars.return_value.first.return_value = None
|
|
mock_db_session.execute.return_value = mock_result
|
|
item = await get_item_by_id(mock_db_session, 999)
|
|
assert item is None
|
|
|
|
# --- update_item Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_update_item_success(mock_db_session, db_item_model, item_update_data, user_model):
|
|
item_update_data.version = db_item_model.version # Match versions for successful update
|
|
item_update_data.name = "Newly Updated Name"
|
|
item_update_data.is_complete = True # Test completion logic
|
|
|
|
updated_item = await update_item(mock_db_session, db_item_model, item_update_data, user_model.id)
|
|
|
|
mock_db_session.add.assert_called_once_with(db_item_model) # add is used for existing objects too
|
|
mock_db_session.flush.assert_called_once()
|
|
mock_db_session.refresh.assert_called_once_with(db_item_model)
|
|
assert updated_item.name == "Newly Updated Name"
|
|
assert updated_item.version == db_item_model.version # Check version increment logic in test
|
|
assert updated_item.is_complete is True
|
|
assert updated_item.completed_by_id == user_model.id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_item_version_conflict(mock_db_session, db_item_model, item_update_data, user_model):
|
|
item_update_data.version = db_item_model.version + 1 # Create a version mismatch
|
|
with pytest.raises(ConflictError):
|
|
await update_item(mock_db_session, db_item_model, item_update_data, user_model.id)
|
|
mock_db_session.rollback.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_item_set_incomplete(mock_db_session, db_item_model, item_update_data, user_model):
|
|
db_item_model.is_complete = True # Start as complete
|
|
db_item_model.completed_by_id = user_model.id
|
|
db_item_model.version = 1
|
|
|
|
item_update_data.version = 1
|
|
item_update_data.is_complete = False
|
|
item_update_data.name = db_item_model.name # No name change for this test
|
|
item_update_data.quantity = db_item_model.quantity
|
|
|
|
updated_item = await update_item(mock_db_session, db_item_model, item_update_data, user_model.id)
|
|
assert updated_item.is_complete is False
|
|
assert updated_item.completed_by_id is None
|
|
assert updated_item.version == 2
|
|
|
|
# --- delete_item Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_delete_item_success(mock_db_session, db_item_model):
|
|
result = await delete_item(mock_db_session, db_item_model)
|
|
assert result is None
|
|
mock_db_session.delete.assert_called_once_with(db_item_model)
|
|
mock_db_session.commit.assert_called_once() # Commit happens in the `async with db.begin()` context manager
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_item_db_error(mock_db_session, db_item_model):
|
|
mock_db_session.delete.side_effect = OperationalError("mock op error", "params", "orig")
|
|
with pytest.raises(DatabaseConnectionError):
|
|
await delete_item(mock_db_session, db_item_model)
|
|
mock_db_session.rollback.assert_called_once()
|
|
|
|
# TODO: Add more specific DB error tests (Operational, SQLAlchemyError) for each function. |