File size: 891 Bytes
3d2ca49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from sentence_transformers import util
from sklearn.base import BaseEstimator, TransformerMixin
import os
import pandas as pd


class Search(BaseEstimator, TransformerMixin):
    def __init__(self, path_to_library) -> None:
        super().__init__()

        self.path_to_library = path_to_library

    def fit(self):
        return self

    def transform(self, X, y=None):
        library_metadata = pd.read_feather(
            os.path.join(self.path_to_library, "metadata.feather")
        )
        library_embeddings = pd.read_feather(
            os.path.join(self.path_to_library, "embeddings.feather")
        ).values

        matches = util.semantic_search(
            query_embeddings=X, corpus_embeddings=library_embeddings, top_k=5
        )

        recommended_indices = [dict["corpus_id"] for dict in matches[0]]

        return library_metadata.iloc[recommended_indices]