Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
# Copyright 2020 Johns Hopkins University (Shinji Watanabe) | |
# Northwestern Polytechnical University (Pengcheng Guo) | |
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
"""ConvolutionModule definition.""" | |
from torch import nn | |
class ConvolutionModule(nn.Module): | |
"""ConvolutionModule in Conformer model. | |
:param int channels: channels of cnn | |
:param int kernel_size: kernerl size of cnn | |
""" | |
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True): | |
"""Construct an ConvolutionModule object.""" | |
super(ConvolutionModule, self).__init__() | |
# kernerl_size should be a odd number for 'SAME' padding | |
assert (kernel_size - 1) % 2 == 0 | |
self.pointwise_conv1 = nn.Conv1d( | |
channels, | |
2 * channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=bias, | |
) | |
self.depthwise_conv = nn.Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
stride=1, | |
padding=(kernel_size - 1) // 2, | |
groups=channels, | |
bias=bias, | |
) | |
self.norm = nn.BatchNorm1d(channels) | |
self.pointwise_conv2 = nn.Conv1d( | |
channels, | |
channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=bias, | |
) | |
self.activation = activation | |
def forward(self, x): | |
"""Compute convolution module. | |
:param torch.Tensor x: (batch, time, size) | |
:return torch.Tensor: convoluted `value` (batch, time, d_model) | |
""" | |
# exchange the temporal dimension and the feature dimension | |
x = x.transpose(1, 2) | |
# GLU mechanism | |
x = self.pointwise_conv1(x) # (batch, 2*channel, dim) | |
x = nn.functional.glu(x, dim=1) # (batch, channel, dim) | |
# 1D Depthwise Conv | |
x = self.depthwise_conv(x) | |
x = self.activation(self.norm(x)) | |
x = self.pointwise_conv2(x) | |
return x.transpose(1, 2) | |