| import tensorflow as tf
|
| from tensorflow.keras import layers
|
|
|
| def downsample(filters, size, apply_instancenorm=True):
|
| initializer = tf.random_normal_initializer(0., 0.02)
|
| result = tf.keras.Sequential()
|
| result.add(layers.Conv2D(filters, size, strides=2, padding='same',
|
| kernel_initializer=initializer, use_bias=False))
|
| if apply_instancenorm:
|
| result.add(tf.keras.layers.GroupNormalization(groups=-1))
|
| result.add(layers.LeakyReLU())
|
| return result
|
|
|
| def upsample(filters, size, apply_dropout=False):
|
| initializer = tf.random_normal_initializer(0., 0.02)
|
| result = tf.keras.Sequential()
|
| result.add(layers.Conv2DTranspose(filters, size, strides=2, padding='same',
|
| kernel_initializer=initializer, use_bias=False))
|
| result.add(tf.keras.layers.GroupNormalization(groups=-1))
|
| if apply_dropout:
|
| result.add(layers.Dropout(0.5))
|
| result.add(layers.ReLU())
|
| return result
|
|
|
| def resnet_block(filters, size=3):
|
| initializer = tf.random_normal_initializer(0., 0.02)
|
| result = tf.keras.Sequential()
|
| result.add(layers.Conv2D(filters, size, padding='same', kernel_initializer=initializer, use_bias=False))
|
| result.add(tf.keras.layers.GroupNormalization(groups=-1))
|
| result.add(layers.ReLU())
|
| result.add(layers.Conv2D(filters, size, padding='same', kernel_initializer=initializer, use_bias=False))
|
| result.add(tf.keras.layers.GroupNormalization(groups=-1))
|
| return result
|
|
|
| def Generator(output_channels=3, num_resnet=9):
|
| inputs = layers.Input(shape=[256, 256, 3])
|
|
|
|
|
| x = layers.Conv2D(64, 7, padding='same', kernel_initializer=tf.random_normal_initializer(0., 0.02), use_bias=False)(inputs)
|
| x = tf.keras.layers.GroupNormalization(groups=-1)(x)
|
| x = layers.ReLU()(x)
|
|
|
| x = downsample(128, 3)(x)
|
| x = downsample(256, 3)(x)
|
|
|
|
|
| for _ in range(num_resnet):
|
| res = resnet_block(256)(x)
|
| x = layers.Add()([x, res])
|
|
|
|
|
| x = upsample(128, 3)(x)
|
| x = upsample(64, 3)(x)
|
|
|
| last = layers.Conv2D(output_channels, 7, padding='same', activation='tanh',
|
| kernel_initializer=tf.random_normal_initializer(0., 0.02))(x)
|
|
|
| return tf.keras.Model(inputs=inputs, outputs=last)
|
|
|
| def Discriminator():
|
| initializer = tf.random_normal_initializer(0., 0.02)
|
| inputs = layers.Input(shape=[256, 256, 3])
|
|
|
| down1 = downsample(64, 4, False)(inputs)
|
| down2 = downsample(128, 4)(down1)
|
| down3 = downsample(256, 4)(down2)
|
|
|
| zero_pad1 = layers.ZeroPadding2D()(down3)
|
| conv = layers.Conv2D(512, 4, strides=1, kernel_initializer=initializer, use_bias=False)(zero_pad1)
|
| norm1 = tf.keras.layers.GroupNormalization(groups=-1)(conv)
|
| leaky_relu = layers.LeakyReLU()(norm1)
|
|
|
| zero_pad2 = layers.ZeroPadding2D()(leaky_relu)
|
| last = layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(zero_pad2)
|
|
|
| return tf.keras.Model(inputs=inputs, outputs=last)
|
|
|