import argparse import logging import os import pathlib from typing import List, NoReturn import lightning.pytorch as pl from lightning.pytorch.strategies import DDPStrategy from torch.utils.tensorboard import SummaryWriter from data.datamodules import * from utils import create_logging, parse_yaml from models.resunet import * from losses import get_loss_function from models.audiosep import AudioSep, get_model_class from data.waveform_mixers import SegmentMixer from models.clap_encoder import CLAP_Encoder from callbacks.base import CheckpointEveryNSteps from optimizers.lr_schedulers import get_lr_lambda def get_dirs( workspace: str, filename: str, config_yaml: str, devices_num: int ) -> List[str]: r"""Get directories and paths. Args: workspace (str): directory of workspace filename (str): filename of current .py file. config_yaml (str): config yaml path devices_num (int): 0 for cpu and 8 for training with 8 GPUs Returns: checkpoints_dir (str): directory to save checkpoints logs_dir (str), directory to save logs tf_logs_dir (str), directory to save TensorBoard logs statistics_path (str), directory to save statistics """ os.makedirs(workspace, exist_ok=True) yaml_name = pathlib.Path(config_yaml).stem # Directory to save checkpoints checkpoints_dir = os.path.join( workspace, "checkpoints", filename, "{},devices={}".format(yaml_name, devices_num), ) os.makedirs(checkpoints_dir, exist_ok=True) # Directory to save logs logs_dir = os.path.join( workspace, "logs", filename, "{},devices={}".format(yaml_name, devices_num), ) os.makedirs(logs_dir, exist_ok=True) # Directory to save TensorBoard logs create_logging(logs_dir, filemode="w") logging.info(args) tf_logs_dir = os.path.join( workspace, "tf_logs", filename, "{},devices={}".format(yaml_name, devices_num), ) # Directory to save statistics statistics_path = os.path.join( workspace, "statistics", filename, "{},devices={}".format(yaml_name, devices_num), "statistics.pkl", ) os.makedirs(os.path.dirname(statistics_path), exist_ok=True) return checkpoints_dir, logs_dir, tf_logs_dir, statistics_path def get_data_module( config_yaml: str, num_workers: int, batch_size: int, ) -> DataModule: r"""Create data_module. Mini-batch data can be obtained by: code-block:: python data_module.setup() for batch_data_dict in data_module.train_dataloader(): print(batch_data_dict.keys()) break Args: workspace: str config_yaml: str num_workers: int, e.g., 0 for non-parallel and 8 for using cpu cores for preparing data in parallel distributed: bool Returns: data_module: DataModule """ # read configurations configs = parse_yaml(config_yaml) sampling_rate = configs['data']['sampling_rate'] segment_seconds = configs['data']['segment_seconds'] # audio-text datasets datafiles = configs['data']['datafiles'] # dataset dataset = AudioTextDataset( datafiles=datafiles, sampling_rate=sampling_rate, max_clip_len=segment_seconds, ) # data module data_module = DataModule( train_dataset=dataset, num_workers=num_workers, batch_size=batch_size ) return data_module def train(args) -> NoReturn: r"""Train, evaluate, and save checkpoints. Args: workspace: str, directory of workspace gpus: int, number of GPUs to train config_yaml: str """ # arguments & parameters workspace = args.workspace config_yaml = args.config_yaml filename = args.filename devices_num = torch.cuda.device_count() # Read config file. configs = parse_yaml(config_yaml) # Configuration of data max_mix_num = configs['data']['max_mix_num'] sampling_rate = configs['data']['sampling_rate'] lower_db = configs['data']['loudness_norm']['lower_db'] higher_db = configs['data']['loudness_norm']['higher_db'] # Configuration of the separation model query_net = configs['model']['query_net'] model_type = configs['model']['model_type'] input_channels = configs['model']['input_channels'] output_channels = configs['model']['output_channels'] condition_size = configs['model']['condition_size'] use_text_ratio = configs['model']['use_text_ratio'] # Configuration of the trainer num_nodes = configs['train']['num_nodes'] batch_size = configs['train']['batch_size_per_device'] sync_batchnorm = configs['train']['sync_batchnorm'] num_workers = configs['train']['num_workers'] loss_type = configs['train']['loss_type'] optimizer_type = configs["train"]["optimizer"]["optimizer_type"] learning_rate = float(configs['train']["optimizer"]['learning_rate']) lr_lambda_type = configs['train']["optimizer"]['lr_lambda_type'] warm_up_steps = configs['train']["optimizer"]['warm_up_steps'] reduce_lr_steps = configs['train']["optimizer"]['reduce_lr_steps'] save_step_frequency = configs['train']['save_step_frequency'] resume_checkpoint_path = args.resume_checkpoint_path if resume_checkpoint_path == "": resume_checkpoint_path = None else: logging.info(f'Finetuning AudioSep with checkpoint [{resume_checkpoint_path}]') # Get directories and paths checkpoints_dir, logs_dir, tf_logs_dir, statistics_path = get_dirs( workspace, filename, config_yaml, devices_num, ) logging.info(configs) # data module data_module = get_data_module( config_yaml=config_yaml, batch_size=batch_size, num_workers=num_workers, ) # model Model = get_model_class(model_type=model_type) ss_model = Model( input_channels=input_channels, output_channels=output_channels, condition_size=condition_size, ) # loss function loss_function = get_loss_function(loss_type) segment_mixer = SegmentMixer( max_mix_num=max_mix_num, lower_db=lower_db, higher_db=higher_db ) if query_net == 'CLAP': query_encoder = CLAP_Encoder() else: raise NotImplementedError lr_lambda_func = get_lr_lambda( lr_lambda_type=lr_lambda_type, warm_up_steps=warm_up_steps, reduce_lr_steps=reduce_lr_steps, ) # pytorch-lightning model pl_model = AudioSep( ss_model=ss_model, waveform_mixer=segment_mixer, query_encoder=query_encoder, loss_function=loss_function, optimizer_type=optimizer_type, learning_rate=learning_rate, lr_lambda_func=lr_lambda_func, use_text_ratio=use_text_ratio ) checkpoint_every_n_steps = CheckpointEveryNSteps( checkpoints_dir=checkpoints_dir, save_step_frequency=save_step_frequency, ) summary_writer = SummaryWriter(log_dir=tf_logs_dir) callbacks = [checkpoint_every_n_steps] trainer = pl.Trainer( accelerator='auto', devices='auto', strategy='ddp_find_unused_parameters_true', num_nodes=num_nodes, precision="32-true", logger=None, callbacks=callbacks, fast_dev_run=False, max_epochs=-1, log_every_n_steps=50, use_distributed_sampler=True, sync_batchnorm=sync_batchnorm, num_sanity_val_steps=2, enable_checkpointing=False, enable_progress_bar=True, enable_model_summary=True, ) # Fit, evaluate, and save checkpoints. trainer.fit( model=pl_model, train_dataloaders=None, val_dataloaders=None, datamodule=data_module, ckpt_path=resume_checkpoint_path, ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--workspace", type=str, required=True, help="Directory of workspace." ) parser.add_argument( "--config_yaml", type=str, required=True, help="Path of config file for training.", ) parser.add_argument( "--resume_checkpoint_path", type=str, required=True, default='', help="Path of pretrained checkpoint for finetuning.", ) args = parser.parse_args() args.filename = pathlib.Path(__file__).stem train(args)