Spaces:
Sleeping
Sleeping
import logging | |
import contextlib | |
from typing import Any, AsyncIterator | |
from sqlalchemy.ext.asyncio import ( | |
AsyncConnection, | |
AsyncSession, | |
async_sessionmaker, | |
create_async_engine, | |
) | |
from sqlalchemy.orm import declarative_base | |
from config.index import config as env | |
logger = logging.getLogger(__name__) | |
Base = declarative_base() | |
class PostgresDatabase: | |
def __init__(self, host: str, engine_kwargs: dict[str, Any] = {}): | |
self._engine = create_async_engine(host, **engine_kwargs) | |
self._sessionmaker = async_sessionmaker(autocommit=False, bind=self._engine) | |
async def close(self): | |
if self._engine is None: | |
raise Exception("DatabaseSessionManager is not initialized") | |
await self._engine.dispose() | |
self._engine = None | |
self._sessionmaker = None | |
async def connect(self) -> AsyncIterator[AsyncConnection]: | |
if self._engine is None: | |
raise Exception("DatabaseSessionManager is not initialized") | |
async with self._engine.begin() as connection: | |
try: | |
yield connection | |
except Exception: | |
await connection.rollback() | |
raise | |
async def session(self) -> AsyncIterator[AsyncSession]: | |
if self._sessionmaker is None: | |
raise Exception("DatabaseSessionManager is not initialized") | |
session = self._sessionmaker() | |
try: | |
yield session | |
except Exception: | |
await session.rollback() | |
raise | |
finally: | |
await session.close() | |
def get_engine(self): | |
return self._engine | |
postgresdb = PostgresDatabase(env.SQLALCHEMY_DATABASE_URL, {"echo": True, "future": True}) | |
async def get_db_session(): | |
async with postgresdb.session() as session: | |
yield session | |