moflo commited on
Commit
4b21d9f
1 Parent(s): 83a62dd

Testing styleGAN

Browse files
Files changed (1) hide show
  1. app.py +612 -4
app.py CHANGED
@@ -1,7 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello there" + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from subprocess import call
3
+ def run_cmd(command):
4
+ try:
5
+ print(command)
6
+ call(command, shell=True)
7
+ except KeyboardInterrupt:
8
+ print("Process interrupted")
9
+ sys.exit(1)
10
+
11
+ print("⬇️ Installing latest gradio==2.4.7b9")
12
+ run_cmd("pip install --upgrade pip")
13
+ run_cmd('pip install gradio==2.4.7b9')
14
+
15
  import gradio as gr
16
+ import io
17
+ import random
18
+ import math
19
+ import numpy as np
20
+ import matplotlib.pyplot as plt
21
+
22
+ from enum import Enum
23
+ from glob import glob
24
+ from functools import partial
25
+
26
+ import tensorflow as tf
27
+ from tensorflow import keras
28
+ from tensorflow.keras import layers
29
+ from tensorflow.keras.models import Sequential
30
+ from tensorflow_addons.layers import InstanceNormalization
31
+
32
+ import tensorflow_datasets as tfds
33
+
34
+ # Model Definition
35
+
36
+ def log2(x):
37
+ return int(np.log2(x))
38
+
39
+
40
+ def resize_image(res, sample):
41
+ print("Call resize_image...")
42
+ image = sample["image"]
43
+ # only donwsampling, so use nearest neighbor that is faster to run
44
+ image = tf.image.resize(
45
+ image, (res, res), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
46
+ )
47
+ image = tf.cast(image, tf.float32) / 127.5 - 1.0
48
+ return image
49
+
50
+
51
+ def create_dataloader(res):
52
+ batch_size = batch_sizes[log2(res)]
53
+ dl = ds_train.map(partial(resize_image, res), num_parallel_calls=tf.data.AUTOTUNE)
54
+ dl = dl.shuffle(200).batch(batch_size, drop_remainder=True).prefetch(1).repeat()
55
+ return dl
56
+
57
+ def fade_in(alpha, a, b):
58
+ return alpha * a + (1.0 - alpha) * b
59
+
60
+
61
+ def wasserstein_loss(y_true, y_pred):
62
+ return -tf.reduce_mean(y_true * y_pred)
63
+
64
+
65
+ def pixel_norm(x, epsilon=1e-8):
66
+ return x / tf.math.sqrt(tf.reduce_mean(x ** 2, axis=-1, keepdims=True) + epsilon)
67
+
68
+
69
+ def minibatch_std(input_tensor, epsilon=1e-8):
70
+ n, h, w, c = tf.shape(input_tensor)
71
+ group_size = tf.minimum(4, n)
72
+ x = tf.reshape(input_tensor, [group_size, -1, h, w, c])
73
+ group_mean, group_var = tf.nn.moments(x, axes=(0), keepdims=False)
74
+ group_std = tf.sqrt(group_var + epsilon)
75
+ avg_std = tf.reduce_mean(group_std, axis=[1, 2, 3], keepdims=True)
76
+ x = tf.tile(avg_std, [group_size, h, w, 1])
77
+ return tf.concat([input_tensor, x], axis=-1)
78
+
79
+
80
+ class EqualizedConv(layers.Layer):
81
+ def __init__(self, out_channels, kernel=3, gain=2, **kwargs):
82
+ super(EqualizedConv, self).__init__(**kwargs)
83
+ self.kernel = kernel
84
+ self.out_channels = out_channels
85
+ self.gain = gain
86
+ self.pad = kernel != 1
87
+
88
+ def build(self, input_shape):
89
+ self.in_channels = input_shape[-1]
90
+ initializer = keras.initializers.RandomNormal(mean=0.0, stddev=1.0)
91
+ self.w = self.add_weight(
92
+ shape=[self.kernel, self.kernel, self.in_channels, self.out_channels],
93
+ initializer=initializer,
94
+ trainable=True,
95
+ name="kernel",
96
+ )
97
+ self.b = self.add_weight(
98
+ shape=(self.out_channels,), initializer="zeros", trainable=True, name="bias"
99
+ )
100
+ fan_in = self.kernel * self.kernel * self.in_channels
101
+ self.scale = tf.sqrt(self.gain / fan_in)
102
+
103
+ def call(self, inputs):
104
+ if self.pad:
105
+ x = tf.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="REFLECT")
106
+ else:
107
+ x = inputs
108
+ output = (
109
+ tf.nn.conv2d(x, self.scale * self.w, strides=1, padding="VALID") + self.b
110
+ )
111
+ return output
112
+
113
+
114
+ class EqualizedDense(layers.Layer):
115
+ def __init__(self, units, gain=2, learning_rate_multiplier=1, **kwargs):
116
+ super(EqualizedDense, self).__init__(**kwargs)
117
+ self.units = units
118
+ self.gain = gain
119
+ self.learning_rate_multiplier = learning_rate_multiplier
120
+
121
+ def build(self, input_shape):
122
+ self.in_channels = input_shape[-1]
123
+ initializer = keras.initializers.RandomNormal(
124
+ mean=0.0, stddev=1.0 / self.learning_rate_multiplier
125
+ )
126
+ self.w = self.add_weight(
127
+ shape=[self.in_channels, self.units],
128
+ initializer=initializer,
129
+ trainable=True,
130
+ name="kernel",
131
+ )
132
+ self.b = self.add_weight(
133
+ shape=(self.units,), initializer="zeros", trainable=True, name="bias"
134
+ )
135
+ fan_in = self.in_channels
136
+ self.scale = tf.sqrt(self.gain / fan_in)
137
+
138
+ def call(self, inputs):
139
+ output = tf.add(tf.matmul(inputs, self.scale * self.w), self.b)
140
+ return output * self.learning_rate_multiplier
141
+
142
+
143
+ class AddNoise(layers.Layer):
144
+ def build(self, input_shape):
145
+ n, h, w, c = input_shape[0]
146
+ initializer = keras.initializers.RandomNormal(mean=0.0, stddev=1.0)
147
+ self.b = self.add_weight(
148
+ shape=[1, 1, 1, c], initializer=initializer, trainable=True, name="kernel"
149
+ )
150
+
151
+ def call(self, inputs):
152
+ x, noise = inputs
153
+ output = x + self.b * noise
154
+ return output
155
+
156
+
157
+ class AdaIN(layers.Layer):
158
+ def __init__(self, gain=1, **kwargs):
159
+ super(AdaIN, self).__init__(**kwargs)
160
+ self.gain = gain
161
+
162
+ def build(self, input_shapes):
163
+ x_shape = input_shapes[0]
164
+ w_shape = input_shapes[1]
165
+
166
+ self.w_channels = w_shape[-1]
167
+ self.x_channels = x_shape[-1]
168
+
169
+ self.dense_1 = EqualizedDense(self.x_channels, gain=1)
170
+ self.dense_2 = EqualizedDense(self.x_channels, gain=1)
171
+
172
+ def call(self, inputs):
173
+ x, w = inputs
174
+ ys = tf.reshape(self.dense_1(w), (-1, 1, 1, self.x_channels))
175
+ yb = tf.reshape(self.dense_2(w), (-1, 1, 1, self.x_channels))
176
+ return ys * x + yb
177
+
178
+ def Mapping(num_stages, input_shape=512):
179
+ z = layers.Input(shape=(input_shape))
180
+ w = pixel_norm(z)
181
+ for i in range(8):
182
+ w = EqualizedDense(512, learning_rate_multiplier=0.01)(w)
183
+ w = layers.LeakyReLU(0.2)(w)
184
+ w = tf.tile(tf.expand_dims(w, 1), (1, num_stages, 1))
185
+ return keras.Model(z, w, name="mapping")
186
+
187
+
188
+ class Generator:
189
+ def __init__(self, start_res_log2, target_res_log2):
190
+ self.start_res_log2 = start_res_log2
191
+ self.target_res_log2 = target_res_log2
192
+ self.num_stages = target_res_log2 - start_res_log2 + 1
193
+ # list of generator blocks at increasing resolution
194
+ self.g_blocks = []
195
+ # list of layers to convert g_block activation to RGB
196
+ self.to_rgb = []
197
+ # list of noise input of different resolutions into g_blocks
198
+ self.noise_inputs = []
199
+ # filter size to use at each stage, keys are log2(resolution)
200
+ self.filter_nums = {
201
+ 0: 512,
202
+ 1: 512,
203
+ 2: 512, # 4x4
204
+ 3: 512, # 8x8
205
+ 4: 512, # 16x16
206
+ 5: 512, # 32x32
207
+ 6: 256, # 64x64
208
+ 7: 128, # 128x128
209
+ 8: 64, # 256x256
210
+ 9: 32, # 512x512
211
+ 10: 16,
212
+ } # 1024x1024
213
+
214
+ start_res = 2 ** start_res_log2
215
+ self.input_shape = (start_res, start_res, self.filter_nums[start_res_log2])
216
+ self.g_input = layers.Input(self.input_shape, name="generator_input")
217
+
218
+ for i in range(start_res_log2, target_res_log2 + 1):
219
+ filter_num = self.filter_nums[i]
220
+ res = 2 ** i
221
+ self.noise_inputs.append(
222
+ layers.Input(shape=(res, res, 1), name=f"noise_{res}x{res}")
223
+ )
224
+ to_rgb = Sequential(
225
+ [
226
+ layers.InputLayer(input_shape=(res, res, filter_num)),
227
+ EqualizedConv(3, 1, gain=1),
228
+ ],
229
+ name=f"to_rgb_{res}x{res}",
230
+ )
231
+ self.to_rgb.append(to_rgb)
232
+ is_base = i == self.start_res_log2
233
+ if is_base:
234
+ input_shape = (res, res, self.filter_nums[i - 1])
235
+ else:
236
+ input_shape = (2 ** (i - 1), 2 ** (i - 1), self.filter_nums[i - 1])
237
+ g_block = self.build_block(
238
+ filter_num, res=res, input_shape=input_shape, is_base=is_base
239
+ )
240
+ self.g_blocks.append(g_block)
241
+
242
+ def build_block(self, filter_num, res, input_shape, is_base):
243
+ input_tensor = layers.Input(shape=input_shape, name=f"g_{res}")
244
+ noise = layers.Input(shape=(res, res, 1), name=f"noise_{res}")
245
+ w = layers.Input(shape=512)
246
+ x = input_tensor
247
+
248
+ if not is_base:
249
+ x = layers.UpSampling2D((2, 2))(x)
250
+ x = EqualizedConv(filter_num, 3)(x)
251
+
252
+ x = AddNoise()([x, noise])
253
+ x = layers.LeakyReLU(0.2)(x)
254
+ x = InstanceNormalization()(x)
255
+ x = AdaIN()([x, w])
256
+
257
+ x = EqualizedConv(filter_num, 3)(x)
258
+ x = AddNoise()([x, noise])
259
+ x = layers.LeakyReLU(0.2)(x)
260
+ x = InstanceNormalization()(x)
261
+ x = AdaIN()([x, w])
262
+ return keras.Model([input_tensor, w, noise], x, name=f"genblock_{res}x{res}")
263
+
264
+ def grow(self, res_log2):
265
+ res = 2 ** res_log2
266
+
267
+ num_stages = res_log2 - self.start_res_log2 + 1
268
+ w = layers.Input(shape=(self.num_stages, 512), name="w")
269
+
270
+ alpha = layers.Input(shape=(1), name="g_alpha")
271
+ x = self.g_blocks[0]([self.g_input, w[:, 0], self.noise_inputs[0]])
272
+
273
+ if num_stages == 1:
274
+ rgb = self.to_rgb[0](x)
275
+ else:
276
+ for i in range(1, num_stages - 1):
277
+
278
+ x = self.g_blocks[i]([x, w[:, i], self.noise_inputs[i]])
279
+
280
+ old_rgb = self.to_rgb[num_stages - 2](x)
281
+ old_rgb = layers.UpSampling2D((2, 2))(old_rgb)
282
+
283
+ i = num_stages - 1
284
+ x = self.g_blocks[i]([x, w[:, i], self.noise_inputs[i]])
285
+
286
+ new_rgb = self.to_rgb[i](x)
287
+
288
+ rgb = fade_in(alpha[0], new_rgb, old_rgb)
289
+
290
+ return keras.Model(
291
+ [self.g_input, w, self.noise_inputs, alpha],
292
+ rgb,
293
+ name=f"generator_{res}_x_{res}",
294
+ )
295
+
296
+
297
+ class Discriminator:
298
+ def __init__(self, start_res_log2, target_res_log2):
299
+ self.start_res_log2 = start_res_log2
300
+ self.target_res_log2 = target_res_log2
301
+ self.num_stages = target_res_log2 - start_res_log2 + 1
302
+ # filter size to use at each stage, keys are log2(resolution)
303
+ self.filter_nums = {
304
+ 0: 512,
305
+ 1: 512,
306
+ 2: 512, # 4x4
307
+ 3: 512, # 8x8
308
+ 4: 512, # 16x16
309
+ 5: 512, # 32x32
310
+ 6: 256, # 64x64
311
+ 7: 128, # 128x128
312
+ 8: 64, # 256x256
313
+ 9: 32, # 512x512
314
+ 10: 16,
315
+ } # 1024x1024
316
+ # list of discriminator blocks at increasing resolution
317
+ self.d_blocks = []
318
+ # list of layers to convert RGB into activation for d_blocks inputs
319
+ self.from_rgb = []
320
+
321
+ for res_log2 in range(self.start_res_log2, self.target_res_log2 + 1):
322
+ res = 2 ** res_log2
323
+ filter_num = self.filter_nums[res_log2]
324
+ from_rgb = Sequential(
325
+ [
326
+ layers.InputLayer(
327
+ input_shape=(res, res, 3), name=f"from_rgb_input_{res}"
328
+ ),
329
+ EqualizedConv(filter_num, 1),
330
+ layers.LeakyReLU(0.2),
331
+ ],
332
+ name=f"from_rgb_{res}",
333
+ )
334
+
335
+ self.from_rgb.append(from_rgb)
336
+
337
+ input_shape = (res, res, filter_num)
338
+ if len(self.d_blocks) == 0:
339
+ d_block = self.build_base(filter_num, res)
340
+ else:
341
+ d_block = self.build_block(
342
+ filter_num, self.filter_nums[res_log2 - 1], res
343
+ )
344
+
345
+ self.d_blocks.append(d_block)
346
+
347
+ def build_base(self, filter_num, res):
348
+ input_tensor = layers.Input(shape=(res, res, filter_num), name=f"d_{res}")
349
+ x = minibatch_std(input_tensor)
350
+ x = EqualizedConv(filter_num, 3)(x)
351
+ x = layers.LeakyReLU(0.2)(x)
352
+ x = layers.Flatten()(x)
353
+ x = EqualizedDense(filter_num)(x)
354
+ x = layers.LeakyReLU(0.2)(x)
355
+ x = EqualizedDense(1)(x)
356
+ return keras.Model(input_tensor, x, name=f"d_{res}")
357
+
358
+ def build_block(self, filter_num_1, filter_num_2, res):
359
+ input_tensor = layers.Input(shape=(res, res, filter_num_1), name=f"d_{res}")
360
+ x = EqualizedConv(filter_num_1, 3)(input_tensor)
361
+ x = layers.LeakyReLU(0.2)(x)
362
+ x = EqualizedConv(filter_num_2)(x)
363
+ x = layers.LeakyReLU(0.2)(x)
364
+ x = layers.AveragePooling2D((2, 2))(x)
365
+ return keras.Model(input_tensor, x, name=f"d_{res}")
366
+
367
+ def grow(self, res_log2):
368
+ res = 2 ** res_log2
369
+ idx = res_log2 - self.start_res_log2
370
+ alpha = layers.Input(shape=(1), name="d_alpha")
371
+ input_image = layers.Input(shape=(res, res, 3), name="input_image")
372
+ x = self.from_rgb[idx](input_image)
373
+ x = self.d_blocks[idx](x)
374
+ if idx > 0:
375
+ idx -= 1
376
+ downsized_image = layers.AveragePooling2D((2, 2))(input_image)
377
+ y = self.from_rgb[idx](downsized_image)
378
+ x = fade_in(alpha[0], x, y)
379
+
380
+ for i in range(idx, -1, -1):
381
+ x = self.d_blocks[i](x)
382
+ return keras.Model([input_image, alpha], x, name=f"discriminator_{res}_x_{res}")
383
+
384
+ class StyleGAN(tf.keras.Model):
385
+ def __init__(self, z_dim=512, target_res=64, start_res=4):
386
+ super(StyleGAN, self).__init__()
387
+ self.z_dim = z_dim
388
+
389
+ self.target_res_log2 = log2(target_res)
390
+ self.start_res_log2 = log2(start_res)
391
+ self.current_res_log2 = self.target_res_log2
392
+ self.num_stages = self.target_res_log2 - self.start_res_log2 + 1
393
+
394
+ self.alpha = tf.Variable(1.0, dtype=tf.float32, trainable=False, name="alpha")
395
+
396
+ self.mapping = Mapping(num_stages=self.num_stages)
397
+ self.d_builder = Discriminator(self.start_res_log2, self.target_res_log2)
398
+ self.g_builder = Generator(self.start_res_log2, self.target_res_log2)
399
+ self.g_input_shape = self.g_builder.input_shape
400
+
401
+ self.phase = None
402
+ self.train_step_counter = tf.Variable(0, dtype=tf.int32, trainable=False)
403
+
404
+ self.loss_weights = {"gradient_penalty": 10, "drift": 0.001}
405
+
406
+ def grow_model(self, res):
407
+ tf.keras.backend.clear_session()
408
+ res_log2 = log2(res)
409
+ self.generator = self.g_builder.grow(res_log2)
410
+ self.discriminator = self.d_builder.grow(res_log2)
411
+ self.current_res_log2 = res_log2
412
+ print(f"\nModel resolution:{res}x{res}")
413
+
414
+ def compile(
415
+ self, steps_per_epoch, phase, res, d_optimizer, g_optimizer, *args, **kwargs
416
+ ):
417
+ self.loss_weights = kwargs.pop("loss_weights", self.loss_weights)
418
+ self.steps_per_epoch = steps_per_epoch
419
+ if res != 2 ** self.current_res_log2:
420
+ self.grow_model(res)
421
+ self.d_optimizer = d_optimizer
422
+ self.g_optimizer = g_optimizer
423
+
424
+ self.train_step_counter.assign(0)
425
+ self.phase = phase
426
+ self.d_loss_metric = keras.metrics.Mean(name="d_loss")
427
+ self.g_loss_metric = keras.metrics.Mean(name="g_loss")
428
+ super(StyleGAN, self).compile(*args, **kwargs)
429
+
430
+ @property
431
+ def metrics(self):
432
+ return [self.d_loss_metric, self.g_loss_metric]
433
+
434
+ def generate_noise(self, batch_size):
435
+ noise = [
436
+ tf.random.normal((batch_size, 2 ** res, 2 ** res, 1))
437
+ for res in range(self.start_res_log2, self.target_res_log2 + 1)
438
+ ]
439
+ return noise
440
+
441
+ def gradient_loss(self, grad):
442
+ loss = tf.square(grad)
443
+ loss = tf.reduce_sum(loss, axis=tf.range(1, tf.size(tf.shape(loss))))
444
+ loss = tf.sqrt(loss)
445
+ loss = tf.reduce_mean(tf.square(loss - 1))
446
+ return loss
447
+
448
+ def train_step(self, real_images):
449
+
450
+ self.train_step_counter.assign_add(1)
451
+
452
+ if self.phase == "TRANSITION":
453
+ self.alpha.assign(
454
+ tf.cast(self.train_step_counter / self.steps_per_epoch, tf.float32)
455
+ )
456
+ elif self.phase == "STABLE":
457
+ self.alpha.assign(1.0)
458
+ else:
459
+ raise NotImplementedError
460
+ alpha = tf.expand_dims(self.alpha, 0)
461
+ batch_size = tf.shape(real_images)[0]
462
+ real_labels = tf.ones(batch_size)
463
+ fake_labels = -tf.ones(batch_size)
464
+
465
+ z = tf.random.normal((batch_size, self.z_dim))
466
+ const_input = tf.ones(tuple([batch_size] + list(self.g_input_shape)))
467
+ noise = self.generate_noise(batch_size)
468
+
469
+ # generator
470
+ with tf.GradientTape() as g_tape:
471
+ w = self.mapping(z)
472
+ fake_images = self.generator([const_input, w, noise, alpha])
473
+ pred_fake = self.discriminator([fake_images, alpha])
474
+ g_loss = wasserstein_loss(real_labels, pred_fake)
475
+
476
+ trainable_weights = (
477
+ self.mapping.trainable_weights + self.generator.trainable_weights
478
+ )
479
+ gradients = g_tape.gradient(g_loss, trainable_weights)
480
+ self.g_optimizer.apply_gradients(zip(gradients, trainable_weights))
481
+
482
+ # discriminator
483
+ with tf.GradientTape() as gradient_tape, tf.GradientTape() as total_tape:
484
+ # forward pass
485
+ pred_fake = self.discriminator([fake_images, alpha])
486
+ pred_real = self.discriminator([real_images, alpha])
487
+
488
+ epsilon = tf.random.uniform((batch_size, 1, 1, 1))
489
+ interpolates = epsilon * real_images + (1 - epsilon) * fake_images
490
+ gradient_tape.watch(interpolates)
491
+ pred_fake_grad = self.discriminator([interpolates, alpha])
492
+
493
+ # calculate losses
494
+ loss_fake = wasserstein_loss(fake_labels, pred_fake)
495
+ loss_real = wasserstein_loss(real_labels, pred_real)
496
+ loss_fake_grad = wasserstein_loss(fake_labels, pred_fake_grad)
497
+
498
+ # gradient penalty
499
+ gradients_fake = gradient_tape.gradient(loss_fake_grad, [interpolates])
500
+ gradient_penalty = self.loss_weights[
501
+ "gradient_penalty"
502
+ ] * self.gradient_loss(gradients_fake)
503
+
504
+ # drift loss
505
+ all_pred = tf.concat([pred_fake, pred_real], axis=0)
506
+ drift_loss = self.loss_weights["drift"] * tf.reduce_mean(all_pred ** 2)
507
+
508
+ d_loss = loss_fake + loss_real + gradient_penalty + drift_loss
509
+
510
+ gradients = total_tape.gradient(
511
+ d_loss, self.discriminator.trainable_weights
512
+ )
513
+ self.d_optimizer.apply_gradients(
514
+ zip(gradients, self.discriminator.trainable_weights)
515
+ )
516
+
517
+ # Update metrics
518
+ self.d_loss_metric.update_state(d_loss)
519
+ self.g_loss_metric.update_state(g_loss)
520
+ return {
521
+ "d_loss": self.d_loss_metric.result(),
522
+ "g_loss": self.g_loss_metric.result(),
523
+ }
524
+
525
+ def call(self, inputs: dict()):
526
+ style_code = inputs.get("style_code", None)
527
+ z = inputs.get("z", None)
528
+ noise = inputs.get("noise", None)
529
+ batch_size = inputs.get("batch_size", 1)
530
+ alpha = inputs.get("alpha", 1.0)
531
+ alpha = tf.expand_dims(alpha, 0)
532
+ if style_code is None:
533
+ if z is None:
534
+ z = tf.random.normal((batch_size, self.z_dim))
535
+ style_code = self.mapping(z)
536
+
537
+ if noise is None:
538
+ noise = self.generate_noise(batch_size)
539
+
540
+ # self.alpha.assign(alpha)
541
+
542
+ const_input = tf.ones(tuple([batch_size] + list(self.g_input_shape)))
543
+ images = self.generator([const_input, style_code, noise, alpha])
544
+ images = np.clip((images * 0.5 + 0.5) * 255, 0, 255).astype(np.uint8)
545
+
546
+ return images
547
+
548
+ # Set up GAN
549
+
550
+ batch_sizes = {2: 16, 3: 16, 4: 16, 5: 16, 6: 16, 7: 8, 8: 4, 9: 2, 10: 1}
551
+ train_step_ratio = {k: batch_sizes[2] / v for k, v in batch_sizes.items()}
552
+
553
+ START_RES = 4
554
+ TARGET_RES = 128
555
+
556
+ # style_gan = StyleGAN(start_res=START_RES, target_res=TARGET_RES)
557
+
558
+ url = "https://github.com/soon-yau/stylegan_keras/releases/download/keras_example_v1.0/stylegan_128x128.ckpt.zip"
559
+
560
+ weights_path = keras.utils.get_file(
561
+ "stylegan_128x128.ckpt.zip",
562
+ url,
563
+ extract=True,
564
+ cache_dir=os.path.abspath("."),
565
+ cache_subdir="pretrained",
566
+ )
567
+
568
+ # style_gan.grow_model(128)
569
+ # style_gan.load_weights(os.path.join("pretrained/stylegan_128x128.ckpt"))
570
+
571
+ # tf.random.set_seed(196)
572
+ # batch_size = 2
573
+ # z = tf.random.normal((batch_size, style_gan.z_dim))
574
+ # w = style_gan.mapping(z)
575
+ # noise = style_gan.generate_noise(batch_size=batch_size)
576
+ # images = style_gan({"style_code": w, "noise": noise, "alpha": 1.0})
577
+
578
+ # plot_images(images, 5)
579
+
580
+ class InferenceWrapper:
581
+ def __init__(self, model):
582
+ self.model = model
583
+ self.style_gan = StyleGAN(start_res=START_RES, target_res=TARGET_RES)
584
+ self.style_gan.grow_model(128)
585
+ self.style_gan.load_weights(os.path.join("pretrained/stylegan_128x128.ckpt"))
586
+ self.seed = 196
587
+
588
+ def __call__(self, seed, feature):
589
+ if seed != self.seed:
590
+ print(f"Loading model: {self.model}")
591
+ tf.random.set_seed(196)
592
+ batch_size = 1
593
+ self.z = tf.random.normal((batch_size, self.style_gan.z_dim))
594
+ self.w = self.style_gan.mapping(z)
595
+ self.noise = self.style_gan.generate_noise(batch_size=batch_size)
596
+ else:
597
+ print(f"Model '{self.model}' already loaded, reusing it.")
598
+ return self.style_gan({"style_code": self.w, "noise": self.noise, "alpha": 1.0})[0]
599
+
600
+
601
+ wrapper = InferenceWrapper('celeba')
602
 
603
+ def fn(seed, feature):
604
+ return wrapper(seed, feature)
605
 
606
+ gr.Interface(
607
+ fn,
608
+ inputs=[
609
+ gr.inputs.Slider(minimum=0, maximum=999999999, step=1, default=0, label='Random Seed'),
610
+ gr.inputs.Radio(list("test1","test2"), type="value", default='test1', label='Feature Type')
611
+ ],
612
+ outputs='image',
613
+ examples=[[343, 'test1'], [456, 'test2']],
614
+ enable_queue=True
615
+ ).launch()