File size: 1,560 Bytes
94e8fb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import duckdb
import lancedb
from src.setting import AVAILABLE_WORDS
class VectorDatabaseHandler:
QUERY_TEMPLATE = "SELECT word, vector FROM {table_name} WHERE word = '{user_word}'"
def __init__(self, db_path: str, table_name: str, metrics_cfg: dict):
db = lancedb.connect(db_path)
self.metrics_cfg = metrics_cfg
self.embeddings_tbl = db.open_table(table_name)
def __call__(self, guessed_word: str, supposed_word: str) -> dict:
arrow_table = self.embeddings_tbl.to_arrow()
word_embedding = self.get_word_vector(guessed_word, "arrow_table")
df_emb = self.embeddings_tbl.search(word_embedding) \
.metric(self.metrics_cfg.metric) \
.limit(len(AVAILABLE_WORDS)) \
.to_df()
supposed_word_row = df_emb[df_emb['word'] == supposed_word].iloc[0]
cosine_distance = supposed_word_row['_distance']
words_between_count = len(df_emb[df_emb['_distance'] < cosine_distance])
closest_word = df_emb[df_emb['word'] != guessed_word].iloc[0]['word'] if words_between_count else supposed_word
return {
"score": cosine_distance,
"rating": words_between_count,
"percentage": 100 - words_between_count / len(df_emb) * 100,
"closest_word": closest_word
}
def get_word_vector(self, word: str, table_name: str):
vector = duckdb.query(
self.QUERY_TEMPLATE.format(table_name=table_name, user_word=word)
).to_df()["vector"].values[0]
return vector
|