174 lines
7.2 KiB
Python
174 lines
7.2 KiB
Python
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError # Assuming these might be raised
|
|
from datetime import datetime, timedelta, timezone
|
|
import secrets
|
|
|
|
from app.crud.invite import (
|
|
create_invite,
|
|
get_active_invite_by_code,
|
|
deactivate_invite,
|
|
MAX_CODE_GENERATION_ATTEMPTS
|
|
)
|
|
from app.models import Invite as InviteModel, User as UserModel, Group as GroupModel # For context
|
|
# No specific schemas for invite CRUD usually, but models are used.
|
|
|
|
# Fixtures
|
|
@pytest.fixture
|
|
def mock_db_session():
|
|
session = AsyncMock()
|
|
session.commit = AsyncMock()
|
|
session.rollback = AsyncMock()
|
|
session.refresh = AsyncMock()
|
|
session.add = MagicMock()
|
|
session.execute = AsyncMock()
|
|
return session
|
|
|
|
@pytest.fixture
|
|
def group_model():
|
|
return GroupModel(id=1, name="Test Group")
|
|
|
|
@pytest.fixture
|
|
def user_model(): # Creator
|
|
return UserModel(id=1, name="Creator User")
|
|
|
|
@pytest.fixture
|
|
def db_invite_model(group_model, user_model):
|
|
return InviteModel(
|
|
id=1,
|
|
code="test_invite_code_123",
|
|
group_id=group_model.id,
|
|
created_by_id=user_model.id,
|
|
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
|
is_active=True
|
|
)
|
|
|
|
# --- create_invite Tests ---
|
|
@pytest.mark.asyncio
|
|
@patch('app.crud.invite.secrets.token_urlsafe') # Patch secrets.token_urlsafe
|
|
async def test_create_invite_success_first_attempt(mock_token_urlsafe, mock_db_session, group_model, user_model):
|
|
generated_code = "unique_code_123"
|
|
mock_token_urlsafe.return_value = generated_code
|
|
|
|
# Mock DB execute for checking existing code (first attempt, no existing code)
|
|
mock_existing_check_result = AsyncMock()
|
|
mock_existing_check_result.scalar_one_or_none.return_value = None
|
|
mock_db_session.execute.return_value = mock_existing_check_result
|
|
|
|
invite = await create_invite(mock_db_session, group_model.id, user_model.id, expires_in_days=5)
|
|
|
|
mock_token_urlsafe.assert_called_once_with(16)
|
|
mock_db_session.execute.assert_called_once() # For the uniqueness check
|
|
mock_db_session.add.assert_called_once()
|
|
mock_db_session.commit.assert_called_once()
|
|
mock_db_session.refresh.assert_called_once_with(invite)
|
|
|
|
assert invite is not None
|
|
assert invite.code == generated_code
|
|
assert invite.group_id == group_model.id
|
|
assert invite.created_by_id == user_model.id
|
|
assert invite.is_active is True
|
|
assert invite.expires_at > datetime.now(timezone.utc) + timedelta(days=4) # Check expiry is roughly correct
|
|
|
|
@pytest.mark.asyncio
|
|
@patch('app.crud.invite.secrets.token_urlsafe')
|
|
async def test_create_invite_success_after_collision(mock_token_urlsafe, mock_db_session, group_model, user_model):
|
|
colliding_code = "colliding_code"
|
|
unique_code = "finally_unique_code"
|
|
mock_token_urlsafe.side_effect = [colliding_code, unique_code] # First call collides, second is unique
|
|
|
|
# Mock DB execute for checking existing code
|
|
mock_collision_check_result = AsyncMock()
|
|
mock_collision_check_result.scalar_one_or_none.return_value = 1 # Simulate collision (ID found)
|
|
|
|
mock_no_collision_check_result = AsyncMock()
|
|
mock_no_collision_check_result.scalar_one_or_none.return_value = None # No collision
|
|
|
|
mock_db_session.execute.side_effect = [mock_collision_check_result, mock_no_collision_check_result]
|
|
|
|
invite = await create_invite(mock_db_session, group_model.id, user_model.id)
|
|
|
|
assert mock_token_urlsafe.call_count == 2
|
|
assert mock_db_session.execute.call_count == 2
|
|
assert invite is not None
|
|
assert invite.code == unique_code
|
|
|
|
@pytest.mark.asyncio
|
|
@patch('app.crud.invite.secrets.token_urlsafe')
|
|
async def test_create_invite_fails_after_max_attempts(mock_token_urlsafe, mock_db_session, group_model, user_model):
|
|
mock_token_urlsafe.return_value = "always_colliding_code"
|
|
|
|
mock_collision_check_result = AsyncMock()
|
|
mock_collision_check_result.scalar_one_or_none.return_value = 1 # Always collide
|
|
mock_db_session.execute.return_value = mock_collision_check_result
|
|
|
|
invite = await create_invite(mock_db_session, group_model.id, user_model.id)
|
|
|
|
assert invite is None
|
|
assert mock_token_urlsafe.call_count == MAX_CODE_GENERATION_ATTEMPTS
|
|
assert mock_db_session.execute.call_count == MAX_CODE_GENERATION_ATTEMPTS
|
|
mock_db_session.add.assert_not_called()
|
|
|
|
# --- get_active_invite_by_code Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_get_active_invite_by_code_found_active(mock_db_session, db_invite_model):
|
|
db_invite_model.is_active = True
|
|
db_invite_model.expires_at = datetime.now(timezone.utc) + timedelta(days=1)
|
|
|
|
mock_result = AsyncMock()
|
|
mock_result.scalars.return_value.first.return_value = db_invite_model
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code)
|
|
assert invite is not None
|
|
assert invite.code == db_invite_model.code
|
|
mock_db_session.execute.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_active_invite_by_code_not_found(mock_db_session):
|
|
mock_result = AsyncMock()
|
|
mock_result.scalars.return_value.first.return_value = None
|
|
mock_db_session.execute.return_value = mock_result
|
|
invite = await get_active_invite_by_code(mock_db_session, "non_existent_code")
|
|
assert invite is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_active_invite_by_code_inactive(mock_db_session, db_invite_model):
|
|
db_invite_model.is_active = False # Inactive
|
|
db_invite_model.expires_at = datetime.now(timezone.utc) + timedelta(days=1)
|
|
|
|
mock_result = AsyncMock()
|
|
mock_result.scalars.return_value.first.return_value = None # Should not be found by query
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code)
|
|
assert invite is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_active_invite_by_code_expired(mock_db_session, db_invite_model):
|
|
db_invite_model.is_active = True
|
|
db_invite_model.expires_at = datetime.now(timezone.utc) - timedelta(days=1) # Expired
|
|
|
|
mock_result = AsyncMock()
|
|
mock_result.scalars.return_value.first.return_value = None # Should not be found by query
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code)
|
|
assert invite is None
|
|
|
|
# --- deactivate_invite Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_deactivate_invite_success(mock_db_session, db_invite_model):
|
|
db_invite_model.is_active = True # Ensure it starts active
|
|
|
|
deactivated_invite = await deactivate_invite(mock_db_session, db_invite_model)
|
|
|
|
mock_db_session.add.assert_called_once_with(db_invite_model)
|
|
mock_db_session.commit.assert_called_once()
|
|
mock_db_session.refresh.assert_called_once_with(db_invite_model)
|
|
assert deactivated_invite.is_active is False
|
|
|
|
# It might be useful to test DB error cases (OperationalError, etc.) for each function
|
|
# if they have specific try-except blocks, but invite.py seems to rely on caller/framework for some of that.
|
|
# create_invite has its own DB interaction within the loop, so that's covered.
|
|
# get_active_invite_by_code and deactivate_invite are simpler DB ops. |