brightly-ai / similarity_slow.py
beweinreich's picture
added in some templates
e5de092
raw
history blame contribute delete
No virus
3.61 kB
import time
import os
import pickle
import re
import pandas as pd
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util
from db.db_utils import get_connection, initialize_db, get_mapping_from_db, store_mapping_to_db
from utils import generate_embedding, cosine_similarity, clean_word, calculate_confidence_and_similar_words_str
# model_name = 'sentence-transformers/all-MiniLM-L6-v2'
model_name = 'sentence-transformers/all-mpnet-base-v2'
filename = model_name.replace('/', '-')
pickle_file_path = f'./embeddings/slow/{filename}.pkl'
class SimilaritySlow:
def __init__(self, db_cursor, db_conn):
self.db_cursor = db_cursor
self.db_conn = db_conn
self.model = SentenceTransformer(model_name)
self.db_cursor.execute("SELECT description FROM dictionary")
dictionary = self.db_cursor.fetchall()
dictionary = [item[0] for item in dictionary]
self.dictionary_embeddings = self.load_dictionary_embeddings(dictionary)
def preprocess_dictionary_word(self, text):
text = text.strip().lower().replace(", raw", "").replace(" raw", "").replace(", nfs", "").replace(" nfs", "")
words = text.split()
return [
' '.join(reversed(words)).replace(',', ''),
', '.join(reversed(text.split(', '))),
text,
' '.join(words).replace(',', '')
]
def load_dictionary_embeddings(self, dictionary):
if os.path.exists(pickle_file_path):
with open(pickle_file_path, 'rb') as f:
return pickle.load(f)
else:
dictionary_embeddings = {}
for dictionary_word in tqdm(dictionary, desc="Generating embeddings for dictionary words"):
preprocessed_words = self.preprocess_dictionary_word(dictionary_word)
for preprocessed_word in preprocessed_words:
dictionary_embeddings[preprocessed_word] = {
'v': generate_embedding(self.model, preprocessed_word),
'd': dictionary_word
}
with open(pickle_file_path, 'wb') as f:
pickle.dump(dictionary_embeddings, f)
return dictionary_embeddings
def calculate_similarity_score(self, input_word_clean):
input_embedding = generate_embedding(self.model, input_word_clean)
similarities = []
for key, val in self.dictionary_embeddings.items():
similarities.append((key, val['d'], cosine_similarity(input_embedding, val['v'])))
most_similar_word, dictionary_word, highest_score = max(similarities, key=lambda x: x[2])
confidence_score, similar_words_str = calculate_confidence_and_similar_words_str(similarities, highest_score)
return most_similar_word, dictionary_word, highest_score, confidence_score, similar_words_str
def find_most_similar_word(self, input_word):
if not isinstance(input_word, str) or not input_word:
return None
input_word_clean = clean_word(input_word)
most_similar_word, dictionary_word, highest_score, confidence_score, similar_words_str = self.calculate_similarity_score(input_word_clean)
mapping = {
'input_word': input_word,
'cleaned_word': input_word_clean,
'dictionary_word': dictionary_word,
'similarity_score': highest_score,
'confidence_score': confidence_score,
'similar_words': similar_words_str,
}
return mapping