Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| def get_activation_name(activation): | |
| """Given a string or a `torch.nn.modules.activation` return the name of the activation.""" | |
| if isinstance(activation, str): | |
| return activation | |
| mapper = {nn.LeakyReLU: "leaky_relu", nn.ReLU: "relu", nn.Tanh: "tanh", | |
| nn.Sigmoid: "sigmoid", nn.Softmax: "sigmoid"} | |
| for k, v in mapper.items(): | |
| if isinstance(activation, k): | |
| return k | |
| raise ValueError("Unkown given activation type : {}".format(activation)) | |
| def get_gain(activation): | |
| """Given an object of `torch.nn.modules.activation` or an activation name | |
| return the correct gain.""" | |
| if activation is None: | |
| return 1 | |
| activation_name = get_activation_name(activation) | |
| param = None if activation_name != "leaky_relu" else activation.negative_slope | |
| gain = nn.init.calculate_gain(activation_name, param) | |
| return gain | |
| def linear_init(layer, activation="relu"): | |
| """Initialize a linear layer. | |
| Args: | |
| layer (nn.Linear): parameters to initialize. | |
| activation (`torch.nn.modules.activation` or str, optional) activation that | |
| will be used on the `layer`. | |
| """ | |
| x = layer.weight | |
| if activation is None: | |
| return nn.init.xavier_uniform_(x) | |
| activation_name = get_activation_name(activation) | |
| if activation_name == "leaky_relu": | |
| a = 0 if isinstance(activation, str) else activation.negative_slope | |
| return nn.init.kaiming_uniform_(x, a=a, nonlinearity='leaky_relu') | |
| elif activation_name == "relu": | |
| return nn.init.kaiming_uniform_(x, nonlinearity='relu') | |
| elif activation_name in ["sigmoid", "tanh"]: | |
| return nn.init.xavier_uniform_(x, gain=get_gain(activation)) | |
| def weights_init(module): | |
| if isinstance(module, torch.nn.modules.conv._ConvNd): | |
| # TO-DO: check litterature | |
| linear_init(module) | |
| elif isinstance(module, nn.Linear): | |
| linear_init(module) | |