Spaces:
Paused
Paused
import time | |
import os | |
import pickle | |
import re | |
import pandas as pd | |
from tqdm import tqdm | |
from sentence_transformers import SentenceTransformer, util | |
from db.db_utils import get_connection, initialize_db, get_mapping_from_db, store_mapping_to_db | |
from utils import generate_embedding, cosine_similarity, clean_word, calculate_confidence_and_similar_words_str | |
# model_name = 'sentence-transformers/all-MiniLM-L6-v2' | |
model_name = 'sentence-transformers/all-mpnet-base-v2' | |
filename = model_name.replace('/', '-') | |
pickle_file_path = f'./embeddings/slow/{filename}.pkl' | |
class SimilaritySlow: | |
def __init__(self, db_cursor, db_conn): | |
self.db_cursor = db_cursor | |
self.db_conn = db_conn | |
self.model = SentenceTransformer(model_name) | |
self.db_cursor.execute("SELECT description FROM dictionary") | |
dictionary = self.db_cursor.fetchall() | |
dictionary = [item[0] for item in dictionary] | |
self.dictionary_embeddings = self.load_dictionary_embeddings(dictionary) | |
def preprocess_dictionary_word(self, text): | |
text = text.strip().lower().replace(", raw", "").replace(" raw", "").replace(", nfs", "").replace(" nfs", "") | |
words = text.split() | |
return [ | |
' '.join(reversed(words)).replace(',', ''), | |
', '.join(reversed(text.split(', '))), | |
text, | |
' '.join(words).replace(',', '') | |
] | |
def load_dictionary_embeddings(self, dictionary): | |
if os.path.exists(pickle_file_path): | |
with open(pickle_file_path, 'rb') as f: | |
return pickle.load(f) | |
else: | |
dictionary_embeddings = {} | |
for dictionary_word in tqdm(dictionary, desc="Generating embeddings for dictionary words"): | |
preprocessed_words = self.preprocess_dictionary_word(dictionary_word) | |
for preprocessed_word in preprocessed_words: | |
dictionary_embeddings[preprocessed_word] = { | |
'v': generate_embedding(self.model, preprocessed_word), | |
'd': dictionary_word | |
} | |
with open(pickle_file_path, 'wb') as f: | |
pickle.dump(dictionary_embeddings, f) | |
return dictionary_embeddings | |
def calculate_similarity_score(self, input_word_clean): | |
input_embedding = generate_embedding(self.model, input_word_clean) | |
similarities = [] | |
for key, val in self.dictionary_embeddings.items(): | |
similarities.append((key, val['d'], cosine_similarity(input_embedding, val['v']))) | |
most_similar_word, dictionary_word, highest_score = max(similarities, key=lambda x: x[2]) | |
confidence_score, similar_words_str = calculate_confidence_and_similar_words_str(similarities, highest_score) | |
return most_similar_word, dictionary_word, highest_score, confidence_score, similar_words_str | |
def find_most_similar_word(self, input_word): | |
if not isinstance(input_word, str) or not input_word: | |
return None | |
input_word_clean = clean_word(input_word) | |
most_similar_word, dictionary_word, highest_score, confidence_score, similar_words_str = self.calculate_similarity_score(input_word_clean) | |
mapping = { | |
'input_word': input_word, | |
'cleaned_word': input_word_clean, | |
'matching_word': most_similar_word, | |
'dictionary_word': dictionary_word, | |
'similarity_score': highest_score, | |
'confidence_score': confidence_score, | |
'similar_words': similar_words_str, | |
} | |
return mapping | |