zyingt's picture
Upload 685 files
0d80816
raw
history blame contribute delete
No virus
2.96 kB
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
# ## Citations
# ```bibtex
# @inproceedings{yao2021wenet,
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
# booktitle={Proc. Interspeech},
# year={2021},
# address={Brno, Czech Republic },
# organization={IEEE}
# }
# @article{zhang2022wenet,
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
# journal={arXiv preprint arXiv:2203.15455},
# year={2022}
# }
#
"""Conv2d Module with Valid Padding"""
import torch.nn.functional as F
from torch.nn.modules.conv import _ConvNd, _size_2_t, Union, _pair, Tensor, Optional
class Conv2dValid(_ConvNd):
"""
Conv2d operator for VALID mode padding.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros", # TODO: refine this type
device=None,
dtype=None,
valid_trigx: bool = False,
valid_trigy: bool = False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
kernel_size_ = _pair(kernel_size)
stride_ = _pair(stride)
padding_ = padding if isinstance(padding, str) else _pair(padding)
dilation_ = _pair(dilation)
super(Conv2dValid, self).__init__(
in_channels,
out_channels,
kernel_size_,
stride_,
padding_,
dilation_,
False,
_pair(0),
groups,
bias,
padding_mode,
**factory_kwargs,
)
self.valid_trigx = valid_trigx
self.valid_trigy = valid_trigy
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
validx, validy = 0, 0
if self.valid_trigx:
validx = (
input.size(-2) * (self.stride[-2] - 1) - 1 + self.kernel_size[-2]
) // 2
if self.valid_trigy:
validy = (
input.size(-1) * (self.stride[-1] - 1) - 1 + self.kernel_size[-1]
) // 2
return F.conv2d(
input,
weight,
bias,
self.stride,
(validx, validy),
self.dilation,
self.groups,
)
def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.weight, self.bias)