|
|
|
|
|
from functools import partial |
|
import soundfile as sf |
|
import io |
|
import numpy as np |
|
import torch |
|
|
|
import torchaudio |
|
import torchvision |
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
AUDIO_CFG = { |
|
"sample_rate": 48000, |
|
"audio_length": 1024, |
|
"clip_samples": 480000, |
|
"mel_bins": 64, |
|
"window_size": 1024, |
|
"hop_size": 480, |
|
"fmin": 50, |
|
"fmax": 14000, |
|
"class_num": 527, |
|
} |
|
|
|
class dotdict(dict): |
|
"""dot.notation access to dictionary attributes""" |
|
__getattr__ = dict.get |
|
__setattr__ = dict.__setitem__ |
|
__delattr__ = dict.__delitem__ |
|
|
|
class Map(dict): |
|
""" |
|
Example: |
|
m = Map({'first_name': 'Eduardo'}, last_name='Pool', age=24, sports=['Soccer']) |
|
""" |
|
def __init__(self, *args, **kwargs): |
|
super(Map, self).__init__(*args, **kwargs) |
|
for arg in args: |
|
if isinstance(arg, dict): |
|
for k, v in arg.iteritems(): |
|
self[k] = v |
|
|
|
if kwargs: |
|
for k, v in kwargs.iteritems(): |
|
self[k] = v |
|
|
|
def __getattr__(self, attr): |
|
return self.get(attr) |
|
|
|
def __setattr__(self, key, value): |
|
self.__setitem__(key, value) |
|
|
|
def __setitem__(self, key, value): |
|
super(Map, self).__setitem__(key, value) |
|
self.__dict__.update({key: value}) |
|
|
|
def __delattr__(self, item): |
|
self.__delitem__(item) |
|
|
|
def __delitem__(self, key): |
|
super(Map, self).__delitem__(key) |
|
del self.__dict__[key] |
|
|
|
|
|
def int16_to_float32(x): |
|
return (x / 32767.0).astype(np.float32) |
|
|
|
|
|
def float32_to_int16(x): |
|
x = np.clip(x, a_min=-1., a_max=1.) |
|
return (x * 32767.).astype(np.int16) |
|
|
|
|
|
def get_mel(audio_data,audio_cfg): |
|
|
|
mel = torchaudio.transforms.MelSpectrogram( |
|
sample_rate=audio_cfg['sample_rate'], |
|
n_fft=audio_cfg['window_size'], |
|
win_length=audio_cfg['window_size'], |
|
hop_length=audio_cfg['hop_size'], |
|
center=True, |
|
pad_mode="reflect", |
|
power=2.0, |
|
norm=None, |
|
onesided=True, |
|
n_mels=audio_cfg['mel_bins'], |
|
f_min=audio_cfg['fmin'], |
|
f_max=audio_cfg['fmax'] |
|
)(audio_data) |
|
|
|
|
|
mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) |
|
return mel.T |
|
|
|
|
|
def get_audio_features(sample, audio_data, max_len, data_truncating, data_filling, audio_cfg): |
|
""" |
|
Calculate and add audio features to sample. |
|
Sample: a dict containing all the data of current sample. |
|
audio_data: a tensor of shape (T) containing audio data. |
|
max_len: the maximum length of audio data. |
|
data_truncating: the method of truncating data. |
|
data_filling: the method of filling data. |
|
audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg']. |
|
""" |
|
with torch.no_grad(): |
|
if len(audio_data) > max_len: |
|
if data_truncating == "rand_trunc": |
|
longer = torch.tensor([True]) |
|
elif data_truncating == "fusion": |
|
|
|
mel = get_mel(audio_data, audio_cfg) |
|
|
|
chunk_frames = max_len // audio_cfg['hop_size']+1 |
|
total_frames = mel.shape[0] |
|
if chunk_frames == total_frames: |
|
|
|
|
|
|
|
mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) |
|
sample["mel_fusion"] = mel_fusion |
|
longer = torch.tensor([False]) |
|
else: |
|
ranges = np.array_split(list(range(0, total_frames-chunk_frames+1)), 3) |
|
|
|
|
|
|
|
|
|
if len(ranges[1]) == 0: |
|
|
|
ranges[1] = [0] |
|
if len(ranges[2]) == 0: |
|
|
|
ranges[2] = [0] |
|
|
|
idx_front = np.random.choice(ranges[0]) |
|
idx_middle = np.random.choice(ranges[1]) |
|
idx_back = np.random.choice(ranges[2]) |
|
|
|
mel_chunk_front = mel[idx_front:idx_front+chunk_frames, :] |
|
mel_chunk_middle = mel[idx_middle:idx_middle+chunk_frames, :] |
|
mel_chunk_back = mel[idx_back:idx_back+chunk_frames, :] |
|
|
|
|
|
mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, 64])(mel[None])[0] |
|
|
|
|
|
|
|
mel_fusion = torch.stack([mel_chunk_front, mel_chunk_middle, mel_chunk_back, mel_shrink], dim=0) |
|
sample["mel_fusion"] = mel_fusion |
|
longer = torch.tensor([True]) |
|
else: |
|
raise NotImplementedError( |
|
f"data_truncating {data_truncating} not implemented" |
|
) |
|
|
|
overflow = len(audio_data) - max_len |
|
idx = np.random.randint(0, overflow + 1) |
|
audio_data = audio_data[idx: idx + max_len] |
|
|
|
else: |
|
if len(audio_data) < max_len: |
|
if data_filling == "repeatpad": |
|
n_repeat = int(max_len/len(audio_data)) |
|
audio_data = audio_data.repeat(n_repeat) |
|
|
|
|
|
audio_data = F.pad( |
|
audio_data, |
|
(0, max_len - len(audio_data)), |
|
mode="constant", |
|
value=0, |
|
) |
|
elif data_filling == "pad": |
|
audio_data = F.pad( |
|
audio_data, |
|
(0, max_len - len(audio_data)), |
|
mode="constant", |
|
value=0, |
|
) |
|
elif data_filling == "repeat": |
|
n_repeat = int(max_len/len(audio_data)) |
|
audio_data = audio_data.repeat(n_repeat+1)[:max_len] |
|
else: |
|
raise NotImplementedError( |
|
f"data_filling {data_filling} not implemented" |
|
) |
|
if data_truncating == 'fusion': |
|
mel = get_mel(audio_data, audio_cfg) |
|
mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) |
|
sample["mel_fusion"] = mel_fusion |
|
longer = torch.tensor([False]) |
|
|
|
sample["longer"] = longer |
|
sample["waveform"] = audio_data |
|
|
|
return sample |