FRN / dataset.py
anhnv125's picture
Update dataset.py
dc58348
import glob
import os
import random
import librosa
import numpy as np
import soundfile as sf
import torch
from numpy.random import default_rng
from pydtmc import MarkovChain
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from config import CONFIG
np.random.seed(0)
rng = default_rng()
def load_audio(
path,
sample_rate: int = 16000,
chunk_len=None,
):
with sf.SoundFile(path) as f:
sr = f.samplerate
audio_len = f.frames
if chunk_len is not None and chunk_len < audio_len:
start_index = torch.randint(0, audio_len - chunk_len, (1,))[0]
frames = f._prepare_read(start_index, start_index + chunk_len, -1)
audio = f.read(frames, always_2d=True, dtype="float32")
else:
audio = f.read(always_2d=True, dtype="float32")
if sr != sample_rate:
audio = librosa.resample(np.squeeze(audio), sr, sample_rate)[:, np.newaxis]
return audio.T
def pad(sig, length):
if sig.shape[1] < length:
pad_len = length - sig.shape[1]
sig = torch.hstack((sig, torch.zeros((sig.shape[0], pad_len))))
else:
start = random.randint(0, sig.shape[1] - length)
sig = sig[:, start:start + length]
return sig
class MaskGenerator:
def __init__(self, is_train=True, probs=((0.9, 0.1), (0.5, 0.1), (0.5, 0.5))):
'''
is_train: if True, mask generator for training otherwise for evaluation
probs: a list of transition probability (p_N, p_L) for Markov Chain. Only allow 1 tuple if 'is_train=False'
'''
self.is_train = is_train
self.probs = probs
self.mcs = []
if self.is_train:
for prob in probs:
self.mcs.append(MarkovChain([[prob[0], 1 - prob[0]], [1 - prob[1], prob[1]]], ['1', '0']))
else:
assert len(probs) == 1
prob = self.probs[0]
self.mcs.append(MarkovChain([[prob[0], 1 - prob[0]], [1 - prob[1], prob[1]]], ['1', '0']))
def gen_mask(self, length, seed=0):
if self.is_train:
mc = random.choice(self.mcs)
else:
mc = self.mcs[0]
mask = mc.walk(length - 1, seed=seed)
mask = np.array(list(map(int, mask)))
return mask
class TestLoader(Dataset):
def __init__(self):
dataset_name = CONFIG.DATA.dataset
self.mask = CONFIG.DATA.EVAL.masking
self.target_root = CONFIG.DATA.data_dir[dataset_name]['root']
txt_list = CONFIG.DATA.data_dir[dataset_name]['test']
self.data_list = self.load_txt(txt_list)
if self.mask == 'real':
trace_txt = glob.glob(os.path.join(CONFIG.DATA.EVAL.trace_path, '*.txt'))
trace_txt.sort()
self.trace_list = [1 - np.array(list(map(int, open(txt, 'r').read().strip('\n').split('\n')))) for txt in
trace_txt]
else:
self.mask_generator = MaskGenerator(is_train=False, probs=CONFIG.DATA.EVAL.transition_probs)
self.sr = CONFIG.DATA.sr
self.stride = CONFIG.DATA.stride
self.window_size = CONFIG.DATA.window_size
self.audio_chunk_len = CONFIG.DATA.audio_chunk_len
self.p_size = CONFIG.DATA.EVAL.packet_size # 20ms
self.hann = torch.sqrt(torch.hann_window(self.window_size))
def __len__(self):
return len(self.data_list)
def load_txt(self, txt_list):
target = []
with open(txt_list) as f:
for line in f:
target.append(os.path.join(self.target_root, line.strip('\n')))
target = list(set(target))
target.sort()
return target
def __getitem__(self, index):
target = load_audio(self.data_list[index], sample_rate=self.sr)
target = target[:, :(target.shape[1] // self.p_size) * self.p_size]
sig = np.reshape(target, (-1, self.p_size)).copy()
if self.mask == 'real':
mask = self.trace_list[index % len(self.trace_list)]
mask = np.repeat(mask, np.ceil(len(sig) / len(mask)), 0)[:len(sig)][:, np.newaxis]
else:
mask = self.mask_generator.gen_mask(len(sig), seed=index)[:, np.newaxis]
sig *= mask
sig = torch.tensor(sig).reshape(-1)
target = torch.tensor(target).squeeze(0)
sig_wav = sig.clone()
target_wav = target.clone()
target = torch.stft(target, self.window_size, self.stride, window=self.hann,
return_complex=False).permute(2, 0, 1)
sig = torch.stft(sig, self.window_size, self.stride, window=self.hann, return_complex=False).permute(2, 0, 1)
return sig.float(), target.float(), sig_wav, target_wav
class BlindTestLoader(Dataset):
def __init__(self, test_dir):
self.data_list = glob.glob(os.path.join(test_dir, '*.wav'))
self.sr = CONFIG.DATA.sr
self.stride = CONFIG.DATA.stride
self.chunk_len = CONFIG.DATA.window_size
self.hann = torch.sqrt(torch.hann_window(self.chunk_len))
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
sig = load_audio(self.data_list[index], sample_rate=self.sr)
sig = torch.from_numpy(sig).squeeze(0)
sig = torch.stft(sig, self.chunk_len, self.stride, window=self.hann, return_complex=False).permute(2, 0, 1)
return sig.float()
class TrainDataset(Dataset):
def __init__(self, mode='train'):
dataset_name = CONFIG.DATA.dataset
self.target_root = CONFIG.DATA.data_dir[dataset_name]['root']
txt_list = CONFIG.DATA.data_dir[dataset_name]['train']
self.data_list = self.load_txt(txt_list)
if mode == 'train':
self.data_list, _ = train_test_split(self.data_list, test_size=CONFIG.TRAIN.val_split, random_state=0)
elif mode == 'val':
_, self.data_list = train_test_split(self.data_list, test_size=CONFIG.TRAIN.val_split, random_state=0)
self.p_sizes = CONFIG.DATA.TRAIN.packet_sizes
self.mode = mode
self.sr = CONFIG.DATA.sr
self.window = CONFIG.DATA.audio_chunk_len
self.stride = CONFIG.DATA.stride
self.chunk_len = CONFIG.DATA.window_size
self.hann = torch.sqrt(torch.hann_window(self.chunk_len))
self.mask_generator = MaskGenerator(is_train=True, probs=CONFIG.DATA.TRAIN.transition_probs)
def __len__(self):
return len(self.data_list)
def load_txt(self, txt_list):
target = []
with open(txt_list) as f:
for line in f:
target.append(os.path.join(self.target_root, line.strip('\n')))
target = list(set(target))
target.sort()
return target
def fetch_audio(self, index):
sig = load_audio(self.data_list[index], sample_rate=self.sr, chunk_len=self.window)
while sig.shape[1] < self.window:
idx = torch.randint(0, len(self.data_list), (1,))[0]
pad_len = self.window - sig.shape[1]
if pad_len < 0.02 * self.sr:
padding = np.zeros((1, pad_len), dtype=np.float)
else:
padding = load_audio(self.data_list[idx], sample_rate=self.sr, chunk_len=pad_len)
sig = np.hstack((sig, padding))
return sig
def __getitem__(self, index):
sig = self.fetch_audio(index)
sig = sig.reshape(-1).astype(np.float32)
target = torch.tensor(sig.copy())
p_size = random.choice(self.p_sizes)
sig = np.reshape(sig, (-1, p_size))
mask = self.mask_generator.gen_mask(len(sig), seed=index)[:, np.newaxis]
sig *= mask
sig = torch.tensor(sig.copy()).reshape(-1)
target = torch.stft(target, self.chunk_len, self.stride, window=self.hann,
return_complex=False).permute(2, 0, 1).float()
sig = torch.stft(sig, self.chunk_len, self.stride, window=self.hann, return_complex=False)
sig = sig.permute(2, 0, 1).float()
return sig, target