brightly-ai / old_experiments /llama3-gpu-compare2.py
beweinreich's picture
first
9189e38
raw
history blame
No virus
3.5 kB
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import re
from tqdm import tqdm
import json
# Check if MPS is available
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print("Device:", device)
# Load model and tokenizer
model_id = "meta-llama/Meta-Llama-3-8B-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
# Load dictionary from CSV file
csv_file_path = '/Users/bw/Webstuff/btest/test/dictionary.csv'
df = pd.read_csv(csv_file_path)
dictionary = df['description'].tolist()
# Define the prompt with instructions for comparison
compare_prompt = "Instructions: Provide a one item answer, do not response with a sentence. Respond in JSON format. Q: What food item from this list is most similar to: '{}'?\nList:\n- {}\nA:"
# Method to get similarity score using Llama model's comprehension
def get_similarity_score(input_word, dictionary):
dictionary_list_str = "\n- ".join(dictionary)
input_text = compare_prompt.format(input_word, dictionary_list_str)
inputs = tokenizer(input_text, return_tensors='pt').to(device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, top_k=50, top_p=0.9, temperature=0.7)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the matched word from the generated text
print("Generated text:", generated_text)
match = re.search(r'Q:.*?\nA:\s*(.*)', generated_text)
result = match.group(1).strip() if match else None
# Format the result as JSON
response_json = json.dumps({"input_word": input_word, "matched_entry": result})
return response_json
# Method to find the best match for the input word in the dictionary
def match_word(input_word, dictionary):
# Remove text in parentheses and split into words
input_word_clean = re.sub(r'\(.*?\)', '', input_word).strip()
words = re.findall(r'\w+', input_word_clean.lower())
# Filter dictionary entries containing any of the words
filtered_dictionary = [desc for desc in dictionary if any(word in desc.lower() for word in words)]
# remove duplicate entries
filtered_dictionary = list(set(filtered_dictionary))
print(f"Filtered dictionary size: {len(filtered_dictionary)}")
if not filtered_dictionary:
return None
# Get similarity score
result = get_similarity_score(input_word_clean, filtered_dictionary)
return result
# Example usage
input_words = [
"Pepper - Habanero Pepper", "Bananas (12 lbs)", "Squash - Yellow Squash", "Cauliflower",
"Squash mix italian/yellow (30 lbs)", "Tomato - Roma Tomato", "Tomato - Grape Tomato",
"Squash - Mexican Squash", "Pepper - Bell Pepper", "Squash - Italian Squash",
"Pepper - Red Fresno Pepper", "Tomato - Cherry Tomato", "Pepper - Serrano Pepper",
"Kale ( 5 lbs)", "Tomato - Beefsteak Tomato", "Pepper - Anaheim Pepper",
"Banana - Burro Banana", "Squash - Butternut Squash", "Apricot ( 10 lbs)",
"Squash - Acorn Squash", "Tomato - Heirloom Tomato", "Pepper - Pasilla Pepper",
"Pepper - Jalapeno Pepper", "carrot (10 lbs )"
]
for input_word in tqdm(input_words, desc="Matching Words"):
print("Input word:", input_word)
matched_entry = match_word(input_word, dictionary)
print("Matched entry (JSON):", matched_entry if matched_entry else "None")
print()