|  | import warnings | 
					
						
						|  | warnings.filterwarnings("ignore", category=FutureWarning) | 
					
						
						|  |  | 
					
						
						|  | import logging | 
					
						
						|  | import math | 
					
						
						|  | import random | 
					
						
						|  | from datetime import timedelta | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | from tqdm import tqdm | 
					
						
						|  | import hydra | 
					
						
						|  | import numpy as np | 
					
						
						|  | import torch | 
					
						
						|  | import torch.distributed as distributed | 
					
						
						|  | from hydra import compose | 
					
						
						|  | from hydra.core.hydra_config import HydraConfig | 
					
						
						|  | from omegaconf import DictConfig, open_dict | 
					
						
						|  | from torch.distributed.elastic.multiprocessing.errors import record | 
					
						
						|  |  | 
					
						
						|  | from meanaudio.data.data_setup import setup_training_datasets, setup_val_datasets | 
					
						
						|  | from meanaudio.model.sequence_config import CONFIG_16K, CONFIG_44K | 
					
						
						|  | from meanaudio.runner_flowmatching import RunnerFlowMatching | 
					
						
						|  | from meanaudio.runner_meanflow import RunnerMeanFlow | 
					
						
						|  | from meanaudio.sample import sample | 
					
						
						|  | from meanaudio.utils.dist_utils import info_if_rank_zero, local_rank, world_size | 
					
						
						|  | from meanaudio.utils.logger import TensorboardLogger | 
					
						
						|  | from meanaudio.utils.synthesize_ema import synthesize_ema | 
					
						
						|  | import os | 
					
						
						|  | import wandb | 
					
						
						|  |  | 
					
						
						|  | torch.backends.cuda.matmul.allow_tf32 = True | 
					
						
						|  | torch.backends.cudnn.allow_tf32 = True | 
					
						
						|  |  | 
					
						
						|  | log = logging.getLogger() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def distributed_setup(): | 
					
						
						|  | distributed.init_process_group(backend="nccl", timeout=timedelta(hours=2)) | 
					
						
						|  | log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}') | 
					
						
						|  | return local_rank, world_size | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @record | 
					
						
						|  | @hydra.main(version_base='1.3.2', config_path='config', config_name='train_config.yaml') | 
					
						
						|  | def train(cfg: DictConfig): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if cfg.get("debug", False): | 
					
						
						|  | import debugpy | 
					
						
						|  | if "RANK" not in os.environ or int(os.environ["RANK"]) == 0: | 
					
						
						|  | debugpy.listen(6665) | 
					
						
						|  | print(f'Waiting for debugger attach (rank {os.environ["RANK"]})...') | 
					
						
						|  | debugpy.wait_for_client() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | torch.cuda.set_device(local_rank) | 
					
						
						|  | torch.backends.cudnn.benchmark = cfg.cudnn_benchmark | 
					
						
						|  | distributed_setup() | 
					
						
						|  | num_gpus = world_size | 
					
						
						|  | run_dir = HydraConfig.get().run.dir | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | seq_cfg = CONFIG_16K | 
					
						
						|  | with open_dict(cfg): | 
					
						
						|  | cfg.data_dim.latent_seq_len = seq_cfg.latent_seq_len | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | log = TensorboardLogger(cfg.exp_id, | 
					
						
						|  | run_dir, | 
					
						
						|  | logging.getLogger(), | 
					
						
						|  | is_rank0=(local_rank == 0), | 
					
						
						|  | enable_email=cfg.enable_email and not cfg.debug) | 
					
						
						|  |  | 
					
						
						|  | info_if_rank_zero(log, f'All configuration: {cfg}') | 
					
						
						|  | info_if_rank_zero(log, f'Number of GPUs detected: {num_gpus}') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | info_if_rank_zero(log, f'Number of dataloader workers (per GPU): {cfg.num_workers}') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | torch.manual_seed(cfg.seed) | 
					
						
						|  | np.random.seed(cfg.seed) | 
					
						
						|  | random.seed(cfg.seed) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | info_if_rank_zero(log, f'Training configuration: {cfg}') | 
					
						
						|  | cfg.batch_size //= num_gpus | 
					
						
						|  | info_if_rank_zero(log, f'Batch size (per GPU): {cfg.batch_size}') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | total_iterations = cfg['num_iterations'] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if cfg['text_encoder_name'] == 't5_clap_cat': | 
					
						
						|  | cfg['concat_text_fc'] = True | 
					
						
						|  |  | 
					
						
						|  | dataset, sampler, loader = setup_training_datasets(cfg) | 
					
						
						|  | info_if_rank_zero(log, f'Number of training samples: {len(dataset)}') | 
					
						
						|  | info_if_rank_zero(log, f'Number of training batches: {len(loader)}') | 
					
						
						|  |  | 
					
						
						|  | val_dataset, val_loader, eval_loader = setup_val_datasets(cfg) | 
					
						
						|  | info_if_rank_zero(log, f'Number of val samples: {len(val_dataset)}') | 
					
						
						|  | val_cfg = cfg.data.AudioCaps_val_npz | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | latent_mean, latent_std = torch.load(cfg.data.latent_mean), torch.load(cfg.data.latent_std) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not cfg.use_repa: | 
					
						
						|  | if cfg.use_meanflow: | 
					
						
						|  | trainer = RunnerMeanFlow(cfg, | 
					
						
						|  | log=log, | 
					
						
						|  | run_path=run_dir, | 
					
						
						|  | for_training=True, | 
					
						
						|  | latent_mean=latent_mean, | 
					
						
						|  | latent_std=latent_std).enter_train() | 
					
						
						|  | else: | 
					
						
						|  | trainer = RunnerFlowMatching(cfg, | 
					
						
						|  | log=log, | 
					
						
						|  | run_path=run_dir, | 
					
						
						|  | for_training=True, | 
					
						
						|  | latent_mean=latent_mean, | 
					
						
						|  | latent_std=latent_std).enter_train() | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | raise NotImplementedError('REPA is not supported yet') | 
					
						
						|  | trainer = RunnerAT_REPA(cfg, | 
					
						
						|  | log=log, | 
					
						
						|  | run_path=run_dir, | 
					
						
						|  | for_training=True, | 
					
						
						|  | latent_mean=latent_mean, | 
					
						
						|  | latent_std=latent_std).enter_train() | 
					
						
						|  |  | 
					
						
						|  | eval_rng_clone = trainer.rng.graphsafe_get_state() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if cfg['checkpoint'] is not None: | 
					
						
						|  | curr_iter = trainer.load_checkpoint(cfg['checkpoint']) | 
					
						
						|  | cfg['checkpoint'] = None | 
					
						
						|  | info_if_rank_zero(log, 'Model checkpoint loaded!') | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | checkpoint = trainer.get_latest_checkpoint_path() | 
					
						
						|  | if checkpoint is not None: | 
					
						
						|  | curr_iter = trainer.load_checkpoint(checkpoint) | 
					
						
						|  | info_if_rank_zero(log, 'Latest checkpoint loaded!') | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | curr_iter = 0 | 
					
						
						|  | if cfg['weights'] is not None: | 
					
						
						|  | info_if_rank_zero(log, 'Loading weights from the disk') | 
					
						
						|  | trainer.load_weights(cfg['weights']) | 
					
						
						|  | cfg['weights'] = None | 
					
						
						|  | else: | 
					
						
						|  | info_if_rank_zero(log, 'No checkpoint or weights found, starting from scratch') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | total_epoch = math.ceil(total_iterations / len(loader)) | 
					
						
						|  | current_epoch = curr_iter // len(loader) | 
					
						
						|  | info_if_rank_zero(log, f'We will approximately use {total_epoch - current_epoch} epochs.') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | np.random.seed(np.random.randint(2**30 - 1) + local_rank * 1000) | 
					
						
						|  | while curr_iter < total_iterations: | 
					
						
						|  |  | 
					
						
						|  | sampler.set_epoch(current_epoch) | 
					
						
						|  | current_epoch += 1 | 
					
						
						|  | log.debug(f'Current epoch: {current_epoch}') | 
					
						
						|  |  | 
					
						
						|  | trainer.enter_train() | 
					
						
						|  | trainer.log.data_timer.start() | 
					
						
						|  | for data in loader: | 
					
						
						|  | trainer.train_pass(data, curr_iter) | 
					
						
						|  |  | 
					
						
						|  | if (curr_iter + 1) % cfg.val_interval == 0: | 
					
						
						|  |  | 
					
						
						|  | train_rng_snapshot = trainer.rng.graphsafe_get_state() | 
					
						
						|  | trainer.rng.graphsafe_set_state(eval_rng_clone) | 
					
						
						|  | info_if_rank_zero(log, f'Iteration {curr_iter}: validating') | 
					
						
						|  | total_loss = 0 | 
					
						
						|  | n = 0 | 
					
						
						|  | if cfg.use_repa: | 
					
						
						|  | total_diff_loss = 0 | 
					
						
						|  | total_proj_loss = 0 | 
					
						
						|  | for data in tqdm(val_loader): | 
					
						
						|  | n += 1 | 
					
						
						|  | if not cfg.use_repa: | 
					
						
						|  | mean_loss = trainer.validation_pass(data, curr_iter) | 
					
						
						|  | total_loss += mean_loss | 
					
						
						|  | else: | 
					
						
						|  | mean_loss, diff_loss, proj_loss =  trainer.validation_pass(data, curr_iter) | 
					
						
						|  | total_loss += mean_loss | 
					
						
						|  | total_diff_loss += diff_loss | 
					
						
						|  | total_proj_loss += proj_loss | 
					
						
						|  |  | 
					
						
						|  | total_loss /= n | 
					
						
						|  | if cfg.use_repa: | 
					
						
						|  | total_diff_loss /= n | 
					
						
						|  | total_proj_loss /= n | 
					
						
						|  | if cfg.use_wandb and local_rank == 0: | 
					
						
						|  | wandb.log({"val/loss": total_loss}) | 
					
						
						|  | if cfg.use_repa: | 
					
						
						|  | wandb.log({"val/diff_loss": total_diff_loss}, step=curr_iter) | 
					
						
						|  | wandb.log({"val/proj_loss": total_proj_loss}, step=curr_iter) | 
					
						
						|  |  | 
					
						
						|  | distributed.barrier() | 
					
						
						|  | trainer.val_integrator.finalize('val', curr_iter, ignore_timer=True) | 
					
						
						|  | trainer.rng.graphsafe_set_state(train_rng_snapshot) | 
					
						
						|  |  | 
					
						
						|  | if (curr_iter + 1) % cfg.eval_interval == 0: | 
					
						
						|  | save_eval = (curr_iter + 1) % cfg.save_eval_interval == 0 | 
					
						
						|  | train_rng_snapshot = trainer.rng.graphsafe_get_state() | 
					
						
						|  | trainer.rng.graphsafe_set_state(eval_rng_clone) | 
					
						
						|  | info_if_rank_zero(log, f'Iteration {curr_iter}: inference') | 
					
						
						|  | for data in tqdm(eval_loader): | 
					
						
						|  | audio_path = trainer.inference_pass(data, | 
					
						
						|  | curr_iter, | 
					
						
						|  | val_cfg, | 
					
						
						|  | save_eval=save_eval) | 
					
						
						|  | distributed.barrier() | 
					
						
						|  | trainer.rng.graphsafe_set_state(train_rng_snapshot) | 
					
						
						|  | trainer.eval(audio_path, curr_iter, val_cfg) | 
					
						
						|  |  | 
					
						
						|  | curr_iter += 1 | 
					
						
						|  |  | 
					
						
						|  | if curr_iter >= total_iterations: | 
					
						
						|  | break | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | log.error(f'Error occurred at iteration {curr_iter}!') | 
					
						
						|  | log.critical(e.message if hasattr(e, 'message') else str(e)) | 
					
						
						|  | raise | 
					
						
						|  | finally: | 
					
						
						|  | if not cfg.debug: | 
					
						
						|  | trainer.save_checkpoint(curr_iter) | 
					
						
						|  | trainer.save_weights(curr_iter) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | del trainer | 
					
						
						|  | torch.cuda.empty_cache() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if local_rank == 0: | 
					
						
						|  | log.info(f'Synthesizing EMA with sigma={cfg.ema.default_output_sigma}') | 
					
						
						|  | ema_sigma = cfg.ema.default_output_sigma | 
					
						
						|  | state_dict = synthesize_ema(cfg, ema_sigma, step=None) | 
					
						
						|  | save_dir = Path(run_dir) / f'{cfg.exp_id}_ema_final.pth' | 
					
						
						|  | torch.save(state_dict, save_dir) | 
					
						
						|  | log.info(f'Synthesized EMA saved to {save_dir}!') | 
					
						
						|  | distributed.barrier() | 
					
						
						|  |  | 
					
						
						|  | log.info(f'Evaluation: {cfg}') | 
					
						
						|  | sample(cfg) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | log.complete() | 
					
						
						|  | distributed.barrier() | 
					
						
						|  | distributed.destroy_process_group() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == '__main__': | 
					
						
						|  | train() | 
					
						
						|  |  |