import tensorflow as tf from tensorflow import keras from tensorflow.keras import regularizers import numpy as np import tensorflow_probability as tfp #Affine Coupling Layer ## Creating a custom layer with keras API. output_dim = 256 reg = 0.01 def Coupling(input_shape): input = keras.layers.Input(shape=input_shape) t_layer_1 = keras.layers.Dense( output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) )(input) t_layer_2 = keras.layers.Dense( output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) )(t_layer_1) t_layer_3 = keras.layers.Dense( output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) )(t_layer_2) t_layer_4 = keras.layers.Dense( output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) )(t_layer_3) t_layer_5 = keras.layers.Dense( input_shape, activation="linear", kernel_regularizer=regularizers.l2(reg) )(t_layer_4) s_layer_1 = keras.layers.Dense( output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) )(input) s_layer_2 = keras.layers.Dense( output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) )(s_layer_1) s_layer_3 = keras.layers.Dense( output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) )(s_layer_2) s_layer_4 = keras.layers.Dense( output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg) )(s_layer_3) s_layer_5 = keras.layers.Dense( input_shape, activation="tanh", kernel_regularizer=regularizers.l2(reg) )(s_layer_4) return keras.Model(inputs=input, outputs=[s_layer_5, t_layer_5]) #Real NVP class RealNVP(keras.Model): def __init__(self, num_coupling_layers): super(RealNVP, self).__init__() self.num_coupling_layers = num_coupling_layers # Distribution of the latent space. self.distribution = tfp.distributions.MultivariateNormalDiag( loc=[0.0, 0.0], scale_diag=[1.0, 1.0] ) self.masks = np.array( [[0, 1], [1, 0]] * (num_coupling_layers // 2), dtype="float32" ) self.loss_tracker = keras.metrics.Mean(name="loss") self.layers_list = [Coupling(2) for i in range(num_coupling_layers)] @property def metrics(self): """List of the model's metrics. We make sure the loss tracker is listed as part of `model.metrics` so that `fit()` and `evaluate()` are able to `reset()` the loss tracker at the start of each epoch and at the start of an `evaluate()` call. """ return [self.loss_tracker] def call(self, x, training=True): log_det_inv = 0 direction = 1 if training: direction = -1 for i in range(self.num_coupling_layers)[::direction]: x_masked = x * self.masks[i] reversed_mask = 1 - self.masks[i] s, t = self.layers_list[i](x_masked) s *= reversed_mask t *= reversed_mask gate = (direction - 1) / 2 x = ( reversed_mask * (x * tf.exp(direction * s) + direction * t * tf.exp(gate * s)) + x_masked ) log_det_inv += gate * tf.reduce_sum(s, [1]) return x, log_det_inv # Log likelihood of the normal distribution plus the log determinant of the jacobian. def log_loss(self, x): y, logdet = self(x) log_likelihood = self.distribution.log_prob(y) + logdet return -tf.reduce_mean(log_likelihood) def train_step(self, data): with tf.GradientTape() as tape: loss = self.log_loss(data) g = tape.gradient(loss, self.trainable_variables) self.optimizer.apply_gradients(zip(g, self.trainable_variables)) self.loss_tracker.update_state(loss) return {"loss": self.loss_tracker.result()} def test_step(self, data): loss = self.log_loss(data) self.loss_tracker.update_state(loss) return {"loss": self.loss_tracker.result()} def load_model(): return RealNVP(num_coupling_layers=6)