GANime / ganime /model /p2p /p2p_v2.py
Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
17.1 kB
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import (
LSTM,
Activation,
BatchNormalization,
Conv2D,
Conv2DTranspose,
Conv3D,
Conv3DTranspose,
Dense,
Flatten,
Input,
Layer,
LeakyReLU,
MaxPooling2D,
Reshape,
TimeDistributed,
UpSampling2D,
)
from tensorflow.keras.losses import Loss
from tensorflow.keras.losses import KLDivergence, MeanSquaredError
from tqdm.auto import tqdm
class KLCriterion(Loss):
def call(self, y_true, y_pred):
(mu1, logvar1), (mu2, logvar2) = y_true, y_pred
"""KL( N(mu_1, sigma2_1) || N(mu_2, sigma2_2))"""
sigma1 = tf.exp(tf.math.multiply(logvar1, 0.5))
sigma2 = tf.exp(tf.math.multiply(logvar2, 0.5))
kld = (
tf.math.log(sigma2 / sigma1)
+ (tf.exp(logvar1) + tf.square(mu1 - mu2)) / (2 * tf.exp(logvar2))
- 0.5
)
return kld
class Decoder(Model):
def __init__(self, dim, nc=1):
super().__init__()
self.dim = dim
self.upc1 = Sequential(
[
TimeDistributed(
Conv2DTranspose(512, kernel_size=4, strides=1, padding="valid")
),
BatchNormalization(),
LeakyReLU(alpha=0.2),
]
)
self.upc2 = Sequential(
[
TimeDistributed(
Conv2DTranspose(256, kernel_size=4, strides=2, padding="same")
),
BatchNormalization(),
LeakyReLU(alpha=0.2),
]
)
self.upc3 = Sequential(
[
TimeDistributed(
Conv2DTranspose(128, kernel_size=4, strides=2, padding="same")
),
BatchNormalization(),
LeakyReLU(alpha=0.2),
]
)
self.upc4 = Sequential(
[
TimeDistributed(
Conv2DTranspose(64, kernel_size=4, strides=2, padding="same")
),
BatchNormalization(),
LeakyReLU(alpha=0.2),
]
)
self.upc5 = Sequential(
[
TimeDistributed(
Conv2DTranspose(1, kernel_size=4, strides=2, padding="same")
),
Activation("sigmoid"),
]
)
def call(self, input):
vec, skip = input
d1 = self.upc1(tf.reshape(vec, (-1, 1, 1, 1, self.dim)))
d2 = self.upc2(tf.concat([d1, skip[3]], axis=-1))
d3 = self.upc3(tf.concat([d2, skip[2]], axis=-1))
d4 = self.upc4(tf.concat([d3, skip[1]], axis=-1))
output = self.upc5(tf.concat([d4, skip[0]], axis=-1))
return output
class Sampling(Layer):
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
def compute_output_shape(self, input_shape):
return input_shape[0]
class P2P(Model):
def __init__(
self,
channels: int = 1,
g_dim: int = 128,
z_dim: int = 10,
rnn_size: int = 256,
prior_rnn_layers: int = 1,
posterior_rnn_layers: int = 1,
predictor_rnn_layers: float = 1,
skip_prob: float = 0.1,
n_past: int = 1,
last_frame_skip: bool = False,
beta: float = 0.0001,
weight_align: float = 0.1,
weight_cpc: float = 100,
):
super().__init__()
# Models parameters
self.channels = channels
self.g_dim = g_dim
self.z_dim = z_dim
self.rnn_size = rnn_size
self.prior_rnn_layers = prior_rnn_layers
self.posterior_rnn_layers = posterior_rnn_layers
self.predictor_rnn_layers = predictor_rnn_layers
# Training parameters
self.skip_prob = skip_prob
self.n_past = n_past
self.last_frame_skip = last_frame_skip
self.beta = beta
self.weight_align = weight_align
self.weight_cpc = weight_cpc
self.frame_predictor = self.build_lstm()
self.prior = self.build_gaussian_lstm()
self.posterior = self.build_gaussian_lstm()
self.encoder = self.build_encoder()
self.decoder = self.build_decoder()
self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
name="reconstruction_loss"
)
self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
self.align_loss_tracker = tf.keras.metrics.Mean(name="align_loss")
self.cpc_loss_tracker = tf.keras.metrics.Mean(name="align_loss")
self.kl_loss = KLCriterion(
reduction=tf.keras.losses.Reduction.NONE
) # KLDivergence(reduction=tf.keras.losses.Reduction.NONE)
self.mse = MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
self.align_loss = MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
# self.optimizer = tf.keras.optimizers.Adam(
# learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
# )
# self.prior_optimizer = tf.keras.optimizers.Adam(
# learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
# )
# region Model building
def build_lstm(self):
input = Input(shape=(None, self.g_dim + self.z_dim))
embed = TimeDistributed(Dense(self.rnn_size))(input)
lstm = LSTM(self.rnn_size)(embed)
output = Dense(self.g_dim)(lstm)
output = (tf.expand_dims(output, axis=1),)
return Model(inputs=input, outputs=output, name="frame_predictor")
def build_gaussian_lstm(self):
input = Input(shape=(None, self.g_dim))
embed = TimeDistributed(Dense(self.rnn_size))(input)
lstm = LSTM(self.rnn_size)(embed)
mu = Dense(self.z_dim)(lstm)
logvar = Dense(self.z_dim)(lstm)
z = Sampling()([mu, logvar])
return Model(inputs=input, outputs=[mu, logvar, z])
def build_encoder(self):
input = Input(shape=(1, 64, 64, 1))
h = TimeDistributed(Conv2D(64, kernel_size=4, strides=2, padding="same"))(input)
h = BatchNormalization()(h)
h1 = LeakyReLU(alpha=0.2)(h)
# h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
h = TimeDistributed(Conv2D(128, kernel_size=4, strides=2, padding="same"))(h1)
h = BatchNormalization()(h)
h2 = LeakyReLU(alpha=0.2)(h)
# h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
h = TimeDistributed(Conv2D(256, kernel_size=4, strides=2, padding="same"))(h2)
h = BatchNormalization()(h)
h3 = LeakyReLU(alpha=0.2)(h)
# h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
h = TimeDistributed(Conv2D(512, kernel_size=4, strides=2, padding="same"))(h3)
h = BatchNormalization()(h)
h4 = LeakyReLU(alpha=0.2)(h)
# h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
h = TimeDistributed(
Conv2D(self.g_dim, kernel_size=4, strides=1, padding="valid")
)(h4)
h = BatchNormalization()(h)
h5 = Activation("tanh")(h)
output = tf.reshape(h5, (-1, 1, self.g_dim))
# h = Flatten()(h)
# output = Dense(self.g_dim)(h)
# output = tf.expand_dims(output, axis=1)
return Model(inputs=input, outputs=[output, [h1, h2, h3, h4]], name="encoder")
def build_decoder(self):
return Decoder(self.g_dim)
# def build_decoder(self):
# latent_inputs = Input(
# shape=(
# 1,
# self.g_dim,
# )
# )
# x = Dense(1 * 1 * 1 * 128, activation="relu")(latent_inputs)
# x = Reshape((1, 1, 1, 128))(x)
# x = TimeDistributed(
# Conv2DTranspose(512, kernel_size=4, strides=1, padding="valid")
# )(x)
# x = BatchNormalization()(x)
# x1 = LeakyReLU(alpha=0.2)(x)
# x = TimeDistributed(
# Conv2DTranspose(256, kernel_size=4, strides=2, padding="same")
# )(x1)
# x = BatchNormalization()(x)
# x2 = LeakyReLU(alpha=0.2)(x)
# x = TimeDistributed(
# Conv2DTranspose(128, kernel_size=4, strides=2, padding="same")
# )(x2)
# x = BatchNormalization()(x)
# x3 = LeakyReLU(alpha=0.2)(x)
# x = TimeDistributed(
# Conv2DTranspose(64, kernel_size=4, strides=2, padding="same")
# )(x3)
# x = BatchNormalization()(x)
# x4 = LeakyReLU(alpha=0.2)(x)
# x = TimeDistributed(
# Conv2DTranspose(1, kernel_size=4, strides=2, padding="same")
# )(x4)
# x5 = Activation("sigmoid")(x)
# return Model(inputs=latent_inputs, outputs=x5, name="decoder")
# endregion
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
self.align_loss_tracker,
self.cpc_loss_tracker,
]
def call(self, inputs, training=None, mask=None):
first_frame = inputs[:, 0:1, ...]
last_frame = inputs[:, -1:, ...]
desired_length = 20
previous_frame = first_frame
generated = [first_frame]
z_last, _ = self.encoder(last_frame)
for i in range(1, desired_length):
z_prev = self.encoder(previous_frame)
if self.last_frame_skip or i == 1 or i < self.n_past:
z_prev, skip = z_prev
else:
z_prev = z_prev[0]
prior_input = tf.concat([z_prev, z_last], axis=1)
z_mean_prior, z_log_var_prior, z_prior = self.prior(prior_input)
predictor_input = tf.concat(
(z_prev, tf.expand_dims(z_prior, axis=1)), axis=-1
)
z_pred = self.frame_predictor(predictor_input)
current_frame = self.decoder([z_pred, skip])
generated.append(current_frame)
previous_frame = current_frame
return tf.concat(generated, axis=1)
def train_step(self, data):
global_batch_size = 100 # * 8
x, y = data
first_frame = x[:, 0:1, ...]
last_frame = x[:, -1:, ...]
desired_length = y.shape[1]
previous_frame = first_frame
reconstruction_loss = 0
kl_loss = 0
align_loss = 0
cpc_loss = 0
with tf.GradientTape(persistent=True) as tape:
z_last, _ = self.encoder(last_frame)
for i in tqdm(range(1, desired_length)):
current_frame = y[:, i : i + 1, ...]
z_prev = self.encoder(previous_frame)
if self.last_frame_skip or i <= self.n_past:
z_prev, skip = z_prev
else:
z_prev = z_prev[0]
z_curr, _ = self.encoder(current_frame)
prior_input = tf.concat([z_prev, z_last], axis=1)
posterior_input = tf.concat([z_curr, z_last], axis=1)
z_mean_prior, z_log_var_prior, z_prior = self.prior(prior_input)
z_mean_posterior, z_log_var_posterior, z_posterior = self.posterior(
posterior_input
)
# predictor_input = z_prev
predictor_input = tf.concat(
(z_prev, tf.expand_dims(z_posterior, axis=1)), axis=-1
)
z_pred = self.frame_predictor(predictor_input)
kl_loss += tf.reduce_sum(
self.kl_loss(
(z_mean_prior, z_log_var_prior),
(z_mean_posterior, z_log_var_posterior),
)
) * (1.0 / global_batch_size)
if i > 1:
align_loss += tf.reduce_sum(self.align_loss(z_pred, z_curr)) * (
1.0 / global_batch_size
)
if i == desired_length - 1:
h_pred_p = self.frame_predictor(
tf.concat([z_prev, tf.expand_dims(z_prior, axis=1)], axis=-1)
)
x_pred_p = self.decoder([h_pred_p, skip])
cpc_loss = tf.reduce_sum(self.mse(x_pred_p, current_frame)) * (
1.0 / global_batch_size
)
prediction = self.decoder([z_pred, skip])
reconstruction_loss += tf.reduce_sum(
self.mse(prediction, current_frame)
) * (1.0 / global_batch_size)
previous_frame = current_frame
loss = (
reconstruction_loss
+ kl_loss * self.beta
+ align_loss * self.weight_align
+ cpc_loss * self.weight_cpc
)
prior_loss = kl_loss + cpc_loss * self.weight_cpc
grads_without_prior = tape.gradient(
loss,
(
self.encoder.trainable_weights
+ self.decoder.trainable_weights
+ self.posterior.trainable_weights
+ self.frame_predictor.trainable_weights
),
)
self.optimizer.apply_gradients(
zip(
grads_without_prior,
(
self.encoder.trainable_weights
+ self.decoder.trainable_weights
+ self.posterior.trainable_weights
+ self.frame_predictor.trainable_weights
),
)
)
grads_prior = tape.gradient(
prior_loss,
self.prior.trainable_weights,
)
self.optimizer.apply_gradients(
zip(
grads_prior,
self.prior.trainable_weights,
)
)
del tape
self.total_loss_tracker.update_state(loss)
self.kl_loss_tracker.update_state(kl_loss)
self.align_loss_tracker.update_state(align_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.cpc_loss_tracker.update_state(cpc_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
"align_loss": self.align_loss_tracker.result(),
"cpc_loss": self.cpc_loss_tracker.result(),
}
# print("KL_LOSS")
# print(kl_loss)
# print("ALIGN_LOSS")
# print(align_loss)
# print("RECONSTRUCTION_LOSS")
# print(reconstruction_loss)
# with tf.GradientTape() as tape:
# z_mean, z_log_var, z = self.encoder(x)
# reconstruction = self.decoder(z)
# reconstruction_loss = tf.reduce_mean(
# tf.reduce_sum(
# tf.keras.losses.binary_crossentropy(y, reconstruction),
# axis=(1, 2),
# )
# )
# kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
# kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
# total_loss = reconstruction_loss + self.kl_beta * kl_loss
# grads = tape.gradient(total_loss, self.trainable_weights)
# self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
# self.total_loss_tracker.update_state(total_loss)
# self.reconstruction_loss_tracker.update_state(reconstruction_loss)
# self.kl_loss_tracker.update_state(kl_loss)
# return {
# "loss": self.total_loss_tracker.result(),
# "reconstruction_loss": self.reconstruction_loss_tracker.result(),
# "kl_loss": self.kl_loss_tracker.result(),
# }
# def test_step(self, data):
# if isinstance(data, tuple):
# data = data[0]
# z_mean, z_log_var, z = self.encoder(data)
# reconstruction = self.decoder(z)
# reconstruction_loss = tf.reduce_mean(
# tf.keras.losses.binary_crossentropy(data, reconstruction)
# )
# reconstruction_loss *= 28 * 28
# kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
# kl_loss = tf.reduce_mean(kl_loss)
# kl_loss *= -0.5
# total_loss = reconstruction_loss + kl_loss
# return {
# "loss": total_loss,
# "reconstruction_loss": reconstruction_loss,
# "kl_loss": kl_loss,
# }