Spaces:
Sleeping
Sleeping
from transformers import AutoConfig | |
from sentence_transformers import SentenceTransformer | |
import lancedb | |
import torch | |
import pyarrow as pa | |
import pandas as pd | |
import numpy as np | |
import tqdm | |
class VectorDB: | |
vector_column = "vector" | |
description_column = "description" | |
name_column = "name" | |
table_name = "pimcore_actions" | |
emb_model = '' | |
db_location = '' | |
def __init__(self, emb_model, db_location, actions_list_file_path, num_sub_vectors, batch_size): | |
self.emb_model = emb_model | |
self.db_location = db_location | |
emb_config = AutoConfig.from_pretrained(emb_model) | |
emb_dimension = emb_config.hidden_size | |
assert emb_dimension % num_sub_vectors == 0, \ | |
"Embedding size must be divisible by the num of sub vectors" | |
print('Model loaded...') | |
print(emb_model) | |
model = SentenceTransformer(emb_model) | |
model.eval() | |
if torch.backends.mps.is_available(): | |
device = "mps" | |
elif torch.cuda.is_available(): | |
device = "cuda" | |
else: | |
device = "cpu" | |
print(f"Device: {device}") | |
db = lancedb.connect(db_location) | |
schema = pa.schema( | |
[ | |
pa.field(self.vector_column, pa.list_(pa.float32(), emb_dimension)), | |
pa.field(self.description_column, pa.string()), | |
pa.field(self.name_column, pa.string()) | |
] | |
) | |
tbl = db.create_table(self.table_name, schema=schema, mode="overwrite") | |
df = pd.read_csv(actions_list_file_path) | |
sentences = df.values | |
print("Starting vector generation") | |
for i in tqdm.tqdm(range(0, int(np.ceil(len(sentences) / batch_size)))): | |
try: | |
batch = [sent for sent in sentences[i * batch_size:(i + 1) * batch_size] if len(sent) > 0] | |
to_encode = [entry[1] for entry in batch] | |
names = [entry[0] for entry in batch] | |
encoded = model.encode(to_encode, normalize_embeddings=True, device=device) | |
encoded = [list(vec) for vec in encoded] | |
df = pd.DataFrame({ | |
self.vector_column: encoded, | |
self.description_column: to_encode, | |
self.name_column: names | |
}) | |
tbl.add(df) | |
except: | |
print(f"batch {i} was skipped") | |
print("Vector generation done.") | |
def get_embedding_db_as_pandas(self): | |
db = lancedb.connect(self.db_location) | |
tbl = db.open_table(self.table_name) | |
return tbl.to_pandas() | |
def retrieve_prefiltered_hits(self, query, k): | |
db = lancedb.connect(".lancedb") | |
table = db.open_table(self.table_name) | |
retriever = SentenceTransformer(self.emb_model) | |
query_vec = retriever.encode(query) | |
documents = table.search(query_vec, vector_column_name=self.vector_column).limit(k).to_list() | |
names = [doc[self.name_column] for doc in documents] | |
descriptions = [doc[self.description_column] for doc in documents] | |
return names, descriptions | |