import torch from torch_complex.tensor import ComplexTensor from espnet2.enh.decoder.abs_decoder import AbsDecoder from espnet2.layers.stft import Stft class STFTDecoder(AbsDecoder): """STFT decoder for speech enhancement and separation """ def __init__( self, n_fft: int = 512, win_length: int = None, hop_length: int = 128, window="hann", center: bool = True, normalized: bool = False, onesided: bool = True, ): super().__init__() self.stft = Stft( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided, ) def forward(self, input: ComplexTensor, ilens: torch.Tensor): """Forward. Args: input (ComplexTensor): spectrum [Batch, T, F] ilens (torch.Tensor): input lengths [Batch] """ if not isinstance(input, ComplexTensor): raise TypeError("Only support ComplexTensor for stft decoder") wav, wav_lens = self.stft.inverse(input, ilens) return wav, wav_lens