Spaces:
Paused
Paused
import pandas as pd | |
from transformers import RobertaTokenizer, RobertaModel | |
from sklearn.metrics.pairwise import cosine_similarity | |
import torch | |
import numpy as np | |
import re | |
# Load pre-trained RoBERTa model and tokenizer | |
tokenizer = RobertaTokenizer.from_pretrained('roberta-base') | |
model = RobertaModel.from_pretrained('roberta-base') | |
# Load dictionary from CSV file | |
csv_file_path = './dictionary/dictionary.csv' | |
df = pd.read_csv(csv_file_path) | |
dictionary = df['description'].tolist() | |
# Method to compute cosine similarity between two embeddings | |
def compute_cosine_similarity(embedding1, embedding2): | |
return cosine_similarity(embedding1, embedding2)[0][0] | |
# Method to get RoBERTa embeddings | |
def get_roberta_embedding(text): | |
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True) | |
outputs = model(**inputs) | |
return outputs.last_hidden_state.mean(dim=1).detach().numpy() | |
# Method to find the best match for the input word in the dictionary | |
def match_word(input_word, dictionary): | |
# Extract words from the input | |
words = re.findall(r'\w+', input_word.lower()) | |
# Filter dictionary based on words | |
filtered_dictionary = [desc for desc in dictionary if any(word in desc.lower() for word in words)] | |
print(f"Filtered dictionary size: {len(filtered_dictionary)}") | |
# print(f"Filtered dictionary: {filtered_dictionary}") | |
# Proceed with RoBERTa embeddings and cosine similarity on the filtered dictionary | |
input_embedding = get_roberta_embedding(input_word) | |
similarities = [] | |
for entry in filtered_dictionary: | |
entry_embedding = get_roberta_embedding(entry) | |
similarity_score = compute_cosine_similarity(input_embedding, entry_embedding) | |
similarities.append((entry, similarity_score)) | |
if similarities: | |
best_match = max(similarities, key=lambda x: x[1]) | |
return best_match[0] if best_match[1] > 0.7 else None | |
else: | |
return None | |
# Example usage | |
# Filtered dictionary size: 224 | |
# Input word: Squash - Yellow Squash | |
# Matched entry: Squash, Indian, raw (Navajo) | |
# Filtered dictionary size: 37 | |
# Input word: Cauliflower | |
# Matched entry: Cauliflower, raw | |
# Filtered dictionary size: 185 | |
# Input word: Pepper - Habanero Pepper | |
# Matched entry: Stuffed green pepper, Puerto Rican style | |
# Filtered dictionary size: 414 | |
# Input word: Bananas (12 lbs) | |
# Matched entry: Bananas, slightly ripe | |
# Filtered dictionary size: 784 | |
# Input word: Squash mix italian/yellow (30 lbs ) | |
# Matched entry: Squash, Indian, cooked, boiled (Navajo) | |
for input_word in input_words: | |
matched_entry = match_word(input_word, dictionary) | |
print("Input word:", input_word) | |
print("Matched entry:", matched_entry) | |
print() | |