i135e1fi414i41tqe / storage.py
serhan's picture
Upload 16 files
14e11d6
import os.path
from abc import ABC, abstractmethod
import faiss
import numpy as np
import pandas as pd
from pgvector.sqlalchemy import Vector
from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy.orm import sessionmaker, declarative_base
from config import Config
Base = declarative_base()
class Storage(ABC):
"""Abstract Storage class."""
# factory method
@staticmethod
def create_storage(cfg: Config) -> 'Storage':
"""Create a storage object."""
if cfg.use_postgres:
return _PostgresStorage(cfg)
else:
return _IndexStorage(cfg)
@abstractmethod
def add_all(self, embeddings: list[tuple[str, list[float]]], name: str):
"""Add multiple embeddings."""
pass
@abstractmethod
def get_texts(self, embedding: list[float], name: str, limit=100) -> list[str]:
"""Get the text for the provided embedding."""
pass
@abstractmethod
def get_all_embeddings(self, name: str):
"""Get all embeddings."""
pass
@abstractmethod
def clear(self, name: str):
"""Clear the database."""
pass
@abstractmethod
def been_indexed(self, name: str) -> bool:
"""Check if the database has been indexed."""
pass
class _IndexStorage(Storage):
"""IndexStorage class."""
def __init__(self, cfg: Config):
"""Initialize the storage."""
self._cfg = cfg
def add_all(self, embeddings: list[tuple[str, list[float]]], name):
"""Add multiple embeddings."""
texts, index = self._load(name)
ids = np.array([len(texts) + i for i, _ in enumerate(embeddings)])
texts = pd.concat([texts, pd.DataFrame(
{'index': len(texts) + i, 'text': text} for i, (text, _) in enumerate(embeddings))])
array = np.array([emb for text, emb in embeddings])
index.add_with_ids(array, ids)
self._save(texts, index, name)
def get_texts(self, embedding: list[float], name: str, limit=100) -> list[str]:
"""Get the text for the provided embedding."""
texts, index = self._load(name)
_, indexs = index.search(np.array([embedding]), limit)
indexs = [i for i in indexs[0] if i >= 0]
return [f'paragraph {p}: {t}' for _, p, t in texts.iloc[indexs].values]
def get_all_embeddings(self, name: str):
texts, index = self._load(name)
texts = texts.text.tolist()
embeddings = index.reconstruct_n(0, len(texts))
return list(zip(texts, embeddings))
def clear(self, name: str):
"""Clear the database."""
self._delete(name)
def been_indexed(self, name: str) -> bool:
return os.path.exists(os.path.join(self._cfg.index_path, f'{name}.csv')) and os.path.exists(
os.path.join(self._cfg.index_path, f'{name}.bin'))
def _save(self, texts, index, name: str):
texts.to_csv(os.path.join(self._cfg.index_path, f'{name}.csv'))
faiss.write_index(index, os.path.join(self._cfg.index_path, f'{name}.bin'))
def _load(self, name: str):
if self.been_indexed(name):
texts = pd.read_csv(os.path.join(self._cfg.index_path, f'{name}.csv'))
index = faiss.read_index(os.path.join(self._cfg.index_path, f'{name}.bin'))
else:
texts = pd.DataFrame(columns=['index', 'text'])
# IDMap2 with Flat
index = faiss.index_factory(1536, "IDMap2,Flat", faiss.METRIC_INNER_PRODUCT)
return texts, index
def _delete(self, name: str):
try:
os.remove(os.path.join(self._cfg.index_path, f'{name}.csv'))
os.remove(os.path.join(self._cfg.index_path, f'{name}.bin'))
except FileNotFoundError:
pass
def singleton(cls):
instances = {}
def get_instance(cfg):
if cls not in instances:
instances[cls] = cls(cfg)
return instances[cls]
return get_instance
@singleton
class _PostgresStorage(Storage):
"""PostgresStorage class."""
def __init__(self, cfg: Config):
"""Initialize the storage."""
self._postgresql = cfg.postgres_url
self._engine = create_engine(self._postgresql)
Base.metadata.create_all(self._engine)
session = sessionmaker(bind=self._engine)
self._session = session()
def add_all(self, embeddings: list[tuple[str, list[float]]], name: str):
"""Add multiple embeddings."""
data = [self.EmbeddingEntity(text=text, embedding=embedding, name=name) for text, embedding in embeddings]
self._session.add_all(data)
self._session.commit()
def get_texts(self, embedding: list[float], name: str, limit=100) -> list[str]:
"""Get the text for the provided embedding."""
result = self._session.query(self.EmbeddingEntity).where(self.EmbeddingEntity.name == name).order_by(
self.EmbeddingEntity.embedding.cosine_distance(embedding)).limit(limit).all()
return [f'paragraph {s.id}: {s.text}' for s in result]
def get_all_embeddings(self, name: str):
"""Get all embeddings."""
result = self._session.query(self.EmbeddingEntity).where(self.EmbeddingEntity.name == name).all()
return [(s.text, s.embedding) for s in result]
def clear(self, name: str):
"""Clear the database."""
self._session.query(self.EmbeddingEntity).where(self.EmbeddingEntity.name == name).delete()
self._session.commit()
def been_indexed(self, name: str) -> bool:
return self._session.query(self.EmbeddingEntity).filter_by(name=name).first() is not None
def __del__(self):
"""Close the session."""
self._session.close()
class EmbeddingEntity(Base):
__tablename__ = 'embedding'
id = Column(Integer, primary_key=True)
name = Column(String)
text = Column(String)
embedding = Column(Vector(1536))