brightly-ai / similarity_fast.py
beweinreich's picture
pull pickle from s3
e210682
raw
history blame
No virus
3.95 kB
import os
import re
import pickle
import pandas as pd
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util
from db.db_utils import get_connection, initialize_db, store_mapping_to_db
from food_nonfood import classify_as_food_nonfood
from utils import generate_embedding, clean_word, cosine_similarity, calculate_confidence_and_similar_words_str, generate_embedded_dictionary
# model_name = 'sentence-transformers/all-MiniLM-L6-v2'
model_name = 'sentence-transformers/all-mpnet-base-v2'
filename = model_name.replace('/', '-')
pickle_file_path = f'./embeddings/fast/{filename}.pkl'
class SimilarityFast:
def __init__(self, db_cursor):
self.db_cursor = db_cursor
self.model = SentenceTransformer(model_name)
self.db_cursor.execute("SELECT description FROM dictionary")
dictionary = self.db_cursor.fetchall()
dictionary = [item[0] for item in dictionary]
self.dictionary_embeddings = self.load_dictionary_embeddings(dictionary)
def preprocess_dictionary_word(self, text):
text = text.strip().lower()
text = text.replace(", raw", "").replace(" raw", "")
text = text.replace(", nfs", "").replace(" nfs", "")
if ',' in text:
parts = [part.strip() for part in text.split(',')]
text = ' '.join(reversed(parts))
text = text.strip() # strip again in case there were multiple commas
return text
def load_dictionary_embeddings(self, dictionary):
if os.path.exists(pickle_file_path):
with open(pickle_file_path, 'rb') as f:
return pickle.load(f)
else:
url = "https://s3.amazonaws.com/www.brianweinreich.com/tmp/sentence-transformers-all-mpnet-base-v2.pkl"
response = requests.get(url)
with open(pickle_file_path, 'wb') as f:
f.write(response.content)
with open(pickle_file_path, 'rb') as f:
return pickle.load(f)
# dont generate the embeddings on demand anymore
# dictionary_embeddings = generate_embedded_dictionary(dictionary, self.model, self.preprocess_dictionary_word)
# with open(pickle_file_path, 'wb') as f:
# pickle.dump(dictionary_embeddings, f)
# return dictionary_embeddings
def calculate_similarity_score(self, input_word_clean):
input_embedding = generate_embedding(self.model, input_word_clean)
similarities = []
contains_cooked = 'cooked' in input_word_clean.lower()
for key, val in self.dictionary_embeddings.items():
similarity_score = cosine_similarity(input_embedding, val['v'])
if contains_cooked and 'cooked' in key.lower():
similarity_score *= 1.05
similarities.append((key, val['d'], similarity_score))
most_similar_word, dictionary_word, highest_score = max(similarities, key=lambda x: x[2])
confidence_score, similar_words_str = calculate_confidence_and_similar_words_str(similarities, highest_score)
return most_similar_word, dictionary_word, highest_score, confidence_score, similar_words_str
def find_most_similar_word(self, input_word):
if not isinstance(input_word, str) or not input_word:
return None
input_word_clean = clean_word(input_word)
most_similar_word, dictionary_word, highest_score, confidence_score, similar_words_str = self.calculate_similarity_score(input_word_clean)
mapping = {
'input_word': input_word,
'cleaned_word': input_word_clean,
'matching_word': most_similar_word,
'dictionary_word': dictionary_word,
'similarity_score': highest_score,
'confidence_score': confidence_score,
'similar_words': similar_words_str,
}
return mapping