import re import csv import random import numpy as np import torch from tqdm import tqdm from transformers import pipeline from sentence_transformers import SentenceTransformer, util from sklearn.manifold import TSNE import matplotlib.pyplot as plt from db.db_utils import get_connection, get_mapping_from_db # Load a pre-trained SBERT model model = SentenceTransformer('all-MiniLM-L6-v2') # Set seeds for reproducibility of zero-shot classification def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False set_seed(1) # Load a pre-trained model and tokenizer classifier = pipeline("zero-shot-classification", model="roberta-large-mnli") # Load food categories from CSV def load_food_categories(csv_file): food_categories = set() with open(csv_file, newline='') as csvfile: reader = csv.DictReader(csvfile) for row in reader: food_categories.add(row['food_category']) return list(food_categories) # Path to the CSV file with food categories csv_file_path = 'dictionary/dictionary.csv' food_categories = load_food_categories(csv_file_path) # Precompute embeddings for food categories category_embeddings = model.encode(food_categories, convert_to_tensor=True) # Classify item as food or non-food def classify_item(item, cursor): # Check database for item cleaned_item = item.strip().lower() mapping = get_mapping_from_db(cursor, cleaned_item) if mapping and 'is_food' in mapping: is_food = mapping['is_food'] if is_food is not None: print(f"Item: {item} found in database with is_food: {is_food}") return ("food" if is_food else "non-food"), 1.0 # If not found in database, classify using the model result = classifier(item, candidate_labels=["food", "non-food"]) label = result["labels"][0] score = result["scores"][0] print(f"Item: {item}, Label: {label}, Score: {score}") return label, score # Determine the category of a food item def determine_category(item): item_embedding = model.encode(item, convert_to_tensor=True) similarities = util.pytorch_cos_sim(item_embedding, category_embeddings) category_idx = similarities.argmax() category = food_categories[category_idx] # Assuming 'similarities' is a tensor of similarity scores and 'food_categories' is the list of category names top_3_indices = torch.topk(similarities, 3).indices[0].tolist() top_3_scores = torch.topk(similarities, 3).values[0].tolist() top_3_categories = [(food_categories[idx], score) for idx, score in zip(top_3_indices, top_3_scores)] print("=========================================") print(f"item: {item}") for category, score in top_3_categories: print(f"Category: {category}, Similarity Score: {score:.4f}") return category # Visualize embeddings def visualize_embeddings(items, categories, item_embeddings, category_embeddings): tsne = TSNE(n_components=2, random_state=1) embeddings = torch.cat([item_embeddings, category_embeddings], dim=0) tsne_embeddings = tsne.fit_transform(embeddings.detach().cpu().numpy()) plt.figure(figsize=(10, 10)) for i, label in enumerate(items + categories): x, y = tsne_embeddings[i] plt.scatter(x, y) plt.text(x+0.1, y+0.1, label, fontsize=9) plt.show() # Categorize food items def categorize_food_items(items): categories_found = set() for item in items: category = determine_category(item) categories_found.add(category) print(f"Categories found: {categories_found}") if len(categories_found) == 1: return list(categories_found)[0] elif len(categories_found) > 1: return "heterogeneous mixture" else: return "food" # List of items to test items_to_test = [ "Swiss Cheese, Provolone cheese, cheddar, mozzarella" ] # Initialize database connection conn = get_connection() cursor = conn.cursor() # Collect results and visualize embeddings results = [] item_embeddings = [] items_list = [] for items in items_to_test: # Split items by both "/" and "," and strip extra spaces items_list = [item.strip().lower() for item in re.split(r'[\/,]', items)] item_labels = [classify_item(item, cursor) for item in items_list] non_food_items = [item for item, (label, _) in zip(items_list, item_labels) if label == "non-food"] # Get embeddings for visualization item_embeddings.extend(model.encode(items_list, convert_to_tensor=True)) list_label = categorize_food_items(items_list) results.append([items, list_label, ", ".join(non_food_items)]) # Visualize embeddings visualize_embeddings(items_list, food_categories, torch.stack(item_embeddings), category_embeddings) # Write results to a CSV file with open('multi-item-experiments/classification_results2.csv', 'w', newline='') as csvfile: csvwriter = csv.writer(csvfile) csvwriter.writerow(['Item List', 'Category', 'Non-Food Items']) csvwriter.writerows(results) # Close the SQLite connection conn.close()