File size: 7,088 Bytes
db534ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import math
import tensorflow as tf
import numpy as np
import dnnlib.tflib as tflib
from functools import partial


def create_stub(name, batch_size):
    return tf.constant(0, dtype='float32', shape=(batch_size, 0))


def create_variable_for_generator(name, batch_size, tiled_dlatent, model_scale=18, tile_size = 1):
    if tiled_dlatent:
        low_dim_dlatent = tf.get_variable('learnable_dlatents',
            shape=(batch_size, tile_size, 512),
            dtype='float32',
            initializer=tf.initializers.random_normal())
        return tf.tile(low_dim_dlatent, [1, model_scale // tile_size, 1])
    else:
        return tf.get_variable('learnable_dlatents',
            shape=(batch_size, model_scale, 512),
            dtype='float32',
            initializer=tf.initializers.random_normal())


class Generator:
    def __init__(self, model, batch_size, custom_input=None, clipping_threshold=2, tiled_dlatent=False, model_res=1024, randomize_noise=False):
        self.batch_size = batch_size
        self.tiled_dlatent=tiled_dlatent
        self.model_scale = int(2*(math.log(model_res,2)-1)) # For example, 1024 -> 18

        if tiled_dlatent:
            self.initial_dlatents = np.zeros((self.batch_size, 512))
            model.components.synthesis.run(np.zeros((self.batch_size, self.model_scale, 512)),
                randomize_noise=randomize_noise, minibatch_size=self.batch_size,
                custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=True),
                                                partial(create_stub, batch_size=batch_size)],
                structure='fixed')
        else:
            self.initial_dlatents = np.zeros((self.batch_size, self.model_scale, 512))
            if custom_input is not None:
                model.components.synthesis.run(self.initial_dlatents,
                    randomize_noise=randomize_noise, minibatch_size=self.batch_size,
                    custom_inputs=[partial(custom_input.eval(), batch_size=batch_size), partial(create_stub, batch_size=batch_size)],
                    structure='fixed')
            else:
                model.components.synthesis.run(self.initial_dlatents,
                    randomize_noise=randomize_noise, minibatch_size=self.batch_size,
                    custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=False, model_scale=self.model_scale),
                                                    partial(create_stub, batch_size=batch_size)],
                    structure='fixed')

        self.dlatent_avg_def = model.get_var('dlatent_avg')
        self.reset_dlatent_avg()
        self.sess = tf.compat.v1.get_default_session()
        self.graph = tf.compat.v1.get_default_graph()

        self.dlatent_variable = next(v for v in tf.compat.v1.global_variables() if 'learnable_dlatents' in v.name)
        self._assign_dlatent_ph = tf.compat.v1.placeholder(tf.float32, name="assign_dlatent_ph")
        self._assign_dlantent = tf.assign(self.dlatent_variable, self._assign_dlatent_ph)
        self.set_dlatents(self.initial_dlatents)

        def get_tensor(name):
            try:
                return self.graph.get_tensor_by_name(name)
            except KeyError:
                return None

        self.generator_output = get_tensor('G_synthesis_1/_Run/concat:0')
        if self.generator_output is None:
            self.generator_output = get_tensor('G_synthesis_1/_Run/concat/concat:0')
        if self.generator_output is None:
            self.generator_output = get_tensor('G_synthesis_1/_Run/concat_1/concat:0')
        # If we loaded only Gs and didn't load G or D, then scope "G_synthesis_1" won't exist in the graph.
        if self.generator_output is None:
            self.generator_output = get_tensor('G_synthesis/_Run/concat:0')
        if self.generator_output is None:
            self.generator_output = get_tensor('G_synthesis/_Run/concat/concat:0')
        if self.generator_output is None:
            self.generator_output = get_tensor('G_synthesis/_Run/concat_1/concat:0')
        if self.generator_output is None:
            for op in self.graph.get_operations():
                print(op)
            raise Exception("Couldn't find G_synthesis_1/_Run/concat tensor output")
        self.generated_image = tflib.convert_images_to_uint8(self.generator_output, nchw_to_nhwc=True, uint8_cast=False)
        self.generated_image_uint8 = tf.saturate_cast(self.generated_image, tf.uint8)

        # Implement stochastic clipping similar to what is described in https://arxiv.org/abs/1702.04782
        # (Slightly different in that the latent space is normal gaussian here and was uniform in [-1, 1] in that paper,
        # so we clip any vector components outside of [-2, 2]. It seems fine, but I haven't done an ablation check.)
        clipping_mask = tf.math.logical_or(self.dlatent_variable > clipping_threshold, self.dlatent_variable < -clipping_threshold)
        clipped_values = tf.where(clipping_mask, tf.random.normal(shape=self.dlatent_variable.shape), self.dlatent_variable)
        self.stochastic_clip_op = tf.assign(self.dlatent_variable, clipped_values)

    def reset_dlatents(self):
        self.set_dlatents(self.initial_dlatents)

    def set_dlatents(self, dlatents):
        if self.tiled_dlatent:
            if (dlatents.shape != (self.batch_size, 512)) and (dlatents.shape[1] != 512):
                dlatents = np.mean(dlatents, axis=1)
            if (dlatents.shape != (self.batch_size, 512)):
                dlatents = np.vstack([dlatents, np.zeros((self.batch_size-dlatents.shape[0], 512))])
            assert (dlatents.shape == (self.batch_size, 512))
        else:
            if (dlatents.shape[1] > self.model_scale):
                dlatents = dlatents[:,:self.model_scale,:]
            if (isinstance(dlatents.shape[0], int)):
                if (dlatents.shape != (self.batch_size, self.model_scale, 512)):
                    dlatents = np.vstack([dlatents, np.zeros((self.batch_size-dlatents.shape[0], self.model_scale, 512))])
                assert (dlatents.shape == (self.batch_size, self.model_scale, 512))
                self.sess.run([self._assign_dlantent], {self._assign_dlatent_ph: dlatents})
                return
            else:
                self._assign_dlantent = tf.assign(self.dlatent_variable, dlatents)
                return
        self.sess.run([self._assign_dlantent], {self._assign_dlatent_ph: dlatents})

    def stochastic_clip_dlatents(self):
        self.sess.run(self.stochastic_clip_op)

    def get_dlatents(self):
        return self.sess.run(self.dlatent_variable)

    def get_dlatent_avg(self):
        return self.dlatent_avg

    def set_dlatent_avg(self, dlatent_avg):
        self.dlatent_avg = dlatent_avg

    def reset_dlatent_avg(self):
        self.dlatent_avg = self.dlatent_avg_def

    def generate_images(self, dlatents=None):
        if dlatents is not None:
            self.set_dlatents(dlatents)
        return self.sess.run(self.generated_image_uint8)