Spaces:
Runtime error
Runtime error
import torch | |
from models import MultiGPUModelWrapper | |
from swapae.optimizers.swapping_autoencoder_optimizer import SwappingAutoencoderOptimizer | |
import swapae.util | |
class ClassifierOptimizer(SwappingAutoencoderOptimizer): | |
def modify_commandline_options(parser, is_train): | |
parser = SwappingAutoencoderOptimizer.modify_commandline_options(parser, is_train) | |
return parser | |
def train_one_step(self, data_i, total_steps_so_far): | |
images_minibatch, labels = self.prepare_images(data_i) | |
c_losses = self.train_classifier_one_step(images_minibatch, labels) | |
self.adjust_lr_if_necessary(total_steps_so_far) | |
return util.to_numpy(c_losses) | |
def train_classifier_one_step(self, images, labels): | |
self.set_requires_grad(self.Gparams, False) | |
self.optimizer_C.zero_grad() | |
losses, metrics = self.model(images, labels, command="compute_classifier_losses") | |
loss = sum([v.mean() for v in losses.values()]) | |
loss.backward() | |
self.optimizer_C.step() | |
losses.update(metrics) | |
return losses | |
def get_visuals_for_snapshot(self, data_i): | |
images, labels = self.prepare_images(data_i) | |
with torch.no_grad(): | |
return self.model(images, labels, command="get_visuals_for_snapshot") | |
def save(self, total_steps_so_far): | |
self.model.save(total_steps_so_far) | |