|
from tqdm.auto import tqdm |
|
import numpy as np |
|
import tensorflow as tf |
|
from tensorflow.keras import Model, Sequential |
|
from tensorflow.keras.layers import ( |
|
LSTM, |
|
LSTMCell, |
|
Activation, |
|
BatchNormalization, |
|
Conv2D, |
|
Conv2DTranspose, |
|
Conv3D, |
|
Conv3DTranspose, |
|
Dense, |
|
Flatten, |
|
Input, |
|
Layer, |
|
LeakyReLU, |
|
MaxPooling2D, |
|
Reshape, |
|
TimeDistributed, |
|
UpSampling2D, |
|
) |
|
from tensorflow.keras.losses import Loss |
|
from tensorflow.keras.losses import KLDivergence, MeanSquaredError |
|
|
|
|
|
|
|
|
|
|
|
from tensorflow.keras.losses import Loss |
|
|
|
initializer_conv_dense = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02) |
|
initializer_batch_norm = tf.keras.initializers.RandomNormal(mean=1.0, stddev=0.02) |
|
|
|
|
|
class KLCriterion(Loss): |
|
def call(self, y_true, y_pred): |
|
(mu1, logvar1), (mu2, logvar2) = y_true, y_pred |
|
|
|
"""KL( N(mu_1, sigma2_1) || N(mu_2, sigma2_2))""" |
|
sigma1 = tf.exp(tf.math.multiply(logvar1, 0.5)) |
|
sigma2 = tf.exp(tf.math.multiply(logvar2, 0.5)) |
|
|
|
kld = ( |
|
tf.math.log(sigma2 / sigma1) |
|
+ (tf.exp(logvar1) + tf.square(mu1 - mu2)) / (2 * tf.exp(logvar2)) |
|
- 0.5 |
|
) |
|
return tf.reduce_sum(kld) / 100 |
|
|
|
|
|
class Encoder(Model): |
|
def __init__(self, dim, nc=1): |
|
super().__init__() |
|
self.dim = dim |
|
self.c1 = Sequential( |
|
[ |
|
Conv2D( |
|
64, |
|
kernel_size=4, |
|
strides=2, |
|
padding="same", |
|
kernel_initializer=initializer_conv_dense, |
|
), |
|
|
|
LeakyReLU(alpha=0.2), |
|
] |
|
) |
|
self.c2 = Sequential( |
|
[ |
|
Conv2D( |
|
128, |
|
kernel_size=4, |
|
strides=2, |
|
padding="same", |
|
kernel_initializer=initializer_conv_dense, |
|
), |
|
|
|
LeakyReLU(alpha=0.2), |
|
] |
|
) |
|
self.c3 = Sequential( |
|
[ |
|
Conv2D( |
|
256, |
|
kernel_size=4, |
|
strides=2, |
|
padding="same", |
|
kernel_initializer=initializer_conv_dense, |
|
), |
|
|
|
LeakyReLU(alpha=0.2), |
|
] |
|
) |
|
self.c4 = Sequential( |
|
[ |
|
Conv2D( |
|
512, |
|
kernel_size=4, |
|
strides=2, |
|
padding="same", |
|
kernel_initializer=initializer_conv_dense, |
|
), |
|
|
|
LeakyReLU(alpha=0.2), |
|
] |
|
) |
|
self.c5 = Sequential( |
|
[ |
|
Conv2D( |
|
self.dim, |
|
kernel_size=4, |
|
strides=1, |
|
padding="valid", |
|
kernel_initializer=initializer_conv_dense, |
|
), |
|
|
|
Activation("tanh"), |
|
] |
|
) |
|
|
|
def call(self, input): |
|
h1 = self.c1(input) |
|
h2 = self.c2(h1) |
|
h3 = self.c3(h2) |
|
h4 = self.c4(h3) |
|
h5 = self.c5(h4) |
|
return tf.reshape(h5, (-1, self.dim)), [h1, h2, h3, h4, h5] |
|
|
|
|
|
class Decoder(Model): |
|
def __init__(self, dim, nc=1): |
|
super().__init__() |
|
self.dim = dim |
|
self.upc1 = Sequential( |
|
[ |
|
Conv2DTranspose( |
|
512, |
|
kernel_size=4, |
|
strides=1, |
|
padding="valid", |
|
kernel_initializer=initializer_conv_dense, |
|
), |
|
|
|
LeakyReLU(alpha=0.2), |
|
] |
|
) |
|
self.upc2 = Sequential( |
|
[ |
|
Conv2DTranspose( |
|
256, |
|
kernel_size=4, |
|
strides=2, |
|
padding="same", |
|
kernel_initializer=initializer_conv_dense, |
|
), |
|
|
|
LeakyReLU(alpha=0.2), |
|
] |
|
) |
|
self.upc3 = Sequential( |
|
[ |
|
Conv2DTranspose( |
|
128, |
|
kernel_size=4, |
|
strides=2, |
|
padding="same", |
|
kernel_initializer=initializer_conv_dense, |
|
), |
|
|
|
LeakyReLU(alpha=0.2), |
|
] |
|
) |
|
self.upc4 = Sequential( |
|
[ |
|
Conv2DTranspose( |
|
64, |
|
kernel_size=4, |
|
strides=2, |
|
padding="same", |
|
kernel_initializer=initializer_conv_dense, |
|
), |
|
|
|
LeakyReLU(alpha=0.2), |
|
] |
|
) |
|
self.upc5 = Sequential( |
|
[ |
|
Conv2DTranspose( |
|
1, |
|
kernel_size=4, |
|
strides=2, |
|
padding="same", |
|
kernel_initializer=initializer_conv_dense, |
|
), |
|
Activation("sigmoid"), |
|
] |
|
) |
|
|
|
def call(self, input): |
|
vec, skip = input |
|
d1 = self.upc1(tf.reshape(vec, (-1, 1, 1, self.dim))) |
|
d2 = self.upc2(tf.concat([d1, skip[3]], axis=-1)) |
|
d3 = self.upc3(tf.concat([d2, skip[2]], axis=-1)) |
|
d4 = self.upc4(tf.concat([d3, skip[1]], axis=-1)) |
|
output = self.upc5(tf.concat([d4, skip[0]], axis=-1)) |
|
return output |
|
|
|
|
|
class MyLSTM(Model): |
|
def __init__(self, input_shape, hidden_size, output_size, n_layers): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.n_layers = n_layers |
|
self.embed = Dense( |
|
hidden_size, |
|
input_dim=input_shape, |
|
kernel_initializer=initializer_conv_dense, |
|
) |
|
|
|
|
|
|
|
|
|
self.lstm = [ |
|
LSTMCell( |
|
hidden_size |
|
) |
|
for i in range(self.n_layers) |
|
] |
|
self.lstm_rnn = tf.keras.layers.RNN(self.lstm[0], return_state=True) |
|
self.out = Dense(output_size, kernel_initializer=initializer_conv_dense) |
|
|
|
def init_hidden(self, batch_size): |
|
hidden = [] |
|
for i in range(self.n_layers): |
|
hidden.append( |
|
( |
|
tf.Variable(tf.zeros([batch_size, self.hidden_size])), |
|
tf.Variable(tf.zeros([batch_size, self.hidden_size])), |
|
) |
|
) |
|
self.__dict__["hidden"] = hidden |
|
|
|
def build(self, input_shape): |
|
self.init_hidden(input_shape[0]) |
|
|
|
def call(self, inputs): |
|
h_in = self.embed(inputs) |
|
h_in = tf.reshape(h_in, (-1, 1, self.hidden_size)) |
|
h_in, *state = self.lstm_rnn(h_in) |
|
for i in range(self.n_layers): |
|
h_in, state = self.lstm[i](h_in, state) |
|
return self.out(h_in) |
|
|
|
|
|
class MyGaussianLSTM(Model): |
|
def __init__(self, input_shape, hidden_size, output_size, n_layers): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.n_layers = n_layers |
|
self.embed = Dense( |
|
hidden_size, |
|
input_dim=input_shape, |
|
kernel_initializer=initializer_conv_dense, |
|
) |
|
|
|
|
|
|
|
self.lstm = [ |
|
LSTMCell( |
|
hidden_size |
|
) |
|
for i in range(self.n_layers) |
|
] |
|
self.lstm_rnn = tf.keras.layers.RNN(self.lstm[0], return_state=True) |
|
self.mu_net = Dense(output_size, kernel_initializer=initializer_conv_dense) |
|
self.logvar_net = Dense(output_size, kernel_initializer=initializer_conv_dense) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reparameterize(self, mu, logvar: tf.Tensor): |
|
logvar = tf.math.exp(logvar * 0.5) |
|
eps = tf.random.normal(logvar.shape) |
|
return tf.add(tf.math.multiply(eps, logvar), mu) |
|
|
|
def init_hidden(self, batch_size): |
|
hidden = [] |
|
for i in range(self.n_layers): |
|
hidden.append( |
|
( |
|
tf.Variable(tf.zeros([batch_size, self.hidden_size])), |
|
tf.Variable(tf.zeros([batch_size, self.hidden_size])), |
|
) |
|
) |
|
self.__dict__["hidden"] = hidden |
|
|
|
def build(self, input_shape): |
|
self.init_hidden(input_shape[0]) |
|
|
|
def call(self, inputs): |
|
h_in = self.embed(inputs) |
|
|
|
|
|
|
|
|
|
|
|
h_in = tf.reshape(h_in, (-1, 1, self.hidden_size)) |
|
h_in, *state = self.lstm_rnn(h_in) |
|
for i in range(self.n_layers): |
|
h_in, state = self.lstm[i](h_in, state) |
|
|
|
mu = self.mu_net(h_in) |
|
logvar = self.logvar_net(h_in) |
|
z = self.reparameterize(mu, logvar) |
|
return z, mu, logvar |
|
|
|
|
|
class P2P(Model): |
|
def __init__( |
|
self, |
|
channels: int = 1, |
|
g_dim: int = 128, |
|
z_dim: int = 10, |
|
rnn_size: int = 256, |
|
prior_rnn_layers: int = 1, |
|
posterior_rnn_layers: int = 1, |
|
predictor_rnn_layers: float = 2, |
|
skip_prob: float = 0.5, |
|
n_past: int = 1, |
|
last_frame_skip: bool = False, |
|
beta: float = 0.0001, |
|
weight_align: float = 0.1, |
|
weight_cpc: float = 100, |
|
): |
|
super().__init__() |
|
self.channels = channels |
|
self.g_dim = g_dim |
|
self.z_dim = z_dim |
|
self.rnn_size = rnn_size |
|
self.prior_rnn_layers = prior_rnn_layers |
|
self.posterior_rnn_layers = posterior_rnn_layers |
|
self.predictor_rnn_layers = predictor_rnn_layers |
|
|
|
self.skip_prob = skip_prob |
|
self.n_past = n_past |
|
self.last_frame_skip = last_frame_skip |
|
self.beta = beta |
|
self.weight_align = weight_align |
|
self.weight_cpc = weight_cpc |
|
|
|
self.frame_predictor = MyLSTM( |
|
self.g_dim + self.z_dim + 1 + 1, |
|
self.rnn_size, |
|
self.g_dim, |
|
self.predictor_rnn_layers, |
|
) |
|
|
|
self.prior = MyGaussianLSTM( |
|
self.g_dim + self.g_dim + 1 + 1, |
|
self.rnn_size, |
|
self.z_dim, |
|
self.prior_rnn_layers, |
|
) |
|
|
|
self.posterior = MyGaussianLSTM( |
|
self.g_dim + self.g_dim + 1 + 1, |
|
self.rnn_size, |
|
self.z_dim, |
|
self.posterior_rnn_layers, |
|
) |
|
|
|
self.encoder = Encoder(self.g_dim, self.channels) |
|
self.decoder = Decoder(self.g_dim, self.channels) |
|
|
|
|
|
self.mse_criterion = tf.keras.losses.MeanSquaredError() |
|
self.kl_criterion = KLCriterion() |
|
self.align_criterion = tf.keras.losses.MeanSquaredError() |
|
|
|
self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss") |
|
self.reconstruction_loss_tracker = tf.keras.metrics.Mean( |
|
name="reconstruction_loss" |
|
) |
|
self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss") |
|
self.align_loss_tracker = tf.keras.metrics.Mean(name="align_loss") |
|
self.cpc_loss_tracker = tf.keras.metrics.Mean(name="align_loss") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property |
|
def metrics(self): |
|
return [ |
|
self.total_loss_tracker, |
|
self.reconstruction_loss_tracker, |
|
self.kl_loss_tracker, |
|
self.align_loss_tracker, |
|
self.cpc_loss_tracker, |
|
] |
|
|
|
def get_global_descriptor(self, x, start_ix=0, cp_ix=None): |
|
"""Get the global descriptor based on x, start_ix, cp_ix.""" |
|
if cp_ix is None: |
|
cp_ix = x.shape[1] - 1 |
|
|
|
x_cp = x[:, cp_ix, ...] |
|
h_cp = self.encoder(x_cp)[0] |
|
|
|
return x_cp, h_cp |
|
|
|
def compile( |
|
self, |
|
frame_predictor_optimizer, |
|
prior_optimizer, |
|
posterior_optimizer, |
|
encoder_optimizer, |
|
decoder_optimizer, |
|
): |
|
super().compile() |
|
self.frame_predictor_optimizer = frame_predictor_optimizer |
|
self.prior_optimizer = prior_optimizer |
|
self.posterior_optimizer = posterior_optimizer |
|
self.encoder_optimizer = encoder_optimizer |
|
self.decoder_optimizer = decoder_optimizer |
|
|
|
def train_step(self, data): |
|
y, x = data |
|
batch_size = 100 |
|
|
|
mse_loss = 0 |
|
kld_loss = 0 |
|
cpc_loss = 0 |
|
align_loss = 0 |
|
|
|
seq_len = x.shape[1] |
|
start_ix = 0 |
|
cp_ix = seq_len - 1 |
|
x_cp, global_z = self.get_global_descriptor( |
|
x, start_ix, cp_ix |
|
) |
|
|
|
skip_prob = self.skip_prob |
|
|
|
prev_i = 0 |
|
max_skip_count = seq_len * skip_prob |
|
skip_count = 0 |
|
probs = np.random.uniform(low=0, high=1, size=seq_len - 1) |
|
|
|
with tf.GradientTape(persistent=True) as tape: |
|
for i in tqdm(range(1, seq_len)): |
|
if ( |
|
probs[i - 1] <= skip_prob |
|
and i >= self.n_past |
|
and skip_count < max_skip_count |
|
and i != 1 |
|
and i != cp_ix |
|
): |
|
skip_count += 1 |
|
continue |
|
|
|
if i > 1: |
|
align_loss += self.align_criterion(h, h_pred) |
|
|
|
time_until_cp = tf.fill( |
|
[batch_size, 1], |
|
(cp_ix - i + 1) / cp_ix, |
|
) |
|
delta_time = tf.fill([batch_size, 1], ((i - prev_i) / cp_ix)) |
|
prev_i = i |
|
|
|
h = self.encoder(x[:, i - 1, ...]) |
|
h_target = self.encoder(x[:, i, ...])[0] |
|
|
|
if self.last_frame_skip or i <= self.n_past: |
|
h, skip = h |
|
else: |
|
h = h[0] |
|
|
|
|
|
h_cpaw = tf.concat([h, global_z, time_until_cp, delta_time], axis=-1) |
|
h_target_cpaw = tf.concat( |
|
[h_target, global_z, time_until_cp, delta_time], axis=-1 |
|
) |
|
|
|
zt, mu, logvar = self.posterior(h_target_cpaw) |
|
zt_p, mu_p, logvar_p = self.prior(h_cpaw) |
|
|
|
frame_predictor_input = tf.concat( |
|
[h, zt, time_until_cp, delta_time], axis=-1 |
|
) |
|
h_pred = self.frame_predictor(frame_predictor_input) |
|
x_pred = self.decoder([h_pred, skip]) |
|
|
|
if i == cp_ix: |
|
h_pred_p = self.frame_predictor( |
|
tf.concat([h, zt_p, time_until_cp, delta_time], axis=-1) |
|
) |
|
x_pred_p = self.decoder([h_pred_p, skip]) |
|
cpc_loss = self.mse_criterion(x_pred_p, x_cp) |
|
|
|
mse_loss += self.mse_criterion(x_pred, x[:, i, ...]) |
|
kld_loss += self.kl_criterion((mu, logvar), (mu_p, logvar_p)) |
|
|
|
|
|
loss = ( |
|
mse_loss |
|
+ kld_loss * self.beta |
|
+ align_loss * self.weight_align |
|
|
|
) |
|
|
|
prior_loss = kld_loss + cpc_loss * self.weight_cpc |
|
|
|
var_list_frame_predictor = self.frame_predictor.trainable_variables |
|
var_list_posterior = self.posterior.trainable_variables |
|
var_list_prior = self.prior.trainable_variables |
|
var_list_encoder = self.encoder.trainable_variables |
|
var_list_decoder = self.decoder.trainable_variables |
|
|
|
|
|
|
|
|
|
|
|
var_list = ( |
|
var_list_frame_predictor |
|
+ var_list_posterior |
|
+ var_list_encoder |
|
+ var_list_decoder |
|
+ var_list_prior |
|
) |
|
|
|
gradients = tape.gradient( |
|
loss, |
|
var_list, |
|
) |
|
gradients_prior = tape.gradient( |
|
prior_loss, |
|
var_list_prior, |
|
) |
|
|
|
self.update_model( |
|
gradients, |
|
var_list, |
|
) |
|
self.update_prior(gradients_prior, var_list_prior) |
|
del tape |
|
|
|
self.total_loss_tracker.update_state(loss) |
|
self.kl_loss_tracker.update_state(kld_loss) |
|
self.align_loss_tracker.update_state(align_loss) |
|
self.reconstruction_loss_tracker.update_state(mse_loss) |
|
self.cpc_loss_tracker.update_state(cpc_loss) |
|
|
|
return { |
|
"loss": self.total_loss_tracker.result(), |
|
"reconstruction_loss": self.reconstruction_loss_tracker.result(), |
|
"kl_loss": self.kl_loss_tracker.result(), |
|
"align_loss": self.align_loss_tracker.result(), |
|
"cpc_loss": self.cpc_loss_tracker.result(), |
|
} |
|
|
|
def call( |
|
self, |
|
inputs, |
|
training=None, |
|
mask=None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
): |
|
len_output = 20 |
|
eval_cp_ix = len_output - 1 |
|
start_ix = 0 |
|
cp_ix = -1 |
|
model_mode = "full" |
|
skip_frame = False |
|
init_hidden = True |
|
|
|
batch_size, num_frames, h, w, channels = inputs.shape |
|
dim_shape = (h, w, channels) |
|
|
|
gen_seq = [inputs[:, 0, ...]] |
|
x_in = inputs[:, 0, ...] |
|
|
|
seq_len = inputs.shape[1] |
|
cp_ix = seq_len - 1 |
|
|
|
x_cp, global_z = self.get_global_descriptor( |
|
inputs, cp_ix=cp_ix |
|
) |
|
|
|
skip_prob = self.skip_prob |
|
|
|
prev_i = 0 |
|
max_skip_count = seq_len * skip_prob |
|
skip_count = 0 |
|
probs = np.random.uniform(0, 1, len_output - 1) |
|
|
|
for i in range(1, len_output): |
|
if ( |
|
probs[i - 1] <= skip_prob |
|
and i >= self.n_past |
|
and skip_count < max_skip_count |
|
and i != 1 |
|
and i != (len_output - 1) |
|
and skip_frame |
|
): |
|
skip_count += 1 |
|
gen_seq.append(tf.zeros_like(x_in)) |
|
continue |
|
|
|
time_until_cp = tf.fill([100, 1], (eval_cp_ix - i + 1) / eval_cp_ix) |
|
|
|
delta_time = tf.fill([100, 1], ((i - prev_i) / eval_cp_ix)) |
|
|
|
prev_i = i |
|
|
|
h = self.encoder(x_in) |
|
|
|
if self.last_frame_skip or i == 1 or i < self.n_past: |
|
h, skip = h |
|
else: |
|
h, _ = h |
|
|
|
h_cpaw = tf.stop_gradient(tf.concat([h, global_z, time_until_cp, delta_time], axis=-1)) |
|
|
|
if i < self.n_past: |
|
h_target = self.encoder(inputs[:, i, ...])[0] |
|
h_target_cpaw = tf.stop_gradient(tf.concat( |
|
[h_target, global_z, time_until_cp, delta_time], axis=1 |
|
)) |
|
|
|
zt, _, _ = self.posterior(h_target_cpaw) |
|
zt_p, _, _ = self.prior(h_cpaw) |
|
|
|
if model_mode == "posterior" or model_mode == "full": |
|
self.frame_predictor( |
|
tf.concat([h, zt, time_until_cp, delta_time], axis=-1) |
|
) |
|
elif model_mode == "prior": |
|
self.frame_predictor( |
|
tf.concat([h, zt_p, time_until_cp, delta_time], axis=-1) |
|
) |
|
|
|
x_in = inputs[:, i, ...] |
|
gen_seq.append(x_in) |
|
else: |
|
if i < num_frames: |
|
h_target = self.encoder(inputs[:, i, ...])[0] |
|
h_target_cpaw = tf.stop_gradient(tf.concat( |
|
[h_target, global_z, time_until_cp, delta_time], axis=-1 |
|
)) |
|
else: |
|
h_target_cpaw = h_cpaw |
|
|
|
zt, _, _ = self.posterior(h_target_cpaw) |
|
zt_p, _, _ = self.prior(h_cpaw) |
|
|
|
if model_mode == "posterior": |
|
h = self.frame_predictor( |
|
tf.concat([h, zt, time_until_cp, delta_time], axis=-1) |
|
) |
|
elif model_mode == "prior" or model_mode == "full": |
|
h = self.frame_predictor( |
|
tf.concat([h, zt_p, time_until_cp, delta_time], axis=-1) |
|
) |
|
|
|
x_in = tf.stop_gradient(self.decoder([h, skip])) |
|
gen_seq.append(x_in) |
|
|
|
return tf.stack(gen_seq, axis=1) |
|
|
|
def update_model(self, gradients, var_list): |
|
self.frame_predictor_optimizer.apply_gradients(zip(gradients, var_list)) |
|
self.posterior_optimizer.apply_gradients(zip(gradients, var_list)) |
|
self.encoder_optimizer.apply_gradients(zip(gradients, var_list)) |
|
self.decoder_optimizer.apply_gradients(zip(gradients, var_list)) |
|
|
|
|
|
def update_prior(self, gradients, var_list): |
|
self.prior_optimizer.apply_gradients(zip(gradients, var_list)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|