Spaces:
Runtime error
Runtime error
import subprocess | |
import sys | |
import gradio as gr | |
from transformers import pipeline | |
def install(package, index): | |
subprocess.check_call([sys.executable, "-m", "pip", "install", package, index]) | |
install("natten", "-f https://shi-labs.com/natten/wheels/cpu/torch1.13/index.html") | |
model_names = [ | |
"facebook/deit-base-patch16-224", | |
"facebook/convnext-base-224", | |
"google/vit-base-patch16-224", | |
"microsoft/resnet-50", | |
"microsoft/swin-base-patch4-window7-224", | |
"microsoft/beit-base-patch16-224", | |
"nvidia/mit-b0", | |
"shi-labs/nat-base-in1k-224", | |
"shi-labs/dinat-base-in1k-224", | |
] | |
def process(image_file, top_k, model_name): | |
p = pipeline("image-classification", model=model_name) | |
pred = p(image_file) | |
return {x["label"]: x["score"] for x in pred[:top_k]} | |
# Inputs | |
image = gr.Image(type="filepath", label="Upload an image") | |
top_k = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Top k classes") | |
model_selection = gr.Dropdown( | |
model_names, value="google/vit-base-patch16-224", label="Pick a model" | |
) | |
# Output | |
labels = gr.Label() | |
description = "This Space lets you quickly compare the most popular image classifier models available on the hub. All of them have been fine-tuned on the ImageNet-1k dataset. Anecdotally, the three sample images have been generated with a Stable Diffusion model :)" | |
iface = gr.Interface( | |
theme="huggingface", | |
description=description, | |
fn=process, | |
inputs=[image, top_k, model_selection], | |
outputs=[labels], | |
examples=[ | |
["bike.jpg", 5, "google/vit-base-patch16-224"], | |
["car.jpg", 5, "microsoft/swin-base-patch4-window7-224"], | |
["food.jpg", 5, "facebook/convnext-base-224"], | |
], | |
allow_flagging="never", | |
) | |
iface.launch() | |