Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import json | |
| from fractions import Fraction | |
| from concurrent import futures | |
| import musdb | |
| from torch import distributed | |
| from .audio import AudioFile | |
| def get_musdb_tracks(root, *args, **kwargs): | |
| mus = musdb.DB(root, *args, **kwargs) | |
| return {track.name: track.path for track in mus} | |
| class StemsSet: | |
| def __init__(self, tracks, metadata, duration=None, stride=1, | |
| samplerate=44100, channels=2, streams=slice(None)): | |
| self.metadata = [] | |
| for name, path in tracks.items(): | |
| meta = dict(metadata[name]) | |
| meta["path"] = path | |
| meta["name"] = name | |
| self.metadata.append(meta) | |
| if duration is not None and meta["duration"] < duration: | |
| raise ValueError(f"Track {name} duration is too small {meta['duration']}") | |
| self.metadata.sort(key=lambda x: x["name"]) | |
| self.duration = duration | |
| self.stride = stride | |
| self.channels = channels | |
| self.samplerate = samplerate | |
| self.streams = streams | |
| def __len__(self): | |
| return sum(self._examples_count(m) for m in self.metadata) | |
| def _examples_count(self, meta): | |
| if self.duration is None: | |
| return 1 | |
| else: | |
| return int((meta["duration"] - self.duration) // self.stride + 1) | |
| def track_metadata(self, index): | |
| for meta in self.metadata: | |
| examples = self._examples_count(meta) | |
| if index >= examples: | |
| index -= examples | |
| continue | |
| return meta | |
| def __getitem__(self, index): | |
| for meta in self.metadata: | |
| examples = self._examples_count(meta) | |
| if index >= examples: | |
| index -= examples | |
| continue | |
| streams = AudioFile(meta["path"]).read(seek_time=index * self.stride, | |
| duration=self.duration, | |
| channels=self.channels, | |
| samplerate=self.samplerate, | |
| streams=self.streams) | |
| return (streams - meta["mean"]) / meta["std"] | |
| def _get_track_metadata(path): | |
| # use mono at 44kHz as reference. For any other settings data won't be perfectly | |
| # normalized but it should be good enough. | |
| audio = AudioFile(path) | |
| mix = audio.read(streams=0, channels=1, samplerate=44100) | |
| return {"duration": audio.duration, "std": mix.std().item(), "mean": mix.mean().item()} | |
| def _build_metadata(tracks, workers=10): | |
| pendings = [] | |
| with futures.ProcessPoolExecutor(workers) as pool: | |
| for name, path in tracks.items(): | |
| pendings.append((name, pool.submit(_get_track_metadata, path))) | |
| return {name: p.result() for name, p in pendings} | |
| def _build_musdb_metadata(path, musdb, workers): | |
| tracks = get_musdb_tracks(musdb) | |
| metadata = _build_metadata(tracks, workers) | |
| path.parent.mkdir(exist_ok=True, parents=True) | |
| json.dump(metadata, open(path, "w")) | |
| def get_compressed_datasets(args, samples): | |
| metadata_file = args.metadata / "musdb.json" | |
| if not metadata_file.is_file() and args.rank == 0: | |
| _build_musdb_metadata(metadata_file, args.musdb, args.workers) | |
| if args.world_size > 1: | |
| distributed.barrier() | |
| metadata = json.load(open(metadata_file)) | |
| duration = Fraction(samples, args.samplerate) | |
| stride = Fraction(args.data_stride, args.samplerate) | |
| train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"), | |
| metadata, | |
| duration=duration, | |
| stride=stride, | |
| streams=slice(1, None), | |
| samplerate=args.samplerate, | |
| channels=args.audio_channels) | |
| valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"), | |
| metadata, | |
| samplerate=args.samplerate, | |
| channels=args.audio_channels) | |
| return train_set, valid_set | |