Spaces:
Runtime error
Runtime error
"""Implementation of a record management layer in SQLAlchemy. | |
The management layer uses SQLAlchemy to track upserted records. | |
Currently, this layer only works with SQLite; hopwever, should be adaptable | |
to other SQL implementations with minimal effort. | |
Currently, includes an implementation that uses SQLAlchemy which should | |
allow it to work with a variety of SQL as a backend. | |
* Each key is associated with an updated_at field. | |
* This filed is updated whenever the key is updated. | |
* Keys can be listed based on the updated at field. | |
* Keys can be deleted. | |
""" | |
import contextlib | |
import decimal | |
import uuid | |
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence, Union | |
from sqlalchemy import ( | |
URL, | |
Column, | |
Engine, | |
Float, | |
Index, | |
String, | |
UniqueConstraint, | |
and_, | |
create_engine, | |
delete, | |
select, | |
text, | |
) | |
from sqlalchemy.ext.asyncio import ( | |
AsyncEngine, | |
AsyncSession, | |
async_sessionmaker, | |
create_async_engine, | |
) | |
from sqlalchemy.ext.declarative import declarative_base | |
from sqlalchemy.orm import Session, sessionmaker | |
from langchain.indexes.base import RecordManager | |
Base = declarative_base() | |
class UpsertionRecord(Base): # type: ignore[valid-type,misc] | |
"""Table used to keep track of when a key was last updated.""" | |
# ATTENTION: | |
# Prior to modifying this table, please determine whether | |
# we should create migrations for this table to make sure | |
# users do not experience data loss. | |
__tablename__ = "upsertion_record" | |
uuid = Column( | |
String, | |
index=True, | |
default=lambda: str(uuid.uuid4()), | |
primary_key=True, | |
nullable=False, | |
) | |
key = Column(String, index=True) | |
# Using a non-normalized representation to handle `namespace` attribute. | |
# If the need arises, this attribute can be pulled into a separate Collection | |
# table at some time later. | |
namespace = Column(String, index=True, nullable=False) | |
group_id = Column(String, index=True, nullable=True) | |
# The timestamp associated with the last record upsertion. | |
updated_at = Column(Float, index=True) | |
__table_args__ = ( | |
UniqueConstraint("key", "namespace", name="uix_key_namespace"), | |
Index("ix_key_namespace", "key", "namespace"), | |
) | |
class SQLRecordManager(RecordManager): | |
"""A SQL Alchemy based implementation of the record manager.""" | |
def __init__( | |
self, | |
namespace: str, | |
*, | |
engine: Optional[Union[Engine, AsyncEngine]] = None, | |
db_url: Union[None, str, URL] = None, | |
engine_kwargs: Optional[Dict[str, Any]] = None, | |
async_mode: bool = False, | |
) -> None: | |
"""Initialize the SQLRecordManager. | |
This class serves as a manager persistence layer that uses an SQL | |
backend to track upserted records. You should specify either a db_url | |
to create an engine or provide an existing engine. | |
Args: | |
namespace: The namespace associated with this record manager. | |
engine: An already existing SQL Alchemy engine. | |
Default is None. | |
db_url: A database connection string used to create | |
an SQL Alchemy engine. Default is None. | |
engine_kwargs: Additional keyword arguments | |
to be passed when creating the engine. Default is an empty dictionary. | |
async_mode: Whether to create an async engine. | |
Driver should support async operations. | |
It only applies if db_url is provided. | |
Default is False. | |
Raises: | |
ValueError: If both db_url and engine are provided or neither. | |
AssertionError: If something unexpected happens during engine configuration. | |
""" | |
super().__init__(namespace=namespace) | |
if db_url is None and engine is None: | |
raise ValueError("Must specify either db_url or engine") | |
if db_url is not None and engine is not None: | |
raise ValueError("Must specify either db_url or engine, not both") | |
_engine: Union[Engine, AsyncEngine] | |
if db_url: | |
if async_mode: | |
_engine = create_async_engine(db_url, **(engine_kwargs or {})) | |
else: | |
_engine = create_engine(db_url, **(engine_kwargs or {})) | |
elif engine: | |
_engine = engine | |
else: | |
raise AssertionError("Something went wrong with configuration of engine.") | |
_session_factory: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]] | |
if isinstance(_engine, AsyncEngine): | |
_session_factory = async_sessionmaker(bind=_engine) | |
else: | |
_session_factory = sessionmaker(bind=_engine) | |
self.engine = _engine | |
self.dialect = _engine.dialect.name | |
self.session_factory = _session_factory | |
def create_schema(self) -> None: | |
"""Create the database schema.""" | |
if isinstance(self.engine, AsyncEngine): | |
raise AssertionError("This method is not supported for async engines.") | |
Base.metadata.create_all(self.engine) | |
async def acreate_schema(self) -> None: | |
"""Create the database schema.""" | |
if not isinstance(self.engine, AsyncEngine): | |
raise AssertionError("This method is not supported for sync engines.") | |
async with self.engine.begin() as session: | |
await session.run_sync(Base.metadata.create_all) | |
def _make_session(self) -> Generator[Session, None, None]: | |
"""Create a session and close it after use.""" | |
if isinstance(self.session_factory, async_sessionmaker): | |
raise AssertionError("This method is not supported for async engines.") | |
session = self.session_factory() | |
try: | |
yield session | |
finally: | |
session.close() | |
async def _amake_session(self) -> AsyncGenerator[AsyncSession, None]: | |
"""Create a session and close it after use.""" | |
if not isinstance(self.session_factory, async_sessionmaker): | |
raise AssertionError("This method is not supported for sync engines.") | |
async with self.session_factory() as session: | |
yield session | |
def get_time(self) -> float: | |
"""Get the current server time as a timestamp. | |
Please note it's critical that time is obtained from the server since | |
we want a monotonic clock. | |
""" | |
with self._make_session() as session: | |
# * SQLite specific implementation, can be changed based on dialect. | |
# * For SQLite, unlike unixepoch it will work with older versions of SQLite. | |
# ---- | |
# julianday('now'): Julian day number for the current date and time. | |
# The Julian day is a continuous count of days, starting from a | |
# reference date (Julian day number 0). | |
# 2440587.5 - constant represents the Julian day number for January 1, 1970 | |
# 86400.0 - constant represents the number of seconds | |
# in a day (24 hours * 60 minutes * 60 seconds) | |
if self.dialect == "sqlite": | |
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;") | |
elif self.dialect == "postgresql": | |
query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);") | |
else: | |
raise NotImplementedError(f"Not implemented for dialect {self.dialect}") | |
dt = session.execute(query).scalar() | |
if isinstance(dt, decimal.Decimal): | |
dt = float(dt) | |
if not isinstance(dt, float): | |
raise AssertionError(f"Unexpected type for datetime: {type(dt)}") | |
return dt | |
async def aget_time(self) -> float: | |
"""Get the current server time as a timestamp. | |
Please note it's critical that time is obtained from the server since | |
we want a monotonic clock. | |
""" | |
async with self._amake_session() as session: | |
# * SQLite specific implementation, can be changed based on dialect. | |
# * For SQLite, unlike unixepoch it will work with older versions of SQLite. | |
# ---- | |
# julianday('now'): Julian day number for the current date and time. | |
# The Julian day is a continuous count of days, starting from a | |
# reference date (Julian day number 0). | |
# 2440587.5 - constant represents the Julian day number for January 1, 1970 | |
# 86400.0 - constant represents the number of seconds | |
# in a day (24 hours * 60 minutes * 60 seconds) | |
if self.dialect == "sqlite": | |
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;") | |
elif self.dialect == "postgresql": | |
query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);") | |
else: | |
raise NotImplementedError(f"Not implemented for dialect {self.dialect}") | |
dt = (await session.execute(query)).scalar_one_or_none() | |
if isinstance(dt, decimal.Decimal): | |
dt = float(dt) | |
if not isinstance(dt, float): | |
raise AssertionError(f"Unexpected type for datetime: {type(dt)}") | |
return dt | |
def update( | |
self, | |
keys: Sequence[str], | |
*, | |
group_ids: Optional[Sequence[Optional[str]]] = None, | |
time_at_least: Optional[float] = None, | |
) -> None: | |
"""Upsert records into the SQLite database.""" | |
if group_ids is None: | |
group_ids = [None] * len(keys) | |
if len(keys) != len(group_ids): | |
raise ValueError( | |
f"Number of keys ({len(keys)}) does not match number of " | |
f"group_ids ({len(group_ids)})" | |
) | |
# Get the current time from the server. | |
# This makes an extra round trip to the server, should not be a big deal | |
# if the batch size is large enough. | |
# Getting the time here helps us compare it against the time_at_least | |
# and raise an error if there is a time sync issue. | |
# Here, we're just being extra careful to minimize the chance of | |
# data loss due to incorrectly deleting records. | |
update_time = self.get_time() | |
if time_at_least and update_time < time_at_least: | |
# Safeguard against time sync issues | |
raise AssertionError(f"Time sync issue: {update_time} < {time_at_least}") | |
records_to_upsert = [ | |
{ | |
"key": key, | |
"namespace": self.namespace, | |
"updated_at": update_time, | |
"group_id": group_id, | |
} | |
for key, group_id in zip(keys, group_ids) | |
] | |
with self._make_session() as session: | |
if self.dialect == "sqlite": | |
from sqlalchemy.dialects.sqlite import insert as sqlite_insert | |
# Note: uses SQLite insert to make on_conflict_do_update work. | |
# This code needs to be generalized a bit to work with more dialects. | |
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert) | |
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] | |
[UpsertionRecord.key, UpsertionRecord.namespace], | |
set_=dict( | |
# attr-defined type ignore | |
updated_at=insert_stmt.excluded.updated_at, # type: ignore | |
group_id=insert_stmt.excluded.group_id, # type: ignore | |
), | |
) | |
elif self.dialect == "postgresql": | |
from sqlalchemy.dialects.postgresql import insert as pg_insert | |
# Note: uses SQLite insert to make on_conflict_do_update work. | |
# This code needs to be generalized a bit to work with more dialects. | |
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert) | |
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] | |
"uix_key_namespace", # Name of constraint | |
set_=dict( | |
# attr-defined type ignore | |
updated_at=insert_stmt.excluded.updated_at, # type: ignore | |
group_id=insert_stmt.excluded.group_id, # type: ignore | |
), | |
) | |
else: | |
raise NotImplementedError(f"Unsupported dialect {self.dialect}") | |
session.execute(stmt) | |
session.commit() | |
async def aupdate( | |
self, | |
keys: Sequence[str], | |
*, | |
group_ids: Optional[Sequence[Optional[str]]] = None, | |
time_at_least: Optional[float] = None, | |
) -> None: | |
"""Upsert records into the SQLite database.""" | |
if group_ids is None: | |
group_ids = [None] * len(keys) | |
if len(keys) != len(group_ids): | |
raise ValueError( | |
f"Number of keys ({len(keys)}) does not match number of " | |
f"group_ids ({len(group_ids)})" | |
) | |
# Get the current time from the server. | |
# This makes an extra round trip to the server, should not be a big deal | |
# if the batch size is large enough. | |
# Getting the time here helps us compare it against the time_at_least | |
# and raise an error if there is a time sync issue. | |
# Here, we're just being extra careful to minimize the chance of | |
# data loss due to incorrectly deleting records. | |
update_time = await self.aget_time() | |
if time_at_least and update_time < time_at_least: | |
# Safeguard against time sync issues | |
raise AssertionError(f"Time sync issue: {update_time} < {time_at_least}") | |
records_to_upsert = [ | |
{ | |
"key": key, | |
"namespace": self.namespace, | |
"updated_at": update_time, | |
"group_id": group_id, | |
} | |
for key, group_id in zip(keys, group_ids) | |
] | |
async with self._amake_session() as session: | |
if self.dialect == "sqlite": | |
from sqlalchemy.dialects.sqlite import insert as sqlite_insert | |
# Note: uses SQLite insert to make on_conflict_do_update work. | |
# This code needs to be generalized a bit to work with more dialects. | |
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert) | |
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] | |
[UpsertionRecord.key, UpsertionRecord.namespace], | |
set_=dict( | |
# attr-defined type ignore | |
updated_at=insert_stmt.excluded.updated_at, # type: ignore | |
group_id=insert_stmt.excluded.group_id, # type: ignore | |
), | |
) | |
elif self.dialect == "postgresql": | |
from sqlalchemy.dialects.postgresql import insert as pg_insert | |
# Note: uses SQLite insert to make on_conflict_do_update work. | |
# This code needs to be generalized a bit to work with more dialects. | |
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert) | |
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] | |
"uix_key_namespace", # Name of constraint | |
set_=dict( | |
# attr-defined type ignore | |
updated_at=insert_stmt.excluded.updated_at, # type: ignore | |
group_id=insert_stmt.excluded.group_id, # type: ignore | |
), | |
) | |
else: | |
raise NotImplementedError(f"Unsupported dialect {self.dialect}") | |
await session.execute(stmt) | |
await session.commit() | |
def exists(self, keys: Sequence[str]) -> List[bool]: | |
"""Check if the given keys exist in the SQLite database.""" | |
with self._make_session() as session: | |
records = ( | |
# mypy does not recognize .all() | |
session.query(UpsertionRecord.key) # type: ignore[attr-defined] | |
.filter( | |
and_( | |
UpsertionRecord.key.in_(keys), | |
UpsertionRecord.namespace == self.namespace, | |
) | |
) | |
.all() | |
) | |
found_keys = set(r.key for r in records) | |
return [k in found_keys for k in keys] | |
async def aexists(self, keys: Sequence[str]) -> List[bool]: | |
"""Check if the given keys exist in the SQLite database.""" | |
async with self._amake_session() as session: | |
records = ( | |
( | |
await session.execute( | |
select(UpsertionRecord.key).where( | |
and_( | |
UpsertionRecord.key.in_(keys), | |
UpsertionRecord.namespace == self.namespace, | |
) | |
) | |
) | |
) | |
.scalars() | |
.all() | |
) | |
found_keys = set(records) | |
return [k in found_keys for k in keys] | |
def list_keys( | |
self, | |
*, | |
before: Optional[float] = None, | |
after: Optional[float] = None, | |
group_ids: Optional[Sequence[str]] = None, | |
limit: Optional[int] = None, | |
) -> List[str]: | |
"""List records in the SQLite database based on the provided date range.""" | |
with self._make_session() as session: | |
query = session.query(UpsertionRecord).filter( | |
UpsertionRecord.namespace == self.namespace | |
) | |
# mypy does not recognize .all() or .filter() | |
if after: | |
query = query.filter( # type: ignore[attr-defined] | |
UpsertionRecord.updated_at > after | |
) | |
if before: | |
query = query.filter( # type: ignore[attr-defined] | |
UpsertionRecord.updated_at < before | |
) | |
if group_ids: | |
query = query.filter( # type: ignore[attr-defined] | |
UpsertionRecord.group_id.in_(group_ids) | |
) | |
if limit: | |
query = query.limit(limit) # type: ignore[attr-defined] | |
records = query.all() # type: ignore[attr-defined] | |
return [r.key for r in records] | |
async def alist_keys( | |
self, | |
*, | |
before: Optional[float] = None, | |
after: Optional[float] = None, | |
group_ids: Optional[Sequence[str]] = None, | |
limit: Optional[int] = None, | |
) -> List[str]: | |
"""List records in the SQLite database based on the provided date range.""" | |
async with self._amake_session() as session: | |
query = select(UpsertionRecord.key).filter( | |
UpsertionRecord.namespace == self.namespace | |
) | |
# mypy does not recognize .all() or .filter() | |
if after: | |
query = query.filter( # type: ignore[attr-defined] | |
UpsertionRecord.updated_at > after | |
) | |
if before: | |
query = query.filter( # type: ignore[attr-defined] | |
UpsertionRecord.updated_at < before | |
) | |
if group_ids: | |
query = query.filter( # type: ignore[attr-defined] | |
UpsertionRecord.group_id.in_(group_ids) | |
) | |
if limit: | |
query = query.limit(limit) # type: ignore[attr-defined] | |
records = (await session.execute(query)).scalars().all() | |
return list(records) | |
def delete_keys(self, keys: Sequence[str]) -> None: | |
"""Delete records from the SQLite database.""" | |
with self._make_session() as session: | |
# mypy does not recognize .delete() | |
session.query(UpsertionRecord).filter( | |
and_( | |
UpsertionRecord.key.in_(keys), | |
UpsertionRecord.namespace == self.namespace, | |
) | |
).delete() # type: ignore[attr-defined] | |
session.commit() | |
async def adelete_keys(self, keys: Sequence[str]) -> None: | |
"""Delete records from the SQLite database.""" | |
async with self._amake_session() as session: | |
await session.execute( | |
delete(UpsertionRecord).where( | |
and_( | |
UpsertionRecord.key.in_(keys), | |
UpsertionRecord.namespace == self.namespace, | |
) | |
) | |
) | |
await session.commit() | |