Keiser41 commited on
Commit
41ed247
1 Parent(s): d8e5829

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +12 -4
train.py CHANGED
@@ -197,7 +197,12 @@ def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_op
197
  for cur_disc_step in range(5):
198
  discriminator.zero_grad()
199
 
200
- bw, dfm, color_for_real = next(data_iter)
 
 
 
 
 
201
  bw, dfm, color_for_real = bw.to(device), dfm.to(device), color_for_real.to(device)
202
 
203
  y_real = torch.full((bw.size(0), 1), 0.9, device = device)
@@ -227,7 +232,12 @@ def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_op
227
 
228
  colorizer.generator.zero_grad()
229
 
230
- bw, dfm, _ = next(data_iter)
 
 
 
 
 
231
  bw, dfm = bw.to(device), dfm.to(device)
232
 
233
  y_real = torch.ones((bw.size(0), 1), device = device)
@@ -243,8 +253,6 @@ def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_op
243
 
244
  generator_loss.backward()
245
  gen_optimizer.step()
246
-
247
-
248
 
249
  def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, ft_dataloader, device = 'cpu'):
250
  colorizer.generator.train()
 
197
  for cur_disc_step in range(5):
198
  discriminator.zero_grad()
199
 
200
+ try:
201
+ bw, dfm, color_for_real = next(data_iter)
202
+ except StopIteration:
203
+ data_iter = iter(ft_dataloader)
204
+ bw, dfm, color_for_real = next(data_iter)
205
+
206
  bw, dfm, color_for_real = bw.to(device), dfm.to(device), color_for_real.to(device)
207
 
208
  y_real = torch.full((bw.size(0), 1), 0.9, device = device)
 
232
 
233
  colorizer.generator.zero_grad()
234
 
235
+ try:
236
+ bw, dfm, _ = next(data_iter)
237
+ except StopIteration:
238
+ data_iter = iter(ft_dataloader)
239
+ bw, dfm, _ = next(data_iter)
240
+
241
  bw, dfm = bw.to(device), dfm.to(device)
242
 
243
  y_real = torch.ones((bw.size(0), 1), device = device)
 
253
 
254
  generator_loss.backward()
255
  gen_optimizer.step()
 
 
256
 
257
  def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, ft_dataloader, device = 'cpu'):
258
  colorizer.generator.train()