Diff-Pitcher / pitch_controller /train_world_tuner_24k.py
jerryhai
Track binary files with Git LFS
90f7c1e
raw
history blame
10.2 kB
import os, json, argparse, yaml
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from diffusers import DDIMScheduler
from dataset import VCDecLPCDataset, VCDecLPCBatchCollate, VCDecLPCTest
from models.unet import UNetVC
from modules.BigVGAN.inference import load_model
from utils import save_plot, save_audio
from utils import minmax_norm_diff, reverse_minmax_norm_diff
parser = argparse.ArgumentParser()
parser.add_argument('-config', type=str, default='config/DiffWorld_24k_log.yaml')
parser.add_argument('-seed', type=int, default=98)
parser.add_argument('-amp', type=bool, default=True)
parser.add_argument('-compile', type=bool, default=False)
parser.add_argument('-data_dir', type=str, default='../24k_center/')
parser.add_argument('-lpc_dir', type=str, default='world')
parser.add_argument('-vocoder_dir', type=str, default='modules/BigVGAN/ckpt/bigvgan_base_24khz_100band/g_05000000')
parser.add_argument('-train_frames', type=int, default=128)
parser.add_argument('-batch_size', type=int, default=32)
parser.add_argument('-test_size', type=int, default=1)
parser.add_argument('-num_workers', type=int, default=4)
parser.add_argument('-lr', type=float, default=5e-5)
parser.add_argument('-weight_decay', type=int, default=1e-6)
parser.add_argument('-epochs', type=int, default=80)
parser.add_argument('-save_every', type=int, default=2)
parser.add_argument('-log_step', type=int, default=200)
parser.add_argument('-log_dir', type=str, default='logs_dec_world_24k')
parser.add_argument('-ckpt_dir', type=str, default='ckpt_world_24k')
args = parser.parse_args()
args.save_ori = True
config = yaml.load(open(args.config), Loader=yaml.FullLoader)
mel_cfg = config['logmel']
ddpm_cfg = config['ddpm']
unet_cfg = config['unet']
f0_type = unet_cfg['pitch_type']
if __name__ == "__main__":
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
args.device = 'cuda'
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cuda.matmul.allow_tf32 = True
if torch.backends.cudnn.is_available():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
else:
args.device = 'cpu'
if os.path.exists(args.log_dir) is False:
os.makedirs(args.log_dir)
if os.path.exists(args.ckpt_dir) is False:
os.makedirs(args.ckpt_dir)
print('Initializing vocoder...')
hifigan, cfg = load_model(args.vocoder_dir, device=args.device)
print('Initializing data loaders...')
train_set = VCDecLPCDataset(args.data_dir, subset='train', content_dir=args.lpc_dir, f0_type=f0_type)
collate_fn = VCDecLPCBatchCollate(args.train_frames)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True,
collate_fn=collate_fn, num_workers=args.num_workers, drop_last=True)
val_set = VCDecLPCTest(args.data_dir, content_dir=args.lpc_dir, f0_type=f0_type)
val_loader = DataLoader(val_set, batch_size=1, shuffle=False)
print('Initializing and loading models...')
model = UNetVC(**unet_cfg).to(args.device)
print('Number of parameters = %.2fm\n' % (model.nparams / 1e6))
# prepare DPM scheduler
noise_scheduler = DDIMScheduler(num_train_timesteps=ddpm_cfg['num_train_steps'])
print('Initializing optimizers...')
optimizer = torch.optim.AdamW(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scaler = GradScaler()
if args.compile:
model = torch.compile(model)
print('Start training.')
global_step = 0
for epoch in range(1, args.epochs + 1):
print(f'Epoch: {epoch} [iteration: {global_step}]')
model.train()
losses = []
for step, batch in enumerate(tqdm(train_loader)):
optimizer.zero_grad()
# make spectrogram range from -1 to 1
mel = batch['mel1'].to(args.device)
mel = minmax_norm_diff(mel, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
if unet_cfg["use_ref_t"]:
mel_ref = batch['mel2'].to(args.device)
mel_ref = minmax_norm_diff(mel_ref, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
else:
mel_ref = None
f0 = batch['f0_1'].to(args.device)
mean = batch['content1'].to(args.device)
mean = minmax_norm_diff(mean, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
noise = torch.randn(mel.shape).to(args.device)
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps,
(args.batch_size,),
device=args.device, ).long()
noisy_mel = noise_scheduler.add_noise(mel, noise, timesteps)
if args.amp:
with autocast():
noise_pred = model(x=noisy_mel, mean=mean, f0=f0, t=timesteps, ref=mel_ref, embed=None)
loss = F.mse_loss(noise_pred, noise)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
noise_pred = model(x=noisy_mel, mean=mean, f0=f0, t=timesteps, ref=mel_ref, embed=None)
loss = F.mse_loss(noise_pred, noise)
# Backward propagation
loss.backward()
optimizer.step()
losses.append(loss.item())
global_step += 1
if global_step % args.log_step == 0:
losses = np.asarray(losses)
# msg = 'Epoch %d: loss = %.4f\n' % (epoch, np.mean(losses))
msg = '\nEpoch: [{}][{}]\t' \
'Batch: [{}][{}]\tLoss: {:.6f}\n'.format(epoch,
args.epochs,
step+1,
len(train_loader),
np.mean(losses))
with open(f'{args.log_dir}/train_dec.log', 'a') as f:
f.write(msg)
losses = []
if epoch % args.save_every > 0:
continue
print('Saving model...\n')
ckpt = model.state_dict()
torch.save(ckpt, f=f"{args.ckpt_dir}/lpc_vc_{epoch}.pt")
print('Inference...\n')
noise = None
noise_scheduler.set_timesteps(ddpm_cfg['inference_steps'])
model.eval()
with torch.no_grad():
for i, batch in enumerate(val_loader):
# optimizer.zero_grad()
generator = torch.Generator(device=args.device).manual_seed(args.seed)
mel = batch['mel1'].to(args.device)
mel = minmax_norm_diff(mel, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
if unet_cfg["use_ref_t"]:
mel_ref = batch['mel2'].to(args.device)
mel_ref = minmax_norm_diff(mel_ref, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
else:
mel_ref = None
f0 = batch['f0_1'].to(args.device)
embed = batch['embed'].to(args.device)
mean = batch['content1'].to(args.device)
mean = minmax_norm_diff(mean, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
# make spectrogram range from -1 to 1
if noise is None:
noise = torch.randn(mel.shape,
generator=generator,
device=args.device,
)
pred = noise
for t in noise_scheduler.timesteps:
pred = noise_scheduler.scale_model_input(pred, t)
model_output = model(x=pred, mean=mean, f0=f0, t=t, ref=mel_ref, embed=None)
pred = noise_scheduler.step(model_output=model_output,
timestep=t,
sample=pred,
eta=ddpm_cfg['eta'], generator=generator).prev_sample
if os.path.exists(f'{args.log_dir}/audio/{i}/') is False:
os.makedirs(f'{args.log_dir}/audio/{i}/')
os.makedirs(f'{args.log_dir}/pic/{i}/')
# save pred
pred = reverse_minmax_norm_diff(pred, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
save_plot(pred.squeeze().cpu(), f'{args.log_dir}/pic/{i}/{epoch}_pred.png')
audio = hifigan(pred)
save_audio(f'{args.log_dir}/audio/{i}/{epoch}_pred.wav', mel_cfg['sampling_rate'], audio)
if args.save_ori is True:
# save ref
# mel_ref = reverse_minmax_norm_diff(mel_ref, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
# save_plot(mel_ref.squeeze().cpu(), f'{args.log_dir}/pic/{i}/{epoch}_ref.png')
# audio = hifigan(mel_ref)
# save_audio(f'{args.log_dir}/audio/{i}/{epoch}_ref.wav', mel_cfg['sampling_rate'], audio)
# save source
mel = reverse_minmax_norm_diff(mel, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
save_plot(mel.squeeze().cpu(), f'{args.log_dir}/pic/{i}/{epoch}_source.png')
audio = hifigan(mel)
save_audio(f'{args.log_dir}/audio/{i}/{epoch}_source.wav', mel_cfg['sampling_rate'], audio)
# save content
mean = reverse_minmax_norm_diff(mean, vmax=mel_cfg['max'], vmin=mel_cfg['min'])
save_plot(mean.squeeze().cpu(), f'{args.log_dir}/pic/{i}/{epoch}_avg.png')
audio = hifigan(mean)
save_audio(f'{args.log_dir}/audio/{i}/{epoch}_avg.wav', mel_cfg['sampling_rate'], audio)
args.save_ori = False