File size: 4,520 Bytes
3be620b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from tensorflow.keras import Model
import tensorflow as tf
import tensorflow_probability as tfp
class MovingVAE(Model):
def __init__(self, input_shape, encoded_size=64, base_depth=32):
super().__init__()
self.encoded_size = encoded_size
self.base_depth = base_depth
self.prior = tfp.distributions.Independent(
tfp.distributions.Normal(loc=tf.zeros(encoded_size), scale=1),
reinterpreted_batch_ndims=1,
)
self.encoder = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=input_shape),
tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5),
tf.keras.layers.Conv3D(
self.base_depth,
5,
strides=1,
padding="same",
activation=tf.nn.leaky_relu,
),
tf.keras.layers.Conv3D(
self.base_depth,
5,
strides=2,
padding="same",
activation=tf.nn.leaky_relu,
),
tf.keras.layers.Conv3D(
2 * self.base_depth,
5,
strides=1,
padding="same",
activation=tf.nn.leaky_relu,
),
tf.keras.layers.Conv3D(
2 * self.base_depth,
5,
strides=2,
padding="same",
activation=tf.nn.leaky_relu,
),
# tf.keras.layers.Conv3D(4 * encoded_size, 7, strides=1,
# padding='valid', activation=tf.nn.leaky_relu),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(
tfp.layers.MultivariateNormalTriL.params_size(self.encoded_size),
activation=None,
),
tfp.layers.MultivariateNormalTriL(
self.encoded_size,
activity_regularizer=tfp.layers.KLDivergenceRegularizer(self.prior),
),
]
)
self.decoder = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=[self.encoded_size]),
tf.keras.layers.Reshape([1, 1, 1, self.encoded_size]),
tf.keras.layers.Conv3DTranspose(
self.base_depth,
(5, 4, 4),
strides=1,
padding="valid",
activation=tf.nn.leaky_relu,
),
tf.keras.layers.Conv3DTranspose(
2 * self.base_depth,
(5, 4, 4),
strides=(1, 2, 2),
padding="same",
activation=tf.nn.leaky_relu,
),
tf.keras.layers.Conv3DTranspose(
2 * self.base_depth,
(5, 4, 4),
strides=2,
padding="same",
activation=tf.nn.leaky_relu,
),
tf.keras.layers.Conv3DTranspose(
self.base_depth,
(5, 4, 4),
strides=(1, 2, 2),
padding="same",
activation=tf.nn.leaky_relu,
),
tf.keras.layers.Conv3DTranspose(
self.base_depth,
(5, 4, 4),
strides=2,
padding="same",
activation=tf.nn.leaky_relu,
),
tf.keras.layers.Conv3DTranspose(
self.base_depth,
(5, 4, 4),
strides=1,
padding="same",
activation=tf.nn.leaky_relu,
),
tf.keras.layers.Conv2D(
filters=1, kernel_size=5, strides=1, padding="same", activation=None
),
tf.keras.layers.Flatten(),
tfp.layers.IndependentBernoulli(
input_shape, tfp.distributions.Bernoulli.logits
),
]
)
self.model = tf.keras.Model(
inputs=self.encoder.inputs, outputs=self.decoder(self.encoder.outputs[0])
)
def call(self, inputs):
return self.model(inputs)
|