''' For MEMO implementations of CIFAR-ConvNet Reference: https://github.com/wangkiw/ICLR23-MEMO/blob/main/convs/conv_cifar.py ''' import torch import torch.nn as nn import torch.nn.functional as F # for cifar def conv_block(in_channels, out_channels): return nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(), nn.MaxPool2d(2) ) class ConvNet2(nn.Module): def __init__(self, x_dim=3, hid_dim=64, z_dim=64): super().__init__() self.out_dim = 64 self.avgpool = nn.AvgPool2d(8) self.encoder = nn.Sequential( conv_block(x_dim, hid_dim), conv_block(hid_dim, z_dim), ) def forward(self, x): x = self.encoder(x) x = self.avgpool(x) features = x.view(x.shape[0], -1) return { "features":features } class GeneralizedConvNet2(nn.Module): def __init__(self, x_dim=3, hid_dim=64, z_dim=64): super().__init__() self.encoder = nn.Sequential( conv_block(x_dim, hid_dim), ) def forward(self, x): base_features = self.encoder(x) return base_features class SpecializedConvNet2(nn.Module): def __init__(self,hid_dim=64,z_dim=64): super().__init__() self.feature_dim = 64 self.avgpool = nn.AvgPool2d(8) self.AdaptiveBlock = conv_block(hid_dim,z_dim) def forward(self,x): base_features = self.AdaptiveBlock(x) pooled = self.avgpool(base_features) features = pooled.view(pooled.size(0),-1) return features def conv2(): return ConvNet2() def get_conv_a2fc(): basenet = GeneralizedConvNet2() adaptivenet = SpecializedConvNet2() return basenet,adaptivenet if __name__ == '__main__': a, b = get_conv_a2fc() _base = sum(p.numel() for p in a.parameters()) _adap = sum(p.numel() for p in b.parameters()) print(f"conv :{_base+_adap}") conv2 = conv2() conv2_sum = sum(p.numel() for p in conv2.parameters()) print(f"conv2 :{conv2_sum}")