JustinLin610
update
8437114
raw history blame
No virus
1.68 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from torch.nn.modules.utils import _single
from torch import Tensor
class ConvTBC(torch.nn.Module):
"""1D convolution over an input of shape (time x batch x channel)
The implementation uses gemm to perform the convolution. This implementation
is faster than cuDNN for small kernel sizes.
"""
def __init__(self, in_channels, out_channels, kernel_size, padding=0):
super(ConvTBC, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _single(kernel_size)
self.padding = _single(padding)
self.weight = torch.nn.Parameter(
torch.Tensor(self.kernel_size[0], in_channels, out_channels)
)
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_normal_(self.weight)
nn.init.zeros_(self.bias)
def conv_tbc(self, input: Tensor):
return torch.conv_tbc(
input.contiguous(), self.weight, self.bias, self.padding[0]
)
def forward(self, input: Tensor):
return self.conv_tbc(input)
def __repr__(self):
s = (
"{name}({in_channels}, {out_channels}, kernel_size={kernel_size}"
", padding={padding}"
)
if self.bias is None:
s += ", bias=False"
s += ")"
return s.format(name=self.__class__.__name__, **self.__dict__)