YoonaAI commited on
Commit
2b5b023
1 Parent(s): be7be52

Update apps/Normal.py

Browse files
Files changed (1) hide show
  1. apps/Normal.py +31 -38
apps/Normal.py CHANGED
@@ -10,13 +10,9 @@ import pytorch_lightning as pl
10
  torch.backends.cudnn.benchmark = True
11
 
12
  logging.getLogger("lightning").setLevel(logging.ERROR)
13
- import warnings
14
-
15
- warnings.filterwarnings("ignore")
16
 
17
 
18
  class Normal(pl.LightningModule):
19
-
20
  def __init__(self, cfg):
21
  super(Normal, self).__init__()
22
  self.cfg = cfg
@@ -42,28 +38,26 @@ class Normal(pl.LightningModule):
42
  weight_decay = self.cfg.weight_decay
43
  momentum = self.cfg.momentum
44
 
45
- optim_params_N_F = [{
46
- "params": self.netG.netF.parameters(),
47
- "lr": self.lr_N
48
- }]
49
- optim_params_N_B = [{
50
- "params": self.netG.netB.parameters(),
51
- "lr": self.lr_N
52
- }]
53
 
54
- optimizer_N_F = torch.optim.Adam(optim_params_N_F,
55
- lr=self.lr_N,
56
- weight_decay=weight_decay)
57
 
58
- optimizer_N_B = torch.optim.Adam(optim_params_N_B,
59
- lr=self.lr_N,
60
- weight_decay=weight_decay)
61
 
62
  scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(
63
- optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma)
 
64
 
65
  scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(
66
- optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma)
 
67
 
68
  self.schedulers = [scheduler_N_F, scheduler_N_B]
69
  optims = [optimizer_N_F, optimizer_N_B]
@@ -78,11 +72,13 @@ class Normal(pl.LightningModule):
78
  for name in render_tensor.keys():
79
  result_list.append(
80
  resize(
81
- ((render_tensor[name].cpu().numpy()[0] + 1.0) /
82
- 2.0).transpose(1, 2, 0),
 
83
  (height, height),
84
  anti_aliasing=True,
85
- ))
 
86
  result_array = np.concatenate(result_list, axis=1)
87
 
88
  return result_array
@@ -96,16 +92,14 @@ class Normal(pl.LightningModule):
96
  for name in self.in_nml:
97
  in_tensor[name] = batch[name]
98
 
99
- FB_tensor = {
100
- "normal_F": batch["normal_F"],
101
- "normal_B": batch["normal_B"]
102
- }
103
 
104
  self.netG.train()
105
 
106
  preds_F, preds_B = self.netG(in_tensor)
107
- error_NF, error_NB = self.netG.get_norm_error(preds_F, preds_B,
108
- FB_tensor)
109
 
110
  (opt_nf, opt_nb) = self.optimizers()
111
 
@@ -175,19 +169,18 @@ class Normal(pl.LightningModule):
175
  for name in self.in_nml:
176
  in_tensor[name] = batch[name]
177
 
178
- FB_tensor = {
179
- "normal_F": batch["normal_F"],
180
- "normal_B": batch["normal_B"]
181
- }
182
 
183
  self.netG.train()
184
 
185
  preds_F, preds_B = self.netG(in_tensor)
186
- error_NF, error_NB = self.netG.get_norm_error(preds_F, preds_B,
187
- FB_tensor)
188
 
189
- if (batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train)
190
- == 0) or (batch_idx == 0):
 
191
 
192
  with torch.no_grad():
193
  nmlF, nmlB = self.netG(in_tensor)
@@ -217,4 +210,4 @@ class Normal(pl.LightningModule):
217
 
218
  tf_log = tf_log_convert(metrics_log)
219
 
220
- return {"log": tf_log}
 
10
  torch.backends.cudnn.benchmark = True
11
 
12
  logging.getLogger("lightning").setLevel(logging.ERROR)
 
 
 
13
 
14
 
15
  class Normal(pl.LightningModule):
 
16
  def __init__(self, cfg):
17
  super(Normal, self).__init__()
18
  self.cfg = cfg
 
38
  weight_decay = self.cfg.weight_decay
39
  momentum = self.cfg.momentum
40
 
41
+ optim_params_N_F = [
42
+ {"params": self.netG.netF.parameters(), "lr": self.lr_N}]
43
+ optim_params_N_B = [
44
+ {"params": self.netG.netB.parameters(), "lr": self.lr_N}]
 
 
 
 
45
 
46
+ optimizer_N_F = torch.optim.Adam(
47
+ optim_params_N_F, lr=self.lr_N, weight_decay=weight_decay
48
+ )
49
 
50
+ optimizer_N_B = torch.optim.Adam(
51
+ optim_params_N_B, lr=self.lr_N, weight_decay=weight_decay
52
+ )
53
 
54
  scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(
55
+ optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma
56
+ )
57
 
58
  scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(
59
+ optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma
60
+ )
61
 
62
  self.schedulers = [scheduler_N_F, scheduler_N_B]
63
  optims = [optimizer_N_F, optimizer_N_B]
 
72
  for name in render_tensor.keys():
73
  result_list.append(
74
  resize(
75
+ ((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose(
76
+ 1, 2, 0
77
+ ),
78
  (height, height),
79
  anti_aliasing=True,
80
+ )
81
+ )
82
  result_array = np.concatenate(result_list, axis=1)
83
 
84
  return result_array
 
92
  for name in self.in_nml:
93
  in_tensor[name] = batch[name]
94
 
95
+ FB_tensor = {"normal_F": batch["normal_F"],
96
+ "normal_B": batch["normal_B"]}
 
 
97
 
98
  self.netG.train()
99
 
100
  preds_F, preds_B = self.netG(in_tensor)
101
+ error_NF, error_NB = self.netG.get_norm_error(
102
+ preds_F, preds_B, FB_tensor)
103
 
104
  (opt_nf, opt_nb) = self.optimizers()
105
 
 
169
  for name in self.in_nml:
170
  in_tensor[name] = batch[name]
171
 
172
+ FB_tensor = {"normal_F": batch["normal_F"],
173
+ "normal_B": batch["normal_B"]}
 
 
174
 
175
  self.netG.train()
176
 
177
  preds_F, preds_B = self.netG(in_tensor)
178
+ error_NF, error_NB = self.netG.get_norm_error(
179
+ preds_F, preds_B, FB_tensor)
180
 
181
+ if (batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0) or (
182
+ batch_idx == 0
183
+ ):
184
 
185
  with torch.no_grad():
186
  nmlF, nmlB = self.netG(in_tensor)
 
210
 
211
  tf_log = tf_log_convert(metrics_log)
212
 
213
+ return {"log": tf_log}