import logging import contextlib from typing import Any, AsyncIterator from sqlalchemy.ext.asyncio import ( AsyncConnection, AsyncSession, AsyncEngine, async_sessionmaker, create_async_engine, ) from sqlalchemy.orm import declarative_base from config.index import config as env logger = logging.getLogger(__name__) Base = declarative_base() class PostgresDatabase: def __init__(self): self._engine: AsyncEngine | None = None self._sessionmaker: async_sessionmaker | None = None def init(self, host: str, engine_kwargs: dict[str, Any] = {}): self._engine = create_async_engine(host, **engine_kwargs) self._sessionmaker = async_sessionmaker(autocommit=False, bind=self._engine) async def close(self): if self._engine is None: raise Exception("DatabaseSessionManager is not initialized") await self._engine.dispose() self._engine = None self._sessionmaker = None @contextlib.asynccontextmanager async def connect(self) -> AsyncIterator[AsyncConnection]: if self._engine is None: raise Exception("DatabaseSessionManager is not initialized") async with self._engine.begin() as connection: try: yield connection except Exception: await connection.rollback() raise @contextlib.asynccontextmanager async def session(self) -> AsyncIterator[AsyncSession]: if self._sessionmaker is None: raise Exception("DatabaseSessionManager is not initialized") session = self._sessionmaker() try: yield session except Exception: await session.rollback() raise finally: await session.close() def get_engine(self): return self._engine # Used for testing async def create_all(self, connection: AsyncConnection): await connection.run_sync(Base.metadata.create_all) async def drop_all(self, connection: AsyncConnection): await connection.run_sync(Base.metadata.drop_all) postgresdb = PostgresDatabase() async def get_db_session(): async with postgresdb.session() as session: yield session