# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # author: adefossez import json from pathlib import Path import math import os import sys import torchaudio from torch.nn import functional as F def find_audio_files(path, exts=[".wav"], progress=True): audio_files = [] for root, folders, files in os.walk(path, followlinks=True): for file in files: file = Path(root) / file if file.suffix.lower() in exts: audio_files.append(str(file.resolve())) meta = [] for idx, file in enumerate(audio_files): siginfo, _ = torchaudio.info(file) length = siginfo.length // siginfo.channels meta.append((file, length)) if progress: print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr) meta.sort() return meta class Audioset: def __init__(self, files=None, length=None, stride=None, pad=True, with_path=False, sample_rate=None): """ files should be a list [(file, length)] """ self.files = files self.num_examples = [] self.length = length self.stride = stride or length self.with_path = with_path self.sample_rate = sample_rate for file, file_length in self.files: if length is None: examples = 1 elif file_length < length: examples = 1 if pad else 0 elif pad: examples = int(math.ceil((file_length - self.length) / self.stride) + 1) else: examples = (file_length - self.length) // self.stride + 1 self.num_examples.append(examples) def __len__(self): return sum(self.num_examples) def __getitem__(self, index): for (file, _), examples in zip(self.files, self.num_examples): if index >= examples: index -= examples continue num_frames = 0 offset = 0 if self.length is not None: offset = self.stride * index num_frames = self.length out, sr = torchaudio.load(str(file), offset=offset, num_frames=num_frames) if self.sample_rate is not None: if sr != self.sample_rate: raise RuntimeError(f"Expected {file} to have sample rate of " f"{self.sample_rate}, but got {sr}") if num_frames: out = F.pad(out, (0, num_frames - out.shape[-1])) if self.with_path: return out, file else: return out if __name__ == "__main__": meta = [] for path in sys.argv[1:]: meta += find_audio_files(path) json.dump(meta, sys.stdout, indent=4)