|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from .base_model import BaseModel |
|
from .fs_networks_fix import Generator_Adain_Upsample |
|
|
|
from pg_modules.projected_discriminator import ProjectedDiscriminator |
|
|
|
def compute_grad2(d_out, x_in): |
|
batch_size = x_in.size(0) |
|
grad_dout = torch.autograd.grad( |
|
outputs=d_out.sum(), inputs=x_in, |
|
create_graph=True, retain_graph=True, only_inputs=True |
|
)[0] |
|
grad_dout2 = grad_dout.pow(2) |
|
assert(grad_dout2.size() == x_in.size()) |
|
reg = grad_dout2.view(batch_size, -1).sum(1) |
|
return reg |
|
|
|
class fsModel(BaseModel): |
|
def name(self): |
|
return 'fsModel' |
|
|
|
def initialize(self, opt): |
|
BaseModel.initialize(self, opt) |
|
|
|
self.isTrain = opt.isTrain |
|
|
|
|
|
self.netG = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=opt.Gdeep) |
|
self.netG.cuda() |
|
|
|
|
|
netArc_checkpoint = opt.Arc_path |
|
netArc_checkpoint = torch.load(netArc_checkpoint, map_location=torch.device("cpu")) |
|
self.netArc = netArc_checkpoint |
|
self.netArc = self.netArc.cuda() |
|
self.netArc.eval() |
|
self.netArc.requires_grad_(False) |
|
if not self.isTrain: |
|
pretrained_path = opt.checkpoints_dir |
|
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) |
|
return |
|
self.netD = ProjectedDiscriminator(diffaug=False, interp224=False, **{}) |
|
|
|
self.netD.cuda() |
|
|
|
|
|
if self.isTrain: |
|
|
|
self.criterionFeat = nn.L1Loss() |
|
self.criterionRec = nn.L1Loss() |
|
|
|
|
|
|
|
|
|
|
|
params = list(self.netG.parameters()) |
|
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8) |
|
|
|
|
|
params = list(self.netD.parameters()) |
|
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8) |
|
|
|
|
|
if opt.continue_train: |
|
pretrained_path = '' if not self.isTrain else opt.load_pretrain |
|
|
|
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) |
|
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) |
|
self.load_optim(self.optimizer_G, 'G', opt.which_epoch, pretrained_path) |
|
self.load_optim(self.optimizer_D, 'D', opt.which_epoch, pretrained_path) |
|
torch.cuda.empty_cache() |
|
|
|
def cosin_metric(self, x1, x2): |
|
|
|
return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1)) |
|
|
|
|
|
|
|
def save(self, which_epoch): |
|
self.save_network(self.netG, 'G', which_epoch) |
|
self.save_network(self.netD, 'D', which_epoch) |
|
self.save_optim(self.optimizer_G, 'G', which_epoch) |
|
self.save_optim(self.optimizer_D, 'D', which_epoch) |
|
'''if self.gen_features: |
|
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)''' |
|
|
|
def update_fixed_params(self): |
|
|
|
params = list(self.netG.parameters()) |
|
if self.gen_features: |
|
params += list(self.netE.parameters()) |
|
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) |
|
if self.opt.verbose: |
|
print('------------ Now also finetuning global generator -----------') |
|
|
|
def update_learning_rate(self): |
|
lrd = self.opt.lr / self.opt.niter_decay |
|
lr = self.old_lr - lrd |
|
for param_group in self.optimizer_D.param_groups: |
|
param_group['lr'] = lr |
|
for param_group in self.optimizer_G.param_groups: |
|
param_group['lr'] = lr |
|
if self.opt.verbose: |
|
print('update learning rate: %f -> %f' % (self.old_lr, lr)) |
|
self.old_lr = lr |
|
|
|
|
|
|