# 1. Import the required packages import torch import gradio as gr from typing import Dict from transformers import pipeline huggingface_model_path = "Suraj-Yadav/learn_hf_food_not_food_text_classifier-distilbert-base-uncased" # 2. Define function to use our model on given text def food_not_food_classifier( text: Union[str, list], model_path: str, batch_size: int = 32, device: str = None, get_classifier:bool = False ) -> Dict[str, float]: """ Classifies whether the given text is related to food or not, returning a dictionary of labels and their scores. Args: text (Union[str, list]): The input text or list of texts to classify. model_path (str): The path to the Hugging Face model for classification. batch_size (int): The batch size for processing. Default is 32. device (str): The device to run inference on (e.g., 'cuda', 'cpu'). Default is None (auto-detect best available). Returns: Dict[str, float]: A dictionary where the keys are the labels and the values are the classification scores. """ if device is None: device = set_device() classifier = pipeline( task="text-classification", model=model_path, batch_size=batch_size, device=device, top_k=None # Keep all predictions ) if get_classifier: return classifier else: results = classifier(text) # [[{'label': 'food', 'score': 0.9500328898429871}, {'label': 'not_food', 'score': 0.04996709153056145}]] output_dict = {} for output in results[0]: output_dict[output['label']] = output['score'] return output_dict def gradio_food_classifier(text: str) -> dict: """ A wrapper function for Gradio to classify text using the classify_food_text function. Args: text (str): The input text to classify. Returns: dict: Classification results as a dictionary of label and score. """ classifier = food_not_food_classifier(text=text, model_path=huggingface_model_path, get_classifier=True) results = classifier(text) output_dict = {} for output in results[0]: output_dict[output['label']] = output['score'] return output_dict # 3. Create a Gradio interface with details about our app description = """ A text classifier to determine if a sentence is about food or not food. Fine-tuned from [DistilBERT](https://huggingface.co/distilbert/distilbert-base-uncased) on a [small dataset of food and not food text](https://huggingface.co/datasets/mrdbourke/learn_hf_food_not_food_image_captions). See [source code](https://github.com/mrdbourke/learn-huggingface/blob/main/notebooks/hugging_face_text_classification_tutorial.ipynb). """ demo = gr.Interface(fn=gradio_food_classifier, inputs="text", outputs=gr.Label(num_top_classes=2), title="🍗🚫🥑 Food or Not Food Text Classifier", description=description, examples=[["I whipped up a fresh batch of code, but it seems to have a syntax error."], ["A delicious photo of a plate of scrambled eggs, bacon and toast."]]) # 4. Launch the interface if __name__ == "__main__": demo.launch()