jeffliulab's picture
Initial deploy
66de2dc verified
"""
Image Classification — Compare ResNet-50 / ViT-base / MobileNetV3
Course: 100 Deep Learning ch2
"""
import json
import urllib.request
import torch
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as T
import timm
import gradio as gr
from PIL import Image
device = torch.device("cpu")
# ---------------------------------------------------------------------------
# Models
# ---------------------------------------------------------------------------
model_registry = {
"ResNet-50": models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1),
"MobileNetV3-Small": models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1),
"ViT-Base (timm)": timm.create_model("vit_base_patch16_224", pretrained=True),
}
for m in model_registry.values():
m.eval().to(device)
# ---------------------------------------------------------------------------
# Preprocessing
# ---------------------------------------------------------------------------
preprocess = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# ImageNet labels
LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
try:
with urllib.request.urlopen(LABELS_URL) as resp:
LABELS = json.loads(resp.read().decode())
except Exception:
LABELS = [str(i) for i in range(1000)]
# ---------------------------------------------------------------------------
# Classify
# ---------------------------------------------------------------------------
def classify(image: Image.Image, model_name: str):
if image is None:
return {}
img = image.convert("RGB")
tensor = preprocess(img).unsqueeze(0).to(device)
model = model_registry[model_name]
with torch.no_grad():
logits = model(tensor)
probs = F.softmax(logits, dim=1)[0]
top5 = torch.topk(probs, 5)
return {LABELS[idx]: float(prob) for prob, idx in zip(top5.values, top5.indices)}
def compare_all(image: Image.Image):
"""Run all 3 models and return results."""
if image is None:
return {}, {}, {}
r1 = classify(image, "ResNet-50")
r2 = classify(image, "MobileNetV3-Small")
r3 = classify(image, "ViT-Base (timm)")
return r1, r2, r3
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="Image Classification") as demo:
gr.Markdown(
"# Image Classification\n"
"Upload an image to compare predictions from different architectures.\n"
"*Course: 100 Deep Learning ch2 — CNN*"
)
with gr.Tab("Single Model"):
with gr.Row():
with gr.Column():
img_single = gr.Image(type="pil", label="Upload Image")
model_choice = gr.Dropdown(
list(model_registry.keys()), value="ResNet-50", label="Model"
)
btn_single = gr.Button("Classify", variant="primary")
with gr.Column():
out_single = gr.Label(num_top_classes=5, label="Top-5 Predictions")
btn_single.click(classify, [img_single, model_choice], out_single)
with gr.Tab("Compare All Models"):
with gr.Row():
img_compare = gr.Image(type="pil", label="Upload Image")
btn_compare = gr.Button("Compare All", variant="primary")
with gr.Row():
out_resnet = gr.Label(num_top_classes=5, label="ResNet-50")
out_mobile = gr.Label(num_top_classes=5, label="MobileNetV3-Small")
out_vit = gr.Label(num_top_classes=5, label="ViT-Base")
btn_compare.click(compare_all, [img_compare], [out_resnet, out_mobile, out_vit])
gr.Examples(
examples=[
["examples/cat.jpg"],
["examples/dog.jpg"],
["examples/car.jpg"],
],
inputs=[img_single],
)
if __name__ == "__main__":
demo.launch()