import os, warnings warnings.filterwarnings('ignore') from abc import ABC, abstractmethod import numpy as np import joblib import tensorflow as tf from tensorflow.keras import backend as K from tensorflow.keras.layers import Conv1D, Flatten, Dense, Conv1DTranspose, Reshape, Input, Layer from tensorflow.keras.models import Model from tensorflow.keras.optimizers import Adam from tensorflow.keras.metrics import Mean from tensorflow.keras.backend import random_normal class Sampling(Layer): """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.""" def call(self, inputs): z_mean, z_log_var = inputs batch = tf.shape(z_mean)[0] dim = tf.shape(z_mean)[1] epsilon = random_normal(shape=(batch, dim)) return z_mean + tf.exp(0.5 * z_log_var) * epsilon class BaseVariationalAutoencoder(Model, ABC): def __init__(self, seq_len, feat_dim, latent_dim, reconstruction_wt=3.0, **kwargs): super(BaseVariationalAutoencoder, self).__init__(**kwargs) self.seq_len = seq_len self.feat_dim = feat_dim self.latent_dim = latent_dim self.reconstruction_wt = reconstruction_wt self.total_loss_tracker = Mean(name="total_loss") self.reconstruction_loss_tracker = Mean(name="reconstruction_loss") self.kl_loss_tracker = Mean(name="kl_loss") self.encoder = None self.decoder = None def call(self, X): z_mean, _, _ = self.encoder(X) x_decoded = self.decoder(z_mean) if len(x_decoded.shape) == 1: x_decoded = x_decoded.reshape((1, -1)) return x_decoded def get_num_trainable_variables(self): trainableParams = int(np.sum([np.prod(v.get_shape()) for v in self.trainable_weights])) nonTrainableParams = int(np.sum([np.prod(v.get_shape()) for v in self.non_trainable_weights])) totalParams = trainableParams + nonTrainableParams return trainableParams, nonTrainableParams, totalParams def get_prior_samples(self, num_samples): Z = np.random.randn(num_samples, self.latent_dim) samples = self.decoder.predict(Z) return samples def get_prior_samples_given_Z(self, Z): samples = self.decoder.predict(Z) return samples @abstractmethod def _get_encoder(self, **kwargs): raise NotImplementedError @abstractmethod def _get_decoder(self, **kwargs): raise NotImplementedError def summary(self): self.encoder.summary() self.decoder.summary() def _get_reconstruction_loss(self, X, X_recons): def get_reconst_loss_by_axis(X, X_c, axis): x_r = tf.reduce_mean(X, axis=axis) x_c_r = tf.reduce_mean(X_recons, axis=axis) err = tf.math.squared_difference(x_r, x_c_r) loss = tf.reduce_sum(err) return loss # overall err = tf.math.squared_difference(X, X_recons) reconst_loss = tf.reduce_sum(err) reconst_loss += get_reconst_loss_by_axis(X, X_recons, axis=[2]) # by time axis # reconst_loss += get_reconst_loss_by_axis(X, X_recons, axis=[1]) # by feature axis return reconst_loss def train_step(self, X): with tf.GradientTape() as tape: z_mean, z_log_var, z = self.encoder(X) reconstruction = self.decoder(z) reconstruction_loss = self._get_reconstruction_loss(X, reconstruction) kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) kl_loss = tf.reduce_sum(tf.reduce_sum(kl_loss, axis=1)) # kl_loss = kl_loss / self.latent_dim total_loss = self.reconstruction_wt * reconstruction_loss + kl_loss grads = tape.gradient(total_loss, self.trainable_weights) self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) self.total_loss_tracker.update_state(total_loss) self.reconstruction_loss_tracker.update_state(reconstruction_loss) self.kl_loss_tracker.update_state(kl_loss) return { "loss": self.total_loss_tracker.result(), "reconstruction_loss": self.reconstruction_loss_tracker.result(), "kl_loss": self.kl_loss_tracker.result(), } def test_step(self, X): z_mean, z_log_var, z = self.encoder(X) reconstruction = self.decoder(z) reconstruction_loss = self._get_reconstruction_loss(X, reconstruction) kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) kl_loss = tf.reduce_sum(tf.reduce_sum(kl_loss, axis=1)) # kl_loss = kl_loss / self.latent_dim total_loss = self.reconstruction_wt * reconstruction_loss + kl_loss self.total_loss_tracker.update_state(total_loss) self.reconstruction_loss_tracker.update_state(reconstruction_loss) self.kl_loss_tracker.update_state(kl_loss) return { "loss": self.total_loss_tracker.result(), "reconstruction_loss": self.reconstruction_loss_tracker.result(), "kl_loss": self.kl_loss_tracker.result(), } def save_weights(self, model_dir, file_pref): encoder_wts = self.encoder.get_weights() decoder_wts = self.decoder.get_weights() joblib.dump(encoder_wts, os.path.join(model_dir, f'{file_pref}encoder_wts.h5')) joblib.dump(decoder_wts, os.path.join(model_dir, f'{file_pref}decoder_wts.h5')) def load_weights(self, model_dir, file_pref): encoder_wts = joblib.load(os.path.join(model_dir, f'{file_pref}encoder_wts.h5')) decoder_wts = joblib.load(os.path.join(model_dir, f'{file_pref}decoder_wts.h5')) self.encoder.set_weights(encoder_wts) self.decoder.set_weights(decoder_wts) def save(self, model_dir, file_pref): self.save_weights(model_dir, file_pref) dict_params = { 'seq_len': self.seq_len, 'feat_dim': self.feat_dim, 'latent_dim': self.latent_dim, 'reconstruction_wt': self.reconstruction_wt, 'hidden_layer_sizes': self.hidden_layer_sizes, } params_file = os.path.join(model_dir, f'{file_pref}parameters.pkl') joblib.dump(dict_params, params_file) class TimeVAE(BaseVariationalAutoencoder): def __init__(self, hidden_layer_sizes, trend_poly=0, num_gen_seas=0, custom_seas=None, use_scaler=False, use_residual_conn=True, **kwargs): ''' hidden_layer_sizes: list of number of filters in convolutional layers in encoder and residual connection of decoder. trend_poly: integer for number of orders for trend component. e.g. setting trend_poly = 2 will include linear and quadratic term. num_gen_seas: Number of sine-waves to use to model seasonalities. Each sine wae will have its own amplitude, frequency and phase. custom_seas: list of tuples of (num_seasons, len_per_season). num_seasons: number of seasons per cycle. len_per_season: number of epochs (time-steps) per season. use_residual_conn: boolean value indicating whether to use a residual connection for reconstruction in addition to trend, generic and custom seasonalities. ''' super(TimeVAE, self).__init__(**kwargs) self.hidden_layer_sizes = hidden_layer_sizes self.trend_poly = trend_poly self.num_gen_seas = num_gen_seas self.custom_seas = custom_seas self.use_scaler = use_scaler self.use_residual_conn = use_residual_conn self.encoder = self._get_encoder() self.decoder = self._get_decoder() def _get_encoder(self): encoder_inputs = Input(shape=(self.seq_len, self.feat_dim), name='encoder_input') x = encoder_inputs for i, num_filters in enumerate(self.hidden_layer_sizes): x = Conv1D( filters=num_filters, kernel_size=3, strides=2, activation='relu', padding='same', name=f'enc_conv_{i}')(x) x = Flatten(name='enc_flatten')(x) # save the dimensionality of this last dense layer before the hidden state layer. We need it in the decoder. self.encoder_last_dense_dim = x.get_shape()[-1] z_mean = Dense(self.latent_dim, name="z_mean")(x) z_log_var = Dense(self.latent_dim, name="z_log_var")(x) encoder_output = Sampling()([z_mean, z_log_var]) self.encoder_output = encoder_output encoder = Model(encoder_inputs, [z_mean, z_log_var, encoder_output], name="encoder") return encoder def _get_decoder(self): decoder_inputs = Input(shape=(int(self.latent_dim)), name='decoder_input') outputs = None outputs = self.level_model(decoder_inputs) # trend polynomials if self.trend_poly is not None and self.trend_poly > 0: trend_vals = self.trend_model(decoder_inputs) outputs = trend_vals if outputs is None else outputs + trend_vals # # generic seasonalities # if self.num_gen_seas is not None and self.num_gen_seas > 0: # gen_seas_vals, freq, phase, amplitude = self.generic_seasonal_model(decoder_inputs) # # gen_seas_vals = self.generic_seasonal_model2(decoder_inputs) # outputs = gen_seas_vals if outputs is None else outputs + gen_seas_vals # custom seasons if self.custom_seas is not None and len(self.custom_seas) > 0: cust_seas_vals = self.custom_seasonal_model(decoder_inputs) outputs = cust_seas_vals if outputs is None else outputs + cust_seas_vals if self.use_residual_conn: residuals = self._get_decoder_residual(decoder_inputs) outputs = residuals if outputs is None else outputs + residuals if self.use_scaler and outputs is not None: scale = self.scale_model(decoder_inputs) outputs *= scale # outputs = Activation(activation='sigmoid')(outputs) if outputs is None: raise Exception('''Error: No decoder model to use. You must use one or more of: trend, generic seasonality(ies), custom seasonality(ies), and/or residual connection. ''') decoder = Model(decoder_inputs, [outputs], name="decoder") return decoder def level_model(self, z): level_params = Dense(self.feat_dim, name="level_params", activation='relu')(z) level_params = Dense(self.feat_dim, name="level_params2")(level_params) level_params = Reshape(target_shape=(1, self.feat_dim))(level_params) # shape: (N, 1, D) ones_tensor = tf.ones(shape=[1, self.seq_len, 1], dtype=tf.float32) # shape: (1, T, D) level_vals = level_params * ones_tensor return level_vals def scale_model(self, z): scale_params = Dense(self.feat_dim, name="scale_params", activation='relu')(z) scale_params = Dense(self.feat_dim, name="scale_params2")(scale_params) scale_params = Reshape(target_shape=(1, self.feat_dim))(scale_params) # shape: (N, 1, D) scale_vals = tf.repeat(scale_params, repeats=self.seq_len, axis=1) # shape: (N, T, D) return scale_vals def trend_model(self, z): trend_params = Dense(self.feat_dim * self.trend_poly, name="trend_params", activation='relu')(z) trend_params = Dense(self.feat_dim * self.trend_poly, name="trend_params2")(trend_params) trend_params = Reshape(target_shape=(self.feat_dim, self.trend_poly))(trend_params) # shape: N x D x P lin_space = K.arange(0, float(self.seq_len), 1) / self.seq_len # shape of lin_space : 1d tensor of length T poly_space = K.stack([lin_space ** float(p + 1) for p in range(self.trend_poly)], axis=0) # shape: P x T trend_vals = K.dot(trend_params, poly_space) # shape (N, D, T) trend_vals = tf.transpose(trend_vals, perm=[0, 2, 1]) # shape: (N, T, D) trend_vals = K.cast(trend_vals, tf.float32) return trend_vals def custom_seasonal_model(self, z): N = tf.shape(z)[0] ones_tensor = tf.ones(shape=[N, self.feat_dim, self.seq_len], dtype=tf.int32) all_seas_vals = [] for i, season_tup in enumerate(self.custom_seas): num_seasons, len_per_season = season_tup season_params = Dense(self.feat_dim * num_seasons, name=f"season_params_{i}")(z) # shape: (N, D * S) season_params = Reshape(target_shape=(self.feat_dim, num_seasons))(season_params) # shape: (N, D, S) season_indexes_over_time = self._get_season_indexes_over_seq(num_seasons, len_per_season) # shape: (T, ) dim2_idxes = ones_tensor * tf.reshape(season_indexes_over_time, shape=(1, 1, -1)) # shape: (1, 1, T) season_vals = tf.gather(season_params, dim2_idxes, batch_dims=-1) # shape (N, D, T) all_seas_vals.append(season_vals) all_seas_vals = K.stack(all_seas_vals, axis=-1) # shape: (N, D, T, S) all_seas_vals = tf.reduce_sum(all_seas_vals, axis=-1) # shape (N, D, T) all_seas_vals = tf.transpose(all_seas_vals, perm=[0, 2, 1]) # shape (N, T, D) return all_seas_vals def _get_season_indexes_over_seq(self, num_seasons, len_per_season): curr_len = 0 season_idx = [] curr_idx = 0 while curr_len < self.seq_len: reps = len_per_season if curr_len + len_per_season <= self.seq_len else self.seq_len - curr_len season_idx.extend([curr_idx] * reps) curr_idx += 1 if curr_idx == num_seasons: curr_idx = 0 curr_len += reps return season_idx def generic_seasonal_model(self, z): freq = Dense(self.feat_dim * self.num_gen_seas, name="g_season_freq", activation='sigmoid')(z) freq = Reshape(target_shape=(1, self.feat_dim, self.num_gen_seas))(freq) # shape: (N, 1, D, S) phase = Dense(self.feat_dim * self.num_gen_seas, name="g_season_phase")(z) phase = Reshape(target_shape=(1, self.feat_dim, self.num_gen_seas))(phase) # shape: (N, 1, D, S) amplitude = Dense(self.feat_dim * self.num_gen_seas, name="g_season_amplitude")(z) amplitude = Reshape(target_shape=(1, self.feat_dim, self.num_gen_seas))(amplitude) # shape: (N, 1, D, S) lin_space = K.arange(0, float(self.seq_len), 1) / self.seq_len # shape of lin_space : 1d tensor of length T lin_space = tf.reshape(lin_space, shape=(1, self.seq_len, 1, 1)) # shape: 1, T, 1, 1 seas_vals = amplitude * K.sin(2. * np.pi * freq * lin_space + phase) # shape: N, T, D, S seas_vals = tf.math.reduce_sum(seas_vals, axis=-1) # shape: N, T, D return seas_vals def generic_seasonal_model2(self, z): season_params = Dense(self.feat_dim * self.num_gen_seas, name="g_season_params")(z) season_params = Reshape(target_shape=(self.feat_dim, self.num_gen_seas))(season_params) # shape: (D, S) p = self.num_gen_seas p1, p2 = (p // 2, p // 2) if p % 2 == 0 else (p // 2, p // 2 + 1) ls = K.arange(0, float(self.seq_len), 1) / self.seq_len # shape of ls : 1d tensor of length T s1 = K.stack([K.cos(2 * np.pi * i * ls) for i in range(p1)], axis=0) s2 = K.stack([K.sin(2 * np.pi * i * ls) for i in range(p2)], axis=0) if p == 1: s = s2 else: s = K.concatenate([s1, s2], axis=0) s = K.cast(s, np.float32) seas_vals = K.dot(season_params, s, name='g_seasonal_vals') seas_vals = tf.transpose(seas_vals, perm=[0, 2, 1]) # shape: (N, T, D) seas_vals = K.cast(seas_vals, np.float32) print('seas_vals shape', tf.shape(seas_vals)) return seas_vals def _get_decoder_residual(self, x): x = Dense(self.encoder_last_dense_dim, name="dec_dense", activation='relu')(x) x = Reshape(target_shape=(-1, self.hidden_layer_sizes[-1]), name="dec_reshape")(x) for i, num_filters in enumerate(reversed(self.hidden_layer_sizes[:-1])): x = Conv1DTranspose( filters=num_filters, kernel_size=3, strides=2, padding='same', activation='relu', name=f'dec_deconv_{i}')(x) # last de-convolution x = Conv1DTranspose( filters=self.feat_dim, kernel_size=3, strides=2, padding='same', activation='relu', name=f'dec_deconv__{i + 1}')(x) x = Flatten(name='dec_flatten')(x) x = Dense(self.seq_len * self.feat_dim, name="decoder_dense_final")(x) residuals = Reshape(target_shape=(self.seq_len, self.feat_dim))(x) return residuals def save(self, model_dir, file_pref): super().save_weights(model_dir, file_pref) dict_params = { 'seq_len': self.seq_len, 'feat_dim': self.feat_dim, 'latent_dim': self.latent_dim, 'reconstruction_wt': self.reconstruction_wt, 'hidden_layer_sizes': self.hidden_layer_sizes, 'trend_poly': self.trend_poly, 'num_gen_seas': self.num_gen_seas, 'custom_seas': self.custom_seas, 'use_scaler': self.use_scaler, 'use_residual_conn': self.use_residual_conn, } params_file = os.path.join(model_dir, f'{file_pref}parameters.pkl') joblib.dump(dict_params, params_file) @staticmethod def load(model_dir, file_pref): params_file = os.path.join(model_dir, f'{file_pref}parameters.pkl') dict_params = joblib.load(params_file) vae_model = TimeVAE(**dict_params) vae_model.load_weights(model_dir, file_pref) vae_model.compile(optimizer=Adam()) return vae_model