# -*- coding: utf-8 -*- # Copyright 2020 Minh Nguyen (@dathudeptrai) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Based Trainer.""" import abc import logging import os import tensorflow as tf from tqdm import tqdm from tensorflow_tts.optimizers import GradientAccumulator from tensorflow_tts.utils import utils class BasedTrainer(metaclass=abc.ABCMeta): """Customized trainer module for all models.""" def __init__(self, steps, epochs, config): self.steps = steps self.epochs = epochs self.config = config self.finish_train = False self.writer = tf.summary.create_file_writer(config["outdir"]) self.train_data_loader = None self.eval_data_loader = None self.train_metrics = None self.eval_metrics = None self.list_metrics_name = None def init_train_eval_metrics(self, list_metrics_name): """Init train and eval metrics to save it to tensorboard.""" self.train_metrics = {} self.eval_metrics = {} for name in list_metrics_name: self.train_metrics.update( {name: tf.keras.metrics.Mean(name="train_" + name, dtype=tf.float32)} ) self.eval_metrics.update( {name: tf.keras.metrics.Mean(name="eval_" + name, dtype=tf.float32)} ) def reset_states_train(self): """Reset train metrics after save it to tensorboard.""" for metric in self.train_metrics.keys(): self.train_metrics[metric].reset_states() def reset_states_eval(self): """Reset eval metrics after save it to tensorboard.""" for metric in self.eval_metrics.keys(): self.eval_metrics[metric].reset_states() def update_train_metrics(self, dict_metrics_losses): for name, value in dict_metrics_losses.items(): self.train_metrics[name].update_state(value) def update_eval_metrics(self, dict_metrics_losses): for name, value in dict_metrics_losses.items(): self.eval_metrics[name].update_state(value) def set_train_data_loader(self, train_dataset): """Set train data loader (MUST).""" self.train_data_loader = train_dataset def get_train_data_loader(self): """Get train data loader.""" return self.train_data_loader def set_eval_data_loader(self, eval_dataset): """Set eval data loader (MUST).""" self.eval_data_loader = eval_dataset def get_eval_data_loader(self): """Get eval data loader.""" return self.eval_data_loader @abc.abstractmethod def compile(self): pass @abc.abstractmethod def create_checkpoint_manager(self, saved_path=None, max_to_keep=10): """Create checkpoint management.""" pass def run(self): """Run training.""" self.tqdm = tqdm( initial=self.steps, total=self.config["train_max_steps"], desc="[train]" ) while True: self._train_epoch() if self.finish_train: break self.tqdm.close() logging.info("Finish training.") @abc.abstractmethod def save_checkpoint(self): """Save checkpoint.""" pass @abc.abstractmethod def load_checkpoint(self, pretrained_path): """Load checkpoint.""" pass def _train_epoch(self): """Train model one epoch.""" for train_steps_per_epoch, batch in enumerate(self.train_data_loader, 1): # one step training self._train_step(batch) # check interval self._check_log_interval() self._check_eval_interval() self._check_save_interval() # check wheter training is finished if self.finish_train: return # update self.epochs += 1 self.train_steps_per_epoch = train_steps_per_epoch logging.info( f"(Steps: {self.steps}) Finished {self.epochs} epoch training " f"({self.train_steps_per_epoch} steps per epoch)." ) @abc.abstractmethod def _eval_epoch(self): """One epoch evaluation.""" pass @abc.abstractmethod def _train_step(self, batch): """One step training.""" pass @abc.abstractmethod def _check_log_interval(self): """Save log interval.""" pass @abc.abstractmethod def fit(self): pass def _check_eval_interval(self): """Evaluation interval step.""" if self.steps % self.config["eval_interval_steps"] == 0: self._eval_epoch() def _check_save_interval(self): """Save interval checkpoint.""" if self.steps % self.config["save_interval_steps"] == 0: self.save_checkpoint() logging.info(f"Successfully saved checkpoint @ {self.steps} steps.") def generate_and_save_intermediate_result(self, batch): """Generate and save intermediate result.""" pass def _write_to_tensorboard(self, list_metrics, stage="train"): """Write variables to tensorboard.""" with self.writer.as_default(): for key, value in list_metrics.items(): tf.summary.scalar(stage + "/" + key, value.result(), step=self.steps) self.writer.flush() class GanBasedTrainer(BasedTrainer): """Customized trainer module for GAN TTS training (MelGAN, GAN-TTS, ParallelWaveGAN).""" def __init__( self, steps, epochs, config, strategy, is_generator_mixed_precision=False, is_discriminator_mixed_precision=False, ): """Initialize trainer. Args: steps (int): Initial global steps. epochs (int): Initial global epochs. config (dict): Config dict loaded from yaml format configuration file. """ super().__init__(steps, epochs, config) self._is_generator_mixed_precision = is_generator_mixed_precision self._is_discriminator_mixed_precision = is_discriminator_mixed_precision self._strategy = strategy self._already_apply_input_signature = False self._generator_gradient_accumulator = GradientAccumulator() self._discriminator_gradient_accumulator = GradientAccumulator() self._generator_gradient_accumulator.reset() self._discriminator_gradient_accumulator.reset() def init_train_eval_metrics(self, list_metrics_name): with self._strategy.scope(): super().init_train_eval_metrics(list_metrics_name) def get_n_gpus(self): return self._strategy.num_replicas_in_sync def _get_train_element_signature(self): return self.train_data_loader.element_spec def _get_eval_element_signature(self): return self.eval_data_loader.element_spec def set_gen_model(self, generator_model): """Set generator class model (MUST).""" self._generator = generator_model def get_gen_model(self): """Get generator model.""" return self._generator def set_dis_model(self, discriminator_model): """Set discriminator class model (MUST).""" self._discriminator = discriminator_model def get_dis_model(self): """Get discriminator model.""" return self._discriminator def set_gen_optimizer(self, generator_optimizer): """Set generator optimizer (MUST).""" self._gen_optimizer = generator_optimizer if self._is_generator_mixed_precision: self._gen_optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( self._gen_optimizer, "dynamic" ) def get_gen_optimizer(self): """Get generator optimizer.""" return self._gen_optimizer def set_dis_optimizer(self, discriminator_optimizer): """Set discriminator optimizer (MUST).""" self._dis_optimizer = discriminator_optimizer if self._is_discriminator_mixed_precision: self._dis_optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( self._dis_optimizer, "dynamic" ) def get_dis_optimizer(self): """Get discriminator optimizer.""" return self._dis_optimizer def compile(self, gen_model, dis_model, gen_optimizer, dis_optimizer): self.set_gen_model(gen_model) self.set_dis_model(dis_model) self.set_gen_optimizer(gen_optimizer) self.set_dis_optimizer(dis_optimizer) def _train_step(self, batch): if self._already_apply_input_signature is False: train_element_signature = self._get_train_element_signature() eval_element_signature = self._get_eval_element_signature() self.one_step_forward = tf.function( self._one_step_forward, input_signature=[train_element_signature] ) self.one_step_evaluate = tf.function( self._one_step_evaluate, input_signature=[eval_element_signature] ) self.one_step_predict = tf.function( self._one_step_predict, input_signature=[eval_element_signature] ) self._already_apply_input_signature = True # run one_step_forward self.one_step_forward(batch) # update counts self.steps += 1 self.tqdm.update(1) self._check_train_finish() def _one_step_forward(self, batch): per_replica_losses = self._strategy.run( self._one_step_forward_per_replica, args=(batch,) ) return self._strategy.reduce( tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None ) @abc.abstractmethod def compute_per_example_generator_losses(self, batch, outputs): """Compute per example generator losses and return dict_metrics_losses Note that all element of the loss MUST has a shape [batch_size] and the keys of dict_metrics_losses MUST be in self.list_metrics_name. Args: batch: dictionary batch input return from dataloader outputs: outputs of the model Returns: per_example_losses: per example losses for each GPU, shape [B] dict_metrics_losses: dictionary loss. """ per_example_losses = 0.0 dict_metrics_losses = {} return per_example_losses, dict_metrics_losses @abc.abstractmethod def compute_per_example_discriminator_losses(self, batch, gen_outputs): """Compute per example discriminator losses and return dict_metrics_losses Note that all element of the loss MUST has a shape [batch_size] and the keys of dict_metrics_losses MUST be in self.list_metrics_name. Args: batch: dictionary batch input return from dataloader outputs: outputs of the model Returns: per_example_losses: per example losses for each GPU, shape [B] dict_metrics_losses: dictionary loss. """ per_example_losses = 0.0 dict_metrics_losses = {} return per_example_losses, dict_metrics_losses def _calculate_generator_gradient_per_batch(self, batch): outputs = self._generator(**batch, training=True) ( per_example_losses, dict_metrics_losses, ) = self.compute_per_example_generator_losses(batch, outputs) per_replica_gen_losses = tf.nn.compute_average_loss( per_example_losses, global_batch_size=self.config["batch_size"] * self.get_n_gpus() * self.config["gradient_accumulation_steps"], ) if self._is_generator_mixed_precision: scaled_per_replica_gen_losses = self._gen_optimizer.get_scaled_loss( per_replica_gen_losses ) if self._is_generator_mixed_precision: scaled_gradients = tf.gradients( scaled_per_replica_gen_losses, self._generator.trainable_variables ) gradients = self._gen_optimizer.get_unscaled_gradients(scaled_gradients) else: gradients = tf.gradients( per_replica_gen_losses, self._generator.trainable_variables ) # gradient accumulate for generator here if self.config["gradient_accumulation_steps"] > 1: self._generator_gradient_accumulator(gradients) # accumulate loss into metrics self.update_train_metrics(dict_metrics_losses) if self.config["gradient_accumulation_steps"] == 1: return gradients, per_replica_gen_losses else: return per_replica_gen_losses def _calculate_discriminator_gradient_per_batch(self, batch): ( per_example_losses, dict_metrics_losses, ) = self.compute_per_example_discriminator_losses( batch, self._generator(**batch, training=True) ) per_replica_dis_losses = tf.nn.compute_average_loss( per_example_losses, global_batch_size=self.config["batch_size"] * self.get_n_gpus() * self.config["gradient_accumulation_steps"], ) if self._is_discriminator_mixed_precision: scaled_per_replica_dis_losses = self._dis_optimizer.get_scaled_loss( per_replica_dis_losses ) if self._is_discriminator_mixed_precision: scaled_gradients = tf.gradients( scaled_per_replica_dis_losses, self._discriminator.trainable_variables, ) gradients = self._dis_optimizer.get_unscaled_gradients(scaled_gradients) else: gradients = tf.gradients( per_replica_dis_losses, self._discriminator.trainable_variables ) # accumulate loss into metrics self.update_train_metrics(dict_metrics_losses) # gradient accumulate for discriminator here if self.config["gradient_accumulation_steps"] > 1: self._discriminator_gradient_accumulator(gradients) if self.config["gradient_accumulation_steps"] == 1: return gradients, per_replica_dis_losses else: return per_replica_dis_losses def _one_step_forward_per_replica(self, batch): per_replica_gen_losses = 0.0 per_replica_dis_losses = 0.0 if self.config["gradient_accumulation_steps"] == 1: ( gradients, per_replica_gen_losses, ) = self._calculate_generator_gradient_per_batch(batch) self._gen_optimizer.apply_gradients( zip(gradients, self._generator.trainable_variables) ) else: # gradient acummulation here. for i in tf.range(self.config["gradient_accumulation_steps"]): reduced_batch = { k: v[ i * self.config["batch_size"] : (i + 1) * self.config["batch_size"] ] for k, v in batch.items() } # run 1 step accumulate reduced_batch_losses = self._calculate_generator_gradient_per_batch( reduced_batch ) # sum per_replica_losses per_replica_gen_losses += reduced_batch_losses gradients = self._generator_gradient_accumulator.gradients self._gen_optimizer.apply_gradients( zip(gradients, self._generator.trainable_variables) ) self._generator_gradient_accumulator.reset() # one step discriminator # recompute y_hat after 1 step generator for discriminator training. if self.steps >= self.config["discriminator_train_start_steps"]: if self.config["gradient_accumulation_steps"] == 1: ( gradients, per_replica_dis_losses, ) = self._calculate_discriminator_gradient_per_batch(batch) self._dis_optimizer.apply_gradients( zip(gradients, self._discriminator.trainable_variables) ) else: # gradient acummulation here. for i in tf.range(self.config["gradient_accumulation_steps"]): reduced_batch = { k: v[ i * self.config["batch_size"] : (i + 1) * self.config["batch_size"] ] for k, v in batch.items() } # run 1 step accumulate reduced_batch_losses = ( self._calculate_discriminator_gradient_per_batch(reduced_batch) ) # sum per_replica_losses per_replica_dis_losses += reduced_batch_losses gradients = self._discriminator_gradient_accumulator.gradients self._dis_optimizer.apply_gradients( zip(gradients, self._discriminator.trainable_variables) ) self._discriminator_gradient_accumulator.reset() return per_replica_gen_losses + per_replica_dis_losses def _eval_epoch(self): """Evaluate model one epoch.""" logging.info(f"(Steps: {self.steps}) Start evaluation.") # calculate loss for each batch for eval_steps_per_epoch, batch in enumerate( tqdm(self.eval_data_loader, desc="[eval]"), 1 ): # eval one step self.one_step_evaluate(batch) if eval_steps_per_epoch <= self.config["num_save_intermediate_results"]: # save intermedia self.generate_and_save_intermediate_result(batch) logging.info( f"(Steps: {self.steps}) Finished evaluation " f"({eval_steps_per_epoch} steps per epoch)." ) # average loss for key in self.eval_metrics.keys(): logging.info( f"(Steps: {self.steps}) eval_{key} = {self.eval_metrics[key].result():.4f}." ) # record self._write_to_tensorboard(self.eval_metrics, stage="eval") # reset self.reset_states_eval() def _one_step_evaluate_per_replica(self, batch): ################################################ # one step generator. outputs = self._generator(**batch, training=False) _, dict_metrics_losses = self.compute_per_example_generator_losses( batch, outputs ) # accumulate loss into metrics self.update_eval_metrics(dict_metrics_losses) ################################################ # one step discriminator if self.steps >= self.config["discriminator_train_start_steps"]: _, dict_metrics_losses = self.compute_per_example_discriminator_losses( batch, outputs ) # accumulate loss into metrics self.update_eval_metrics(dict_metrics_losses) ################################################ def _one_step_evaluate(self, batch): self._strategy.run(self._one_step_evaluate_per_replica, args=(batch,)) def _one_step_predict_per_replica(self, batch): outputs = self._generator(**batch, training=False) return outputs def _one_step_predict(self, batch): outputs = self._strategy.run(self._one_step_predict_per_replica, args=(batch,)) return outputs @abc.abstractmethod def generate_and_save_intermediate_result(self, batch): return def create_checkpoint_manager(self, saved_path=None, max_to_keep=10): """Create checkpoint management.""" if saved_path is None: saved_path = self.config["outdir"] + "/checkpoints/" os.makedirs(saved_path, exist_ok=True) self.saved_path = saved_path self.ckpt = tf.train.Checkpoint( steps=tf.Variable(1), epochs=tf.Variable(1), gen_optimizer=self.get_gen_optimizer(), dis_optimizer=self.get_dis_optimizer(), ) self.ckp_manager = tf.train.CheckpointManager( self.ckpt, saved_path, max_to_keep=max_to_keep ) def save_checkpoint(self): """Save checkpoint.""" self.ckpt.steps.assign(self.steps) self.ckpt.epochs.assign(self.epochs) self.ckp_manager.save(checkpoint_number=self.steps) utils.save_weights( self._generator, self.saved_path + "generator-{}.h5".format(self.steps) ) utils.save_weights( self._discriminator, self.saved_path + "discriminator-{}.h5".format(self.steps) ) def load_checkpoint(self, pretrained_path): """Load checkpoint.""" self.ckpt.restore(pretrained_path) self.steps = self.ckpt.steps.numpy() self.epochs = self.ckpt.epochs.numpy() self._gen_optimizer = self.ckpt.gen_optimizer # re-assign iterations (global steps) for gen_optimizer. self._gen_optimizer.iterations.assign(tf.cast(self.steps, tf.int64)) # re-assign iterations (global steps) for dis_optimizer. try: discriminator_train_start_steps = self.config[ "discriminator_train_start_steps" ] discriminator_train_start_steps = tf.math.maximum( 0, self.steps - discriminator_train_start_steps ) except Exception: discriminator_train_start_steps = self.steps self._dis_optimizer = self.ckpt.dis_optimizer self._dis_optimizer.iterations.assign( tf.cast(discriminator_train_start_steps, tf.int64) ) # load weights. utils.load_weights( self._generator, self.saved_path + "generator-{}.h5".format(self.steps) ) utils.load_weights( self._discriminator, self.saved_path + "discriminator-{}.h5".format(self.steps) ) def _check_train_finish(self): """Check training finished.""" if self.steps >= self.config["train_max_steps"]: self.finish_train = True if ( self.steps != 0 and self.steps == self.config["discriminator_train_start_steps"] ): self.finish_train = True logging.info( f"Finished training only generator at {self.steps}steps, pls resume and continue training." ) def _check_log_interval(self): """Log to tensorboard.""" if self.steps % self.config["log_interval_steps"] == 0: for metric_name in self.list_metrics_name: logging.info( f"(Step: {self.steps}) train_{metric_name} = {self.train_metrics[metric_name].result():.4f}." ) self._write_to_tensorboard(self.train_metrics, stage="train") # reset self.reset_states_train() def fit(self, train_data_loader, valid_data_loader, saved_path, resume=None): self.set_train_data_loader(train_data_loader) self.set_eval_data_loader(valid_data_loader) self.train_data_loader = self._strategy.experimental_distribute_dataset( self.train_data_loader ) self.eval_data_loader = self._strategy.experimental_distribute_dataset( self.eval_data_loader ) with self._strategy.scope(): self.create_checkpoint_manager(saved_path=saved_path, max_to_keep=10000) if len(resume) > 1: self.load_checkpoint(resume) logging.info(f"Successfully resumed from {resume}.") self.run() class Seq2SeqBasedTrainer(BasedTrainer, metaclass=abc.ABCMeta): """Customized trainer module for Seq2Seq TTS training (Tacotron, FastSpeech).""" def __init__( self, steps, epochs, config, strategy, is_mixed_precision=False, ): """Initialize trainer. Args: steps (int): Initial global steps. epochs (int): Initial global epochs. config (dict): Config dict loaded from yaml format configuration file. strategy (tf.distribute): Strategy for distributed training. is_mixed_precision (bool): Use mixed_precision training or not. """ super().__init__(steps, epochs, config) self._is_mixed_precision = is_mixed_precision self._strategy = strategy self._model = None self._optimizer = None self._trainable_variables = None # check if we already apply input_signature for train_step. self._already_apply_input_signature = False # create gradient accumulator self._gradient_accumulator = GradientAccumulator() self._gradient_accumulator.reset() def init_train_eval_metrics(self, list_metrics_name): with self._strategy.scope(): super().init_train_eval_metrics(list_metrics_name) def set_model(self, model): """Set generator class model (MUST).""" self._model = model def get_model(self): """Get generator model.""" return self._model def set_optimizer(self, optimizer): """Set optimizer (MUST).""" self._optimizer = optimizer if self._is_mixed_precision: self._optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( self._optimizer, "dynamic" ) def get_optimizer(self): """Get optimizer.""" return self._optimizer def get_n_gpus(self): return self._strategy.num_replicas_in_sync def compile(self, model, optimizer): self.set_model(model) self.set_optimizer(optimizer) self._trainable_variables = self._train_vars() def _train_vars(self): if self.config["var_train_expr"]: list_train_var = self.config["var_train_expr"].split("|") return [ v for v in self._model.trainable_variables if self._check_string_exist(list_train_var, v.name) ] return self._model.trainable_variables def _check_string_exist(self, list_string, inp_string): for string in list_string: if string in inp_string: return True return False def _get_train_element_signature(self): return self.train_data_loader.element_spec def _get_eval_element_signature(self): return self.eval_data_loader.element_spec def _train_step(self, batch): if self._already_apply_input_signature is False: train_element_signature = self._get_train_element_signature() eval_element_signature = self._get_eval_element_signature() self.one_step_forward = tf.function( self._one_step_forward, input_signature=[train_element_signature] ) self.one_step_evaluate = tf.function( self._one_step_evaluate, input_signature=[eval_element_signature] ) self.one_step_predict = tf.function( self._one_step_predict, input_signature=[eval_element_signature] ) self._already_apply_input_signature = True # run one_step_forward self.one_step_forward(batch) # update counts self.steps += 1 self.tqdm.update(1) self._check_train_finish() def _one_step_forward(self, batch): per_replica_losses = self._strategy.run( self._one_step_forward_per_replica, args=(batch,) ) return self._strategy.reduce( tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None ) def _calculate_gradient_per_batch(self, batch): outputs = self._model(**batch, training=True) per_example_losses, dict_metrics_losses = self.compute_per_example_losses( batch, outputs ) per_replica_losses = tf.nn.compute_average_loss( per_example_losses, global_batch_size=self.config["batch_size"] * self.get_n_gpus() * self.config["gradient_accumulation_steps"], ) if self._is_mixed_precision: scaled_per_replica_losses = self._optimizer.get_scaled_loss( per_replica_losses ) if self._is_mixed_precision: scaled_gradients = tf.gradients( scaled_per_replica_losses, self._trainable_variables ) gradients = self._optimizer.get_unscaled_gradients(scaled_gradients) else: gradients = tf.gradients(per_replica_losses, self._trainable_variables) # gradient accumulate here if self.config["gradient_accumulation_steps"] > 1: self._gradient_accumulator(gradients) # accumulate loss into metrics self.update_train_metrics(dict_metrics_losses) if self.config["gradient_accumulation_steps"] == 1: return gradients, per_replica_losses else: return per_replica_losses def _one_step_forward_per_replica(self, batch): if self.config["gradient_accumulation_steps"] == 1: gradients, per_replica_losses = self._calculate_gradient_per_batch(batch) self._optimizer.apply_gradients( zip(gradients, self._trainable_variables), 1.0 ) else: # gradient acummulation here. per_replica_losses = 0.0 for i in tf.range(self.config["gradient_accumulation_steps"]): reduced_batch = { k: v[ i * self.config["batch_size"] : (i + 1) * self.config["batch_size"] ] for k, v in batch.items() } # run 1 step accumulate reduced_batch_losses = self._calculate_gradient_per_batch(reduced_batch) # sum per_replica_losses per_replica_losses += reduced_batch_losses gradients = self._gradient_accumulator.gradients self._optimizer.apply_gradients( zip(gradients, self._trainable_variables), 1.0 ) self._gradient_accumulator.reset() return per_replica_losses @abc.abstractmethod def compute_per_example_losses(self, batch, outputs): """Compute per example losses and return dict_metrics_losses Note that all element of the loss MUST has a shape [batch_size] and the keys of dict_metrics_losses MUST be in self.list_metrics_name. Args: batch: dictionary batch input return from dataloader outputs: outputs of the model Returns: per_example_losses: per example losses for each GPU, shape [B] dict_metrics_losses: dictionary loss. """ per_example_losses = 0.0 dict_metrics_losses = {} return per_example_losses, dict_metrics_losses def _eval_epoch(self): """Evaluate model one epoch.""" logging.info(f"(Steps: {self.steps}) Start evaluation.") # calculate loss for each batch for eval_steps_per_epoch, batch in enumerate( tqdm(self.eval_data_loader, desc="[eval]"), 1 ): # eval one step self.one_step_evaluate(batch) if eval_steps_per_epoch <= self.config["num_save_intermediate_results"]: # save intermedia self.generate_and_save_intermediate_result(batch) logging.info( f"(Steps: {self.steps}) Finished evaluation " f"({eval_steps_per_epoch} steps per epoch)." ) # average loss for key in self.eval_metrics.keys(): logging.info( f"(Steps: {self.steps}) eval_{key} = {self.eval_metrics[key].result():.4f}." ) # record self._write_to_tensorboard(self.eval_metrics, stage="eval") # reset self.reset_states_eval() def _one_step_evaluate_per_replica(self, batch): outputs = self._model(**batch, training=False) _, dict_metrics_losses = self.compute_per_example_losses(batch, outputs) self.update_eval_metrics(dict_metrics_losses) def _one_step_evaluate(self, batch): self._strategy.run(self._one_step_evaluate_per_replica, args=(batch,)) def _one_step_predict_per_replica(self, batch): outputs = self._model(**batch, training=False) return outputs def _one_step_predict(self, batch): outputs = self._strategy.run(self._one_step_predict_per_replica, args=(batch,)) return outputs @abc.abstractmethod def generate_and_save_intermediate_result(self, batch): return def create_checkpoint_manager(self, saved_path=None, max_to_keep=10): """Create checkpoint management.""" if saved_path is None: saved_path = self.config["outdir"] + "/checkpoints/" os.makedirs(saved_path, exist_ok=True) self.saved_path = saved_path self.ckpt = tf.train.Checkpoint( steps=tf.Variable(1), epochs=tf.Variable(1), optimizer=self.get_optimizer() ) self.ckp_manager = tf.train.CheckpointManager( self.ckpt, saved_path, max_to_keep=max_to_keep ) def save_checkpoint(self): """Save checkpoint.""" self.ckpt.steps.assign(self.steps) self.ckpt.epochs.assign(self.epochs) self.ckp_manager.save(checkpoint_number=self.steps) utils.save_weights( self._model, self.saved_path + "model-{}.h5".format(self.steps) ) def load_checkpoint(self, pretrained_path): """Load checkpoint.""" self.ckpt.restore(pretrained_path) self.steps = self.ckpt.steps.numpy() self.epochs = self.ckpt.epochs.numpy() self._optimizer = self.ckpt.optimizer # re-assign iterations (global steps) for optimizer. self._optimizer.iterations.assign(tf.cast(self.steps, tf.int64)) # load weights. utils.load_weights( self._model, self.saved_path + "model-{}.h5".format(self.steps) ) def _check_train_finish(self): """Check training finished.""" if self.steps >= self.config["train_max_steps"]: self.finish_train = True def _check_log_interval(self): """Log to tensorboard.""" if self.steps % self.config["log_interval_steps"] == 0: for metric_name in self.list_metrics_name: logging.info( f"(Step: {self.steps}) train_{metric_name} = {self.train_metrics[metric_name].result():.4f}." ) self._write_to_tensorboard(self.train_metrics, stage="train") # reset self.reset_states_train() def fit(self, train_data_loader, valid_data_loader, saved_path, resume=None): self.set_train_data_loader(train_data_loader) self.set_eval_data_loader(valid_data_loader) self.train_data_loader = self._strategy.experimental_distribute_dataset( self.train_data_loader ) self.eval_data_loader = self._strategy.experimental_distribute_dataset( self.eval_data_loader ) with self._strategy.scope(): self.create_checkpoint_manager(saved_path=saved_path, max_to_keep=10000) if len(resume) > 1: self.load_checkpoint(resume) logging.info(f"Successfully resumed from {resume}.") self.run()