|
import tensorflow as tf |
|
from ganime.model.vqgan_clean.vqgan import VQGAN |
|
|
|
|
|
def load_model( |
|
model: str, config: dict, strategy: tf.distribute.Strategy |
|
) -> tf.keras.Model: |
|
|
|
if model == "vqgan": |
|
with strategy.scope(): |
|
print(config["model"]) |
|
model = VQGAN(**config["model"]) |
|
|
|
gen_optimizer = tf.keras.optimizers.Adam( |
|
learning_rate=config["trainer"]["gen_lr"], |
|
beta_1=config["trainer"]["gen_beta_1"], |
|
beta_2=config["trainer"]["gen_beta_2"], |
|
clipnorm=config["trainer"]["gen_clip_norm"], |
|
) |
|
disc_optimizer = tf.keras.optimizers.Adam( |
|
learning_rate=config["trainer"]["disc_lr"], |
|
beta_1=config["trainer"]["disc_beta_1"], |
|
beta_2=config["trainer"]["disc_beta_2"], |
|
clipnorm=config["trainer"]["disc_clip_norm"], |
|
) |
|
model.compile(gen_optimizer=gen_optimizer, disc_optimizer=disc_optimizer) |
|
return model |
|
else: |
|
raise ValueError(f"Unknown model: {model}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|