Spaces:
Runtime error
Runtime error
import requests | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
from torchvision import transforms | |
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
from timm.data import create_transform | |
from focalnet import FocalNet, build_transforms, build_transforms4display | |
# Download human-readable labels for ImageNet. | |
response = requests.get("https://git.io/JJkYN") | |
labels = response.text.split("\n") | |
''' | |
build model | |
''' | |
model = FocalNet(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], use_layerscale=True, use_postln=True) | |
# url = 'https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_iso_16.pth' | |
# checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) | |
checkpoint = torch.load("./focalnet_base_iso_16.pth") | |
model.load_state_dict(checkpoint["model"]) | |
model.eval() | |
''' | |
build data transform | |
''' | |
eval_transforms = build_transforms(224, center_crop=False) | |
display_transforms = build_transforms4display(224, center_crop=False) | |
''' | |
build upsampler | |
''' | |
# upsampler = nn.Upsample(scale_factor=16, mode='bilinear') | |
''' | |
borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py | |
''' | |
def show_cam_on_image(img: np.ndarray, | |
mask: np.ndarray, | |
use_rgb: bool = False, | |
colormap: int = cv2.COLORMAP_JET) -> np.ndarray: | |
""" This function overlays the cam mask on the image as an heatmap. | |
By default the heatmap is in BGR format. | |
:param img: The base image in RGB or BGR format. | |
:param mask: The cam mask. | |
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. | |
:param colormap: The OpenCV colormap to be used. | |
:returns: The default image with the cam overlay. | |
""" | |
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) | |
if use_rgb: | |
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
heatmap = np.float32(heatmap) / 255 | |
if np.max(img) > 1: | |
raise Exception( | |
"The input image should np.float32 in the range [0, 1]") | |
cam = 0.7*heatmap + 0.3*img | |
# cam = cam / np.max(cam) | |
return np.uint8(255 * cam) | |
def classify_image(inp): | |
img_t = eval_transforms(inp) | |
img_d = display_transforms(inp).permute(1, 2, 0).numpy() | |
print(img_d.min(), img_d.max()) | |
prediction = model(img_t.unsqueeze(0)).softmax(-1).flatten() | |
modulator = model.layers[0].blocks[11].modulation.modulator.norm(2, 1, keepdim=True) | |
modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator) | |
modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy() | |
modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min()) | |
cam0 = show_cam_on_image(img_d, modulator, use_rgb=True) | |
modulator = model.layers[0].blocks[8].modulation.modulator.norm(2, 1, keepdim=True) | |
modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator) | |
modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy() | |
modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min()) | |
cam1 = show_cam_on_image(img_d, modulator, use_rgb=True) | |
modulator = model.layers[0].blocks[5].modulation.modulator.norm(2, 1, keepdim=True) | |
modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator) | |
modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy() | |
modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min()) | |
cam2 = show_cam_on_image(img_d, modulator, use_rgb=True) | |
modulator = model.layers[0].blocks[2].modulation.modulator.norm(2, 1, keepdim=True) | |
modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator) | |
modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy() | |
modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min()) | |
cam3 = show_cam_on_image(img_d, modulator, use_rgb=True) | |
return Image.fromarray(cam0), Image.fromarray(cam1), Image.fromarray(cam2), Image.fromarray(cam3), {labels[i]: float(prediction[i]) for i in range(1000)} | |
image = gr.inputs.Image() | |
label = gr.outputs.Label(num_top_classes=3) | |
gr.Interface( | |
description="Image classification and visualizations with FocalNet (https://github.com/microsoft/FocalNet)", | |
fn=classify_image, | |
inputs=image, | |
outputs=[ | |
gr.outputs.Image( | |
type="pil", | |
label="Modulator at layer 12"), | |
gr.outputs.Image( | |
type="pil", | |
label="Modulator at layer 9"), | |
gr.outputs.Image( | |
type="pil", | |
label="Modulator at layer 6"), | |
gr.outputs.Image( | |
type="pil", | |
label="Modulator at layer 3"), | |
label, | |
], | |
examples=[["./donut.png"], ["./horses.png"], ["./pencil.png"]], | |
).launch() | |