Spaces:
Paused
Paused
import re | |
import csv | |
import random | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer, util | |
from sklearn.manifold import TSNE | |
import matplotlib.pyplot as plt | |
from db.db_utils import get_connection, get_mapping_from_db | |
# Load a pre-trained SBERT model | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Set seeds for reproducibility of zero-shot classification | |
def set_seed(seed): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
set_seed(1) | |
# Load a pre-trained model and tokenizer | |
classifier = pipeline("zero-shot-classification", model="roberta-large-mnli") | |
# Load food categories from CSV | |
def load_food_categories(csv_file): | |
food_categories = set() | |
with open(csv_file, newline='') as csvfile: | |
reader = csv.DictReader(csvfile) | |
for row in reader: | |
food_categories.add(row['food_category']) | |
return list(food_categories) | |
# Path to the CSV file with food categories | |
csv_file_path = 'dictionary/dictionary.csv' | |
food_categories = load_food_categories(csv_file_path) | |
# Precompute embeddings for food categories | |
category_embeddings = model.encode(food_categories, convert_to_tensor=True) | |
# Classify item as food or non-food | |
def classify_item(item, cursor): | |
# Check database for item | |
cleaned_item = item.strip().lower() | |
mapping = get_mapping_from_db(cursor, cleaned_item) | |
if mapping and 'is_food' in mapping: | |
is_food = mapping['is_food'] | |
if is_food is not None: | |
print(f"Item: {item} found in database with is_food: {is_food}") | |
return ("food" if is_food else "non-food"), 1.0 | |
# If not found in database, classify using the model | |
result = classifier(item, candidate_labels=["food", "non-food"]) | |
label = result["labels"][0] | |
score = result["scores"][0] | |
print(f"Item: {item}, Label: {label}, Score: {score}") | |
return label, score | |
# Determine the category of a food item | |
def determine_category(item): | |
item_embedding = model.encode(item, convert_to_tensor=True) | |
similarities = util.pytorch_cos_sim(item_embedding, category_embeddings) | |
category_idx = similarities.argmax() | |
category = food_categories[category_idx] | |
# Assuming 'similarities' is a tensor of similarity scores and 'food_categories' is the list of category names | |
top_3_indices = torch.topk(similarities, 3).indices[0].tolist() | |
top_3_scores = torch.topk(similarities, 3).values[0].tolist() | |
top_3_categories = [(food_categories[idx], score) for idx, score in zip(top_3_indices, top_3_scores)] | |
print("=========================================") | |
print(f"item: {item}") | |
for category, score in top_3_categories: | |
print(f"Category: {category}, Similarity Score: {score:.4f}") | |
return category | |
# Visualize embeddings | |
def visualize_embeddings(items, categories, item_embeddings, category_embeddings): | |
tsne = TSNE(n_components=2, random_state=1) | |
embeddings = torch.cat([item_embeddings, category_embeddings], dim=0) | |
tsne_embeddings = tsne.fit_transform(embeddings.detach().cpu().numpy()) | |
plt.figure(figsize=(10, 10)) | |
for i, label in enumerate(items + categories): | |
x, y = tsne_embeddings[i] | |
plt.scatter(x, y) | |
plt.text(x+0.1, y+0.1, label, fontsize=9) | |
plt.show() | |
# Categorize food items | |
def categorize_food_items(items): | |
categories_found = set() | |
for item in items: | |
category = determine_category(item) | |
categories_found.add(category) | |
print(f"Categories found: {categories_found}") | |
if len(categories_found) == 1: | |
return list(categories_found)[0] | |
elif len(categories_found) > 1: | |
return "heterogeneous mixture" | |
else: | |
return "food" | |
# List of items to test | |
items_to_test = [ | |
"Swiss Cheese, Provolone cheese, cheddar, mozzarella" | |
] | |
# Initialize database connection | |
conn = get_connection() | |
cursor = conn.cursor() | |
# Collect results and visualize embeddings | |
results = [] | |
item_embeddings = [] | |
items_list = [] | |
for items in items_to_test: | |
# Split items by both "/" and "," and strip extra spaces | |
items_list = [item.strip().lower() for item in re.split(r'[\/,]', items)] | |
item_labels = [classify_item(item, cursor) for item in items_list] | |
non_food_items = [item for item, (label, _) in zip(items_list, item_labels) if label == "non-food"] | |
# Get embeddings for visualization | |
item_embeddings.extend(model.encode(items_list, convert_to_tensor=True)) | |
list_label = categorize_food_items(items_list) | |
results.append([items, list_label, ", ".join(non_food_items)]) | |
# Visualize embeddings | |
visualize_embeddings(items_list, food_categories, torch.stack(item_embeddings), category_embeddings) | |
# Write results to a CSV file | |
with open('multi-item-experiments/classification_results2.csv', 'w', newline='') as csvfile: | |
csvwriter = csv.writer(csvfile) | |
csvwriter.writerow(['Item List', 'Category', 'Non-Food Items']) | |
csvwriter.writerows(results) | |
# Close the SQLite connection | |
conn.close() | |