Update train.py
Browse files
train.py
CHANGED
@@ -246,7 +246,7 @@ def fine_tuning_step(data_iter, colorizer, discriminator, gen_optimizer, disc_op
|
|
246 |
|
247 |
|
248 |
|
249 |
-
def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer,
|
250 |
colorizer.generator.train()
|
251 |
discriminator.train()
|
252 |
|
@@ -265,6 +265,7 @@ def fine_tuning(colorizer, discriminator, content, dataloader, iterations, color
|
|
265 |
disc_step = disc_step ^ True
|
266 |
|
267 |
if n % 10 == 5:
|
|
|
268 |
fine_tuning_step(data_iter, colorizer, discriminator, colorizer_optimizer, discriminator_optimizer, device)
|
269 |
|
270 |
if __name__ == '__main__':
|
|
|
246 |
|
247 |
|
248 |
|
249 |
+
def fine_tuning(colorizer, discriminator, content, dataloader, iterations, colorizer_optimizer, discriminator_optimizer, ft_dataloader, device = 'cpu'):
|
250 |
colorizer.generator.train()
|
251 |
discriminator.train()
|
252 |
|
|
|
265 |
disc_step = disc_step ^ True
|
266 |
|
267 |
if n % 10 == 5:
|
268 |
+
data_iter = iter(ft_dataloader)
|
269 |
fine_tuning_step(data_iter, colorizer, discriminator, colorizer_optimizer, discriminator_optimizer, device)
|
270 |
|
271 |
if __name__ == '__main__':
|