nikunjkdtechnoland commited on
Commit
575c901
1 Parent(s): 518eb4f

trainer add

Browse files
Files changed (1) hide show
  1. trainer.py +155 -0
trainer.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch import autograd
5
+ from model.networks import Generator, LocalDis, GlobalDis
6
+
7
+
8
+ from utils.tools import get_model_list, local_patch, spatial_discounting_mask
9
+ from utils.logger import get_logger
10
+
11
+ logger = get_logger()
12
+
13
+
14
+ class Trainer(nn.Module):
15
+ def __init__(self, config):
16
+ super(Trainer, self).__init__()
17
+ self.config = config
18
+ self.use_cuda = self.config['cuda']
19
+ self.device_ids = self.config['gpu_ids']
20
+
21
+ self.netG = Generator(self.config['netG'], self.use_cuda, self.device_ids)
22
+ self.localD = LocalDis(self.config['netD'], self.use_cuda, self.device_ids)
23
+ self.globalD = GlobalDis(self.config['netD'], self.use_cuda, self.device_ids)
24
+
25
+ self.optimizer_g = torch.optim.Adam(self.netG.parameters(), lr=self.config['lr'],
26
+ betas=(self.config['beta1'], self.config['beta2']))
27
+ d_params = list(self.localD.parameters()) + list(self.globalD.parameters())
28
+ self.optimizer_d = torch.optim.Adam(d_params, lr=config['lr'],
29
+ betas=(self.config['beta1'], self.config['beta2']))
30
+ if self.use_cuda:
31
+ self.netG.to(self.device_ids[0])
32
+ self.localD.to(self.device_ids[0])
33
+ self.globalD.to(self.device_ids[0])
34
+
35
+ def forward(self, x, bboxes, masks, ground_truth, compute_loss_g=False):
36
+ self.train()
37
+ l1_loss = nn.L1Loss()
38
+ losses = {}
39
+
40
+ x1, x2, offset_flow = self.netG(x, masks)
41
+ local_patch_gt = local_patch(ground_truth, bboxes)
42
+ x1_inpaint = x1 * masks + x * (1. - masks)
43
+ x2_inpaint = x2 * masks + x * (1. - masks)
44
+ local_patch_x1_inpaint = local_patch(x1_inpaint, bboxes)
45
+ local_patch_x2_inpaint = local_patch(x2_inpaint, bboxes)
46
+
47
+ # D part
48
+ # wgan d loss
49
+ local_patch_real_pred, local_patch_fake_pred = self.dis_forward(
50
+ self.localD, local_patch_gt, local_patch_x2_inpaint.detach())
51
+ global_real_pred, global_fake_pred = self.dis_forward(
52
+ self.globalD, ground_truth, x2_inpaint.detach())
53
+ losses['wgan_d'] = torch.mean(local_patch_fake_pred - local_patch_real_pred) + \
54
+ torch.mean(global_fake_pred - global_real_pred) * self.config['global_wgan_loss_alpha']
55
+ # gradients penalty loss
56
+ local_penalty = self.calc_gradient_penalty(
57
+ self.localD, local_patch_gt, local_patch_x2_inpaint.detach())
58
+ global_penalty = self.calc_gradient_penalty(self.globalD, ground_truth, x2_inpaint.detach())
59
+ losses['wgan_gp'] = local_penalty + global_penalty
60
+
61
+ # G part
62
+ if compute_loss_g:
63
+ sd_mask = spatial_discounting_mask(self.config)
64
+ losses['l1'] = l1_loss(local_patch_x1_inpaint * sd_mask, local_patch_gt * sd_mask) * \
65
+ self.config['coarse_l1_alpha'] + \
66
+ l1_loss(local_patch_x2_inpaint * sd_mask, local_patch_gt * sd_mask)
67
+ losses['ae'] = l1_loss(x1 * (1. - masks), ground_truth * (1. - masks)) * \
68
+ self.config['coarse_l1_alpha'] + \
69
+ l1_loss(x2 * (1. - masks), ground_truth * (1. - masks))
70
+
71
+ # wgan g loss
72
+ local_patch_real_pred, local_patch_fake_pred = self.dis_forward(
73
+ self.localD, local_patch_gt, local_patch_x2_inpaint)
74
+ global_real_pred, global_fake_pred = self.dis_forward(
75
+ self.globalD, ground_truth, x2_inpaint)
76
+ losses['wgan_g'] = - torch.mean(local_patch_fake_pred) - \
77
+ torch.mean(global_fake_pred) * self.config['global_wgan_loss_alpha']
78
+
79
+ return losses, x2_inpaint, offset_flow
80
+
81
+ def dis_forward(self, netD, ground_truth, x_inpaint):
82
+ assert ground_truth.size() == x_inpaint.size()
83
+ batch_size = ground_truth.size(0)
84
+ batch_data = torch.cat([ground_truth, x_inpaint], dim=0)
85
+ batch_output = netD(batch_data)
86
+ real_pred, fake_pred = torch.split(batch_output, batch_size, dim=0)
87
+
88
+ return real_pred, fake_pred
89
+
90
+ # Calculate gradient penalty
91
+ def calc_gradient_penalty(self, netD, real_data, fake_data):
92
+ batch_size = real_data.size(0)
93
+ alpha = torch.rand(batch_size, 1, 1, 1)
94
+ alpha = alpha.expand_as(real_data)
95
+ if self.use_cuda:
96
+ alpha = alpha.cuda()
97
+
98
+ interpolates = alpha * real_data + (1 - alpha) * fake_data
99
+ interpolates = interpolates.requires_grad_().clone()
100
+
101
+ disc_interpolates = netD(interpolates)
102
+ grad_outputs = torch.ones(disc_interpolates.size())
103
+
104
+ if self.use_cuda:
105
+ grad_outputs = grad_outputs.cuda()
106
+
107
+ gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
108
+ grad_outputs=grad_outputs, create_graph=True,
109
+ retain_graph=True, only_inputs=True)[0]
110
+
111
+ gradients = gradients.view(batch_size, -1)
112
+ gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
113
+
114
+ return gradient_penalty
115
+
116
+ def inference(self, x, masks):
117
+ self.eval()
118
+ x1, x2, offset_flow = self.netG(x, masks)
119
+ # x1_inpaint = x1 * masks + x * (1. - masks)
120
+ x2_inpaint = x2 * masks + x * (1. - masks)
121
+
122
+ return x2_inpaint, offset_flow
123
+
124
+ def save_model(self, checkpoint_dir, iteration):
125
+ # Save generators, discriminators, and optimizers
126
+ gen_name = os.path.join(checkpoint_dir, 'gen_%08d.pt' % iteration)
127
+ dis_name = os.path.join(checkpoint_dir, 'dis_%08d.pt' % iteration)
128
+ opt_name = os.path.join(checkpoint_dir, 'optimizer.pt')
129
+ torch.save(self.netG.state_dict(), gen_name)
130
+ torch.save({'localD': self.localD.state_dict(),
131
+ 'globalD': self.globalD.state_dict()}, dis_name)
132
+ torch.save({'gen': self.optimizer_g.state_dict(),
133
+ 'dis': self.optimizer_d.state_dict()}, opt_name)
134
+
135
+ def resume(self, checkpoint_dir, iteration=0, test=False):
136
+ # Load generators
137
+ last_model_name = get_model_list(checkpoint_dir, "gen", iteration=iteration)
138
+ self.netG.load_state_dict(torch.load(last_model_name))
139
+ iteration = int(last_model_name[-11:-3])
140
+
141
+ if not test:
142
+ # Load discriminators
143
+ last_model_name = get_model_list(checkpoint_dir, "dis", iteration=iteration)
144
+ state_dict = torch.load(last_model_name)
145
+ self.localD.load_state_dict(state_dict['localD'])
146
+ self.globalD.load_state_dict(state_dict['globalD'])
147
+ # Load optimizers
148
+ state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
149
+ self.optimizer_d.load_state_dict(state_dict['dis'])
150
+ self.optimizer_g.load_state_dict(state_dict['gen'])
151
+
152
+ print("Resume from {} at iteration {}".format(checkpoint_dir, iteration))
153
+ logger.info("Resume from {} at iteration {}".format(checkpoint_dir, iteration))
154
+
155
+ return iteration