import pytest import asyncio from typing import AsyncGenerator from fastapi.testclient import TestClient from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool from app.main import app from app.models import Base from app.database import get_db from app.config import settings # Create test database engine TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" engine = create_async_engine( TEST_DATABASE_URL, connect_args={"check_same_thread": False}, poolclass=StaticPool, ) TestingSessionLocal = sessionmaker( engine, class_=AsyncSession, expire_on_commit=False ) @pytest.fixture(scope="session") def event_loop(): """Create an instance of the default event loop for each test case.""" loop = asyncio.get_event_loop_policy().new_event_loop() yield loop loop.close() @pytest.fixture(scope="session") async def test_db(): """Create test database and tables.""" async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) @pytest.fixture async def db_session(test_db) -> AsyncGenerator[AsyncSession, None]: """Create a fresh database session for each test.""" async with TestingSessionLocal() as session: yield session @pytest.fixture async def client(db_session) -> AsyncGenerator[TestClient, None]: """Create a test client with the test database session.""" async def override_get_db(): yield db_session app.dependency_overrides[get_db] = override_get_db with TestClient(app) as test_client: yield test_client app.dependency_overrides.clear()