| |
| |
| |
| |
| |
|
|
| import json |
| from fractions import Fraction |
| from concurrent import futures |
|
|
| import musdb |
| from torch import distributed |
|
|
| from .audio import AudioFile |
|
|
|
|
| def get_musdb_tracks(root, *args, **kwargs): |
| mus = musdb.DB(root, *args, **kwargs) |
| return {track.name: track.path for track in mus} |
|
|
|
|
| class StemsSet: |
| def __init__(self, tracks, metadata, duration=None, stride=1, |
| samplerate=44100, channels=2, streams=slice(None)): |
|
|
| self.metadata = [] |
| for name, path in tracks.items(): |
| meta = dict(metadata[name]) |
| meta["path"] = path |
| meta["name"] = name |
| self.metadata.append(meta) |
| if duration is not None and meta["duration"] < duration: |
| raise ValueError(f"Track {name} duration is too small {meta['duration']}") |
| self.metadata.sort(key=lambda x: x["name"]) |
| self.duration = duration |
| self.stride = stride |
| self.channels = channels |
| self.samplerate = samplerate |
| self.streams = streams |
|
|
| def __len__(self): |
| return sum(self._examples_count(m) for m in self.metadata) |
|
|
| def _examples_count(self, meta): |
| if self.duration is None: |
| return 1 |
| else: |
| return int((meta["duration"] - self.duration) // self.stride + 1) |
|
|
| def track_metadata(self, index): |
| for meta in self.metadata: |
| examples = self._examples_count(meta) |
| if index >= examples: |
| index -= examples |
| continue |
| return meta |
|
|
| def __getitem__(self, index): |
| for meta in self.metadata: |
| examples = self._examples_count(meta) |
| if index >= examples: |
| index -= examples |
| continue |
| streams = AudioFile(meta["path"]).read(seek_time=index * self.stride, |
| duration=self.duration, |
| channels=self.channels, |
| samplerate=self.samplerate, |
| streams=self.streams) |
| return (streams - meta["mean"]) / meta["std"] |
|
|
|
|
| def _get_track_metadata(path): |
| |
| |
| audio = AudioFile(path) |
| mix = audio.read(streams=0, channels=1, samplerate=44100) |
| return {"duration": audio.duration, "std": mix.std().item(), "mean": mix.mean().item()} |
|
|
|
|
| def _build_metadata(tracks, workers=10): |
| pendings = [] |
| with futures.ProcessPoolExecutor(workers) as pool: |
| for name, path in tracks.items(): |
| pendings.append((name, pool.submit(_get_track_metadata, path))) |
| return {name: p.result() for name, p in pendings} |
|
|
|
|
| def _build_musdb_metadata(path, musdb, workers): |
| tracks = get_musdb_tracks(musdb) |
| metadata = _build_metadata(tracks, workers) |
| path.parent.mkdir(exist_ok=True, parents=True) |
| json.dump(metadata, open(path, "w")) |
|
|
|
|
| def get_compressed_datasets(args, samples): |
| metadata_file = args.metadata / "musdb.json" |
| if not metadata_file.is_file() and args.rank == 0: |
| _build_musdb_metadata(metadata_file, args.musdb, args.workers) |
| if args.world_size > 1: |
| distributed.barrier() |
| metadata = json.load(open(metadata_file)) |
| duration = Fraction(samples, args.samplerate) |
| stride = Fraction(args.data_stride, args.samplerate) |
| train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"), |
| metadata, |
| duration=duration, |
| stride=stride, |
| streams=slice(1, None), |
| samplerate=args.samplerate, |
| channels=args.audio_channels) |
| valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"), |
| metadata, |
| samplerate=args.samplerate, |
| channels=args.audio_channels) |
| return train_set, valid_set |
|
|