brightly-ai / similarity_fast.py
beweinreich's picture
add in torch
9ca54a4
raw
history blame
No virus
5.61 kB
import os
import re
import torch
import pickle
import pandas as pd
import requests
from sentence_transformers import SentenceTransformer, util
from db.db_utils import get_connection, initialize_db, store_mapping_to_db
from food_nonfood import classify_as_food_nonfood
from utils import generate_embedding, clean_word, cosine_similarity, calculate_confidence_and_similar_words_str, generate_embedded_dictionary
from add_mappings_to_embeddings import run_mappings_to_embeddings
# model_name = 'sentence-transformers/all-MiniLM-L6-v2'
model_name = 'sentence-transformers/all-mpnet-base-v2'
filename = model_name.replace('/', '-')
all_pickle_file_path = f'./embeddings/fast/{filename}.pkl'
category_pickle_file_path = f'./embeddings/fast/{filename}-categories.pkl'
class SimilarityFast:
def __init__(self, db_cursor):
self.db_cursor = db_cursor
device_available = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device_available)
print(f"Using device: {device_available}")
self.model = SentenceTransformer(model_name).to(self.device)
self.db_cursor.execute("SELECT description FROM dictionary")
dictionary = self.db_cursor.fetchall()
dictionary = [item[0] for item in dictionary]
self.db_cursor.execute("SELECT description FROM dictionary where fdc_id >= 9999000")
categories = self.db_cursor.fetchall()
categories = [item[0] for item in categories]
self.dictionary_embeddings = self.load_dictionary_embeddings(dictionary, all_pickle_file_path)
self.category_embeddings = self.load_dictionary_embeddings(categories, category_pickle_file_path)
def preprocess_dictionary_word(self, text):
text = text.strip().lower()
text = text.replace(", raw", "").replace(" raw", "")
text = text.replace(", nfs", "").replace(" nfs", "")
if ',' in text:
parts = [part.strip() for part in text.split(',')]
text = ' '.join(reversed(parts))
text = text.strip() # strip again in case there were multiple commas
return text
def load_dictionary_embeddings(self, data, file_path):
if os.path.exists(file_path):
with open(file_path, 'rb') as f:
return pickle.load(f)
else:
dictionary_embeddings = generate_embedded_dictionary(data, self.model, self.preprocess_dictionary_word)
with open(file_path, 'wb') as f:
pickle.dump(dictionary_embeddings, f)
new_entries = run_mappings_to_embeddings(self.model)
# merge the new entries with the dictionary embeddings
dictionary_embeddings.update(new_entries)
return dictionary_embeddings
def calculate_similarity_score(self, input_word_clean, embeddings):
input_embedding = generate_embedding(self.model, input_word_clean)
similarities = []
for key, val in embeddings:
similarity_score = cosine_similarity(input_embedding, val['v'])
adjustment_made = False
if 'cooked' in key.lower() and 'cooked' in input_word_clean:
adjustment_made = True
similarity_score *= 1.07
if 'frozen' in key.lower() and 'frozen' in input_word_clean:
adjustment_made = True
similarity_score *= 1.07
if 'canned' in key.lower() and 'canned' in input_word_clean:
adjustment_made = True
similarity_score *= 1.07
if 'raw' in key.lower() and 'raw' in input_word_clean:
adjustment_made = True
similarity_score *= 1.07
if not adjustment_made:
if 'cooked' in key.lower():
similarity_score *= 0.95
if 'frozen' in key.lower():
similarity_score *= 0.95
if 'canned' in key.lower():
similarity_score *= 0.95
# if we haven't made any adjustments, we can make a slight boost for raw
if 'raw' in input_word_clean:
similarity_score *= 1.02
similarities.append((key, val['d'], similarity_score))
most_similar_word, dictionary_word, highest_score = max(similarities, key=lambda x: x[2])
# ensure highest_score is not negative nor greater than 1
highest_score = max(0, min(1, highest_score))
confidence_score, similar_words_str = calculate_confidence_and_similar_words_str(similarities, highest_score)
return most_similar_word, dictionary_word, highest_score, confidence_score, similar_words_str
def find_most_similar_word(self, input_word, only_categories=False):
if not isinstance(input_word, str) or not input_word:
return None
embeddings = self.category_embeddings.items() if only_categories else self.dictionary_embeddings.items()
input_word_clean = clean_word(input_word)
most_similar_word, dictionary_word, highest_score, confidence_score, similar_words_str = self.calculate_similarity_score(input_word_clean, embeddings)
mapping = {
'input_word': input_word,
'cleaned_word': input_word_clean,
'most_similar_word': most_similar_word,
'dictionary_word': dictionary_word,
'similarity_score': highest_score,
'confidence_score': confidence_score,
'similar_words': similar_words_str,
}
return mapping