brightly-ai / old_experiments /llama3-gpu-instructions.py
beweinreich's picture
first
9189e38
raw
history blame
No virus
3.9 kB
import pandas as pd
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
import re
from tqdm import tqdm
# Check if MPS is available
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print("Device:", device)
# Load model and tokenizer
model_id = "meta-llama/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
# Load dictionary from CSV file
csv_file_path = '/Users/bw/Webstuff/btest/test/dictionary.csv'
df = pd.read_csv(csv_file_path)
dictionary = df['description'].tolist()
# Define the prompt with instructions
prompt = "The text sometimes comes hyphenated, where the part before the hyphen is the general category, and the item after the hyphen is the more specific item. Please generate an embedding for the following text: "
# Method to generate embeddings using the text generation pipeline
def generate_embedding(sentence):
# Combine the prompt with the sentence
input_text = prompt + sentence
inputs = tokenizer(input_text, return_tensors='pt').to(device)
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.logits.mean(dim=1).squeeze().cpu()
return embeddings
# Cosine Similarity
def cosine_similarity(embedding1, embedding2):
return torch.nn.functional.cosine_similarity(embedding1, embedding2, dim=0).item()
# Custom scoring function
def custom_score(input_word, input_embedding, entry_embedding, entry_text):
# Calculate cosine similarity
similarity_score = cosine_similarity(input_embedding, entry_embedding)
# Boost score if the input word is a single word and the entry contains preferred keywords
if 'raw' in entry_text.lower() and len(re.findall(r'\w+', input_word.lower())) == 1:
similarity_score += 0.1 # Adjust this value as needed
return similarity_score
# Method to find the best match for the input word in the dictionary
def match_word(input_word, dictionary):
# Remove text in parentheses
input_word_clean = re.sub(r'\(.*?\)', '', input_word).strip()
words = re.findall(r'\w+', input_word_clean.lower())
filtered_dictionary = [desc for desc in dictionary if any(word in desc.lower() for word in words)]
print(f"Filtered dictionary size: {len(filtered_dictionary)}")
input_embedding = generate_embedding(input_word_clean)
similarities = []
for entry in tqdm(filtered_dictionary, desc="Processing Entries"):
entry_embedding = generate_embedding(entry)
score = custom_score(input_word_clean, input_embedding, entry_embedding, entry)
similarities.append((entry, score))
if similarities:
best_match = max(similarities, key=lambda x: x[1])
return best_match if best_match[1] > 0.7 else None
else:
return None
# Example usage
input_words = [
"Pepper - Habanero Pepper", "Bananas (12 lbs)", "Squash - Yellow Squash", "Cauliflower",
"Squash mix italian/yellow (30 lbs)", "Tomato - Roma Tomato", "Tomato - Grape Tomato",
"Squash - Mexican Squash", "Pepper - Bell Pepper", "Squash - Italian Squash",
"Pepper - Red Fresno Pepper", "Tomato - Cherry Tomato", "Pepper - Serrano Pepper",
"Kale ( 5 lbs)", "Tomato - Beefsteak Tomato", "Pepper - Anaheim Pepper",
"Banana - Burro Banana", "Squash - Butternut Squash", "Apricot ( 10 lbs)",
"Squash - Acorn Squash", "Tomato - Heirloom Tomato", "Pepper - Pasilla Pepper",
"Pepper - Jalapeno Pepper", "carrot (10 lbs )"
]
for input_word in tqdm(input_words, desc="Matching Words"):
print("Input word:", input_word)
matched_entry = match_word(input_word, dictionary)
if matched_entry:
print("Matched entry:", matched_entry[0])
print("Similarity score:", matched_entry[1])
else:
print("Matched entry: None")
print()