Spaces:
Runtime error
Runtime error
import os.path | |
from abc import ABC, abstractmethod | |
import faiss | |
import numpy as np | |
import pandas as pd | |
from pgvector.sqlalchemy import Vector | |
from sqlalchemy import create_engine, Column, Integer, String | |
from sqlalchemy.orm import sessionmaker, declarative_base | |
from config import Config | |
Base = declarative_base() | |
class Storage(ABC): | |
"""Abstract Storage class.""" | |
# factory method | |
def create_storage(cfg: Config) -> 'Storage': | |
"""Create a storage object.""" | |
if cfg.use_postgres: | |
return _PostgresStorage(cfg) | |
else: | |
return _IndexStorage(cfg) | |
def add_all(self, embeddings: list[tuple[str, list[float]]], name: str): | |
"""Add multiple embeddings.""" | |
pass | |
def get_texts(self, embedding: list[float], name: str, limit=100) -> list[str]: | |
"""Get the text for the provided embedding.""" | |
pass | |
def get_all_embeddings(self, name: str): | |
"""Get all embeddings.""" | |
pass | |
def clear(self, name: str): | |
"""Clear the database.""" | |
pass | |
def been_indexed(self, name: str) -> bool: | |
"""Check if the database has been indexed.""" | |
pass | |
class _IndexStorage(Storage): | |
"""IndexStorage class.""" | |
def __init__(self, cfg: Config): | |
"""Initialize the storage.""" | |
self._cfg = cfg | |
def add_all(self, embeddings: list[tuple[str, list[float]]], name): | |
"""Add multiple embeddings.""" | |
texts, index = self._load(name) | |
ids = np.array([len(texts) + i for i, _ in enumerate(embeddings)]) | |
texts = pd.concat([texts, pd.DataFrame( | |
{'index': len(texts) + i, 'text': text} for i, (text, _) in enumerate(embeddings))]) | |
array = np.array([emb for text, emb in embeddings]) | |
index.add_with_ids(array, ids) | |
self._save(texts, index, name) | |
def get_texts(self, embedding: list[float], name: str, limit=100) -> list[str]: | |
"""Get the text for the provided embedding.""" | |
texts, index = self._load(name) | |
_, indexs = index.search(np.array([embedding]), limit) | |
indexs = [i for i in indexs[0] if i >= 0] | |
return [f'paragraph {p}: {t}' for _, p, t in texts.iloc[indexs].values] | |
def get_all_embeddings(self, name: str): | |
texts, index = self._load(name) | |
texts = texts.text.tolist() | |
embeddings = index.reconstruct_n(0, len(texts)) | |
return list(zip(texts, embeddings)) | |
def clear(self, name: str): | |
"""Clear the database.""" | |
self._delete(name) | |
def been_indexed(self, name: str) -> bool: | |
return os.path.exists(os.path.join(self._cfg.index_path, f'{name}.csv')) and os.path.exists( | |
os.path.join(self._cfg.index_path, f'{name}.bin')) | |
def _save(self, texts, index, name: str): | |
texts.to_csv(os.path.join(self._cfg.index_path, f'{name}.csv')) | |
faiss.write_index(index, os.path.join(self._cfg.index_path, f'{name}.bin')) | |
def _load(self, name: str): | |
if self.been_indexed(name): | |
texts = pd.read_csv(os.path.join(self._cfg.index_path, f'{name}.csv')) | |
index = faiss.read_index(os.path.join(self._cfg.index_path, f'{name}.bin')) | |
else: | |
texts = pd.DataFrame(columns=['index', 'text']) | |
# IDMap2 with Flat | |
index = faiss.index_factory(1536, "IDMap2,Flat", faiss.METRIC_INNER_PRODUCT) | |
return texts, index | |
def _delete(self, name: str): | |
try: | |
os.remove(os.path.join(self._cfg.index_path, f'{name}.csv')) | |
os.remove(os.path.join(self._cfg.index_path, f'{name}.bin')) | |
except FileNotFoundError: | |
pass | |
def singleton(cls): | |
instances = {} | |
def get_instance(cfg): | |
if cls not in instances: | |
instances[cls] = cls(cfg) | |
return instances[cls] | |
return get_instance | |
class _PostgresStorage(Storage): | |
"""PostgresStorage class.""" | |
def __init__(self, cfg: Config): | |
"""Initialize the storage.""" | |
self._postgresql = cfg.postgres_url | |
self._engine = create_engine(self._postgresql) | |
Base.metadata.create_all(self._engine) | |
session = sessionmaker(bind=self._engine) | |
self._session = session() | |
def add_all(self, embeddings: list[tuple[str, list[float]]], name: str): | |
"""Add multiple embeddings.""" | |
data = [self.EmbeddingEntity(text=text, embedding=embedding, name=name) for text, embedding in embeddings] | |
self._session.add_all(data) | |
self._session.commit() | |
def get_texts(self, embedding: list[float], name: str, limit=100) -> list[str]: | |
"""Get the text for the provided embedding.""" | |
result = self._session.query(self.EmbeddingEntity).where(self.EmbeddingEntity.name == name).order_by( | |
self.EmbeddingEntity.embedding.cosine_distance(embedding)).limit(limit).all() | |
return [f'paragraph {s.id}: {s.text}' for s in result] | |
def get_all_embeddings(self, name: str): | |
"""Get all embeddings.""" | |
result = self._session.query(self.EmbeddingEntity).where(self.EmbeddingEntity.name == name).all() | |
return [(s.text, s.embedding) for s in result] | |
def clear(self, name: str): | |
"""Clear the database.""" | |
self._session.query(self.EmbeddingEntity).where(self.EmbeddingEntity.name == name).delete() | |
self._session.commit() | |
def been_indexed(self, name: str) -> bool: | |
return self._session.query(self.EmbeddingEntity).filter_by(name=name).first() is not None | |
def __del__(self): | |
"""Close the session.""" | |
self._session.close() | |
class EmbeddingEntity(Base): | |
__tablename__ = 'embedding' | |
id = Column(Integer, primary_key=True) | |
name = Column(String) | |
text = Column(String) | |
embedding = Column(Vector(1536)) | |