CLIPasso / CLIP_ /example.py
yael-vinker
a
253b0de
import torch
import clip
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt
def interpret(image, text, model, device, index=None):
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()
if index is None:
index = np.argmax(logits_per_image.cpu().data.numpy(), axis=-1)
one_hot = np.zeros((1, logits_per_image.size()[-1]), dtype=np.float32)
one_hot[0, index] = 1
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
one_hot = torch.sum(one_hot.cuda() * logits_per_image)
model.zero_grad()
one_hot.backward(retain_graph=True)
image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())
num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
for blk in image_attn_blocks:
grad = blk.attn_grad
cam = blk.attn_probs
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
cam = grad * cam
cam = cam.clamp(min=0).mean(dim=0)
R += torch.matmul(cam, R)
R[0, 0] = 0
image_relevance = R[0, 1:]
# create heatmap from mask on image
def show_cam_on_image(img, mask):
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return cam
image_relevance = image_relevance.reshape(1, 1, 7, 7)
image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
image_relevance = image_relevance.reshape(224, 224).cuda().data.cpu().numpy()
image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
image = image[0].permute(1, 2, 0).data.cpu().numpy()
image = (image - image.min()) / (image.max() - image.min())
vis = show_cam_on_image(image, image_relevance)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
plt.imshow(vis)
plt.show()
print("Label probs:", probs)
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
image = preprocess(Image.open("catdog.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a dog", "a cat"]).to(device)
interpret(model=model, image=image, text=text, device=device, index=0)
interpret(model=model, image=image, text=text, device=device, index=1)
image = preprocess(Image.open("el1.png")).unsqueeze(0).to(device)
text = clip.tokenize(["an elephant", "a zebra"]).to(device)
interpret(model=model, image=image, text=text, device=device, index=0)
interpret(model=model, image=image, text=text, device=device, index=1)
image = preprocess(Image.open("el2.png")).unsqueeze(0).to(device)
text = clip.tokenize(["an elephant", "a zebra"]).to(device)
interpret(model=model, image=image, text=text, device=device, index=0)
interpret(model=model, image=image, text=text, device=device, index=1)
image = preprocess(Image.open("el3.png")).unsqueeze(0).to(device)
text = clip.tokenize(["an elephant", "a zebra"]).to(device)
interpret(model=model, image=image, text=text, device=device, index=0)
interpret(model=model, image=image, text=text, device=device, index=1)
image = preprocess(Image.open("el4.png")).unsqueeze(0).to(device)
text = clip.tokenize(["an elephant", "a zebra"]).to(device)
interpret(model=model, image=image, text=text, device=device, index=0)
interpret(model=model, image=image, text=text, device=device, index=1)
image = preprocess(Image.open("dogbird.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a basset hound", "a parrot"]).to(device)
interpret(model=model, image=image, text=text, device=device, index=0)
interpret(model=model, image=image, text=text, device=device, index=1)
if __name__ == "__main__":
main()