|
import duckdb |
|
import lancedb |
|
|
|
from src.setting import AVAILABLE_WORDS |
|
|
|
|
|
class VectorDatabaseHandler: |
|
QUERY_TEMPLATE = "SELECT word, vector FROM {table_name} WHERE word = '{user_word}'" |
|
|
|
def __init__(self, db_path: str, table_name: str, metrics_cfg: dict): |
|
db = lancedb.connect(db_path) |
|
|
|
self.metrics_cfg = metrics_cfg |
|
self.embeddings_tbl = db.open_table(table_name) |
|
|
|
def __call__(self, guessed_word: str, supposed_word: str) -> dict: |
|
arrow_table = self.embeddings_tbl.to_arrow() |
|
word_embedding = self.get_word_vector(guessed_word, "arrow_table") |
|
|
|
df_emb = self.embeddings_tbl.search(word_embedding) \ |
|
.metric(self.metrics_cfg.metric) \ |
|
.limit(len(AVAILABLE_WORDS)) \ |
|
.to_df() |
|
|
|
supposed_word_row = df_emb[df_emb['word'] == supposed_word].iloc[0] |
|
cosine_distance = supposed_word_row['_distance'] |
|
|
|
words_between_count = len(df_emb[df_emb['_distance'] < cosine_distance]) |
|
closest_word = df_emb[df_emb['word'] != guessed_word].iloc[0]['word'] if words_between_count else supposed_word |
|
|
|
return { |
|
"score": cosine_distance, |
|
"rating": words_between_count, |
|
"percentage": 100 - words_between_count / len(df_emb) * 100, |
|
"closest_word": closest_word |
|
} |
|
|
|
def get_word_vector(self, word: str, table_name: str): |
|
vector = duckdb.query( |
|
self.QUERY_TEMPLATE.format(table_name=table_name, user_word=word) |
|
).to_df()["vector"].values[0] |
|
return vector |
|
|