Spaces:
Running
Running
| import os | |
| import sys | |
| os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache' | |
| import torch | |
| import torch.multiprocessing as mp | |
| import random | |
| import librosa | |
| import yaml | |
| import argparse | |
| import torchaudio | |
| import torchaudio.compliance.kaldi as kaldi | |
| import glob | |
| import time | |
| from tqdm import tqdm | |
| import shutil | |
| import accelerate | |
| from optimizers import build_optimizer | |
| from data.ft_dataset import build_ft_dataloader | |
| import hydra | |
| from omegaconf import DictConfig | |
| from accelerate import Accelerator | |
| from accelerate import DistributedDataParallelKwargs | |
| from accelerate.logging import get_logger | |
| class Trainer: | |
| def __init__( | |
| self, | |
| config_path, | |
| pretrained_cfm_ckpt_path, | |
| pretrained_ar_ckpt_path, | |
| data_dir, | |
| run_name, | |
| batch_size=0, | |
| num_workers=0, | |
| steps=1000, | |
| save_interval=500, | |
| max_epochs=1000, | |
| train_cfm=True, | |
| train_ar=False, | |
| mixed_precision=None, | |
| ): | |
| self.config_path = config_path | |
| self.mixed_precision = mixed_precision | |
| # Load configuration | |
| self.config = yaml.safe_load(open(config_path)) | |
| # Setup logging directory | |
| self.log_dir = os.path.join("runs", run_name) | |
| if not os.path.exists(self.log_dir): | |
| os.makedirs(self.log_dir, exist_ok=True) | |
| shutil.copy(config_path, os.path.join(self.log_dir, os.path.basename(config_path))) | |
| # Setup accelerator | |
| ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True, broadcast_buffers=False) | |
| self.accelerator = Accelerator( | |
| project_dir=self.log_dir, | |
| split_batches=True, | |
| kwargs_handlers=[ddp_kwargs], | |
| mixed_precision=mixed_precision | |
| ) | |
| self.device = self.accelerator.device | |
| # Initialize training parameters | |
| self._init_dataloader( | |
| data_dir=data_dir, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| spect_params=self.config['mel_fn'], | |
| sr=self.config['sr'], | |
| ) | |
| # Initialize models and optimizers | |
| self._init_models(train_cfm=train_cfm, train_ar=train_ar) | |
| # Load checkpoint if available | |
| self._load_checkpoint(pretrained_cfm_ckpt_path, pretrained_ar_ckpt_path) | |
| # Initialize training parameters | |
| self.iters = 0 | |
| self.start_epoch = 0 | |
| self.log_interval = 10 | |
| self.max_steps = steps | |
| self.save_interval = save_interval | |
| self.max_epochs = max_epochs | |
| def _init_dataloader(self, data_dir, batch_size, num_workers, spect_params, sr): | |
| self.spect_params = spect_params | |
| self.sr = sr | |
| # Initialize dataloader | |
| self.train_dataloader = build_ft_dataloader( | |
| data_dir, | |
| spect_params, | |
| self.sr, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| ) | |
| def _init_models(self, train_cfm=True, train_ar=False): | |
| """Initialize models and optimizers""" | |
| assert train_cfm or train_ar, "At least one model should be trained" | |
| self.train_cfm = train_cfm | |
| self.train_ar = train_ar | |
| # Initialize main model | |
| self._init_main_model(train_cfm=train_cfm, train_ar=train_ar) | |
| # Initialize optimizers | |
| self._init_optimizers() | |
| def _init_main_model(self, train_cfm=True, train_ar=False): | |
| """Initialize the main model""" | |
| with self.accelerator.main_process_first(): | |
| cfg = DictConfig(self.config) | |
| self.model = hydra.utils.instantiate(cfg).to(self.device) | |
| for p in self.model.parameters(): | |
| p.requires_grad = False | |
| if train_cfm: | |
| for p in self.model.cfm.parameters(): | |
| p.requires_grad = True | |
| for p in self.model.cfm_length_regulator.parameters(): | |
| p.requires_grad = True | |
| if train_ar: | |
| for p in self.model.ar.parameters(): | |
| p.requires_grad = True | |
| for p in self.model.ar_length_regulator.parameters(): | |
| p.requires_grad = True | |
| def _init_optimizers(self): | |
| """Initialize optimizers and schedulers""" | |
| from optimizers import build_single_optimizer | |
| self.optimizer, self.scheduler = build_single_optimizer( | |
| self.model, | |
| lr=2e-5, | |
| ) | |
| self.optimizer = self.accelerator.prepare(self.optimizer) | |
| self.scheduler = self.accelerator.prepare(self.scheduler) | |
| def _find_checkpoint(self, name_pattern, max_keep=1): | |
| """Find checkpoint files in the specified directory""" | |
| available_checkpoints = glob.glob(os.path.join(self.log_dir, name_pattern)) | |
| if len(available_checkpoints) > max_keep - 1: | |
| # find the checkpoint that has the highest step number | |
| latest_checkpoint = max( | |
| available_checkpoints, key=lambda x: int(x.split("_")[-1].split(".")[0]) | |
| ) | |
| earliest_checkpoint = min( | |
| available_checkpoints, key=lambda x: int(x.split("_")[-1].split(".")[0]) | |
| ) | |
| # delete the earliest checkpoint | |
| if ( | |
| earliest_checkpoint != latest_checkpoint | |
| and self.accelerator.is_main_process | |
| and len(available_checkpoints) > max_keep | |
| ): | |
| os.remove(earliest_checkpoint) | |
| print(f"Removed {earliest_checkpoint}") | |
| return latest_checkpoint | |
| else: | |
| return None | |
| def _load_checkpoint(self, pretrained_cfm_ckpt_path, pretrained_ar_ckpt_path): | |
| """Load checkpoint if available""" | |
| cfm_checkpoint_path = pretrained_cfm_ckpt_path or self._find_checkpoint("CFM_epoch_*_step_*.pth", max_keep=1) | |
| ar_checkpoint_path = pretrained_ar_ckpt_path or self._find_checkpoint("AR_epoch_*_step_*.pth", max_keep=1) | |
| with self.accelerator.main_process_first(): | |
| if cfm_checkpoint_path: | |
| print(f"Loading CFM checkpoint from {cfm_checkpoint_path}") | |
| if ar_checkpoint_path: | |
| print(f"Loading AR checkpoint from {ar_checkpoint_path}") | |
| self.model.load_checkpoints(cfm_checkpoint_path=cfm_checkpoint_path, ar_checkpoint_path=ar_checkpoint_path) | |
| self.model = self.accelerator.prepare(self.model) | |
| def filter_state_dict_shapes(self, params, model): | |
| model_state_dict = model.state_dict() | |
| filtered_state_dict = { | |
| k: v | |
| for k, v in params.items() | |
| if k in model_state_dict and v.shape == model_state_dict[k].shape | |
| } | |
| skipped_keys = set(params.keys()) - set(filtered_state_dict.keys()) | |
| if skipped_keys: | |
| print( | |
| f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}" | |
| ) | |
| return filtered_state_dict, skipped_keys | |
| def train(self): | |
| """Main training loop""" | |
| for epoch in range(self.start_epoch, self.start_epoch + 1000): | |
| epoch_start_time = time.time() | |
| try: | |
| self.train_dataloader.sampler.set_epoch(epoch) | |
| except AttributeError: | |
| pass | |
| self.model.train() | |
| for i, batch in enumerate(tqdm(self.train_dataloader)): | |
| # Process batch | |
| self._process_batch(epoch, i, batch) | |
| if self.iters >= self.max_steps and self.accelerator.is_main_process: | |
| print("Reached max steps, stopping training") | |
| self._save_checkpoint(epoch) | |
| exit() | |
| # Log epoch completion | |
| if self.accelerator.is_main_process: | |
| print(f"Epoch {epoch} completed in {time.time() - epoch_start_time:.2f} seconds") | |
| if epoch + 1 >= self.max_epochs and self.accelerator.is_main_process: | |
| print("Reached max epochs, stopping training") | |
| self._save_checkpoint(epoch) | |
| exit() | |
| def _process_batch(self, epoch, i, batch): | |
| """Process a single batch""" | |
| # Move batch to device | |
| waves, mels, wave_lens, mel_lens = batch | |
| # Resample to 16kHz for ASR models | |
| waves_16k = torchaudio.functional.resample(waves, self.sr, 16000) | |
| wave_lengths_16k = (wave_lens.float() * 16000 / self.sr).long() | |
| # Forward pass and loss calculation | |
| with self.accelerator.autocast(): | |
| loss_ar, loss_cfm = self.model( | |
| waves_16k.to(self.device), | |
| mels.to(self.device), | |
| wave_lengths_16k.to(self.device), | |
| mel_lens.to(self.device), | |
| forward_ar=self.train_ar, | |
| forward_cfm=self.train_cfm, | |
| ) | |
| loss = loss_ar + loss_cfm | |
| self.accelerator.backward(loss) | |
| grad_norm_g = torch.nn.utils.clip_grad_norm_( | |
| self.model.parameters(), 1000.0 | |
| ) | |
| self.optimizer.step() | |
| self.scheduler.step(self.iters) | |
| self.optimizer.zero_grad() | |
| # Log training progress | |
| self._log_training_progress(epoch, i, loss, loss_ar, loss_cfm, grad_norm_g) | |
| # Save checkpoint | |
| if self.iters != 0 and self.iters % self.save_interval == 0 and self.accelerator.is_main_process: | |
| self._save_checkpoint(epoch) | |
| # Increment iteration counter | |
| self.iters += 1 | |
| def _log_training_progress(self, epoch, i, loss, loss_ar, loss_cfm, grad_norm_g): | |
| """Log training progress to tensorboard and wandb""" | |
| if self.iters % self.log_interval == 0 and self.accelerator.is_main_process: | |
| with torch.no_grad(): | |
| cur_lr = self.scheduler.get_last_lr()[0] if i != 0 else 0 | |
| # Log to console | |
| print("Epoch %d, Iteration %d, Loss: %.4f, Loss AR: %.4f, Loss CFM: %.4f, Grad Norm: %.4f, LR: %.6f" | |
| % (epoch, i, loss.item(), loss_ar.item(), loss_cfm.item(), grad_norm_g, cur_lr)) | |
| def _save_checkpoint(self, epoch): | |
| """Save model checkpoint""" | |
| print('Saving checkpoint...') | |
| if self.train_ar: | |
| state = { | |
| 'net': { | |
| 'ar': self.accelerator.unwrap_model(self.model).ar.state_dict(), | |
| 'length_regulator': self.accelerator.unwrap_model(self.model).ar_length_regulator.state_dict(), | |
| }, | |
| 'iters': self.iters, | |
| 'epoch': epoch, | |
| } | |
| save_path = os.path.join(self.log_dir, 'AR_epoch_%05d_step_%05d.pth' % (epoch, self.iters)) | |
| torch.save(state, save_path) | |
| print(f"Saved AR checkpoint to {save_path}") | |
| # Find all checkpoints and remove old ones | |
| self._remove_old_checkpoints("AR_epoch_*_step_*.pth", max_keep=1) | |
| if self.train_cfm: | |
| state = { | |
| 'net': { | |
| 'cfm': self.accelerator.unwrap_model(self.model).cfm.state_dict(), | |
| 'length_regulator': self.accelerator.unwrap_model(self.model).cfm_length_regulator.state_dict(), | |
| }, | |
| 'iters': self.iters, | |
| 'epoch': epoch, | |
| } | |
| save_path = os.path.join(self.log_dir, 'CFM_epoch_%05d_step_%05d.pth' % (epoch, self.iters)) | |
| torch.save(state, save_path) | |
| print(f"Saved CFM checkpoint to {save_path}") | |
| # Find all checkpoints and remove old ones | |
| self._remove_old_checkpoints("CFM_epoch_*_step_*.pth", max_keep=1) | |
| def _remove_old_checkpoints(self, name_pattern, max_keep=1): | |
| """Remove old checkpoints""" | |
| checkpoints = glob.glob(os.path.join(self.log_dir, name_pattern)) | |
| if len(checkpoints) > max_keep: | |
| # Sort by step | |
| checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0])) | |
| # Remove all except last 1 | |
| for cp in checkpoints[:-max_keep]: | |
| os.remove(cp) | |
| def main(args): | |
| trainer = Trainer( | |
| config_path=args.config, | |
| pretrained_cfm_ckpt_path=args.pretrained_cfm_ckpt, | |
| pretrained_ar_ckpt_path=args.pretrained_ar_ckpt, | |
| data_dir=args.dataset_dir, | |
| run_name=args.run_name, | |
| batch_size=args.batch_size, | |
| steps=args.max_steps, | |
| max_epochs=args.max_epochs, | |
| save_interval=args.save_every, | |
| num_workers=args.num_workers, | |
| train_cfm=args.train_cfm, | |
| train_ar=args.train_ar, | |
| ) | |
| trainer.train() | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config', type=str, default='configs/v2/vc_wrapper.yaml') | |
| parser.add_argument('--pretrained-cfm-ckpt', type=str, default=None) | |
| parser.add_argument('--pretrained-ar-ckpt', type=str, default=None) | |
| parser.add_argument('--dataset-dir', type=str, default='/path/to/dataset') | |
| parser.add_argument('--run-name', type=str, default='my_run') | |
| parser.add_argument('--batch-size', type=int, default=2) | |
| parser.add_argument('--max-steps', type=int, default=1000) | |
| parser.add_argument('--max-epochs', type=int, default=1000) | |
| parser.add_argument('--save-every', type=int, default=500) | |
| parser.add_argument('--num-workers', type=int, default=0) | |
| parser.add_argument('--train-cfm', action='store_true', help='Train CFM model') | |
| parser.add_argument('--train-ar', action='store_true', help='Train AR model') | |
| args = parser.parse_args() | |
| main(args) | |