Spaces:
Paused
Paused
import os | |
import re | |
import pickle | |
import pandas as pd | |
import requests | |
from tqdm import tqdm | |
from sentence_transformers import SentenceTransformer, util | |
from db.db_utils import get_connection, initialize_db, store_mapping_to_db | |
from food_nonfood import classify_as_food_nonfood | |
from utils import generate_embedding, clean_word, cosine_similarity, calculate_confidence_and_similar_words_str, generate_embedded_dictionary | |
from add_mappings_to_embeddings import run_mappings_to_embeddings | |
# model_name = 'sentence-transformers/all-MiniLM-L6-v2' | |
model_name = 'sentence-transformers/all-mpnet-base-v2' | |
filename = model_name.replace('/', '-') | |
pickle_file_path = f'./embeddings/fast/{filename}.pkl' | |
class SimilarityFast: | |
def __init__(self, db_cursor): | |
self.db_cursor = db_cursor | |
self.model = SentenceTransformer(model_name) | |
self.db_cursor.execute("SELECT description FROM dictionary") | |
dictionary = self.db_cursor.fetchall() | |
dictionary = [item[0] for item in dictionary] | |
self.dictionary_embeddings = self.load_dictionary_embeddings(dictionary) | |
def preprocess_dictionary_word(self, text): | |
text = text.strip().lower() | |
text = text.replace(", raw", "").replace(" raw", "") | |
text = text.replace(", nfs", "").replace(" nfs", "") | |
if ',' in text: | |
parts = [part.strip() for part in text.split(',')] | |
text = ' '.join(reversed(parts)) | |
text = text.strip() # strip again in case there were multiple commas | |
return text | |
def load_dictionary_embeddings(self, dictionary): | |
if os.path.exists(pickle_file_path): | |
with open(pickle_file_path, 'rb') as f: | |
return pickle.load(f) | |
else: | |
# url = "https://s3.amazonaws.com/www.brianweinreich.com/tmp/sentence-transformers-all-mpnet-base-v2.pkl" | |
# response = requests.get(url) | |
# with open(pickle_file_path, 'wb') as f: | |
# f.write(response.content) | |
# with open(pickle_file_path, 'rb') as f: | |
# return pickle.load(f) | |
# dont generate the embeddings on demand anymore | |
dictionary_embeddings = generate_embedded_dictionary(dictionary, self.model, self.preprocess_dictionary_word) | |
with open(pickle_file_path, 'wb') as f: | |
pickle.dump(dictionary_embeddings, f) | |
new_entries = run_mappings_to_embeddings(self.model) | |
# merge the new entries with the dictionary embeddings | |
dictionary_embeddings.update(new_entries) | |
return dictionary_embeddings | |
def calculate_similarity_score(self, input_word_clean): | |
input_embedding = generate_embedding(self.model, input_word_clean) | |
similarities = [] | |
for key, val in self.dictionary_embeddings.items(): | |
similarity_score = cosine_similarity(input_embedding, val['v']) | |
adjustment_made = False | |
if 'cooked' in key.lower() and 'cooked' in input_word_clean: | |
adjustment_made = True | |
similarity_score *= 1.07 | |
if 'frozen' in key.lower() and 'frozen' in input_word_clean: | |
adjustment_made = True | |
similarity_score *= 1.07 | |
if 'canned' in key.lower() and 'canned' in input_word_clean: | |
adjustment_made = True | |
similarity_score *= 1.07 | |
if 'raw' in key.lower() and 'raw' in input_word_clean: | |
adjustment_made = True | |
similarity_score *= 1.07 | |
if not adjustment_made: | |
if 'cooked' in key.lower(): | |
similarity_score *= 0.95 | |
if 'frozen' in key.lower(): | |
similarity_score *= 0.95 | |
if 'canned' in key.lower(): | |
similarity_score *= 0.95 | |
# if we haven't made any adjustments, we can make a slight boost for raw | |
if 'raw' in input_word_clean: | |
similarity_score *= 1.02 | |
similarities.append((key, val['d'], similarity_score)) | |
most_similar_word, dictionary_word, highest_score = max(similarities, key=lambda x: x[2]) | |
confidence_score, similar_words_str = calculate_confidence_and_similar_words_str(similarities, highest_score) | |
return most_similar_word, dictionary_word, highest_score, confidence_score, similar_words_str | |
def find_most_similar_word(self, input_word): | |
if not isinstance(input_word, str) or not input_word: | |
return None | |
input_word_clean = clean_word(input_word) | |
most_similar_word, dictionary_word, highest_score, confidence_score, similar_words_str = self.calculate_similarity_score(input_word_clean) | |
mapping = { | |
'input_word': input_word, | |
'cleaned_word': input_word_clean, | |
'matching_word': most_similar_word, | |
'dictionary_word': dictionary_word, | |
'similarity_score': highest_score, | |
'confidence_score': confidence_score, | |
'similar_words': similar_words_str, | |
} | |
return mapping | |