Shanmuk4622 commited on
Commit
c2aaa6b
·
verified ·
1 Parent(s): c09c98b

Upload test1/mobilevit_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test1/mobilevit_model.py +190 -0
test1/mobilevit_model.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+
5
+ def conv_1x1_bn(inp, oup):
6
+ return nn.Sequential(
7
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
8
+ nn.BatchNorm2d(oup),
9
+ nn.SiLU()
10
+ )
11
+
12
+ def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
13
+ return nn.Sequential(
14
+ nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
15
+ nn.BatchNorm2d(oup),
16
+ nn.SiLU()
17
+ )
18
+
19
+ class PreNorm(nn.Module):
20
+ def __init__(self, dim, fn):
21
+ super().__init__()
22
+ self.norm = nn.LayerNorm(dim)
23
+ self.fn = fn
24
+ def forward(self, x, **kwargs):
25
+ return self.fn(self.norm(x), **kwargs)
26
+
27
+ class FeedForward(nn.Module):
28
+ def __init__(self, dim, hidden_dim, dropout=0.):
29
+ super().__init__()
30
+ self.net = nn.Sequential(
31
+ nn.Linear(dim, hidden_dim),
32
+ nn.SiLU(),
33
+ nn.Dropout(dropout),
34
+ nn.Linear(hidden_dim, dim),
35
+ nn.Dropout(dropout)
36
+ )
37
+ def forward(self, x):
38
+ return self.net(x)
39
+
40
+ class Attention(nn.Module):
41
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
42
+ super().__init__()
43
+ inner_dim = dim_head * heads
44
+ project_out = not (heads == 1 and dim_head == dim)
45
+
46
+ self.heads = heads
47
+ self.scale = dim_head ** -0.5
48
+
49
+ self.attend = nn.Softmax(dim = -1)
50
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
51
+
52
+ self.to_out = nn.Sequential(
53
+ nn.Linear(inner_dim, dim),
54
+ nn.Dropout(dropout)
55
+ ) if project_out else nn.Identity()
56
+
57
+ def forward(self, x):
58
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
59
+ q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h = self.heads), qkv)
60
+
61
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
62
+ attn = self.attend(dots)
63
+ out = torch.matmul(attn, v)
64
+ out = rearrange(out, 'b p h n d -> b p n (h d)')
65
+ return self.to_out(out)
66
+
67
+ class Transformer(nn.Module):
68
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
69
+ super().__init__()
70
+ self.layers = nn.ModuleList([])
71
+ for _ in range(depth):
72
+ self.layers.append(nn.ModuleList([
73
+ PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
74
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
75
+ ]))
76
+ def forward(self, x):
77
+ for attn, ff in self.layers:
78
+ x = attn(x) + x
79
+ x = ff(x) + x
80
+ return x
81
+
82
+ class MV2Block(nn.Module):
83
+ def __init__(self, inp, oup, stride=1, expansion=4):
84
+ super().__init__()
85
+ self.stride = stride
86
+ hidden_dim = int(inp * expansion)
87
+ self.use_res_connect = self.stride == 1 and inp == oup
88
+
89
+ if expansion == 1:
90
+ self.conv = nn.Sequential(
91
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
92
+ nn.BatchNorm2d(hidden_dim),
93
+ nn.SiLU(),
94
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
95
+ nn.BatchNorm2d(oup),
96
+ )
97
+ else:
98
+ self.conv = nn.Sequential(
99
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
100
+ nn.BatchNorm2d(hidden_dim),
101
+ nn.SiLU(),
102
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
103
+ nn.BatchNorm2d(hidden_dim),
104
+ nn.SiLU(),
105
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
106
+ nn.BatchNorm2d(oup),
107
+ )
108
+
109
+ def forward(self, x):
110
+ if self.use_res_connect:
111
+ return x + self.conv(x)
112
+ else:
113
+ return self.conv(x)
114
+
115
+ class MobileViTBlock(nn.Module):
116
+ def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
117
+ super().__init__()
118
+ self.ph, self.pw = patch_size
119
+
120
+ self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
121
+ self.conv2 = conv_1x1_bn(channel, dim)
122
+
123
+ self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout)
124
+
125
+ self.conv3 = conv_1x1_bn(dim, channel)
126
+ self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
127
+
128
+ def forward(self, x):
129
+ y = x.clone()
130
+
131
+ x = self.conv1(x)
132
+ x = self.conv2(x)
133
+
134
+ _, _, h, w = x.shape
135
+ x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
136
+ x = self.transformer(x)
137
+ x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
138
+
139
+ x = self.conv3(x)
140
+ x = torch.cat((x, y), 1)
141
+ x = self.conv4(x)
142
+ return x
143
+
144
+ class MobileViTv3_Small(nn.Module):
145
+ def __init__(self, image_size=(224, 224), num_classes=10):
146
+ super().__init__()
147
+ ih, iw = image_size
148
+ ph, pw = 2, 2
149
+
150
+ dims = [144, 192, 240]
151
+ channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]
152
+
153
+ self.conv1 = conv_nxn_bn(3, channels[0], stride=2)
154
+
155
+ self.mv2 = nn.ModuleList([])
156
+ self.mv2.append(MV2Block(channels[0], channels[1], 1, 4))
157
+ self.mv2.append(MV2Block(channels[1], channels[2], 2, 4))
158
+ self.mv2.append(MV2Block(channels[2], channels[3], 1, 4))
159
+ self.mv2.append(MV2Block(channels[3], channels[4], 2, 4))
160
+
161
+ self.mvit = nn.ModuleList([])
162
+ self.mvit.append(MobileViTBlock(dims[0], 2, channels[5], 3, (ph, pw), int(dims[0]*2)))
163
+
164
+ self.mv2_2 = nn.ModuleList([])
165
+ self.mv2_2.append(MV2Block(channels[5], channels[6], 2, 4))
166
+
167
+ self.mvit_2 = nn.ModuleList([])
168
+ self.mvit_2.append(MobileViTBlock(dims[1], 4, channels[7], 3, (ph, pw), int(dims[1]*2)))
169
+
170
+ self.mv2_3 = nn.ModuleList([])
171
+ self.mv2_3.append(MV2Block(channels[7], channels[8], 2, 4))
172
+
173
+ self.mvit_3 = nn.ModuleList([])
174
+ self.mvit_3.append(MobileViTBlock(dims[2], 3, channels[9], 3, (ph, pw), int(dims[2]*2)))
175
+
176
+ self.conv2 = conv_1x1_bn(channels[9], channels[10])
177
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
178
+ self.fc = nn.Linear(channels[10], num_classes)
179
+
180
+ def forward(self, x):
181
+ x = self.conv1(x)
182
+ for conv in self.mv2: x = conv(x)
183
+ for m in self.mvit: x = m(x)
184
+ for conv in self.mv2_2: x = conv(x)
185
+ for m in self.mvit_2: x = m(x)
186
+ for conv in self.mv2_3: x = conv(x)
187
+ for m in self.mvit_3: x = m(x)
188
+ x = self.conv2(x)
189
+ x = self.pool(x).view(-1, x.shape[1])
190
+ return self.fc(x)