|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import json |
|
import os |
|
import torch |
|
|
|
|
|
from distributed import init_distributed, apply_gradient_allreduce, reduce_tensor |
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
|
|
from torch.utils.data import DataLoader |
|
from glow import WaveGlow, WaveGlowLoss |
|
from mel2samp import Mel2Samp |
|
|
|
def load_checkpoint(checkpoint_path, model, optimizer): |
|
assert os.path.isfile(checkpoint_path) |
|
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') |
|
iteration = checkpoint_dict['iteration'] |
|
optimizer.load_state_dict(checkpoint_dict['optimizer']) |
|
model_for_loading = checkpoint_dict['model'] |
|
model.load_state_dict(model_for_loading.state_dict()) |
|
print("Loaded checkpoint '{}' (iteration {})" .format( |
|
checkpoint_path, iteration)) |
|
return model, optimizer, iteration |
|
|
|
def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): |
|
print("Saving model and optimizer state at iteration {} to {}".format( |
|
iteration, filepath)) |
|
model_for_saving = WaveGlow(**waveglow_config).cuda() |
|
model_for_saving.load_state_dict(model.state_dict()) |
|
torch.save({'model': model_for_saving, |
|
'iteration': iteration, |
|
'optimizer': optimizer.state_dict(), |
|
'learning_rate': learning_rate}, filepath) |
|
|
|
def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate, |
|
sigma, iters_per_checkpoint, batch_size, seed, fp16_run, |
|
checkpoint_path, with_tensorboard): |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
|
|
if num_gpus > 1: |
|
init_distributed(rank, num_gpus, group_name, **dist_config) |
|
|
|
|
|
criterion = WaveGlowLoss(sigma) |
|
model = WaveGlow(**waveglow_config).cuda() |
|
|
|
|
|
if num_gpus > 1: |
|
model = apply_gradient_allreduce(model) |
|
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) |
|
|
|
if fp16_run: |
|
from apex import amp |
|
model, optimizer = amp.initialize(model, optimizer, opt_level='O1') |
|
|
|
|
|
iteration = 0 |
|
if checkpoint_path != "": |
|
model, optimizer, iteration = load_checkpoint(checkpoint_path, model, |
|
optimizer) |
|
iteration += 1 |
|
|
|
trainset = Mel2Samp(**data_config) |
|
|
|
train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None |
|
|
|
train_loader = DataLoader(trainset, num_workers=1, shuffle=False, |
|
sampler=train_sampler, |
|
batch_size=batch_size, |
|
pin_memory=False, |
|
drop_last=True) |
|
|
|
|
|
if rank == 0: |
|
if not os.path.isdir(output_directory): |
|
os.makedirs(output_directory) |
|
os.chmod(output_directory, 0o775) |
|
print("output directory", output_directory) |
|
|
|
if with_tensorboard and rank == 0: |
|
from tensorboardX import SummaryWriter |
|
logger = SummaryWriter(os.path.join(output_directory, 'logs')) |
|
|
|
model.train() |
|
epoch_offset = max(0, int(iteration / len(train_loader))) |
|
|
|
for epoch in range(epoch_offset, epochs): |
|
print("Epoch: {}".format(epoch)) |
|
for i, batch in enumerate(train_loader): |
|
model.zero_grad() |
|
|
|
mel, audio = batch |
|
mel = torch.autograd.Variable(mel.cuda()) |
|
audio = torch.autograd.Variable(audio.cuda()) |
|
outputs = model((mel, audio)) |
|
|
|
loss = criterion(outputs) |
|
if num_gpus > 1: |
|
reduced_loss = reduce_tensor(loss.data, num_gpus).item() |
|
else: |
|
reduced_loss = loss.item() |
|
|
|
if fp16_run: |
|
with amp.scale_loss(loss, optimizer) as scaled_loss: |
|
scaled_loss.backward() |
|
else: |
|
loss.backward() |
|
|
|
optimizer.step() |
|
|
|
print("{}:\t{:.9f}".format(iteration, reduced_loss)) |
|
if with_tensorboard and rank == 0: |
|
logger.add_scalar('training_loss', reduced_loss, i + len(train_loader) * epoch) |
|
|
|
if (iteration % iters_per_checkpoint == 0): |
|
if rank == 0: |
|
checkpoint_path = "{}/waveglow_{}".format( |
|
output_directory, iteration) |
|
save_checkpoint(model, optimizer, learning_rate, iteration, |
|
checkpoint_path) |
|
|
|
iteration += 1 |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('-c', '--config', type=str, |
|
help='JSON file for configuration') |
|
parser.add_argument('-r', '--rank', type=int, default=0, |
|
help='rank of process for distributed') |
|
parser.add_argument('-g', '--group_name', type=str, default='', |
|
help='name of group for distributed') |
|
args = parser.parse_args() |
|
|
|
|
|
with open(args.config) as f: |
|
data = f.read() |
|
config = json.loads(data) |
|
train_config = config["train_config"] |
|
global data_config |
|
data_config = config["data_config"] |
|
global dist_config |
|
dist_config = config["dist_config"] |
|
global waveglow_config |
|
waveglow_config = config["waveglow_config"] |
|
|
|
num_gpus = torch.cuda.device_count() |
|
if num_gpus > 1: |
|
if args.group_name == '': |
|
print("WARNING: Multiple GPUs detected but no distributed group set") |
|
print("Only running 1 GPU. Use distributed.py for multiple GPUs") |
|
num_gpus = 1 |
|
|
|
if num_gpus == 1 and args.rank != 0: |
|
raise Exception("Doing single GPU training on rank > 0") |
|
|
|
torch.backends.cudnn.enabled = True |
|
torch.backends.cudnn.benchmark = False |
|
train(num_gpus, args.rank, args.group_name, **train_config) |
|
|