Spaces:
Running
Running
| # Copyright 2024 The YourMT3 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Please see the details in the LICENSE file. | |
| """ data_modules.py """ | |
| from typing import Optional, Dict, List, Any | |
| import os | |
| import numpy as np | |
| from pytorch_lightning import LightningDataModule | |
| from pytorch_lightning.utilities import CombinedLoader | |
| from utils.datasets_train import get_cache_data_loader | |
| from utils.datasets_eval import get_eval_dataloader | |
| from utils.datasets_helper import create_merged_train_dataset_info, get_list_of_weighted_random_samplers | |
| from utils.task_manager import TaskManager | |
| from config.config import shared_cfg | |
| from config.config import audio_cfg as default_audio_cfg | |
| from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg | |
| class AMTDataModule(LightningDataModule): | |
| def __init__( | |
| self, | |
| data_home: Optional[os.PathLike] = None, | |
| data_preset_multi: Dict[str, Any] = { | |
| "presets": ["musicnet_mt3_synth_only"], | |
| }, # only allowing multi_preset_cfg. single_preset_cfg should be converted to multi_preset_cfg | |
| task_manager: TaskManager = TaskManager(task_name="mt3_full_plus"), | |
| train_num_samples_per_epoch: Optional[int] = None, | |
| train_random_amp_range: List[float] = [0.6, 1.2], | |
| train_stem_iaug_prob: Optional[float] = 0.7, | |
| train_stem_xaug_policy: Optional[Dict] = { | |
| "max_k": 3, | |
| "tau": 0.3, | |
| "alpha": 1.0, | |
| "max_subunit_stems": 12, # the number of subunit stems to be reduced to this number of stems | |
| "p_include_singing": | |
| 0.8, # probability of including singing for cross augmented examples. if None, use base probaility. | |
| "no_instr_overlap": True, | |
| "no_drum_overlap": True, | |
| "uhat_intra_stem_augment": True, | |
| }, | |
| train_pitch_shift_range: Optional[List[int]] = None, | |
| audio_cfg: Optional[Dict] = None) -> None: | |
| super().__init__() | |
| # check path existence | |
| if data_home is None: | |
| data_home = shared_cfg["PATH"]["data_home"] | |
| if os.path.exists(data_home): | |
| self.data_home = data_home | |
| else: | |
| raise ValueError(f"Invalid data_home: {data_home}") | |
| self.preset_multi = data_preset_multi | |
| self.preset_singles = [] | |
| # e.g. [{"dataset_name": ..., "train_split": ..., "validation_split":...,}, {...}] | |
| for dp in self.preset_multi["presets"]: | |
| if dp not in data_preset_single_cfg.keys(): | |
| raise ValueError("Invalid data_preset") | |
| self.preset_singles.append(data_preset_single_cfg[dp]) | |
| # task manager | |
| self.task_manager = task_manager | |
| # train num samples per epoch, passed to the sampler | |
| self.train_num_samples_per_epoch = train_num_samples_per_epoch | |
| assert shared_cfg["BSZ"]["train_local"] % shared_cfg["BSZ"]["train_sub"] == 0 | |
| self.num_train_samplers = shared_cfg["BSZ"]["train_local"] // shared_cfg["BSZ"]["train_sub"] | |
| # train augmentation parameters | |
| self.train_random_amp_range = train_random_amp_range | |
| self.train_stem_iaug_prob = train_stem_iaug_prob | |
| self.train_stem_xaug_policy = train_stem_xaug_policy | |
| self.train_pitch_shift_range = train_pitch_shift_range | |
| # train data info | |
| self.train_data_info = None # to be set in setup() | |
| # validation/test max num of files | |
| self.val_max_num_files = data_preset_multi.get("val_max_num_files", None) | |
| self.test_max_num_files = data_preset_multi.get("test_max_num_files", None) | |
| # audio config | |
| self.audio_cfg = audio_cfg if audio_cfg is not None else default_audio_cfg | |
| def set_merged_train_data_info(self) -> None: | |
| """Collect train datasets and create info... | |
| self.train_dataset_info = { | |
| "n_datasets": 0, | |
| "n_notes_per_dataset": [], | |
| "n_files_per_dataset": [], | |
| "dataset_names": [], # dataset names by order of merging file lists | |
| "train_split_names": [], # train split names by order of merging file lists | |
| "index_ranges": [], # index ranges of each dataset in the merged file list | |
| "dataset_weights": [], # pre-defined list of dataset weights for sampling, if available | |
| "merged_file_list": {}, | |
| } | |
| """ | |
| self.train_data_info = create_merged_train_dataset_info(self.preset_multi) | |
| print( | |
| f"AMTDataModule: Added {len(self.train_data_info['merged_file_list'])} files from {self.train_data_info['n_datasets']} datasets to the training set." | |
| ) | |
| def setup(self, stage: str): | |
| """ | |
| Prepare data args for the dataloaders to be used on each stage. | |
| `stage` is automatically passed by pytorch lightning Trainer. | |
| """ | |
| if stage == "fit": | |
| # Set up train data info | |
| self.set_merged_train_data_info() | |
| # Distributed Weighted random sampler for training | |
| actual_train_num_samples_per_epoch = self.train_num_samples_per_epoch // shared_cfg["BSZ"][ | |
| "train_local"] if self.train_num_samples_per_epoch else None | |
| samplers = get_list_of_weighted_random_samplers(num_samplers=self.num_train_samplers, | |
| dataset_weights=self.train_data_info["dataset_weights"], | |
| dataset_index_ranges=self.train_data_info["index_ranges"], | |
| num_samples_per_epoch=actual_train_num_samples_per_epoch) | |
| # Train dataloader arguments | |
| self.train_data_args = [] | |
| for sampler in samplers: | |
| self.train_data_args.append({ | |
| "dataset_name": None, | |
| "split": None, | |
| "file_list": self.train_data_info["merged_file_list"], | |
| "sub_batch_size": shared_cfg["BSZ"]["train_sub"], | |
| "task_manager": self.task_manager, | |
| "random_amp_range": self.train_random_amp_range, # "0.1,0.5 | |
| "stem_iaug_prob": self.train_stem_iaug_prob, | |
| "stem_xaug_policy": self.train_stem_xaug_policy, | |
| "pitch_shift_range": self.train_pitch_shift_range, | |
| "shuffle": True, | |
| "sampler": sampler, | |
| "audio_cfg": self.audio_cfg, | |
| }) | |
| # Validation dataloader arguments | |
| self.val_data_args = [] | |
| for preset_single in self.preset_singles: | |
| if preset_single["validation_split"] != None: | |
| self.val_data_args.append({ | |
| "dataset_name": preset_single["dataset_name"], | |
| "split": preset_single["validation_split"], | |
| "task_manager": self.task_manager, | |
| # "tokenizer": self.task_manager.get_tokenizer(), | |
| "max_num_files": self.val_max_num_files, | |
| "audio_cfg": self.audio_cfg, | |
| }) | |
| if stage == "test": | |
| self.test_data_args = [] | |
| for preset_single in self.preset_singles: | |
| if preset_single["test_split"] != None: | |
| self.test_data_args.append({ | |
| "dataset_name": preset_single["dataset_name"], | |
| "split": preset_single["test_split"], | |
| "task_manager": self.task_manager, | |
| "max_num_files": self.test_max_num_files, | |
| "audio_cfg": self.audio_cfg, | |
| }) | |
| def train_dataloader(self) -> Any: | |
| loaders = {} | |
| for i, args_dict in enumerate(self.train_data_args): | |
| loaders[f"data_loader_{i}"] = get_cache_data_loader(**args_dict, dataloader_config=shared_cfg["DATAIO"]) | |
| return CombinedLoader(loaders, mode="min_size") # size is always identical | |
| def val_dataloader(self) -> Any: | |
| loaders = {} | |
| for args_dict in self.val_data_args: | |
| dataset_name = args_dict["dataset_name"] | |
| loaders[dataset_name] = get_eval_dataloader(**args_dict, dataloader_config=shared_cfg["DATAIO"]) | |
| return loaders | |
| def test_dataloader(self) -> Any: | |
| loaders = {} | |
| for args_dict in self.test_data_args: | |
| dataset_name = args_dict["dataset_name"] | |
| loaders[dataset_name] = get_eval_dataloader(**args_dict, dataloader_config=shared_cfg["DATAIO"]) | |
| return loaders | |
| """CombinedLoader in "sequential" mode returns dataloader_idx to the | |
| trainer, which is used to get the dataset name in the logger. """ | |
| def num_val_dataloaders(self) -> int: | |
| return len(self.val_data_args) | |
| def num_test_dataloaders(self) -> int: | |
| return len(self.test_data_args) | |
| def get_val_dataset_name(self, dataloader_idx: int) -> str: | |
| return self.val_data_args[dataloader_idx]["dataset_name"] | |
| def get_test_dataset_name(self, dataloader_idx: int) -> str: | |
| return self.test_data_args[dataloader_idx]["dataset_name"] | |