beweinreich's picture
first
9189e38
raw
history blame
No virus
2.79 kB
import pandas as pd
from transformers import BertTokenizer, BertModel
from sklearn.metrics.pairwise import cosine_similarity
import torch
import numpy as np
import re
# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
# Load dictionary from CSV file
csv_file_path = './dictionary/dictionary.csv'
df = pd.read_csv(csv_file_path)
dictionary = df['description'].tolist()
# Method to compute cosine similarity between two embeddings
def compute_cosine_similarity(embedding1, embedding2):
return cosine_similarity(embedding1, embedding2)[0][0]
# Method to get BERT embeddings
def get_bert_embedding(text):
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).detach().numpy()
# Method to find the best match for the input word in the dictionary
def match_word(input_word, dictionary):
# Extract words from the input
words = re.findall(r'\w+', input_word.lower())
# Filter dictionary based on words
filtered_dictionary = [desc for desc in dictionary if any(word in desc.lower() for word in words)]
print(f"Filtered dictionary size: {len(filtered_dictionary)}")
# print(f"Filtered dictionary: {filtered_dictionary}")
# Proceed with BERT embeddings and cosine similarity on the filtered dictionary
input_embedding = get_bert_embedding(input_word)
similarities = []
for entry in filtered_dictionary:
entry_embedding = get_bert_embedding(entry)
similarity_score = compute_cosine_similarity(input_embedding, entry_embedding)
similarities.append((entry, similarity_score))
# print(similarities)
if similarities:
best_match = max(similarities, key=lambda x: x[1])
return best_match[0] if best_match[1] > 0.7 else None
else:
return None
# Example usage
input_words = ["Yellow Squash", "Cauliflower", "Habanero Pepper", "Bananas (12 lbs)", "Squash mix italian/yellow (30 lbs )"]
# Filtered dictionary size: 224
# Input word: Yellow Squash
# Matched entry: None
# Filtered dictionary size: 37
# Input word: Cauliflower
# Matched entry: Fried cauliflower
# Filtered dictionary size: 185
# Input word: Habanero Pepper
# Matched entry: Peppers, ancho, dried
# Filtered dictionary size: 414
# Input word: Bananas (12 lbs)
# Matched entry: Bananas, raw
# Filtered dictionary size: 784
# Input word: Squash mix italian/yellow (30 lbs )
# Matched entry: Nutritional powder mix (Slim Fast)
for input_word in input_words:
matched_entry = match_word(input_word, dictionary)
print("Input word:", input_word)
print("Matched entry:", matched_entry)
print()