|
|
import torch |
|
|
from torch.nn import init |
|
|
|
|
|
|
|
|
def init_weights(net, init_type="normal", init_gain=0.02): |
|
|
"""Initialize network weights. |
|
|
|
|
|
Parameters: |
|
|
net (network) -- network to be initialized |
|
|
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal |
|
|
init_gain (float) -- scaling factor for normal, xavier and orthogonal. |
|
|
|
|
|
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might |
|
|
work better for some applications. Feel free to try yourself. |
|
|
""" |
|
|
|
|
|
def init_func(m): |
|
|
classname = m.__class__.__name__ |
|
|
if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): |
|
|
if init_type == "normal": |
|
|
init.normal_(m.weight.data, 0.0, init_gain) |
|
|
elif init_type == "xavier": |
|
|
init.xavier_normal_(m.weight.data, gain=init_gain) |
|
|
elif init_type == "kaiming": |
|
|
init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") |
|
|
elif init_type == "orthogonal": |
|
|
init.orthogonal_(m.weight.data, gain=init_gain) |
|
|
else: |
|
|
raise NotImplementedError("initialization method [%s] is not implemented" % init_type) |
|
|
if hasattr(m, "bias") and m.bias is not None: |
|
|
init.constant_(m.bias.data, 0.0) |
|
|
elif ( |
|
|
classname.find("BatchNorm2d") != -1 |
|
|
): |
|
|
init.normal_(m.weight.data, 1.0, init_gain) |
|
|
init.constant_(m.bias.data, 0.0) |
|
|
|
|
|
|
|
|
net.apply(init_func) |
|
|
|
|
|
|
|
|
def init_net(net, init_type="normal", init_gain=0.02, gpu_ids=[]): |
|
|
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights |
|
|
Parameters: |
|
|
net (network) -- the network to be initialized |
|
|
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal |
|
|
gain (float) -- scaling factor for normal, xavier and orthogonal. |
|
|
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 |
|
|
|
|
|
Return an initialized network. |
|
|
""" |
|
|
if len(gpu_ids) > 0: |
|
|
assert torch.cuda.is_available() |
|
|
net.to(gpu_ids[0]) |
|
|
|
|
|
init_weights(net, init_type, init_gain=init_gain) |
|
|
return net |
|
|
|