import re import csv import random import numpy as np import torch from tqdm import tqdm from transformers import pipeline from sentence_transformers import SentenceTransformer, util from db.db_utils import get_connection, initialize_db, get_mapping_from_db, store_mapping_to_db # Load a pre-trained SBERT model model = SentenceTransformer('all-MiniLM-L6-v2') # Set seeds for reproducibility of zero-shot classification def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False set_seed(1) # Load a pre-trained model and tokenizer classifier = pipeline("zero-shot-classification", model="roberta-large-mnli") # Load food categories from CSV def load_food_categories(csv_file): food_categories = set() with open(csv_file, newline='') as csvfile: reader = csv.DictReader(csvfile) for row in reader: food_categories.add(row['food_category']) return list(food_categories) # Path to the CSV file with food categories csv_file_path = 'dictionary/dictionary.csv' food_categories = load_food_categories(csv_file_path) # Precompute embeddings for food categories category_embeddings = model.encode(food_categories, convert_to_tensor=True) # Classify item as food or non-food def classify_as_food_nonfood(item, cursor): if item == 'Non-Food Item': return "non-food", 1.0, None cleaned_item = item.strip().lower() db_record = get_mapping_from_db(cursor, cleaned_item) if db_record and 'is_food' in db_record: is_food = db_record['is_food'] if is_food is not None: print(f"Item: {item} found in database with is_food: {is_food}") return ("food" if is_food else "non-food"), 1.0, db_record result = classifier(item, candidate_labels=["food", "non-food"]) label = result["labels"][0] score = result["scores"][0] print(f"Item: {item}, Label: {label}, Score: {score}") return label, score, db_record # Determine the category of a food item def determine_category(item): item_embedding = model.encode(item, convert_to_tensor=True) similarities = util.pytorch_cos_sim(item_embedding, category_embeddings) category_idx = similarities.argmax() category = food_categories[category_idx] top_3_indices = torch.topk(similarities, 3).indices[0].tolist() top_3_scores = torch.topk(similarities, 3).values[0].tolist() top_3_categories = [(food_categories[idx], score) for idx, score in zip(top_3_indices, top_3_scores)] print("=========================================") print(f"item: {item}") for category, score in top_3_categories: print(f"Category: {category}, Similarity Score: {score:.4f}") return category # Categorize food items def categorize_food_items(items): categories_found = set() for item in items: category = determine_category(item) categories_found.add(category) print(f"Categories found: {categories_found}") if len(categories_found) == 1: return list(categories_found)[0] elif len(categories_found) > 1: return "heterogeneous mixture" else: return "food" # List of items to test items_to_test = [ "Misc Grocery/Yogurt/Coffee/Baking Chips/Coffeemate Creamer/Baby Food/Sports", "Drinks/Cond Milk/Evap Milk/Coffee Syrup/Soup", "Plastic Bags, Water, Snacks, Grocery, Meat, Candy", "Bread/BakeryDesserts/Deli/Seafood", "Mixed Vegetables/Lettuce Assorted", "Rolls/Fruit/Vegetables/Herbs/Butter/Oatmeal/Bread/Salad Greens", "Breast,wings,thighs,legs,tenders", "Swiss Cheese, Provolone cheese, cheddar, mozzarella" ] # Initialize database connection conn = get_connection() cursor = conn.cursor() # Collect results results = [] for items in items_to_test: items_list = [item.strip().lower() for item in re.split(r'[\/,]', items)] item_labels = [classify_as_food_nonfood(item, cursor) for item in items_list] non_food_items = [(item, score, db_record) for item, (label, score, db_record) in zip(items_list, item_labels) if label == "non-food" and score > 0.75] for (item, score, db_record) in non_food_items: # Store non-food items in the database if db_record is None: mapping = (item, item, "Non-Food Item", score, None, None, False) store_mapping_to_db(cursor, conn, mapping) list_label = categorize_food_items(items_list) non_food_items_str = ", ".join([item for item, _, _ in non_food_items]) results.append([items, list_label, non_food_items_str]) # Write results to a CSV file with open('multi-item-experiments/classification_results2.csv', 'w', newline='') as csvfile: csvwriter = csv.writer(csvfile) csvwriter.writerow(['Item List', 'Category', 'Non-Food Items']) csvwriter.writerows(results) # Close the SQLite connection conn.close()