| |
| |
| |
| |
| |
|
|
| import argparse |
| import os |
| from collections import defaultdict, namedtuple |
| from pathlib import Path |
|
|
| import musdb |
| import numpy as np |
| import torch as th |
| import tqdm |
| from torch.utils.data import DataLoader |
|
|
| from .audio import AudioFile |
|
|
| ChunkInfo = namedtuple("ChunkInfo", ["file_index", "offset", "local_index"]) |
|
|
|
|
| class Rawset: |
| """ |
| Dataset of raw, normalized, float32 audio files |
| """ |
| def __init__(self, path, samples=None, stride=None, channels=2, streams=None): |
| self.path = Path(path) |
| self.channels = channels |
| self.samples = samples |
| if stride is None: |
| stride = samples if samples is not None else 0 |
| self.stride = stride |
| entries = defaultdict(list) |
| for root, folders, files in os.walk(self.path, followlinks=True): |
| folders.sort() |
| files.sort() |
| for file in files: |
| if file.endswith(".raw"): |
| path = Path(root) / file |
| name, stream = path.stem.rsplit('.', 1) |
| entries[(path.parent.relative_to(self.path), name)].append(int(stream)) |
|
|
| self._entries = list(entries.keys()) |
|
|
| sizes = [] |
| self._lengths = [] |
| ref_streams = sorted(entries[self._entries[0]]) |
| assert ref_streams == list(range(len(ref_streams))) |
| if streams is None: |
| self.streams = ref_streams |
| else: |
| self.streams = streams |
| for entry in sorted(entries.keys()): |
| streams = entries[entry] |
| assert sorted(streams) == ref_streams |
| file = self._path(*entry) |
| length = file.stat().st_size // (4 * channels) |
| if samples is None: |
| sizes.append(1) |
| else: |
| if length < samples: |
| self._entries.remove(entry) |
| continue |
| sizes.append((length - samples) // stride + 1) |
| self._lengths.append(length) |
| if not sizes: |
| raise ValueError(f"Empty dataset {self.path}") |
| self._cumulative_sizes = np.cumsum(sizes) |
| self._sizes = sizes |
|
|
| def __len__(self): |
| return self._cumulative_sizes[-1] |
|
|
| @property |
| def total_length(self): |
| return sum(self._lengths) |
|
|
| def chunk_info(self, index): |
| file_index = np.searchsorted(self._cumulative_sizes, index, side='right') |
| if file_index == 0: |
| local_index = index |
| else: |
| local_index = index - self._cumulative_sizes[file_index - 1] |
| return ChunkInfo(offset=local_index * self.stride, |
| file_index=file_index, |
| local_index=local_index) |
|
|
| def _path(self, folder, name, stream=0): |
| return self.path / folder / (name + f'.{stream}.raw') |
|
|
| def __getitem__(self, index): |
| chunk = self.chunk_info(index) |
| entry = self._entries[chunk.file_index] |
|
|
| length = self.samples or self._lengths[chunk.file_index] |
| streams = [] |
| to_read = length * self.channels * 4 |
| for stream_index, stream in enumerate(self.streams): |
| offset = chunk.offset * 4 * self.channels |
| file = open(self._path(*entry, stream=stream), 'rb') |
| file.seek(offset) |
| content = file.read(to_read) |
| assert len(content) == to_read |
| content = np.frombuffer(content, dtype=np.float32) |
| content = content.copy() |
| streams.append(th.from_numpy(content).view(length, self.channels).t()) |
| return th.stack(streams, dim=0) |
|
|
| def name(self, index): |
| chunk = self.chunk_info(index) |
| folder, name = self._entries[chunk.file_index] |
| return folder / name |
|
|
|
|
| class MusDBSet: |
| def __init__(self, mus, streams=slice(None), samplerate=44100, channels=2): |
| self.mus = mus |
| self.streams = streams |
| self.samplerate = samplerate |
| self.channels = channels |
|
|
| def __len__(self): |
| return len(self.mus.tracks) |
|
|
| def __getitem__(self, index): |
| track = self.mus.tracks[index] |
| return (track.name, AudioFile(track.path).read(channels=self.channels, |
| seek_time=0, |
| streams=self.streams, |
| samplerate=self.samplerate)) |
|
|
|
|
| def build_raw(mus, destination, normalize, workers, samplerate, channels): |
| destination.mkdir(parents=True, exist_ok=True) |
| loader = DataLoader(MusDBSet(mus, channels=channels, samplerate=samplerate), |
| batch_size=1, |
| num_workers=workers, |
| collate_fn=lambda x: x[0]) |
| for name, streams in tqdm.tqdm(loader): |
| if normalize: |
| ref = streams[0].mean(dim=0) |
| streams = (streams - ref.mean()) / ref.std() |
| for index, stream in enumerate(streams): |
| open(destination / (name + f'.{index}.raw'), "wb").write(stream.t().numpy().tobytes()) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser('rawset') |
| parser.add_argument('--workers', type=int, default=10) |
| parser.add_argument('--samplerate', type=int, default=44100) |
| parser.add_argument('--channels', type=int, default=2) |
| parser.add_argument('musdb', type=Path) |
| parser.add_argument('destination', type=Path) |
|
|
| args = parser.parse_args() |
|
|
| build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="train"), |
| args.destination / "train", |
| normalize=True, |
| channels=args.channels, |
| samplerate=args.samplerate, |
| workers=args.workers) |
| build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="valid"), |
| args.destination / "valid", |
| normalize=True, |
| samplerate=args.samplerate, |
| channels=args.channels, |
| workers=args.workers) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|