Spaces:
Build error
Build error
Update apps/Normal.py
Browse files- apps/Normal.py +31 -38
apps/Normal.py
CHANGED
@@ -10,13 +10,9 @@ import pytorch_lightning as pl
|
|
10 |
torch.backends.cudnn.benchmark = True
|
11 |
|
12 |
logging.getLogger("lightning").setLevel(logging.ERROR)
|
13 |
-
import warnings
|
14 |
-
|
15 |
-
warnings.filterwarnings("ignore")
|
16 |
|
17 |
|
18 |
class Normal(pl.LightningModule):
|
19 |
-
|
20 |
def __init__(self, cfg):
|
21 |
super(Normal, self).__init__()
|
22 |
self.cfg = cfg
|
@@ -42,28 +38,26 @@ class Normal(pl.LightningModule):
|
|
42 |
weight_decay = self.cfg.weight_decay
|
43 |
momentum = self.cfg.momentum
|
44 |
|
45 |
-
optim_params_N_F = [
|
46 |
-
"params": self.netG.netF.parameters(),
|
47 |
-
|
48 |
-
|
49 |
-
optim_params_N_B = [{
|
50 |
-
"params": self.netG.netB.parameters(),
|
51 |
-
"lr": self.lr_N
|
52 |
-
}]
|
53 |
|
54 |
-
optimizer_N_F = torch.optim.Adam(
|
55 |
-
|
56 |
-
|
57 |
|
58 |
-
optimizer_N_B = torch.optim.Adam(
|
59 |
-
|
60 |
-
|
61 |
|
62 |
scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(
|
63 |
-
optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma
|
|
|
64 |
|
65 |
scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(
|
66 |
-
optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma
|
|
|
67 |
|
68 |
self.schedulers = [scheduler_N_F, scheduler_N_B]
|
69 |
optims = [optimizer_N_F, optimizer_N_B]
|
@@ -78,11 +72,13 @@ class Normal(pl.LightningModule):
|
|
78 |
for name in render_tensor.keys():
|
79 |
result_list.append(
|
80 |
resize(
|
81 |
-
((render_tensor[name].cpu().numpy()[0] + 1.0) /
|
82 |
-
|
|
|
83 |
(height, height),
|
84 |
anti_aliasing=True,
|
85 |
-
)
|
|
|
86 |
result_array = np.concatenate(result_list, axis=1)
|
87 |
|
88 |
return result_array
|
@@ -96,16 +92,14 @@ class Normal(pl.LightningModule):
|
|
96 |
for name in self.in_nml:
|
97 |
in_tensor[name] = batch[name]
|
98 |
|
99 |
-
FB_tensor = {
|
100 |
-
|
101 |
-
"normal_B": batch["normal_B"]
|
102 |
-
}
|
103 |
|
104 |
self.netG.train()
|
105 |
|
106 |
preds_F, preds_B = self.netG(in_tensor)
|
107 |
-
error_NF, error_NB = self.netG.get_norm_error(
|
108 |
-
|
109 |
|
110 |
(opt_nf, opt_nb) = self.optimizers()
|
111 |
|
@@ -175,19 +169,18 @@ class Normal(pl.LightningModule):
|
|
175 |
for name in self.in_nml:
|
176 |
in_tensor[name] = batch[name]
|
177 |
|
178 |
-
FB_tensor = {
|
179 |
-
|
180 |
-
"normal_B": batch["normal_B"]
|
181 |
-
}
|
182 |
|
183 |
self.netG.train()
|
184 |
|
185 |
preds_F, preds_B = self.netG(in_tensor)
|
186 |
-
error_NF, error_NB = self.netG.get_norm_error(
|
187 |
-
|
188 |
|
189 |
-
if (batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train)
|
190 |
-
|
|
|
191 |
|
192 |
with torch.no_grad():
|
193 |
nmlF, nmlB = self.netG(in_tensor)
|
@@ -217,4 +210,4 @@ class Normal(pl.LightningModule):
|
|
217 |
|
218 |
tf_log = tf_log_convert(metrics_log)
|
219 |
|
220 |
-
return {"log": tf_log}
|
|
|
10 |
torch.backends.cudnn.benchmark = True
|
11 |
|
12 |
logging.getLogger("lightning").setLevel(logging.ERROR)
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
class Normal(pl.LightningModule):
|
|
|
16 |
def __init__(self, cfg):
|
17 |
super(Normal, self).__init__()
|
18 |
self.cfg = cfg
|
|
|
38 |
weight_decay = self.cfg.weight_decay
|
39 |
momentum = self.cfg.momentum
|
40 |
|
41 |
+
optim_params_N_F = [
|
42 |
+
{"params": self.netG.netF.parameters(), "lr": self.lr_N}]
|
43 |
+
optim_params_N_B = [
|
44 |
+
{"params": self.netG.netB.parameters(), "lr": self.lr_N}]
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
optimizer_N_F = torch.optim.Adam(
|
47 |
+
optim_params_N_F, lr=self.lr_N, weight_decay=weight_decay
|
48 |
+
)
|
49 |
|
50 |
+
optimizer_N_B = torch.optim.Adam(
|
51 |
+
optim_params_N_B, lr=self.lr_N, weight_decay=weight_decay
|
52 |
+
)
|
53 |
|
54 |
scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(
|
55 |
+
optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma
|
56 |
+
)
|
57 |
|
58 |
scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(
|
59 |
+
optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma
|
60 |
+
)
|
61 |
|
62 |
self.schedulers = [scheduler_N_F, scheduler_N_B]
|
63 |
optims = [optimizer_N_F, optimizer_N_B]
|
|
|
72 |
for name in render_tensor.keys():
|
73 |
result_list.append(
|
74 |
resize(
|
75 |
+
((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose(
|
76 |
+
1, 2, 0
|
77 |
+
),
|
78 |
(height, height),
|
79 |
anti_aliasing=True,
|
80 |
+
)
|
81 |
+
)
|
82 |
result_array = np.concatenate(result_list, axis=1)
|
83 |
|
84 |
return result_array
|
|
|
92 |
for name in self.in_nml:
|
93 |
in_tensor[name] = batch[name]
|
94 |
|
95 |
+
FB_tensor = {"normal_F": batch["normal_F"],
|
96 |
+
"normal_B": batch["normal_B"]}
|
|
|
|
|
97 |
|
98 |
self.netG.train()
|
99 |
|
100 |
preds_F, preds_B = self.netG(in_tensor)
|
101 |
+
error_NF, error_NB = self.netG.get_norm_error(
|
102 |
+
preds_F, preds_B, FB_tensor)
|
103 |
|
104 |
(opt_nf, opt_nb) = self.optimizers()
|
105 |
|
|
|
169 |
for name in self.in_nml:
|
170 |
in_tensor[name] = batch[name]
|
171 |
|
172 |
+
FB_tensor = {"normal_F": batch["normal_F"],
|
173 |
+
"normal_B": batch["normal_B"]}
|
|
|
|
|
174 |
|
175 |
self.netG.train()
|
176 |
|
177 |
preds_F, preds_B = self.netG(in_tensor)
|
178 |
+
error_NF, error_NB = self.netG.get_norm_error(
|
179 |
+
preds_F, preds_B, FB_tensor)
|
180 |
|
181 |
+
if (batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0) or (
|
182 |
+
batch_idx == 0
|
183 |
+
):
|
184 |
|
185 |
with torch.no_grad():
|
186 |
nmlF, nmlB = self.netG(in_tensor)
|
|
|
210 |
|
211 |
tf_log = tf_log_convert(metrics_log)
|
212 |
|
213 |
+
return {"log": tf_log}
|