Spaces:
Sleeping
Sleeping
File size: 2,131 Bytes
cb80c28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
'''
For MEMO implementations of CIFAR-ConvNet
Reference:
https://github.com/wangkiw/ICLR23-MEMO/blob/main/convs/conv_cifar.py
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
# for cifar
def conv_block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.MaxPool2d(2)
)
class ConvNet2(nn.Module):
def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
super().__init__()
self.out_dim = 64
self.avgpool = nn.AvgPool2d(8)
self.encoder = nn.Sequential(
conv_block(x_dim, hid_dim),
conv_block(hid_dim, z_dim),
)
def forward(self, x):
x = self.encoder(x)
x = self.avgpool(x)
features = x.view(x.shape[0], -1)
return {
"features":features
}
class GeneralizedConvNet2(nn.Module):
def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
super().__init__()
self.encoder = nn.Sequential(
conv_block(x_dim, hid_dim),
)
def forward(self, x):
base_features = self.encoder(x)
return base_features
class SpecializedConvNet2(nn.Module):
def __init__(self,hid_dim=64,z_dim=64):
super().__init__()
self.feature_dim = 64
self.avgpool = nn.AvgPool2d(8)
self.AdaptiveBlock = conv_block(hid_dim,z_dim)
def forward(self,x):
base_features = self.AdaptiveBlock(x)
pooled = self.avgpool(base_features)
features = pooled.view(pooled.size(0),-1)
return features
def conv2():
return ConvNet2()
def get_conv_a2fc():
basenet = GeneralizedConvNet2()
adaptivenet = SpecializedConvNet2()
return basenet,adaptivenet
if __name__ == '__main__':
a, b = get_conv_a2fc()
_base = sum(p.numel() for p in a.parameters())
_adap = sum(p.numel() for p in b.parameters())
print(f"conv :{_base+_adap}")
conv2 = conv2()
conv2_sum = sum(p.numel() for p in conv2.parameters())
print(f"conv2 :{conv2_sum}") |