MetaCLIP / app.py
SkalskiP's picture
debug
fcb4afd
raw
history blame contribute delete
No virus
1.88 kB
from typing import List
import gradio as gr
import numpy as np
import torch
from transformers import CLIPProcessor, CLIPModel
IMAGENET_CLASSES_FILE = "imagenet-classes.txt"
EXAMPLES = ["dog.jpeg", "car.png"]
MARKDOWN = """
# Zero-Shot Image Classification with MetaCLIP
This is the demo for a zero-shot image classification model based on
[MetaCLIP](https://github.com/facebookresearch/MetaCLIP), described in the paper
[Demystifying CLIP Data](https://arxiv.org/abs/2309.16671) that formalizes CLIP data
curation as a simple algorithm.
"""
def load_text_lines(file_path: str) -> List[str]:
with open(file_path, 'r') as file:
lines = file.readlines()
return [line.rstrip() for line in lines]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(device)
processor = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m")
imagenet_classes = load_text_lines(IMAGENET_CLASSES_FILE)
def classify_image(input_image) -> str:
inputs = processor(
text=imagenet_classes,
images=input_image,
return_tensors="pt",
padding=True).to(device)
outputs = model(**inputs)
probs = outputs.logits_per_image.softmax(dim=1)
class_index = np.argmax(probs.detach().cpu().numpy())
return imagenet_classes[class_index]
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
image = gr.Image(image_mode='RGB', type='pil')
output_text = gr.Textbox(label="Output")
submit_button = gr.Button("Submit")
submit_button.click(classify_image, inputs=[image], outputs=output_text)
gr.Examples(
examples=EXAMPLES,
fn=classify_image,
inputs=[image],
outputs=[output_text],
cache_examples=True,
run_on_click=True
)
demo.launch(debug=False)