Comparative-Analysis-of-Speech-Synthesis-Models
/
TensorFlowTTS
/tensorflow_tts
/trainers
/base_trainer.py
# -*- coding: utf-8 -*- | |
# Copyright 2020 Minh Nguyen (@dathudeptrai) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Based Trainer.""" | |
import abc | |
import logging | |
import os | |
import tensorflow as tf | |
from tqdm import tqdm | |
from tensorflow_tts.optimizers import GradientAccumulator | |
from tensorflow_tts.utils import utils | |
class BasedTrainer(metaclass=abc.ABCMeta): | |
"""Customized trainer module for all models.""" | |
def __init__(self, steps, epochs, config): | |
self.steps = steps | |
self.epochs = epochs | |
self.config = config | |
self.finish_train = False | |
self.writer = tf.summary.create_file_writer(config["outdir"]) | |
self.train_data_loader = None | |
self.eval_data_loader = None | |
self.train_metrics = None | |
self.eval_metrics = None | |
self.list_metrics_name = None | |
def init_train_eval_metrics(self, list_metrics_name): | |
"""Init train and eval metrics to save it to tensorboard.""" | |
self.train_metrics = {} | |
self.eval_metrics = {} | |
for name in list_metrics_name: | |
self.train_metrics.update( | |
{name: tf.keras.metrics.Mean(name="train_" + name, dtype=tf.float32)} | |
) | |
self.eval_metrics.update( | |
{name: tf.keras.metrics.Mean(name="eval_" + name, dtype=tf.float32)} | |
) | |
def reset_states_train(self): | |
"""Reset train metrics after save it to tensorboard.""" | |
for metric in self.train_metrics.keys(): | |
self.train_metrics[metric].reset_states() | |
def reset_states_eval(self): | |
"""Reset eval metrics after save it to tensorboard.""" | |
for metric in self.eval_metrics.keys(): | |
self.eval_metrics[metric].reset_states() | |
def update_train_metrics(self, dict_metrics_losses): | |
for name, value in dict_metrics_losses.items(): | |
self.train_metrics[name].update_state(value) | |
def update_eval_metrics(self, dict_metrics_losses): | |
for name, value in dict_metrics_losses.items(): | |
self.eval_metrics[name].update_state(value) | |
def set_train_data_loader(self, train_dataset): | |
"""Set train data loader (MUST).""" | |
self.train_data_loader = train_dataset | |
def get_train_data_loader(self): | |
"""Get train data loader.""" | |
return self.train_data_loader | |
def set_eval_data_loader(self, eval_dataset): | |
"""Set eval data loader (MUST).""" | |
self.eval_data_loader = eval_dataset | |
def get_eval_data_loader(self): | |
"""Get eval data loader.""" | |
return self.eval_data_loader | |
def compile(self): | |
pass | |
def create_checkpoint_manager(self, saved_path=None, max_to_keep=10): | |
"""Create checkpoint management.""" | |
pass | |
def run(self): | |
"""Run training.""" | |
self.tqdm = tqdm( | |
initial=self.steps, total=self.config["train_max_steps"], desc="[train]" | |
) | |
while True: | |
self._train_epoch() | |
if self.finish_train: | |
break | |
self.tqdm.close() | |
logging.info("Finish training.") | |
def save_checkpoint(self): | |
"""Save checkpoint.""" | |
pass | |
def load_checkpoint(self, pretrained_path): | |
"""Load checkpoint.""" | |
pass | |
def _train_epoch(self): | |
"""Train model one epoch.""" | |
for train_steps_per_epoch, batch in enumerate(self.train_data_loader, 1): | |
# one step training | |
self._train_step(batch) | |
# check interval | |
self._check_log_interval() | |
self._check_eval_interval() | |
self._check_save_interval() | |
# check wheter training is finished | |
if self.finish_train: | |
return | |
# update | |
self.epochs += 1 | |
self.train_steps_per_epoch = train_steps_per_epoch | |
logging.info( | |
f"(Steps: {self.steps}) Finished {self.epochs} epoch training " | |
f"({self.train_steps_per_epoch} steps per epoch)." | |
) | |
def _eval_epoch(self): | |
"""One epoch evaluation.""" | |
pass | |
def _train_step(self, batch): | |
"""One step training.""" | |
pass | |
def _check_log_interval(self): | |
"""Save log interval.""" | |
pass | |
def fit(self): | |
pass | |
def _check_eval_interval(self): | |
"""Evaluation interval step.""" | |
if self.steps % self.config["eval_interval_steps"] == 0: | |
self._eval_epoch() | |
def _check_save_interval(self): | |
"""Save interval checkpoint.""" | |
if self.steps % self.config["save_interval_steps"] == 0: | |
self.save_checkpoint() | |
logging.info(f"Successfully saved checkpoint @ {self.steps} steps.") | |
def generate_and_save_intermediate_result(self, batch): | |
"""Generate and save intermediate result.""" | |
pass | |
def _write_to_tensorboard(self, list_metrics, stage="train"): | |
"""Write variables to tensorboard.""" | |
with self.writer.as_default(): | |
for key, value in list_metrics.items(): | |
tf.summary.scalar(stage + "/" + key, value.result(), step=self.steps) | |
self.writer.flush() | |
class GanBasedTrainer(BasedTrainer): | |
"""Customized trainer module for GAN TTS training (MelGAN, GAN-TTS, ParallelWaveGAN).""" | |
def __init__( | |
self, | |
steps, | |
epochs, | |
config, | |
strategy, | |
is_generator_mixed_precision=False, | |
is_discriminator_mixed_precision=False, | |
): | |
"""Initialize trainer. | |
Args: | |
steps (int): Initial global steps. | |
epochs (int): Initial global epochs. | |
config (dict): Config dict loaded from yaml format configuration file. | |
""" | |
super().__init__(steps, epochs, config) | |
self._is_generator_mixed_precision = is_generator_mixed_precision | |
self._is_discriminator_mixed_precision = is_discriminator_mixed_precision | |
self._strategy = strategy | |
self._already_apply_input_signature = False | |
self._generator_gradient_accumulator = GradientAccumulator() | |
self._discriminator_gradient_accumulator = GradientAccumulator() | |
self._generator_gradient_accumulator.reset() | |
self._discriminator_gradient_accumulator.reset() | |
def init_train_eval_metrics(self, list_metrics_name): | |
with self._strategy.scope(): | |
super().init_train_eval_metrics(list_metrics_name) | |
def get_n_gpus(self): | |
return self._strategy.num_replicas_in_sync | |
def _get_train_element_signature(self): | |
return self.train_data_loader.element_spec | |
def _get_eval_element_signature(self): | |
return self.eval_data_loader.element_spec | |
def set_gen_model(self, generator_model): | |
"""Set generator class model (MUST).""" | |
self._generator = generator_model | |
def get_gen_model(self): | |
"""Get generator model.""" | |
return self._generator | |
def set_dis_model(self, discriminator_model): | |
"""Set discriminator class model (MUST).""" | |
self._discriminator = discriminator_model | |
def get_dis_model(self): | |
"""Get discriminator model.""" | |
return self._discriminator | |
def set_gen_optimizer(self, generator_optimizer): | |
"""Set generator optimizer (MUST).""" | |
self._gen_optimizer = generator_optimizer | |
if self._is_generator_mixed_precision: | |
self._gen_optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( | |
self._gen_optimizer, "dynamic" | |
) | |
def get_gen_optimizer(self): | |
"""Get generator optimizer.""" | |
return self._gen_optimizer | |
def set_dis_optimizer(self, discriminator_optimizer): | |
"""Set discriminator optimizer (MUST).""" | |
self._dis_optimizer = discriminator_optimizer | |
if self._is_discriminator_mixed_precision: | |
self._dis_optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( | |
self._dis_optimizer, "dynamic" | |
) | |
def get_dis_optimizer(self): | |
"""Get discriminator optimizer.""" | |
return self._dis_optimizer | |
def compile(self, gen_model, dis_model, gen_optimizer, dis_optimizer): | |
self.set_gen_model(gen_model) | |
self.set_dis_model(dis_model) | |
self.set_gen_optimizer(gen_optimizer) | |
self.set_dis_optimizer(dis_optimizer) | |
def _train_step(self, batch): | |
if self._already_apply_input_signature is False: | |
train_element_signature = self._get_train_element_signature() | |
eval_element_signature = self._get_eval_element_signature() | |
self.one_step_forward = tf.function( | |
self._one_step_forward, input_signature=[train_element_signature] | |
) | |
self.one_step_evaluate = tf.function( | |
self._one_step_evaluate, input_signature=[eval_element_signature] | |
) | |
self.one_step_predict = tf.function( | |
self._one_step_predict, input_signature=[eval_element_signature] | |
) | |
self._already_apply_input_signature = True | |
# run one_step_forward | |
self.one_step_forward(batch) | |
# update counts | |
self.steps += 1 | |
self.tqdm.update(1) | |
self._check_train_finish() | |
def _one_step_forward(self, batch): | |
per_replica_losses = self._strategy.run( | |
self._one_step_forward_per_replica, args=(batch,) | |
) | |
return self._strategy.reduce( | |
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None | |
) | |
def compute_per_example_generator_losses(self, batch, outputs): | |
"""Compute per example generator losses and return dict_metrics_losses | |
Note that all element of the loss MUST has a shape [batch_size] and | |
the keys of dict_metrics_losses MUST be in self.list_metrics_name. | |
Args: | |
batch: dictionary batch input return from dataloader | |
outputs: outputs of the model | |
Returns: | |
per_example_losses: per example losses for each GPU, shape [B] | |
dict_metrics_losses: dictionary loss. | |
""" | |
per_example_losses = 0.0 | |
dict_metrics_losses = {} | |
return per_example_losses, dict_metrics_losses | |
def compute_per_example_discriminator_losses(self, batch, gen_outputs): | |
"""Compute per example discriminator losses and return dict_metrics_losses | |
Note that all element of the loss MUST has a shape [batch_size] and | |
the keys of dict_metrics_losses MUST be in self.list_metrics_name. | |
Args: | |
batch: dictionary batch input return from dataloader | |
outputs: outputs of the model | |
Returns: | |
per_example_losses: per example losses for each GPU, shape [B] | |
dict_metrics_losses: dictionary loss. | |
""" | |
per_example_losses = 0.0 | |
dict_metrics_losses = {} | |
return per_example_losses, dict_metrics_losses | |
def _calculate_generator_gradient_per_batch(self, batch): | |
outputs = self._generator(**batch, training=True) | |
( | |
per_example_losses, | |
dict_metrics_losses, | |
) = self.compute_per_example_generator_losses(batch, outputs) | |
per_replica_gen_losses = tf.nn.compute_average_loss( | |
per_example_losses, | |
global_batch_size=self.config["batch_size"] | |
* self.get_n_gpus() | |
* self.config["gradient_accumulation_steps"], | |
) | |
if self._is_generator_mixed_precision: | |
scaled_per_replica_gen_losses = self._gen_optimizer.get_scaled_loss( | |
per_replica_gen_losses | |
) | |
if self._is_generator_mixed_precision: | |
scaled_gradients = tf.gradients( | |
scaled_per_replica_gen_losses, self._generator.trainable_variables | |
) | |
gradients = self._gen_optimizer.get_unscaled_gradients(scaled_gradients) | |
else: | |
gradients = tf.gradients( | |
per_replica_gen_losses, self._generator.trainable_variables | |
) | |
# gradient accumulate for generator here | |
if self.config["gradient_accumulation_steps"] > 1: | |
self._generator_gradient_accumulator(gradients) | |
# accumulate loss into metrics | |
self.update_train_metrics(dict_metrics_losses) | |
if self.config["gradient_accumulation_steps"] == 1: | |
return gradients, per_replica_gen_losses | |
else: | |
return per_replica_gen_losses | |
def _calculate_discriminator_gradient_per_batch(self, batch): | |
( | |
per_example_losses, | |
dict_metrics_losses, | |
) = self.compute_per_example_discriminator_losses( | |
batch, self._generator(**batch, training=True) | |
) | |
per_replica_dis_losses = tf.nn.compute_average_loss( | |
per_example_losses, | |
global_batch_size=self.config["batch_size"] | |
* self.get_n_gpus() | |
* self.config["gradient_accumulation_steps"], | |
) | |
if self._is_discriminator_mixed_precision: | |
scaled_per_replica_dis_losses = self._dis_optimizer.get_scaled_loss( | |
per_replica_dis_losses | |
) | |
if self._is_discriminator_mixed_precision: | |
scaled_gradients = tf.gradients( | |
scaled_per_replica_dis_losses, | |
self._discriminator.trainable_variables, | |
) | |
gradients = self._dis_optimizer.get_unscaled_gradients(scaled_gradients) | |
else: | |
gradients = tf.gradients( | |
per_replica_dis_losses, self._discriminator.trainable_variables | |
) | |
# accumulate loss into metrics | |
self.update_train_metrics(dict_metrics_losses) | |
# gradient accumulate for discriminator here | |
if self.config["gradient_accumulation_steps"] > 1: | |
self._discriminator_gradient_accumulator(gradients) | |
if self.config["gradient_accumulation_steps"] == 1: | |
return gradients, per_replica_dis_losses | |
else: | |
return per_replica_dis_losses | |
def _one_step_forward_per_replica(self, batch): | |
per_replica_gen_losses = 0.0 | |
per_replica_dis_losses = 0.0 | |
if self.config["gradient_accumulation_steps"] == 1: | |
( | |
gradients, | |
per_replica_gen_losses, | |
) = self._calculate_generator_gradient_per_batch(batch) | |
self._gen_optimizer.apply_gradients( | |
zip(gradients, self._generator.trainable_variables) | |
) | |
else: | |
# gradient acummulation here. | |
for i in tf.range(self.config["gradient_accumulation_steps"]): | |
reduced_batch = { | |
k: v[ | |
i | |
* self.config["batch_size"] : (i + 1) | |
* self.config["batch_size"] | |
] | |
for k, v in batch.items() | |
} | |
# run 1 step accumulate | |
reduced_batch_losses = self._calculate_generator_gradient_per_batch( | |
reduced_batch | |
) | |
# sum per_replica_losses | |
per_replica_gen_losses += reduced_batch_losses | |
gradients = self._generator_gradient_accumulator.gradients | |
self._gen_optimizer.apply_gradients( | |
zip(gradients, self._generator.trainable_variables) | |
) | |
self._generator_gradient_accumulator.reset() | |
# one step discriminator | |
# recompute y_hat after 1 step generator for discriminator training. | |
if self.steps >= self.config["discriminator_train_start_steps"]: | |
if self.config["gradient_accumulation_steps"] == 1: | |
( | |
gradients, | |
per_replica_dis_losses, | |
) = self._calculate_discriminator_gradient_per_batch(batch) | |
self._dis_optimizer.apply_gradients( | |
zip(gradients, self._discriminator.trainable_variables) | |
) | |
else: | |
# gradient acummulation here. | |
for i in tf.range(self.config["gradient_accumulation_steps"]): | |
reduced_batch = { | |
k: v[ | |
i | |
* self.config["batch_size"] : (i + 1) | |
* self.config["batch_size"] | |
] | |
for k, v in batch.items() | |
} | |
# run 1 step accumulate | |
reduced_batch_losses = ( | |
self._calculate_discriminator_gradient_per_batch(reduced_batch) | |
) | |
# sum per_replica_losses | |
per_replica_dis_losses += reduced_batch_losses | |
gradients = self._discriminator_gradient_accumulator.gradients | |
self._dis_optimizer.apply_gradients( | |
zip(gradients, self._discriminator.trainable_variables) | |
) | |
self._discriminator_gradient_accumulator.reset() | |
return per_replica_gen_losses + per_replica_dis_losses | |
def _eval_epoch(self): | |
"""Evaluate model one epoch.""" | |
logging.info(f"(Steps: {self.steps}) Start evaluation.") | |
# calculate loss for each batch | |
for eval_steps_per_epoch, batch in enumerate( | |
tqdm(self.eval_data_loader, desc="[eval]"), 1 | |
): | |
# eval one step | |
self.one_step_evaluate(batch) | |
if eval_steps_per_epoch <= self.config["num_save_intermediate_results"]: | |
# save intermedia | |
self.generate_and_save_intermediate_result(batch) | |
logging.info( | |
f"(Steps: {self.steps}) Finished evaluation " | |
f"({eval_steps_per_epoch} steps per epoch)." | |
) | |
# average loss | |
for key in self.eval_metrics.keys(): | |
logging.info( | |
f"(Steps: {self.steps}) eval_{key} = {self.eval_metrics[key].result():.4f}." | |
) | |
# record | |
self._write_to_tensorboard(self.eval_metrics, stage="eval") | |
# reset | |
self.reset_states_eval() | |
def _one_step_evaluate_per_replica(self, batch): | |
################################################ | |
# one step generator. | |
outputs = self._generator(**batch, training=False) | |
_, dict_metrics_losses = self.compute_per_example_generator_losses( | |
batch, outputs | |
) | |
# accumulate loss into metrics | |
self.update_eval_metrics(dict_metrics_losses) | |
################################################ | |
# one step discriminator | |
if self.steps >= self.config["discriminator_train_start_steps"]: | |
_, dict_metrics_losses = self.compute_per_example_discriminator_losses( | |
batch, outputs | |
) | |
# accumulate loss into metrics | |
self.update_eval_metrics(dict_metrics_losses) | |
################################################ | |
def _one_step_evaluate(self, batch): | |
self._strategy.run(self._one_step_evaluate_per_replica, args=(batch,)) | |
def _one_step_predict_per_replica(self, batch): | |
outputs = self._generator(**batch, training=False) | |
return outputs | |
def _one_step_predict(self, batch): | |
outputs = self._strategy.run(self._one_step_predict_per_replica, args=(batch,)) | |
return outputs | |
def generate_and_save_intermediate_result(self, batch): | |
return | |
def create_checkpoint_manager(self, saved_path=None, max_to_keep=10): | |
"""Create checkpoint management.""" | |
if saved_path is None: | |
saved_path = self.config["outdir"] + "/checkpoints/" | |
os.makedirs(saved_path, exist_ok=True) | |
self.saved_path = saved_path | |
self.ckpt = tf.train.Checkpoint( | |
steps=tf.Variable(1), | |
epochs=tf.Variable(1), | |
gen_optimizer=self.get_gen_optimizer(), | |
dis_optimizer=self.get_dis_optimizer(), | |
) | |
self.ckp_manager = tf.train.CheckpointManager( | |
self.ckpt, saved_path, max_to_keep=max_to_keep | |
) | |
def save_checkpoint(self): | |
"""Save checkpoint.""" | |
self.ckpt.steps.assign(self.steps) | |
self.ckpt.epochs.assign(self.epochs) | |
self.ckp_manager.save(checkpoint_number=self.steps) | |
utils.save_weights( | |
self._generator, | |
self.saved_path + "generator-{}.h5".format(self.steps) | |
) | |
utils.save_weights( | |
self._discriminator, | |
self.saved_path + "discriminator-{}.h5".format(self.steps) | |
) | |
def load_checkpoint(self, pretrained_path): | |
"""Load checkpoint.""" | |
self.ckpt.restore(pretrained_path) | |
self.steps = self.ckpt.steps.numpy() | |
self.epochs = self.ckpt.epochs.numpy() | |
self._gen_optimizer = self.ckpt.gen_optimizer | |
# re-assign iterations (global steps) for gen_optimizer. | |
self._gen_optimizer.iterations.assign(tf.cast(self.steps, tf.int64)) | |
# re-assign iterations (global steps) for dis_optimizer. | |
try: | |
discriminator_train_start_steps = self.config[ | |
"discriminator_train_start_steps" | |
] | |
discriminator_train_start_steps = tf.math.maximum( | |
0, self.steps - discriminator_train_start_steps | |
) | |
except Exception: | |
discriminator_train_start_steps = self.steps | |
self._dis_optimizer = self.ckpt.dis_optimizer | |
self._dis_optimizer.iterations.assign( | |
tf.cast(discriminator_train_start_steps, tf.int64) | |
) | |
# load weights. | |
utils.load_weights( | |
self._generator, | |
self.saved_path + "generator-{}.h5".format(self.steps) | |
) | |
utils.load_weights( | |
self._discriminator, | |
self.saved_path + "discriminator-{}.h5".format(self.steps) | |
) | |
def _check_train_finish(self): | |
"""Check training finished.""" | |
if self.steps >= self.config["train_max_steps"]: | |
self.finish_train = True | |
if ( | |
self.steps != 0 | |
and self.steps == self.config["discriminator_train_start_steps"] | |
): | |
self.finish_train = True | |
logging.info( | |
f"Finished training only generator at {self.steps}steps, pls resume and continue training." | |
) | |
def _check_log_interval(self): | |
"""Log to tensorboard.""" | |
if self.steps % self.config["log_interval_steps"] == 0: | |
for metric_name in self.list_metrics_name: | |
logging.info( | |
f"(Step: {self.steps}) train_{metric_name} = {self.train_metrics[metric_name].result():.4f}." | |
) | |
self._write_to_tensorboard(self.train_metrics, stage="train") | |
# reset | |
self.reset_states_train() | |
def fit(self, train_data_loader, valid_data_loader, saved_path, resume=None): | |
self.set_train_data_loader(train_data_loader) | |
self.set_eval_data_loader(valid_data_loader) | |
self.train_data_loader = self._strategy.experimental_distribute_dataset( | |
self.train_data_loader | |
) | |
self.eval_data_loader = self._strategy.experimental_distribute_dataset( | |
self.eval_data_loader | |
) | |
with self._strategy.scope(): | |
self.create_checkpoint_manager(saved_path=saved_path, max_to_keep=10000) | |
if len(resume) > 1: | |
self.load_checkpoint(resume) | |
logging.info(f"Successfully resumed from {resume}.") | |
self.run() | |
class Seq2SeqBasedTrainer(BasedTrainer, metaclass=abc.ABCMeta): | |
"""Customized trainer module for Seq2Seq TTS training (Tacotron, FastSpeech).""" | |
def __init__( | |
self, steps, epochs, config, strategy, is_mixed_precision=False, | |
): | |
"""Initialize trainer. | |
Args: | |
steps (int): Initial global steps. | |
epochs (int): Initial global epochs. | |
config (dict): Config dict loaded from yaml format configuration file. | |
strategy (tf.distribute): Strategy for distributed training. | |
is_mixed_precision (bool): Use mixed_precision training or not. | |
""" | |
super().__init__(steps, epochs, config) | |
self._is_mixed_precision = is_mixed_precision | |
self._strategy = strategy | |
self._model = None | |
self._optimizer = None | |
self._trainable_variables = None | |
# check if we already apply input_signature for train_step. | |
self._already_apply_input_signature = False | |
# create gradient accumulator | |
self._gradient_accumulator = GradientAccumulator() | |
self._gradient_accumulator.reset() | |
def init_train_eval_metrics(self, list_metrics_name): | |
with self._strategy.scope(): | |
super().init_train_eval_metrics(list_metrics_name) | |
def set_model(self, model): | |
"""Set generator class model (MUST).""" | |
self._model = model | |
def get_model(self): | |
"""Get generator model.""" | |
return self._model | |
def set_optimizer(self, optimizer): | |
"""Set optimizer (MUST).""" | |
self._optimizer = optimizer | |
if self._is_mixed_precision: | |
self._optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( | |
self._optimizer, "dynamic" | |
) | |
def get_optimizer(self): | |
"""Get optimizer.""" | |
return self._optimizer | |
def get_n_gpus(self): | |
return self._strategy.num_replicas_in_sync | |
def compile(self, model, optimizer): | |
self.set_model(model) | |
self.set_optimizer(optimizer) | |
self._trainable_variables = self._train_vars() | |
def _train_vars(self): | |
if self.config["var_train_expr"]: | |
list_train_var = self.config["var_train_expr"].split("|") | |
return [ | |
v | |
for v in self._model.trainable_variables | |
if self._check_string_exist(list_train_var, v.name) | |
] | |
return self._model.trainable_variables | |
def _check_string_exist(self, list_string, inp_string): | |
for string in list_string: | |
if string in inp_string: | |
return True | |
return False | |
def _get_train_element_signature(self): | |
return self.train_data_loader.element_spec | |
def _get_eval_element_signature(self): | |
return self.eval_data_loader.element_spec | |
def _train_step(self, batch): | |
if self._already_apply_input_signature is False: | |
train_element_signature = self._get_train_element_signature() | |
eval_element_signature = self._get_eval_element_signature() | |
self.one_step_forward = tf.function( | |
self._one_step_forward, input_signature=[train_element_signature] | |
) | |
self.one_step_evaluate = tf.function( | |
self._one_step_evaluate, input_signature=[eval_element_signature] | |
) | |
self.one_step_predict = tf.function( | |
self._one_step_predict, input_signature=[eval_element_signature] | |
) | |
self._already_apply_input_signature = True | |
# run one_step_forward | |
self.one_step_forward(batch) | |
# update counts | |
self.steps += 1 | |
self.tqdm.update(1) | |
self._check_train_finish() | |
def _one_step_forward(self, batch): | |
per_replica_losses = self._strategy.run( | |
self._one_step_forward_per_replica, args=(batch,) | |
) | |
return self._strategy.reduce( | |
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None | |
) | |
def _calculate_gradient_per_batch(self, batch): | |
outputs = self._model(**batch, training=True) | |
per_example_losses, dict_metrics_losses = self.compute_per_example_losses( | |
batch, outputs | |
) | |
per_replica_losses = tf.nn.compute_average_loss( | |
per_example_losses, | |
global_batch_size=self.config["batch_size"] | |
* self.get_n_gpus() | |
* self.config["gradient_accumulation_steps"], | |
) | |
if self._is_mixed_precision: | |
scaled_per_replica_losses = self._optimizer.get_scaled_loss( | |
per_replica_losses | |
) | |
if self._is_mixed_precision: | |
scaled_gradients = tf.gradients( | |
scaled_per_replica_losses, self._trainable_variables | |
) | |
gradients = self._optimizer.get_unscaled_gradients(scaled_gradients) | |
else: | |
gradients = tf.gradients(per_replica_losses, self._trainable_variables) | |
# gradient accumulate here | |
if self.config["gradient_accumulation_steps"] > 1: | |
self._gradient_accumulator(gradients) | |
# accumulate loss into metrics | |
self.update_train_metrics(dict_metrics_losses) | |
if self.config["gradient_accumulation_steps"] == 1: | |
return gradients, per_replica_losses | |
else: | |
return per_replica_losses | |
def _one_step_forward_per_replica(self, batch): | |
if self.config["gradient_accumulation_steps"] == 1: | |
gradients, per_replica_losses = self._calculate_gradient_per_batch(batch) | |
self._optimizer.apply_gradients( | |
zip(gradients, self._trainable_variables), 1.0 | |
) | |
else: | |
# gradient acummulation here. | |
per_replica_losses = 0.0 | |
for i in tf.range(self.config["gradient_accumulation_steps"]): | |
reduced_batch = { | |
k: v[ | |
i | |
* self.config["batch_size"] : (i + 1) | |
* self.config["batch_size"] | |
] | |
for k, v in batch.items() | |
} | |
# run 1 step accumulate | |
reduced_batch_losses = self._calculate_gradient_per_batch(reduced_batch) | |
# sum per_replica_losses | |
per_replica_losses += reduced_batch_losses | |
gradients = self._gradient_accumulator.gradients | |
self._optimizer.apply_gradients( | |
zip(gradients, self._trainable_variables), 1.0 | |
) | |
self._gradient_accumulator.reset() | |
return per_replica_losses | |
def compute_per_example_losses(self, batch, outputs): | |
"""Compute per example losses and return dict_metrics_losses | |
Note that all element of the loss MUST has a shape [batch_size] and | |
the keys of dict_metrics_losses MUST be in self.list_metrics_name. | |
Args: | |
batch: dictionary batch input return from dataloader | |
outputs: outputs of the model | |
Returns: | |
per_example_losses: per example losses for each GPU, shape [B] | |
dict_metrics_losses: dictionary loss. | |
""" | |
per_example_losses = 0.0 | |
dict_metrics_losses = {} | |
return per_example_losses, dict_metrics_losses | |
def _eval_epoch(self): | |
"""Evaluate model one epoch.""" | |
logging.info(f"(Steps: {self.steps}) Start evaluation.") | |
# calculate loss for each batch | |
for eval_steps_per_epoch, batch in enumerate( | |
tqdm(self.eval_data_loader, desc="[eval]"), 1 | |
): | |
# eval one step | |
self.one_step_evaluate(batch) | |
if eval_steps_per_epoch <= self.config["num_save_intermediate_results"]: | |
# save intermedia | |
self.generate_and_save_intermediate_result(batch) | |
logging.info( | |
f"(Steps: {self.steps}) Finished evaluation " | |
f"({eval_steps_per_epoch} steps per epoch)." | |
) | |
# average loss | |
for key in self.eval_metrics.keys(): | |
logging.info( | |
f"(Steps: {self.steps}) eval_{key} = {self.eval_metrics[key].result():.4f}." | |
) | |
# record | |
self._write_to_tensorboard(self.eval_metrics, stage="eval") | |
# reset | |
self.reset_states_eval() | |
def _one_step_evaluate_per_replica(self, batch): | |
outputs = self._model(**batch, training=False) | |
_, dict_metrics_losses = self.compute_per_example_losses(batch, outputs) | |
self.update_eval_metrics(dict_metrics_losses) | |
def _one_step_evaluate(self, batch): | |
self._strategy.run(self._one_step_evaluate_per_replica, args=(batch,)) | |
def _one_step_predict_per_replica(self, batch): | |
outputs = self._model(**batch, training=False) | |
return outputs | |
def _one_step_predict(self, batch): | |
outputs = self._strategy.run(self._one_step_predict_per_replica, args=(batch,)) | |
return outputs | |
def generate_and_save_intermediate_result(self, batch): | |
return | |
def create_checkpoint_manager(self, saved_path=None, max_to_keep=10): | |
"""Create checkpoint management.""" | |
if saved_path is None: | |
saved_path = self.config["outdir"] + "/checkpoints/" | |
os.makedirs(saved_path, exist_ok=True) | |
self.saved_path = saved_path | |
self.ckpt = tf.train.Checkpoint( | |
steps=tf.Variable(1), epochs=tf.Variable(1), optimizer=self.get_optimizer() | |
) | |
self.ckp_manager = tf.train.CheckpointManager( | |
self.ckpt, saved_path, max_to_keep=max_to_keep | |
) | |
def save_checkpoint(self): | |
"""Save checkpoint.""" | |
self.ckpt.steps.assign(self.steps) | |
self.ckpt.epochs.assign(self.epochs) | |
self.ckp_manager.save(checkpoint_number=self.steps) | |
utils.save_weights( | |
self._model, | |
self.saved_path + "model-{}.h5".format(self.steps) | |
) | |
def load_checkpoint(self, pretrained_path): | |
"""Load checkpoint.""" | |
self.ckpt.restore(pretrained_path) | |
self.steps = self.ckpt.steps.numpy() | |
self.epochs = self.ckpt.epochs.numpy() | |
self._optimizer = self.ckpt.optimizer | |
# re-assign iterations (global steps) for optimizer. | |
self._optimizer.iterations.assign(tf.cast(self.steps, tf.int64)) | |
# load weights. | |
utils.load_weights( | |
self._model, | |
self.saved_path + "model-{}.h5".format(self.steps) | |
) | |
def _check_train_finish(self): | |
"""Check training finished.""" | |
if self.steps >= self.config["train_max_steps"]: | |
self.finish_train = True | |
def _check_log_interval(self): | |
"""Log to tensorboard.""" | |
if self.steps % self.config["log_interval_steps"] == 0: | |
for metric_name in self.list_metrics_name: | |
logging.info( | |
f"(Step: {self.steps}) train_{metric_name} = {self.train_metrics[metric_name].result():.4f}." | |
) | |
self._write_to_tensorboard(self.train_metrics, stage="train") | |
# reset | |
self.reset_states_train() | |
def fit(self, train_data_loader, valid_data_loader, saved_path, resume=None): | |
self.set_train_data_loader(train_data_loader) | |
self.set_eval_data_loader(valid_data_loader) | |
self.train_data_loader = self._strategy.experimental_distribute_dataset( | |
self.train_data_loader | |
) | |
self.eval_data_loader = self._strategy.experimental_distribute_dataset( | |
self.eval_data_loader | |
) | |
with self._strategy.scope(): | |
self.create_checkpoint_manager(saved_path=saved_path, max_to_keep=10000) | |
if len(resume) > 1: | |
self.load_checkpoint(resume) | |
logging.info(f"Successfully resumed from {resume}.") | |
self.run() | |