''' For MEMO implementations of ImageNet-ConvNet Reference: https://github.com/wangkiw/ICLR23-MEMO/blob/main/convs/conv_imagenet.py ''' import torch.nn as nn import torch # for imagenet def first_block(in_channels, out_channels): return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3), nn.BatchNorm2d(out_channels), nn.ReLU(), nn.MaxPool2d(2) ) def conv_block(in_channels, out_channels): return nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(), nn.MaxPool2d(2) ) class ConvNet(nn.Module): def __init__(self, x_dim=3, hid_dim=128, z_dim=512): super().__init__() self.block1 = first_block(x_dim, hid_dim) self.block2 = conv_block(hid_dim, hid_dim) self.block3 = conv_block(hid_dim, hid_dim) self.block4 = conv_block(hid_dim, z_dim) self.avgpool = nn.AvgPool2d(7) self.out_dim = 512 def forward(self, x): x = self.block1(x) x = self.block2(x) x = self.block3(x) x = self.block4(x) x = self.avgpool(x) features = x.view(x.shape[0], -1) return { "features": features } class GeneralizedConvNet(nn.Module): def __init__(self, x_dim=3, hid_dim=128, z_dim=512): super().__init__() self.block1 = first_block(x_dim, hid_dim) self.block2 = conv_block(hid_dim, hid_dim) self.block3 = conv_block(hid_dim, hid_dim) def forward(self, x): x = self.block1(x) x = self.block2(x) x = self.block3(x) return x class SpecializedConvNet(nn.Module): def __init__(self, hid_dim=128,z_dim=512): super().__init__() self.block4 = conv_block(hid_dim, z_dim) self.avgpool = nn.AvgPool2d(7) self.feature_dim = 512 def forward(self, x): x = self.block4(x) x = self.avgpool(x) features = x.view(x.shape[0], -1) return features def conv4(): model = ConvNet() return model def conv_a2fc_imagenet(): _base = GeneralizedConvNet() _adaptive_net = SpecializedConvNet() return _base, _adaptive_net