# Credits: https://github.com/ThomasAitken/demo-fastapi-async-sqlalchemy/blob/main/backend/app/conftest.py import asyncio from contextlib import ExitStack import pytest from alembic.config import Config from alembic.migration import MigrationContext from alembic.operations import Operations from alembic.script import ScriptDirectory from config.index import config as settings from app.engine.postgresdb import Base, get_db_session, postgresdb as sessionmanager from main import app as actual_app from asyncpg import Connection from fastapi.testclient import TestClient @pytest.fixture(autouse=True) def app(): with ExitStack(): yield actual_app @pytest.fixture def client(app): with TestClient(app) as c: yield c @pytest.fixture(scope="session") def event_loop(request): loop = asyncio.get_event_loop_policy().new_event_loop() yield loop loop.close() def run_migrations(connection: Connection): config = Config("alembic.ini") config.set_main_option("script_location", "app/migration") config.set_main_option("sqlalchemy.url", settings.SQLALCHEMY_TEST_DATABASE_URL) script = ScriptDirectory.from_config(config) def upgrade(rev, context): return script._upgrade_revs("head", rev) context = MigrationContext.configure(connection, opts={"target_metadata": Base.metadata, "fn": upgrade}) with context.begin_transaction(): with Operations.context(context): context.run_migrations() @pytest.fixture(scope="session", autouse=True) async def setup_database(): # Run alembic migrations on test DB async with sessionmanager.connect() as connection: await connection.run_sync(run_migrations) yield # Teardown await sessionmanager.close() # Each test function is a clean slate @pytest.fixture(scope="function", autouse=True) async def transactional_session(): async with sessionmanager.session() as session: try: await session.begin() yield session finally: await session.rollback() # Rolls back the outer transaction @pytest.fixture(scope="function") async def db_session(transactional_session): yield transactional_session @pytest.fixture(scope="function", autouse=True) async def session_override(app, db_session): async def get_db_session_override(): yield db_session[0] app.dependency_overrides[get_db_session] = get_db_session_override