File size: 3,325 Bytes
cddd431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d54c0f8
cddd431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image 
import random
import cv2
import io
from ssl_models.simclr2 import get_simclr2_model
from ssl_models.barlow_twins import get_barlow_twins_model
from ssl_models.simsiam import get_simsiam
from ssl_models.dino import get_dino_model_without_loss, get_dino_model_with_loss

def get_ssl_model(network, variant):
    
    if network == 'simclrv2':
        if variant == '1x':
            ssl_model = get_simclr2_model('r50_1x_sk0_ema.pth').eval()
        else:
            ssl_model = get_simclr2_model('r50_2x_sk0_ema.pth').eval()
    elif network == 'barlow_twins':
        ssl_model = get_barlow_twins_model().eval()  
    elif network == 'simsiam':
        ssl_model = get_simsiam().eval()
    elif network == 'dino':
        ssl_model = get_dino_model_without_loss().eval()
    elif network == 'dino+loss':
        ssl_model, dino_score = get_dino_model_with_loss()
        ssl_model = ssl_model.eval()
        
    return ssl_model

def overlay_heatmap(img, heatmap, denormalize = False):
    loaded_img = img.squeeze(0).cpu().numpy().transpose((1, 2, 0))
    
    if denormalize:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        loaded_img = std * loaded_img + mean 
                  
    loaded_img = (loaded_img.clip(0, 1) * 255).astype(np.uint8)
    cam = heatmap / heatmap.max()
    cam = cv2.resize(cam, (224, 224))
    cam = np.uint8(255 * cam)
    cam = cv2.applyColorMap(cam, cv2.COLORMAP_JET)   # jet: blue --> red
    cam = cv2.cvtColor(cam, cv2.COLOR_BGR2RGB)
    added_image = cv2.addWeighted(cam, 0.5, loaded_img, 0.5, 0)
    return added_image

def viz_map(img_path, heatmap):
    "For pixel invariance"
    img = np.array(Image.open(img_path).resize((224,224))) if isinstance(img_path, str) else np.array(img_path.resize((224,224)))
    width, height, _ = img.shape
    cam = heatmap.detach().cpu().numpy()
    cam = cam / cam.max()
    cam = cv2.resize(cam, (height, width))
    heatmap = np.uint8(255 * cam)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    added_image = cv2.addWeighted(heatmap, 0.5, img, 0.7, 0)
    return added_image

def show_image(x, squeeze = True, denormalize = False):
    
    if squeeze:
        x = x.squeeze(0)
        
    x = x.cpu().numpy().transpose((1, 2, 0))
    
    if denormalize:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        x = std * x + mean 
    
    return x.clip(0, 1)

def deprocess(inp, to_numpy = True, to_PIL = False, denormalize = False):
    
    if to_numpy:
        inp = inp.detach().cpu().numpy()
    
    inp = inp.squeeze(0).transpose((1, 2, 0))
           
    if denormalize:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        inp = std * inp + mean 
           
    inp = (inp.clip(0, 1) * 255).astype(np.uint8)
           
    if to_PIL:
        return Image.fromarray(inp)
    return inp

def fig2img(fig):
    """Convert a Matplotlib figure to a PIL Image and return it"""
    buf = io.BytesIO()
    fig.savefig(buf, bbox_inches='tight', pad_inches=0)
    buf.seek(0)
    img = Image.open(buf)
    return img