Spaces:
Paused
Paused
File size: 2,848 Bytes
9189e38 b72dd6f 9189e38 ecbcfc4 9189e38 ecbcfc4 9189e38 ecbcfc4 9189e38 3f68dec b72dd6f 3f68dec eccaeb3 f7bc87b eccaeb3 ecbcfc4 b72dd6f eccaeb3 b72dd6f eccaeb3 b72dd6f 9189e38 f7bc87b 9189e38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import random
import numpy as np
import torch
import logging
from transformers import pipeline
from autocorrect import Speller
# Load a pre-trained SBERT model
# Set seeds for reproducibility of zero-shot classification
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(1)
# Load a pre-trained model and tokenizer
classifier = pipeline("zero-shot-classification", model="roberta-large-mnli")
spell = Speller()
# Classify item as food or non-food
def classify_as_food_nonfood(item):
try:
cleaned_item = item.strip().lower()
result = classifier(cleaned_item, candidate_labels=["food", "non-food"])
label = result["labels"][0]
score = result["scores"][0]
except Exception as e:
logging.info(f"Error: {e}")
logging.info(f"item is: {item}")
label = "non-food"
score = 0.0
if label == "non-food":
# check if the item is a drink
drink_label, drink_score = classify_as_drink_nondrink(item)
if drink_label == "drink" and drink_score >= 0.7:
label = "food"
score = drink_score
# try correcting the spelling
if label == "non-food":
spell_fix_item = spell(cleaned_item)
result = classifier(cleaned_item, candidate_labels=["food", "non-food"])
food_label = result["labels"][0]
food_score = result["scores"][0]
if food_label == "food" and food_score >= 0.7:
label = "food"
score = food_score
# logging.info(f"Item: {item}, Label: {label}, Score: {score}")
return label, score
def classify_as_drink_nondrink(item):
try:
cleaned_item = item.strip().lower()
result = classifier(cleaned_item, candidate_labels=["drink", "non-drink"])
label = result["labels"][0]
score = result["scores"][0]
except Exception as e:
logging.info(f"Error: {e}")
logging.info(f"item is: {item}")
label = "non-drink"
score = 0.0
# logging.info(f"Item: {item}, Label: {label}, Score: {score}")
return label, score
def pessimistic_food_nonfood_score(food_nonfood, similarity_score):
# For us to truly believe that the word is nonfood, we need to be confident that it is nonfood.
#
# Three conditions need to be met:
# 1. The word must be classified as nonfood
# 2. The food_nonfood_score must be greater than a threshold
is_food = food_nonfood[0] == 'food'
food_nonfood_score = food_nonfood[1]
if is_food == False and food_nonfood_score >= 0.7:
is_food = False
else:
is_food = True
return is_food, food_nonfood_score
|