import random import numpy as np import torch import logging from transformers import pipeline from autocorrect import Speller # Load a pre-trained SBERT model # Set seeds for reproducibility of zero-shot classification def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False set_seed(1) # Load a pre-trained model and tokenizer classifier = pipeline("zero-shot-classification", model="roberta-large-mnli") spell = Speller() # Classify item as food or non-food def classify_as_food_nonfood(item): try: cleaned_item = item.strip().lower() result = classifier(cleaned_item, candidate_labels=["food", "non-food"]) label = result["labels"][0] score = result["scores"][0] except Exception as e: logging.info(f"Error: {e}") logging.info(f"item is: {item}") label = "non-food" score = 0.0 if label == "non-food": # check if the item is a drink drink_label, drink_score = classify_as_drink_nondrink(item) if drink_label == "drink" and drink_score >= 0.7: label = "food" score = drink_score # try correcting the spelling if label == "non-food": spell_fix_item = spell(cleaned_item) result = classifier(cleaned_item, candidate_labels=["food", "non-food"]) food_label = result["labels"][0] food_score = result["scores"][0] if food_label == "food" and food_score >= 0.7: label = "food" score = food_score # logging.info(f"Item: {item}, Label: {label}, Score: {score}") return label, score def classify_as_drink_nondrink(item): try: cleaned_item = item.strip().lower() result = classifier(cleaned_item, candidate_labels=["drink", "non-drink"]) label = result["labels"][0] score = result["scores"][0] except Exception as e: logging.info(f"Error: {e}") logging.info(f"item is: {item}") label = "non-drink" score = 0.0 # logging.info(f"Item: {item}, Label: {label}, Score: {score}") return label, score def pessimistic_food_nonfood_score(food_nonfood, similarity_score): # For us to truly believe that the word is nonfood, we need to be confident that it is nonfood. # # Three conditions need to be met: # 1. The word must be classified as nonfood # 2. The food_nonfood_score must be greater than a threshold is_food = food_nonfood[0] == 'food' food_nonfood_score = food_nonfood[1] if is_food == False and food_nonfood_score >= 0.7: is_food = False else: is_food = True return is_food, food_nonfood_score