myfirst / miniclip /imageWrangle.py
hen's picture
.
fc0f846
raw history blame
No virus
1.68 kB
from PIL import Image
import numpy as np
from streamlit.logger import update_formatter
import torch
from matplotlib import cm
def min_max_norm(array):
lim = [array.min(), array.max()]
array = array - lim[0]
array.mul_(1 / (1.e-10+ (lim[1] - lim[0])))
# array = torch.clamp(array, min=0, max=1)
return array
def torch_to_rgba(img):
img = min_max_norm(img)
rgba_im = img.permute(1, 2, 0).cpu()
if rgba_im.shape[2] == 3:
rgba_im = torch.cat((rgba_im, torch.ones(*rgba_im.shape[:2], 1)), dim=2)
assert rgba_im.shape[2] == 4
return rgba_im
def numpy_to_image(img, size):
"""
takes a [0..1] normalized rgba input and returns resized image as [0...255] rgba image
"""
resized = Image.fromarray((img*255.).astype(np.uint8)).resize((size, size))
return resized
def upscale_pytorch(img:np.array, size):
torch_img = torch.from_numpy(img).unsqueeze(0).permute(0,3,1,2)
print(torch_img)
upsampler = torch.nn.Upsample(size=size)
return upsampler(torch_img)[0].permute(1,2,0).cpu().numpy()
def heatmap(image:torch.Tensor, heatmap: torch.Tensor, size=None, alpha=.6):
if not size:
size = image.shape[1]
# print(heatmap)
# print(min_max_norm(heatmap))
img = torch_to_rgba(image).numpy() # [0...1] rgba numpy "image"
hm = cm.hot(min_max_norm(heatmap).numpy()) # [0...1] rgba numpy "image"
# print(hm.shape, hm)
#
img = np.array(numpy_to_image(img,size))
hm = np.array(numpy_to_image(hm, size))
# hm = upscale_pytorch(hm, size)
# print (hm)
return Image.fromarray((alpha * hm + (1-alpha)*img).astype(np.uint8))
# return Image.fromarray(hm)