Spaces:
Runtime error
Runtime error
File size: 1,400 Bytes
1b2a9b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
from models import MultiGPUModelWrapper
from swapae.optimizers.swapping_autoencoder_optimizer import SwappingAutoencoderOptimizer
import swapae.util
class ClassifierOptimizer(SwappingAutoencoderOptimizer):
@staticmethod
def modify_commandline_options(parser, is_train):
parser = SwappingAutoencoderOptimizer.modify_commandline_options(parser, is_train)
return parser
def train_one_step(self, data_i, total_steps_so_far):
images_minibatch, labels = self.prepare_images(data_i)
c_losses = self.train_classifier_one_step(images_minibatch, labels)
self.adjust_lr_if_necessary(total_steps_so_far)
return util.to_numpy(c_losses)
def train_classifier_one_step(self, images, labels):
self.set_requires_grad(self.Gparams, False)
self.optimizer_C.zero_grad()
losses, metrics = self.model(images, labels, command="compute_classifier_losses")
loss = sum([v.mean() for v in losses.values()])
loss.backward()
self.optimizer_C.step()
losses.update(metrics)
return losses
def get_visuals_for_snapshot(self, data_i):
images, labels = self.prepare_images(data_i)
with torch.no_grad():
return self.model(images, labels, command="get_visuals_for_snapshot")
def save(self, total_steps_so_far):
self.model.save(total_steps_so_far)
|