beweinreich's picture
first
9189e38
raw
history blame
No virus
2.73 kB
import pandas as pd
from transformers import RobertaTokenizer, RobertaModel
from sklearn.metrics.pairwise import cosine_similarity
import torch
import numpy as np
import re
# Load pre-trained RoBERTa model and tokenizer
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaModel.from_pretrained('roberta-base')
# Load dictionary from CSV file
csv_file_path = './dictionary/dictionary.csv'
df = pd.read_csv(csv_file_path)
dictionary = df['description'].tolist()
# Method to compute cosine similarity between two embeddings
def compute_cosine_similarity(embedding1, embedding2):
return cosine_similarity(embedding1, embedding2)[0][0]
# Method to get RoBERTa embeddings
def get_roberta_embedding(text):
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).detach().numpy()
# Method to find the best match for the input word in the dictionary
def match_word(input_word, dictionary):
# Extract words from the input
words = re.findall(r'\w+', input_word.lower())
# Filter dictionary based on words
filtered_dictionary = [desc for desc in dictionary if any(word in desc.lower() for word in words)]
print(f"Filtered dictionary size: {len(filtered_dictionary)}")
# print(f"Filtered dictionary: {filtered_dictionary}")
# Proceed with RoBERTa embeddings and cosine similarity on the filtered dictionary
input_embedding = get_roberta_embedding(input_word)
similarities = []
for entry in filtered_dictionary:
entry_embedding = get_roberta_embedding(entry)
similarity_score = compute_cosine_similarity(input_embedding, entry_embedding)
similarities.append((entry, similarity_score))
if similarities:
best_match = max(similarities, key=lambda x: x[1])
return best_match[0] if best_match[1] > 0.7 else None
else:
return None
# Example usage
# Filtered dictionary size: 224
# Input word: Squash - Yellow Squash
# Matched entry: Squash, Indian, raw (Navajo)
# Filtered dictionary size: 37
# Input word: Cauliflower
# Matched entry: Cauliflower, raw
# Filtered dictionary size: 185
# Input word: Pepper - Habanero Pepper
# Matched entry: Stuffed green pepper, Puerto Rican style
# Filtered dictionary size: 414
# Input word: Bananas (12 lbs)
# Matched entry: Bananas, slightly ripe
# Filtered dictionary size: 784
# Input word: Squash mix italian/yellow (30 lbs )
# Matched entry: Squash, Indian, cooked, boiled (Navajo)
for input_word in input_words:
matched_entry = match_word(input_word, dictionary)
print("Input word:", input_word)
print("Matched entry:", matched_entry)
print()