text2live / models /networks.py
SupermanxKiaski's picture
Upload 351 files
3b40f46
raw
history blame contribute delete
No virus
1.44 kB
from torch.optim import lr_scheduler
from models.backbone.skip import skip
def get_scheduler(optimizer, opt):
if opt.lr_policy == "linear":
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == "step":
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
elif opt.lr_policy == "plateau":
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, threshold=0.01, patience=5)
elif opt.lr_policy == "cosine":
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
else:
return NotImplementedError("learning rate policy [%s] is not implemented", opt.lr_policy)
return scheduler
def define_G(cfg):
netG = skip(
3,
4,
num_channels_down=[cfg["skip_n33d"]] * cfg["num_scales"]
if isinstance(cfg["skip_n33d"], int)
else cfg["skip_n33d"],
num_channels_up=[cfg["skip_n33u"]] * cfg["num_scales"]
if isinstance(cfg["skip_n33u"], int)
else cfg["skip_n33u"],
num_channels_skip=[cfg["skip_n11"]] * cfg["num_scales"]
if isinstance(cfg["skip_n11"], int)
else cfg["skip_n11"],
need_bias=True,
)
return netG