hakansivuk's picture
Final commit
087921f
raw
history blame contribute delete
No virus
2.28 kB
import math
import yaml
import torch.nn.init as init
import torch
import numpy as np
def get_config(config):
with open(config, 'r') as stream:
return yaml.load(stream, Loader=yaml.Loader)
def weights_init(init_type='gaussian'):
def init_fun(m):
classname = m.__class__.__name__
if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
# print m.__class__.__name__
if init_type == 'gaussian':
init.normal_(m.weight.data, 0.0, 0.02)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=math.sqrt(2))
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=math.sqrt(2))
elif init_type == 'default':
pass
else:
assert 0, "Unsupported initialization: {}".format(init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
return init_fun
def tensor2im(input_image, imtype=np.uint8, no_fg=True):
""""Converts a Tensor array into a numpy image array.
Parameters:
input_image (tensor) -- the input image tensor array
imtype (type) -- the desired type of the converted numpy array
no_fg: binary image and don't transform
"""
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array, only take the first output
if no_fg:
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
else:
image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0
image_numpy = np.clip(image_numpy, 0, 255)
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)