|
import glob |
|
import os |
|
import random |
|
|
|
import librosa |
|
import numpy as np |
|
import soundfile as sf |
|
import torch |
|
from numpy.random import default_rng |
|
from pydtmc import MarkovChain |
|
from sklearn.model_selection import train_test_split |
|
from torch.utils.data import Dataset |
|
|
|
from config import CONFIG |
|
|
|
np.random.seed(0) |
|
rng = default_rng() |
|
|
|
|
|
def load_audio( |
|
path, |
|
sample_rate: int = 16000, |
|
chunk_len=None, |
|
): |
|
with sf.SoundFile(path) as f: |
|
sr = f.samplerate |
|
audio_len = f.frames |
|
|
|
if chunk_len is not None and chunk_len < audio_len: |
|
start_index = torch.randint(0, audio_len - chunk_len, (1,))[0] |
|
|
|
frames = f._prepare_read(start_index, start_index + chunk_len, -1) |
|
audio = f.read(frames, always_2d=True, dtype="float32") |
|
|
|
else: |
|
audio = f.read(always_2d=True, dtype="float32") |
|
|
|
if sr != sample_rate: |
|
audio = librosa.resample(np.squeeze(audio), sr, sample_rate)[:, np.newaxis] |
|
|
|
return audio.T |
|
|
|
|
|
def pad(sig, length): |
|
if sig.shape[1] < length: |
|
pad_len = length - sig.shape[1] |
|
sig = torch.hstack((sig, torch.zeros((sig.shape[0], pad_len)))) |
|
|
|
else: |
|
start = random.randint(0, sig.shape[1] - length) |
|
sig = sig[:, start:start + length] |
|
return sig |
|
|
|
|
|
class MaskGenerator: |
|
def __init__(self, is_train=True, probs=((0.9, 0.1), (0.5, 0.1), (0.5, 0.5))): |
|
''' |
|
is_train: if True, mask generator for training otherwise for evaluation |
|
probs: a list of transition probability (p_N, p_L) for Markov Chain. Only allow 1 tuple if 'is_train=False' |
|
''' |
|
self.is_train = is_train |
|
self.probs = probs |
|
self.mcs = [] |
|
if self.is_train: |
|
for prob in probs: |
|
self.mcs.append(MarkovChain([[prob[0], 1 - prob[0]], [1 - prob[1], prob[1]]], ['1', '0'])) |
|
else: |
|
assert len(probs) == 1 |
|
prob = self.probs[0] |
|
self.mcs.append(MarkovChain([[prob[0], 1 - prob[0]], [1 - prob[1], prob[1]]], ['1', '0'])) |
|
|
|
def gen_mask(self, length, seed=0): |
|
if self.is_train: |
|
mc = random.choice(self.mcs) |
|
else: |
|
mc = self.mcs[0] |
|
mask = mc.walk(length - 1, seed=seed) |
|
mask = np.array(list(map(int, mask))) |
|
return mask |
|
|
|
|
|
class TestLoader(Dataset): |
|
def __init__(self): |
|
dataset_name = CONFIG.DATA.dataset |
|
self.mask = CONFIG.DATA.EVAL.masking |
|
|
|
self.target_root = CONFIG.DATA.data_dir[dataset_name]['root'] |
|
txt_list = CONFIG.DATA.data_dir[dataset_name]['test'] |
|
self.data_list = self.load_txt(txt_list) |
|
if self.mask == 'real': |
|
trace_txt = glob.glob(os.path.join(CONFIG.DATA.EVAL.trace_path, '*.txt')) |
|
trace_txt.sort() |
|
self.trace_list = [1 - np.array(list(map(int, open(txt, 'r').read().strip('\n').split('\n')))) for txt in |
|
trace_txt] |
|
else: |
|
self.mask_generator = MaskGenerator(is_train=False, probs=CONFIG.DATA.EVAL.transition_probs) |
|
|
|
self.sr = CONFIG.DATA.sr |
|
self.stride = CONFIG.DATA.stride |
|
self.window_size = CONFIG.DATA.window_size |
|
self.audio_chunk_len = CONFIG.DATA.audio_chunk_len |
|
self.p_size = CONFIG.DATA.EVAL.packet_size |
|
self.hann = torch.sqrt(torch.hann_window(self.window_size)) |
|
|
|
def __len__(self): |
|
return len(self.data_list) |
|
|
|
def load_txt(self, txt_list): |
|
target = [] |
|
with open(txt_list) as f: |
|
for line in f: |
|
target.append(os.path.join(self.target_root, line.strip('\n'))) |
|
target = list(set(target)) |
|
target.sort() |
|
return target |
|
|
|
def __getitem__(self, index): |
|
target = load_audio(self.data_list[index], sample_rate=self.sr) |
|
target = target[:, :(target.shape[1] // self.p_size) * self.p_size] |
|
|
|
sig = np.reshape(target, (-1, self.p_size)).copy() |
|
if self.mask == 'real': |
|
mask = self.trace_list[index % len(self.trace_list)] |
|
mask = np.repeat(mask, np.ceil(len(sig) / len(mask)), 0)[:len(sig)][:, np.newaxis] |
|
else: |
|
mask = self.mask_generator.gen_mask(len(sig), seed=index)[:, np.newaxis] |
|
sig *= mask |
|
sig = torch.tensor(sig).reshape(-1) |
|
|
|
target = torch.tensor(target).squeeze(0) |
|
|
|
sig_wav = sig.clone() |
|
target_wav = target.clone() |
|
|
|
target = torch.stft(target, self.window_size, self.stride, window=self.hann, |
|
return_complex=False).permute(2, 0, 1) |
|
sig = torch.stft(sig, self.window_size, self.stride, window=self.hann, return_complex=False).permute(2, 0, 1) |
|
return sig.float(), target.float(), sig_wav, target_wav |
|
|
|
|
|
class BlindTestLoader(Dataset): |
|
def __init__(self, test_dir): |
|
self.data_list = glob.glob(os.path.join(test_dir, '*.wav')) |
|
self.sr = CONFIG.DATA.sr |
|
self.stride = CONFIG.DATA.stride |
|
self.chunk_len = CONFIG.DATA.window_size |
|
self.hann = torch.sqrt(torch.hann_window(self.chunk_len)) |
|
|
|
def __len__(self): |
|
return len(self.data_list) |
|
|
|
def __getitem__(self, index): |
|
sig = load_audio(self.data_list[index], sample_rate=self.sr) |
|
sig = torch.from_numpy(sig).squeeze(0) |
|
sig = torch.stft(sig, self.chunk_len, self.stride, window=self.hann, return_complex=False).permute(2, 0, 1) |
|
return sig.float() |
|
|
|
|
|
class TrainDataset(Dataset): |
|
|
|
def __init__(self, mode='train'): |
|
dataset_name = CONFIG.DATA.dataset |
|
self.target_root = CONFIG.DATA.data_dir[dataset_name]['root'] |
|
|
|
txt_list = CONFIG.DATA.data_dir[dataset_name]['train'] |
|
self.data_list = self.load_txt(txt_list) |
|
|
|
if mode == 'train': |
|
self.data_list, _ = train_test_split(self.data_list, test_size=CONFIG.TRAIN.val_split, random_state=0) |
|
|
|
elif mode == 'val': |
|
_, self.data_list = train_test_split(self.data_list, test_size=CONFIG.TRAIN.val_split, random_state=0) |
|
|
|
self.p_sizes = CONFIG.DATA.TRAIN.packet_sizes |
|
self.mode = mode |
|
self.sr = CONFIG.DATA.sr |
|
self.window = CONFIG.DATA.audio_chunk_len |
|
self.stride = CONFIG.DATA.stride |
|
self.chunk_len = CONFIG.DATA.window_size |
|
self.hann = torch.sqrt(torch.hann_window(self.chunk_len)) |
|
self.mask_generator = MaskGenerator(is_train=True, probs=CONFIG.DATA.TRAIN.transition_probs) |
|
|
|
def __len__(self): |
|
return len(self.data_list) |
|
|
|
def load_txt(self, txt_list): |
|
target = [] |
|
with open(txt_list) as f: |
|
for line in f: |
|
target.append(os.path.join(self.target_root, line.strip('\n'))) |
|
target = list(set(target)) |
|
target.sort() |
|
return target |
|
|
|
def fetch_audio(self, index): |
|
sig = load_audio(self.data_list[index], sample_rate=self.sr, chunk_len=self.window) |
|
while sig.shape[1] < self.window: |
|
idx = torch.randint(0, len(self.data_list), (1,))[0] |
|
pad_len = self.window - sig.shape[1] |
|
if pad_len < 0.02 * self.sr: |
|
padding = np.zeros((1, pad_len), dtype=np.float) |
|
else: |
|
padding = load_audio(self.data_list[idx], sample_rate=self.sr, chunk_len=pad_len) |
|
sig = np.hstack((sig, padding)) |
|
return sig |
|
|
|
def __getitem__(self, index): |
|
sig = self.fetch_audio(index) |
|
|
|
sig = sig.reshape(-1).astype(np.float32) |
|
|
|
target = torch.tensor(sig.copy()) |
|
p_size = random.choice(self.p_sizes) |
|
|
|
sig = np.reshape(sig, (-1, p_size)) |
|
mask = self.mask_generator.gen_mask(len(sig), seed=index)[:, np.newaxis] |
|
sig *= mask |
|
sig = torch.tensor(sig.copy()).reshape(-1) |
|
|
|
target = torch.stft(target, self.chunk_len, self.stride, window=self.hann, |
|
return_complex=False).permute(2, 0, 1).float() |
|
sig = torch.stft(sig, self.chunk_len, self.stride, window=self.hann, return_complex=False) |
|
sig = sig.permute(2, 0, 1).float() |
|
return sig, target |
|
|