Spaces:
Running
Running
| import os | |
| import json | |
| import numpy as np | |
| from typing import Any, Tuple | |
| import soundfile as sf | |
| import torch | |
| from pytorch_lightning import LightningDataModule | |
| # from pytorch_lightning.core.mixins import HyperparametersMixin | |
| from torch.utils.data import ConcatDataset, DataLoader, Dataset | |
| from typing import Dict, Iterable, List, Iterator | |
| from rich import print | |
| from pytorch_lightning.utilities import rank_zero_only | |
| def print_(message: str): | |
| print(message) | |
| def normalize_tensor_wav(wav_tensor, eps=1e-8, std=None): | |
| mean = wav_tensor.mean(-1, keepdim=True) | |
| if std is None: | |
| std = wav_tensor.std(-1, keepdim=True) | |
| return (wav_tensor - mean) / (std + eps) | |
| class MP3DDataset(Dataset): | |
| def __init__( | |
| self, | |
| json_dir: str = "", | |
| n_src: int = 2, | |
| sample_rate: int = 8000, | |
| segment: float = 4.0, | |
| normalize_audio: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| self.EPS = 1e-8 | |
| if json_dir == None: | |
| raise ValueError("JSON DIR is None!") | |
| if n_src not in [1, 2]: | |
| raise ValueError("{} is not in [1, 2]".format(n_src)) | |
| self.json_dir = json_dir | |
| self.sample_rate = sample_rate | |
| self.normalize_audio = normalize_audio | |
| if segment is None: | |
| self.seg_len = None | |
| self.fps_len = None | |
| else: | |
| self.seg_len = int(segment * sample_rate) | |
| self.n_src = n_src | |
| self.test = self.seg_len is None | |
| mix_json = os.path.join(json_dir, "mix.json") | |
| sources_json = [ | |
| os.path.join(json_dir, source + ".json") for source in ["s1", "s2"] | |
| ] | |
| with open(mix_json, "r") as f: | |
| mix_infos = json.load(f) | |
| sources_infos = [] | |
| for src_json in sources_json: | |
| with open(src_json, "r") as f: | |
| sources_infos.append(json.load(f)) | |
| self.mix = [] | |
| self.sources = [] | |
| if self.n_src == 1: | |
| orig_len = len(mix_infos) * 2 | |
| drop_utt, drop_len = 0, 0 | |
| if not self.test: | |
| for i in range(len(mix_infos) - 1, -1, -1): | |
| if mix_infos[i][1] < self.seg_len: | |
| drop_utt = drop_utt + 1 | |
| drop_len = drop_len + mix_infos[i][1] | |
| del mix_infos[i] | |
| for src_inf in sources_infos: | |
| del src_inf[i] | |
| else: | |
| for src_inf in sources_infos: | |
| self.mix.append(mix_infos[i]) | |
| self.sources.append(src_inf[i]) | |
| else: | |
| for i in range(len(mix_infos)): | |
| for src_inf in sources_infos: | |
| self.mix.append(mix_infos[i]) | |
| self.sources.append(src_inf[i]) | |
| print_( | |
| "Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format( | |
| drop_utt, drop_len / sample_rate / 3600, orig_len, self.seg_len | |
| ) | |
| ) | |
| self.length = len(self.mix) | |
| elif self.n_src == 2: | |
| orig_len = len(mix_infos) | |
| drop_utt, drop_len = 0, 0 | |
| if not self.test: | |
| for i in range(len(mix_infos) - 1, -1, -1): # Go backward | |
| if mix_infos[i][1] < self.seg_len: | |
| drop_utt = drop_utt + 1 | |
| drop_len = drop_len + mix_infos[i][1] | |
| del mix_infos[i] | |
| for src_inf in sources_infos: | |
| del src_inf[i] | |
| print_( | |
| "Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format( | |
| drop_utt, drop_len / sample_rate / 36000, orig_len, self.seg_len | |
| ) | |
| ) | |
| self.mix = mix_infos | |
| self.sources = sources_infos | |
| self.length = len(self.mix) | |
| def __len__(self): | |
| return self.length | |
| def preprocess_audio_only(self, idx: int): | |
| if self.n_src == 1: | |
| if self.mix[idx][1] == self.seg_len or self.test: | |
| rand_start = 0 | |
| else: | |
| rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len) | |
| if self.test: | |
| stop = None | |
| else: | |
| stop = rand_start + self.seg_len | |
| # Load mixture | |
| x, _ = sf.read( | |
| self.mix[idx][0], start=rand_start, stop=stop, dtype="float32" | |
| ) | |
| # Load sources | |
| s, _ = sf.read( | |
| self.sources[idx][0], start=rand_start, stop=stop, dtype="float32" | |
| ) | |
| # torch from numpy | |
| target = torch.from_numpy(s) | |
| mixture = torch.from_numpy(x) | |
| if self.normalize_audio: | |
| m_std = mixture.std(-1, keepdim=True) | |
| mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std) | |
| target = normalize_tensor_wav(target, eps=self.EPS, std=m_std) | |
| # return mixture, target.unsqueeze(0), self.mix[idx][0].split("/")[-1] | |
| return mixture, target.unsqueeze(0), self.mix[idx][0] | |
| # import pdb; pdb.set_trace() | |
| if self.n_src == 2: | |
| if self.mix[idx][1] == self.seg_len or self.test: | |
| rand_start = 0 | |
| else: | |
| rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len) | |
| if self.test: | |
| stop = None | |
| else: | |
| stop = rand_start + self.seg_len | |
| # Load mixture | |
| x, _ = sf.read( | |
| self.mix[idx][0], start=rand_start, stop=stop, dtype="float32" | |
| ) | |
| # Load sources | |
| source_arrays = [] | |
| for src in self.sources: | |
| s, _ = sf.read( | |
| src[idx][0], start=rand_start, stop=stop, dtype="float32" | |
| ) | |
| source_arrays.append(s) | |
| sources = torch.from_numpy(np.vstack(source_arrays)) | |
| mixture = torch.from_numpy(x) | |
| if self.normalize_audio: | |
| m_std = mixture.std(-1, keepdim=True) | |
| mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std) | |
| sources = normalize_tensor_wav(sources, eps=self.EPS, std=m_std) | |
| # return mixture, sources, self.mix[idx][0].split("/")[-1] | |
| return mixture, sources, self.mix[idx][0] | |
| def __getitem__(self, index: int): | |
| return self.preprocess_audio_only(index) | |
| class EchoSetDataModule(object): | |
| def __init__( | |
| self, | |
| train_dir: str, | |
| valid_dir: str, | |
| test_dir: str, | |
| n_src: int = 2, | |
| sample_rate: int = 8000, | |
| segment: float = 4.0, | |
| normalize_audio: bool = False, | |
| batch_size: int = 64, | |
| num_workers: int = 0, | |
| pin_memory: bool = False, | |
| persistent_workers: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| if train_dir == None or valid_dir == None or test_dir == None: | |
| raise ValueError("JSON DIR is None!") | |
| if n_src not in [1, 2]: | |
| raise ValueError("{} is not in [1, 2]".format(n_src)) | |
| # this line allows to access init params with 'self.hparams' attribute | |
| self.train_dir = train_dir | |
| self.valid_dir = valid_dir | |
| self.test_dir = test_dir | |
| self.n_src = n_src | |
| self.sample_rate = sample_rate | |
| self.segment = segment | |
| self.normalize_audio = normalize_audio | |
| self.batch_size = batch_size | |
| self.num_workers = num_workers | |
| self.pin_memory = pin_memory | |
| self.persistent_workers = persistent_workers | |
| self.data_train: Dataset = None | |
| self.data_val: Dataset = None | |
| self.data_test: Dataset = None | |
| def setup(self) -> None: | |
| self.data_train = MP3DDataset( | |
| json_dir=self.train_dir, | |
| n_src=self.n_src, | |
| sample_rate=self.sample_rate, | |
| segment=self.segment, | |
| normalize_audio=self.normalize_audio, | |
| ) | |
| self.data_val = MP3DDataset( | |
| json_dir=self.valid_dir, | |
| n_src=self.n_src, | |
| sample_rate=self.sample_rate, | |
| segment=None, | |
| normalize_audio=self.normalize_audio, | |
| ) | |
| self.data_test = MP3DDataset( | |
| json_dir=self.test_dir, | |
| n_src=self.n_src, | |
| sample_rate=self.sample_rate, | |
| segment=None, | |
| normalize_audio=self.normalize_audio, | |
| ) | |
| def train_dataloader(self) -> DataLoader: | |
| return DataLoader( | |
| dataset=self.data_train, | |
| batch_size=self.batch_size, | |
| shuffle=True, | |
| num_workers=self.num_workers, | |
| persistent_workers=self.persistent_workers, | |
| pin_memory=self.pin_memory, | |
| drop_last=True, | |
| ) | |
| def val_dataloader(self) -> DataLoader: | |
| return DataLoader( | |
| dataset=self.data_val, | |
| shuffle=False, | |
| batch_size=self.batch_size, | |
| num_workers=self.num_workers, | |
| persistent_workers=self.persistent_workers, | |
| pin_memory=self.pin_memory, | |
| drop_last=False, | |
| ) | |
| def test_dataloader(self) -> DataLoader: | |
| return DataLoader( | |
| dataset=self.data_test, | |
| shuffle=False, | |
| batch_size=self.batch_size, | |
| num_workers=self.num_workers, | |
| persistent_workers=self.persistent_workers, | |
| pin_memory=self.pin_memory, | |
| drop_last=False, | |
| ) | |
| def make_loader(self): | |
| return self.train_dataloader(), self.val_dataloader(), self.test_dataloader() | |
| def make_sets(self): | |
| return self.data_train, self.data_val, self.data_test | |