PyCIL_Stanford_Car / convs /conv_imagenet.py
HungNP
New single commit message
cb80c28
raw
history blame contribute delete
No virus
2.28 kB
'''
For MEMO implementations of ImageNet-ConvNet
Reference:
https://github.com/wangkiw/ICLR23-MEMO/blob/main/convs/conv_imagenet.py
'''
import torch.nn as nn
import torch
# for imagenet
def first_block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.MaxPool2d(2)
)
def conv_block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.MaxPool2d(2)
)
class ConvNet(nn.Module):
def __init__(self, x_dim=3, hid_dim=128, z_dim=512):
super().__init__()
self.block1 = first_block(x_dim, hid_dim)
self.block2 = conv_block(hid_dim, hid_dim)
self.block3 = conv_block(hid_dim, hid_dim)
self.block4 = conv_block(hid_dim, z_dim)
self.avgpool = nn.AvgPool2d(7)
self.out_dim = 512
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.avgpool(x)
features = x.view(x.shape[0], -1)
return {
"features": features
}
class GeneralizedConvNet(nn.Module):
def __init__(self, x_dim=3, hid_dim=128, z_dim=512):
super().__init__()
self.block1 = first_block(x_dim, hid_dim)
self.block2 = conv_block(hid_dim, hid_dim)
self.block3 = conv_block(hid_dim, hid_dim)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
return x
class SpecializedConvNet(nn.Module):
def __init__(self, hid_dim=128,z_dim=512):
super().__init__()
self.block4 = conv_block(hid_dim, z_dim)
self.avgpool = nn.AvgPool2d(7)
self.feature_dim = 512
def forward(self, x):
x = self.block4(x)
x = self.avgpool(x)
features = x.view(x.shape[0], -1)
return features
def conv4():
model = ConvNet()
return model
def conv_a2fc_imagenet():
_base = GeneralizedConvNet()
_adaptive_net = SpecializedConvNet()
return _base, _adaptive_net