|
import librosa |
|
import torch |
|
from typing import Tuple |
|
|
|
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask |
|
|
|
|
|
class LogMel(torch.nn.Module): |
|
"""Convert STFT to fbank feats |
|
|
|
The arguments is same as librosa.filters.mel |
|
|
|
Args: |
|
fs: number > 0 [scalar] sampling rate of the incoming signal |
|
n_fft: int > 0 [scalar] number of FFT components |
|
n_mels: int > 0 [scalar] number of Mel bands to generate |
|
fmin: float >= 0 [scalar] lowest frequency (in Hz) |
|
fmax: float >= 0 [scalar] highest frequency (in Hz). |
|
If `None`, use `fmax = fs / 2.0` |
|
htk: use HTK formula instead of Slaney |
|
""" |
|
|
|
def __init__( |
|
self, |
|
fs: int = 16000, |
|
n_fft: int = 512, |
|
n_mels: int = 80, |
|
fmin: float = None, |
|
fmax: float = None, |
|
htk: bool = False, |
|
log_base: float = None, |
|
): |
|
super().__init__() |
|
|
|
fmin = 0 if fmin is None else fmin |
|
fmax = fs / 2 if fmax is None else fmax |
|
_mel_options = dict( |
|
sr=fs, |
|
n_fft=n_fft, |
|
n_mels=n_mels, |
|
fmin=fmin, |
|
fmax=fmax, |
|
htk=htk, |
|
) |
|
self.mel_options = _mel_options |
|
self.log_base = log_base |
|
|
|
|
|
melmat = librosa.filters.mel(**_mel_options) |
|
|
|
self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) |
|
|
|
def extra_repr(self): |
|
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) |
|
|
|
def forward( |
|
self, |
|
feat: torch.Tensor, |
|
ilens: torch.Tensor = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
mel_feat = torch.matmul(feat, self.melmat) |
|
mel_feat = torch.clamp(mel_feat, min=1e-10) |
|
|
|
if self.log_base is None: |
|
logmel_feat = mel_feat.log() |
|
elif self.log_base == 2.0: |
|
logmel_feat = mel_feat.log2() |
|
elif self.log_base == 10.0: |
|
logmel_feat = mel_feat.log10() |
|
else: |
|
logmel_feat = mel_feat.log() / torch.log(self.log_base) |
|
|
|
|
|
if ilens is not None: |
|
logmel_feat = logmel_feat.masked_fill( |
|
make_pad_mask(ilens, logmel_feat, 1), 0.0 |
|
) |
|
else: |
|
ilens = feat.new_full( |
|
[feat.size(0)], fill_value=feat.size(1), dtype=torch.long |
|
) |
|
return logmel_feat, ilens |
|
|