brightly-ai / item_or_category.py
beweinreich's picture
added in some templates
e5de092
raw
history blame
No virus
1.01 kB
import random
import numpy as np
import torch
import logging
from transformers import pipeline
from autocorrect import Speller
# Load a pre-trained SBERT model
# Set seeds for reproducibility of zero-shot classification
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(1)
# Load a pre-trained model and tokenizer
classifier = pipeline("zero-shot-classification", model="roberta-large-mnli")
spell = Speller()
# Classify item as food or non-food
def classify_as_item_or_category(item):
cleaned_item = item.strip().lower()
spell_fix_item = spell(cleaned_item)
result = classifier(spell_fix_item, candidate_labels=["single food item", "food category"])
label = result["labels"][0]
score = result["scores"][0]
# logging.info(f"Item: {item}, Label: {label}, Score: {score}")
return label, score