|
from tensorflow.keras import Model |
|
|
|
import tensorflow as tf |
|
import tensorflow_probability as tfp |
|
|
|
|
|
class MovingVAE(Model): |
|
def __init__(self, input_shape, encoded_size=64, base_depth=32): |
|
super().__init__() |
|
|
|
self.encoded_size = encoded_size |
|
self.base_depth = base_depth |
|
|
|
self.prior = tfp.distributions.Independent( |
|
tfp.distributions.Normal(loc=tf.zeros(encoded_size), scale=1), |
|
reinterpreted_batch_ndims=1, |
|
) |
|
|
|
self.encoder = tf.keras.Sequential( |
|
[ |
|
tf.keras.layers.InputLayer(input_shape=input_shape), |
|
tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5), |
|
tf.keras.layers.Conv3D( |
|
self.base_depth, |
|
5, |
|
strides=1, |
|
padding="same", |
|
activation=tf.nn.leaky_relu, |
|
), |
|
tf.keras.layers.Conv3D( |
|
self.base_depth, |
|
5, |
|
strides=2, |
|
padding="same", |
|
activation=tf.nn.leaky_relu, |
|
), |
|
tf.keras.layers.Conv3D( |
|
2 * self.base_depth, |
|
5, |
|
strides=1, |
|
padding="same", |
|
activation=tf.nn.leaky_relu, |
|
), |
|
tf.keras.layers.Conv3D( |
|
2 * self.base_depth, |
|
5, |
|
strides=2, |
|
padding="same", |
|
activation=tf.nn.leaky_relu, |
|
), |
|
|
|
|
|
tf.keras.layers.Flatten(), |
|
tf.keras.layers.Dense( |
|
tfp.layers.MultivariateNormalTriL.params_size(self.encoded_size), |
|
activation=None, |
|
), |
|
tfp.layers.MultivariateNormalTriL( |
|
self.encoded_size, |
|
activity_regularizer=tfp.layers.KLDivergenceRegularizer(self.prior), |
|
), |
|
] |
|
) |
|
|
|
self.decoder = tf.keras.Sequential( |
|
[ |
|
tf.keras.layers.InputLayer(input_shape=[self.encoded_size]), |
|
tf.keras.layers.Reshape([1, 1, 1, self.encoded_size]), |
|
tf.keras.layers.Conv3DTranspose( |
|
self.base_depth, |
|
(5, 4, 4), |
|
strides=1, |
|
padding="valid", |
|
activation=tf.nn.leaky_relu, |
|
), |
|
tf.keras.layers.Conv3DTranspose( |
|
2 * self.base_depth, |
|
(5, 4, 4), |
|
strides=(1, 2, 2), |
|
padding="same", |
|
activation=tf.nn.leaky_relu, |
|
), |
|
tf.keras.layers.Conv3DTranspose( |
|
2 * self.base_depth, |
|
(5, 4, 4), |
|
strides=2, |
|
padding="same", |
|
activation=tf.nn.leaky_relu, |
|
), |
|
tf.keras.layers.Conv3DTranspose( |
|
self.base_depth, |
|
(5, 4, 4), |
|
strides=(1, 2, 2), |
|
padding="same", |
|
activation=tf.nn.leaky_relu, |
|
), |
|
tf.keras.layers.Conv3DTranspose( |
|
self.base_depth, |
|
(5, 4, 4), |
|
strides=2, |
|
padding="same", |
|
activation=tf.nn.leaky_relu, |
|
), |
|
tf.keras.layers.Conv3DTranspose( |
|
self.base_depth, |
|
(5, 4, 4), |
|
strides=1, |
|
padding="same", |
|
activation=tf.nn.leaky_relu, |
|
), |
|
tf.keras.layers.Conv2D( |
|
filters=1, kernel_size=5, strides=1, padding="same", activation=None |
|
), |
|
tf.keras.layers.Flatten(), |
|
tfp.layers.IndependentBernoulli( |
|
input_shape, tfp.distributions.Bernoulli.logits |
|
), |
|
] |
|
) |
|
|
|
self.model = tf.keras.Model( |
|
inputs=self.encoder.inputs, outputs=self.decoder(self.encoder.outputs[0]) |
|
) |
|
|
|
def call(self, inputs): |
|
return self.model(inputs) |
|
|