yourusername's picture
:beers: cheers
66a6dc0
import os
import sys
import csv
import glob
import torch
import random
from tqdm import tqdm
from typing import List, Any
from deepafx_st.data.audio import AudioFile
import deepafx_st.utils as utils
import deepafx_st.data.augmentations as augmentations
class AudioDataset(torch.utils.data.Dataset):
"""Audio dataset which returns an input and target file.
Args:
audio_dir (str): Path to the top level of the audio dataset.
input_dir (List[str], optional): List of paths to the directories containing input audio files. Default: ["clean"]
subset (str, optional): Dataset subset. One of ["train", "val", "test"]. Default: "train"
length (int, optional): Number of samples to load for each example. Default: 65536
train_frac (float, optional): Fraction of the files to use for training subset. Default: 0.8
val_frac (float, optional): Fraction of the files to use for validation subset. Default: 0.1
buffer_size_gb (float, optional): Size of audio to read into RAM in GB at any given time. Default: 10.0
Note: This is the buffer size PER DataLoader worker. So total RAM = buffer_size_gb * num_workers
buffer_reload_rate (int, optional): Number of items to generate before loading next chunk of dataset. Default: 10000
half (bool, optional): Sotre audio samples as float 16. Default: False
num_examples_per_epoch (int, optional): Define an epoch as certain number of audio examples. Default: 10000
random_scale_input (bool, optional): Apply random gain scaling to input utterances. Default: False
random_scale_target (bool, optional): Apply same random gain scaling to target utterances. Default: False
augmentations (dict, optional): List of augmentation types to apply to inputs. Default: []
freq_corrupt (bool, optional): Apply bad EQ filters. Default: False
drc_corrupt (bool, optional): Apply an expander to corrupt dynamic range. Default: False
ext (str, optional): Expected audio file extension. Default: "wav"
"""
def __init__(
self,
audio_dir,
input_dirs: List[str] = ["cleanraw"],
subset: str = "train",
length: int = 65536,
train_frac: float = 0.8,
val_per: float = 0.1,
buffer_size_gb: float = 1.0,
buffer_reload_rate: float = 1000,
half: bool = False,
num_examples_per_epoch: int = 10000,
random_scale_input: bool = False,
random_scale_target: bool = False,
augmentations: dict = {},
freq_corrupt: bool = False,
drc_corrupt: bool = False,
ext: str = "wav",
):
super().__init__()
self.audio_dir = audio_dir
self.dataset_name = os.path.basename(audio_dir)
self.input_dirs = input_dirs
self.subset = subset
self.length = length
self.train_frac = train_frac
self.val_per = val_per
self.buffer_size_gb = buffer_size_gb
self.buffer_reload_rate = buffer_reload_rate
self.half = half
self.num_examples_per_epoch = num_examples_per_epoch
self.random_scale_input = random_scale_input
self.random_scale_target = random_scale_target
self.augmentations = augmentations
self.freq_corrupt = freq_corrupt
self.drc_corrupt = drc_corrupt
self.ext = ext
self.input_filepaths = []
for input_dir in input_dirs:
search_path = os.path.join(audio_dir, input_dir, f"*.{ext}")
self.input_filepaths += glob.glob(search_path)
self.input_filepaths = sorted(self.input_filepaths)
# create dataset split based on subset
self.input_filepaths = utils.split_dataset(
self.input_filepaths,
subset,
train_frac,
)
# get details about input audio files
input_files = {}
input_dur_frames = 0
for input_filepath in tqdm(self.input_filepaths, ncols=80):
file_id = os.path.basename(input_filepath)
audio_file = AudioFile(
input_filepath,
preload=False,
half=half,
)
if audio_file.num_frames < (self.length * 2):
continue
input_files[file_id] = audio_file
input_dur_frames += input_files[file_id].num_frames
if len(list(input_files.items())) < 1:
raise RuntimeError(f"No files found in {search_path}.")
input_dur_hr = (input_dur_frames / input_files[file_id].sample_rate) / 3600
print(
f"\nLoaded {len(input_files)} files for {subset} = {input_dur_hr:0.2f} hours."
)
self.sample_rate = input_files[file_id].sample_rate
# save a csv file with details about the train and test split
splits_dir = os.path.join("configs", "splits")
if not os.path.isdir(splits_dir):
os.makedirs(splits_dir)
csv_filepath = os.path.join(splits_dir, f"{self.dataset_name}_{self.subset}_set.csv")
with open(csv_filepath, "w") as fp:
dw = csv.DictWriter(fp, ["file_id", "filepath", "type", "subset"])
dw.writeheader()
for input_filepath in self.input_filepaths:
dw.writerow(
{
"file_id": self.get_file_id(input_filepath),
"filepath": input_filepath,
"type": "input",
"subset": self.subset,
}
)
# some setup for iteratble loading of the dataset into RAM
self.items_since_load = self.buffer_reload_rate
def __len__(self):
return self.num_examples_per_epoch
def load_audio_buffer(self):
self.input_files_loaded = {} # clear audio buffer
self.items_since_load = 0 # reset iteration counter
nbytes_loaded = 0 # counter for data in RAM
# different subset in each
random.shuffle(self.input_filepaths)
# load files into RAM
for input_filepath in self.input_filepaths:
file_id = os.path.basename(input_filepath)
audio_file = AudioFile(
input_filepath,
preload=True,
half=self.half,
)
if audio_file.num_frames < (self.length * 2):
continue
self.input_files_loaded[file_id] = audio_file
nbytes = audio_file.audio.element_size() * audio_file.audio.nelement()
nbytes_loaded += nbytes
# check the size of loaded data
if nbytes_loaded > self.buffer_size_gb * 1e9:
break
def generate_pair(self):
# ------------------------ Input audio ----------------------
rand_input_file_id = None
input_file = None
start_idx = None
stop_idx = None
while True:
rand_input_file_id = self.get_random_file_id(self.input_files_loaded.keys())
# use this random key to retrieve an input file
input_file = self.input_files_loaded[rand_input_file_id]
# load the audio data if needed
if not input_file.loaded:
raise RuntimeError("Audio not loaded.")
# get a random patch of size `self.length` x 2
start_idx, stop_idx = self.get_random_patch(
input_file, int(self.length * 2)
)
if start_idx >= 0:
break
input_audio = input_file.audio[:, start_idx:stop_idx].clone().detach()
input_audio = input_audio.view(1, -1)
if self.half:
input_audio = input_audio.float()
# peak normalize to -12 dBFS
input_audio /= input_audio.abs().max()
input_audio *= 10 ** (-12.0 / 20) # with min 3 dBFS headroom
if len(list(self.augmentations.items())) > 0:
if torch.rand(1).sum() < 0.5:
input_audio_aug = augmentations.apply(
[input_audio],
self.sample_rate,
self.augmentations,
)[0]
else:
input_audio_aug = input_audio.clone()
else:
input_audio_aug = input_audio.clone()
input_audio_corrupt = input_audio_aug.clone()
# apply frequency and dynamic range corrpution (expander)
if self.freq_corrupt and torch.rand(1).sum() < 0.75:
input_audio_corrupt = augmentations.frequency_corruption(
[input_audio_corrupt], self.sample_rate
)[0]
# peak normalize again before passing through dynamic range expander
input_audio_corrupt /= input_audio_corrupt.abs().max()
input_audio_corrupt *= 10 ** (-12.0 / 20) # with min 3 dBFS headroom
if self.drc_corrupt and torch.rand(1).sum() < 0.10:
input_audio_corrupt = augmentations.dynamic_range_corruption(
[input_audio_corrupt], self.sample_rate
)[0]
# ------------------------ Target audio ----------------------
# use the same augmented audio clip, add different random EQ and compressor
target_audio_corrupt = input_audio_aug.clone()
# apply frequency and dynamic range corrpution (expander)
if self.freq_corrupt and torch.rand(1).sum() < 0.75:
target_audio_corrupt = augmentations.frequency_corruption(
[target_audio_corrupt], self.sample_rate
)[0]
# peak normalize again before passing through dynamic range compressor
input_audio_corrupt /= input_audio_corrupt.abs().max()
input_audio_corrupt *= 10 ** (-12.0 / 20) # with min 3 dBFS headroom
if self.drc_corrupt and torch.rand(1).sum() < 0.75:
target_audio_corrupt = augmentations.dynamic_range_compression(
[target_audio_corrupt], self.sample_rate
)[0]
return input_audio_corrupt, target_audio_corrupt
def __getitem__(self, _):
""" """
# increment counter
self.items_since_load += 1
# load next chunk into buffer if needed
if self.items_since_load > self.buffer_reload_rate:
self.load_audio_buffer()
# generate pairs for style training
input_audio, target_audio = self.generate_pair()
# ------------------------ Conform length of files -------------------
input_audio = utils.conform_length(input_audio, int(self.length * 2))
target_audio = utils.conform_length(target_audio, int(self.length * 2))
# ------------------------ Apply fade in and fade out -------------------
input_audio = utils.linear_fade(input_audio, sample_rate=self.sample_rate)
target_audio = utils.linear_fade(target_audio, sample_rate=self.sample_rate)
# ------------------------ Final normalizeation ----------------------
# always peak normalize final input to -12 dBFS
input_audio /= input_audio.abs().max()
input_audio *= 10 ** (-12.0 / 20.0)
# always peak normalize the target to -12 dBFS
target_audio /= target_audio.abs().max()
target_audio *= 10 ** (-12.0 / 20.0)
return input_audio, target_audio
@staticmethod
def get_random_file_id(keys):
# generate a random index into the keys of the input files
rand_input_idx = torch.randint(0, len(keys) - 1, [1])[0]
# find the key (file_id) correponding to the random index
rand_input_file_id = list(keys)[rand_input_idx]
return rand_input_file_id
@staticmethod
def get_random_patch(audio_file, length, check_silence=True):
silent = True
count = 0
while silent:
count += 1
start_idx = torch.randint(0, audio_file.num_frames - length - 1, [1])[0]
# int(torch.rand(1) * (audio_file.num_frames - length))
stop_idx = start_idx + length
patch = audio_file.audio[:, start_idx:stop_idx].clone().detach()
length = patch.shape[-1]
first_patch = patch[..., : length // 2]
second_patch = patch[..., length // 2 :]
if (
(first_patch**2).mean() > 1e-5 and (second_patch**2).mean() > 1e-5
) or not check_silence:
silent = False
if count > 100:
print("get_random_patch count", count)
return -1, -1
# break
return start_idx, stop_idx
def get_file_id(self, filepath):
"""Given a filepath extract the DAPS file id.
Args:
filepath (str): Path to an audio files in the DAPS dataset.
Returns:
file_id (str): DAPS file id of the form <participant_id>_<script_id>
file_set (str): The DAPS set to which the file belongs.
"""
file_id = os.path.basename(filepath).split("_")[:2]
file_id = "_".join(file_id)
return file_id
def get_file_set(self, filepath):
"""Given a filepath extract the DAPS file set name.
Args:
filepath (str): Path to an audio files in the DAPS dataset.
Returns:
file_set (str): The DAPS set to which the file belongs.
"""
file_set = os.path.basename(filepath).split("_")[2:]
file_set = "_".join(file_set)
file_set = file_set.replace(f".{self.ext}", "")
return file_set