Spaces:
Runtime error
Runtime error
import pickle | |
from typing import Dict, List, NoReturn | |
import numpy as np | |
import torch.distributed as dist | |
class SegmentSampler: | |
def __init__( | |
self, | |
indexes_path: str, | |
segment_samples: int, | |
mixaudio_dict: Dict, | |
batch_size: int, | |
steps_per_epoch: int, | |
random_seed=1234, | |
): | |
r"""Sample training indexes of sources. | |
Args: | |
indexes_path: str, path of indexes dict | |
segment_samplers: int | |
mixaudio_dict, dict, including hyper-parameters for mix-audio data | |
augmentation, e.g., {'voclas': 2, 'accompaniment': 2} | |
batch_size: int | |
steps_per_epoch: int, #steps_per_epoch is called an `epoch` | |
random_seed: int | |
""" | |
self.segment_samples = segment_samples | |
self.mixaudio_dict = mixaudio_dict | |
self.batch_size = batch_size | |
self.steps_per_epoch = steps_per_epoch | |
self.meta_dict = pickle.load(open(indexes_path, "rb")) | |
# E.g., { | |
# 'vocals': [ | |
# {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}, | |
# {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 445410}, | |
# ... | |
# ], | |
# 'accompaniment': [ | |
# {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}, | |
# {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 445410}, | |
# ... | |
# ] | |
# } | |
self.source_types = self.meta_dict.keys() | |
# E.g., ['vocals', 'accompaniment'] | |
self.pointers_dict = {source_type: 0 for source_type in self.source_types} | |
# E.g., {'vocals': 0, 'accompaniment': 0} | |
self.indexes_dict = { | |
source_type: np.arange(len(self.meta_dict[source_type])) | |
for source_type in self.source_types | |
} | |
# E.g. { | |
# 'vocals': [0, 1, ..., 225751], | |
# 'accompaniment': [0, 1, ..., 225751] | |
# } | |
self.random_state = np.random.RandomState(random_seed) | |
# Shuffle indexes. | |
for source_type in self.source_types: | |
self.random_state.shuffle(self.indexes_dict[source_type]) | |
print("{}: {}".format(source_type, len(self.indexes_dict[source_type]))) | |
def __iter__(self) -> List[Dict]: | |
r"""Yield a batch of meta info. | |
Returns: | |
batch_meta_list: (batch_size,) e.g., when mix-audio is 2, looks like [ | |
{'vocals': [ | |
{'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700}, | |
{'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}] | |
'accompaniment': [ | |
{'hdf5_path': 'songE.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 14579460, 'end_sample': 14711760}, | |
{'hdf5_path': 'songF.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 3995460, 'end_sample': 4127760}] | |
} | |
... | |
] | |
""" | |
batch_size = self.batch_size | |
while True: | |
batch_meta_dict = {source_type: [] for source_type in self.source_types} | |
for source_type in self.source_types: | |
# E.g., ['vocals', 'accompaniment'] | |
# Loop until get a mini-batch. | |
while len(batch_meta_dict[source_type]) != batch_size: | |
largest_index = ( | |
len(self.indexes_dict[source_type]) | |
- self.mixaudio_dict[source_type] | |
) | |
# E.g., 225750 = 225752 - 2 | |
if self.pointers_dict[source_type] > largest_index: | |
# Reset pointer, and shuffle indexes. | |
self.pointers_dict[source_type] = 0 | |
self.random_state.shuffle(self.indexes_dict[source_type]) | |
source_metas = [] | |
mix_audios_num = self.mixaudio_dict[source_type] | |
for _ in range(mix_audios_num): | |
pointer = self.pointers_dict[source_type] | |
# E.g., 1 | |
index = self.indexes_dict[source_type][pointer] | |
# E.g., 12231 | |
self.pointers_dict[source_type] += 1 | |
source_meta = self.meta_dict[source_type][index] | |
# E.g., ['song_A.h5', 198450, 330750] | |
# source_metas.append(new_source_meta) | |
source_metas.append(source_meta) | |
batch_meta_dict[source_type].append(source_metas) | |
# When mix-audio is 2, batch_meta_dict looks like: { | |
# 'vocals': [ | |
# [{'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700}, | |
# {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}], | |
# [{'hdf5_path': 'songC.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 1186290, 'end_sample': 1318590}, | |
# {'hdf5_path': 'songD.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 8462790, 'end_sample': 8595090}] | |
# ] | |
# 'accompaniment': [ | |
# [{'hdf5_path': 'songE.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 24232950, 'end_sample': 24365250}, | |
# {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 1569960, 'end_sample': 1702260}], | |
# [{'hdf5_path': 'songG.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 2795940, 'end_sample': 2928240}, | |
# {'hdf5_path': 'songH.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 10923570, 'end_sample': 11055870}] | |
# ] | |
# } | |
batch_meta_list = [ | |
{ | |
source_type: batch_meta_dict[source_type][i] | |
for source_type in self.source_types | |
} | |
for i in range(batch_size) | |
] | |
# When mix-audio is 2, batch_meta_list looks like: [ | |
# {'vocals': [ | |
# {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700}, | |
# {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}] | |
# 'accompaniment': [ | |
# {'hdf5_path': 'songE.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 14579460, 'end_sample': 14711760}, | |
# {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 3995460, 'end_sample': 4127760}] | |
# } | |
# ... | |
# ] | |
yield batch_meta_list | |
def __len__(self) -> int: | |
return self.steps_per_epoch | |
def state_dict(self) -> Dict: | |
state = {'pointers_dict': self.pointers_dict, 'indexes_dict': self.indexes_dict} | |
return state | |
def load_state_dict(self, state) -> NoReturn: | |
self.pointers_dict = state['pointers_dict'] | |
self.indexes_dict = state['indexes_dict'] | |
class DistributedSamplerWrapper: | |
def __init__(self, sampler): | |
r"""Distributed wrapper of sampler.""" | |
self.sampler = sampler | |
def __iter__(self): | |
num_replicas = dist.get_world_size() | |
rank = dist.get_rank() | |
for indices in self.sampler: | |
yield indices[rank::num_replicas] | |
def __len__(self) -> int: | |
return len(self.sampler) | |