# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse import os from collections import defaultdict, namedtuple from pathlib import Path import musdb import numpy as np import torch as th import tqdm from torch.utils.data import DataLoader from .audio import AudioFile ChunkInfo = namedtuple("ChunkInfo", ["file_index", "offset", "local_index"]) class Rawset: """ Dataset of raw, normalized, float32 audio files """ def __init__(self, path, samples=None, stride=None, channels=2, streams=None): self.path = Path(path) self.channels = channels self.samples = samples if stride is None: stride = samples if samples is not None else 0 self.stride = stride entries = defaultdict(list) for root, folders, files in os.walk(self.path, followlinks=True): folders.sort() files.sort() for file in files: if file.endswith(".raw"): path = Path(root) / file name, stream = path.stem.rsplit('.', 1) entries[(path.parent.relative_to(self.path), name)].append(int(stream)) self._entries = list(entries.keys()) sizes = [] self._lengths = [] ref_streams = sorted(entries[self._entries[0]]) assert ref_streams == list(range(len(ref_streams))) if streams is None: self.streams = ref_streams else: self.streams = streams for entry in sorted(entries.keys()): streams = entries[entry] assert sorted(streams) == ref_streams file = self._path(*entry) length = file.stat().st_size // (4 * channels) if samples is None: sizes.append(1) else: if length < samples: self._entries.remove(entry) continue sizes.append((length - samples) // stride + 1) self._lengths.append(length) if not sizes: raise ValueError(f"Empty dataset {self.path}") self._cumulative_sizes = np.cumsum(sizes) self._sizes = sizes def __len__(self): return self._cumulative_sizes[-1] @property def total_length(self): return sum(self._lengths) def chunk_info(self, index): file_index = np.searchsorted(self._cumulative_sizes, index, side='right') if file_index == 0: local_index = index else: local_index = index - self._cumulative_sizes[file_index - 1] return ChunkInfo(offset=local_index * self.stride, file_index=file_index, local_index=local_index) def _path(self, folder, name, stream=0): return self.path / folder / (name + f'.{stream}.raw') def __getitem__(self, index): chunk = self.chunk_info(index) entry = self._entries[chunk.file_index] length = self.samples or self._lengths[chunk.file_index] streams = [] to_read = length * self.channels * 4 for stream_index, stream in enumerate(self.streams): offset = chunk.offset * 4 * self.channels file = open(self._path(*entry, stream=stream), 'rb') file.seek(offset) content = file.read(to_read) assert len(content) == to_read content = np.frombuffer(content, dtype=np.float32) content = content.copy() # make writable streams.append(th.from_numpy(content).view(length, self.channels).t()) return th.stack(streams, dim=0) def name(self, index): chunk = self.chunk_info(index) folder, name = self._entries[chunk.file_index] return folder / name class MusDBSet: def __init__(self, mus, streams=slice(None), samplerate=44100, channels=2): self.mus = mus self.streams = streams self.samplerate = samplerate self.channels = channels def __len__(self): return len(self.mus.tracks) def __getitem__(self, index): track = self.mus.tracks[index] return (track.name, AudioFile(track.path).read(channels=self.channels, seek_time=0, streams=self.streams, samplerate=self.samplerate)) def build_raw(mus, destination, normalize, workers, samplerate, channels): destination.mkdir(parents=True, exist_ok=True) loader = DataLoader(MusDBSet(mus, channels=channels, samplerate=samplerate), batch_size=1, num_workers=workers, collate_fn=lambda x: x[0]) for name, streams in tqdm.tqdm(loader): if normalize: ref = streams[0].mean(dim=0) # use mono mixture as reference streams = (streams - ref.mean()) / ref.std() for index, stream in enumerate(streams): open(destination / (name + f'.{index}.raw'), "wb").write(stream.t().numpy().tobytes()) def main(): parser = argparse.ArgumentParser('rawset') parser.add_argument('--workers', type=int, default=10) parser.add_argument('--samplerate', type=int, default=44100) parser.add_argument('--channels', type=int, default=2) parser.add_argument('musdb', type=Path) parser.add_argument('destination', type=Path) args = parser.parse_args() build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="train"), args.destination / "train", normalize=True, channels=args.channels, samplerate=args.samplerate, workers=args.workers) build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="valid"), args.destination / "valid", normalize=True, samplerate=args.samplerate, channels=args.channels, workers=args.workers) if __name__ == "__main__": main()