brightly-ai / similarity_fast.py
beweinreich's picture
include mappings in dictionary embeddings
a1c159e
raw
history blame
No virus
4.22 kB
import os
import re
import pickle
import pandas as pd
import requests
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util
from db.db_utils import get_connection, initialize_db, store_mapping_to_db
from food_nonfood import classify_as_food_nonfood
from utils import generate_embedding, clean_word, cosine_similarity, calculate_confidence_and_similar_words_str, generate_embedded_dictionary
from add_mappings_to_embeddings import run_mappings_to_embeddings
# model_name = 'sentence-transformers/all-MiniLM-L6-v2'
model_name = 'sentence-transformers/all-mpnet-base-v2'
filename = model_name.replace('/', '-')
pickle_file_path = f'./embeddings/fast/{filename}.pkl'
class SimilarityFast:
def __init__(self, db_cursor):
self.db_cursor = db_cursor
self.model = SentenceTransformer(model_name)
self.db_cursor.execute("SELECT description FROM dictionary")
dictionary = self.db_cursor.fetchall()
dictionary = [item[0] for item in dictionary]
self.dictionary_embeddings = self.load_dictionary_embeddings(dictionary)
def preprocess_dictionary_word(self, text):
text = text.strip().lower()
text = text.replace(", raw", "").replace(" raw", "")
text = text.replace(", nfs", "").replace(" nfs", "")
if ',' in text:
parts = [part.strip() for part in text.split(',')]
text = ' '.join(reversed(parts))
text = text.strip() # strip again in case there were multiple commas
return text
def load_dictionary_embeddings(self, dictionary):
if os.path.exists(pickle_file_path):
with open(pickle_file_path, 'rb') as f:
return pickle.load(f)
else:
# url = "https://s3.amazonaws.com/www.brianweinreich.com/tmp/sentence-transformers-all-mpnet-base-v2.pkl"
# response = requests.get(url)
# with open(pickle_file_path, 'wb') as f:
# f.write(response.content)
# with open(pickle_file_path, 'rb') as f:
# return pickle.load(f)
# dont generate the embeddings on demand anymore
dictionary_embeddings = generate_embedded_dictionary(dictionary, self.model, self.preprocess_dictionary_word)
with open(pickle_file_path, 'wb') as f:
pickle.dump(dictionary_embeddings, f)
new_entries = run_mappings_to_embeddings(self.model)
# merge the new entries with the dictionary embeddings
dictionary_embeddings.update(new_entries)
return dictionary_embeddings
def calculate_similarity_score(self, input_word_clean):
input_embedding = generate_embedding(self.model, input_word_clean)
similarities = []
contains_cooked = 'cooked' in input_word_clean.lower()
for key, val in self.dictionary_embeddings.items():
similarity_score = cosine_similarity(input_embedding, val['v'])
if contains_cooked and 'cooked' in key.lower():
similarity_score *= 1.05
similarities.append((key, val['d'], similarity_score))
most_similar_word, dictionary_word, highest_score = max(similarities, key=lambda x: x[2])
confidence_score, similar_words_str = calculate_confidence_and_similar_words_str(similarities, highest_score)
return most_similar_word, dictionary_word, highest_score, confidence_score, similar_words_str
def find_most_similar_word(self, input_word):
if not isinstance(input_word, str) or not input_word:
return None
input_word_clean = clean_word(input_word)
most_similar_word, dictionary_word, highest_score, confidence_score, similar_words_str = self.calculate_similarity_score(input_word_clean)
mapping = {
'input_word': input_word,
'cleaned_word': input_word_clean,
'matching_word': most_similar_word,
'dictionary_word': dictionary_word,
'similarity_score': highest_score,
'confidence_score': confidence_score,
'similar_words': similar_words_str,
}
return mapping