|
from typing import List |
|
|
|
import gradio as gr |
|
import numpy as np |
|
from transformers import CLIPProcessor, CLIPModel |
|
|
|
IMAGENET_CLASSES_FILE = "imagenet-classes.txt" |
|
EXAMPLES = ["dog.jpeg", "car.png"] |
|
|
|
MARKDOWN = """ |
|
# Zero-Shot Image Classification with MetaCLIP |
|
|
|
This is the demo for a zero-shot image classification model based on |
|
[MetaCLIP](https://github.com/facebookresearch/MetaCLIP), described in the paper |
|
[Demystifying CLIP Data](https://arxiv.org/abs/2309.16671) that formalizes CLIP data |
|
curation as a simple algorithm. |
|
""" |
|
|
|
|
|
def load_text_lines(file_path: str) -> List[str]: |
|
with open(file_path, 'r') as file: |
|
lines = file.readlines() |
|
return [line.rstrip() for line in lines] |
|
|
|
|
|
model = CLIPModel.from_pretrained("facebook/metaclip-b32-400m") |
|
processor = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m") |
|
imagenet_classes = load_text_lines(IMAGENET_CLASSES_FILE) |
|
|
|
|
|
def classify_image(input_image) -> str: |
|
inputs = processor( |
|
text=imagenet_classes, |
|
images=input_image, |
|
return_tensors="pt", |
|
padding=True) |
|
outputs = model(**inputs) |
|
probs = outputs.logits_per_image.softmax(dim=1) |
|
class_index = np.argmax(probs.detach().numpy()) |
|
return imagenet_classes[class_index] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(MARKDOWN) |
|
with gr.Row(): |
|
image = gr.Image(image_mode='RGB', type='pil') |
|
output_text = gr.Textbox(label="Output") |
|
submit_button = gr.Button("Submit") |
|
|
|
submit_button.click(classify_image, inputs=[image], outputs=output_text) |
|
|
|
gr.Examples( |
|
examples=EXAMPLES, |
|
fn=classify_image, |
|
inputs=[image], |
|
outputs=[output_text], |
|
cache_examples=True, |
|
run_on_click=True |
|
) |
|
|
|
demo.queue(max_size=64).launch(debug=False) |
|
|