File size: 1,400 Bytes
1b2a9b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from models import MultiGPUModelWrapper
from swapae.optimizers.swapping_autoencoder_optimizer import SwappingAutoencoderOptimizer
import swapae.util


class ClassifierOptimizer(SwappingAutoencoderOptimizer):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser = SwappingAutoencoderOptimizer.modify_commandline_options(parser, is_train)
        return parser

    def train_one_step(self, data_i, total_steps_so_far):
        images_minibatch, labels = self.prepare_images(data_i)
        c_losses = self.train_classifier_one_step(images_minibatch, labels)
        self.adjust_lr_if_necessary(total_steps_so_far)
        return util.to_numpy(c_losses)

    def train_classifier_one_step(self, images, labels):
        self.set_requires_grad(self.Gparams, False)
        self.optimizer_C.zero_grad()
        losses, metrics = self.model(images, labels, command="compute_classifier_losses")
        loss = sum([v.mean() for v in losses.values()])
        loss.backward()
        self.optimizer_C.step()
        losses.update(metrics)
        return losses

    def get_visuals_for_snapshot(self, data_i):
        images, labels = self.prepare_images(data_i)
        with torch.no_grad():
            return self.model(images, labels, command="get_visuals_for_snapshot")

    def save(self, total_steps_so_far):
        self.model.save(total_steps_so_far)