File size: 1,601 Bytes
3be620b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import tensorflow as tf
from ganime.model.vqgan_clean.vqgan import VQGAN
def load_model(
model: str, config: dict, strategy: tf.distribute.Strategy
) -> tf.keras.Model:
if model == "vqgan":
with strategy.scope():
print(config["model"])
model = VQGAN(**config["model"])
gen_optimizer = tf.keras.optimizers.Adam(
learning_rate=config["trainer"]["gen_lr"],
beta_1=config["trainer"]["gen_beta_1"],
beta_2=config["trainer"]["gen_beta_2"],
clipnorm=config["trainer"]["gen_clip_norm"],
)
disc_optimizer = tf.keras.optimizers.Adam(
learning_rate=config["trainer"]["disc_lr"],
beta_1=config["trainer"]["disc_beta_1"],
beta_2=config["trainer"]["disc_beta_2"],
clipnorm=config["trainer"]["disc_clip_norm"],
)
model.compile(gen_optimizer=gen_optimizer, disc_optimizer=disc_optimizer)
return model
else:
raise ValueError(f"Unknown model: {model}")
# if model == "moving_vae":
# from ganime.model.moving_vae import MovingVAE
# with strategy.scope():
# model = MovingVAE(input_shape=input_shape)
# negloglik = lambda x, rv_x: -rv_x.log_prob(x)
# model.compile(
# optimizer=tf.optimizers.Adam(learning_rate=config["lr"]),
# loss=negloglik,
# )
# # model.build(input_shape=(None, *input_shape))
# # model.summary()
# return model
|