akhaliq3
spaces demo
5019931
import pickle
from typing import Dict, List, NoReturn
import numpy as np
import torch.distributed as dist
class SegmentSampler:
def __init__(
self,
indexes_path: str,
segment_samples: int,
mixaudio_dict: Dict,
batch_size: int,
steps_per_epoch: int,
random_seed=1234,
):
r"""Sample training indexes of sources.
Args:
indexes_path: str, path of indexes dict
segment_samplers: int
mixaudio_dict, dict, including hyper-parameters for mix-audio data
augmentation, e.g., {'voclas': 2, 'accompaniment': 2}
batch_size: int
steps_per_epoch: int, #steps_per_epoch is called an `epoch`
random_seed: int
"""
self.segment_samples = segment_samples
self.mixaudio_dict = mixaudio_dict
self.batch_size = batch_size
self.steps_per_epoch = steps_per_epoch
self.meta_dict = pickle.load(open(indexes_path, "rb"))
# E.g., {
# 'vocals': [
# {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300},
# {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 445410},
# ...
# ],
# 'accompaniment': [
# {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300},
# {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 445410},
# ...
# ]
# }
self.source_types = self.meta_dict.keys()
# E.g., ['vocals', 'accompaniment']
self.pointers_dict = {source_type: 0 for source_type in self.source_types}
# E.g., {'vocals': 0, 'accompaniment': 0}
self.indexes_dict = {
source_type: np.arange(len(self.meta_dict[source_type]))
for source_type in self.source_types
}
# E.g. {
# 'vocals': [0, 1, ..., 225751],
# 'accompaniment': [0, 1, ..., 225751]
# }
self.random_state = np.random.RandomState(random_seed)
# Shuffle indexes.
for source_type in self.source_types:
self.random_state.shuffle(self.indexes_dict[source_type])
print("{}: {}".format(source_type, len(self.indexes_dict[source_type])))
def __iter__(self) -> List[Dict]:
r"""Yield a batch of meta info.
Returns:
batch_meta_list: (batch_size,) e.g., when mix-audio is 2, looks like [
{'vocals': [
{'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
{'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}]
'accompaniment': [
{'hdf5_path': 'songE.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 14579460, 'end_sample': 14711760},
{'hdf5_path': 'songF.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 3995460, 'end_sample': 4127760}]
}
...
]
"""
batch_size = self.batch_size
while True:
batch_meta_dict = {source_type: [] for source_type in self.source_types}
for source_type in self.source_types:
# E.g., ['vocals', 'accompaniment']
# Loop until get a mini-batch.
while len(batch_meta_dict[source_type]) != batch_size:
largest_index = (
len(self.indexes_dict[source_type])
- self.mixaudio_dict[source_type]
)
# E.g., 225750 = 225752 - 2
if self.pointers_dict[source_type] > largest_index:
# Reset pointer, and shuffle indexes.
self.pointers_dict[source_type] = 0
self.random_state.shuffle(self.indexes_dict[source_type])
source_metas = []
mix_audios_num = self.mixaudio_dict[source_type]
for _ in range(mix_audios_num):
pointer = self.pointers_dict[source_type]
# E.g., 1
index = self.indexes_dict[source_type][pointer]
# E.g., 12231
self.pointers_dict[source_type] += 1
source_meta = self.meta_dict[source_type][index]
# E.g., ['song_A.h5', 198450, 330750]
# source_metas.append(new_source_meta)
source_metas.append(source_meta)
batch_meta_dict[source_type].append(source_metas)
# When mix-audio is 2, batch_meta_dict looks like: {
# 'vocals': [
# [{'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
# {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}],
# [{'hdf5_path': 'songC.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 1186290, 'end_sample': 1318590},
# {'hdf5_path': 'songD.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 8462790, 'end_sample': 8595090}]
# ]
# 'accompaniment': [
# [{'hdf5_path': 'songE.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 24232950, 'end_sample': 24365250},
# {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 1569960, 'end_sample': 1702260}],
# [{'hdf5_path': 'songG.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 2795940, 'end_sample': 2928240},
# {'hdf5_path': 'songH.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 10923570, 'end_sample': 11055870}]
# ]
# }
batch_meta_list = [
{
source_type: batch_meta_dict[source_type][i]
for source_type in self.source_types
}
for i in range(batch_size)
]
# When mix-audio is 2, batch_meta_list looks like: [
# {'vocals': [
# {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
# {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}]
# 'accompaniment': [
# {'hdf5_path': 'songE.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 14579460, 'end_sample': 14711760},
# {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 3995460, 'end_sample': 4127760}]
# }
# ...
# ]
yield batch_meta_list
def __len__(self) -> int:
return self.steps_per_epoch
def state_dict(self) -> Dict:
state = {'pointers_dict': self.pointers_dict, 'indexes_dict': self.indexes_dict}
return state
def load_state_dict(self, state) -> NoReturn:
self.pointers_dict = state['pointers_dict']
self.indexes_dict = state['indexes_dict']
class DistributedSamplerWrapper:
def __init__(self, sampler):
r"""Distributed wrapper of sampler."""
self.sampler = sampler
def __iter__(self):
num_replicas = dist.get_world_size()
rank = dist.get_rank()
for indices in self.sampler:
yield indices[rank::num_replicas]
def __len__(self) -> int:
return len(self.sampler)