import subprocess import sys import gradio as gr from transformers import pipeline def install(package, index): subprocess.check_call([sys.executable, "-m", "pip", "install", package, index]) install("natten", "-f https://shi-labs.com/natten/wheels/cpu/torch1.13/index.html") model_names = [ "facebook/deit-base-patch16-224", "facebook/convnext-base-224", "google/vit-base-patch16-224", "microsoft/resnet-50", "microsoft/swin-base-patch4-window7-224", "microsoft/beit-base-patch16-224", "nvidia/mit-b0", "shi-labs/nat-base-in1k-224", "shi-labs/dinat-base-in1k-224", ] def process(image_file, top_k, model_name): p = pipeline("image-classification", model=model_name) pred = p(image_file) return {x["label"]: x["score"] for x in pred[:top_k]} # Inputs image = gr.Image(type="filepath", label="Upload an image") top_k = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Top k classes") model_selection = gr.Dropdown( model_names, value="google/vit-base-patch16-224", label="Pick a model" ) # Output labels = gr.Label() description = "This Space lets you quickly compare the most popular image classifier models available on the hub. All of them have been fine-tuned on the ImageNet-1k dataset. Anecdotally, the three sample images have been generated with a Stable Diffusion model :)" iface = gr.Interface( theme="huggingface", description=description, fn=process, inputs=[image, top_k, model_selection], outputs=[labels], examples=[ ["bike.jpg", 5, "google/vit-base-patch16-224"], ["car.jpg", 5, "microsoft/swin-base-patch4-window7-224"], ["food.jpg", 5, "facebook/convnext-base-224"], ], allow_flagging="never", ) iface.launch()