Spaces:
Paused
Paused
import pandas as pd | |
from sentence_transformers import SentenceTransformer, util | |
import re | |
# Load pre-trained SBERT model | |
model = SentenceTransformer('all-mpnet-base-v2') # Larger and more accurate model | |
# Load dictionary from CSV file | |
csv_file_path = './dictionary/dictionary.csv' | |
df = pd.read_csv(csv_file_path) | |
dictionary = df['description'].tolist() | |
# Method to preprocess the input word | |
def preprocess_input(input_word): | |
# Remove text within parentheses | |
input_word = re.sub(r'\(.*?\)', '', input_word).strip() | |
# Handle broad category and specific item separately when there's a hyphen | |
if ' - ' in input_word: | |
broad_category, specific_item = input_word.split(' - ', 1) | |
return specific_item.strip(), broad_category.strip() | |
return input_word | |
# Method to create regex pattern for filtering | |
def create_regex_pattern(input_word): | |
words = re.findall(r'\w+', input_word.lower()) | |
pattern = '|'.join([re.escape(word) for word in words]) | |
return pattern | |
# Method to find the best match for the input word in the dictionary | |
def match_word(input_word, dictionary): | |
processed_input = preprocess_input(input_word) | |
if isinstance(processed_input, tuple): | |
specific_item, broad_category = processed_input | |
specific_pattern = create_regex_pattern(specific_item) | |
broad_pattern = create_regex_pattern(broad_category) | |
filtered_dictionary = [desc for desc in dictionary if re.search(specific_pattern, desc.lower()) or re.search(broad_pattern, desc.lower())] | |
else: | |
specific_item = processed_input | |
specific_pattern = create_regex_pattern(specific_item) | |
filtered_dictionary = [desc for desc in dictionary if re.search(specific_pattern, desc.lower())] | |
print(f"Filtered dictionary size: {len(filtered_dictionary)}") | |
input_embedding = model.encode(input_word, convert_to_tensor=True) | |
similarities = [] | |
for entry in filtered_dictionary: | |
entry_embedding = model.encode(entry, convert_to_tensor=True) | |
similarity_score = util.pytorch_cos_sim(input_embedding, entry_embedding).item() | |
similarities.append((entry, similarity_score)) | |
if similarities: | |
best_match = max(similarities, key=lambda x: x[1]) | |
return best_match if best_match[1] > 0.7 else None | |
else: | |
return None | |
# Example usage | |
input_words = [ | |
"Carrot (10 lbs )", | |
"Pepper - Habanero Pepper", "Bananas (12 lbs)", "Squash - Yellow Squash", "Cauliflower", | |
"Squash mix italian/yellow (30 lbs)", "Tomato - Roma Tomato", "Tomato - Grape Tomato", | |
"Squash - Mexican Squash", "Pepper - Bell Pepper", "Squash - Italian Squash", | |
"Pepper - Red Fresno Pepper", "Tomato - Cherry Tomato", "Pepper - Serrano Pepper", | |
"Kale ( 5 lbs)", "Tomato - Beefsteak Tomato", "Pepper - Anaheim Pepper", | |
"Banana - Burro Banana", "Squash - Butternut Squash", "Apricot ( 10 lbs)", | |
"Squash - Acorn Squash", "Tomato - Heirloom Tomato", "Pepper - Pasilla Pepper", | |
"Pepper - Jalapeno Pepper" | |
] | |
for input_word in input_words: | |
print("Input word:", input_word) | |
matched_entry = match_word(input_word, dictionary) | |
if matched_entry: | |
print("Matched entry:", matched_entry[0]) | |
print("Similarity score:", matched_entry[1]) | |
else: | |
print("Matched entry: None") | |
print() | |