Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import os | |
import sys | |
import cv2 | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from PIL import Image | |
from PIL import ImageFont | |
from PIL import ImageDraw | |
from scipy.stats import rankdata | |
import torch | |
import torch.nn as nn | |
import torchvision | |
from torchvision import transforms as pth_transforms | |
import torchvision.transforms as transforms | |
from einops import rearrange, repeat | |
import vision_transformer as vits | |
def get_vit256(pretrained_weights, arch='vit_small', device=torch.device('cpu')): | |
r""" | |
Builds ViT-256 Model. | |
Args: | |
- pretrained_weights (str): Path to ViT-256 Model Checkpoint. | |
- arch (str): Which model architecture. | |
- device (torch): Torch device to save model. | |
Returns: | |
- model256 (torch.nn): Initialized model. | |
""" | |
checkpoint_key = 'teacher' | |
device = torch.device("cpu") if torch.cuda.is_available() else torch.device("cpu") | |
model256 = vits.__dict__[arch](patch_size=16, num_classes=0) | |
for p in model256.parameters(): | |
p.requires_grad = False | |
model256.eval() | |
model256.to(device) | |
if os.path.isfile(pretrained_weights): | |
state_dict = torch.load(pretrained_weights, map_location="cpu") | |
if checkpoint_key is not None and checkpoint_key in state_dict: | |
print(f"Take key {checkpoint_key} in provided checkpoint dict") | |
state_dict = state_dict[checkpoint_key] | |
# remove `module.` prefix | |
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} | |
# remove `backbone.` prefix induced by multicrop wrapper | |
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} | |
msg = model256.load_state_dict(state_dict, strict=False) | |
print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg)) | |
return model256 | |
def cmap_map(function, cmap): | |
r""" | |
Applies function (which should operate on vectors of shape 3: [r, g, b]), on colormap cmap. | |
This routine will break any discontinuous points in a colormap. | |
Args: | |
- function (function) | |
- cmap (matplotlib.colormap) | |
Returns: | |
- matplotlib.colormap | |
""" | |
cdict = cmap._segmentdata | |
step_dict = {} | |
# Firt get the list of points where the segments start or end | |
for key in ('red', 'green', 'blue'): | |
step_dict[key] = list(map(lambda x: x[0], cdict[key])) | |
step_list = sum(step_dict.values(), []) | |
step_list = np.array(list(set(step_list))) | |
# Then compute the LUT, and apply the function to the LUT | |
reduced_cmap = lambda step : np.array(cmap(step)[0:3]) | |
old_LUT = np.array(list(map(reduced_cmap, step_list))) | |
new_LUT = np.array(list(map(function, old_LUT))) | |
# Now try to make a minimal segment definition of the new LUT | |
cdict = {} | |
for i, key in enumerate(['red','green','blue']): | |
this_cdict = {} | |
for j, step in enumerate(step_list): | |
if step in step_dict[key]: | |
this_cdict[step] = new_LUT[j, i] | |
elif new_LUT[j,i] != old_LUT[j, i]: | |
this_cdict[step] = new_LUT[j, i] | |
colorvector = list(map(lambda x: x + (x[1], ), this_cdict.items())) | |
colorvector.sort() | |
cdict[key] = colorvector | |
return matplotlib.colors.LinearSegmentedColormap('colormap',cdict,1024) | |
def identity(x): | |
r""" | |
Identity Function. | |
Args: | |
- x: | |
Returns: | |
- x | |
""" | |
return x | |
def tensorbatch2im(input_image, imtype=np.uint8): | |
r"""" | |
Converts a Tensor array into a numpy image array. | |
Args: | |
- input_image (torch.Tensor): (B, C, W, H) Torch Tensor. | |
- imtype (type): the desired type of the converted numpy array | |
Returns: | |
- image_numpy (np.array): (B, W, H, C) Numpy Array. | |
""" | |
if not isinstance(input_image, np.ndarray): | |
image_numpy = input_image.cpu().float().numpy() # convert it into a numpy array | |
#if image_numpy.shape[0] == 1: # grayscale to RGB | |
# image_numpy = np.tile(image_numpy, (3, 1, 1)) | |
image_numpy = (np.transpose(image_numpy, (0, 2, 3, 1)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling | |
else: # if it is a numpy array, do nothing | |
image_numpy = input_image | |
return image_numpy.astype(imtype) | |
def getConcatImage(imgs, how='horizontal', gap=0): | |
r""" | |
Function to concatenate list of images (vertical or horizontal). | |
Args: | |
- imgs (list of PIL.Image): List of PIL Images to concatenate. | |
- how (str): How the images are concatenated (either 'horizontal' or 'vertical') | |
- gap (int): Gap (in px) between images | |
Return: | |
- dst (PIL.Image): Concatenated image result. | |
""" | |
gap_dist = (len(imgs)-1)*gap | |
if how == 'vertical': | |
w, h = np.max([img.width for img in imgs]), np.sum([img.height for img in imgs]) | |
h += gap_dist | |
curr_h = 0 | |
dst = Image.new('RGBA', (w, h), color=(255, 255, 255, 0)) | |
for img in imgs: | |
dst.paste(img, (0, curr_h)) | |
curr_h += img.height + gap | |
elif how == 'horizontal': | |
w, h = np.sum([img.width for img in imgs]), np.min([img.height for img in imgs]) | |
w += gap_dist | |
curr_w = 0 | |
dst = Image.new('RGBA', (w, h), color=(255, 255, 255, 0)) | |
for idx, img in enumerate(imgs): | |
dst.paste(img, (curr_w, 0)) | |
curr_w += img.width + gap | |
return dst | |
def add_margin(pil_img, top, right, bottom, left, color): | |
r""" | |
Adds custom margin to PIL.Image. | |
""" | |
width, height = pil_img.size | |
new_width = width + right + left | |
new_height = height + top + bottom | |
result = Image.new(pil_img.mode, (new_width, new_height), color) | |
result.paste(pil_img, (left, top)) | |
return result | |
def concat_scores256(attns, size=(256,256)): | |
r""" | |
""" | |
rank = lambda v: rankdata(v)*100/len(v) | |
color_block = [rank(attn.flatten()).reshape(size) for attn in attns] | |
color_hm = np.concatenate([ | |
np.concatenate(color_block[i:(i+16)], axis=1) | |
for i in range(0,256,16) | |
]) | |
return color_hm | |
def get_scores256(attns, size=(256,256)): | |
r""" | |
""" | |
rank = lambda v: rankdata(v)*100/len(v) | |
color_block = [rank(attn.flatten()).reshape(size) for attn in attns][0] | |
return color_block | |
def get_patch_attention_scores(patch, model256, scale=1, device256=torch.device('cpu')): | |
t = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize( | |
[0.5, 0.5, 0.5], [0.5, 0.5, 0.5] | |
) | |
]) | |
with torch.no_grad(): | |
batch_256 = t(patch).unsqueeze(0) | |
batch_256 = batch_256.to(device256, non_blocking=True) | |
features_256 = model256(batch_256) | |
attention_256 = model256.get_last_selfattention(batch_256) | |
nh = attention_256.shape[1] # number of head | |
attention_256 = attention_256[:, :, 0, 1:].reshape(256, nh, -1) | |
attention_256 = attention_256.reshape(1, nh, 16, 16) | |
attention_256 = nn.functional.interpolate(attention_256, scale_factor=int(16/scale), mode="nearest").cpu().numpy() | |
if scale != 1: | |
batch_256 = nn.functional.interpolate(batch_256, scale_factor=(1/scale), mode="nearest") | |
return tensorbatch2im(batch_256), attention_256 | |
def create_patch_heatmaps_concat(patch, model256, output_dir=None, fname=None, threshold=None, | |
offset=16, alpha=0.5, cmap=plt.get_cmap('coolwarm')): | |
r""" | |
Creates patch heatmaps (concatenated for easy comparison) | |
Args: | |
- patch (PIL.Image): 256 x 256 Image | |
- model256 (torch.nn): 256-Level ViT | |
- output_dir (str): Save directory / subdirectory | |
- fname (str): Naming structure of files | |
- offset (int): How much to offset (from top-left corner with zero-padding) the region by for blending | |
- alpha (float): Image blending factor for cv2.addWeighted | |
- cmap (matplotlib.pyplot): Colormap for creating heatmaps | |
Returns: | |
- None | |
""" | |
patch1 = patch.copy() | |
patch2 = add_margin(patch.crop((16,16,256,256)), top=0, left=0, bottom=16, right=16, color=(255,255,255)) | |
b256_1, a256_1 = get_patch_attention_scores(patch1, model256) | |
b256_1, a256_2 = get_patch_attention_scores(patch2, model256) | |
save_region = np.array(patch.copy()) | |
s = 256 | |
offset_2 = offset | |
if threshold != None: | |
ths = [] | |
for i in range(6): | |
score256_1 = get_scores256(a256_1[:,i,:,:], size=(s,)*2) | |
score256_2 = get_scores256(a256_2[:,i,:,:], size=(s,)*2) | |
new_score256_2 = np.zeros_like(score256_2) | |
new_score256_2[offset_2:s, offset_2:s] = score256_2[:(s-offset_2), :(s-offset_2)] | |
overlay256 = np.ones_like(score256_2)*100 | |
overlay256[offset_2:s, offset_2:s] += 100 | |
score256 = (score256_1+new_score256_2)/overlay256 | |
mask256 = score256.copy() | |
mask256[mask256 < threshold] = 0 | |
mask256[mask256 > threshold] = 0.95 | |
color_block256 = (cmap(mask256)*255)[:,:,:3].astype(np.uint8) | |
region256_hm = cv2.addWeighted(color_block256, alpha, save_region.copy(), 1-alpha, 0, save_region.copy()) | |
region256_hm[mask256==0] = 0 | |
img_inverse = save_region.copy() | |
img_inverse[mask256 == 0.95] = 0 | |
ths.append(region256_hm+img_inverse) | |
ths = [Image.fromarray(img) for img in ths] | |
getConcatImage([getConcatImage(ths[0:3]), | |
getConcatImage(ths[4:6])], how='vertical').save(os.path.join(output_dir, '%s_256th.png' % (fname))) | |
hms = [] | |
for i in range(6): | |
score256_1 = get_scores256(a256_1[:,i,:,:], size=(s,)*2) | |
score256_2 = get_scores256(a256_2[:,i,:,:], size=(s,)*2) | |
new_score256_2 = np.zeros_like(score256_2) | |
new_score256_2[offset_2:s, offset_2:s] = score256_2[:(s-offset_2), :(s-offset_2)] | |
overlay256 = np.ones_like(score256_2)*100 | |
overlay256[offset_2:s, offset_2:s] += 100 | |
score256 = (score256_1+new_score256_2)/overlay256 | |
color_block256 = (cmap(score256)*255)[:,:,:3].astype(np.uint8) | |
region256_hm = cv2.addWeighted(color_block256, alpha, save_region.copy(), 1-alpha, 0, save_region.copy()) | |
hms.append(region256_hm) | |
hms = [Image.fromarray(img) for img in hms] | |
return getConcatImage([getConcatImage(hms[0:3], how='horizontal', gap=10), | |
getConcatImage(hms[4:6], how='horizontal', gap=10)], how='vertical', gap=10) | |
def demo_patch_heatmaps(input_image): | |
light_jet = cmap_map(lambda x: x/2 + 0.5, matplotlib.cm.jet) | |
model256 = get_vit256(pretrained_weights=pretrained_weights256) | |
demo_heatmap = create_patch_heatmaps_concat(input_image, model256, cmap=light_jet) | |
return demo_heatmap | |
pretrained_weights256 = './model.pt' | |
title = "Demo for 11604" | |
description = "To use, upload a 256 x 256 patch (20X magnification). \ | |
The output will generate attention results from 6 attention heads." | |
iface = gr.Interface(fn=demo_patch_heatmaps, | |
inputs=gr.inputs.Image(type='pil'), | |
outputs="image", | |
title=title, | |
description=description, | |
allow_flagging=False) | |
iface.launch() | |