Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Authors: Yossi Adi (adiyoss) and Alexandre Défossez (adefossez) | |
import json | |
import logging | |
import math | |
from pathlib import Path | |
import os | |
import re | |
import librosa | |
import numpy as np | |
import torch | |
import torch.utils.data as data | |
from .preprocess import preprocess_one_dir | |
from .audio import Audioset | |
logger = logging.getLogger(__name__) | |
def sort(infos): return sorted( | |
infos, key=lambda info: int(info[1]), reverse=True) | |
class Trainset: | |
def __init__(self, json_dir, sample_rate=16000, segment=4.0, stride=1.0, pad=True): | |
mix_json = os.path.join(json_dir, 'mix.json') | |
s_jsons = list() | |
s_infos = list() | |
sets_re = re.compile(r's[0-9]+.json') | |
print(os.listdir(json_dir)) | |
for s in os.listdir(json_dir): | |
if sets_re.search(s): | |
s_jsons.append(os.path.join(json_dir, s)) | |
with open(mix_json, 'r') as f: | |
mix_infos = json.load(f) | |
for s_json in s_jsons: | |
with open(s_json, 'r') as f: | |
s_infos.append(json.load(f)) | |
length = int(sample_rate * segment) | |
stride = int(sample_rate * stride) | |
kw = {'length': length, 'stride': stride, 'pad': pad} | |
self.mix_set = Audioset(sort(mix_infos), **kw) | |
self.sets = list() | |
for s_info in s_infos: | |
self.sets.append(Audioset(sort(s_info), **kw)) | |
# verify all sets has the same size | |
for s in self.sets: | |
assert len(s) == len(self.mix_set) | |
def __getitem__(self, index): | |
mix_sig = self.mix_set[index] | |
tgt_sig = [self.sets[i][index] for i in range(len(self.sets))] | |
return self.mix_set[index], torch.LongTensor([mix_sig.shape[0]]), torch.stack(tgt_sig) | |
def __len__(self): | |
return len(self.mix_set) | |
class Validset: | |
""" | |
load entire wav. | |
""" | |
def __init__(self, json_dir): | |
mix_json = os.path.join(json_dir, 'mix.json') | |
s_jsons = list() | |
s_infos = list() | |
sets_re = re.compile(r's[0-9]+.json') | |
for s in os.listdir(json_dir): | |
if sets_re.search(s): | |
s_jsons.append(os.path.join(json_dir, s)) | |
with open(mix_json, 'r') as f: | |
mix_infos = json.load(f) | |
for s_json in s_jsons: | |
with open(s_json, 'r') as f: | |
s_infos.append(json.load(f)) | |
self.mix_set = Audioset(sort(mix_infos)) | |
self.sets = list() | |
for s_info in s_infos: | |
self.sets.append(Audioset(sort(s_info))) | |
for s in self.sets: | |
assert len(s) == len(self.mix_set) | |
def __getitem__(self, index): | |
mix_sig = self.mix_set[index] | |
tgt_sig = [self.sets[i][index] for i in range(len(self.sets))] | |
return self.mix_set[index], torch.LongTensor([mix_sig.shape[0]]), torch.stack(tgt_sig) | |
def __len__(self): | |
return len(self.mix_set) | |
# The following piece of code was adapted from https://github.com/kaituoxu/Conv-TasNet | |
# released under the MIT License. | |
# Author: Kaituo XU | |
# Created on 2018/12 | |
class EvalDataset(data.Dataset): | |
def __init__(self, mix_dir, mix_json, batch_size, sample_rate=8000): | |
""" | |
Args: | |
mix_dir: directory including mixture wav files | |
mix_json: json file including mixture wav files | |
""" | |
super(EvalDataset, self).__init__() | |
assert mix_dir != None or mix_json != None | |
if mix_dir is not None: | |
# Generate mix.json given mix_dir | |
preprocess_one_dir(mix_dir, mix_dir, 'mix', | |
sample_rate=sample_rate) | |
mix_json = os.path.join(mix_dir, 'mix.json') | |
with open(mix_json, 'r') as f: | |
mix_infos = json.load(f) | |
# sort it by #samples (impl bucket) | |
def sort(infos): return sorted( | |
infos, key=lambda info: int(info[1]), reverse=True) | |
sorted_mix_infos = sort(mix_infos) | |
# generate minibach infomations | |
minibatch = [] | |
start = 0 | |
while True: | |
end = min(len(sorted_mix_infos), start + batch_size) | |
minibatch.append([sorted_mix_infos[start:end], | |
sample_rate]) | |
if end == len(sorted_mix_infos): | |
break | |
start = end | |
self.minibatch = minibatch | |
def __getitem__(self, index): | |
return self.minibatch[index] | |
def __len__(self): | |
return len(self.minibatch) | |
class EvalDataLoader(data.DataLoader): | |
""" | |
NOTE: just use batchsize=1 here, so drop_last=True makes no sense here. | |
""" | |
def __init__(self, *args, **kwargs): | |
super(EvalDataLoader, self).__init__(*args, **kwargs) | |
self.collate_fn = _collate_fn_eval | |
def _collate_fn_eval(batch): | |
""" | |
Args: | |
batch: list, len(batch) = 1. See AudioDataset.__getitem__() | |
Returns: | |
mixtures_pad: B x T, torch.Tensor | |
ilens : B, torch.Tentor | |
filenames: a list contain B strings | |
""" | |
# batch should be located in list | |
assert len(batch) == 1 | |
mixtures, filenames = load_mixtures(batch[0]) | |
# get batch of lengths of input sequences | |
ilens = np.array([mix.shape[0] for mix in mixtures]) | |
# perform padding and convert to tensor | |
pad_value = 0 | |
mixtures_pad = pad_list([torch.from_numpy(mix).float() | |
for mix in mixtures], pad_value) | |
ilens = torch.from_numpy(ilens) | |
return mixtures_pad, ilens, filenames | |
def load_mixtures(batch): | |
""" | |
Returns: | |
mixtures: a list containing B items, each item is T np.ndarray | |
filenames: a list containing B strings | |
T varies from item to item. | |
""" | |
mixtures, filenames = [], [] | |
mix_infos, sample_rate = batch | |
# for each utterance | |
for mix_info in mix_infos: | |
mix_path = mix_info[0] | |
# read wav file | |
mix, _ = librosa.load(mix_path, sr=sample_rate) | |
mixtures.append(mix) | |
filenames.append(mix_path) | |
return mixtures, filenames | |
def pad_list(xs, pad_value): | |
n_batch = len(xs) | |
max_len = max(x.size(0) for x in xs) | |
pad = xs[0].new(n_batch, max_len, * xs[0].size()[1:]).fill_(pad_value) | |
for i in range(n_batch): | |
pad[i, :xs[i].size(0)] = xs[i] | |
return pad | |