File size: 3,689 Bytes
9189e38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import pandas as pd
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
import re
from tqdm import tqdm

# Check if MPS is available
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

print("Device:", device)

# Load model and tokenizer
model_id = "meta-llama/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)

# Load dictionary from CSV file
csv_file_path = '/Users/bw/Webstuff/btest/test/dictionary.csv'
df = pd.read_csv(csv_file_path)
dictionary = df['description'].tolist()

# Define the prompt with instructions for comparison
compare_prompt = "Compare the following two texts and rate their similarity on a scale from 0 to 1. Text 1: {} Text 2: {}. Similarity score: "

# Method to generate embeddings using the text generation pipeline
def generate_embedding(sentence):
    input_text = sentence
    inputs = tokenizer(input_text, return_tensors='pt').to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    embeddings = outputs.logits.mean(dim=1).squeeze().cpu()
    return embeddings

# Method to get similarity score using Llama model's comprehension
def get_similarity_score(text1, text2):
    input_text = compare_prompt.format(text1, text2)
    inputs = tokenizer(input_text, return_tensors='pt').to(device)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=50)
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    try:
        score = float(re.findall(r"\d+\.\d+", generated_text)[-1])
    except:
        score = 0.0

    print(text1, text2, score)
    
    return score

# Method to find the best match for the input word in the dictionary
def match_word(input_word, dictionary):
    # Remove text in parentheses
    input_word_clean = re.sub(r'\(.*?\)', '', input_word).strip()
    words = re.findall(r'\w+', input_word_clean.lower())
    filtered_dictionary = [desc for desc in dictionary if any(word in desc.lower() for word in words)]
    print(f"Filtered dictionary size: {len(filtered_dictionary)}")

    similarities = []

    for entry in tqdm(filtered_dictionary, desc="Processing Entries"):
        score = get_similarity_score(input_word_clean, entry)
        if 'raw' in entry.lower() and len(words) == 1:
            score += 0.1  # Boost for raw version and single-word input
        similarities.append((entry, score))

    if similarities:
        best_match = max(similarities, key=lambda x: x[1])
        return best_match if best_match[1] > 0.7 else None
    else:
        return None

# Example usage
input_words = [
    "Pepper - Habanero Pepper", "Bananas (12 lbs)", "Squash - Yellow Squash", "Cauliflower", 
    "Squash mix italian/yellow (30 lbs)", "Tomato - Roma Tomato", "Tomato - Grape Tomato",
    "Squash - Mexican Squash", "Pepper - Bell Pepper", "Squash - Italian Squash",
    "Pepper - Red Fresno Pepper", "Tomato - Cherry Tomato", "Pepper - Serrano Pepper",
    "Kale ( 5 lbs)", "Tomato - Beefsteak Tomato", "Pepper - Anaheim Pepper",
    "Banana - Burro Banana", "Squash - Butternut Squash", "Apricot ( 10 lbs)",
    "Squash - Acorn Squash", "Tomato - Heirloom Tomato", "Pepper - Pasilla Pepper",
    "Pepper - Jalapeno Pepper", "carrot (10 lbs )"
]

for input_word in tqdm(input_words, desc="Matching Words"):
    print("Input word:", input_word)
    matched_entry = match_word(input_word, dictionary)
    if matched_entry:
        print("Matched entry:", matched_entry[0])
        print("Similarity score:", matched_entry[1])
    else:
        print("Matched entry: None")
    print()