akhaliq3
spaces demo
5019931
raw
history blame
No virus
5.12 kB
import argparse
import os
import pickle
from typing import NoReturn
import h5py
from bytesep.utils import read_yaml
def create_indexes(args) -> NoReturn:
r"""Create and write out training indexes into disk. The indexes may contain
information from multiple datasets. During training, training indexes will
be shuffled and iterated for selecting segments to be mixed. E.g., the
training indexes_dict looks like: {
'vocals': [
{'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}
{'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 136710}
...
]
'accompaniment': [
{'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 0, 'end_sample': 132300}
{'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 4410, 'end_sample': 136710}
...
]
}
"""
# Arugments & parameters
workspace = args.workspace
config_yaml = args.config_yaml
# Only create indexes for training, because evalution is on entire pieces.
split = "train"
# Read config file.
configs = read_yaml(config_yaml)
sample_rate = configs["sample_rate"]
segment_samples = int(configs["segment_seconds"] * sample_rate)
# Path to write out index.
indexes_path = os.path.join(workspace, configs[split]["indexes"])
os.makedirs(os.path.dirname(indexes_path), exist_ok=True)
source_types = configs[split]["source_types"].keys()
# E.g., ['vocals', 'accompaniment']
indexes_dict = {source_type: [] for source_type in source_types}
# E.g., indexes_dict will looks like: {
# 'vocals': [
# {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}
# {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 136710}
# ...
# ]
# 'accompaniment': [
# {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 0, 'end_sample': 132300}
# {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 4410, 'end_sample': 136710}
# ...
# ]
# }
# Get training indexes for each source type.
for source_type in source_types:
# E.g., ['vocals', 'bass', ...]
print("--- {} ---".format(source_type))
dataset_types = configs[split]["source_types"][source_type]
# E.g., ['musdb18', ...]
# Each source can come from mulitple datasets.
for dataset_type in dataset_types:
hdf5s_dir = os.path.join(
workspace, dataset_types[dataset_type]["hdf5s_directory"]
)
hop_samples = int(dataset_types[dataset_type]["hop_seconds"] * sample_rate)
key_in_hdf5 = dataset_types[dataset_type]["key_in_hdf5"]
# E.g., 'vocals'
hdf5_names = sorted(os.listdir(hdf5s_dir))
print("Hdf5 files num: {}".format(len(hdf5_names)))
# Traverse all packed hdf5 files of a dataset.
for n, hdf5_name in enumerate(hdf5_names):
print(n, hdf5_name)
hdf5_path = os.path.join(hdf5s_dir, hdf5_name)
with h5py.File(hdf5_path, "r") as hf:
bgn_sample = 0
while bgn_sample + segment_samples < hf[key_in_hdf5].shape[-1]:
meta = {
'hdf5_path': hdf5_path,
'key_in_hdf5': key_in_hdf5,
'begin_sample': bgn_sample,
'end_sample': bgn_sample + segment_samples,
}
indexes_dict[source_type].append(meta)
bgn_sample += hop_samples
# If the audio length is shorter than the segment length,
# then use the entire audio as a segment.
if bgn_sample == 0:
meta = {
'hdf5_path': hdf5_path,
'key_in_hdf5': key_in_hdf5,
'begin_sample': 0,
'end_sample': segment_samples,
}
indexes_dict[source_type].append(meta)
print(
"Total indexes for {}: {}".format(
source_type, len(indexes_dict[source_type])
)
)
pickle.dump(indexes_dict, open(indexes_path, "wb"))
print("Write index dict to {}".format(indexes_path))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--workspace", type=str, required=True, help="Directory of workspace."
)
parser.add_argument(
"--config_yaml", type=str, required=True, help="User defined config file."
)
# Parse arguments.
args = parser.parse_args()
# Create training indexes.
create_indexes(args)