|
from typing import Optional |
|
|
|
import torch |
|
import numpy as np |
|
from librosa.filters import mel |
|
|
|
from .stft import STFT |
|
|
|
|
|
class MelSpectrogram(torch.nn.Module): |
|
def __init__( |
|
self, |
|
is_half: bool, |
|
n_mel_channels: int, |
|
sampling_rate: int, |
|
win_length: int, |
|
hop_length: int, |
|
n_fft: Optional[int] = None, |
|
mel_fmin: int = 0, |
|
mel_fmax: int = None, |
|
clamp: float = 1e-5, |
|
device=torch.device("cpu"), |
|
): |
|
super().__init__() |
|
if n_fft is None: |
|
n_fft = win_length |
|
mel_basis = mel( |
|
sr=sampling_rate, |
|
n_fft=n_fft, |
|
n_mels=n_mel_channels, |
|
fmin=mel_fmin, |
|
fmax=mel_fmax, |
|
htk=True, |
|
) |
|
mel_basis = torch.from_numpy(mel_basis).float() |
|
self.register_buffer("mel_basis", mel_basis) |
|
self.n_fft = n_fft |
|
self.hop_length = hop_length |
|
self.win_length = win_length |
|
self.clamp = clamp |
|
self.is_half = is_half |
|
|
|
self.stft = STFT( |
|
filter_length=n_fft, |
|
hop_length=hop_length, |
|
win_length=win_length, |
|
window="hann", |
|
use_torch_stft="privateuseone" not in str(device), |
|
).to(device) |
|
|
|
def forward( |
|
self, |
|
audio: torch.Tensor, |
|
keyshift=0, |
|
speed=1, |
|
center=True, |
|
): |
|
factor = 2 ** (keyshift / 12) |
|
win_length_new = int(np.round(self.win_length * factor)) |
|
magnitude = self.stft(audio, keyshift, speed, center) |
|
if keyshift != 0: |
|
size = self.n_fft // 2 + 1 |
|
resize = magnitude.size(1) |
|
if resize < size: |
|
magnitude = torch.nn.functional.pad(magnitude, (0, 0, 0, size - resize)) |
|
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new |
|
mel_output = torch.matmul(self.mel_basis, magnitude) |
|
if self.is_half: |
|
mel_output = mel_output.half() |
|
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) |
|
return log_mel_spec |
|
|