voice-xtts2 / TTS /bin /train_vocoder_wavegrad.py
antoniomae1234's picture
changes in flenema
2493d72 verified
raw
history blame contribute delete
No virus
18.3 kB
import argparse
import glob
import os
import sys
import time
import traceback
import numpy as np
import torch
# DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.utils.audio import AudioProcessor
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict)
from TTS.utils.io import copy_model_files, load_config
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import setup_torch_training_env
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.utils.generic_utils import plot_results, setup_generator
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
use_cuda, num_gpus = setup_torch_training_env(True, True)
def setup_loader(ap, is_val=False, verbose=False):
if is_val and not c.run_eval:
loader = None
else:
dataset = WaveGradDataset(ap=ap,
items=eval_data if is_val else train_data,
seq_len=c.seq_len,
hop_len=ap.hop_length,
pad_short=c.pad_short,
conv_pad=c.conv_pad,
is_training=not is_val,
return_segments=True,
use_noise_augment=False,
use_cache=c.use_cache,
verbose=verbose)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(dataset,
batch_size=c.batch_size,
shuffle=num_gpus <= 1,
drop_last=False,
sampler=sampler,
num_workers=c.num_val_loader_workers
if is_val else c.num_loader_workers,
pin_memory=False)
return loader
def format_data(data):
# return a whole audio segment
m, x = data
x = x.unsqueeze(1)
if use_cuda:
m = m.cuda(non_blocking=True)
x = x.cuda(non_blocking=True)
return m, x
def format_test_data(data):
# return a whole audio segment
m, x = data
m = m[None, ...]
x = x[None, None, ...]
if use_cuda:
m = m.cuda(non_blocking=True)
x = x.cuda(non_blocking=True)
return m, x
def train(model, criterion, optimizer,
scheduler, scaler, ap, global_step, epoch):
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
model.train()
epoch_time = 0
keep_avg = KeepAverage()
if use_cuda:
batch_n_iter = int(
len(data_loader.dataset) / (c.batch_size * num_gpus))
else:
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time()
c_logger.print_train_start()
# setup noise schedule
noise_schedule = c['train_noise_schedule']
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
if hasattr(model, 'module'):
model.module.compute_noise_level(betas)
else:
model.compute_noise_level(betas)
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# format data
m, x = format_data(data)
loader_time = time.time() - end_time
global_step += 1
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
# compute noisy input
if hasattr(model, 'module'):
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
else:
noise, x_noisy, noise_scale = model.compute_y_n(x)
# forward pass
noise_hat = model(x_noisy, m, noise_scale)
# compute losses
loss = criterion(noise, noise_hat)
loss_wavegrad_dict = {'wavegrad_loss':loss}
# check nan loss
if torch.isnan(loss).any():
raise RuntimeError(f'Detected NaN loss at step {global_step}.')
optimizer.zero_grad()
# backward pass with loss scaling
if c.mixed_precision:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.clip_grad)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.clip_grad)
optimizer.step()
# schedule update
if scheduler is not None:
scheduler.step()
# disconnect loss values
loss_dict = dict()
for key, value in loss_wavegrad_dict.items():
if isinstance(value, int):
loss_dict[key] = value
else:
loss_dict[key] = value.item()
# epoch/step timing
step_time = time.time() - start_time
epoch_time += step_time
# get current learning rates
current_lr = list(optimizer.param_groups)[0]['lr']
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values['avg_' + key] = value
update_train_values['avg_loader_time'] = loader_time
update_train_values['avg_step_time'] = step_time
keep_avg.update_values(update_train_values)
# print training stats
if global_step % c.print_step == 0:
log_dict = {
'step_time': [step_time, 2],
'loader_time': [loader_time, 4],
"current_lr": current_lr,
"grad_norm": grad_norm.item()
}
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
log_dict, loss_dict, keep_avg.avg_values)
if args.rank == 0:
# plot step stats
if global_step % 10 == 0:
iter_stats = {
"lr": current_lr,
"grad_norm": grad_norm.item(),
"step_time": step_time
}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
# save checkpoint
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model,
optimizer,
scheduler,
None,
None,
None,
global_step,
epoch,
OUT_PATH,
model_losses=loss_dict,
scaler=scaler.state_dict() if c.mixed_precision else None)
end_time = time.time()
# print epoch stats
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
# Plot Training Epoch Stats
epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(keep_avg.avg_values)
if args.rank == 0:
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
# TODO: plot model stats
if c.tb_model_param_stats and args.rank == 0:
tb_logger.tb_model_weights(model, global_step)
return keep_avg.avg_values, global_step
@torch.no_grad()
def evaluate(model, criterion, ap, global_step, epoch):
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
model.eval()
epoch_time = 0
keep_avg = KeepAverage()
end_time = time.time()
c_logger.print_eval_start()
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# format data
m, x = format_data(data)
loader_time = time.time() - end_time
global_step += 1
# compute noisy input
if hasattr(model, 'module'):
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
else:
noise, x_noisy, noise_scale = model.compute_y_n(x)
# forward pass
noise_hat = model(x_noisy, m, noise_scale)
# compute losses
loss = criterion(noise, noise_hat)
loss_wavegrad_dict = {'wavegrad_loss':loss}
loss_dict = dict()
for key, value in loss_wavegrad_dict.items():
if isinstance(value, (int, float)):
loss_dict[key] = value
else:
loss_dict[key] = value.item()
step_time = time.time() - start_time
epoch_time += step_time
# update avg stats
update_eval_values = dict()
for key, value in loss_dict.items():
update_eval_values['avg_' + key] = value
update_eval_values['avg_loader_time'] = loader_time
update_eval_values['avg_step_time'] = step_time
keep_avg.update_values(update_eval_values)
# print eval stats
if c.print_eval:
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
if args.rank == 0:
data_loader.dataset.return_segments = False
samples = data_loader.dataset.load_test_samples(1)
m, x = format_test_data(samples[0])
# setup noise schedule and inference
noise_schedule = c['test_noise_schedule']
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
if hasattr(model, 'module'):
model.module.compute_noise_level(betas)
# compute voice
x_pred = model.module.inference(m)
else:
model.compute_noise_level(betas)
# compute voice
x_pred = model.inference(m)
# compute spectrograms
figures = plot_results(x_pred, x, ap, global_step, 'eval')
tb_logger.tb_eval_figures(global_step, figures)
# Sample audio
sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
c.audio["sample_rate"])
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
data_loader.dataset.return_segments = True
return keep_avg.avg_values
def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global train_data, eval_data
print(f" > Loading wavs from: {c.data_path}")
if c.feature_path is not None:
print(f" > Loading features from: {c.feature_path}")
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
else:
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
# setup audio processor
ap = AudioProcessor(**c.audio)
# DISTRUBUTED
if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
# setup models
model = setup_generator(c)
# scaler for mixed_precision
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
# setup optimizers
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
# schedulers
scheduler = None
if 'lr_scheduler' in c:
scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
scheduler = scheduler(optimizer, **c.lr_scheduler_params)
# setup criterion
criterion = torch.nn.L1Loss().cuda()
if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location='cpu')
try:
print(" > Restoring Model...")
model.load_state_dict(checkpoint['model'])
print(" > Restoring Optimizer...")
optimizer.load_state_dict(checkpoint['optimizer'])
if 'scheduler' in checkpoint:
print(" > Restoring LR Scheduler...")
scheduler.load_state_dict(checkpoint['scheduler'])
# NOTE: Not sure if necessary
scheduler.optimizer = optimizer
if "scaler" in checkpoint and c.mixed_precision:
print(" > Restoring AMP Scaler...")
scaler.load_state_dict(checkpoint["scaler"])
except RuntimeError:
# retore only matching layers.
print(" > Partial model initialization...")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
model.load_state_dict(model_dict)
del model_dict
# reset lr if not countinuining training.
for group in optimizer.param_groups:
group['lr'] = c.lr
print(" > Model restored from step %d" % checkpoint['step'],
flush=True)
args.restore_step = checkpoint['step']
else:
args.restore_step = 0
if use_cuda:
model.cuda()
criterion.cuda()
# DISTRUBUTED
if num_gpus > 1:
model = DDP_th(model, device_ids=[args.rank])
num_params = count_parameters(model)
print(" > WaveGrad has {} parameters".format(num_params), flush=True)
if 'best_loss' not in locals():
best_loss = float('inf')
global_step = args.restore_step
for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs)
_, global_step = train(model, criterion, optimizer,
scheduler, scaler, ap, global_step,
epoch)
eval_avg_loss_dict = evaluate(model, criterion, ap,
global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = eval_avg_loss_dict[c.target_loss]
best_loss = save_best_model(target_loss,
best_loss,
model,
optimizer,
scheduler,
None,
None,
None,
global_step,
epoch,
OUT_PATH,
model_losses=eval_avg_loss_dict,
scaler=scaler.state_dict() if c.mixed_precision else None)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--continue_path',
type=str,
help=
'Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default='',
required='--config_path' not in sys.argv)
parser.add_argument(
'--restore_path',
type=str,
help='Model file to be restored. Use to finetune a model.',
default='')
parser.add_argument('--config_path',
type=str,
help='Path to config file for training.',
required='--continue_path' not in sys.argv)
parser.add_argument('--debug',
type=bool,
default=False,
help='Do not verify commit integrity to run training.')
# DISTRUBUTED
parser.add_argument(
'--rank',
type=int,
default=0,
help='DISTRIBUTED: process rank for distributed training.')
parser.add_argument('--group_id',
type=str,
default="",
help='DISTRIBUTED: process group id.')
args = parser.parse_args()
if args.continue_path != '':
args.output_path = args.continue_path
args.config_path = os.path.join(args.continue_path, 'config.json')
list_of_files = glob.glob(
args.continue_path +
"/*.pth.tar") # * means all if need specific format then *.csv
latest_model_file = max(list_of_files, key=os.path.getctime)
args.restore_path = latest_model_file
print(f" > Training continues for {args.restore_path}")
# setup output paths and read configs
c = load_config(args.config_path)
# check_config(c)
_ = os.path.dirname(os.path.realpath(__file__))
# DISTRIBUTED
if c.mixed_precision:
print(" > Mixed precision is enabled")
OUT_PATH = args.continue_path
if args.continue_path == '':
OUT_PATH = create_experiment_folder(c.output_path, c.run_name,
args.debug)
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
c_logger = ConsoleLogger()
if args.rank == 0:
os.makedirs(AUDIO_PATH, exist_ok=True)
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
copy_model_files(c, args.config_path,
OUT_PATH, new_fields)
os.chmod(AUDIO_PATH, 0o775)
os.chmod(OUT_PATH, 0o775)
LOG_DIR = OUT_PATH
tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER')
# write model desc to tensorboard
tb_logger.tb_add_text('model-description', c['run_description'], 0)
try:
main(args)
except KeyboardInterrupt:
remove_experiment_folder(OUT_PATH)
try:
sys.exit(0)
except SystemExit:
os._exit(0) # pylint: disable=protected-access
except Exception: # pylint: disable=broad-except
remove_experiment_folder(OUT_PATH)
traceback.print_exc()
sys.exit(1)