conex / espnet2 /enh /encoder /conv_encoder.py
tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
No virus
1.3 kB
import torch
from espnet2.enh.encoder.abs_encoder import AbsEncoder
class ConvEncoder(AbsEncoder):
"""Convolutional encoder for speech enhancement and separation """
def __init__(
self,
channel: int,
kernel_size: int,
stride: int,
):
super().__init__()
self.conv1d = torch.nn.Conv1d(
1, channel, kernel_size=kernel_size, stride=stride, bias=False
)
self.stride = stride
self.kernel_size = kernel_size
self._output_dim = channel
@property
def output_dim(self) -> int:
return self._output_dim
def forward(self, input: torch.Tensor, ilens: torch.Tensor):
"""Forward.
Args:
input (torch.Tensor): mixed speech [Batch, sample]
ilens (torch.Tensor): input lengths [Batch]
Returns:
feature (torch.Tensor): mixed feature after encoder [Batch, flens, channel]
"""
assert input.dim() == 2, "Currently only support single channle input"
input = torch.unsqueeze(input, 1)
feature = self.conv1d(input)
feature = torch.nn.functional.relu(feature)
feature = feature.transpose(1, 2)
flens = (ilens - self.kernel_size) // self.stride + 1
return feature, flens