Update train.py
Browse files
train.py
CHANGED
@@ -197,7 +197,12 @@ def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_op
|
|
197 |
for cur_disc_step in range(5):
|
198 |
discriminator.zero_grad()
|
199 |
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
201 |
bw, dfm, color_for_real = bw.to(device), dfm.to(device), color_for_real.to(device)
|
202 |
|
203 |
y_real = torch.full((bw.size(0), 1), 0.9, device = device)
|
@@ -227,7 +232,12 @@ def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_op
|
|
227 |
|
228 |
colorizer.generator.zero_grad()
|
229 |
|
230 |
-
|
|
|
|
|
|
|
|
|
|
|
231 |
bw, dfm = bw.to(device), dfm.to(device)
|
232 |
|
233 |
y_real = torch.ones((bw.size(0), 1), device = device)
|
@@ -243,8 +253,6 @@ def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_op
|
|
243 |
|
244 |
generator_loss.backward()
|
245 |
gen_optimizer.step()
|
246 |
-
|
247 |
-
|
248 |
|
249 |
def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, ft_dataloader, device = 'cpu'):
|
250 |
colorizer.generator.train()
|
|
|
197 |
for cur_disc_step in range(5):
|
198 |
discriminator.zero_grad()
|
199 |
|
200 |
+
try:
|
201 |
+
bw, dfm, color_for_real = next(data_iter)
|
202 |
+
except StopIteration:
|
203 |
+
data_iter = iter(ft_dataloader)
|
204 |
+
bw, dfm, color_for_real = next(data_iter)
|
205 |
+
|
206 |
bw, dfm, color_for_real = bw.to(device), dfm.to(device), color_for_real.to(device)
|
207 |
|
208 |
y_real = torch.full((bw.size(0), 1), 0.9, device = device)
|
|
|
232 |
|
233 |
colorizer.generator.zero_grad()
|
234 |
|
235 |
+
try:
|
236 |
+
bw, dfm, _ = next(data_iter)
|
237 |
+
except StopIteration:
|
238 |
+
data_iter = iter(ft_dataloader)
|
239 |
+
bw, dfm, _ = next(data_iter)
|
240 |
+
|
241 |
bw, dfm = bw.to(device), dfm.to(device)
|
242 |
|
243 |
y_real = torch.ones((bw.size(0), 1), device = device)
|
|
|
253 |
|
254 |
generator_loss.backward()
|
255 |
gen_optimizer.step()
|
|
|
|
|
256 |
|
257 |
def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, ft_dataloader, device = 'cpu'):
|
258 |
colorizer.generator.train()
|