hieupt's picture
Update utils.py
d989b7f
raw
history blame contribute delete
905 Bytes
import torch
from PIL import Image
import numpy as np
mean = [0.4763, 0.4507, 0.4094]
std = [0.2702, 0.2652, 0.2811]
class UnNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized image.
"""
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
# The normalize code -> t.sub_(m).div_(s)
return tensor
def deprocess(image_tensor):
""" Denormalizes and rescales image tensor """
unnorm = UnNormalize(mean=mean, std=std)
img = image_tensor
unnorm(img)
img *= 255
image_np = torch.clamp(img, 0, 255).numpy().astype(np.uint8)
image_np = image_np.transpose(1, 2, 0)
return image_np