beweinreich's picture
first
9189e38
raw
history blame
No virus
2.22 kB
import pandas as pd
from transformers import BertTokenizer, BertModel
from sklearn.metrics.pairwise import cosine_similarity
import torch
import numpy as np
import re
# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
# Load dictionary from CSV file
csv_file_path = './dictionary/dictionary.csv'
df = pd.read_csv(csv_file_path)
dictionary = df['description'].tolist()
# Method to compute cosine similarity between two embeddings
def compute_cosine_similarity(embedding1, embedding2):
return cosine_similarity(embedding1, embedding2)[0][0]
# Method to get BERT embeddings
def get_bert_embedding(text):
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).detach().numpy()
# Method to find the best match for the input word in the dictionary
def match_word(input_word, dictionary):
# Extract words from the input
words = re.findall(r'\w+', input_word.lower())
# Filter dictionary based on words
filtered_dictionary = [desc for desc in dictionary if any(word in desc.lower() for word in words)]
print(f"Filtered dictionary size: {len(filtered_dictionary)}")
print(f"Filtered dictionary: {filtered_dictionary}")
# Proceed with BERT embeddings and cosine similarity on the filtered dictionary
input_embedding = get_bert_embedding(input_word)
similarities = []
for entry in filtered_dictionary:
entry_embedding = get_bert_embedding(entry)
similarity_score = compute_cosine_similarity(input_embedding, entry_embedding)
similarities.append((entry, similarity_score))
print(similarities)
if similarities:
best_match = max(similarities, key=lambda x: x[1])
return best_match[0] if best_match[1] > 0.7 else None
else:
return None
# Example usage
input_words = ["Pepper - Habanero Pepper", "Bananas (12 lbs)"]
for input_word in input_words:
matched_entry = match_word(input_word, dictionary)
print("Input word:", input_word)
print("Matched entry:", matched_entry)
print()