| """ | |
| The backbone run procedure for the common train/valid/test | |
| Authors | |
| * Leo 2022 | |
| """ | |
| import logging | |
| import pickle | |
| import shutil | |
| from pathlib import Path | |
| import pandas as pd | |
| import torch | |
| import yaml | |
| from s3prl.problem.base import Problem | |
| from s3prl.task.utterance_classification_task import UtteranceClassificationTask | |
| logger = logging.getLogger(__name__) | |
| __all__ = ["Common"] | |
| class Common(Problem): | |
| def run( | |
| self, | |
| target_dir: str, | |
| cache_dir: str = None, | |
| remove_all_cache: bool = False, | |
| start: int = 0, | |
| stop: int = None, | |
| num_workers: int = 6, | |
| eval_batch: int = -1, | |
| device: str = "cuda", | |
| world_size: int = 1, | |
| rank: int = 0, | |
| test_ckpt_dir: str = None, | |
| prepare_data: dict = None, | |
| build_encoder: dict = None, | |
| build_dataset: dict = None, | |
| build_batch_sampler: dict = None, | |
| build_collate_fn: dict = None, | |
| build_upstream: dict = None, | |
| build_featurizer: dict = None, | |
| build_downstream: dict = None, | |
| build_model: dict = None, | |
| build_task: dict = None, | |
| build_optimizer: dict = None, | |
| build_scheduler: dict = None, | |
| save_model: dict = None, | |
| save_task: dict = None, | |
| train: dict = None, | |
| evaluate: dict = None, | |
| ): | |
| """ | |
| ======== ==================== | |
| stage description | |
| ======== ==================== | |
| 0 Parse the corpus and save the metadata file (waveform path, label...) | |
| 1 Build the encoder to encode the labels | |
| 2 Train the model | |
| 3 Evaluate the model on multiple test sets | |
| ======== ==================== | |
| Args: | |
| target_dir (str): | |
| The directory that stores the script result. | |
| cache_dir (str): | |
| The directory that caches the processed data. | |
| Default: /home/user/.cache/s3prl/data | |
| remove_all_cache (bool): | |
| Whether to remove all the cache stored under `cache_dir`. | |
| Default: False | |
| start (int): | |
| The starting stage of the problem script. | |
| Default: 0 | |
| stop (int): | |
| The stoping stage of the problem script, set `None` to reach the final stage. | |
| Default: None | |
| num_workers (int): num_workers for all the torch DataLoder | |
| eval_batch (int): | |
| During evaluation (valid or test), limit the number of batch. | |
| This is helpful for the fast development to check everything won't crash. | |
| If is -1, disable this feature and evaluate the entire epoch. | |
| Default: -1 | |
| device (str): | |
| The device type for all torch-related operation: "cpu" or "cuda" | |
| Default: "cuda" | |
| world_size (int): | |
| How many processes are running this script simultaneously (in parallel). | |
| Usually this is just 1, however if you are runnig distributed training, | |
| this should be > 1. | |
| Default: 1 | |
| rank (int): | |
| When distributed training, world_size > 1. Take :code:`world_size == 8` for | |
| example, this means 8 processes (8 GPUs) are runing in parallel. The script | |
| needs to know which process among 8 processes it is. In this case, :code:`rank` | |
| can range from 0~7. All the 8 processes have the same :code:`world_size` but | |
| different :code:`rank` (process id). | |
| test_ckpt_dir (str): | |
| Specify the checkpoint path for testing. If not, use the validation best | |
| checkpoint under the given :code:`target_dir` directory. | |
| **kwds: | |
| The other arguments like :code:`prepare_data` and :code:`build_model` are | |
| method specific-arguments for methods like :obj:`prepare_data` and | |
| :obj:`build_model`, and will not be used in the core :obj:`run` logic. | |
| See the specific method documentation for their supported arguments and | |
| meaning | |
| """ | |
| yaml_path = Path(target_dir) / "configs" / f"{self._get_time_tag()}.yaml" | |
| yaml_path.parent.mkdir(exist_ok=True, parents=True) | |
| with yaml_path.open("w") as f: | |
| yaml.safe_dump(self._get_current_arguments(), f) | |
| cache_dir: str = cache_dir or Path.home() / ".cache" / "s3prl" / "data" | |
| prepare_data: dict = prepare_data or {} | |
| build_encoder: dict = build_encoder or {} | |
| build_dataset: dict = build_dataset or {} | |
| build_batch_sampler: dict = build_batch_sampler or {} | |
| build_collate_fn: dict = build_collate_fn or {} | |
| build_upstream: dict = build_upstream or {} | |
| build_featurizer: dict = build_featurizer or {} | |
| build_downstream: dict = build_downstream or {} | |
| build_model: dict = build_model or {} | |
| build_task: dict = build_task or {} | |
| build_optimizer: dict = build_optimizer or {} | |
| build_scheduler: dict = build_scheduler or {} | |
| save_model: dict = save_model or {} | |
| save_task: dict = save_task or {} | |
| train: dict = train or {} | |
| evaluate = evaluate or {} | |
| target_dir: Path = Path(target_dir) | |
| target_dir.mkdir(exist_ok=True, parents=True) | |
| cache_dir = Path(cache_dir) | |
| cache_dir.mkdir(exist_ok=True, parents=True) | |
| if remove_all_cache: | |
| shutil.rmtree(cache_dir) | |
| stage_id = 0 | |
| if start <= stage_id: | |
| logger.info(f"Stage {stage_id}: prepare data") | |
| train_csv, valid_csv, test_csvs = self.prepare_data( | |
| prepare_data, target_dir, cache_dir, get_path_only=False | |
| ) | |
| train_csv, valid_csv, test_csvs = self.prepare_data( | |
| prepare_data, target_dir, cache_dir, get_path_only=True | |
| ) | |
| def check_fn(): | |
| assert Path(train_csv).is_file() and Path(valid_csv).is_file() | |
| for test_csv in test_csvs: | |
| assert Path(test_csv).is_file() | |
| self._stage_check(stage_id, stop, check_fn) | |
| stage_id = 1 | |
| if start <= stage_id: | |
| logger.info(f"Stage {stage_id}: build encoder") | |
| encoder_path = self.build_encoder( | |
| build_encoder, | |
| target_dir, | |
| cache_dir, | |
| train_csv, | |
| valid_csv, | |
| test_csvs, | |
| get_path_only=False, | |
| ) | |
| encoder_path = self.build_encoder( | |
| build_encoder, | |
| target_dir, | |
| cache_dir, | |
| train_csv, | |
| valid_csv, | |
| test_csvs, | |
| get_path_only=True, | |
| ) | |
| def check_fn(): | |
| assert Path(encoder_path).is_file() | |
| self._stage_check(stage_id, stop, check_fn) | |
| with open(encoder_path, "rb") as f: | |
| encoder = pickle.load(f) | |
| model_output_size = len(encoder) | |
| model = self.build_model( | |
| build_model, | |
| model_output_size, | |
| build_upstream, | |
| build_featurizer, | |
| build_downstream, | |
| ) | |
| frame_shift = model.downsample_rate | |
| stage_id = 2 | |
| train_dir = target_dir / "train" | |
| if start <= stage_id: | |
| logger.info(f"Stage {stage_id}: Train Model") | |
| train_ds, train_bs = self._build_dataset_and_sampler( | |
| target_dir, | |
| cache_dir, | |
| "train", | |
| train_csv, | |
| encoder_path, | |
| frame_shift, | |
| build_dataset, | |
| build_batch_sampler, | |
| ) | |
| valid_ds, valid_bs = self._build_dataset_and_sampler( | |
| target_dir, | |
| cache_dir, | |
| "valid", | |
| valid_csv, | |
| encoder_path, | |
| frame_shift, | |
| build_dataset, | |
| build_batch_sampler, | |
| ) | |
| with Path(encoder_path).open("rb") as f: | |
| encoder = pickle.load(f) | |
| build_model_all_args = dict( | |
| build_model=build_model, | |
| model_output_size=len(encoder), | |
| build_upstream=build_upstream, | |
| build_featurizer=build_featurizer, | |
| build_downstream=build_downstream, | |
| ) | |
| build_task_all_args_except_model = dict( | |
| build_task=build_task, | |
| encoder=encoder, | |
| valid_df=pd.read_csv(valid_csv), | |
| ) | |
| self.train( | |
| train, | |
| train_dir, | |
| build_model_all_args, | |
| build_task_all_args_except_model, | |
| save_model, | |
| save_task, | |
| build_optimizer, | |
| build_scheduler, | |
| evaluate, | |
| train_ds, | |
| train_bs, | |
| self.build_collate_fn(build_collate_fn, "train"), | |
| valid_ds, | |
| valid_bs, | |
| self.build_collate_fn(build_collate_fn, "valid"), | |
| device=device, | |
| eval_batch=eval_batch, | |
| num_workers=num_workers, | |
| world_size=world_size, | |
| rank=rank, | |
| ) | |
| def check_fn(): | |
| assert (train_dir / "valid_best").is_dir() | |
| self._stage_check(stage_id, stop, check_fn) | |
| stage_id = 3 | |
| if start <= stage_id: | |
| test_ckpt_dir: Path = Path( | |
| test_ckpt_dir or target_dir / "train" / "valid_best" | |
| ) | |
| assert test_ckpt_dir.is_dir() | |
| logger.info(f"Stage {stage_id}: Test model: {test_ckpt_dir}") | |
| for test_idx, test_csv in enumerate(test_csvs): | |
| test_name = Path(test_csv).stem | |
| test_dir: Path = ( | |
| target_dir | |
| / "evaluate" | |
| / test_ckpt_dir.relative_to(train_dir).as_posix().replace("/", "-") | |
| / test_name | |
| ) | |
| test_dir.mkdir(exist_ok=True, parents=True) | |
| logger.info(f"Stage {stage_id}.{test_idx}: Test model on {test_csv}") | |
| test_ds, test_bs = self._build_dataset_and_sampler( | |
| target_dir, | |
| cache_dir, | |
| "test", | |
| test_csv, | |
| encoder_path, | |
| frame_shift, | |
| build_dataset, | |
| build_batch_sampler, | |
| ) | |
| _, valid_best_task = self.load_model_and_task( | |
| test_ckpt_dir, task_overrides={"test_df": pd.read_csv(test_csv)} | |
| ) | |
| logs = self.evaluate( | |
| evaluate, | |
| "test", | |
| valid_best_task, | |
| test_ds, | |
| test_bs, | |
| self.build_collate_fn(build_collate_fn, "test"), | |
| eval_batch, | |
| test_dir, | |
| device, | |
| num_workers, | |
| ) | |
| test_metrics = {name: float(value) for name, value in logs.items()} | |
| logger.info(f"test results: {test_metrics}") | |
| with (test_dir / f"result.yaml").open("w") as f: | |
| yaml.safe_dump(test_metrics, f) | |
| def _build_dataset_and_sampler( | |
| self, | |
| target_dir: str, | |
| cache_dir: str, | |
| mode: str, | |
| data_csv: str, | |
| encoder_path: str, | |
| frame_shift: int, | |
| build_dataset: dict, | |
| build_batch_sampler: dict, | |
| ): | |
| logger.info(f"Build {mode} dataset") | |
| dataset = self.build_dataset( | |
| build_dataset, | |
| target_dir, | |
| cache_dir, | |
| mode, | |
| data_csv, | |
| encoder_path, | |
| frame_shift, | |
| ) | |
| logger.info(f"Build {mode} batch sampler") | |
| batch_sampler = self.build_batch_sampler( | |
| build_batch_sampler, | |
| target_dir, | |
| cache_dir, | |
| mode, | |
| data_csv, | |
| dataset, | |
| ) | |
| return dataset, batch_sampler | |
| def build_task( | |
| self, | |
| build_task: dict, | |
| model: torch.nn.Module, | |
| encoder, | |
| valid_df: pd.DataFrame = None, | |
| test_df: pd.DataFrame = None, | |
| ): | |
| """ | |
| Build the task, which defines the logics for every train/valid/test forward step for the :code:`model`, | |
| and the logics for how to reduce all the batch results from multiple train/valid/test steps into metrics | |
| By default build :obj:`UtteranceClassificationTask` | |
| Args: | |
| build_task (dict): same in :obj:`default_config`, no argument supported for now | |
| model (torch.nn.Module): the model built by :obj:`build_model` | |
| encoder: the encoder built by :obj:`build_encoder` | |
| Returns: | |
| Task | |
| """ | |
| task = UtteranceClassificationTask(model, encoder) | |
| return task | |