|
from PIL import Image |
|
import numpy as np |
|
from streamlit.logger import update_formatter |
|
import torch |
|
from matplotlib import cm |
|
|
|
|
|
|
|
def min_max_norm(array): |
|
lim = [array.min(), array.max()] |
|
array = array - lim[0] |
|
array.mul_(1 / (1.e-10+ (lim[1] - lim[0]))) |
|
|
|
return array |
|
|
|
def torch_to_rgba(img): |
|
img = min_max_norm(img) |
|
rgba_im = img.permute(1, 2, 0).cpu() |
|
if rgba_im.shape[2] == 3: |
|
rgba_im = torch.cat((rgba_im, torch.ones(*rgba_im.shape[:2], 1)), dim=2) |
|
assert rgba_im.shape[2] == 4 |
|
return rgba_im |
|
|
|
|
|
def numpy_to_image(img, size): |
|
""" |
|
takes a [0..1] normalized rgba input and returns resized image as [0...255] rgba image |
|
""" |
|
resized = Image.fromarray((img*255.).astype(np.uint8)).resize((size, size)) |
|
return resized |
|
|
|
def upscale_pytorch(img:np.array, size): |
|
torch_img = torch.from_numpy(img).unsqueeze(0).permute(0,3,1,2) |
|
print(torch_img) |
|
upsampler = torch.nn.Upsample(size=size) |
|
return upsampler(torch_img)[0].permute(1,2,0).cpu().numpy() |
|
|
|
|
|
def heatmap(image:torch.Tensor, heatmap: torch.Tensor, size=None, alpha=.6): |
|
if not size: |
|
size = image.shape[1] |
|
|
|
|
|
|
|
img = torch_to_rgba(image).numpy() |
|
hm = cm.hot(min_max_norm(heatmap).numpy()) |
|
|
|
|
|
|
|
|
|
img = np.array(numpy_to_image(img,size)) |
|
hm = np.array(numpy_to_image(hm, size)) |
|
|
|
|
|
|
|
return Image.fromarray((alpha * hm + (1-alpha)*img).astype(np.uint8)) |
|
|