TextureScraping / swapae /optimizers /classifier_optimizer.py
sunshineatnoon
Add application file
1b2a9b1
raw
history blame
No virus
1.4 kB
import torch
from models import MultiGPUModelWrapper
from swapae.optimizers.swapping_autoencoder_optimizer import SwappingAutoencoderOptimizer
import swapae.util
class ClassifierOptimizer(SwappingAutoencoderOptimizer):
@staticmethod
def modify_commandline_options(parser, is_train):
parser = SwappingAutoencoderOptimizer.modify_commandline_options(parser, is_train)
return parser
def train_one_step(self, data_i, total_steps_so_far):
images_minibatch, labels = self.prepare_images(data_i)
c_losses = self.train_classifier_one_step(images_minibatch, labels)
self.adjust_lr_if_necessary(total_steps_so_far)
return util.to_numpy(c_losses)
def train_classifier_one_step(self, images, labels):
self.set_requires_grad(self.Gparams, False)
self.optimizer_C.zero_grad()
losses, metrics = self.model(images, labels, command="compute_classifier_losses")
loss = sum([v.mean() for v in losses.values()])
loss.backward()
self.optimizer_C.step()
losses.update(metrics)
return losses
def get_visuals_for_snapshot(self, data_i):
images, labels = self.prepare_images(data_i)
with torch.no_grad():
return self.model(images, labels, command="get_visuals_for_snapshot")
def save(self, total_steps_so_far):
self.model.save(total_steps_so_far)