xai-cl / utils.py
Annonymous
Update utils.py
d54c0f8
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import random
import cv2
import io
from ssl_models.simclr2 import get_simclr2_model
from ssl_models.barlow_twins import get_barlow_twins_model
from ssl_models.simsiam import get_simsiam
from ssl_models.dino import get_dino_model_without_loss, get_dino_model_with_loss
def get_ssl_model(network, variant):
if network == 'simclrv2':
if variant == '1x':
ssl_model = get_simclr2_model('r50_1x_sk0_ema.pth').eval()
else:
ssl_model = get_simclr2_model('r50_2x_sk0_ema.pth').eval()
elif network == 'barlow_twins':
ssl_model = get_barlow_twins_model().eval()
elif network == 'simsiam':
ssl_model = get_simsiam().eval()
elif network == 'dino':
ssl_model = get_dino_model_without_loss().eval()
elif network == 'dino+loss':
ssl_model, dino_score = get_dino_model_with_loss()
ssl_model = ssl_model.eval()
return ssl_model
def overlay_heatmap(img, heatmap, denormalize = False):
loaded_img = img.squeeze(0).cpu().numpy().transpose((1, 2, 0))
if denormalize:
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
loaded_img = std * loaded_img + mean
loaded_img = (loaded_img.clip(0, 1) * 255).astype(np.uint8)
cam = heatmap / heatmap.max()
cam = cv2.resize(cam, (224, 224))
cam = np.uint8(255 * cam)
cam = cv2.applyColorMap(cam, cv2.COLORMAP_JET) # jet: blue --> red
cam = cv2.cvtColor(cam, cv2.COLOR_BGR2RGB)
added_image = cv2.addWeighted(cam, 0.5, loaded_img, 0.5, 0)
return added_image
def viz_map(img_path, heatmap):
"For pixel invariance"
img = np.array(Image.open(img_path).resize((224,224))) if isinstance(img_path, str) else np.array(img_path.resize((224,224)))
width, height, _ = img.shape
cam = heatmap.detach().cpu().numpy()
cam = cam / cam.max()
cam = cv2.resize(cam, (height, width))
heatmap = np.uint8(255 * cam)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
added_image = cv2.addWeighted(heatmap, 0.5, img, 0.7, 0)
return added_image
def show_image(x, squeeze = True, denormalize = False):
if squeeze:
x = x.squeeze(0)
x = x.cpu().numpy().transpose((1, 2, 0))
if denormalize:
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
x = std * x + mean
return x.clip(0, 1)
def deprocess(inp, to_numpy = True, to_PIL = False, denormalize = False):
if to_numpy:
inp = inp.detach().cpu().numpy()
inp = inp.squeeze(0).transpose((1, 2, 0))
if denormalize:
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = (inp.clip(0, 1) * 255).astype(np.uint8)
if to_PIL:
return Image.fromarray(inp)
return inp
def fig2img(fig):
"""Convert a Matplotlib figure to a PIL Image and return it"""
buf = io.BytesIO()
fig.savefig(buf, bbox_inches='tight', pad_inches=0)
buf.seek(0)
img = Image.open(buf)
return img