import numpy as np import tensorflow as tf import tensorflow_addons as tfa from tensorflow.keras import mixed_precision from tensorflow.keras.layers import ( Add, BatchNormalization, Concatenate, Conv2D, Conv2DTranspose, Cropping1D, Cropping2D, Dense, Dot, Flatten, GlobalAveragePooling2D, Input, Lambda, LeakyReLU, Multiply, ReLU, Reshape, SeparableConv2D, UpSampling2D, ZeroPadding2D, ) from tensorflow.keras.models import Model, Sequential from tensorflow.keras.optimizers import Adam from tensorflow.python.keras.utils.layer_utils import count_params from layers import ConvSN2D, DenseSN, PosEnc, AddNoise class Models_functions: def __init__(self, args): self.args = args if self.args.mixed_precision: self.mixed_precision = mixed_precision self.policy = mixed_precision.Policy("mixed_float16") mixed_precision.set_global_policy(self.policy) self.init = tf.keras.initializers.he_uniform() def conv_util( self, inp, filters, kernel_size=(1, 3), strides=(1, 1), noise=False, upsample=False, padding="same", bn=True, ): x = inp if upsample: x = tf.keras.layers.Conv2DTranspose( filters, kernel_size=kernel_size, strides=strides, activation="linear", padding=padding, kernel_initializer=self.init, )(x) else: x = tf.keras.layers.Conv2D( filters, kernel_size=kernel_size, strides=strides, activation="linear", padding=padding, kernel_initializer=self.init, )(x) if noise: x = AddNoise()(x) if bn: x = tf.keras.layers.BatchNormalization()(x) x = tf.keras.activations.swish(x) return x def adain(self, x, emb): emb = tf.keras.layers.Conv2D( x.shape[-1], kernel_size=(1, 1), strides=1, activation="linear", padding="same", kernel_initializer=self.init, use_bias=True, )(emb) x = x / (tf.math.reduce_std(x, -2, keepdims=True) + 1e-7) return x * emb def se_layer(self, x, filters): x = tf.reduce_mean(x, -2, keepdims=True) x = tf.keras.layers.Conv2D( filters, kernel_size=(1, 1), strides=(1, 1), activation="linear", padding="valid", kernel_initializer=self.init, use_bias=True, )(x) x = tf.keras.activations.swish(x) return tf.keras.layers.Conv2D( filters, kernel_size=(1, 1), strides=(1, 1), activation="sigmoid", padding="valid", kernel_initializer=self.init, use_bias=True, )(x) def conv_util_gen( self, inp, filters, kernel_size=(1, 9), strides=(1, 1), noise=False, upsample=False, emb=None, se1=None, ): x = inp if upsample: x = tf.keras.layers.Conv2DTranspose( filters, kernel_size=kernel_size, strides=strides, activation="linear", padding="same", kernel_initializer=self.init, use_bias=True, )(x) else: x = tf.keras.layers.Conv2D( filters, kernel_size=kernel_size, strides=strides, activation="linear", padding="same", kernel_initializer=self.init, use_bias=True, )(x) if noise: x = AddNoise()(x) if emb is not None: x = self.adain(x, emb) else: x = tf.keras.layers.BatchNormalization()(x) x1 = tf.keras.activations.swish(x) if se1 is not None: x1 = x1 * se1 return x1 def res_block_disc(self, inp, filters, kernel_size=(1, 3), kernel_size_2=None, strides=(1, 1)): if kernel_size_2 is None: kernel_size_2 = kernel_size x = tf.keras.layers.Conv2D( inp.shape[-1], kernel_size=kernel_size_2, strides=1, activation="linear", padding="same", kernel_initializer=self.init, )(inp) x = tf.keras.layers.LeakyReLU(0.2)(x) x = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * x x = tf.keras.layers.Conv2D( filters, kernel_size=kernel_size, strides=strides, activation="linear", padding="same", kernel_initializer=self.init, )(x) x = tf.keras.layers.LeakyReLU(0.2)(x) x = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * x if strides != (1, 1): inp = tf.keras.layers.AveragePooling2D(strides, padding="same")(inp) if inp.shape[-1] != filters: inp = tf.keras.layers.Conv2D( filters, kernel_size=1, strides=1, activation="linear", padding="same", kernel_initializer=self.init, use_bias=False, )(inp) return x + inp def build_encoder2(self): dim = 128 inpf = Input((1, self.args.shape, dim)) inpfls = tf.split(inpf, 16, -2) inpb = tf.concat(inpfls, 0) g0 = self.conv_util(inpb, 256, kernel_size=(1, 1), strides=(1, 1), padding="valid") g1 = self.conv_util(g0, 256 + 256, kernel_size=(1, 3), strides=(1, 1), padding="valid") g2 = self.conv_util(g1, 512 + 128, kernel_size=(1, 3), strides=(1, 1), padding="valid") g3 = self.conv_util(g2, 512 + 128, kernel_size=(1, 1), strides=(1, 1), padding="valid") g4 = self.conv_util(g3, 512 + 256, kernel_size=(1, 3), strides=(1, 1), padding="valid") g5 = self.conv_util(g4, 512 + 256, kernel_size=(1, 2), strides=(1, 1), padding="valid") g = tf.keras.layers.Conv2D( 64, kernel_size=(1, 1), strides=1, padding="valid", kernel_initializer=self.init, name="cbottle", activation="tanh", )(g5) gls = tf.split(g, 16, 0) g = tf.concat(gls, -2) gls = tf.split(g, 2, -2) g = tf.concat(gls, 0) gf = tf.cast(g, tf.float32) return Model(inpf, gf, name="ENC2") def build_decoder2(self): dim = 128 bottledim = 64 inpf = Input((1, self.args.shape // 16, bottledim)) g = inpf g = self.conv_util( g, 512 + 128 + 128, kernel_size=(1, 4), strides=(1, 1), upsample=False, noise=True, ) g = self.conv_util( g, 512 + 128 + 128, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, ) g = self.conv_util(g, 512 + 128, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True) g = self.conv_util(g, 512, kernel_size=(1, 4), strides=(1, 1), upsample=False, noise=True) g = self.conv_util(g, 256 + 128, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True) gf = tf.keras.layers.Conv2D( dim, kernel_size=(1, 1), strides=1, padding="same", activation="tanh", kernel_initializer=self.init, )(g) gfls = tf.split(gf, 2, 0) gf = tf.concat(gfls, -2) gf = tf.cast(gf, tf.float32) return Model(inpf, gf, name="DEC2") def build_encoder(self): dim = ((4 * self.args.hop) // 2) + 1 inpf = Input((dim, self.args.shape, 1)) ginp = tf.transpose(inpf, [0, 3, 2, 1]) g0 = self.conv_util( ginp, self.args.hop * 2 + 32, kernel_size=(1, 1), strides=(1, 1), padding="valid", ) g = self.conv_util( g0, self.args.hop * 2 + 64, kernel_size=(1, 1), strides=(1, 1), padding="valid", ) g = self.conv_util( g, self.args.hop * 2 + 64 + 64, kernel_size=(1, 1), strides=(1, 1), padding="valid", ) g = self.conv_util( g, self.args.hop * 2 + 128 + 64, kernel_size=(1, 1), strides=(1, 1), padding="valid", ) g = self.conv_util( g, self.args.hop * 2 + 128 + 128, kernel_size=(1, 1), strides=(1, 1), padding="valid", ) g = tf.keras.layers.Conv2D( 128, kernel_size=(1, 1), strides=1, padding="valid", kernel_initializer=self.init, )(g) gb = tf.keras.activations.tanh(g) gbls = tf.split(gb, 2, -2) gb = tf.concat(gbls, 0) gb = tf.cast(gb, tf.float32) return Model(inpf, gb, name="ENC") def build_decoder(self): dim = ((4 * self.args.hop) // 2) + 1 inpf = Input((1, self.args.shape // 2, 128)) g = inpf g0 = self.conv_util(g, self.args.hop * 3, kernel_size=(1, 1), strides=(1, 1), noise=True) g1 = self.conv_util(g0, self.args.hop * 2, kernel_size=(1, 3), strides=(1, 2), noise=True) g2 = self.conv_util( g1, self.args.hop + self.args.hop // 2, kernel_size=(1, 3), strides=(1, 2), noise=True, ) g = self.conv_util( g2, self.args.hop + self.args.hop // 4, kernel_size=(1, 3), strides=(1, 2), noise=True, ) g = self.conv_util( g, self.args.hop + self.args.hop // 2, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, ) g = self.conv_util( g + g2, self.args.hop * 2, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, ) g = self.conv_util( g + g1, self.args.hop * 3, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, ) g = self.conv_util(g + g0, self.args.hop * 5, kernel_size=(1, 1), strides=(1, 1), noise=True) g = Conv2D( dim * 2, kernel_size=(1, 1), strides=(1, 1), kernel_initializer=self.init, padding="same", )(g) g = tf.clip_by_value(g, -1.0, 1.0) gf, pf = tf.split(g, 2, -1) gfls = tf.split(gf, self.args.shape // self.args.window, 0) gf = tf.concat(gfls, -2) pfls = tf.split(pf, self.args.shape // self.args.window, 0) pf = tf.concat(pfls, -2) s = tf.transpose(gf, [0, 2, 3, 1]) p = tf.transpose(pf, [0, 2, 3, 1]) s = tf.cast(tf.squeeze(s, -1), tf.float32) p = tf.cast(tf.squeeze(p, -1), tf.float32) return Model(inpf, [s, p], name="DEC") def build_critic(self): sinp = Input(shape=(1, self.args.latlen, self.args.latdepth * 2)) dim = 64 * 2 sf = tf.keras.layers.Conv2D( self.args.latdepth * 4, kernel_size=(1, 1), strides=(1, 1), activation="linear", padding="valid", kernel_initializer=self.init, use_bias=False, trainable=False, )(sinp) sf = tf.keras.layers.Conv2D( 256 + 128, kernel_size=(1, 3), strides=(1, 2), activation="linear", padding="same", kernel_initializer=self.init, )(sf) sf = tf.keras.layers.LeakyReLU(0.2)(sf) sf = self.res_block_disc(sf, 256 + 128 + 128, kernel_size=(1, 3), strides=(1, 2)) sf = self.res_block_disc(sf, 512 + 128, kernel_size=(1, 3), strides=(1, 2)) sf = self.res_block_disc(sf, 512 + 256, kernel_size=(1, 3), strides=(1, 2)) sf = self.res_block_disc(sf, 512 + 128 + 256, kernel_size=(1, 3), strides=(1, 2)) sfo = self.res_block_disc(sf, 512 + 512, kernel_size=(1, 3), strides=(1, 2), kernel_size_2=(1, 1)) sf = sfo gf = tf.keras.layers.Dense(1, activation="linear", use_bias=True, kernel_initializer=self.init)(Flatten()(sf)) gf = tf.cast(gf, tf.float32) sfo = tf.cast(sfo, tf.float32) return Model(sinp, [gf, sfo], name="C") def build_critic_rec(self): sinp = Input(shape=(1, self.args.latlen // 64, 512 + 512)) dim = self.args.latdepth * 2 sf = tf.keras.layers.Conv2DTranspose( 512, kernel_size=(1, 4), strides=(1, 2), activation="linear", padding="same", kernel_initializer=self.init, )(sinp) sf = tf.keras.layers.LeakyReLU(0.2)(sf) sf = tf.keras.layers.Conv2DTranspose( 256 + 128 + 64, kernel_size=(1, 4), strides=(1, 2), activation="linear", padding="same", kernel_initializer=self.init, )(sf) sf = tf.keras.layers.LeakyReLU(0.2)(sf) sf = tf.keras.layers.Conv2DTranspose( 256 + 128, kernel_size=(1, 4), strides=(1, 2), activation="linear", padding="same", kernel_initializer=self.init, )(sf) sf = tf.keras.layers.LeakyReLU(0.2)(sf) sf = tf.keras.layers.Conv2DTranspose( 256 + 64, kernel_size=(1, 4), strides=(1, 2), activation="linear", padding="same", kernel_initializer=self.init, )(sf) sf = tf.keras.layers.LeakyReLU(0.2)(sf) sf = tf.keras.layers.Conv2DTranspose( 256, kernel_size=(1, 4), strides=(1, 2), activation="linear", padding="same", kernel_initializer=self.init, )(sf) sf = tf.keras.layers.LeakyReLU(0.2)(sf) sf = tf.keras.layers.Conv2DTranspose( 128 + 64, kernel_size=(1, 4), strides=(1, 2), activation="linear", padding="same", kernel_initializer=self.init, )(sf) sf = tf.keras.layers.LeakyReLU(0.2)(sf) gf = tf.keras.layers.Conv2D( dim, kernel_size=(1, 1), strides=(1, 1), activation="tanh", padding="same", kernel_initializer=self.init, )(sf) gf = tf.cast(gf, tf.float32) return Model(sinp, gf, name="CR") def build_generator(self): dim = self.args.latdepth * 2 inpf = Input((self.args.latlen, self.args.latdepth * 2)) inpfls = tf.split(inpf, 2, -2) inpb = tf.concat(inpfls, 0) inpg = tf.reduce_mean(inpb, -2) inp1 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(tf.expand_dims(inpb, -3)) inp2 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp1) inp3 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp2) inp4 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp3) inp5 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp4) inp6 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp5) g = tf.keras.layers.Dense( 4 * (512 + 256 + 128), activation="linear", use_bias=True, kernel_initializer=self.init, )(Flatten()(inp6)) g = tf.keras.layers.Reshape((1, 4, 512 + 256 + 128))(g) g = AddNoise()(g) g = self.adain(g, inp5) g = tf.keras.activations.swish(g) g = self.conv_util_gen( g, 512 + 256, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, emb=inp4, ) g1 = self.conv_util_gen( g, 512 + 256, kernel_size=(1, 1), strides=(1, 1), upsample=False, noise=True, emb=inp4, ) g2 = self.conv_util_gen( g1, 512 + 128, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, emb=inp3, ) g2b = self.conv_util_gen( g2, 512 + 128, kernel_size=(1, 3), strides=(1, 1), upsample=False, noise=True, emb=inp3, ) g3 = self.conv_util_gen( g2b, 256 + 256, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, emb=inp2, se1=self.se_layer(g, 256 + 256), ) g3 = self.conv_util_gen( g3, 256 + 256, kernel_size=(1, 3), strides=(1, 1), upsample=False, noise=True, emb=inp2, se1=self.se_layer(g1, 256 + 256), ) g4 = self.conv_util_gen( g3, 256 + 128, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, emb=inp1, se1=self.se_layer(g2, 256 + 128), ) g4 = self.conv_util_gen( g4, 256 + 128, kernel_size=(1, 3), strides=(1, 1), upsample=False, noise=True, emb=inp1, se1=self.se_layer(g2b, 256 + 128), ) g5 = self.conv_util_gen( g4, 256, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, emb=tf.expand_dims(tf.cast(inpb, dtype=self.args.datatype), -3), ) gf = tf.keras.layers.Conv2D( dim, kernel_size=(1, 1), strides=(1, 1), kernel_initializer=self.init, padding="same", activation="tanh", )(g5) gfls = tf.split(gf, 2, 0) gf = tf.concat(gfls, -2) gf = tf.cast(gf, tf.float32) return Model(inpf, gf, name="GEN") # Load past models from path to resume training or test def load(self, path, load_dec=False): gen = self.build_generator() critic = self.build_critic() enc = self.build_encoder() dec = self.build_decoder() enc2 = self.build_encoder2() dec2 = self.build_decoder2() critic_rec = self.build_critic_rec() gen_ema = self.build_generator() if self.args.mixed_precision: opt_disc = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.9)) opt_dec = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.9)) else: opt_disc = tf.keras.optimizers.Adam(0.0001, 0.9) opt_dec = tf.keras.optimizers.Adam(0.0001, 0.9) if load_dec: dec.load_weights(self.args.dec_path + "/dec.h5") dec2.load_weights(self.args.dec_path + "/dec2.h5") else: grad_vars = critic.trainable_weights + critic_rec.trainable_weights zero_grads = [tf.zeros_like(w) for w in grad_vars] opt_disc.apply_gradients(zip(zero_grads, grad_vars)) grad_vars = gen.trainable_variables zero_grads = [tf.zeros_like(w) for w in grad_vars] opt_dec.apply_gradients(zip(zero_grads, grad_vars)) if not self.args.testing: opt_disc.set_weights(np.load(path + "/opt_disc.npy", allow_pickle=True)) opt_dec.set_weights(np.load(path + "/opt_dec.npy", allow_pickle=True)) if not self.args.testing: critic.load_weights(path + "/critic.h5") gen.load_weights(path + "/gen.h5") enc.load_weights(path + "/enc.h5") enc2.load_weights(path + "/enc2.h5") critic_rec.load_weights(path + "/critic_rec.h5") gen_ema.load_weights(path + "/gen_ema.h5") dec.load_weights(path + "/dec.h5") dec2.load_weights(path + "/dec2.h5") return ( critic, gen, enc, dec, enc2, dec2, critic_rec, gen_ema, [opt_dec, opt_disc], ) def build(self): gen = self.build_generator() critic = self.build_critic() enc = self.build_encoder() dec = self.build_decoder() enc2 = self.build_encoder2() dec2 = self.build_decoder2() critic_rec = self.build_critic_rec() gen_ema = self.build_generator() gen_ema = tf.keras.models.clone_model(gen) gen_ema.set_weights(gen.get_weights()) if self.args.mixed_precision: opt_disc = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.9)) opt_dec = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.9)) else: opt_disc = tf.keras.optimizers.Adam(0.0001, 0.9) opt_dec = tf.keras.optimizers.Adam(0.0001, 0.9) return ( critic, gen, enc, dec, enc2, dec2, critic_rec, gen_ema, [opt_dec, opt_disc], ) def get_networks(self): ( critic, gen, enc, dec_techno, enc2, dec2_techno, critic_rec, gen_ema_techno, [opt_dec, opt_disc], ) = self.load(self.args.load_path_techno, load_dec=False) print(f"Techno networks loaded from {self.args.load_path_techno}") ( critic, gen, enc, dec_classical, enc2, dec2_classical, critic_rec, gen_ema_classical, [opt_dec, opt_disc], ) = self.load(self.args.load_path_classical, load_dec=False) print(f"Classical networks loaded from {self.args.load_path_classical}") return [critic, gen, enc, dec_techno, enc2, dec2_techno, critic_rec, gen_ema_techno, [opt_dec, opt_disc]], [ critic, gen, enc, dec_classical, enc2, dec2_classical, critic_rec, gen_ema_classical, [opt_dec, opt_disc], ] def initialize_networks(self): [critic, gen, enc, dec_techno, enc2, dec2_techno, critic_rec, gen_ema_techno, [opt_dec, opt_disc]], [ critic, gen, enc, dec_classical, enc2, dec2_classical, critic_rec, gen_ema_classical, [opt_dec, opt_disc], ] = self.get_networks() print(f"Generator params: {count_params(gen_ema_techno.trainable_variables)}") print(f"Decoder params: {count_params(dec_techno.trainable_variables+dec2_techno.trainable_variables)}") return [critic, gen, enc, dec_techno, enc2, dec2_techno, critic_rec, gen_ema_techno, [opt_dec, opt_disc]], [ critic, gen, enc, dec_classical, enc2, dec2_classical, critic_rec, gen_ema_classical, [opt_dec, opt_disc], ]