from transformers import AutoConfig from sentence_transformers import SentenceTransformer import lancedb import torch import pyarrow as pa import pandas as pd import numpy as np import tqdm class VectorDB: vector_column = "vector" description_column = "description" name_column = "name" table_name = "pimcore_actions" emb_model = '' db_location = '' def __init__(self, emb_model, db_location, actions_list_file_path, num_sub_vectors, batch_size): self.emb_model = emb_model self.db_location = db_location emb_config = AutoConfig.from_pretrained(emb_model) emb_dimension = emb_config.hidden_size assert emb_dimension % num_sub_vectors == 0, \ "Embedding size must be divisible by the num of sub vectors" print('Model loaded...') print(emb_model) model = SentenceTransformer(emb_model) model.eval() if torch.backends.mps.is_available(): device = "mps" elif torch.cuda.is_available(): device = "cuda" else: device = "cpu" print(f"Device: {device}") db = lancedb.connect(db_location) schema = pa.schema( [ pa.field(self.vector_column, pa.list_(pa.float32(), emb_dimension)), pa.field(self.description_column, pa.string()), pa.field(self.name_column, pa.string()) ] ) tbl = db.create_table(self.table_name, schema=schema, mode="overwrite") df = pd.read_csv(actions_list_file_path) sentences = df.values print("Starting vector generation") for i in tqdm.tqdm(range(0, int(np.ceil(len(sentences) / batch_size)))): try: batch = [sent for sent in sentences[i * batch_size:(i + 1) * batch_size] if len(sent) > 0] to_encode = [entry[1] for entry in batch] names = [entry[0] for entry in batch] encoded = model.encode(to_encode, normalize_embeddings=True, device=device) encoded = [list(vec) for vec in encoded] df = pd.DataFrame({ self.vector_column: encoded, self.description_column: to_encode, self.name_column: names }) tbl.add(df) except: print(f"batch {i} was skipped") print("Vector generation done.") def get_embedding_db_as_pandas(self): db = lancedb.connect(self.db_location) tbl = db.open_table(self.table_name) return tbl.to_pandas() def retrieve_prefiltered_hits(self, query, k): db = lancedb.connect(".lancedb") table = db.open_table(self.table_name) retriever = SentenceTransformer(self.emb_model) query_vec = retriever.encode(query) documents = table.search(query_vec, vector_column_name=self.vector_column).limit(k).to_list() names = [doc[self.name_column] for doc in documents] descriptions = [doc[self.description_column] for doc in documents] return names, descriptions