import os import keras.regularizers import tensorflow as tf from keras.layers import InputLayer, Conv2D, Flatten, BatchNormalization, Dense, UpSampling2D, Reshape, Dropout, Add import keras.backend as tfkbk import numpy as np from blocks import ResidualBlock from keras.layers import LeakyReLU, PReLU INPUT_SHAPE = (64, 64) LATENT_DIM = 512 def get_encoder(): encoder = tf.keras.Sequential(name="encoder") encoder.add(InputLayer(input_shape=(*INPUT_SHAPE, 1))) encoder.add(Conv2D(32, 3, activation=PReLU(), padding='same', kernel_initializer='he_uniform')) encoder.add(Conv2D(32, 3, activation=PReLU(), padding='same', strides=2, kernel_initializer='he_uniform')) encoder.add(Conv2D(64, 3, activation=PReLU(), padding='same', kernel_initializer='he_uniform')) encoder.add(Conv2D(64, 3, activation=PReLU(), padding='same', strides=2, kernel_initializer='he_uniform')) encoder.add(Conv2D(128, 3, activation=PReLU(), padding='same', kernel_initializer='he_uniform')) encoder.add(Conv2D(128, 3, activation=PReLU(), padding='same', strides=2, kernel_initializer='he_uniform')) encoder.add(Flatten()) encoder.add(Dense(LATENT_DIM * 2, activation=PReLU(), activity_regularizer=tf.keras.regularizers.L2(10e-6))) return encoder def get_decoder(): inputs = tf.keras.layers.Input(shape=[LATENT_DIM, ]) x = inputs x = Dense(8 * 8 * 16, activation='relu')(x) x = Dense(8 * 8 * 16, activation='relu')(x) x = Reshape(target_shape=(8, 8, 16))(x) x = UpSampling2D(2)(x) x = Conv2D(128, 3, activation=LeakyReLU(), padding='same', kernel_initializer='he_uniform')(x) x = ResidualBlock(128, 3, seed=42, name="res1", padding="reflect")(x) x = ResidualBlock(128, 3, seed=42, name="res2", padding="reflect")(x) x = UpSampling2D(2)(x) x = Conv2D(64, 3, activation=LeakyReLU(), padding='same', kernel_initializer='he_uniform')(x) x = ResidualBlock(64, 3, seed=42, name="res4", padding="reflect")(x) x = ResidualBlock(64, 3, seed=42, name="res5", padding="reflect")(x) x = UpSampling2D(2)(x) x = Conv2D(32, 3, activation=LeakyReLU(), padding='same', kernel_initializer='he_uniform')(x) x = ResidualBlock(32, 3, seed=42, name="res7", padding="reflect")(x) x = ResidualBlock(32, 3, seed=42, name="res8", padding="reflect")(x) x = Conv2D(1, 3, padding='same', kernel_initializer='he_uniform')(x) return tf.keras.Model(inputs=inputs, outputs=x) class CVAE(tf.keras.Model): def __init__(self, encoder: tf.keras.models.Model, decoder: tf.keras.models.Model, latent_dim, kl_weight=1, loss_fun='bce', include_regularization: bool = False): super(CVAE, self).__init__() self.kl_weight = kl_weight self.latent_dim = latent_dim self.loss_fun = loss_fun self.encoder = encoder self.decoder = decoder self.kl_loss = 0 self.reconstruction_loss = 0 self.include_regularization = include_regularization def call(self, inputs, training=None, mask=None): z_mean, z_log_var = tf.split(self.encoder(inputs), num_or_size_splits=2, axis=1) z = self.sampling(z_mean, z_log_var, self.latent_dim) # z_mean, z_log_var, z = self.encoder(inputs) outputs = self.decoder(z) if training: regularization_loss = tf.math.reduce_sum(self.encoder.losses) if self.loss_fun == 'elbo': cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=outputs, labels=inputs) logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3]) logpz = self.log_normal_pdf(z, 0., 0.) logqz_x = self.log_normal_pdf(z, z_mean, z_log_var) vae_loss = -tf.reduce_mean(logpx_z + logpz - logqz_x) else: kl_loss = 1 + z_log_var - tf.math.square(z_mean) - tf.math.exp(z_log_var) kl_loss = tf.math.reduce_sum(kl_loss, axis=-1) kl_loss *= -0.5 * self.kl_weight self.kl_loss = kl_loss if self.loss_fun == 'mse': reconstruction_loss = tf.keras.metrics.mean_squared_error(tfkbk.flatten(inputs), tfkbk.flatten(outputs)) elif self.loss_fun == 'bce': reconstruction_loss = tf.keras.metrics.binary_crossentropy(tfkbk.flatten(inputs), tfkbk.flatten(outputs)) else: raise ValueError reconstruction_loss *= (inputs.shape[1] * inputs.shape[1]) self.reconstruction_loss = reconstruction_loss vae_loss = tf.math.reduce_mean(reconstruction_loss + kl_loss) if self.include_regularization: vae_loss += regularization_loss self.add_loss(vae_loss) return outputs @staticmethod def sampling(z_mean, z_log_var, latent_dim): batch = tf.shape(z_mean)[0] epsilon = tf.keras.backend.random_normal(shape=(batch, latent_dim)) return z_mean + tf.exp(0.5 * z_log_var) * epsilon @staticmethod def log_normal_pdf(sample, mean, logvar, raxis=1): log2pi = tf.math.log(2. * np.pi) return tf.reduce_sum( -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi), axis=raxis)