Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import sys | |
import argparse | |
import cv2 | |
import random | |
import colorsys | |
import requests | |
from io import BytesIO | |
import skimage.io | |
from skimage.measure import find_contours | |
import matplotlib.pyplot as plt | |
from matplotlib.patches import Polygon | |
import torch | |
import torch.nn as nn | |
import torchvision | |
from torchvision import transforms as pth_transforms | |
import numpy as np | |
from PIL import Image | |
import utils | |
import vision_transformer as vits | |
def apply_mask(image, mask, color, alpha=0.5): | |
for c in range(3): | |
image[:, :, c] = ( | |
image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255 | |
) | |
return image | |
def random_colors(N, bright=True): | |
""" | |
Generate random colors. | |
""" | |
brightness = 1.0 if bright else 0.7 | |
hsv = [(i / N, 1, brightness) for i in range(N)] | |
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) | |
random.shuffle(colors) | |
return colors | |
def display_instances( | |
image, mask, fname="test", figsize=(5, 5), blur=False, contour=True, alpha=0.5 | |
): | |
fig = plt.figure(figsize=figsize, frameon=False) | |
ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) | |
ax.set_axis_off() | |
fig.add_axes(ax) | |
ax = plt.gca() | |
N = 1 | |
mask = mask[None, :, :] | |
# Generate random colors | |
colors = random_colors(N) | |
# Show area outside image boundaries. | |
height, width = image.shape[:2] | |
margin = 0 | |
ax.set_ylim(height + margin, -margin) | |
ax.set_xlim(-margin, width + margin) | |
ax.axis("off") | |
masked_image = image.astype(np.uint32).copy() | |
for i in range(N): | |
color = colors[i] | |
_mask = mask[i] | |
if blur: | |
_mask = cv2.blur(_mask, (10, 10)) | |
# Mask | |
masked_image = apply_mask(masked_image, _mask, color, alpha) | |
# Mask Polygon | |
# Pad to ensure proper polygons for masks that touch image edges. | |
if contour: | |
padded_mask = np.zeros((_mask.shape[0] + 2, _mask.shape[1] + 2)) | |
padded_mask[1:-1, 1:-1] = _mask | |
contours = find_contours(padded_mask, 0.5) | |
for verts in contours: | |
# Subtract the padding and flip (y, x) to (x, y) | |
verts = np.fliplr(verts) - 1 | |
p = Polygon(verts, facecolor="none", edgecolor=color) | |
ax.add_patch(p) | |
ax.imshow(masked_image.astype(np.uint8), aspect="auto") | |
fig.savefig(fname) | |
print(f"{fname} saved.") | |
return | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser("Visualize Self-Attention maps") | |
parser.add_argument( | |
"--arch", | |
default="vit_small", | |
type=str, | |
choices=["vit_tiny", "vit_small", "vit_base"], | |
help="Architecture (support only ViT atm).", | |
) | |
parser.add_argument( | |
"--patch_size", default=8, type=int, help="Patch resolution of the model." | |
) | |
parser.add_argument( | |
"--pretrained_weights", | |
default="", | |
type=str, | |
help="Path to pretrained weights to load.", | |
) | |
parser.add_argument( | |
"--checkpoint_key", | |
default="teacher", | |
type=str, | |
help='Key to use in the checkpoint (example: "teacher")', | |
) | |
parser.add_argument( | |
"--image_path", default=None, type=str, help="Path of the image to load." | |
) | |
parser.add_argument( | |
"--image_size", default=(480, 480), type=int, nargs="+", help="Resize image." | |
) | |
parser.add_argument( | |
"--output_dir", default=".", help="Path where to save visualizations." | |
) | |
parser.add_argument( | |
"--threshold", | |
type=float, | |
default=None, | |
help="""We visualize masks | |
obtained by thresholding the self-attention maps to keep xx% of the mass.""", | |
) | |
args = parser.parse_args() | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
# build model | |
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) | |
for p in model.parameters(): | |
p.requires_grad = False | |
model.eval() | |
model.to(device) | |
if os.path.isfile(args.pretrained_weights): | |
state_dict = torch.load(args.pretrained_weights, map_location="cpu") | |
if args.checkpoint_key is not None and args.checkpoint_key in state_dict: | |
print(f"Take key {args.checkpoint_key} in provided checkpoint dict") | |
state_dict = state_dict[args.checkpoint_key] | |
# remove `module.` prefix | |
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} | |
# remove `backbone.` prefix induced by multicrop wrapper | |
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} | |
msg = model.load_state_dict(state_dict, strict=False) | |
print( | |
"Pretrained weights found at {} and loaded with msg: {}".format( | |
args.pretrained_weights, msg | |
) | |
) | |
else: | |
print( | |
"Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate." | |
) | |
url = None | |
if args.arch == "vit_small" and args.patch_size == 16: | |
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" | |
elif args.arch == "vit_small" and args.patch_size == 8: | |
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper | |
elif args.arch == "vit_base" and args.patch_size == 16: | |
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" | |
elif args.arch == "vit_base" and args.patch_size == 8: | |
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" | |
if url is not None: | |
print( | |
"Since no pretrained weights have been provided, we load the reference pretrained DINO weights." | |
) | |
state_dict = torch.hub.load_state_dict_from_url( | |
url="https://dl.fbaipublicfiles.com/dino/" + url | |
) | |
model.load_state_dict(state_dict, strict=True) | |
else: | |
print( | |
"There is no reference weights available for this model => We use random weights." | |
) | |
# open image | |
if args.image_path is None: | |
# user has not specified any image - we use our own image | |
print( | |
"Please use the `--image_path` argument to indicate the path of the image you wish to visualize." | |
) | |
print( | |
"Since no image path have been provided, we take the first image in our paper." | |
) | |
response = requests.get("https://dl.fbaipublicfiles.com/dino/img.png") | |
img = Image.open(BytesIO(response.content)) | |
img = img.convert("RGB") | |
elif os.path.isfile(args.image_path): | |
with open(args.image_path, "rb") as f: | |
img = Image.open(f) | |
img = img.convert("RGB") | |
else: | |
print(f"Provided image path {args.image_path} is non valid.") | |
sys.exit(1) | |
transform = pth_transforms.Compose( | |
[ | |
pth_transforms.Resize(args.image_size), | |
pth_transforms.ToTensor(), | |
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
] | |
) | |
img = transform(img) | |
# make the image divisible by the patch size | |
w, h = ( | |
img.shape[1] - img.shape[1] % args.patch_size, | |
img.shape[2] - img.shape[2] % args.patch_size, | |
) | |
img = img[:, :w, :h].unsqueeze(0) | |
w_featmap = img.shape[-2] // args.patch_size | |
h_featmap = img.shape[-1] // args.patch_size | |
attentions = model.get_last_selfattention(img.to(device)) | |
nh = attentions.shape[1] # number of head | |
# we keep only the output patch attention | |
attentions = attentions[0, :, 0, 1:].reshape(nh, -1) | |
if args.threshold is not None: | |
# we keep only a certain percentage of the mass | |
val, idx = torch.sort(attentions) | |
val /= torch.sum(val, dim=1, keepdim=True) | |
cumval = torch.cumsum(val, dim=1) | |
th_attn = cumval > (1 - args.threshold) | |
idx2 = torch.argsort(idx) | |
for head in range(nh): | |
th_attn[head] = th_attn[head][idx2[head]] | |
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() | |
# interpolate | |
th_attn = ( | |
nn.functional.interpolate( | |
th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest" | |
)[0] | |
.cpu() | |
.numpy() | |
) | |
attentions = attentions.reshape(nh, w_featmap, h_featmap) | |
attentions = ( | |
nn.functional.interpolate( | |
attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest" | |
)[0] | |
.cpu() | |
.numpy() | |
) | |
# save attentions heatmaps | |
os.makedirs(args.output_dir, exist_ok=True) | |
torchvision.utils.save_image( | |
torchvision.utils.make_grid(img, normalize=True, scale_each=True), | |
os.path.join(args.output_dir, "img.png"), | |
) | |
for j in range(nh): | |
fname = os.path.join(args.output_dir, "attn-head" + str(j) + ".png") | |
plt.imsave(fname=fname, arr=attentions[j], format="png") | |
print(f"{fname} saved.") | |
if args.threshold is not None: | |
image = skimage.io.imread(os.path.join(args.output_dir, "img.png")) | |
for j in range(nh): | |
display_instances( | |
image, | |
th_attn[j], | |
fname=os.path.join( | |
args.output_dir, | |
"mask_th" + str(args.threshold) + "_head" + str(j) + ".png", | |
), | |
blur=False, | |
) | |