Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import json | |
from fractions import Fraction | |
from concurrent import futures | |
import musdb | |
from torch import distributed | |
from .audio import AudioFile | |
def get_musdb_tracks(root, *args, **kwargs): | |
mus = musdb.DB(root, *args, **kwargs) | |
return {track.name: track.path for track in mus} | |
class StemsSet: | |
def __init__(self, tracks, metadata, duration=None, stride=1, | |
samplerate=44100, channels=2, streams=slice(None)): | |
self.metadata = [] | |
for name, path in tracks.items(): | |
meta = dict(metadata[name]) | |
meta["path"] = path | |
meta["name"] = name | |
self.metadata.append(meta) | |
if duration is not None and meta["duration"] < duration: | |
raise ValueError(f"Track {name} duration is too small {meta['duration']}") | |
self.metadata.sort(key=lambda x: x["name"]) | |
self.duration = duration | |
self.stride = stride | |
self.channels = channels | |
self.samplerate = samplerate | |
self.streams = streams | |
def __len__(self): | |
return sum(self._examples_count(m) for m in self.metadata) | |
def _examples_count(self, meta): | |
if self.duration is None: | |
return 1 | |
else: | |
return int((meta["duration"] - self.duration) // self.stride + 1) | |
def track_metadata(self, index): | |
for meta in self.metadata: | |
examples = self._examples_count(meta) | |
if index >= examples: | |
index -= examples | |
continue | |
return meta | |
def __getitem__(self, index): | |
for meta in self.metadata: | |
examples = self._examples_count(meta) | |
if index >= examples: | |
index -= examples | |
continue | |
streams = AudioFile(meta["path"]).read(seek_time=index * self.stride, | |
duration=self.duration, | |
channels=self.channels, | |
samplerate=self.samplerate, | |
streams=self.streams) | |
return (streams - meta["mean"]) / meta["std"] | |
def _get_track_metadata(path): | |
# use mono at 44kHz as reference. For any other settings data won't be perfectly | |
# normalized but it should be good enough. | |
audio = AudioFile(path) | |
mix = audio.read(streams=0, channels=1, samplerate=44100) | |
return {"duration": audio.duration, "std": mix.std().item(), "mean": mix.mean().item()} | |
def _build_metadata(tracks, workers=10): | |
pendings = [] | |
with futures.ProcessPoolExecutor(workers) as pool: | |
for name, path in tracks.items(): | |
pendings.append((name, pool.submit(_get_track_metadata, path))) | |
return {name: p.result() for name, p in pendings} | |
def _build_musdb_metadata(path, musdb, workers): | |
tracks = get_musdb_tracks(musdb) | |
metadata = _build_metadata(tracks, workers) | |
path.parent.mkdir(exist_ok=True, parents=True) | |
json.dump(metadata, open(path, "w")) | |
def get_compressed_datasets(args, samples): | |
metadata_file = args.metadata / "musdb.json" | |
if not metadata_file.is_file() and args.rank == 0: | |
_build_musdb_metadata(metadata_file, args.musdb, args.workers) | |
if args.world_size > 1: | |
distributed.barrier() | |
metadata = json.load(open(metadata_file)) | |
duration = Fraction(samples, args.samplerate) | |
stride = Fraction(args.data_stride, args.samplerate) | |
train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"), | |
metadata, | |
duration=duration, | |
stride=stride, | |
streams=slice(1, None), | |
samplerate=args.samplerate, | |
channels=args.audio_channels) | |
valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"), | |
metadata, | |
samplerate=args.samplerate, | |
channels=args.audio_channels) | |
return train_set, valid_set | |