from typing import List import gradio as gr import numpy as np import torch from transformers import CLIPProcessor, CLIPModel IMAGENET_CLASSES_FILE = "imagenet-classes.txt" EXAMPLES = ["dog.jpeg", "car.png"] MARKDOWN = """ # Zero-Shot Image Classification with MetaCLIP This is the demo for a zero-shot image classification model based on [MetaCLIP](https://github.com/facebookresearch/MetaCLIP), described in the paper [Demystifying CLIP Data](https://arxiv.org/abs/2309.16671) that formalizes CLIP data curation as a simple algorithm. """ def load_text_lines(file_path: str) -> List[str]: with open(file_path, 'r') as file: lines = file.readlines() return [line.rstrip() for line in lines] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(device) processor = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m") imagenet_classes = load_text_lines(IMAGENET_CLASSES_FILE) def classify_image(input_image) -> str: inputs = processor( text=imagenet_classes, images=input_image, return_tensors="pt", padding=True).to(device) outputs = model(**inputs) probs = outputs.logits_per_image.softmax(dim=1) class_index = np.argmax(probs.detach().cpu().numpy()) return imagenet_classes[class_index] with gr.Blocks() as demo: gr.Markdown(MARKDOWN) with gr.Row(): image = gr.Image(image_mode='RGB', type='pil') output_text = gr.Textbox(label="Output") submit_button = gr.Button("Submit") submit_button.click(classify_image, inputs=[image], outputs=output_text) gr.Examples( examples=EXAMPLES, fn=classify_image, inputs=[image], outputs=[output_text], cache_examples=True, run_on_click=True ) demo.launch(debug=False)