File size: 1,680 Bytes
fc0f846
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from PIL import Image
import numpy as np
from streamlit.logger import update_formatter
import torch
from matplotlib import cm



def min_max_norm(array):
    lim = [array.min(), array.max()]
    array = array - lim[0] 
    array.mul_(1 / (1.e-10+ (lim[1] - lim[0])))
    # array = torch.clamp(array, min=0, max=1)
    return array

def torch_to_rgba(img):
    img = min_max_norm(img)
    rgba_im = img.permute(1, 2, 0).cpu()
    if rgba_im.shape[2] == 3:
        rgba_im = torch.cat((rgba_im, torch.ones(*rgba_im.shape[:2], 1)), dim=2)
    assert rgba_im.shape[2] == 4
    return rgba_im


def numpy_to_image(img, size):
    """
    takes a [0..1] normalized rgba input and returns resized image as [0...255] rgba image
    """
    resized = Image.fromarray((img*255.).astype(np.uint8)).resize((size, size))
    return resized

def upscale_pytorch(img:np.array, size):
    torch_img = torch.from_numpy(img).unsqueeze(0).permute(0,3,1,2)
    print(torch_img)
    upsampler = torch.nn.Upsample(size=size)    
    return upsampler(torch_img)[0].permute(1,2,0).cpu().numpy()


def heatmap(image:torch.Tensor, heatmap: torch.Tensor, size=None, alpha=.6):
    if not size:
        size = image.shape[1]
    # print(heatmap)
    # print(min_max_norm(heatmap))

    img = torch_to_rgba(image).numpy() # [0...1] rgba numpy "image"
    hm = cm.hot(min_max_norm(heatmap).numpy()) # [0...1] rgba numpy "image"

    # print(hm.shape, hm)
 #

    img = np.array(numpy_to_image(img,size))
    hm = np.array(numpy_to_image(hm, size))
    # hm = upscale_pytorch(hm, size)
    # print (hm) 

    return Image.fromarray((alpha * hm + (1-alpha)*img).astype(np.uint8))
    # return Image.fromarray(hm)