from typing import List import pytorch_lightning as pl import torch.nn as nn def get_callbacks( task_name: str, config_yaml: str, workspace: str, checkpoints_dir: str, statistics_path: str, logger: pl.loggers.TensorBoardLogger, model: nn.Module, evaluate_device: str, ) -> List[pl.Callback]: r"""Get callbacks of a task and config yaml file. Args: task_name: str config_yaml: str dataset_dir: str workspace: str, containing useful files such as audios for evaluation checkpoints_dir: str, directory to save checkpoints statistics_dir: str, directory to save statistics logger: pl.loggers.TensorBoardLogger model: nn.Module evaluate_device: str Return: callbacks: List[pl.Callback] """ if task_name == 'musdb18': from bytesep.callbacks.musdb18 import get_musdb18_callbacks return get_musdb18_callbacks( config_yaml=config_yaml, workspace=workspace, checkpoints_dir=checkpoints_dir, statistics_path=statistics_path, logger=logger, model=model, evaluate_device=evaluate_device, ) elif task_name == 'voicebank-demand': from bytesep.callbacks.voicebank_demand import get_voicebank_demand_callbacks return get_voicebank_demand_callbacks( config_yaml=config_yaml, workspace=workspace, checkpoints_dir=checkpoints_dir, statistics_path=statistics_path, logger=logger, model=model, evaluate_device=evaluate_device, ) elif task_name in ['vctk-musdb18', 'violin-piano', 'piano-symphony']: from bytesep.callbacks.instruments_callbacks import get_instruments_callbacks return get_instruments_callbacks( config_yaml=config_yaml, workspace=workspace, checkpoints_dir=checkpoints_dir, statistics_path=statistics_path, logger=logger, model=model, evaluate_device=evaluate_device, ) else: raise NotImplementedError