Spaces:
Runtime error
Runtime error
import torch | |
from torch.nn import init | |
def init_weights(net, init_type="normal", init_gain=0.02): | |
"""Initialize network weights. | |
Parameters: | |
net (network) -- network to be initialized | |
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal | |
init_gain (float) -- scaling factor for normal, xavier and orthogonal. | |
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might | |
work better for some applications. Feel free to try yourself. | |
""" | |
def init_func(m): # define the initialization function | |
classname = m.__class__.__name__ | |
if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): | |
if init_type == "normal": | |
init.normal_(m.weight.data, 0.0, init_gain) | |
elif init_type == "xavier": | |
init.xavier_normal_(m.weight.data, gain=init_gain) | |
elif init_type == "kaiming": | |
init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") | |
elif init_type == "orthogonal": | |
init.orthogonal_(m.weight.data, gain=init_gain) | |
else: | |
raise NotImplementedError("initialization method [%s] is not implemented" % init_type) | |
if hasattr(m, "bias") and m.bias is not None: | |
init.constant_(m.bias.data, 0.0) | |
elif ( | |
classname.find("BatchNorm2d") != -1 | |
): # BatchNorm Layer's weight is not a matrix; only normal distribution applies. | |
init.normal_(m.weight.data, 1.0, init_gain) | |
init.constant_(m.bias.data, 0.0) | |
# print("initialize network with %s" % init_type) | |
net.apply(init_func) # apply the initialization function <init_func> | |
def init_net(net, init_type="normal", init_gain=0.02, gpu_ids=[]): | |
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights | |
Parameters: | |
net (network) -- the network to be initialized | |
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal | |
gain (float) -- scaling factor for normal, xavier and orthogonal. | |
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 | |
Return an initialized network. | |
""" | |
if len(gpu_ids) > 0: | |
assert torch.cuda.is_available() | |
net.to(gpu_ids[0]) | |
# net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs | |
init_weights(net, init_type, init_gain=init_gain) | |
return net | |