import torch import torch.nn as nn class DepthWiseSeperableConv(nn.Module): def __init__(self, in_dim, out_dim, *args, **kwargs): super().__init__() if 'groups' in kwargs: # ignoring groups for Depthwise Sep Conv del kwargs['groups'] self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs) self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1) def forward(self, x): out = self.depthwise(x) out = self.pointwise(out) return out