|
import sys |
|
|
|
sys.path.append("src") |
|
import os |
|
import pandas as pd |
|
import yaml |
|
import audioldm_train.utilities.audio as Audio |
|
from audioldm_train.utilities.tools import load_json |
|
from audioldm_train.dataset_plugin import * |
|
from librosa.filters import mel as librosa_mel_fn |
|
|
|
import random |
|
from torch.utils.data import Dataset |
|
import torch.nn.functional |
|
import torch |
|
import numpy as np |
|
import torchaudio |
|
import json |
|
|
|
|
|
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): |
|
return torch.log(torch.clamp(x, min=clip_val) * C) |
|
|
|
|
|
def dynamic_range_decompression_torch(x, C=1): |
|
return torch.exp(x) / C |
|
|
|
|
|
def spectral_normalize_torch(magnitudes): |
|
output = dynamic_range_compression_torch(magnitudes) |
|
return output |
|
|
|
|
|
def spectral_de_normalize_torch(magnitudes): |
|
output = dynamic_range_decompression_torch(magnitudes) |
|
return output |
|
|
|
|
|
class AudioDataset(Dataset): |
|
def __init__( |
|
self, |
|
config=None, |
|
split="train", |
|
waveform_only=False, |
|
add_ons=[], |
|
dataset_json=None, |
|
): |
|
""" |
|
Dataset that manages audio recordings |
|
:param audio_conf: Dictionary containing the audio loading and preprocessing settings |
|
:param dataset_json_file |
|
""" |
|
self.config = config |
|
self.split = split |
|
self.pad_wav_start_sample = 0 |
|
self.trim_wav = False |
|
self.waveform_only = waveform_only |
|
self.add_ons = [eval(x) for x in add_ons] |
|
print("Add-ons:", self.add_ons) |
|
|
|
self.build_setting_parameters() |
|
|
|
|
|
if dataset_json is not None: |
|
self.data = dataset_json["data"] |
|
self.id2label, self.index_dict, self.num2label = {}, {}, {} |
|
else: |
|
self.metadata_root = load_json(self.config["metadata_root"]) |
|
self.dataset_name = self.config["data"][self.split] |
|
assert split in self.config["data"].keys(), ( |
|
"The dataset split %s you specified is not present in the config. You can choose from %s" |
|
% (split, self.config["data"].keys()) |
|
) |
|
self.build_dataset() |
|
self.build_id_to_label() |
|
|
|
self.build_dsp() |
|
self.label_num = len(self.index_dict) |
|
print("Dataset initialize finished") |
|
|
|
def __getitem__(self, index): |
|
( |
|
fname, |
|
waveform, |
|
stft, |
|
log_mel_spec, |
|
label_vector, |
|
|
|
(datum, mix_datum), |
|
random_start, |
|
) = self.feature_extraction(index) |
|
text = self.get_sample_text_caption(datum, mix_datum, label_vector) |
|
|
|
data = { |
|
"text": text, |
|
"fname": self.text_to_filename(text) if (not fname) else fname, |
|
|
|
"label_vector": "" if (label_vector is None) else label_vector.float(), |
|
|
|
"waveform": "" if (waveform is None) else waveform.float(), |
|
|
|
"stft": "" if (stft is None) else stft.float(), |
|
|
|
"log_mel_spec": "" if (log_mel_spec is None) else log_mel_spec.float(), |
|
"duration": self.duration, |
|
"sampling_rate": self.sampling_rate, |
|
"random_start_sample_in_original_audio_file": random_start, |
|
"mos": 1 |
|
} |
|
|
|
for add_on in self.add_ons: |
|
data.update(add_on(self.config, data, self.data[index])) |
|
|
|
if data["text"] is None: |
|
print("Warning: The model return None on key text", fname) |
|
data["text"] = "" |
|
|
|
return data |
|
|
|
def text_to_filename(self, text): |
|
return text.replace(" ", "_").replace("'", "_").replace('"', "_") |
|
|
|
def get_dataset_root_path(self, dataset): |
|
assert dataset in self.metadata_root.keys() |
|
return self.metadata_root[dataset] |
|
|
|
def get_dataset_metadata_path(self, dataset, key): |
|
|
|
try: |
|
if dataset in self.metadata_root["metadata"]["path"].keys(): |
|
return self.metadata_root["metadata"]["path"][dataset][key] |
|
except: |
|
raise ValueError( |
|
'Dataset %s does not metadata "%s" specified' % (dataset, key) |
|
) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def feature_extraction(self, index): |
|
if index > len(self.data) - 1: |
|
print( |
|
"The index of the dataloader is out of range: %s/%s" |
|
% (index, len(self.data)) |
|
) |
|
index = random.randint(0, len(self.data) - 1) |
|
|
|
|
|
while True: |
|
try: |
|
label_indices = np.zeros(self.label_num, dtype=np.float32) |
|
datum = self.data[index] |
|
( |
|
log_mel_spec, |
|
stft, |
|
waveform, |
|
random_start, |
|
) = self.read_audio_file(datum["wav"]) |
|
mix_datum = None |
|
if self.label_num > 0 and "labels" in datum.keys(): |
|
for label_str in datum["labels"].split(","): |
|
label_indices[int(self.index_dict[label_str])] = 1.0 |
|
|
|
|
|
label_indices = torch.FloatTensor(label_indices) |
|
break |
|
except Exception as e: |
|
index = (index + 1) % len(self.data) |
|
print( |
|
"Error encounter during audio feature extraction: ", e, datum["wav"] |
|
) |
|
continue |
|
|
|
|
|
fname = datum["wav"] |
|
|
|
|
|
waveform = torch.FloatTensor(waveform) |
|
|
|
return ( |
|
fname, |
|
waveform, |
|
stft, |
|
log_mel_spec, |
|
label_indices, |
|
(datum, mix_datum), |
|
random_start, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_setting_parameters(self): |
|
|
|
self.melbins = self.config["preprocessing"]["mel"]["n_mel_channels"] |
|
|
|
|
|
self.sampling_rate = self.config["preprocessing"]["audio"]["sampling_rate"] |
|
self.hopsize = self.config["preprocessing"]["stft"]["hop_length"] |
|
self.duration = self.config["preprocessing"]["audio"]["duration"] |
|
self.target_length = int(self.duration * self.sampling_rate / self.hopsize) |
|
|
|
self.mixup = self.config["augmentation"]["mixup"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "train" not in self.split: |
|
self.mixup = 0.0 |
|
|
|
|
|
|
|
def _relative_path_to_absolute_path(self, metadata, dataset_name): |
|
root_path = self.get_dataset_root_path(dataset_name) |
|
for i in range(len(metadata["data"])): |
|
assert "wav" in metadata["data"][i].keys(), metadata["data"][i] |
|
assert metadata["data"][i]["wav"][0] != "/", ( |
|
"The dataset metadata should only contain relative path to the audio file: " |
|
+ str(metadata["data"][i]["wav"]) |
|
) |
|
metadata["data"][i]["wav"] = os.path.join( |
|
root_path, metadata["data"][i]["wav"] |
|
) |
|
return metadata |
|
|
|
def build_dataset(self): |
|
self.data = [] |
|
print("Build dataset split %s from %s" % (self.split, self.dataset_name)) |
|
if type(self.dataset_name) is str: |
|
data_json = load_json( |
|
self.get_dataset_metadata_path(self.dataset_name, key=self.split) |
|
) |
|
data_json = self._relative_path_to_absolute_path( |
|
data_json, self.dataset_name |
|
) |
|
self.data = data_json["data"] |
|
elif type(self.dataset_name) is list: |
|
for dataset_name in self.dataset_name: |
|
data_json = load_json( |
|
self.get_dataset_metadata_path(dataset_name, key=self.split) |
|
) |
|
data_json = self._relative_path_to_absolute_path( |
|
data_json, dataset_name |
|
) |
|
self.data += data_json["data"] |
|
else: |
|
raise Exception("Invalid data format") |
|
print("Data size: {}".format(len(self.data))) |
|
|
|
def build_dsp(self): |
|
self.mel_basis = {} |
|
self.hann_window = {} |
|
|
|
self.filter_length = self.config["preprocessing"]["stft"]["filter_length"] |
|
self.hop_length = self.config["preprocessing"]["stft"]["hop_length"] |
|
self.win_length = self.config["preprocessing"]["stft"]["win_length"] |
|
self.n_mel = self.config["preprocessing"]["mel"]["n_mel_channels"] |
|
self.sampling_rate = self.config["preprocessing"]["audio"]["sampling_rate"] |
|
self.mel_fmin = self.config["preprocessing"]["mel"]["mel_fmin"] |
|
self.mel_fmax = self.config["preprocessing"]["mel"]["mel_fmax"] |
|
|
|
self.STFT = Audio.stft.TacotronSTFT( |
|
self.config["preprocessing"]["stft"]["filter_length"], |
|
self.config["preprocessing"]["stft"]["hop_length"], |
|
self.config["preprocessing"]["stft"]["win_length"], |
|
self.config["preprocessing"]["mel"]["n_mel_channels"], |
|
self.config["preprocessing"]["audio"]["sampling_rate"], |
|
self.config["preprocessing"]["mel"]["mel_fmin"], |
|
self.config["preprocessing"]["mel"]["mel_fmax"], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_id_to_label(self): |
|
id2label = {} |
|
id2num = {} |
|
num2label = {} |
|
class_label_indices_path = self.get_dataset_metadata_path( |
|
dataset=self.config["data"]["class_label_indices"], |
|
key="class_label_indices", |
|
) |
|
if class_label_indices_path is not None: |
|
df = pd.read_csv(class_label_indices_path) |
|
for _, row in df.iterrows(): |
|
index, mid, display_name = row["index"], row["mid"], row["display_name"] |
|
id2label[mid] = display_name |
|
id2num[mid] = index |
|
num2label[index] = display_name |
|
self.id2label, self.index_dict, self.num2label = id2label, id2num, num2label |
|
else: |
|
self.id2label, self.index_dict, self.num2label = {}, {}, {} |
|
|
|
def resample(self, waveform, sr): |
|
waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate) |
|
return waveform |
|
|
|
def normalize_wav(self, waveform): |
|
waveform = waveform - np.mean(waveform) |
|
waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) |
|
return waveform * 0.5 |
|
|
|
def random_segment_wav(self, waveform, target_length): |
|
waveform_length = waveform.shape[-1] |
|
assert waveform_length > 100, "Waveform is too short, %s" % waveform_length |
|
|
|
|
|
if (waveform_length - target_length) <= 0: |
|
return waveform, 0 |
|
|
|
for i in range(10): |
|
random_start = int(self.random_uniform(0, waveform_length - target_length)) |
|
if torch.max( |
|
torch.abs(waveform[:, random_start : random_start + target_length]) |
|
> 1e-4 |
|
): |
|
break |
|
|
|
return waveform[:, random_start : random_start + target_length], random_start |
|
|
|
def pad_wav(self, waveform, target_length): |
|
waveform_length = waveform.shape[-1] |
|
assert waveform_length > 100, "Waveform is too short, %s" % waveform_length |
|
|
|
if waveform_length == target_length: |
|
return waveform |
|
|
|
|
|
temp_wav = np.zeros((1, target_length), dtype=np.float32) |
|
if self.pad_wav_start_sample is None: |
|
rand_start = int(self.random_uniform(0, target_length - waveform_length)) |
|
else: |
|
rand_start = 0 |
|
|
|
temp_wav[:, rand_start : rand_start + waveform_length] = waveform |
|
return temp_wav |
|
|
|
def trim_wav(self, waveform): |
|
if np.max(np.abs(waveform)) < 0.0001: |
|
return waveform |
|
|
|
def detect_leading_silence(waveform, threshold=0.0001): |
|
chunk_size = 1000 |
|
waveform_length = waveform.shape[0] |
|
start = 0 |
|
while start + chunk_size < waveform_length: |
|
if np.max(np.abs(waveform[start : start + chunk_size])) < threshold: |
|
start += chunk_size |
|
else: |
|
break |
|
return start |
|
|
|
def detect_ending_silence(waveform, threshold=0.0001): |
|
chunk_size = 1000 |
|
waveform_length = waveform.shape[0] |
|
start = waveform_length |
|
while start - chunk_size > 0: |
|
if np.max(np.abs(waveform[start - chunk_size : start])) < threshold: |
|
start -= chunk_size |
|
else: |
|
break |
|
if start == waveform_length: |
|
return start |
|
else: |
|
return start + chunk_size |
|
|
|
start = detect_leading_silence(waveform) |
|
end = detect_ending_silence(waveform) |
|
|
|
return waveform[start:end] |
|
|
|
def read_wav_file(self, filename): |
|
|
|
waveform, sr = torchaudio.load(filename) |
|
|
|
waveform, random_start = self.random_segment_wav( |
|
waveform, target_length=int(sr * self.duration) |
|
) |
|
|
|
waveform = self.resample(waveform, sr) |
|
|
|
|
|
waveform = waveform.numpy()[0, ...] |
|
|
|
waveform = self.normalize_wav(waveform) |
|
|
|
if self.trim_wav: |
|
waveform = self.trim_wav(waveform) |
|
|
|
waveform = waveform[None, ...] |
|
waveform = self.pad_wav( |
|
waveform, target_length=int(self.sampling_rate * self.duration) |
|
) |
|
return waveform, random_start |
|
|
|
def read_audio_file(self, filename, filename2=None): |
|
if os.path.exists(filename): |
|
waveform, random_start = self.read_wav_file(filename) |
|
else: |
|
print( |
|
'Non-fatal Warning [dataset.py]: The wav path "', |
|
filename, |
|
'" is not find in the metadata. Use empty waveform instead. This is normal in the inference process.', |
|
) |
|
target_length = int(self.sampling_rate * self.duration) |
|
waveform = torch.zeros((1, target_length)) |
|
random_start = 0 |
|
|
|
|
|
if not self.waveform_only: |
|
log_mel_spec, stft = self.wav_feature_extraction(waveform) |
|
else: |
|
|
|
|
|
log_mel_spec, stft = None, None |
|
|
|
return log_mel_spec, stft, waveform, random_start |
|
|
|
def get_sample_text_caption(self, datum, mix_datum, label_indices): |
|
text = self.label_indices_to_text(datum, label_indices) |
|
if mix_datum is not None: |
|
text += " " + self.label_indices_to_text(mix_datum, label_indices) |
|
return text |
|
|
|
def mel_spectrogram_train(self, y): |
|
if torch.min(y) < -1.0: |
|
print("train min value is ", torch.min(y)) |
|
if torch.max(y) > 1.0: |
|
print("train max value is ", torch.max(y)) |
|
|
|
if self.mel_fmax not in self.mel_basis: |
|
mel = librosa_mel_fn( |
|
self.sampling_rate, |
|
self.filter_length, |
|
self.n_mel, |
|
self.mel_fmin, |
|
self.mel_fmax, |
|
) |
|
self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)] = ( |
|
torch.from_numpy(mel).float().to(y.device) |
|
) |
|
self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to( |
|
y.device |
|
) |
|
|
|
y = torch.nn.functional.pad( |
|
y.unsqueeze(1), |
|
( |
|
int((self.filter_length - self.hop_length) / 2), |
|
int((self.filter_length - self.hop_length) / 2), |
|
), |
|
mode="reflect", |
|
) |
|
|
|
y = y.squeeze(1) |
|
|
|
stft_spec = torch.stft( |
|
y, |
|
self.filter_length, |
|
hop_length=self.hop_length, |
|
win_length=self.win_length, |
|
window=self.hann_window[str(y.device)], |
|
center=False, |
|
pad_mode="reflect", |
|
normalized=False, |
|
onesided=True, |
|
return_complex=True, |
|
) |
|
|
|
stft_spec = torch.abs(stft_spec) |
|
|
|
mel = spectral_normalize_torch( |
|
torch.matmul( |
|
self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)], stft_spec |
|
) |
|
) |
|
|
|
return mel[0], stft_spec[0] |
|
|
|
|
|
def wav_feature_extraction(self, waveform): |
|
waveform = waveform[0, ...] |
|
waveform = torch.FloatTensor(waveform) |
|
|
|
|
|
log_mel_spec, stft = self.mel_spectrogram_train(waveform.unsqueeze(0)) |
|
|
|
log_mel_spec = torch.FloatTensor(log_mel_spec.T) |
|
stft = torch.FloatTensor(stft.T) |
|
|
|
log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft) |
|
return log_mel_spec, stft |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pad_spec(self, log_mel_spec): |
|
n_frames = log_mel_spec.shape[0] |
|
p = self.target_length - n_frames |
|
|
|
if p > 0: |
|
m = torch.nn.ZeroPad2d((0, 0, 0, p)) |
|
log_mel_spec = m(log_mel_spec) |
|
elif p < 0: |
|
log_mel_spec = log_mel_spec[0 : self.target_length, :] |
|
|
|
if log_mel_spec.size(-1) % 2 != 0: |
|
log_mel_spec = log_mel_spec[..., :-1] |
|
|
|
return log_mel_spec |
|
|
|
def _read_datum_caption(self, datum): |
|
caption_keys = [x for x in datum.keys() if ("caption" in x)] |
|
random_index = torch.randint(0, len(caption_keys), (1,))[0].item() |
|
return datum[caption_keys[random_index]] |
|
|
|
def _is_contain_caption(self, datum): |
|
caption_keys = [x for x in datum.keys() if ("caption" in x)] |
|
return len(caption_keys) > 0 |
|
|
|
def label_indices_to_text(self, datum, label_indices): |
|
if self._is_contain_caption(datum): |
|
return self._read_datum_caption(datum) |
|
elif "label" in datum.keys(): |
|
name_indices = torch.where(label_indices > 0.1)[0] |
|
|
|
description_header = "" |
|
labels = "" |
|
for id, each in enumerate(name_indices): |
|
if id == len(name_indices) - 1: |
|
labels += "%s." % self.num2label[int(each)] |
|
else: |
|
labels += "%s, " % self.num2label[int(each)] |
|
return description_header + labels |
|
else: |
|
return "" |
|
|
|
def random_uniform(self, start, end): |
|
val = torch.rand(1).item() |
|
return start + (end - start) * val |
|
|
|
def frequency_masking(self, log_mel_spec, freqm): |
|
bs, freq, tsteps = log_mel_spec.size() |
|
mask_len = int(self.random_uniform(freqm // 8, freqm)) |
|
mask_start = int(self.random_uniform(start=0, end=freq - mask_len)) |
|
log_mel_spec[:, mask_start : mask_start + mask_len, :] *= 0.0 |
|
return log_mel_spec |
|
|
|
def time_masking(self, log_mel_spec, timem): |
|
bs, freq, tsteps = log_mel_spec.size() |
|
mask_len = int(self.random_uniform(timem // 8, timem)) |
|
mask_start = int(self.random_uniform(start=0, end=tsteps - mask_len)) |
|
log_mel_spec[:, :, mask_start : mask_start + mask_len] *= 0.0 |
|
return log_mel_spec |
|
|
|
|
|
if __name__ == "__main__": |
|
import torch |
|
from tqdm import tqdm |
|
from pytorch_lightning import seed_everything |
|
from torch.utils.data import DataLoader |
|
|
|
seed_everything(0) |
|
|
|
def write_json(my_dict, fname): |
|
|
|
json_str = json.dumps(my_dict) |
|
with open(fname, "w") as json_file: |
|
json_file.write(json_str) |
|
|
|
def load_json(fname): |
|
with open(fname, "r") as f: |
|
data = json.load(f) |
|
return data |
|
|
|
config = yaml.load( |
|
open( |
|
"/mnt/bn/lqhaoheliu/project/audio_generation_diffusion/config/vae_48k_256/ds_8_kl_1.0_ch_16.yaml", |
|
"r", |
|
), |
|
Loader=yaml.FullLoader, |
|
) |
|
|
|
add_ons = config["data"]["dataloader_add_ons"] |
|
|
|
|
|
dataset = AudioDataset( |
|
config=config, split="train", waveform_only=False, add_ons=add_ons |
|
) |
|
|
|
loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=True) |
|
|
|
for cnt, each in tqdm(enumerate(loader)): |
|
|
|
|
|
import ipdb |
|
|
|
ipdb.set_trace() |
|
|
|
|