Michael-Geis
reorganized
7cc8002
from sentence_transformers import util
from sklearn.base import BaseEstimator, TransformerMixin
import os
import pandas as pd
class Search(BaseEstimator, TransformerMixin):
def __init__(self, path_to_library) -> None:
super().__init__()
self.path_to_library = path_to_library
def fit(self):
return self
def transform(self, X, y=None):
library_metadata = pd.read_feather(
os.path.join(self.path_to_library, "metadata.feather")
)
library_embeddings = pd.read_feather(
os.path.join(self.path_to_library, "embeddings.feather")
).values
matches = util.semantic_search(
query_embeddings=X, corpus_embeddings=library_embeddings, top_k=5
)
recommended_indices = [dict["corpus_id"] for dict in matches[0]]
return library_metadata.iloc[recommended_indices]