File size: 1,560 Bytes
94e8fb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import duckdb
import lancedb

from src.setting import AVAILABLE_WORDS


class VectorDatabaseHandler:
    QUERY_TEMPLATE = "SELECT word, vector FROM {table_name} WHERE word = '{user_word}'"

    def __init__(self, db_path: str, table_name: str, metrics_cfg: dict):
        db = lancedb.connect(db_path)

        self.metrics_cfg = metrics_cfg
        self.embeddings_tbl = db.open_table(table_name)

    def __call__(self, guessed_word: str, supposed_word: str) -> dict:
        arrow_table = self.embeddings_tbl.to_arrow()
        word_embedding = self.get_word_vector(guessed_word, "arrow_table")

        df_emb = self.embeddings_tbl.search(word_embedding) \
            .metric(self.metrics_cfg.metric) \
            .limit(len(AVAILABLE_WORDS)) \
            .to_df()

        supposed_word_row = df_emb[df_emb['word'] == supposed_word].iloc[0]
        cosine_distance = supposed_word_row['_distance']

        words_between_count = len(df_emb[df_emb['_distance'] < cosine_distance])
        closest_word = df_emb[df_emb['word'] != guessed_word].iloc[0]['word'] if words_between_count else supposed_word

        return {
            "score": cosine_distance,
            "rating": words_between_count,
            "percentage": 100 - words_between_count / len(df_emb) * 100,
            "closest_word": closest_word
        }

    def get_word_vector(self, word: str, table_name: str):
        vector = duckdb.query(
            self.QUERY_TEMPLATE.format(table_name=table_name, user_word=word)
        ).to_df()["vector"].values[0]
        return vector