Spaces:
Build error
Build error
| from omegaconf import OmegaConf | |
| import torchaudio | |
| from ttts.diffusion.aa_model import AA_diffusion, denormalize_tacotron_mel, normalize_tacotron_mel | |
| from ttts.gpt.voice_tokenizer import VoiceBpeTokenizer | |
| from ttts.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule | |
| import torch | |
| import copy | |
| from datetime import datetime | |
| import json | |
| from vocos import Vocos | |
| from pathlib import Path | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| from ttts.utils.infer_utils import load_model | |
| from ttts.utils.utils import EMA, clean_checkpoints, plot_spectrogram_to_numpy, summarize, update_moving_average | |
| from ttts.diffusion.dataset import DiffusionDataset, DiffusionCollater | |
| from ttts.diffusion.model import DiffusionTts | |
| import torch | |
| import os | |
| from torch.utils.data import DataLoader | |
| from torch import nn | |
| from torch.optim import AdamW | |
| from accelerate import Accelerator | |
| import functools | |
| import random | |
| import torch | |
| from torch.cuda.amp import autocast | |
| from ttts.utils.diffusion import get_named_beta_schedule | |
| from ttts.utils.resample import create_named_schedule_sampler, LossAwareSampler, DeterministicSampler, LossSecondMomentResampler | |
| from ttts.utils.diffusion import space_timesteps, SpacedDiffusion | |
| # from ttts.diffusion.diffusion_util import Diffuser | |
| # from accelerate import DistributedDataParallelKwargs | |
| def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_latents, temperature=1, verbose=True): | |
| """ | |
| Uses the specified diffusion model to convert discrete codes into a spectrogram. | |
| """ | |
| with torch.no_grad(): | |
| output_seq_len = latents.shape[2] * 4 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. | |
| output_shape = (latents.shape[0], 100, output_seq_len) | |
| noise = torch.randn(output_shape, device=latents.device) * temperature | |
| mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise, | |
| model_kwargs= { | |
| "hint": latents, | |
| "refer": conditioning_latents | |
| }, | |
| progress=verbose) | |
| return denormalize_tacotron_mel(mel)[:,:,:output_seq_len] | |
| def set_requires_grad(model, val): | |
| for p in model.parameters(): | |
| p.requires_grad = val | |
| def get_grad_norm(model): | |
| total_norm = 0 | |
| for name,p in model.named_parameters(): | |
| try: | |
| param_norm = p.grad.data.norm(2) | |
| total_norm += param_norm.item() ** 2 | |
| except: | |
| print(name) | |
| pass | |
| total_norm = total_norm ** (1. / 2) | |
| return total_norm | |
| def cycle(dl): | |
| while True: | |
| for data in dl: | |
| yield data | |
| def warmup(step): | |
| if step<1000: | |
| return float(step/1000) | |
| else: | |
| return 1 | |
| def count_parameters(model): | |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| class Trainer(object): | |
| def __init__(self, cfg_path='ttts/diffusion/config.yaml'): | |
| # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) | |
| # self.accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) | |
| self.accelerator = Accelerator() | |
| self.cfg = OmegaConf.load(cfg_path) | |
| # self.cfg = json.load(open(cfg_path)) | |
| trained_diffusion_steps = 1000 | |
| self.trained_diffusion_steps = 1000 | |
| desired_diffusion_steps = 1000 | |
| self.desired_diffusion_steps = 1000 | |
| cond_free_k = 2. | |
| self.gpt = load_model('gpt',self.cfg['dataset']['gpt_path'],'ttts/gpt/config.json','cuda') | |
| self.mel_length_compression = self.gpt.mel_length_compression | |
| self.diffuser= SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon', | |
| model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps), | |
| conditioning_free=False, conditioning_free_k=cond_free_k) | |
| self.infer_diffuser = SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [50]), model_mean_type='epsilon', | |
| model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps), | |
| conditioning_free=True, conditioning_free_k=cond_free_k, sampler='dpm++2m') | |
| # self.diffusion = DiffusionTts(**self.cfg['diffusion']) | |
| self.diffusion = AA_diffusion(self.cfg) | |
| print("model params:", count_parameters(self.diffusion)) | |
| self.dataset = DiffusionDataset(self.cfg) | |
| self.dataloader = DataLoader(self.dataset, **self.cfg['dataloader'], collate_fn=DiffusionCollater()) | |
| self.vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") | |
| self.train_steps = self.cfg['train']['train_steps'] | |
| self.val_freq = self.cfg['train']['val_freq'] | |
| if self.accelerator.is_main_process: | |
| self.eval_dataloader = DataLoader(self.dataset, batch_size = 1, shuffle= False, num_workers = 16, pin_memory=True, collate_fn=DiffusionCollater()) | |
| self.eval_dataloader = cycle(self.eval_dataloader) | |
| self.ema_model = self._get_target_encoder(self.diffusion).to(self.accelerator.device) | |
| now = datetime.now() | |
| self.logs_folder = Path(self.cfg['train']['logs_folder']+'/'+now.strftime("%Y-%m-%d-%H-%M-%S")) | |
| self.logs_folder.mkdir(exist_ok = True, parents=True) | |
| self.ema_updater = EMA(0.999) | |
| self.optimizer = AdamW(self.diffusion.parameters(),lr=self.cfg['train']['lr'], betas=(0.9, 0.999), weight_decay=0.01) | |
| self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup) | |
| self.diffusion, self.dataloader, self.optimizer, self.scheduler, self.gpt = self.accelerator.prepare(self.diffusion, self.dataloader, self.optimizer, self.scheduler, self.gpt) | |
| self.dataloader = cycle(self.dataloader) | |
| self.step=0 | |
| self.gradient_accumulate_every=self.cfg['train']['accumulate_num'] | |
| self.unconditioned_percentage = self.cfg['train']['unconditioned_percentage'] | |
| def _get_target_encoder(self, model): | |
| target_encoder = copy.deepcopy(model) | |
| set_requires_grad(target_encoder, False) | |
| for p in target_encoder.parameters(): | |
| p.DO_NOT_TRAIN = True | |
| return target_encoder | |
| def save(self, milestone): | |
| if not self.accelerator.is_local_main_process: | |
| return | |
| data = { | |
| 'step': self.step, | |
| 'model': self.accelerator.get_state_dict(self.diffusion), | |
| } | |
| torch.save(data, str(self.logs_folder / f'model-{milestone}.pt')) | |
| def load(self, model_path): | |
| accelerator = self.accelerator | |
| device = accelerator.device | |
| data = torch.load(model_path, map_location=device) | |
| state_dict = data['model'] | |
| self.step = data['step'] | |
| model = self.accelerator.unwrap_model(self.diffusion) | |
| model.load_state_dict(state_dict) | |
| def train(self): | |
| accelerator = self.accelerator | |
| device = accelerator.device | |
| if accelerator.is_main_process: | |
| writer = SummaryWriter(log_dir=self.logs_folder) | |
| writer_eval = SummaryWriter(log_dir=os.path.join(self.logs_folder, 'eval')) | |
| with tqdm(initial = self.step, total = self.train_steps, disable = not accelerator.is_main_process) as pbar: | |
| while self.step < self.train_steps: | |
| total_loss = 0. | |
| # with torch.autograd.detect_anomaly(): | |
| for _ in range(self.gradient_accumulate_every): | |
| data = next(self.dataloader) | |
| if data==None: | |
| continue | |
| with torch.no_grad(): | |
| latent = self.gpt(data['padded_mel_refer'], data['padded_text'], | |
| torch.tensor([data['padded_text'].shape[-1]], device=device), data['padded_mel_code'], | |
| torch.tensor([data['padded_mel_code'].shape[-1]*self.mel_length_compression], device=device), | |
| return_latent=True, clip_inputs=False).transpose(1,2) | |
| # mel_recon_padded, mel_padded, mel_lengths, refer_padded, refer_lengths | |
| x_start = normalize_tacotron_mel(data['padded_mel'].to(device)) | |
| aligned_conditioning = latent | |
| conditioning_latent = normalize_tacotron_mel(data['padded_mel_refer'].to(device)) | |
| t = torch.randint(0, self.desired_diffusion_steps, (x_start.shape[0],), device=device).long().to(device) | |
| with self.accelerator.autocast(): | |
| loss = self.diffuser.training_losses( | |
| model = self.diffusion, | |
| x_start = x_start, | |
| t = t, | |
| model_kwargs = { | |
| "hint": aligned_conditioning, | |
| "refer": conditioning_latent | |
| }, | |
| )["loss"].mean() | |
| unused_params =[] | |
| model = self.accelerator.unwrap_model(self.diffusion) | |
| unused_params.extend(list(model.refer_model.blocks.parameters())) | |
| unused_params.extend(list(model.refer_model.out.parameters())) | |
| unused_params.extend(list(model.refer_model.hint_converter.parameters())) | |
| unused_params.extend(list(model.refer_enc.visual.proj)) | |
| extraneous_addition = 0 | |
| for p in unused_params: | |
| extraneous_addition = extraneous_addition + p.mean() | |
| loss = loss + 0*extraneous_addition | |
| loss = loss / self.gradient_accumulate_every | |
| total_loss += loss.item() | |
| self.accelerator.backward(loss) | |
| grad_norm = get_grad_norm(self.diffusion) | |
| accelerator.clip_grad_norm_(self.diffusion.parameters(), 1.0) | |
| pbar.set_description(f'loss: {total_loss:.4f}') | |
| accelerator.wait_for_everyone() | |
| self.optimizer.step() | |
| self.optimizer.zero_grad() | |
| self.scheduler.step() | |
| accelerator.wait_for_everyone() | |
| # if accelerator.is_main_process: | |
| # update_moving_average(self.ema_updater,self.ema_model,self.diffusion) | |
| if accelerator.is_main_process and self.step % self.val_freq == 0: | |
| scalar_dict = {"loss": total_loss, "loss/grad": grad_norm, "lr":self.scheduler.get_last_lr()[0]} | |
| summarize( | |
| writer=writer, | |
| global_step=self.step, | |
| scalars=scalar_dict | |
| ) | |
| if accelerator.is_main_process and self.step % self.cfg['train']['save_freq'] == 0: | |
| self.ema_model.eval() | |
| data = next(self.eval_dataloader) | |
| text_padded, mel_code_padded, mel_padded, mel_lengths,\ | |
| refer_padded, refer_lengths = data['padded_text'].to(device), data['padded_mel_code'].to(device), data['padded_mel'], data['mel_lengths'], data['padded_mel_refer'].to(device), data['mel_refer_lengths'] | |
| text_padded, mel_code_padded, refer_padded = text_padded.to(device), mel_code_padded.to(device), refer_padded.to(device) | |
| with torch.no_grad(): | |
| latent = self.gpt(refer_padded, text_padded, | |
| torch.tensor([text_padded.shape[-1]], device=device), mel_code_padded, | |
| torch.tensor([mel_code_padded.shape[-1]*self.mel_length_compression], device=device), | |
| return_latent=True, clip_inputs=False).transpose(1,2) | |
| refer_padded = normalize_tacotron_mel(refer_padded) | |
| with torch.no_grad(): | |
| diffusion = self.accelerator.unwrap_model(self.diffusion) | |
| mel = do_spectrogram_diffusion(diffusion, self.infer_diffuser,latent,refer_padded,temperature=0.8) | |
| mel = mel.detach().cpu() | |
| milestone = self.step // self.cfg['train']['save_freq'] | |
| gen = self.vocos.decode(mel) | |
| torchaudio.save(str(self.logs_folder / f'sample-{milestone}.wav'), gen, 24000) | |
| audio_dict = {} | |
| audio_dict.update({ | |
| f"gen/audio": gen, | |
| }) | |
| image_dict = { | |
| f"gt/mel": plot_spectrogram_to_numpy(mel_padded[0, :, :].detach().unsqueeze(-1).cpu()), | |
| f"gen/mel": plot_spectrogram_to_numpy(mel[0, :, :].detach().unsqueeze(-1).cpu()), | |
| } | |
| summarize( | |
| writer=writer_eval, | |
| audios=audio_dict, | |
| global_step=self.step, | |
| images=image_dict, | |
| ) | |
| keep_ckpts = self.cfg['train']['keep_ckpts'] | |
| if keep_ckpts > 0: | |
| clean_checkpoints(path_to_models=self.logs_folder, n_ckpts_to_keep=keep_ckpts, sort_by_time=True) | |
| self.save(self.step//1000) | |
| self.ema_model.train() | |
| self.step += 1 | |
| pbar.update(1) | |
| accelerator.print('training complete') | |
| if __name__ == '__main__': | |
| trainer = Trainer() | |
| trainer.load('/home/hyc/tortoise_plus_zh/ttts/diffusion/logs/2024-01-09-17-44-36/model-855.pt') | |
| trainer.train() | |