Spaces:
Running
Running
import torchvision | |
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
import einops | |
import matplotlib.cm as cm | |
import numpy as np | |
def colorize(tensor, cmap_fn=cm.turbo): | |
colors = cmap_fn(np.linspace(0, 1, 256))[:, :3] | |
colors = torch.from_numpy(colors).to(tensor) | |
tensor = tensor.squeeze(1) if tensor.ndim == 4 else tensor | |
ids = (tensor * 256).clamp(0, 255).long() | |
tensor = F.embedding(ids, colors).permute(0, 3, 1, 2) | |
tensor = tensor.mul(255).clamp(0, 255).byte() | |
return tensor | |
with open("classes.txt") as f: | |
id2label = f.read().splitlines() | |
id2label = [c.split(",")[0].lower() for c in id2label] | |
label2id = dict([(c, i) for i, c in enumerate(id2label)]) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = torchvision.models.resnet50(weights="DEFAULT") | |
model.eval() | |
model.to(device) | |
fmap_pool = dict() | |
grad_pool = dict() | |
def forward_hook(name): | |
def _hook(module, input, output): | |
fmap_pool[name] = output.detach() | |
return _hook | |
def backward_hook(name): | |
def _hook(module, grad_in, grad_out): | |
grad_pool[name] = grad_out[0].detach() | |
return _hook | |
layer_choices = [] | |
for n, m in model.named_children(): | |
layer_choices.append(n) | |
m.register_forward_hook(forward_hook(n)) | |
m.register_backward_hook(backward_hook(n)) | |
preprocess = torchvision.transforms.Compose( | |
[ | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Resize((224, 224)), | |
torchvision.transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
), | |
] | |
) | |
def predict(image): | |
if image is None: | |
return None, None | |
image = preprocess(image)[None].to(device) | |
probs = model(image).softmax(dim=1) | |
result = dict([(c, float(p)) for c, p in zip(id2label, probs[0])]) | |
return result, None | |
def gradcam(image_orig, layer, event: gr.SelectData): | |
# forward & backward | |
target_class = torch.tensor([label2id[event.value]], device=device) | |
gradient = F.one_hot(target_class, num_classes=len(label2id)).float() | |
image = preprocess(image_orig)[None] | |
model(image).backward(gradient=gradient) | |
# Grad-CAM | |
fmaps = fmap_pool[layer] | |
grads = grad_pool[layer] | |
weights = F.adaptive_avg_pool2d(grads, 1) | |
gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True) | |
gcam = F.relu(gcam) | |
# post-process | |
gcam = F.interpolate( | |
gcam, size=image_orig.shape[:2], mode="bicubic", antialias=True | |
) | |
gcam -= einops.reduce(gcam, "b c h w -> b () () ()", "min") | |
gcam /= einops.reduce(gcam, "b c h w -> b () () ()", "max") | |
gcam = colorize(gcam)[0].permute(1, 2, 0).cpu().numpy() | |
return gcam | |
with gr.Blocks(title="Grad-CAM") as demo: | |
gr.Markdown( | |
""" | |
# Grad-CAM | |
Unofficial re-implementation of Grad-CAM (https://arxiv.org/abs/1610.02391).<br> | |
Upload an image and select a prediction to show the Grad-CAM heatmap. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
layer = gr.Dropdown(layer_choices, label="ResNet-50", value="layer4") | |
image = gr.Image(label="input", type="numpy") | |
label = gr.Label(num_top_classes=10, label="top-10 predictions") | |
exmpl = gr.Examples(["cat_dog.png"], image) | |
with gr.Column(): | |
img_out = gr.Image(type="numpy", label="result") | |
image.change(predict, inputs=[image], outputs=[label, img_out]) | |
label.select(gradcam, inputs=[image, layer], outputs=[img_out]) | |
demo.launch() | |