|
|
|
import torch |
|
import gradio as gr |
|
|
|
from typing import Dict,Union |
|
from transformers import pipeline |
|
|
|
huggingface_model_path = "Suraj-Yadav/learn_hf_food_not_food_text_classifier-distilbert-base-uncased" |
|
|
|
|
|
def set_device() -> torch.device: |
|
""" |
|
Set the device to the best available option: CUDA (if available), MPS (if available on Mac), |
|
or CPU as a fallback. Provides a robust selection mechanism for production environments. |
|
|
|
Returns: |
|
torch.device: The best available device for computation. |
|
""" |
|
if torch.cuda.is_available(): |
|
return torch.device("cuda") |
|
elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): |
|
return torch.device("mps") |
|
else: |
|
return torch.device("cpu") |
|
|
|
|
|
def food_not_food_classifier( |
|
text: Union[str, list], |
|
model_path: str, |
|
batch_size: int = 32, |
|
device: str = None, |
|
get_classifier:bool = False |
|
) -> Dict[str, float]: |
|
""" |
|
Classifies whether the given text is related to food or not, returning a dictionary of labels and their scores. |
|
|
|
Args: |
|
text (Union[str, list]): The input text or list of texts to classify. |
|
model_path (str): The path to the Hugging Face model for classification. |
|
batch_size (int): The batch size for processing. Default is 32. |
|
device (str): The device to run inference on (e.g., 'cuda', 'cpu'). Default is None (auto-detect best available). |
|
|
|
Returns: |
|
Dict[str, float]: A dictionary where the keys are the labels and the values are the classification scores. |
|
""" |
|
|
|
if device is None: |
|
device = set_device() |
|
|
|
classifier = pipeline( |
|
task="text-classification", |
|
model=model_path, |
|
batch_size=batch_size, |
|
device=device, |
|
top_k=None |
|
) |
|
|
|
if get_classifier: |
|
return classifier |
|
else: |
|
|
|
results = classifier(text) |
|
|
|
output_dict = {} |
|
for output in results[0]: |
|
output_dict[output['label']] = output['score'] |
|
|
|
return output_dict |
|
|
|
|
|
def gradio_food_classifier(text: str) -> dict: |
|
""" |
|
A wrapper function for Gradio to classify text using the classify_food_text function. |
|
|
|
Args: |
|
text (str): The input text to classify. |
|
|
|
Returns: |
|
dict: Classification results as a dictionary of label and score. |
|
""" |
|
classifier = food_not_food_classifier(text=text, |
|
model_path=huggingface_model_path, |
|
get_classifier=True) |
|
|
|
results = classifier(text) |
|
|
|
output_dict = {} |
|
for output in results[0]: |
|
output_dict[output['label']] = output['score'] |
|
|
|
return output_dict |
|
|
|
|
|
|
|
description = """ |
|
A text classifier to determine if a sentence is about food or not food. |
|
|
|
Fine-tuned from [DistilBERT](https://huggingface.co/distilbert/distilbert-base-uncased) on a [small dataset of food and not food text](https://huggingface.co/datasets/mrdbourke/learn_hf_food_not_food_image_captions). |
|
|
|
See [source code](https://github.com/mrdbourke/learn-huggingface/blob/main/notebooks/hugging_face_text_classification_tutorial.ipynb). |
|
""" |
|
|
|
demo = gr.Interface(fn=gradio_food_classifier, |
|
inputs="text", |
|
outputs=gr.Label(num_top_classes=2), |
|
title="ππ«π₯ Food or Not Food Text Classifier", |
|
description=description, |
|
examples=[["I whipped up a fresh batch of code, but it seems to have a syntax error."], |
|
["A delicious photo of a plate of scrambled eggs, bacon and toast."]]) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|