import requests import gradio as gr import numpy as np import cv2 import torch import torch.nn as nn from PIL import Image from torchvision import transforms from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import create_transform from focalnet import FocalNet, build_transforms, build_transforms4display # Download human-readable labels for ImageNet. response = requests.get("https://git.io/JJkYN") labels = response.text.split("\n") ''' build model ''' model = FocalNet(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], use_layerscale=True, use_postln=True) # url = 'https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_iso_16.pth' # checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) checkpoint = torch.load("./focalnet_base_iso_16.pth", map_location="cpu") model.load_state_dict(checkpoint["model"]) model.eval() ''' build data transform ''' eval_transforms = build_transforms(224, center_crop=False) display_transforms = build_transforms4display(224, center_crop=False) ''' build upsampler ''' # upsampler = nn.Upsample(scale_factor=16, mode='bilinear') ''' borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py ''' def show_cam_on_image(img: np.ndarray, mask: np.ndarray, use_rgb: bool = False, colormap: int = cv2.COLORMAP_JET) -> np.ndarray: """ This function overlays the cam mask on the image as an heatmap. By default the heatmap is in BGR format. :param img: The base image in RGB or BGR format. :param mask: The cam mask. :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. :param colormap: The OpenCV colormap to be used. :returns: The default image with the cam overlay. """ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) if use_rgb: heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) heatmap = np.float32(heatmap) / 255 if np.max(img) > 1: raise Exception( "The input image should np.float32 in the range [0, 1]") cam = 0.5*heatmap + 0.5*img # cam = heatmap # cam = cam / np.max(cam) return np.uint8(255 * cam) def classify_image(inp): img_t = eval_transforms(inp) img_d = display_transforms(inp).permute(1, 2, 0).numpy() print(img_d.min(), img_d.max()) prediction = model(img_t.unsqueeze(0)).softmax(-1).flatten() modulator = model.layers[0].blocks[11].modulation.modulator.norm(2, 1, keepdim=True) modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator) modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy() modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min()) cam0 = show_cam_on_image(img_d, modulator, use_rgb=True) modulator = model.layers[0].blocks[8].modulation.modulator.norm(2, 1, keepdim=True) modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator) modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy() modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min()) cam1 = show_cam_on_image(img_d, modulator, use_rgb=True) modulator = model.layers[0].blocks[5].modulation.modulator.norm(2, 1, keepdim=True) modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator) modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy() modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min()) cam2 = show_cam_on_image(img_d, modulator, use_rgb=True) modulator = model.layers[0].blocks[2].modulation.modulator.norm(2, 1, keepdim=True) modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator) modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy() modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min()) cam3 = show_cam_on_image(img_d, modulator, use_rgb=True) return {labels[i]: float(prediction[i]) for i in range(1000)}, Image.fromarray(cam0), Image.fromarray(cam1), Image.fromarray(cam2), Image.fromarray(cam3), Image.fromarray(np.uint8(255 * img_d)) image = gr.inputs.Image() label = gr.outputs.Label(num_top_classes=3) gr.Interface( description="Image classification and visualizations with FocalNet (https://github.com/microsoft/FocalNet)", fn=classify_image, inputs=image, outputs=[ label, gr.outputs.Image( type="pil", label="Modulator at layer 12"), gr.outputs.Image( type="pil", label="Modulator at layer 9"), gr.outputs.Image( type="pil", label="Modulator at layer 6"), gr.outputs.Image( type="pil", label="Modulator at layer 3"), gr.outputs.Image( type="pil", label="Cropped Input"), ], examples=[["./donut.png"], ["./horses.png"], ["./pencil.png"], ["./ILSVRC2012_val_00031987.JPEG"]], ).launch()