import glob import os import random import librosa import numpy as np import soundfile as sf import torch from numpy.random import default_rng from pydtmc import MarkovChain from sklearn.model_selection import train_test_split from torch.utils.data import Dataset from config import CONFIG np.random.seed(0) rng = default_rng() def load_audio( path, sample_rate: int = 16000, chunk_len=None, ): with sf.SoundFile(path) as f: sr = f.samplerate audio_len = f.frames if chunk_len is not None and chunk_len < audio_len: start_index = torch.randint(0, audio_len - chunk_len, (1,))[0] frames = f._prepare_read(start_index, start_index + chunk_len, -1) audio = f.read(frames, always_2d=True, dtype="float32") else: audio = f.read(always_2d=True, dtype="float32") if sr != sample_rate: audio = librosa.resample(np.squeeze(audio), sr, sample_rate)[:, np.newaxis] return audio.T def pad(sig, length): if sig.shape[1] < length: pad_len = length - sig.shape[1] sig = torch.hstack((sig, torch.zeros((sig.shape[0], pad_len)))) else: start = random.randint(0, sig.shape[1] - length) sig = sig[:, start:start + length] return sig class MaskGenerator: def __init__(self, is_train=True, probs=((0.9, 0.1), (0.5, 0.1), (0.5, 0.5))): ''' is_train: if True, mask generator for training otherwise for evaluation probs: a list of transition probability (p_N, p_L) for Markov Chain. Only allow 1 tuple if 'is_train=False' ''' self.is_train = is_train self.probs = probs self.mcs = [] if self.is_train: for prob in probs: self.mcs.append(MarkovChain([[prob[0], 1 - prob[0]], [1 - prob[1], prob[1]]], ['1', '0'])) else: assert len(probs) == 1 prob = self.probs[0] self.mcs.append(MarkovChain([[prob[0], 1 - prob[0]], [1 - prob[1], prob[1]]], ['1', '0'])) def gen_mask(self, length, seed=0): if self.is_train: mc = random.choice(self.mcs) else: mc = self.mcs[0] mask = mc.walk(length - 1, seed=seed) mask = np.array(list(map(int, mask))) return mask class TestLoader(Dataset): def __init__(self): dataset_name = CONFIG.DATA.dataset self.mask = CONFIG.DATA.EVAL.masking self.target_root = CONFIG.DATA.data_dir[dataset_name]['root'] txt_list = CONFIG.DATA.data_dir[dataset_name]['test'] self.data_list = self.load_txt(txt_list) if self.mask == 'real': trace_txt = glob.glob(os.path.join(CONFIG.DATA.EVAL.trace_path, '*.txt')) trace_txt.sort() self.trace_list = [1 - np.array(list(map(int, open(txt, 'r').read().strip('\n').split('\n')))) for txt in trace_txt] else: self.mask_generator = MaskGenerator(is_train=False, probs=CONFIG.DATA.EVAL.transition_probs) self.sr = CONFIG.DATA.sr self.stride = CONFIG.DATA.stride self.window_size = CONFIG.DATA.window_size self.audio_chunk_len = CONFIG.DATA.audio_chunk_len self.p_size = CONFIG.DATA.EVAL.packet_size # 20ms self.hann = torch.sqrt(torch.hann_window(self.window_size)) def __len__(self): return len(self.data_list) def load_txt(self, txt_list): target = [] with open(txt_list) as f: for line in f: target.append(os.path.join(self.target_root, line.strip('\n'))) target = list(set(target)) target.sort() return target def __getitem__(self, index): target = load_audio(self.data_list[index], sample_rate=self.sr) target = target[:, :(target.shape[1] // self.p_size) * self.p_size] sig = np.reshape(target, (-1, self.p_size)).copy() if self.mask == 'real': mask = self.trace_list[index % len(self.trace_list)] mask = np.repeat(mask, np.ceil(len(sig) / len(mask)), 0)[:len(sig)][:, np.newaxis] else: mask = self.mask_generator.gen_mask(len(sig), seed=index)[:, np.newaxis] sig *= mask sig = torch.tensor(sig).reshape(-1) target = torch.tensor(target).squeeze(0) sig_wav = sig.clone() target_wav = target.clone() target = torch.stft(target, self.window_size, self.stride, window=self.hann, return_complex=False).permute(2, 0, 1) sig = torch.stft(sig, self.window_size, self.stride, window=self.hann, return_complex=False).permute(2, 0, 1) return sig.float(), target.float(), sig_wav, target_wav class BlindTestLoader(Dataset): def __init__(self, test_dir): self.data_list = glob.glob(os.path.join(test_dir, '*.wav')) self.sr = CONFIG.DATA.sr self.stride = CONFIG.DATA.stride self.chunk_len = CONFIG.DATA.window_size self.hann = torch.sqrt(torch.hann_window(self.chunk_len)) def __len__(self): return len(self.data_list) def __getitem__(self, index): sig = load_audio(self.data_list[index], sample_rate=self.sr) sig = torch.from_numpy(sig).squeeze(0) sig = torch.stft(sig, self.chunk_len, self.stride, window=self.hann, return_complex=False).permute(2, 0, 1) return sig.float() class TrainDataset(Dataset): def __init__(self, mode='train'): dataset_name = CONFIG.DATA.dataset self.target_root = CONFIG.DATA.data_dir[dataset_name]['root'] txt_list = CONFIG.DATA.data_dir[dataset_name]['train'] self.data_list = self.load_txt(txt_list) if mode == 'train': self.data_list, _ = train_test_split(self.data_list, test_size=CONFIG.TRAIN.val_split, random_state=0) elif mode == 'val': _, self.data_list = train_test_split(self.data_list, test_size=CONFIG.TRAIN.val_split, random_state=0) self.p_sizes = CONFIG.DATA.TRAIN.packet_sizes self.mode = mode self.sr = CONFIG.DATA.sr self.window = CONFIG.DATA.audio_chunk_len self.stride = CONFIG.DATA.stride self.chunk_len = CONFIG.DATA.window_size self.hann = torch.sqrt(torch.hann_window(self.chunk_len)) self.mask_generator = MaskGenerator(is_train=True, probs=CONFIG.DATA.TRAIN.transition_probs) def __len__(self): return len(self.data_list) def load_txt(self, txt_list): target = [] with open(txt_list) as f: for line in f: target.append(os.path.join(self.target_root, line.strip('\n'))) target = list(set(target)) target.sort() return target def fetch_audio(self, index): sig = load_audio(self.data_list[index], sample_rate=self.sr, chunk_len=self.window) while sig.shape[1] < self.window: idx = torch.randint(0, len(self.data_list), (1,))[0] pad_len = self.window - sig.shape[1] if pad_len < 0.02 * self.sr: padding = np.zeros((1, pad_len), dtype=np.float) else: padding = load_audio(self.data_list[idx], sample_rate=self.sr, chunk_len=pad_len) sig = np.hstack((sig, padding)) return sig def __getitem__(self, index): sig = self.fetch_audio(index) sig = sig.reshape(-1).astype(np.float32) sig = sig.reshape((1, -1)) target = torch.tensor(sig.copy()) p_size = random.choice(self.p_sizes) sig = np.reshape(sig, (-1, p_size)) mask = self.mask_generator.gen_mask(len(sig), seed=index)[:, np.newaxis] sig *= mask sig = torch.tensor(sig.copy()) sig = sig.reshape(1, -1) target = torch.stft(target.squeeze(0), self.chunk_len, self.stride, window=self.hann, return_complex=False).permute(2, 0, 1).float() sig = torch.stft(sig.squeeze(0), self.chunk_len, self.stride, window=self.hann, return_complex=False) sig = sig.permute(2, 0, 1).float() return sig, target