|
import argparse |
|
import logging |
|
import os |
|
import pathlib |
|
from typing import List, NoReturn |
|
import lightning.pytorch as pl |
|
from lightning.pytorch.strategies import DDPStrategy |
|
from torch.utils.tensorboard import SummaryWriter |
|
from data.datamodules import * |
|
from utils import create_logging, parse_yaml |
|
from models.resunet import * |
|
from losses import get_loss_function |
|
from models.audiosep import AudioSep, get_model_class |
|
from data.waveform_mixers import SegmentMixer |
|
from models.clap_encoder import CLAP_Encoder |
|
from callbacks.base import CheckpointEveryNSteps |
|
from optimizers.lr_schedulers import get_lr_lambda |
|
|
|
|
|
def get_dirs( |
|
workspace: str, |
|
filename: str, |
|
config_yaml: str, |
|
devices_num: int |
|
) -> List[str]: |
|
r"""Get directories and paths. |
|
|
|
Args: |
|
workspace (str): directory of workspace |
|
filename (str): filename of current .py file. |
|
config_yaml (str): config yaml path |
|
devices_num (int): 0 for cpu and 8 for training with 8 GPUs |
|
|
|
Returns: |
|
checkpoints_dir (str): directory to save checkpoints |
|
logs_dir (str), directory to save logs |
|
tf_logs_dir (str), directory to save TensorBoard logs |
|
statistics_path (str), directory to save statistics |
|
""" |
|
|
|
os.makedirs(workspace, exist_ok=True) |
|
|
|
yaml_name = pathlib.Path(config_yaml).stem |
|
|
|
|
|
checkpoints_dir = os.path.join( |
|
workspace, |
|
"checkpoints", |
|
filename, |
|
"{},devices={}".format(yaml_name, devices_num), |
|
) |
|
os.makedirs(checkpoints_dir, exist_ok=True) |
|
|
|
|
|
logs_dir = os.path.join( |
|
workspace, |
|
"logs", |
|
filename, |
|
"{},devices={}".format(yaml_name, devices_num), |
|
) |
|
os.makedirs(logs_dir, exist_ok=True) |
|
|
|
|
|
create_logging(logs_dir, filemode="w") |
|
logging.info(args) |
|
|
|
tf_logs_dir = os.path.join( |
|
workspace, |
|
"tf_logs", |
|
filename, |
|
"{},devices={}".format(yaml_name, devices_num), |
|
) |
|
|
|
|
|
statistics_path = os.path.join( |
|
workspace, |
|
"statistics", |
|
filename, |
|
"{},devices={}".format(yaml_name, devices_num), |
|
"statistics.pkl", |
|
) |
|
os.makedirs(os.path.dirname(statistics_path), exist_ok=True) |
|
|
|
return checkpoints_dir, logs_dir, tf_logs_dir, statistics_path |
|
|
|
|
|
def get_data_module( |
|
config_yaml: str, |
|
num_workers: int, |
|
batch_size: int, |
|
) -> DataModule: |
|
r"""Create data_module. Mini-batch data can be obtained by: |
|
|
|
code-block:: python |
|
|
|
data_module.setup() |
|
|
|
for batch_data_dict in data_module.train_dataloader(): |
|
print(batch_data_dict.keys()) |
|
break |
|
|
|
Args: |
|
workspace: str |
|
config_yaml: str |
|
num_workers: int, e.g., 0 for non-parallel and 8 for using cpu cores |
|
for preparing data in parallel |
|
distributed: bool |
|
|
|
Returns: |
|
data_module: DataModule |
|
""" |
|
|
|
|
|
configs = parse_yaml(config_yaml) |
|
sampling_rate = configs['data']['sampling_rate'] |
|
segment_seconds = configs['data']['segment_seconds'] |
|
|
|
|
|
datafiles = configs['data']['datafiles'] |
|
|
|
|
|
dataset = AudioTextDataset( |
|
datafiles=datafiles, |
|
sampling_rate=sampling_rate, |
|
max_clip_len=segment_seconds, |
|
) |
|
|
|
|
|
|
|
data_module = DataModule( |
|
train_dataset=dataset, |
|
num_workers=num_workers, |
|
batch_size=batch_size |
|
) |
|
|
|
return data_module |
|
|
|
|
|
def train(args) -> NoReturn: |
|
r"""Train, evaluate, and save checkpoints. |
|
|
|
Args: |
|
workspace: str, directory of workspace |
|
gpus: int, number of GPUs to train |
|
config_yaml: str |
|
""" |
|
|
|
|
|
workspace = args.workspace |
|
config_yaml = args.config_yaml |
|
filename = args.filename |
|
|
|
devices_num = torch.cuda.device_count() |
|
|
|
configs = parse_yaml(config_yaml) |
|
|
|
|
|
max_mix_num = configs['data']['max_mix_num'] |
|
sampling_rate = configs['data']['sampling_rate'] |
|
lower_db = configs['data']['loudness_norm']['lower_db'] |
|
higher_db = configs['data']['loudness_norm']['higher_db'] |
|
|
|
|
|
query_net = configs['model']['query_net'] |
|
model_type = configs['model']['model_type'] |
|
input_channels = configs['model']['input_channels'] |
|
output_channels = configs['model']['output_channels'] |
|
condition_size = configs['model']['condition_size'] |
|
use_text_ratio = configs['model']['use_text_ratio'] |
|
|
|
|
|
num_nodes = configs['train']['num_nodes'] |
|
batch_size = configs['train']['batch_size_per_device'] |
|
sync_batchnorm = configs['train']['sync_batchnorm'] |
|
num_workers = configs['train']['num_workers'] |
|
loss_type = configs['train']['loss_type'] |
|
optimizer_type = configs["train"]["optimizer"]["optimizer_type"] |
|
learning_rate = float(configs['train']["optimizer"]['learning_rate']) |
|
lr_lambda_type = configs['train']["optimizer"]['lr_lambda_type'] |
|
warm_up_steps = configs['train']["optimizer"]['warm_up_steps'] |
|
reduce_lr_steps = configs['train']["optimizer"]['reduce_lr_steps'] |
|
save_step_frequency = configs['train']['save_step_frequency'] |
|
resume_checkpoint_path = args.resume_checkpoint_path |
|
if resume_checkpoint_path == "": |
|
resume_checkpoint_path = None |
|
else: |
|
logging.info(f'Finetuning AudioSep with checkpoint [{resume_checkpoint_path}]') |
|
|
|
|
|
checkpoints_dir, logs_dir, tf_logs_dir, statistics_path = get_dirs( |
|
workspace, filename, config_yaml, devices_num, |
|
) |
|
|
|
logging.info(configs) |
|
|
|
|
|
data_module = get_data_module( |
|
config_yaml=config_yaml, |
|
batch_size=batch_size, |
|
num_workers=num_workers, |
|
) |
|
|
|
|
|
Model = get_model_class(model_type=model_type) |
|
|
|
ss_model = Model( |
|
input_channels=input_channels, |
|
output_channels=output_channels, |
|
condition_size=condition_size, |
|
) |
|
|
|
|
|
loss_function = get_loss_function(loss_type) |
|
|
|
segment_mixer = SegmentMixer( |
|
max_mix_num=max_mix_num, |
|
lower_db=lower_db, |
|
higher_db=higher_db |
|
) |
|
|
|
|
|
if query_net == 'CLAP': |
|
query_encoder = CLAP_Encoder() |
|
else: |
|
raise NotImplementedError |
|
|
|
lr_lambda_func = get_lr_lambda( |
|
lr_lambda_type=lr_lambda_type, |
|
warm_up_steps=warm_up_steps, |
|
reduce_lr_steps=reduce_lr_steps, |
|
) |
|
|
|
|
|
pl_model = AudioSep( |
|
ss_model=ss_model, |
|
waveform_mixer=segment_mixer, |
|
query_encoder=query_encoder, |
|
loss_function=loss_function, |
|
optimizer_type=optimizer_type, |
|
learning_rate=learning_rate, |
|
lr_lambda_func=lr_lambda_func, |
|
use_text_ratio=use_text_ratio |
|
) |
|
|
|
checkpoint_every_n_steps = CheckpointEveryNSteps( |
|
checkpoints_dir=checkpoints_dir, |
|
save_step_frequency=save_step_frequency, |
|
) |
|
|
|
summary_writer = SummaryWriter(log_dir=tf_logs_dir) |
|
|
|
callbacks = [checkpoint_every_n_steps] |
|
|
|
trainer = pl.Trainer( |
|
accelerator='auto', |
|
devices='auto', |
|
strategy='ddp_find_unused_parameters_true', |
|
num_nodes=num_nodes, |
|
precision="32-true", |
|
logger=None, |
|
callbacks=callbacks, |
|
fast_dev_run=False, |
|
max_epochs=-1, |
|
log_every_n_steps=50, |
|
use_distributed_sampler=True, |
|
sync_batchnorm=sync_batchnorm, |
|
num_sanity_val_steps=2, |
|
enable_checkpointing=False, |
|
enable_progress_bar=True, |
|
enable_model_summary=True, |
|
) |
|
|
|
|
|
trainer.fit( |
|
model=pl_model, |
|
train_dataloaders=None, |
|
val_dataloaders=None, |
|
datamodule=data_module, |
|
ckpt_path=resume_checkpoint_path, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--workspace", type=str, required=True, help="Directory of workspace." |
|
) |
|
parser.add_argument( |
|
"--config_yaml", |
|
type=str, |
|
required=True, |
|
help="Path of config file for training.", |
|
) |
|
|
|
parser.add_argument( |
|
"--resume_checkpoint_path", |
|
type=str, |
|
required=True, |
|
default='', |
|
help="Path of pretrained checkpoint for finetuning.", |
|
) |
|
|
|
args = parser.parse_args() |
|
args.filename = pathlib.Path(__file__).stem |
|
|
|
train(args) |