so-vits-svc / diffusion /solver.py
aoxiang1221's picture
update
85ce65e
raw
history blame
7.27 kB
import time
import librosa
import numpy as np
import torch
from torch import autocast
from torch.cuda.amp import GradScaler
from diffusion.logger import utils
from diffusion.logger.saver import Saver
def test(args, model, vocoder, loader_test, saver):
print(' [*] testing...')
model.eval()
# losses
test_loss = 0.
# intialization
num_batches = len(loader_test)
rtf_all = []
# run
with torch.no_grad():
for bidx, data in enumerate(loader_test):
fn = data['name'][0].split("/")[-1]
speaker = data['name'][0].split("/")[-2]
print('--------')
print('{}/{} - {}'.format(bidx, num_batches, fn))
# unpack data
for k in data.keys():
if not k.startswith('name'):
data[k] = data[k].to(args.device)
print('>>', data['name'][0])
# forward
st_time = time.time()
mel = model(
data['units'],
data['f0'],
data['volume'],
data['spk_id'],
gt_spec=None if model.k_step_max == model.timesteps else data['mel'],
infer=True,
infer_speedup=args.infer.speedup,
method=args.infer.method,
k_step=model.k_step_max
)
signal = vocoder.infer(mel, data['f0'])
ed_time = time.time()
# RTF
run_time = ed_time - st_time
song_time = signal.shape[-1] / args.data.sampling_rate
rtf = run_time / song_time
print('RTF: {} | {} / {}'.format(rtf, run_time, song_time))
rtf_all.append(rtf)
# loss
for i in range(args.train.batch_size):
loss = model(
data['units'],
data['f0'],
data['volume'],
data['spk_id'],
gt_spec=data['mel'],
infer=False,
k_step=model.k_step_max)
test_loss += loss.item()
# log mel
saver.log_spec(f"{speaker}_{fn}.wav", data['mel'], mel)
# log audi
path_audio = data['name_ext'][0]
audio, sr = librosa.load(path_audio, sr=args.data.sampling_rate)
if len(audio.shape) > 1:
audio = librosa.to_mono(audio)
audio = torch.from_numpy(audio).unsqueeze(0).to(signal)
saver.log_audio({f"{speaker}_{fn}_gt.wav": audio,f"{speaker}_{fn}_pred.wav": signal})
# report
test_loss /= args.train.batch_size
test_loss /= num_batches
# check
print(' [test_loss] test_loss:', test_loss)
print(' Real Time Factor', np.mean(rtf_all))
return test_loss
def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test):
# saver
saver = Saver(args, initial_global_step=initial_global_step)
# model size
params_count = utils.get_network_paras_amount({'model': model})
saver.log_info('--- model size ---')
saver.log_info(params_count)
# run
num_batches = len(loader_train)
model.train()
saver.log_info('======= start training =======')
scaler = GradScaler()
if args.train.amp_dtype == 'fp32':
dtype = torch.float32
elif args.train.amp_dtype == 'fp16':
dtype = torch.float16
elif args.train.amp_dtype == 'bf16':
dtype = torch.bfloat16
else:
raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype)
saver.log_info("epoch|batch_idx/num_batches|output_dir|batch/s|lr|time|step")
for epoch in range(args.train.epochs):
for batch_idx, data in enumerate(loader_train):
saver.global_step_increment()
optimizer.zero_grad()
# unpack data
for k in data.keys():
if not k.startswith('name'):
data[k] = data[k].to(args.device)
# forward
if dtype == torch.float32:
loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'],
aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False, k_step=model.k_step_max)
else:
with autocast(device_type=args.device, dtype=dtype):
loss = model(data['units'], data['f0'], data['volume'], data['spk_id'],
aug_shift = data['aug_shift'], gt_spec=data['mel'], infer=False, k_step=model.k_step_max)
# handle nan loss
if torch.isnan(loss):
raise ValueError(' [x] nan loss ')
else:
# backpropagate
if dtype == torch.float32:
loss.backward()
optimizer.step()
else:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
# log loss
if saver.global_step % args.train.interval_log == 0:
current_lr = optimizer.param_groups[0]['lr']
saver.log_info(
'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format(
epoch,
batch_idx,
num_batches,
args.env.expdir,
args.train.interval_log/saver.get_interval_time(),
current_lr,
loss.item(),
saver.get_total_time(),
saver.global_step
)
)
saver.log_value({
'train/loss': loss.item()
})
saver.log_value({
'train/lr': current_lr
})
# validation
if saver.global_step % args.train.interval_val == 0:
optimizer_save = optimizer if args.train.save_opt else None
# save latest
saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}')
last_val_step = saver.global_step - args.train.interval_val
if last_val_step % args.train.interval_force_save != 0:
saver.delete_model(postfix=f'{last_val_step}')
# run testing set
test_loss = test(args, model, vocoder, loader_test, saver)
# log loss
saver.log_info(
' --- <validation> --- \nloss: {:.3f}. '.format(
test_loss,
)
)
saver.log_value({
'validation/loss': test_loss
})
model.train()