| """Model class template
|
|
|
| This module provides a template for users to implement custom models.
|
| You can specify '--model template' to use this model.
|
| The class name should be consistent with both the filename and its model option.
|
| The filename should be <model>_dataset.py
|
| The class name should be <Model>Dataset.py
|
| It implements a simple image-to-image translation baseline based on regression loss.
|
| Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
|
| min_<netG> ||netG(data_A) - data_B||_1
|
| You need to implement the following functions:
|
| <modify_commandline_options>: Add model-specific options and rewrite default values for existing options.
|
| <__init__>: Initialize this model class.
|
| <set_input>: Unpack input data and perform data pre-processing.
|
| <forward>: Run forward pass. This will be called by both <optimize_parameters> and <test>.
|
| <optimize_parameters>: Update network weights; it will be called in every training iteration.
|
| """
|
| import numpy as np
|
| import torch
|
| from .base_model import BaseModel
|
| from . import networks
|
|
|
|
|
| class TemplateModel(BaseModel):
|
| @staticmethod
|
| def modify_commandline_options(parser, is_train=True):
|
| """Add new model-specific options and rewrite default values for existing options.
|
|
|
| Parameters:
|
| parser -- the option parser
|
| is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
|
|
| Returns:
|
| the modified parser.
|
| """
|
| parser.set_defaults(dataset_mode='aligned')
|
| if is_train:
|
| parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss')
|
|
|
| return parser
|
|
|
| def __init__(self, opt):
|
| """Initialize this model class.
|
|
|
| Parameters:
|
| opt -- training/test options
|
|
|
| A few things can be done here.
|
| - (required) call the initialization function of BaseModel
|
| - define loss function, visualization images, model names, and optimizers
|
| """
|
| BaseModel.__init__(self, opt)
|
|
|
| self.loss_names = ['loss_G']
|
|
|
| self.visual_names = ['data_A', 'data_B', 'output']
|
|
|
|
|
| self.model_names = ['G']
|
|
|
| self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
|
| if self.isTrain:
|
|
|
|
|
| self.criterionLoss = torch.nn.L1Loss()
|
|
|
|
|
| self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
| self.optimizers = [self.optimizer]
|
|
|
|
|
|
|
| def set_input(self, input):
|
| """Unpack input data from the dataloader and perform necessary pre-processing steps.
|
|
|
| Parameters:
|
| input: a dictionary that contains the data itself and its metadata information.
|
| """
|
| AtoB = self.opt.direction == 'AtoB'
|
| self.data_A = input['A' if AtoB else 'B'].to(self.device)
|
| self.data_B = input['B' if AtoB else 'A'].to(self.device)
|
| self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
|
|
| def forward(self):
|
| """Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""
|
| self.output = self.netG(self.data_A)
|
|
|
| def backward(self):
|
| """Calculate losses, gradients, and update network weights; called in every training iteration"""
|
|
|
|
|
| self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
|
| self.loss_G.backward()
|
|
|
| def optimize_parameters(self):
|
| """Update network weights; it will be called in every training iteration."""
|
| self.forward()
|
| self.optimizer.zero_grad()
|
| self.backward()
|
| self.optimizer.step()
|
|
|