Spaces:
Runtime error
Runtime error
from typing import Callable | |
import torch | |
from torch import zero_ | |
from torch.nn import Module | |
from torch.nn.init import kaiming_normal_, xavier_normal_, normal_ | |
def create_init_function(method: str = 'none') -> Callable[[Module], Module]: | |
def init(module: Module): | |
if method == 'none': | |
return module | |
elif method == 'he': | |
kaiming_normal_(module.weight) | |
return module | |
elif method == 'xavier': | |
xavier_normal_(module.weight) | |
return module | |
elif method == 'dcgan': | |
normal_(module.weight, 0.0, 0.02) | |
return module | |
elif method == 'dcgan_001': | |
normal_(module.weight, 0.0, 0.01) | |
return module | |
elif method == "zero": | |
with torch.no_grad(): | |
zero_(module.weight) | |
return module | |
else: | |
raise ("Invalid initialization method %s" % method) | |
return init | |
class HeInitialization: | |
def __init__(self, a: int = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'): | |
self.nonlinearity = nonlinearity | |
self.mode = mode | |
self.a = a | |
def __call__(self, module: Module) -> Module: | |
with torch.no_grad(): | |
kaiming_normal_(module.weight, a=self.a, mode=self.mode, nonlinearity=self.nonlinearity) | |
return module | |
class NormalInitialization: | |
def __init__(self, mean: float = 0.0, std: float = 1.0): | |
self.std = std | |
self.mean = mean | |
def __call__(self, module: Module) -> Module: | |
with torch.no_grad(): | |
normal_(module.weight, self.mean, self.std) | |
return module | |
class XavierInitialization: | |
def __init__(self, gain: float = 1.0): | |
self.gain = gain | |
def __call__(self, module: Module) -> Module: | |
with torch.no_grad(): | |
xavier_normal_(module.weight, self.gain) | |
return module | |
class ZeroInitialization: | |
def __call__(self, module: Module) -> Module: | |
with torch.no_grad: | |
zero_(module.weight) | |
return module | |
class NoInitialization: | |
def __call__(self, module: Module) -> Module: | |
return module |