Spaces:
Runtime error
Runtime error
File size: 4,915 Bytes
7f59780 f20be77 7f59780 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import requests
import gradio as gr
import numpy as np
import cv2
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform
from timm.data.transforms import _pil_interp
from focalnet import FocalNet, build_transforms, build_transforms4display
# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
'''
build model
'''
model = FocalNet(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], use_layerscale=True, use_postln=True)
url = 'https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_iso_16.pth'
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint["model"])
model = model.cuda(); model.eval()
'''
build data transform
'''
eval_transforms = build_transforms(224, center_crop=False)
display_transforms = build_transforms4display(224, center_crop=False)
'''
build upsampler
'''
# upsampler = nn.Upsample(scale_factor=16, mode='bilinear')
'''
borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
'''
def show_cam_on_image(img: np.ndarray,
mask: np.ndarray,
use_rgb: bool = False,
colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
""" This function overlays the cam mask on the image as an heatmap.
By default the heatmap is in BGR format.
:param img: The base image in RGB or BGR format.
:param mask: The cam mask.
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
:param colormap: The OpenCV colormap to be used.
:returns: The default image with the cam overlay.
"""
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
if use_rgb:
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
heatmap = np.float32(heatmap) / 255
if np.max(img) > 1:
raise Exception(
"The input image should np.float32 in the range [0, 1]")
cam = 0.7*heatmap + 0.3*img
# cam = cam / np.max(cam)
return np.uint8(255 * cam)
def classify_image(inp):
img_t = eval_transforms(inp)
img_d = display_transforms(inp).permute(1, 2, 0).cpu().numpy()
print(img_d.min(), img_d.max())
prediction = model(img_t.unsqueeze(0).cuda()).softmax(-1).flatten()
modulator = model.layers[0].blocks[2].modulation.modulator.norm(2, 1, keepdim=True)
modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
modulator = modulator.squeeze(1).detach().permute(1, 2, 0).cpu().numpy()
modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
cam0 = show_cam_on_image(img_d, modulator, use_rgb=True)
modulator = model.layers[0].blocks[5].modulation.modulator.norm(2, 1, keepdim=True)
modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
modulator = modulator.squeeze(1).detach().permute(1, 2, 0).cpu().numpy()
modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
cam1 = show_cam_on_image(img_d, modulator, use_rgb=True)
modulator = model.layers[0].blocks[8].modulation.modulator.norm(2, 1, keepdim=True)
modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
modulator = modulator.squeeze(1).detach().permute(1, 2, 0).cpu().numpy()
modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
cam2 = show_cam_on_image(img_d, modulator, use_rgb=True)
modulator = model.layers[0].blocks[11].modulation.modulator.norm(2, 1, keepdim=True)
modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
modulator = modulator.squeeze(1).detach().permute(1, 2, 0).cpu().numpy()
modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
cam3 = show_cam_on_image(img_d, modulator, use_rgb=True)
return Image.fromarray(cam0), Image.fromarray(cam1), Image.fromarray(cam2), Image.fromarray(cam3), {labels[i]: float(prediction[i]) for i in range(1000)}
image = gr.inputs.Image()
label = gr.outputs.Label(num_top_classes=3)
gr.Interface(
fn=classify_image,
inputs=image,
outputs=[
gr.outputs.Image(
type="pil",
label="Modulator at layer 3"),
gr.outputs.Image(
type="pil",
label="Modulator at layer 6"),
gr.outputs.Image(
type="pil",
label="Modulator at layer 9"),
gr.outputs.Image(
type="pil",
label="Modulator at layer 12"),
label,
],
# examples=[["images/aiko.jpg"], ["images/pencils.jpg"], ["images/donut.png"]],
).launch()
|