Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import numpy as np | |
| import torch | |
| from .base_model import BaseModel | |
| from . import networks | |
| from .patchnce import PatchNCELoss | |
| import util.util as util | |
| class CUTModel(BaseModel): | |
| """ This class implements CUT and FastCUT model, described in the paper | |
| Contrastive Learning for Unpaired Image-to-Image Translation | |
| Taesung Park, Alexei A. Efros, Richard Zhang, Jun-Yan Zhu | |
| ECCV, 2020 | |
| The code borrows heavily from the PyTorch implementation of CycleGAN | |
| https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix | |
| """ | |
| def modify_commandline_options(parser, is_train=True): | |
| """ Configures options specific for CUT model | |
| """ | |
| parser.add_argument('--CUT_mode', type=str, default="CUT", choices='(CUT, cut, FastCUT, fastcut)') | |
| parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN lossοΌGAN(G(X))') | |
| parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)') | |
| parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))') | |
| parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers') | |
| parser.add_argument('--nce_includes_all_negatives_from_minibatch', | |
| type=util.str2bool, nargs='?', const=True, default=False, | |
| help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.') | |
| parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map') | |
| parser.add_argument('--netF_nc', type=int, default=256) | |
| parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss') | |
| parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer') | |
| parser.add_argument('--flip_equivariance', | |
| type=util.str2bool, nargs='?', const=True, default=False, | |
| help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT") | |
| parser.set_defaults(pool_size=0) # no image pooling | |
| opt, _ = parser.parse_known_args() | |
| # Set default parameters for CUT and FastCUT | |
| if opt.CUT_mode.lower() == "cut": | |
| parser.set_defaults(nce_idt=True, lambda_NCE=1.0) | |
| elif opt.CUT_mode.lower() == "fastcut": | |
| parser.set_defaults( | |
| nce_idt=False, lambda_NCE=10.0, flip_equivariance=True, | |
| n_epochs=150, n_epochs_decay=50 | |
| ) | |
| else: | |
| raise ValueError(opt.CUT_mode) | |
| return parser | |
| def __init__(self, opt): | |
| BaseModel.__init__(self, opt) | |
| # specify the training losses you want to print out. | |
| # The training/test scripts will call <BaseModel.get_current_losses> | |
| self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'G', 'NCE'] | |
| self.visual_names = ['real_A', 'fake_B', 'real_B'] | |
| self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')] | |
| if opt.nce_idt and self.isTrain: | |
| self.loss_names += ['NCE_Y'] | |
| self.visual_names += ['idt_B'] | |
| if self.isTrain: | |
| self.model_names = ['G', 'F', 'D'] | |
| else: # during test time, only load G | |
| self.model_names = ['G'] | |
| # define networks (both generator and discriminator) | |
| self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt) | |
| self.netF = networks.define_F(opt.input_nc, opt.netF, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt) | |
| if self.isTrain: | |
| self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt) | |
| # define loss functions | |
| self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) | |
| self.criterionNCE = [] | |
| for nce_layer in self.nce_layers: | |
| self.criterionNCE.append(PatchNCELoss(opt).to(self.device)) | |
| self.criterionIdt = torch.nn.L1Loss().to(self.device) | |
| self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) | |
| self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) | |
| self.optimizers.append(self.optimizer_G) | |
| self.optimizers.append(self.optimizer_D) | |
| def data_dependent_initialize(self, data): | |
| """ | |
| The feature network netF is defined in terms of the shape of the intermediate, extracted | |
| features of the encoder portion of netG. Because of this, the weights of netF are | |
| initialized at the first feedforward pass with some input images. | |
| Please also see PatchSampleF.create_mlp(), which is called at the first forward() call. | |
| """ | |
| self.set_input(data) | |
| bs_per_gpu = self.real_A.size(0) // max(len(self.opt.gpu_ids), 1) | |
| self.real_A = self.real_A[:bs_per_gpu] | |
| self.real_B = self.real_B[:bs_per_gpu] | |
| self.forward() # compute fake images: G(A) | |
| if self.opt.isTrain: | |
| self.compute_D_loss().backward() # calculate gradients for D | |
| self.compute_G_loss().backward() # calculate graidents for G | |
| if self.opt.lambda_NCE > 0.0: | |
| self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2)) | |
| self.optimizers.append(self.optimizer_F) | |
| def optimize_parameters(self): | |
| # forward | |
| self.forward() | |
| # update D | |
| self.set_requires_grad(self.netD, True) | |
| self.optimizer_D.zero_grad() | |
| self.loss_D = self.compute_D_loss() | |
| self.loss_D.backward() | |
| self.optimizer_D.step() | |
| # update G | |
| self.set_requires_grad(self.netD, False) | |
| self.optimizer_G.zero_grad() | |
| if self.opt.netF == 'mlp_sample': | |
| self.optimizer_F.zero_grad() | |
| self.loss_G = self.compute_G_loss() | |
| self.loss_G.backward() | |
| self.optimizer_G.step() | |
| if self.opt.netF == 'mlp_sample': | |
| self.optimizer_F.step() | |
| def set_input(self, input): | |
| """Unpack input data from the dataloader and perform necessary pre-processing steps. | |
| Parameters: | |
| input (dict): include the data itself and its metadata information. | |
| The option 'direction' can be used to swap domain A and domain B. | |
| """ | |
| AtoB = self.opt.direction == 'AtoB' | |
| self.real_A = input['A' if AtoB else 'B'].to(self.device) | |
| self.real_B = input['B' if AtoB else 'A'].to(self.device) | |
| self.image_paths = input['A_paths' if AtoB else 'B_paths'] | |
| def forward(self): | |
| """Run forward pass; called by both functions <optimize_parameters> and <test>.""" | |
| self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt and self.opt.isTrain else self.real_A | |
| if self.opt.flip_equivariance: | |
| self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5) | |
| if self.flipped_for_equivariance: | |
| self.real = torch.flip(self.real, [3]) | |
| self.fake = self.netG(self.real) | |
| self.fake_B = self.fake[:self.real_A.size(0)] | |
| if self.opt.nce_idt: | |
| self.idt_B = self.fake[self.real_A.size(0):] | |
| def compute_D_loss(self): | |
| """Calculate GAN loss for the discriminator""" | |
| fake = self.fake_B.detach() | |
| # Fake; stop backprop to the generator by detaching fake_B | |
| pred_fake = self.netD(fake) | |
| self.loss_D_fake = self.criterionGAN(pred_fake, False).mean() | |
| # Real | |
| self.pred_real = self.netD(self.real_B) | |
| loss_D_real = self.criterionGAN(self.pred_real, True) | |
| self.loss_D_real = loss_D_real.mean() | |
| # combine loss and calculate gradients | |
| self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 | |
| return self.loss_D | |
| def compute_G_loss(self): | |
| """Calculate GAN and NCE loss for the generator""" | |
| fake = self.fake_B | |
| # First, G(A) should fake the discriminator | |
| if self.opt.lambda_GAN > 0.0: | |
| pred_fake = self.netD(fake) | |
| self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN | |
| else: | |
| self.loss_G_GAN = 0.0 | |
| if self.opt.lambda_NCE > 0.0: | |
| self.loss_NCE = self.calculate_NCE_loss(self.real_A, self.fake_B) | |
| else: | |
| self.loss_NCE, self.loss_NCE_bd = 0.0, 0.0 | |
| if self.opt.nce_idt and self.opt.lambda_NCE > 0.0: | |
| self.loss_NCE_Y = self.calculate_NCE_loss(self.real_B, self.idt_B) | |
| loss_NCE_both = (self.loss_NCE + self.loss_NCE_Y) * 0.5 | |
| else: | |
| loss_NCE_both = self.loss_NCE | |
| self.loss_G = self.loss_G_GAN + loss_NCE_both | |
| return self.loss_G | |
| def calculate_NCE_loss(self, src, tgt): | |
| n_layers = len(self.nce_layers) | |
| feat_q = self.netG(tgt, self.nce_layers, encode_only=True) | |
| if self.opt.flip_equivariance and self.flipped_for_equivariance: | |
| feat_q = [torch.flip(fq, [3]) for fq in feat_q] | |
| feat_k = self.netG(src, self.nce_layers, encode_only=True) | |
| feat_k_pool, sample_ids = self.netF(feat_k, self.opt.num_patches, None) | |
| feat_q_pool, _ = self.netF(feat_q, self.opt.num_patches, sample_ids) | |
| total_nce_loss = 0.0 | |
| for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers): | |
| loss = crit(f_q, f_k) * self.opt.lambda_NCE | |
| total_nce_loss += loss.mean() | |
| return total_nce_loss / n_layers | |