import time from PIL import Image import gradio as gr from transformers import pipeline MODEL_MAP = { "ViT (Base/16, 224)": "google/vit-base-patch16-224", "ResNet-50": "microsoft/resnet-50", "EfficientNet-B0": "google/efficientnet-b0" } # Lazy-load to keep startup fast _pipes = {} def get_pipe(model_id: str): if model_id not in _pipes: _pipes[model_id] = pipeline("image-classification", model=model_id, top_k=5) return _pipes[model_id] def predict(img: Image.Image, model_name: str): if img is None: return "Upload an image.", None model_id = MODEL_MAP[model_name] pipe = get_pipe(model_id) t0 = time.time() preds = pipe(img) latency_ms = int((time.time() - t0) * 1000) # Clean top-k dict for Gradio Label scores = {p["label"]: round(float(p["score"]), 3) for p in preds} return scores, f"{model_name} β€’ ~{latency_ms} ms" with gr.Blocks(title="Image Classifier – Multi-Model") as demo: gr.Markdown("# 🐢🐱 Image Classifier (Multi-Model)\nUpload an image, choose a backbone, see top-5 predictions.") with gr.Row(): with gr.Column(): img = gr.Image(type="pil", label="Image") model = gr.Dropdown(list(MODEL_MAP.keys()), value="ViT (Base/16, 224)", label="Backbone") btn = gr.Button("Predict") with gr.Column(): out = gr.Label(label="Top-5") info = gr.Markdown() btn.click(fn=predict, inputs=[img, model], outputs=[out, info]) if __name__ == "__main__": demo.launch()