File size: 5,404 Bytes
9189e38
 
 
 
88cdfd7
9189e38
 
 
 
a1c159e
9189e38
 
 
 
71df5fb
 
9189e38
 
 
 
 
 
 
22ad617
 
 
9189e38
71df5fb
 
 
 
 
 
9189e38
 
 
 
 
 
 
 
 
 
 
71df5fb
 
 
9189e38
 
71df5fb
 
b053f2d
a1c159e
 
 
 
 
b053f2d
9189e38
71df5fb
9189e38
 
 
71df5fb
9189e38
7e5979c
 
4afe3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9189e38
 
 
58b533d
 
 
 
9189e38
 
 
 
71df5fb
9189e38
 
 
71df5fb
 
9189e38
71df5fb
9189e38
 
 
 
e5de092
9189e38
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import re
import pickle
import pandas as pd
import requests
from sentence_transformers import SentenceTransformer, util
from db.db_utils import get_connection, initialize_db, store_mapping_to_db
from food_nonfood import classify_as_food_nonfood
from utils import generate_embedding, clean_word, cosine_similarity, calculate_confidence_and_similar_words_str, generate_embedded_dictionary
from add_mappings_to_embeddings import run_mappings_to_embeddings

# model_name = 'sentence-transformers/all-MiniLM-L6-v2'
model_name = 'sentence-transformers/all-mpnet-base-v2'
filename = model_name.replace('/', '-')
all_pickle_file_path = f'./embeddings/fast/{filename}.pkl'
category_pickle_file_path = f'./embeddings/fast/{filename}-categories.pkl'


class SimilarityFast:
    def __init__(self, db_cursor):
        self.db_cursor = db_cursor
        self.model = SentenceTransformer(model_name)
        
        self.db_cursor.execute("SELECT description FROM dictionary")
        dictionary = self.db_cursor.fetchall()
        dictionary = [item[0] for item in dictionary]

        self.db_cursor.execute("SELECT description FROM dictionary where fdc_id >= 9999000")
        categories = self.db_cursor.fetchall()
        categories = [item[0] for item in categories]

        self.dictionary_embeddings = self.load_dictionary_embeddings(dictionary, all_pickle_file_path)
        self.category_embeddings = self.load_dictionary_embeddings(categories, category_pickle_file_path)

    def preprocess_dictionary_word(self, text):
        text = text.strip().lower()
        text = text.replace(", raw", "").replace(" raw", "")
        text = text.replace(", nfs", "").replace(" nfs", "")
        if ',' in text:
            parts = [part.strip() for part in text.split(',')]
            text = ' '.join(reversed(parts))
        text = text.strip() # strip again in case there were multiple commas
        return text

    def load_dictionary_embeddings(self, data, file_path):
        if os.path.exists(file_path):
            with open(file_path, 'rb') as f:
                return pickle.load(f)
        else:
            dictionary_embeddings = generate_embedded_dictionary(data, self.model, self.preprocess_dictionary_word)
            with open(file_path, 'wb') as f:
                pickle.dump(dictionary_embeddings, f)

            new_entries = run_mappings_to_embeddings(self.model)
            # merge the new entries with the dictionary embeddings
            dictionary_embeddings.update(new_entries)

            return dictionary_embeddings

    def calculate_similarity_score(self, input_word_clean, embeddings):
        input_embedding = generate_embedding(self.model, input_word_clean)
        similarities = []

        for key, val in embeddings:
            similarity_score = cosine_similarity(input_embedding, val['v'])
            
            adjustment_made = False
            if 'cooked' in key.lower() and 'cooked' in input_word_clean:
                adjustment_made = True
                similarity_score *= 1.07
            if 'frozen' in key.lower() and 'frozen' in input_word_clean:
                adjustment_made = True
                similarity_score *= 1.07
            if 'canned' in key.lower() and 'canned' in input_word_clean:
                adjustment_made = True
                similarity_score *= 1.07
            if 'raw' in key.lower() and 'raw' in input_word_clean:
                adjustment_made = True
                similarity_score *= 1.07

            if not adjustment_made:
                if 'cooked' in key.lower():
                    similarity_score *= 0.95
                if 'frozen' in key.lower():
                    similarity_score *= 0.95
                if 'canned' in key.lower():
                    similarity_score *= 0.95
                
                # if we haven't made any adjustments, we can make a slight boost for raw
                if 'raw' in input_word_clean:
                    similarity_score *= 1.02
            
            similarities.append((key, val['d'], similarity_score))
        
        most_similar_word, dictionary_word, highest_score = max(similarities, key=lambda x: x[2])

        # ensure highest_score is not negative nor greater than 1
        highest_score = max(0, min(1, highest_score))

        confidence_score, similar_words_str = calculate_confidence_and_similar_words_str(similarities, highest_score)

        return most_similar_word, dictionary_word, highest_score, confidence_score, similar_words_str
        
    def find_most_similar_word(self, input_word, only_categories=False):
        if not isinstance(input_word, str) or not input_word:
            return None

        embeddings = self.category_embeddings.items() if only_categories else self.dictionary_embeddings.items()

        input_word_clean = clean_word(input_word)
        most_similar_word, dictionary_word, highest_score, confidence_score, similar_words_str = self.calculate_similarity_score(input_word_clean, embeddings)

        mapping = {
            'input_word': input_word,
            'cleaned_word': input_word_clean,
            'most_similar_word': most_similar_word,
            'dictionary_word': dictionary_word,
            'similarity_score': highest_score,
            'confidence_score': confidence_score,
            'similar_words': similar_words_str,
        }

        return mapping