Spaces:
Build error
Build error
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.python.keras.utils.layer_utils import count_params | |
| from layers import AddNoise | |
| class Models_functions: | |
| def __init__(self, args): | |
| self.args = args | |
| if self.args.mixed_precision: | |
| self.mixed_precision = tf.keras.mixed_precision | |
| self.policy = tf.keras.mixed_precision.Policy("mixed_float16") | |
| tf.keras.mixed_precision.set_global_policy(self.policy) | |
| self.init = tf.keras.initializers.he_uniform() | |
| def conv_util( | |
| self, inp, filters, kernel_size=(1, 3), strides=(1, 1), noise=False, upsample=False, padding="same", bnorm=True | |
| ): | |
| x = inp | |
| bias = True | |
| if bnorm: | |
| bias = False | |
| if upsample: | |
| x = tf.keras.layers.Conv2DTranspose( | |
| filters, | |
| kernel_size=kernel_size, | |
| strides=strides, | |
| activation="linear", | |
| padding=padding, | |
| kernel_initializer=self.init, | |
| use_bias=bias, | |
| )(x) | |
| else: | |
| x = tf.keras.layers.Conv2D( | |
| filters, | |
| kernel_size=kernel_size, | |
| strides=strides, | |
| activation="linear", | |
| padding=padding, | |
| kernel_initializer=self.init, | |
| use_bias=bias, | |
| )(x) | |
| if noise: | |
| x = AddNoise(self.args.datatype)(x) | |
| if bnorm: | |
| x = tf.keras.layers.BatchNormalization()(x) | |
| x = tf.keras.activations.swish(x) | |
| return x | |
| def pixel_shuffle(self, x, factor=2): | |
| bs_dim, h_dim, w_dim, c_dim = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3] | |
| x = tf.reshape(x, [bs_dim, h_dim, w_dim, c_dim // factor, factor]) | |
| x = tf.transpose(x, [0, 1, 2, 4, 3]) | |
| return tf.reshape(x, [bs_dim, h_dim, w_dim * factor, c_dim // factor]) | |
| def adain(self, x, emb, name): | |
| emb = tf.keras.layers.Conv2D( | |
| x.shape[-1], | |
| kernel_size=(1, 1), | |
| strides=1, | |
| activation="linear", | |
| padding="same", | |
| kernel_initializer=self.init, | |
| use_bias=True, | |
| name=name, | |
| )(emb) | |
| x = x / (tf.math.reduce_std(x, -2, keepdims=True) + 1e-5) | |
| return x * emb | |
| def conv_util_gen( | |
| self, | |
| inp, | |
| filters, | |
| kernel_size=(1, 9), | |
| strides=(1, 1), | |
| noise=False, | |
| upsample=False, | |
| emb=None, | |
| se1=None, | |
| name="0", | |
| ): | |
| x = inp | |
| if upsample: | |
| x = tf.keras.layers.Conv2DTranspose( | |
| filters, | |
| kernel_size=kernel_size, | |
| strides=strides, | |
| activation="linear", | |
| padding="same", | |
| kernel_initializer=self.init, | |
| use_bias=True, | |
| name=name + "c", | |
| )(x) | |
| else: | |
| x = tf.keras.layers.Conv2D( | |
| filters, | |
| kernel_size=kernel_size, | |
| strides=strides, | |
| activation="linear", | |
| padding="same", | |
| kernel_initializer=self.init, | |
| use_bias=True, | |
| name=name + "c", | |
| )(x) | |
| if noise: | |
| x = AddNoise(self.args.datatype, name=name + "r")(x) | |
| if emb is not None: | |
| x = self.adain(x, emb, name=name + "ai") | |
| else: | |
| x = tf.keras.layers.BatchNormalization(name=name + "bn")(x) | |
| x = tf.keras.activations.swish(x) | |
| return x | |
| def res_block_disc(self, inp, filters, kernel_size=(1, 3), kernel_size_2=None, strides=(1, 1), name="0"): | |
| if kernel_size_2 is None: | |
| kernel_size_2 = kernel_size | |
| x = tf.keras.layers.Conv2D( | |
| inp.shape[-1], | |
| kernel_size=kernel_size_2, | |
| strides=1, | |
| activation="linear", | |
| padding="same", | |
| kernel_initializer=self.init, | |
| name=name + "c0", | |
| )(inp) | |
| x = tf.keras.layers.LeakyReLU(0.2)(x) | |
| x = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * x | |
| x = tf.keras.layers.Conv2D( | |
| filters, | |
| kernel_size=kernel_size, | |
| strides=strides, | |
| activation="linear", | |
| padding="same", | |
| kernel_initializer=self.init, | |
| name=name + "c1", | |
| )(x) | |
| x = tf.keras.layers.LeakyReLU(0.2)(x) | |
| x = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * x | |
| if strides != (1, 1): | |
| inp = tf.keras.layers.AveragePooling2D(strides, padding="same")(inp) | |
| if inp.shape[-1] != filters: | |
| inp = tf.keras.layers.Conv2D( | |
| filters, | |
| kernel_size=1, | |
| strides=1, | |
| activation="linear", | |
| padding="same", | |
| kernel_initializer=self.init, | |
| use_bias=False, | |
| name=name + "c3", | |
| )(inp) | |
| return x + inp | |
| def build_encoder2(self): | |
| inpf = tf.keras.layers.Input((1, self.args.shape, self.args.hop // 4)) | |
| inpfls = tf.split(inpf, 8, -2) | |
| inpb = tf.concat(inpfls, 0) | |
| g0 = self.conv_util(inpb, self.args.hop, kernel_size=(1, 3), strides=(1, 1), padding="same", bnorm=False) | |
| g1 = self.conv_util( | |
| g0, self.args.hop + self.args.hop // 2, kernel_size=(1, 3), strides=(1, 2), padding="valid", bnorm=False | |
| ) | |
| g2 = self.conv_util( | |
| g1, self.args.hop + self.args.hop // 2, kernel_size=(1, 3), strides=(1, 1), padding="same", bnorm=False | |
| ) | |
| g3 = self.conv_util(g2, self.args.hop * 2, kernel_size=(1, 3), strides=(1, 2), padding="valid", bnorm=False) | |
| g4 = self.conv_util(g3, self.args.hop * 2, kernel_size=(1, 3), strides=(1, 1), padding="same", bnorm=False) | |
| g5 = self.conv_util(g4, self.args.hop * 3, kernel_size=(1, 3), strides=(1, 1), padding="valid", bnorm=False) | |
| g5 = self.conv_util(g5, self.args.hop * 3, kernel_size=(1, 1), strides=(1, 1), padding="valid", bnorm=False) | |
| g = tf.keras.layers.Conv2D( | |
| self.args.latdepth, | |
| kernel_size=(1, 1), | |
| strides=1, | |
| padding="valid", | |
| kernel_initializer=self.init, | |
| name="cbottle", | |
| activation="tanh", | |
| )(g5) | |
| gls = tf.split(g, 8, 0) | |
| g = tf.concat(gls, -2) | |
| gls = tf.split(g, 2, -2) | |
| g = tf.concat(gls, 0) | |
| gf = tf.cast(g, tf.float32) | |
| return tf.keras.Model(inpf, gf, name="ENC2") | |
| def build_decoder2(self): | |
| inpf = tf.keras.layers.Input((1, self.args.shape // 32, self.args.latdepth)) | |
| g = inpf | |
| g = self.conv_util( | |
| g, self.args.hop * 3, kernel_size=(1, 3), strides=(1, 1), upsample=False, noise=True, bnorm=False | |
| ) | |
| g = self.conv_util( | |
| g, | |
| self.args.hop * 2 + self.args.hop // 2, | |
| kernel_size=(1, 4), | |
| strides=(1, 2), | |
| upsample=True, | |
| noise=True, | |
| bnorm=False, | |
| ) | |
| g = self.conv_util( | |
| g, | |
| self.args.hop * 2 + self.args.hop // 2, | |
| kernel_size=(1, 3), | |
| strides=(1, 1), | |
| upsample=False, | |
| noise=True, | |
| bnorm=False, | |
| ) | |
| g = self.conv_util( | |
| g, self.args.hop * 2, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False | |
| ) | |
| g = self.conv_util( | |
| g, self.args.hop * 2, kernel_size=(1, 3), strides=(1, 1), upsample=False, noise=True, bnorm=False | |
| ) | |
| g = self.conv_util( | |
| g, | |
| self.args.hop + self.args.hop // 2, | |
| kernel_size=(1, 4), | |
| strides=(1, 2), | |
| upsample=True, | |
| noise=True, | |
| bnorm=False, | |
| ) | |
| g = self.conv_util(g, self.args.hop, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False) | |
| gf = tf.keras.layers.Conv2D( | |
| self.args.hop // 4, kernel_size=(1, 1), strides=1, padding="same", kernel_initializer=self.init, name="cout" | |
| )(g) | |
| gfls = tf.split(gf, 2, 0) | |
| gf = tf.concat(gfls, -2) | |
| gf = tf.cast(gf, tf.float32) | |
| return tf.keras.Model(inpf, gf, name="DEC2") | |
| def build_encoder(self): | |
| dim = ((4 * self.args.hop) // 2) + 1 | |
| inpf = tf.keras.layers.Input((dim, self.args.shape, 1)) | |
| ginp = tf.transpose(inpf, [0, 3, 2, 1]) | |
| g0 = self.conv_util(ginp, self.args.hop * 4, kernel_size=(1, 1), strides=(1, 1), padding="valid", bnorm=False) | |
| g1 = self.conv_util(g0, self.args.hop * 4, kernel_size=(1, 1), strides=(1, 1), padding="valid", bnorm=False) | |
| g2 = self.conv_util(g1, self.args.hop * 4, kernel_size=(1, 1), strides=(1, 1), padding="valid", bnorm=False) | |
| g4 = self.conv_util(g2, self.args.hop * 4, kernel_size=(1, 1), strides=(1, 1), padding="valid", bnorm=False) | |
| g5 = self.conv_util(g4, self.args.hop * 4, kernel_size=(1, 1), strides=(1, 1), padding="valid", bnorm=False) | |
| g = tf.keras.layers.Conv2D( | |
| self.args.hop // 4, kernel_size=(1, 1), strides=1, padding="valid", kernel_initializer=self.init | |
| )(g5) | |
| g = tf.keras.activations.tanh(g) | |
| gls = tf.split(g, 2, -2) | |
| g = tf.concat(gls, 0) | |
| gf = tf.cast(g, tf.float32) | |
| return tf.keras.Model(inpf, gf, name="ENC") | |
| def build_decoder(self): | |
| dim = ((4 * self.args.hop) // 2) + 1 | |
| inpf = tf.keras.layers.Input((1, self.args.shape // 2, self.args.hop // 4)) | |
| g = inpf | |
| g0 = self.conv_util(g, self.args.hop * 3, kernel_size=(1, 3), strides=(1, 1), noise=True, bnorm=False) | |
| g1 = self.conv_util(g0, self.args.hop * 3, kernel_size=(1, 3), strides=(1, 2), noise=True, bnorm=False) | |
| g2 = self.conv_util(g1, self.args.hop * 2, kernel_size=(1, 3), strides=(1, 2), noise=True, bnorm=False) | |
| g3 = self.conv_util(g2, self.args.hop, kernel_size=(1, 3), strides=(1, 2), noise=True, bnorm=False) | |
| g = self.conv_util(g3, self.args.hop, kernel_size=(1, 3), strides=(1, 2), noise=True, bnorm=False) | |
| g33 = self.conv_util( | |
| g, self.args.hop, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False | |
| ) | |
| g22 = self.conv_util( | |
| g3, self.args.hop * 2, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False | |
| ) | |
| g11 = self.conv_util( | |
| g22 + g2, self.args.hop * 3, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False | |
| ) | |
| g00 = self.conv_util( | |
| g11 + g1, self.args.hop * 3, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False | |
| ) | |
| g = tf.keras.layers.Conv2D( | |
| dim, kernel_size=(1, 1), strides=(1, 1), kernel_initializer=self.init, padding="same" | |
| )(g00 + g0) | |
| gf = tf.clip_by_value(g, -1.0, 1.0) | |
| g = self.conv_util( | |
| g22, self.args.hop * 3, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False | |
| ) | |
| g = self.conv_util( | |
| g + g11, self.args.hop * 3, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False | |
| ) | |
| g = tf.keras.layers.Conv2D( | |
| dim, kernel_size=(1, 1), strides=(1, 1), kernel_initializer=self.init, padding="same" | |
| )(g + g00) | |
| pf = tf.clip_by_value(g, -1.0, 1.0) | |
| gfls = tf.split(gf, self.args.shape // self.args.window, 0) | |
| gf = tf.concat(gfls, -2) | |
| pfls = tf.split(pf, self.args.shape // self.args.window, 0) | |
| pf = tf.concat(pfls, -2) | |
| s = tf.transpose(gf, [0, 2, 3, 1]) | |
| p = tf.transpose(pf, [0, 2, 3, 1]) | |
| s = tf.cast(tf.squeeze(s, -1), tf.float32) | |
| p = tf.cast(tf.squeeze(p, -1), tf.float32) | |
| return tf.keras.Model(inpf, [s, p], name="DEC") | |
| def build_critic(self): | |
| sinp = tf.keras.layers.Input(shape=(1, self.args.latlen, self.args.latdepth * 2)) | |
| sf = tf.keras.layers.Conv2D( | |
| self.args.base_channels * 3, | |
| kernel_size=(1, 4), | |
| strides=(1, 2), | |
| activation="linear", | |
| padding="same", | |
| kernel_initializer=self.init, | |
| name="1c", | |
| )(sinp) | |
| sf = tf.keras.layers.LeakyReLU(0.2)(sf) | |
| sf = self.res_block_disc(sf, self.args.base_channels * 4, kernel_size=(1, 4), strides=(1, 2), name="2") | |
| sf = self.res_block_disc(sf, self.args.base_channels * 5, kernel_size=(1, 4), strides=(1, 2), name="3") | |
| sf = self.res_block_disc(sf, self.args.base_channels * 6, kernel_size=(1, 4), strides=(1, 2), name="4") | |
| sf = self.res_block_disc(sf, self.args.base_channels * 7, kernel_size=(1, 4), strides=(1, 2), name="5") | |
| if not self.args.small: | |
| sf = self.res_block_disc( | |
| sf, self.args.base_channels * 7, kernel_size=(1, 4), strides=(1, 2), kernel_size_2=(1, 1), name="6" | |
| ) | |
| sf = tf.keras.layers.Conv2D( | |
| self.args.base_channels * 7, | |
| kernel_size=(1, 3), | |
| strides=(1, 1), | |
| activation="linear", | |
| padding="same", | |
| kernel_initializer=self.init, | |
| name="7c", | |
| )(sf) | |
| sf = tf.keras.layers.LeakyReLU(0.2)(sf) | |
| gf = tf.keras.layers.Dense(1, activation="linear", use_bias=True, kernel_initializer=self.init, name="7d")( | |
| tf.keras.layers.Flatten()(sf) | |
| ) | |
| gf = tf.cast(gf, tf.float32) | |
| return tf.keras.Model(sinp, gf, name="C") | |
| def build_generator(self): | |
| dim = self.args.latdepth * 2 | |
| inpf = tf.keras.layers.Input((self.args.latlen, self.args.latdepth * 2)) | |
| inpfls = tf.split(inpf, 2, -2) | |
| inpb = tf.concat(inpfls, 0) | |
| inpg = tf.reduce_mean(inpb, -2) | |
| inp1 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(tf.expand_dims(inpb, -3)) | |
| inp2 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp1) | |
| inp3 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp2) | |
| inp4 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp3) | |
| inp5 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp4) | |
| if not self.args.small: | |
| inp6 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp5) | |
| if not self.args.small: | |
| g = tf.keras.layers.Dense( | |
| 4 * (self.args.base_channels * 7), | |
| activation="linear", | |
| use_bias=True, | |
| kernel_initializer=self.init, | |
| name="00d", | |
| )(tf.keras.layers.Flatten()(inp6)) | |
| g = tf.keras.layers.Reshape((1, 4, self.args.base_channels * 7))(g) | |
| g = AddNoise(self.args.datatype, name="00n")(g) | |
| g = self.adain(g, inp5, name="00ai") | |
| g = tf.keras.activations.swish(g) | |
| else: | |
| g = tf.keras.layers.Dense( | |
| 4 * (self.args.base_channels * 7), | |
| activation="linear", | |
| use_bias=True, | |
| kernel_initializer=self.init, | |
| name="00d", | |
| )(tf.keras.layers.Flatten()(inp5)) | |
| g = tf.keras.layers.Reshape((1, 4, self.args.base_channels * 7))(g) | |
| g = AddNoise(self.args.datatype, name="00n")(g) | |
| g = self.adain(g, inp4, name="00ai") | |
| g = tf.keras.activations.swish(g) | |
| if not self.args.small: | |
| g1 = self.conv_util_gen( | |
| g, | |
| self.args.base_channels * 6, | |
| kernel_size=(1, 4), | |
| strides=(1, 2), | |
| upsample=True, | |
| noise=True, | |
| emb=inp4, | |
| name="0", | |
| ) | |
| g1 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g1 | |
| g1 = self.conv_util_gen( | |
| g1, | |
| self.args.base_channels * 6, | |
| kernel_size=(1, 4), | |
| strides=(1, 1), | |
| upsample=False, | |
| noise=True, | |
| emb=inp4, | |
| name="1", | |
| ) | |
| g1 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g1 | |
| g1 = g1 + tf.keras.layers.Conv2D( | |
| g1.shape[-1], | |
| kernel_size=(1, 1), | |
| strides=1, | |
| activation="linear", | |
| padding="same", | |
| kernel_initializer=self.init, | |
| use_bias=True, | |
| name="res1c", | |
| )(self.pixel_shuffle(g)) | |
| else: | |
| g1 = self.conv_util_gen( | |
| g, | |
| self.args.base_channels * 6, | |
| kernel_size=(1, 1), | |
| strides=(1, 1), | |
| upsample=False, | |
| noise=True, | |
| emb=inp4, | |
| name="0_small", | |
| ) | |
| g1 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g1 | |
| g1 = self.conv_util_gen( | |
| g1, | |
| self.args.base_channels * 6, | |
| kernel_size=(1, 1), | |
| strides=(1, 1), | |
| upsample=False, | |
| noise=True, | |
| emb=inp4, | |
| name="1_small", | |
| ) | |
| g1 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g1 | |
| g1 = g1 + tf.keras.layers.Conv2D( | |
| g1.shape[-1], | |
| kernel_size=(1, 1), | |
| strides=1, | |
| activation="linear", | |
| padding="same", | |
| kernel_initializer=self.init, | |
| use_bias=True, | |
| name="res1c_small", | |
| )(g) | |
| g2 = self.conv_util_gen( | |
| g1, | |
| self.args.base_channels * 5, | |
| kernel_size=(1, 4), | |
| strides=(1, 2), | |
| upsample=True, | |
| noise=True, | |
| emb=inp3, | |
| name="2", | |
| ) | |
| g2 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g2 | |
| g2 = self.conv_util_gen( | |
| g2, | |
| self.args.base_channels * 5, | |
| kernel_size=(1, 4), | |
| strides=(1, 1), | |
| upsample=False, | |
| noise=True, | |
| emb=inp3, | |
| name="3", | |
| ) | |
| g2 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g2 | |
| g2 = g2 + tf.keras.layers.Conv2D( | |
| g2.shape[-1], | |
| kernel_size=(1, 1), | |
| strides=1, | |
| activation="linear", | |
| padding="same", | |
| kernel_initializer=self.init, | |
| use_bias=True, | |
| name="res2c", | |
| )(self.pixel_shuffle(g1)) | |
| g3 = self.conv_util_gen( | |
| g2, | |
| self.args.base_channels * 4, | |
| kernel_size=(1, 4), | |
| strides=(1, 2), | |
| upsample=True, | |
| noise=True, | |
| emb=inp2, | |
| name="4", | |
| ) | |
| g3 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g3 | |
| g3 = self.conv_util_gen( | |
| g3, | |
| self.args.base_channels * 4, | |
| kernel_size=(1, 4), | |
| strides=(1, 1), | |
| upsample=False, | |
| noise=True, | |
| emb=inp2, | |
| name="5", | |
| ) | |
| g3 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g3 | |
| g3 = g3 + tf.keras.layers.Conv2D( | |
| g3.shape[-1], | |
| kernel_size=(1, 1), | |
| strides=1, | |
| activation="linear", | |
| padding="same", | |
| kernel_initializer=self.init, | |
| use_bias=True, | |
| name="res3c", | |
| )(self.pixel_shuffle(g2)) | |
| g4 = self.conv_util_gen( | |
| g3, | |
| self.args.base_channels * 3, | |
| kernel_size=(1, 4), | |
| strides=(1, 2), | |
| upsample=True, | |
| noise=True, | |
| emb=inp1, | |
| name="6", | |
| ) | |
| g4 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g4 | |
| g4 = self.conv_util_gen( | |
| g4, | |
| self.args.base_channels * 3, | |
| kernel_size=(1, 4), | |
| strides=(1, 1), | |
| upsample=False, | |
| noise=True, | |
| emb=inp1, | |
| name="7", | |
| ) | |
| g4 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g4 | |
| g4 = g4 + tf.keras.layers.Conv2D( | |
| g4.shape[-1], | |
| kernel_size=(1, 1), | |
| strides=1, | |
| activation="linear", | |
| padding="same", | |
| kernel_initializer=self.init, | |
| use_bias=True, | |
| name="res4c", | |
| )(self.pixel_shuffle(g3)) | |
| g5 = self.conv_util_gen( | |
| g4, | |
| self.args.base_channels * 2, | |
| kernel_size=(1, 4), | |
| strides=(1, 2), | |
| upsample=True, | |
| noise=True, | |
| emb=tf.expand_dims(tf.cast(inpb, dtype=self.args.datatype), -3), | |
| name="8", | |
| ) | |
| gf = tf.keras.layers.Conv2D( | |
| dim, kernel_size=(1, 4), strides=(1, 1), kernel_initializer=self.init, padding="same", name="9c" | |
| )(g5) | |
| gfls = tf.split(gf, 2, 0) | |
| gf = tf.concat(gfls, -2) | |
| gf = tf.cast(gf, tf.float32) | |
| return tf.keras.Model(inpf, gf, name="GEN") | |
| # Load past models from path to resume training or test | |
| def load(self, path, load_dec=False): | |
| gen = self.build_generator() | |
| critic = self.build_critic() | |
| enc = self.build_encoder() | |
| dec = self.build_decoder() | |
| enc2 = self.build_encoder2() | |
| dec2 = self.build_decoder2() | |
| gen_ema = self.build_generator() | |
| switch = tf.Variable(-1.0, dtype=tf.float32) | |
| if self.args.mixed_precision: | |
| opt_disc = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.5)) | |
| opt_dec = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.5)) | |
| else: | |
| opt_disc = tf.keras.optimizers.Adam(0.0001, 0.9) | |
| opt_dec = tf.keras.optimizers.Adam(0.0001, 0.9) | |
| if load_dec: | |
| dec.load_weights(self.args.dec_path + "/dec.h5") | |
| dec2.load_weights(self.args.dec_path + "/dec2.h5") | |
| enc.load_weights(self.args.dec_path + "/enc.h5") | |
| enc2.load_weights(self.args.dec_path + "/enc2.h5") | |
| else: | |
| grad_vars = critic.trainable_weights | |
| zero_grads = [tf.zeros_like(w) for w in grad_vars] | |
| opt_disc.apply_gradients(zip(zero_grads, grad_vars)) | |
| grad_vars = gen.trainable_variables | |
| zero_grads = [tf.zeros_like(w) for w in grad_vars] | |
| opt_dec.apply_gradients(zip(zero_grads, grad_vars)) | |
| if not self.args.testing: | |
| opt_disc.set_weights(np.load(path + "/opt_disc.npy", allow_pickle=True)) | |
| opt_dec.set_weights(np.load(path + "/opt_dec.npy", allow_pickle=True)) | |
| critic.load_weights(path + "/critic.h5") | |
| gen.load_weights(path + "/gen.h5") | |
| switch = tf.Variable(float(np.load(path + "/switch.npy", allow_pickle=True)), dtype=tf.float32) | |
| gen_ema.load_weights(path + "/gen_ema.h5") | |
| dec.load_weights(self.args.dec_path + "/dec.h5") | |
| dec2.load_weights(self.args.dec_path + "/dec2.h5") | |
| enc.load_weights(self.args.dec_path + "/enc.h5") | |
| enc2.load_weights(self.args.dec_path + "/enc2.h5") | |
| return ( | |
| critic, | |
| gen, | |
| enc, | |
| dec, | |
| enc2, | |
| dec2, | |
| gen_ema, | |
| [opt_dec, opt_disc], | |
| switch, | |
| ) | |
| def build(self): | |
| gen = self.build_generator() | |
| critic = self.build_critic() | |
| enc = self.build_encoder() | |
| dec = self.build_decoder() | |
| enc2 = self.build_encoder2() | |
| dec2 = self.build_decoder2() | |
| gen_ema = self.build_generator() | |
| switch = tf.Variable(-1.0, dtype=tf.float32) | |
| gen_ema = tf.keras.models.clone_model(gen) | |
| gen_ema.set_weights(gen.get_weights()) | |
| if self.args.mixed_precision: | |
| opt_disc = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.5)) | |
| opt_dec = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.5)) | |
| else: | |
| opt_disc = tf.keras.optimizers.Adam(0.0001, 0.5) | |
| opt_dec = tf.keras.optimizers.Adam(0.0001, 0.5) | |
| return ( | |
| critic, | |
| gen, | |
| enc, | |
| dec, | |
| enc2, | |
| dec2, | |
| gen_ema, | |
| [opt_dec, opt_disc], | |
| switch, | |
| ) | |
| def get_networks(self): | |
| ( | |
| critic, | |
| gen, | |
| enc, | |
| dec, | |
| enc2, | |
| dec2, | |
| gen_ema_1, | |
| [opt_dec, opt_disc], | |
| switch, | |
| ) = self.load(self.args.load_path_1, load_dec=False) | |
| print(f"Networks loaded from {self.args.load_path_1}") | |
| ( | |
| critic, | |
| gen, | |
| enc, | |
| dec, | |
| enc2, | |
| dec2, | |
| gen_ema_2, | |
| [opt_dec, opt_disc], | |
| switch, | |
| ) = self.load(self.args.load_path_2, load_dec=False) | |
| print(f"Networks loaded from {self.args.load_path_2}") | |
| ( | |
| critic, | |
| gen, | |
| enc, | |
| dec, | |
| enc2, | |
| dec2, | |
| gen_ema_3, | |
| [opt_dec, opt_disc], | |
| switch, | |
| ) = self.load(self.args.load_path_3, load_dec=False) | |
| print(f"Networks loaded from {self.args.load_path_3}") | |
| return ( | |
| (critic, gen, enc, dec, enc2, dec2, gen_ema_1, [opt_dec, opt_disc], switch), | |
| (critic, gen, enc, dec, enc2, dec2, gen_ema_2, [opt_dec, opt_disc], switch), | |
| (critic, gen, enc, dec, enc2, dec2, gen_ema_3, [opt_dec, opt_disc], switch), | |
| ) | |
| def initialize_networks(self): | |
| ( | |
| (critic, gen, enc, dec, enc2, dec2, gen_ema_1, [opt_dec, opt_disc], switch), | |
| (critic, gen, enc, dec, enc2, dec2, gen_ema_2, [opt_dec, opt_disc], switch), | |
| (critic, gen, enc, dec, enc2, dec2, gen_ema_3, [opt_dec, opt_disc], switch), | |
| ) = self.get_networks() | |
| print(f"Critic params: {count_params(critic.trainable_variables)}") | |
| print(f"Generator params: {count_params(gen.trainable_variables)}") | |
| return ( | |
| (critic, gen, enc, dec, enc2, dec2, gen_ema_1, [opt_dec, opt_disc], switch), | |
| (critic, gen, enc, dec, enc2, dec2, gen_ema_2, [opt_dec, opt_disc], switch), | |
| (critic, gen, enc, dec, enc2, dec2, gen_ema_3, [opt_dec, opt_disc], switch), | |
| ) | |