beweinreich's picture
first
9189e38
raw
history blame
No virus
3.35 kB
import pandas as pd
from sentence_transformers import SentenceTransformer, util
import re
# Load pre-trained SBERT model
model = SentenceTransformer('all-mpnet-base-v2') # Larger and more accurate model
# Load dictionary from CSV file
csv_file_path = './dictionary/dictionary.csv'
df = pd.read_csv(csv_file_path)
dictionary = df['description'].tolist()
# Method to preprocess the input word
def preprocess_input(input_word):
# Remove text within parentheses
input_word = re.sub(r'\(.*?\)', '', input_word).strip()
# Handle broad category and specific item separately when there's a hyphen
if ' - ' in input_word:
broad_category, specific_item = input_word.split(' - ', 1)
return specific_item.strip(), broad_category.strip()
return input_word
# Method to create regex pattern for filtering
def create_regex_pattern(input_word):
words = re.findall(r'\w+', input_word.lower())
pattern = '|'.join([re.escape(word) for word in words])
return pattern
# Method to find the best match for the input word in the dictionary
def match_word(input_word, dictionary):
processed_input = preprocess_input(input_word)
if isinstance(processed_input, tuple):
specific_item, broad_category = processed_input
specific_pattern = create_regex_pattern(specific_item)
broad_pattern = create_regex_pattern(broad_category)
filtered_dictionary = [desc for desc in dictionary if re.search(specific_pattern, desc.lower()) or re.search(broad_pattern, desc.lower())]
else:
specific_item = processed_input
specific_pattern = create_regex_pattern(specific_item)
filtered_dictionary = [desc for desc in dictionary if re.search(specific_pattern, desc.lower())]
print(f"Filtered dictionary size: {len(filtered_dictionary)}")
input_embedding = model.encode(input_word, convert_to_tensor=True)
similarities = []
for entry in filtered_dictionary:
entry_embedding = model.encode(entry, convert_to_tensor=True)
similarity_score = util.pytorch_cos_sim(input_embedding, entry_embedding).item()
similarities.append((entry, similarity_score))
if similarities:
best_match = max(similarities, key=lambda x: x[1])
return best_match if best_match[1] > 0.7 else None
else:
return None
# Example usage
input_words = [
"Carrot (10 lbs )",
"Pepper - Habanero Pepper", "Bananas (12 lbs)", "Squash - Yellow Squash", "Cauliflower",
"Squash mix italian/yellow (30 lbs)", "Tomato - Roma Tomato", "Tomato - Grape Tomato",
"Squash - Mexican Squash", "Pepper - Bell Pepper", "Squash - Italian Squash",
"Pepper - Red Fresno Pepper", "Tomato - Cherry Tomato", "Pepper - Serrano Pepper",
"Kale ( 5 lbs)", "Tomato - Beefsteak Tomato", "Pepper - Anaheim Pepper",
"Banana - Burro Banana", "Squash - Butternut Squash", "Apricot ( 10 lbs)",
"Squash - Acorn Squash", "Tomato - Heirloom Tomato", "Pepper - Pasilla Pepper",
"Pepper - Jalapeno Pepper"
]
for input_word in input_words:
print("Input word:", input_word)
matched_entry = match_word(input_word, dictionary)
if matched_entry:
print("Matched entry:", matched_entry[0])
print("Similarity score:", matched_entry[1])
else:
print("Matched entry: None")
print()