File size: 7,088 Bytes
db534ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import math
import tensorflow as tf
import numpy as np
import dnnlib.tflib as tflib
from functools import partial
def create_stub(name, batch_size):
return tf.constant(0, dtype='float32', shape=(batch_size, 0))
def create_variable_for_generator(name, batch_size, tiled_dlatent, model_scale=18, tile_size = 1):
if tiled_dlatent:
low_dim_dlatent = tf.get_variable('learnable_dlatents',
shape=(batch_size, tile_size, 512),
dtype='float32',
initializer=tf.initializers.random_normal())
return tf.tile(low_dim_dlatent, [1, model_scale // tile_size, 1])
else:
return tf.get_variable('learnable_dlatents',
shape=(batch_size, model_scale, 512),
dtype='float32',
initializer=tf.initializers.random_normal())
class Generator:
def __init__(self, model, batch_size, custom_input=None, clipping_threshold=2, tiled_dlatent=False, model_res=1024, randomize_noise=False):
self.batch_size = batch_size
self.tiled_dlatent=tiled_dlatent
self.model_scale = int(2*(math.log(model_res,2)-1)) # For example, 1024 -> 18
if tiled_dlatent:
self.initial_dlatents = np.zeros((self.batch_size, 512))
model.components.synthesis.run(np.zeros((self.batch_size, self.model_scale, 512)),
randomize_noise=randomize_noise, minibatch_size=self.batch_size,
custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=True),
partial(create_stub, batch_size=batch_size)],
structure='fixed')
else:
self.initial_dlatents = np.zeros((self.batch_size, self.model_scale, 512))
if custom_input is not None:
model.components.synthesis.run(self.initial_dlatents,
randomize_noise=randomize_noise, minibatch_size=self.batch_size,
custom_inputs=[partial(custom_input.eval(), batch_size=batch_size), partial(create_stub, batch_size=batch_size)],
structure='fixed')
else:
model.components.synthesis.run(self.initial_dlatents,
randomize_noise=randomize_noise, minibatch_size=self.batch_size,
custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=False, model_scale=self.model_scale),
partial(create_stub, batch_size=batch_size)],
structure='fixed')
self.dlatent_avg_def = model.get_var('dlatent_avg')
self.reset_dlatent_avg()
self.sess = tf.compat.v1.get_default_session()
self.graph = tf.compat.v1.get_default_graph()
self.dlatent_variable = next(v for v in tf.compat.v1.global_variables() if 'learnable_dlatents' in v.name)
self._assign_dlatent_ph = tf.compat.v1.placeholder(tf.float32, name="assign_dlatent_ph")
self._assign_dlantent = tf.assign(self.dlatent_variable, self._assign_dlatent_ph)
self.set_dlatents(self.initial_dlatents)
def get_tensor(name):
try:
return self.graph.get_tensor_by_name(name)
except KeyError:
return None
self.generator_output = get_tensor('G_synthesis_1/_Run/concat:0')
if self.generator_output is None:
self.generator_output = get_tensor('G_synthesis_1/_Run/concat/concat:0')
if self.generator_output is None:
self.generator_output = get_tensor('G_synthesis_1/_Run/concat_1/concat:0')
# If we loaded only Gs and didn't load G or D, then scope "G_synthesis_1" won't exist in the graph.
if self.generator_output is None:
self.generator_output = get_tensor('G_synthesis/_Run/concat:0')
if self.generator_output is None:
self.generator_output = get_tensor('G_synthesis/_Run/concat/concat:0')
if self.generator_output is None:
self.generator_output = get_tensor('G_synthesis/_Run/concat_1/concat:0')
if self.generator_output is None:
for op in self.graph.get_operations():
print(op)
raise Exception("Couldn't find G_synthesis_1/_Run/concat tensor output")
self.generated_image = tflib.convert_images_to_uint8(self.generator_output, nchw_to_nhwc=True, uint8_cast=False)
self.generated_image_uint8 = tf.saturate_cast(self.generated_image, tf.uint8)
# Implement stochastic clipping similar to what is described in https://arxiv.org/abs/1702.04782
# (Slightly different in that the latent space is normal gaussian here and was uniform in [-1, 1] in that paper,
# so we clip any vector components outside of [-2, 2]. It seems fine, but I haven't done an ablation check.)
clipping_mask = tf.math.logical_or(self.dlatent_variable > clipping_threshold, self.dlatent_variable < -clipping_threshold)
clipped_values = tf.where(clipping_mask, tf.random.normal(shape=self.dlatent_variable.shape), self.dlatent_variable)
self.stochastic_clip_op = tf.assign(self.dlatent_variable, clipped_values)
def reset_dlatents(self):
self.set_dlatents(self.initial_dlatents)
def set_dlatents(self, dlatents):
if self.tiled_dlatent:
if (dlatents.shape != (self.batch_size, 512)) and (dlatents.shape[1] != 512):
dlatents = np.mean(dlatents, axis=1)
if (dlatents.shape != (self.batch_size, 512)):
dlatents = np.vstack([dlatents, np.zeros((self.batch_size-dlatents.shape[0], 512))])
assert (dlatents.shape == (self.batch_size, 512))
else:
if (dlatents.shape[1] > self.model_scale):
dlatents = dlatents[:,:self.model_scale,:]
if (isinstance(dlatents.shape[0], int)):
if (dlatents.shape != (self.batch_size, self.model_scale, 512)):
dlatents = np.vstack([dlatents, np.zeros((self.batch_size-dlatents.shape[0], self.model_scale, 512))])
assert (dlatents.shape == (self.batch_size, self.model_scale, 512))
self.sess.run([self._assign_dlantent], {self._assign_dlatent_ph: dlatents})
return
else:
self._assign_dlantent = tf.assign(self.dlatent_variable, dlatents)
return
self.sess.run([self._assign_dlantent], {self._assign_dlatent_ph: dlatents})
def stochastic_clip_dlatents(self):
self.sess.run(self.stochastic_clip_op)
def get_dlatents(self):
return self.sess.run(self.dlatent_variable)
def get_dlatent_avg(self):
return self.dlatent_avg
def set_dlatent_avg(self, dlatent_avg):
self.dlatent_avg = dlatent_avg
def reset_dlatent_avg(self):
self.dlatent_avg = self.dlatent_avg_def
def generate_images(self, dlatents=None):
if dlatents is not None:
self.set_dlatents(dlatents)
return self.sess.run(self.generated_image_uint8)
|