File size: 2,848 Bytes
9189e38
 
 
b72dd6f
9189e38
ecbcfc4
9189e38
 
 
 
 
 
 
 
 
 
 
 
 
ecbcfc4
 
9189e38
 
ecbcfc4
9189e38
 
 
3f68dec
 
 
 
 
 
b72dd6f
 
3f68dec
 
eccaeb3
 
 
 
f7bc87b
eccaeb3
 
 
ecbcfc4
 
 
 
 
 
 
 
 
 
b72dd6f
eccaeb3
 
 
 
 
 
 
 
 
b72dd6f
 
eccaeb3
 
b72dd6f
9189e38
 
 
 
 
 
 
 
 
 
 
 
f7bc87b
9189e38
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import random
import numpy as np
import torch
import logging
from transformers import pipeline
from autocorrect import Speller
# Load a pre-trained SBERT model

# Set seeds for reproducibility of zero-shot classification
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(1)



# Load a pre-trained model and tokenizer
classifier = pipeline("zero-shot-classification", model="roberta-large-mnli")
spell = Speller()

# Classify item as food or non-food
def classify_as_food_nonfood(item):
    try:
        cleaned_item = item.strip().lower()
        result = classifier(cleaned_item, candidate_labels=["food", "non-food"])
        label = result["labels"][0]
        score = result["scores"][0]
    except Exception as e:
        logging.info(f"Error: {e}")
        logging.info(f"item is: {item}")
        label = "non-food"
        score = 0.0
    
    if label == "non-food":
        # check if the item is a drink
        drink_label, drink_score = classify_as_drink_nondrink(item)
        if drink_label == "drink" and drink_score >= 0.7:
            label = "food"
            score = drink_score

        # try correcting the spelling
        if label == "non-food":
            spell_fix_item = spell(cleaned_item)
            result = classifier(cleaned_item, candidate_labels=["food", "non-food"])
            food_label = result["labels"][0]
            food_score = result["scores"][0]
            if food_label == "food" and food_score >= 0.7:
                label = "food"
                score = food_score

    # logging.info(f"Item: {item}, Label: {label}, Score: {score}")
    return label, score

def classify_as_drink_nondrink(item):
    try:
        cleaned_item = item.strip().lower()
        result = classifier(cleaned_item, candidate_labels=["drink", "non-drink"])
        label = result["labels"][0]
        score = result["scores"][0]
    except Exception as e:
        logging.info(f"Error: {e}")
        logging.info(f"item is: {item}")
        label = "non-drink"
        score = 0.0
    # logging.info(f"Item: {item}, Label: {label}, Score: {score}")
    return label, score

def pessimistic_food_nonfood_score(food_nonfood, similarity_score):
    # For us to truly believe that the word is nonfood, we need to be confident that it is nonfood.
    #
    # Three conditions need to be met:
    # 1. The word must be classified as nonfood
    # 2. The food_nonfood_score must be greater than a threshold

    is_food = food_nonfood[0] == 'food'
    food_nonfood_score = food_nonfood[1]

    if is_food == False and food_nonfood_score >= 0.7:
        is_food = False
    else:
        is_food = True

    return is_food, food_nonfood_score