import os import json import glob import torch import random from tqdm import tqdm # from deepafx_st.plugins.channel import Channel from deepafx_st.processors.processor import Processor from deepafx_st.data.audio import AudioFile import deepafx_st.utils as utils class DSPProxyDataset(torch.utils.data.Dataset): """Class for generating input-output audio from Python DSP effects. Args: input_dir (List[str]): List of paths to the directories containing input audio files. processor (Processor): Processor object to create proxy of. processor_type (str): Processor name. subset (str, optional): Dataset subset. One of ["train", "val", "test"]. Default: "train" buffer_size_gb (float, optional): Size of audio to read into RAM in GB at any given time. Default: 10.0 Note: This is the buffer size PER DataLoader worker. So total RAM = buffer_size_gb * num_workers buffer_reload_rate (int, optional): Number of items to generate before loading next chunk of dataset. Default: 10000 length (int, optional): Number of samples to load for each example. Default: 65536 num_examples_per_epoch (int, optional): Define an epoch as certain number of audio examples. Default: 10000 ext (str, optional): Expected audio file extension. Default: "wav" hard_clip (bool, optional): Hard clip outputs between -1 and 1. Default: True """ def __init__( self, input_dir: str, processor: Processor, processor_type: str, subset="train", length=65536, buffer_size_gb=1.0, buffer_reload_rate=1000, half=False, num_examples_per_epoch=10000, ext="wav", soft_clip=True, ): super().__init__() self.input_dir = input_dir self.processor = processor self.processor_type = processor_type self.subset = subset self.length = length self.buffer_size_gb = buffer_size_gb self.buffer_reload_rate = buffer_reload_rate self.half = half self.num_examples_per_epoch = num_examples_per_epoch self.ext = ext self.soft_clip = soft_clip search_path = os.path.join(input_dir, f"*.{ext}") self.input_filepaths = glob.glob(search_path) self.input_filepaths = sorted(self.input_filepaths) if len(self.input_filepaths) < 1: raise RuntimeError(f"No files found in {input_dir}.") # get training split self.input_filepaths = utils.split_dataset( self.input_filepaths, self.subset, 0.9 ) # get details about audio files cnt = 0 self.input_files = {} for input_filepath in tqdm(self.input_filepaths, ncols=80): file_id = os.path.basename(input_filepath) audio_file = AudioFile( input_filepath, preload=False, half=half, ) if audio_file.num_frames < self.length: continue self.input_files[file_id] = audio_file self.sample_rate = self.input_files[file_id].sample_rate cnt += 1 if cnt > 1000: break # some setup for iteratble loading of the dataset into RAM self.items_since_load = self.buffer_reload_rate def __len__(self): return self.num_examples_per_epoch def load_audio_buffer(self): self.input_files_loaded = {} # clear audio buffer self.items_since_load = 0 # reset iteration counter nbytes_loaded = 0 # counter for data in RAM # different subset in each random.shuffle(self.input_filepaths) # load files into RAM for input_filepath in self.input_filepaths: file_id = os.path.basename(input_filepath) audio_file = AudioFile( input_filepath, preload=True, half=self.half, ) if audio_file.num_frames < self.length: continue self.input_files_loaded[file_id] = audio_file nbytes = audio_file.audio.element_size() * audio_file.audio.nelement() nbytes_loaded += nbytes if nbytes_loaded > self.buffer_size_gb * 1e9: break def __getitem__(self, _): """ """ # increment counter self.items_since_load += 1 # load next chunk into buffer if needed if self.items_since_load > self.buffer_reload_rate: self.load_audio_buffer() rand_input_file_id = utils.get_random_file_id(self.input_files_loaded.keys()) # use this random key to retrieve an input file input_file = self.input_files_loaded[rand_input_file_id] # load the audio data if needed if not input_file.loaded: input_file.load() # get a random patch of size `self.length` # start_idx, stop_idx = utils.get_random_patch(input_file, self.sample_rate, self.length) start_idx, stop_idx = utils.get_random_patch(input_file, self.length) input_audio = input_file.audio[:, start_idx:stop_idx].clone().detach() # random scaling input_audio /= input_audio.abs().max() scale_dB = (torch.rand(1).squeeze().numpy() * 12) + 12 input_audio *= 10 ** (-scale_dB / 20.0) # generate random parameters (uniform) over 0 to 1 params = torch.rand(self.processor.num_control_params) # expects batch dim # apply plugins with random parameters if self.processor_type == "channel": params[-1] = 0.5 # set makeup gain to 0dB target_audio = self.processor( input_audio.view(1, 1, -1), params.view(1, -1), ) target_audio = target_audio.view(1, -1) elif self.processor_type == "peq": target_audio = self.processor( input_audio.view(1, 1, -1).numpy(), params.view(1, -1).numpy(), ) target_audio = torch.tensor(target_audio).view(1, -1) elif self.processor_type == "comp": params[-1] = 0.5 # set makeup gain to 0dB target_audio = self.processor( input_audio.view(1, 1, -1).numpy(), params.view(1, -1).numpy(), ) target_audio = torch.tensor(target_audio).view(1, -1) # clip if self.soft_clip: # target_audio = target_audio.clamp(-2.0, 2.0) target_audio = torch.tanh(target_audio / 2.0) * 2.0 return input_audio, target_audio, params