Spaces:
Sleeping
Sleeping
# Credits: https://github.com/ThomasAitken/demo-fastapi-async-sqlalchemy/blob/main/backend/app/conftest.py | |
import asyncio | |
from contextlib import ExitStack | |
import pytest | |
from alembic.config import Config | |
from alembic.migration import MigrationContext | |
from alembic.operations import Operations | |
from alembic.script import ScriptDirectory | |
from config.index import config as settings | |
from app.engine.postgresdb import Base, get_db_session, postgresdb as sessionmanager | |
from main import init_app | |
from asyncpg import Connection | |
from fastapi.testclient import TestClient | |
from pytest_postgresql import factories | |
from pytest_postgresql.factories.noprocess import postgresql_noproc | |
from pytest_postgresql.janitor import DatabaseJanitor | |
from sqlalchemy.testing.entities import ComparableEntity | |
from config.index import config as env | |
test_db = factories.postgresql_proc(dbname="test_db", port=5433) | |
def app(): | |
with ExitStack(): | |
# Don't initialize database connection. | |
# This is because we want to initialize the database connection manually, so that we can create the test database. | |
yield init_app(init_db=False) | |
def client(app): | |
with TestClient(app) as c: | |
yield c | |
def event_loop(request): | |
loop = asyncio.get_event_loop_policy().new_event_loop() | |
yield loop | |
loop.close() | |
async def connection_test(test_db, event_loop): | |
pg_host = test_db.host | |
pg_port = test_db.port | |
pg_user = test_db.user | |
pg_db = test_db.dbname | |
pg_password = test_db.password | |
with DatabaseJanitor(user=pg_user, host=pg_host, port=pg_port, dbname=pg_db, version=test_db.version, password=pg_password): | |
connection_str = f"postgresql+psycopg://{pg_user}:@{pg_host}:{pg_port}/{pg_db}" | |
sessionmanager.init(connection_str, | |
# {"echo": True, "future": True} | |
) | |
yield | |
await sessionmanager.close() | |
async def create_tables(connection_test): | |
async with sessionmanager.connect() as connection: | |
await sessionmanager.drop_all(connection) | |
await sessionmanager.create_all(connection) | |
async def session_override(app, connection_test): | |
async def get_db_session_override(): | |
async with sessionmanager.session() as session: | |
yield session | |
app.dependency_overrides[get_db_session] = get_db_session_override | |
async def get_db_session_fixture(): | |
async with sessionmanager.session() as session: | |
yield session | |