image2sketch / models /triplet_model.py
sharazAhm890's picture
init
b4f7b8c verified
import torch
from .base_model import BaseModel
from . import networks
from util.image_pool import ImagePool
class TripletModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
parser.set_defaults(norm='batch', netG='triplet', dataset_mode='triplet')
if is_train:
parser.set_defaults(pool_size=0, gan_mode='vanilla')
parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
return parser
def __init__(self, opt):
BaseModel.__init__(self, opt)
self.loss_names = ['G_triplet']
self.visual_names = ['x','y']
if self.isTrain:
self.model_names = ['G']
else:
self.model_names = ['G']
self.netG = networks.define_G(1, 1, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
if self.isTrain:
self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.criterionL1 = torch.nn.L1Loss()
self.triplet = torch.nn.TripletMarginLoss(margin=3.0)
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers.append(self.optimizer_G)
def set_input(self, input):
AtoB = self.opt.direction == 'AtoB'
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.real_C = input['C'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
self.x,self.y,self.z = self.netG(self.real_A,self.real_B,self.real_C)
def backward_G(self):
self.loss_G_triplet_1 = self.triplet(self.x,self.y,self.z)
self.loss_G_triplet = self.loss_G_triplet_1
self.loss_G = self.loss_G_triplet
self.loss_G.backward()
def optimize_parameters(self):
self.optimizer_G.zero_grad()
self.backward_G()
self.optimizer_G.step()