Spaces:
Paused
Paused
import pandas as pd | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import re | |
from tqdm import tqdm | |
import json | |
# Check if MPS is available | |
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') | |
print("Device:", device) | |
# Load model and tokenizer | |
model_id = "meta-llama/Meta-Llama-3-8B-instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device) | |
# Load dictionary from CSV file | |
csv_file_path = '/Users/bw/Webstuff/btest/test/dictionary.csv' | |
df = pd.read_csv(csv_file_path) | |
dictionary = df['description'].tolist() | |
# Define the prompt with instructions for comparison | |
compare_prompt = "Instructions: Provide a one item answer, do not response with a sentence. Respond in JSON format. Q: What food item from this list is most similar to: '{}'?\nList:\n- {}\nA:" | |
# Method to get similarity score using Llama model's comprehension | |
def get_similarity_score(input_word, dictionary): | |
dictionary_list_str = "\n- ".join(dictionary) | |
input_text = compare_prompt.format(input_word, dictionary_list_str) | |
inputs = tokenizer(input_text, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, top_k=50, top_p=0.9, temperature=0.7) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract the matched word from the generated text | |
print("Generated text:", generated_text) | |
match = re.search(r'Q:.*?\nA:\s*(.*)', generated_text) | |
result = match.group(1).strip() if match else None | |
# Format the result as JSON | |
response_json = json.dumps({"input_word": input_word, "matched_entry": result}) | |
return response_json | |
# Method to find the best match for the input word in the dictionary | |
def match_word(input_word, dictionary): | |
# Remove text in parentheses and split into words | |
input_word_clean = re.sub(r'\(.*?\)', '', input_word).strip() | |
words = re.findall(r'\w+', input_word_clean.lower()) | |
# Filter dictionary entries containing any of the words | |
filtered_dictionary = [desc for desc in dictionary if any(word in desc.lower() for word in words)] | |
# remove duplicate entries | |
filtered_dictionary = list(set(filtered_dictionary)) | |
print(f"Filtered dictionary size: {len(filtered_dictionary)}") | |
if not filtered_dictionary: | |
return None | |
# Get similarity score | |
result = get_similarity_score(input_word_clean, filtered_dictionary) | |
return result | |
# Example usage | |
input_words = [ | |
"Pepper - Habanero Pepper", "Bananas (12 lbs)", "Squash - Yellow Squash", "Cauliflower", | |
"Squash mix italian/yellow (30 lbs)", "Tomato - Roma Tomato", "Tomato - Grape Tomato", | |
"Squash - Mexican Squash", "Pepper - Bell Pepper", "Squash - Italian Squash", | |
"Pepper - Red Fresno Pepper", "Tomato - Cherry Tomato", "Pepper - Serrano Pepper", | |
"Kale ( 5 lbs)", "Tomato - Beefsteak Tomato", "Pepper - Anaheim Pepper", | |
"Banana - Burro Banana", "Squash - Butternut Squash", "Apricot ( 10 lbs)", | |
"Squash - Acorn Squash", "Tomato - Heirloom Tomato", "Pepper - Pasilla Pepper", | |
"Pepper - Jalapeno Pepper", "carrot (10 lbs )" | |
] | |
for input_word in tqdm(input_words, desc="Matching Words"): | |
print("Input word:", input_word) | |
matched_entry = match_word(input_word, dictionary) | |
print("Matched entry (JSON):", matched_entry if matched_entry else "None") | |
print() | |