Keiser41 commited on
Commit
80cbbd2
1 Parent(s): 8ff9b84

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +2 -1
train.py CHANGED
@@ -246,7 +246,7 @@ def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_op
246
 
247
 
248
 
249
- def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, data_iter, device = 'cpu'):
250
  colorizer.generator.train()
251
  discriminator.train()
252
 
@@ -265,6 +265,7 @@ def fine_tuning(colorizer, discriminator, content, dataloader, iterations, color
265
  disc_step = disc_step ^ True
266
 
267
  if n % 10 == 5:
 
268
  fine_tuning_step(data_iter, colorizer, discriminator, colorizer_optimizer, discriminator_optimizer, device)
269
 
270
  if __name__ == '__main__':
 
246
 
247
 
248
 
249
+ def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, ft_dataloader, device = 'cpu'):
250
  colorizer.generator.train()
251
  discriminator.train()
252
 
 
265
  disc_step = disc_step ^ True
266
 
267
  if n % 10 == 5:
268
+ data_iter = iter(ft_dataloader)
269
  fine_tuning_step(data_iter, colorizer, discriminator, colorizer_optimizer, discriminator_optimizer, device)
270
 
271
  if __name__ == '__main__':