File size: 3,947 Bytes
9189e38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22ad617
 
 
9189e38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e210682
 
 
9189e38
e210682
 
 
 
 
 
 
 
 
 
9189e38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import re
import pickle
import pandas as pd
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util
from db.db_utils import get_connection, initialize_db, store_mapping_to_db
from food_nonfood import classify_as_food_nonfood
from utils import generate_embedding, clean_word, cosine_similarity, calculate_confidence_and_similar_words_str, generate_embedded_dictionary

# model_name = 'sentence-transformers/all-MiniLM-L6-v2'
model_name = 'sentence-transformers/all-mpnet-base-v2'
filename = model_name.replace('/', '-')
pickle_file_path = f'./embeddings/fast/{filename}.pkl'


class SimilarityFast:
    def __init__(self, db_cursor):
        self.db_cursor = db_cursor
        self.model = SentenceTransformer(model_name)
        
        self.db_cursor.execute("SELECT description FROM dictionary")
        dictionary = self.db_cursor.fetchall()
        dictionary = [item[0] for item in dictionary]

        self.dictionary_embeddings = self.load_dictionary_embeddings(dictionary)

    def preprocess_dictionary_word(self, text):
        text = text.strip().lower()
        text = text.replace(", raw", "").replace(" raw", "")
        text = text.replace(", nfs", "").replace(" nfs", "")
        if ',' in text:
            parts = [part.strip() for part in text.split(',')]
            text = ' '.join(reversed(parts))
        text = text.strip() # strip again in case there were multiple commas
        return text

    def load_dictionary_embeddings(self, dictionary):
        if os.path.exists(pickle_file_path):
            with open(pickle_file_path, 'rb') as f:
                return pickle.load(f)
        else:
            url = "https://s3.amazonaws.com/www.brianweinreich.com/tmp/sentence-transformers-all-mpnet-base-v2.pkl"
            response = requests.get(url)

            with open(pickle_file_path, 'wb') as f:
                f.write(response.content)
                
            with open(pickle_file_path, 'rb') as f:
                return pickle.load(f)

            # dont generate the embeddings on demand anymore
            # dictionary_embeddings = generate_embedded_dictionary(dictionary, self.model, self.preprocess_dictionary_word)
            # with open(pickle_file_path, 'wb') as f:
            #     pickle.dump(dictionary_embeddings, f)
            # return dictionary_embeddings

    def calculate_similarity_score(self, input_word_clean):
        input_embedding = generate_embedding(self.model, input_word_clean)
        similarities = []

        contains_cooked = 'cooked' in input_word_clean.lower()

        for key, val in self.dictionary_embeddings.items():
            similarity_score = cosine_similarity(input_embedding, val['v'])
            if contains_cooked and 'cooked' in key.lower():
                similarity_score *= 1.05
            similarities.append((key, val['d'], similarity_score))
        
        most_similar_word, dictionary_word, highest_score = max(similarities, key=lambda x: x[2])
        
        confidence_score, similar_words_str = calculate_confidence_and_similar_words_str(similarities, highest_score)

        return most_similar_word, dictionary_word, highest_score, confidence_score, similar_words_str
        
    def find_most_similar_word(self, input_word):
        if not isinstance(input_word, str) or not input_word:
            return None

        input_word_clean = clean_word(input_word)
        most_similar_word, dictionary_word, highest_score, confidence_score, similar_words_str = self.calculate_similarity_score(input_word_clean)

        mapping = {
            'input_word': input_word,
            'cleaned_word': input_word_clean,
            'matching_word': most_similar_word,
            'dictionary_word': dictionary_word,
            'similarity_score': highest_score,
            'confidence_score': confidence_score,
            'similar_words': similar_words_str,
        }

        return mapping