|
import torch |
|
from torch_complex.tensor import ComplexTensor |
|
|
|
from espnet2.enh.encoder.abs_encoder import AbsEncoder |
|
from espnet2.layers.stft import Stft |
|
|
|
|
|
class STFTEncoder(AbsEncoder): |
|
"""STFT encoder for speech enhancement and separation """ |
|
|
|
def __init__( |
|
self, |
|
n_fft: int = 512, |
|
win_length: int = None, |
|
hop_length: int = 128, |
|
window="hann", |
|
center: bool = True, |
|
normalized: bool = False, |
|
onesided: bool = True, |
|
): |
|
super().__init__() |
|
self.stft = Stft( |
|
n_fft=n_fft, |
|
win_length=win_length, |
|
hop_length=hop_length, |
|
window=window, |
|
center=center, |
|
normalized=normalized, |
|
onesided=onesided, |
|
) |
|
|
|
self._output_dim = n_fft // 2 + 1 if onesided else n_fft |
|
|
|
@property |
|
def output_dim(self) -> int: |
|
return self._output_dim |
|
|
|
def forward(self, input: torch.Tensor, ilens: torch.Tensor): |
|
"""Forward. |
|
|
|
Args: |
|
input (torch.Tensor): mixed speech [Batch, sample] |
|
ilens (torch.Tensor): input lengths [Batch] |
|
Returns: |
|
stft spectrum (torch.ComplexTensor): (Batch, Frames, Freq) |
|
or (Batch, Frames, Channels, Freq) |
|
""" |
|
spectrum, flens = self.stft(input, ilens) |
|
spectrum = ComplexTensor(spectrum[..., 0], spectrum[..., 1]) |
|
|
|
return spectrum, flens |
|
|