|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import json |
|
import torch |
|
import numpy as np |
|
from torch.utils.data import DistributedSampler |
|
from torch.utils.data import Dataset, Sampler |
|
from torch.utils.data import RandomSampler, WeightedRandomSampler |
|
from operator import itemgetter |
|
from typing import List, Tuple, Union, Iterator, Optional |
|
from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg |
|
from config.config import shared_cfg |
|
|
|
|
|
class DatasetFromSampler(Dataset): |
|
"""Dataset to create indexes from `Sampler`. From catalyst library. |
|
|
|
Args: |
|
sampler: PyTorch sampler |
|
""" |
|
|
|
def __init__(self, sampler: Sampler): |
|
"""Initialisation for DatasetFromSampler.""" |
|
self.sampler = sampler |
|
self.sampler_list = None |
|
|
|
def __getitem__(self, index: int): |
|
"""Gets element of the dataset. |
|
|
|
Args: |
|
index: index of the element in the dataset |
|
|
|
Returns: |
|
Single element by index |
|
""" |
|
if self.sampler_list is None: |
|
self.sampler_list = list(self.sampler) |
|
return self.sampler_list[index] |
|
|
|
def __len__(self) -> int: |
|
""" |
|
Returns: |
|
int: length of the dataset |
|
""" |
|
return len(self.sampler) |
|
|
|
|
|
class DistributedSamplerWrapper(DistributedSampler): |
|
""" |
|
Wrapper over `Sampler` for distributed training. |
|
Allows to use any sampler in distributed mode. |
|
From https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py |
|
|
|
It is especially useful in conjunction with |
|
`torch.nn.parallel.DistributedDataParallel`. In such case, each |
|
process can pass a DistributedSamplerWrapper instance as a DataLoader |
|
sampler, and load a subset of subsampled data of the original dataset |
|
that is exclusive to it. |
|
|
|
.. note:: |
|
Sampler is assumed to be of constant size. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
sampler, |
|
num_replicas: Optional[int] = None, |
|
rank: Optional[int] = None, |
|
shuffle: bool = True, |
|
): |
|
""" |
|
|
|
Args: |
|
sampler: Sampler used for subsampling |
|
num_replicas (int, optional): Number of processes participating in |
|
distributed training |
|
rank (int, optional): Rank of the current process |
|
within ``num_replicas`` |
|
shuffle (bool, optional): If true (default), |
|
sampler will shuffle the indices |
|
""" |
|
super(DistributedSamplerWrapper, self).__init__( |
|
DatasetFromSampler(sampler), |
|
num_replicas=num_replicas, |
|
rank=rank, |
|
shuffle=shuffle, |
|
) |
|
self.sampler = sampler |
|
|
|
def __iter__(self) -> Iterator[int]: |
|
"""Iterate over sampler. |
|
|
|
Returns: |
|
python iterator |
|
""" |
|
self.dataset = DatasetFromSampler(self.sampler) |
|
indexes_of_indexes = super().__iter__() |
|
subsampler_indexes = self.dataset |
|
return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) |
|
|
|
|
|
def discount_to_target(samples: np.ndarray, target_sum: int) -> np.ndarray: |
|
"""Discounts samples to target sum. |
|
|
|
NOTE: this function is deprecated. |
|
|
|
This function adjusts an array of sample values so that their sum equals a target sum, while ensuring |
|
that each element remains greater than or equal to 1 and attempting to maintain a distribution similar |
|
to the original. |
|
|
|
Example 1: |
|
samples = np.array([3, 1, 1, 1, 1, 1]) |
|
target_sum = 7 |
|
discounted_samples = discount_to_target(samples, target_sum) |
|
# [2, 1, 1, 1, 1, 1] |
|
|
|
Example 2: |
|
samples = np.array([3,1, 10, 1, 1, 1]) |
|
target_sum = 7 |
|
# [1, 1, 2, 1, 1, 1] |
|
|
|
Parameters: |
|
samples (np.ndarray): Original array of sample values. |
|
target_sum (int): The desired sum of the sample array. |
|
|
|
Returns: |
|
np.ndarray: Adjusted array of sample values whose sum should equal the target sum, |
|
and where each element is greater than or equal to 1. |
|
""" |
|
samples = samples.copy().astype(int) |
|
if samples.sum() <= target_sum: |
|
samples[0] += 1 |
|
return samples |
|
|
|
while samples.sum() > target_sum: |
|
|
|
indices_to_discount = np.where(samples > 1)[0] |
|
if indices_to_discount.size == 0: |
|
|
|
print("Cannot reach target sum without going below 1 for some elements.") |
|
return samples |
|
discount_count = int(min(len(indices_to_discount), samples.sum() - target_sum)) |
|
indices_to_discount = indices_to_discount[:discount_count] |
|
samples[indices_to_discount] -= 1 |
|
return samples |
|
|
|
|
|
def create_merged_train_dataset_info(data_preset_multi: dict, data_home: Optional[os.PathLike] = None): |
|
"""Create merged dataset info from data preset multi. |
|
Args: |
|
data_preset_multi (dict): data preset multi |
|
data_home (os.PathLike, optional): path to data home. If None, used the path defined |
|
in config/config.py. |
|
|
|
Returns: |
|
dict: merged dataset info |
|
""" |
|
train_dataset_info = { |
|
"n_datasets": 0, |
|
"n_notes_per_dataset": None, |
|
"n_files_per_dataset": [], |
|
"dataset_names": [], |
|
"data_split_names": [], |
|
"index_ranges": [], |
|
"dataset_weights": None, |
|
"merged_file_list": {}, |
|
} |
|
|
|
if data_home is None: |
|
data_home = shared_cfg["PATH"]["data_home"] |
|
assert os.path.exists(data_home) |
|
|
|
for dp in data_preset_multi["presets"]: |
|
train_dataset_info["n_datasets"] += 1 |
|
|
|
dataset_name = data_preset_single_cfg[dp]["dataset_name"] |
|
train_dataset_info["dataset_names"].append(dataset_name) |
|
train_dataset_info["data_split_names"].append(dp) |
|
|
|
|
|
if isinstance(data_preset_single_cfg[dp]["train_split"], str): |
|
train_split_name = data_preset_single_cfg[dp]["train_split"] |
|
file_list_path = os.path.join(data_home, 'yourmt3_indexes', |
|
f'{dataset_name}_{train_split_name}_file_list.json') |
|
|
|
if not os.path.exists(file_list_path): |
|
raise ValueError(f"File list {file_list_path} does not exist.") |
|
_file_list = json.load(open(file_list_path, 'r')) |
|
elif isinstance(data_preset_single_cfg[dp]["train_split"], dict): |
|
_file_list = data_preset_single_cfg[dp]["train_split"] |
|
else: |
|
raise ValueError("Invalid train split.") |
|
|
|
|
|
start_idx = len(train_dataset_info["merged_file_list"]) |
|
for i, v in enumerate(_file_list.values()): |
|
train_dataset_info["merged_file_list"][start_idx + i] = v |
|
train_dataset_info["n_files_per_dataset"].append(len(_file_list)) |
|
train_dataset_info["index_ranges"].append((start_idx, start_idx + len(_file_list))) |
|
|
|
|
|
if "weights" in data_preset_multi.keys() and data_preset_multi["weights"] is not None: |
|
train_dataset_info["dataset_weights"] = data_preset_multi["weights"] |
|
assert len(train_dataset_info["dataset_weights"]) == train_dataset_info["n_datasets"] |
|
else: |
|
train_dataset_info["dataset_weights"] = np.ones(train_dataset_info["n_datasets"]) |
|
print("No dataset weights specified, using equal weights for all datasets.") |
|
return train_dataset_info |
|
|
|
|
|
def get_random_sampler(dataset, num_samples): |
|
if torch.distributed.is_initialized(): |
|
return DistributedSamplerWrapper(sampler=RandomSampler(dataset, num_samples=num_samples)) |
|
else: |
|
return RandomSampler(dataset, num_samples=num_samples) |
|
|
|
|
|
def get_weighted_random_sampler(dataset_weights: List[float], |
|
dataset_index_ranges: List[Tuple[int]], |
|
num_samples_per_epoch: Optional[int] = None, |
|
replacement: bool = True) -> torch.utils.data.sampler.Sampler: |
|
"""Get distributed weighted random sampler. |
|
Args: |
|
dataset_weights (List[float]): list of dataset weights of n length for n_datasets |
|
dataset_index_ranges (List[Tuple[int]]): list of dataset index ranges |
|
n_samples_per_epoch (Optional[int]): number of samples per epoch, typically length of |
|
entire dataset. Defaults to None. If None, the total number of samples is calculated. |
|
replacement (bool, optional): replacement. Defaults to True. |
|
Returns: |
|
(distributed) weighted random sampler |
|
""" |
|
assert len(dataset_weights) == len(dataset_index_ranges) |
|
|
|
sample_weights = [] |
|
n_total_samples_in_datasets = dataset_index_ranges[-1][1] |
|
if len(dataset_weights) > 1 and len(dataset_index_ranges) > 1: |
|
for dataset_weight, index_range in zip(dataset_weights, dataset_index_ranges): |
|
assert dataset_weight >= 0 |
|
n_samples_in_dataset = index_range[1] - index_range[0] |
|
sample_weight = dataset_weight * (1 - n_samples_in_dataset / n_total_samples_in_datasets) |
|
|
|
sample_weights += [sample_weight] * (index_range[1] - index_range[0]) |
|
elif len(dataset_weights) == 1 and len(dataset_index_ranges) == 1: |
|
|
|
sample_weights = [1] * n_total_samples_in_datasets |
|
|
|
if num_samples_per_epoch is None: |
|
num_samples_per_epoch = n_total_samples_in_datasets |
|
|
|
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=num_samples_per_epoch, replacement=replacement) |
|
|
|
if torch.distributed.is_initialized(): |
|
return DistributedSamplerWrapper(sampler=sampler) |
|
else: |
|
return sampler |
|
|
|
|
|
def get_list_of_weighted_random_samplers(num_samplers: int, |
|
dataset_weights: List[float], |
|
dataset_index_ranges: List[Tuple[int]], |
|
num_samples_per_epoch: Optional[int] = None, |
|
replacement: bool = True) -> List[torch.utils.data.sampler.Sampler]: |
|
"""Get list of distributed weighted random samplers. |
|
Args: |
|
dataset_weights (List[float]): list of dataset weights of n length for n_datasets |
|
dataset_index_ranges (List[Tuple[int]]): list of dataset index ranges |
|
n_samples_per_epoch (Optional[int]): number of samples per epoch, typically length of |
|
entire dataset. Defaults to None. If None, the total number of samples is calculated. |
|
replacement (bool, optional): replacement. Defaults to True. |
|
|
|
Returns: |
|
List[(distributed) weighted random sampler] |
|
""" |
|
assert num_samplers > 0 |
|
samplers = [] |
|
for i in range(num_samplers): |
|
samplers.append( |
|
get_weighted_random_sampler(dataset_weights, dataset_index_ranges, num_samples_per_epoch, replacement)) |
|
return samplers |
|
|