import re
import csv
import random
import numpy as np
import torch
from tqdm import tqdm
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
from db.db_utils import get_connection, initialize_db, get_mapping_from_db, store_mapping_to_db
# Load a pre-trained SBERT model
model = SentenceTransformer('all-MiniLM-L6-v2')
# Set seeds for reproducibility of zero-shot classification
def set_seed(seed):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Load a pre-trained model and tokenizer
classifier = pipeline("zero-shot-classification", model="roberta-large-mnli")
# Load food categories from CSV
def load_food_categories(csv_file):
food_categories = set()
with open(csv_file, newline='') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
return list(food_categories)
# Path to the CSV file with food categories
csv_file_path = 'dictionary/dictionary.csv'
food_categories = load_food_categories(csv_file_path)
# Precompute embeddings for food categories
category_embeddings = model.encode(food_categories, convert_to_tensor=True)
# Classify item as food or non-food
def classify_as_food_nonfood(item, cursor):
if item == 'Non-Food Item':
return "non-food", 1.0, None
cleaned_item = item.strip().lower()
db_record = get_mapping_from_db(cursor, cleaned_item)
if db_record and 'is_food' in db_record:
is_food = db_record['is_food']
if is_food is not None:
print(f"Item: {item} found in database with is_food: {is_food}")
return ("food" if is_food else "non-food"), 1.0, db_record
result = classifier(item, candidate_labels=["food", "non-food"])
label = result["labels"][0]
score = result["scores"][0]
print(f"Item: {item}, Label: {label}, Score: {score}")
return label, score, db_record
# Determine the category of a food item
def determine_category(item):
item_embedding = model.encode(item, convert_to_tensor=True)
similarities = util.pytorch_cos_sim(item_embedding, category_embeddings)
category_idx = similarities.argmax()
category = food_categories[category_idx]
top_3_indices = torch.topk(similarities, 3).indices[0].tolist()
top_3_scores = torch.topk(similarities, 3).values[0].tolist()
top_3_categories = [(food_categories[idx], score) for idx, score in zip(top_3_indices, top_3_scores)]
print(f"item: {item}")
for category, score in top_3_categories:
print(f"Category: {category}, Similarity Score: {score:.4f}")
return category
# Categorize food items
def categorize_food_items(items):
categories_found = set()
for item in items:
category = determine_category(item)
print(f"Categories found: {categories_found}")
if len(categories_found) == 1:
return list(categories_found)[0]
elif len(categories_found) > 1:
return "heterogeneous mixture"
return "food"
# List of items to test
items_to_test = [
"Misc Grocery/Yogurt/Coffee/Baking Chips/Coffeemate Creamer/Baby Food/Sports",
"Drinks/Cond Milk/Evap Milk/Coffee Syrup/Soup",
"Plastic Bags, Water, Snacks, Grocery, Meat, Candy",
"Mixed Vegetables/Lettuce Assorted",
"Rolls/Fruit/Vegetables/Herbs/Butter/Oatmeal/Bread/Salad Greens",
"Swiss Cheese, Provolone cheese, cheddar, mozzarella"
# Initialize database connection
conn = get_connection()
cursor = conn.cursor()
# Collect results
results = []
for items in items_to_test:
items_list = [item.strip().lower() for item in re.split(r'[\/,]', items)]
item_labels = [classify_as_food_nonfood(item, cursor) for item in items_list]
non_food_items = [(item, score, db_record) for item, (label, score, db_record) in zip(items_list, item_labels) if label == "non-food" and score > 0.75]
for (item, score, db_record) in non_food_items:
# Store non-food items in the database
if db_record is None:
mapping = (item, item, "Non-Food Item", score, None, None, False)
store_mapping_to_db(cursor, conn, mapping)
list_label = categorize_food_items(items_list)
non_food_items_str = ", ".join([item for item, _, _ in non_food_items])
results.append([items, list_label, non_food_items_str])
# Write results to a CSV file
with open('multi-item-experiments/classification_results2.csv', 'w', newline='') as csvfile:
csvwriter = csv.writer(csvfile)
csvwriter.writerow(['Item List', 'Category', 'Non-Food Items'])
# Close the SQLite connection