from torch import nn | |
class DepthWiseConv2d(nn.Module): | |
def __init__(self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Conv2d( | |
dim_in, | |
dim_in, | |
kernel_size=kernel_size, | |
padding=padding, | |
groups=dim_in, | |
stride=stride, | |
bias=bias, | |
), | |
nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias), | |
) | |
def forward(self, x): | |
return self.net(x) | |