from torch import nn class DepthWiseConv2d(nn.Module): def __init__(self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True): super().__init__() self.net = nn.Sequential( nn.Conv2d( dim_in, dim_in, kernel_size=kernel_size, padding=padding, groups=dim_in, stride=stride, bias=bias, ), nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias), ) def forward(self, x): return self.net(x)