File size: 2,275 Bytes
cb80c28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
'''
For MEMO implementations of ImageNet-ConvNet
Reference:
https://github.com/wangkiw/ICLR23-MEMO/blob/main/convs/conv_imagenet.py
'''
import torch.nn as nn
import torch

# for imagenet
def first_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )
    
def conv_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )

class ConvNet(nn.Module):
    def __init__(self, x_dim=3, hid_dim=128, z_dim=512):
        super().__init__()
        self.block1 = first_block(x_dim, hid_dim)
        self.block2 = conv_block(hid_dim, hid_dim)
        self.block3 = conv_block(hid_dim, hid_dim)
        self.block4 = conv_block(hid_dim, z_dim)
        self.avgpool = nn.AvgPool2d(7)
        self.out_dim = 512

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)

        x = self.avgpool(x)
        features = x.view(x.shape[0], -1)
        
        return {
            "features": features
        }

class GeneralizedConvNet(nn.Module):
    def __init__(self, x_dim=3, hid_dim=128, z_dim=512):
        super().__init__()
        self.block1 = first_block(x_dim, hid_dim)
        self.block2 = conv_block(hid_dim, hid_dim)
        self.block3 = conv_block(hid_dim, hid_dim)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return x

class SpecializedConvNet(nn.Module):
    def __init__(self, hid_dim=128,z_dim=512):
        super().__init__()
        self.block4 = conv_block(hid_dim, z_dim)
        self.avgpool = nn.AvgPool2d(7)
        self.feature_dim = 512
        
    def forward(self, x):
        x = self.block4(x)
        x = self.avgpool(x)
        features = x.view(x.shape[0], -1)
        return features
    
def conv4():
    model = ConvNet()
    return model

def conv_a2fc_imagenet():
    _base = GeneralizedConvNet()
    _adaptive_net = SpecializedConvNet()
    return _base, _adaptive_net