Spaces:
Sleeping
Sleeping
| """ | |
| Image Classification — Compare ResNet-50 / ViT-base / MobileNetV3 | |
| Course: 100 Deep Learning ch2 | |
| """ | |
| import json | |
| import urllib.request | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| import torchvision.transforms as T | |
| import timm | |
| import gradio as gr | |
| from PIL import Image | |
| device = torch.device("cpu") | |
| # --------------------------------------------------------------------------- | |
| # Models | |
| # --------------------------------------------------------------------------- | |
| model_registry = { | |
| "ResNet-50": models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1), | |
| "MobileNetV3-Small": models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1), | |
| "ViT-Base (timm)": timm.create_model("vit_base_patch16_224", pretrained=True), | |
| } | |
| for m in model_registry.values(): | |
| m.eval().to(device) | |
| # --------------------------------------------------------------------------- | |
| # Preprocessing | |
| # --------------------------------------------------------------------------- | |
| preprocess = T.Compose([ | |
| T.Resize(256), | |
| T.CenterCrop(224), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # ImageNet labels | |
| LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" | |
| try: | |
| with urllib.request.urlopen(LABELS_URL) as resp: | |
| LABELS = json.loads(resp.read().decode()) | |
| except Exception: | |
| LABELS = [str(i) for i in range(1000)] | |
| # --------------------------------------------------------------------------- | |
| # Classify | |
| # --------------------------------------------------------------------------- | |
| def classify(image: Image.Image, model_name: str): | |
| if image is None: | |
| return {} | |
| img = image.convert("RGB") | |
| tensor = preprocess(img).unsqueeze(0).to(device) | |
| model = model_registry[model_name] | |
| with torch.no_grad(): | |
| logits = model(tensor) | |
| probs = F.softmax(logits, dim=1)[0] | |
| top5 = torch.topk(probs, 5) | |
| return {LABELS[idx]: float(prob) for prob, idx in zip(top5.values, top5.indices)} | |
| def compare_all(image: Image.Image): | |
| """Run all 3 models and return results.""" | |
| if image is None: | |
| return {}, {}, {} | |
| r1 = classify(image, "ResNet-50") | |
| r2 = classify(image, "MobileNetV3-Small") | |
| r3 = classify(image, "ViT-Base (timm)") | |
| return r1, r2, r3 | |
| # --------------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="Image Classification") as demo: | |
| gr.Markdown( | |
| "# Image Classification\n" | |
| "Upload an image to compare predictions from different architectures.\n" | |
| "*Course: 100 Deep Learning ch2 — CNN*" | |
| ) | |
| with gr.Tab("Single Model"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_single = gr.Image(type="pil", label="Upload Image") | |
| model_choice = gr.Dropdown( | |
| list(model_registry.keys()), value="ResNet-50", label="Model" | |
| ) | |
| btn_single = gr.Button("Classify", variant="primary") | |
| with gr.Column(): | |
| out_single = gr.Label(num_top_classes=5, label="Top-5 Predictions") | |
| btn_single.click(classify, [img_single, model_choice], out_single) | |
| with gr.Tab("Compare All Models"): | |
| with gr.Row(): | |
| img_compare = gr.Image(type="pil", label="Upload Image") | |
| btn_compare = gr.Button("Compare All", variant="primary") | |
| with gr.Row(): | |
| out_resnet = gr.Label(num_top_classes=5, label="ResNet-50") | |
| out_mobile = gr.Label(num_top_classes=5, label="MobileNetV3-Small") | |
| out_vit = gr.Label(num_top_classes=5, label="ViT-Base") | |
| btn_compare.click(compare_all, [img_compare], [out_resnet, out_mobile, out_vit]) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/cat.jpg"], | |
| ["examples/dog.jpg"], | |
| ["examples/car.jpg"], | |
| ], | |
| inputs=[img_single], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |