GANime / ganime /model /p2p /p2p_test.py
Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
23.3 kB
from tqdm.auto import tqdm
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import (
LSTM,
LSTMCell,
Activation,
BatchNormalization,
Conv2D,
Conv2DTranspose,
Conv3D,
Conv3DTranspose,
Dense,
Flatten,
Input,
Layer,
LeakyReLU,
MaxPooling2D,
Reshape,
TimeDistributed,
UpSampling2D,
)
from tensorflow.keras.losses import Loss
from tensorflow.keras.losses import KLDivergence, MeanSquaredError
# from tensorflow_probability.python.layers.dense_variational import (
# DenseReparameterization,
# )
# import tensorflow_probability as tfp
from tensorflow.keras.losses import Loss
initializer_conv_dense = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
initializer_batch_norm = tf.keras.initializers.RandomNormal(mean=1.0, stddev=0.02)
class KLCriterion(Loss):
def call(self, y_true, y_pred):
(mu1, logvar1), (mu2, logvar2) = y_true, y_pred
"""KL( N(mu_1, sigma2_1) || N(mu_2, sigma2_2))"""
sigma1 = tf.exp(tf.math.multiply(logvar1, 0.5))
sigma2 = tf.exp(tf.math.multiply(logvar2, 0.5))
kld = (
tf.math.log(sigma2 / sigma1)
+ (tf.exp(logvar1) + tf.square(mu1 - mu2)) / (2 * tf.exp(logvar2))
- 0.5
)
return tf.reduce_sum(kld) / 100
class Encoder(Model):
def __init__(self, dim, nc=1):
super().__init__()
self.dim = dim
self.c1 = Sequential(
[
Conv2D(
64,
kernel_size=4,
strides=2,
padding="same",
kernel_initializer=initializer_conv_dense,
),
# BatchNormalization(),
LeakyReLU(alpha=0.2),
]
)
self.c2 = Sequential(
[
Conv2D(
128,
kernel_size=4,
strides=2,
padding="same",
kernel_initializer=initializer_conv_dense,
),
# BatchNormalization(),
LeakyReLU(alpha=0.2),
]
)
self.c3 = Sequential(
[
Conv2D(
256,
kernel_size=4,
strides=2,
padding="same",
kernel_initializer=initializer_conv_dense,
),
# BatchNormalization(),
LeakyReLU(alpha=0.2),
]
)
self.c4 = Sequential(
[
Conv2D(
512,
kernel_size=4,
strides=2,
padding="same",
kernel_initializer=initializer_conv_dense,
),
# BatchNormalization(),
LeakyReLU(alpha=0.2),
]
)
self.c5 = Sequential(
[
Conv2D(
self.dim,
kernel_size=4,
strides=1,
padding="valid",
kernel_initializer=initializer_conv_dense,
),
# BatchNormalization(),
Activation("tanh"),
]
)
def call(self, input):
h1 = self.c1(input)
h2 = self.c2(h1)
h3 = self.c3(h2)
h4 = self.c4(h3)
h5 = self.c5(h4)
return tf.reshape(h5, (-1, self.dim)), [h1, h2, h3, h4, h5]
class Decoder(Model):
def __init__(self, dim, nc=1):
super().__init__()
self.dim = dim
self.upc1 = Sequential(
[
Conv2DTranspose(
512,
kernel_size=4,
strides=1,
padding="valid",
kernel_initializer=initializer_conv_dense,
),
# BatchNormalization(),
LeakyReLU(alpha=0.2),
]
)
self.upc2 = Sequential(
[
Conv2DTranspose(
256,
kernel_size=4,
strides=2,
padding="same",
kernel_initializer=initializer_conv_dense,
),
# BatchNormalization(),
LeakyReLU(alpha=0.2),
]
)
self.upc3 = Sequential(
[
Conv2DTranspose(
128,
kernel_size=4,
strides=2,
padding="same",
kernel_initializer=initializer_conv_dense,
),
# BatchNormalization(),
LeakyReLU(alpha=0.2),
]
)
self.upc4 = Sequential(
[
Conv2DTranspose(
64,
kernel_size=4,
strides=2,
padding="same",
kernel_initializer=initializer_conv_dense,
),
# BatchNormalization(),
LeakyReLU(alpha=0.2),
]
)
self.upc5 = Sequential(
[
Conv2DTranspose(
1,
kernel_size=4,
strides=2,
padding="same",
kernel_initializer=initializer_conv_dense,
),
Activation("sigmoid"),
]
)
def call(self, input):
vec, skip = input
d1 = self.upc1(tf.reshape(vec, (-1, 1, 1, self.dim)))
d2 = self.upc2(tf.concat([d1, skip[3]], axis=-1))
d3 = self.upc3(tf.concat([d2, skip[2]], axis=-1))
d4 = self.upc4(tf.concat([d3, skip[1]], axis=-1))
output = self.upc5(tf.concat([d4, skip[0]], axis=-1))
return output
class MyLSTM(Model):
def __init__(self, input_shape, hidden_size, output_size, n_layers):
super().__init__()
self.hidden_size = hidden_size
self.n_layers = n_layers
self.embed = Dense(
hidden_size,
input_dim=input_shape,
kernel_initializer=initializer_conv_dense,
)
# self.lstm = Sequential(
# [LSTMCell(hidden_size) for _ in range(n_layers)], name="lstm"
# )
# self.lstm = self.create_lstm(hidden_size, n_layers)
self.lstm = [
LSTMCell(
hidden_size # , return_sequences=False if i == self.n_layers - 1 else True
)
for i in range(self.n_layers)
] # LSTMCell(hidden_size)
self.lstm_rnn = tf.keras.layers.RNN(self.lstm[0], return_state=True)
self.out = Dense(output_size, kernel_initializer=initializer_conv_dense)
def init_hidden(self, batch_size):
hidden = []
for i in range(self.n_layers):
hidden.append(
(
tf.Variable(tf.zeros([batch_size, self.hidden_size])),
tf.Variable(tf.zeros([batch_size, self.hidden_size])),
)
)
self.__dict__["hidden"] = hidden
def build(self, input_shape):
self.init_hidden(input_shape[0])
def call(self, inputs):
h_in = self.embed(inputs)
h_in = tf.reshape(h_in, (-1, 1, self.hidden_size))
h_in, *state = self.lstm_rnn(h_in)
for i in range(self.n_layers):
h_in, state = self.lstm[i](h_in, state)
return self.out(h_in)
class MyGaussianLSTM(Model):
def __init__(self, input_shape, hidden_size, output_size, n_layers):
super().__init__()
self.hidden_size = hidden_size
self.n_layers = n_layers
self.embed = Dense(
hidden_size,
input_dim=input_shape,
kernel_initializer=initializer_conv_dense,
)
# self.lstm = Sequential(
# [LSTMCell(hidden_size) for _ in range(n_layers)], name="lstm"
# )
self.lstm = [
LSTMCell(
hidden_size # , return_sequences=False if i == self.n_layers - 1 else True
)
for i in range(self.n_layers)
] # LSTMCell(hidden_size)
self.lstm_rnn = tf.keras.layers.RNN(self.lstm[0], return_state=True)
self.mu_net = Dense(output_size, kernel_initializer=initializer_conv_dense)
self.logvar_net = Dense(output_size, kernel_initializer=initializer_conv_dense)
# self.out = Sequential(
# [
# tf.keras.layers.Dense(
# tfp.layers.MultivariateNormalTriL.params_size(output_size),
# activation=None,
# ),
# tfp.layers.MultivariateNormalTriL(output_size),
# ]
# )
def reparameterize(self, mu, logvar: tf.Tensor):
logvar = tf.math.exp(logvar * 0.5)
eps = tf.random.normal(logvar.shape)
return tf.add(tf.math.multiply(eps, logvar), mu)
def init_hidden(self, batch_size):
hidden = []
for i in range(self.n_layers):
hidden.append(
(
tf.Variable(tf.zeros([batch_size, self.hidden_size])),
tf.Variable(tf.zeros([batch_size, self.hidden_size])),
)
)
self.__dict__["hidden"] = hidden
def build(self, input_shape):
self.init_hidden(input_shape[0])
def call(self, inputs):
h_in = self.embed(inputs)
# for i in range(self.n_layers):
# # print(h_in.shape, self.hidden[i][0].shape, self.hidden[i][0].shape)
# _, self.hidden[i] = self.lstm(h_in, self.hidden[i])
# h_in = self.hidden[i][0]
h_in = tf.reshape(h_in, (-1, 1, self.hidden_size))
h_in, *state = self.lstm_rnn(h_in)
for i in range(self.n_layers):
h_in, state = self.lstm[i](h_in, state)
mu = self.mu_net(h_in)
logvar = self.logvar_net(h_in)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
class P2P(Model):
def __init__(
self,
channels: int = 1,
g_dim: int = 128,
z_dim: int = 10,
rnn_size: int = 256,
prior_rnn_layers: int = 1,
posterior_rnn_layers: int = 1,
predictor_rnn_layers: float = 2,
skip_prob: float = 0.5,
n_past: int = 1,
last_frame_skip: bool = False,
beta: float = 0.0001,
weight_align: float = 0.1,
weight_cpc: float = 100,
):
super().__init__()
self.channels = channels
self.g_dim = g_dim
self.z_dim = z_dim
self.rnn_size = rnn_size
self.prior_rnn_layers = prior_rnn_layers
self.posterior_rnn_layers = posterior_rnn_layers
self.predictor_rnn_layers = predictor_rnn_layers
self.skip_prob = skip_prob
self.n_past = n_past
self.last_frame_skip = last_frame_skip
self.beta = beta
self.weight_align = weight_align
self.weight_cpc = weight_cpc
self.frame_predictor = MyLSTM(
self.g_dim + self.z_dim + 1 + 1,
self.rnn_size,
self.g_dim,
self.predictor_rnn_layers,
)
self.prior = MyGaussianLSTM(
self.g_dim + self.g_dim + 1 + 1,
self.rnn_size,
self.z_dim,
self.prior_rnn_layers,
)
self.posterior = MyGaussianLSTM(
self.g_dim + self.g_dim + 1 + 1,
self.rnn_size,
self.z_dim,
self.posterior_rnn_layers,
)
self.encoder = Encoder(self.g_dim, self.channels)
self.decoder = Decoder(self.g_dim, self.channels)
# criterions
self.mse_criterion = tf.keras.losses.MeanSquaredError()
self.kl_criterion = KLCriterion()
self.align_criterion = tf.keras.losses.MeanSquaredError()
self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
name="reconstruction_loss"
)
self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
self.align_loss_tracker = tf.keras.metrics.Mean(name="align_loss")
self.cpc_loss_tracker = tf.keras.metrics.Mean(name="align_loss")
# optimizers
# self.frame_predictor_optimizer = tf.keras.optimizers.Adam(
# learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
# )
# self.posterior_optimizer = tf.keras.optimizers.Adam(
# learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
# )
# self.prior_optimizer = tf.keras.optimizers.Adam(
# learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
# )
# self.encoder_optimizer = tf.keras.optimizers.Adam(
# learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
# )
# self.decoder_optimizer = tf.keras.optimizers.Adam(
# learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
# )
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
self.align_loss_tracker,
self.cpc_loss_tracker,
]
def get_global_descriptor(self, x, start_ix=0, cp_ix=None):
"""Get the global descriptor based on x, start_ix, cp_ix."""
if cp_ix is None:
cp_ix = x.shape[1] - 1
x_cp = x[:, cp_ix, ...]
h_cp = self.encoder(x_cp)[0] # 1 is input for skip-connection
return x_cp, h_cp
def compile(
self,
frame_predictor_optimizer,
prior_optimizer,
posterior_optimizer,
encoder_optimizer,
decoder_optimizer,
):
super().compile()
self.frame_predictor_optimizer = frame_predictor_optimizer
self.prior_optimizer = prior_optimizer
self.posterior_optimizer = posterior_optimizer
self.encoder_optimizer = encoder_optimizer
self.decoder_optimizer = decoder_optimizer
def train_step(self, data):
y, x = data
batch_size = 100
mse_loss = 0
kld_loss = 0
cpc_loss = 0
align_loss = 0
seq_len = x.shape[1]
start_ix = 0
cp_ix = seq_len - 1
x_cp, global_z = self.get_global_descriptor(
x, start_ix, cp_ix
) # here global_z is h_cp
skip_prob = self.skip_prob
prev_i = 0
max_skip_count = seq_len * skip_prob
skip_count = 0
probs = np.random.uniform(low=0, high=1, size=seq_len - 1)
with tf.GradientTape(persistent=True) as tape:
for i in tqdm(range(1, seq_len)):
if (
probs[i - 1] <= skip_prob
and i >= self.n_past
and skip_count < max_skip_count
and i != 1
and i != cp_ix
):
skip_count += 1
continue
if i > 1:
align_loss += self.align_criterion(h, h_pred)
time_until_cp = tf.fill(
[batch_size, 1],
(cp_ix - i + 1) / cp_ix,
)
delta_time = tf.fill([batch_size, 1], ((i - prev_i) / cp_ix))
prev_i = i
h = self.encoder(x[:, i - 1, ...])
h_target = self.encoder(x[:, i, ...])[0]
if self.last_frame_skip or i <= self.n_past:
h, skip = h
else:
h = h[0]
# Control Point Aware
h_cpaw = tf.concat([h, global_z, time_until_cp, delta_time], axis=-1)
h_target_cpaw = tf.concat(
[h_target, global_z, time_until_cp, delta_time], axis=-1
)
zt, mu, logvar = self.posterior(h_target_cpaw)
zt_p, mu_p, logvar_p = self.prior(h_cpaw)
frame_predictor_input = tf.concat(
[h, zt, time_until_cp, delta_time], axis=-1
)
h_pred = self.frame_predictor(frame_predictor_input)
x_pred = self.decoder([h_pred, skip])
if i == cp_ix: # the gen-cp-frame should be exactly as x_cp
h_pred_p = self.frame_predictor(
tf.concat([h, zt_p, time_until_cp, delta_time], axis=-1)
)
x_pred_p = self.decoder([h_pred_p, skip])
cpc_loss = self.mse_criterion(x_pred_p, x_cp)
mse_loss += self.mse_criterion(x_pred, x[:, i, ...])
kld_loss += self.kl_criterion((mu, logvar), (mu_p, logvar_p))
# backward
loss = (
mse_loss
+ kld_loss * self.beta
+ align_loss * self.weight_align
# + cpc_loss * self.weight_cpc
)
prior_loss = kld_loss + cpc_loss * self.weight_cpc
var_list_frame_predictor = self.frame_predictor.trainable_variables
var_list_posterior = self.posterior.trainable_variables
var_list_prior = self.prior.trainable_variables
var_list_encoder = self.encoder.trainable_variables
var_list_decoder = self.decoder.trainable_variables
# mse: frame_predictor + decoder
# align: frame_predictor + encoder
# kld: posterior + prior + encoder
var_list = (
var_list_frame_predictor
+ var_list_posterior
+ var_list_encoder
+ var_list_decoder
+ var_list_prior
)
gradients = tape.gradient(
loss,
var_list,
)
gradients_prior = tape.gradient(
prior_loss,
var_list_prior,
)
self.update_model(
gradients,
var_list,
)
self.update_prior(gradients_prior, var_list_prior)
del tape
self.total_loss_tracker.update_state(loss)
self.kl_loss_tracker.update_state(kld_loss)
self.align_loss_tracker.update_state(align_loss)
self.reconstruction_loss_tracker.update_state(mse_loss)
self.cpc_loss_tracker.update_state(cpc_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
"align_loss": self.align_loss_tracker.result(),
"cpc_loss": self.cpc_loss_tracker.result(),
}
def call(
self,
inputs,
training=None,
mask=None
# len_output,
# eval_cp_ix,
# start_ix=0,
# cp_ix=-1,
# model_mode="full",
# skip_frame=False,
# init_hidden=True,
):
len_output = 20
eval_cp_ix = len_output - 1
start_ix = 0
cp_ix = -1
model_mode = "full"
skip_frame = False
init_hidden = True
batch_size, num_frames, h, w, channels = inputs.shape
dim_shape = (h, w, channels)
gen_seq = [inputs[:, 0, ...]]
x_in = inputs[:, 0, ...]
seq_len = inputs.shape[1]
cp_ix = seq_len - 1
x_cp, global_z = self.get_global_descriptor(
inputs, cp_ix=cp_ix
) # here global_z is h_cp
skip_prob = self.skip_prob
prev_i = 0
max_skip_count = seq_len * skip_prob
skip_count = 0
probs = np.random.uniform(0, 1, len_output - 1)
for i in range(1, len_output):
if (
probs[i - 1] <= skip_prob
and i >= self.n_past
and skip_count < max_skip_count
and i != 1
and i != (len_output - 1)
and skip_frame
):
skip_count += 1
gen_seq.append(tf.zeros_like(x_in))
continue
time_until_cp = tf.fill([100, 1], (eval_cp_ix - i + 1) / eval_cp_ix)
delta_time = tf.fill([100, 1], ((i - prev_i) / eval_cp_ix))
prev_i = i
h = self.encoder(x_in)
if self.last_frame_skip or i == 1 or i < self.n_past:
h, skip = h
else:
h, _ = h
h_cpaw = tf.stop_gradient(tf.concat([h, global_z, time_until_cp, delta_time], axis=-1))
if i < self.n_past:
h_target = self.encoder(inputs[:, i, ...])[0]
h_target_cpaw = tf.stop_gradient(tf.concat(
[h_target, global_z, time_until_cp, delta_time], axis=1
))
zt, _, _ = self.posterior(h_target_cpaw)
zt_p, _, _ = self.prior(h_cpaw)
if model_mode == "posterior" or model_mode == "full":
self.frame_predictor(
tf.concat([h, zt, time_until_cp, delta_time], axis=-1)
)
elif model_mode == "prior":
self.frame_predictor(
tf.concat([h, zt_p, time_until_cp, delta_time], axis=-1)
)
x_in = inputs[:, i, ...]
gen_seq.append(x_in)
else:
if i < num_frames:
h_target = self.encoder(inputs[:, i, ...])[0]
h_target_cpaw = tf.stop_gradient(tf.concat(
[h_target, global_z, time_until_cp, delta_time], axis=-1
))
else:
h_target_cpaw = h_cpaw
zt, _, _ = self.posterior(h_target_cpaw)
zt_p, _, _ = self.prior(h_cpaw)
if model_mode == "posterior":
h = self.frame_predictor(
tf.concat([h, zt, time_until_cp, delta_time], axis=-1)
)
elif model_mode == "prior" or model_mode == "full":
h = self.frame_predictor(
tf.concat([h, zt_p, time_until_cp, delta_time], axis=-1)
)
x_in = tf.stop_gradient(self.decoder([h, skip]))
gen_seq.append(x_in)
return tf.stack(gen_seq, axis=1)
def update_model(self, gradients, var_list):
self.frame_predictor_optimizer.apply_gradients(zip(gradients, var_list))
self.posterior_optimizer.apply_gradients(zip(gradients, var_list))
self.encoder_optimizer.apply_gradients(zip(gradients, var_list))
self.decoder_optimizer.apply_gradients(zip(gradients, var_list))
#self.prior_optimizer.apply_gradients(zip(gradients, var_list))
def update_prior(self, gradients, var_list):
self.prior_optimizer.apply_gradients(zip(gradients, var_list))
# def update_model_without_prior(self):
# self.frame_predictor_optimizer.step()
# self.posterior_optimizer.step()
# self.encoder_optimizer.step()
# self.decoder_optimizer.step()