|
|
import logging |
|
|
from contextlib import asynccontextmanager |
|
|
|
|
|
from core.embeddings import get_embeddings |
|
|
from langchain_postgres import PGVector |
|
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver |
|
|
from langgraph.store.postgres import AsyncPostgresStore |
|
|
from psycopg.rows import dict_row |
|
|
from psycopg_pool import AsyncConnectionPool |
|
|
|
|
|
from core.settings import settings |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def validate_postgres_config() -> None: |
|
|
""" |
|
|
Validate that all required PostgreSQL configuration is present. |
|
|
Raises ValueError if any required configuration is missing. |
|
|
""" |
|
|
required_vars = [ |
|
|
"POSTGRES_USER", |
|
|
"POSTGRES_PASSWORD", |
|
|
"POSTGRES_HOST", |
|
|
"POSTGRES_PORT", |
|
|
"POSTGRES_DB", |
|
|
] |
|
|
|
|
|
missing = [var for var in required_vars if not getattr(settings, var, None)] |
|
|
if missing: |
|
|
raise ValueError( |
|
|
f"Missing required PostgreSQL configuration: {', '.join(missing)}. " |
|
|
"All individual POSTGRES_* environment variables must be set to use PostgreSQL persistence." |
|
|
) |
|
|
|
|
|
if settings.POSTGRES_MIN_CONNECTIONS_PER_POOL > settings.POSTGRES_MAX_CONNECTIONS_PER_POOL: |
|
|
raise ValueError( |
|
|
f"POSTGRES_MIN_CONNECTIONS_PER_POOL ({settings.POSTGRES_MIN_CONNECTIONS_PER_POOL}) must be less than or equal to POSTGRES_MAX_CONNECTIONS_PER_POOL ({settings.POSTGRES_MAX_CONNECTIONS_PER_POOL})" |
|
|
) |
|
|
|
|
|
|
|
|
def get_postgres_connection_string() -> str: |
|
|
"""Build and return the PostgreSQL connection string from settings.""" |
|
|
if settings.POSTGRES_PASSWORD is None: |
|
|
raise ValueError("POSTGRES_PASSWORD is not set") |
|
|
return ( |
|
|
f"postgresql://{settings.POSTGRES_USER}:" |
|
|
f"{settings.POSTGRES_PASSWORD.get_secret_value()}@" |
|
|
f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/" |
|
|
f"{settings.POSTGRES_DB}" |
|
|
) |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def get_postgres_saver(): |
|
|
"""Initialize and return a PostgreSQL saver instance based on a connection pool for more resilient connections.""" |
|
|
validate_postgres_config() |
|
|
application_name = settings.POSTGRES_APPLICATION_NAME + "-" + "saver" |
|
|
|
|
|
async with AsyncConnectionPool( |
|
|
get_postgres_connection_string(), |
|
|
min_size=settings.POSTGRES_MIN_CONNECTIONS_PER_POOL, |
|
|
max_size=settings.POSTGRES_MAX_CONNECTIONS_PER_POOL, |
|
|
|
|
|
|
|
|
kwargs={"autocommit": True, "row_factory": dict_row, "application_name": application_name}, |
|
|
|
|
|
check=AsyncConnectionPool.check_connection, |
|
|
) as pool: |
|
|
try: |
|
|
checkpointer = AsyncPostgresSaver(pool) |
|
|
await checkpointer.setup() |
|
|
yield checkpointer |
|
|
finally: |
|
|
await pool.close() |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def get_postgres_store(): |
|
|
""" |
|
|
Get a PostgreSQL store instance based on a connection pool for more resilient connections. |
|
|
|
|
|
Returns an AsyncPostgresStore instance that can be used with async context manager pattern. |
|
|
|
|
|
""" |
|
|
validate_postgres_config() |
|
|
application_name = settings.POSTGRES_APPLICATION_NAME + "-" + "store" |
|
|
|
|
|
async with AsyncConnectionPool( |
|
|
get_postgres_connection_string(), |
|
|
min_size=settings.POSTGRES_MIN_CONNECTIONS_PER_POOL, |
|
|
max_size=settings.POSTGRES_MAX_CONNECTIONS_PER_POOL, |
|
|
|
|
|
|
|
|
kwargs={"autocommit": True, "row_factory": dict_row, "application_name": application_name}, |
|
|
|
|
|
check=AsyncConnectionPool.check_connection, |
|
|
) as pool: |
|
|
try: |
|
|
store = AsyncPostgresStore(pool) |
|
|
await store.setup() |
|
|
yield store |
|
|
finally: |
|
|
await pool.close() |
|
|
|
|
|
def get_pgvector_connection_string() -> str: |
|
|
"""Build and return the PostgreSQL connection string for vectors from settings.""" |
|
|
return ( |
|
|
f"postgresql+psycopg://{settings.POSTGRES_USER}:" |
|
|
f"{settings.POSTGRES_PASSWORD.get_secret_value()}@" |
|
|
f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/" |
|
|
f"{settings.POSTGRES_DB}?sslmode=require" |
|
|
) |
|
|
|
|
|
def load_pgvector_store(): |
|
|
"""Get a PostgreSQL vectors store instance.""" |
|
|
validate_postgres_config() |
|
|
|
|
|
return PGVector( |
|
|
connection=get_pgvector_connection_string(), |
|
|
collection_name=settings.VECTOR_STORE_COLLECTION_NAME, |
|
|
embeddings=get_embeddings(settings.DEFAULT_EMBEDDING_MODEL), |
|
|
) |
|
|
|
|
|
def load_pgvector_retriever(k: int = 6): |
|
|
store = load_pgvector_store() |
|
|
return store.as_retriever( |
|
|
search_type="mmr", |
|
|
search_kwargs={ |
|
|
"k": k, |
|
|
"fetch_k": 20, |
|
|
"lambda_mult": 0.6, |
|
|
}, |
|
|
) |