Image Classification
timm
PDE
ConvNet
liuyao commited on
Commit
793ef19
1 Parent(s): 020c806

Upload 2 files

Browse files
Files changed (2) hide show
  1. QLNet_symmetry.ipynb +551 -0
  2. qlnet.py +385 -0
QLNet_symmetry.ipynb ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "71b6152c",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch, timm\n",
11
+ "from qlnet import QLNet"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 2,
17
+ "id": "4e7ed219",
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "m = QLNet()"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 3,
27
+ "id": "3f703be8",
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "state_dict = torch.load('qlnet-50-v0.pth.tar')['state_dict']"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 4,
37
+ "id": "435e2358",
38
+ "metadata": {},
39
+ "outputs": [
40
+ {
41
+ "data": {
42
+ "text/plain": [
43
+ "<All keys matched successfully>"
44
+ ]
45
+ },
46
+ "execution_count": 4,
47
+ "metadata": {},
48
+ "output_type": "execute_result"
49
+ }
50
+ ],
51
+ "source": [
52
+ "m.load_state_dict(state_dict)"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": 5,
58
+ "id": "f14d984a",
59
+ "metadata": {
60
+ "scrolled": true
61
+ },
62
+ "outputs": [
63
+ {
64
+ "data": {
65
+ "text/plain": [
66
+ "QLNet(\n",
67
+ " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
68
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
69
+ " (act1): ReLU(inplace=True)\n",
70
+ " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
71
+ " (layer1): Sequential(\n",
72
+ " (0): QLBlock(\n",
73
+ " (conv1): ConvBN(\n",
74
+ " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
75
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
76
+ " )\n",
77
+ " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n",
78
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
79
+ " (conv3): ConvBN(\n",
80
+ " (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
81
+ " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
82
+ " )\n",
83
+ " (skip): Identity()\n",
84
+ " (act3): hardball()\n",
85
+ " )\n",
86
+ " (1): QLBlock(\n",
87
+ " (conv1): ConvBN(\n",
88
+ " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
89
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
90
+ " )\n",
91
+ " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n",
92
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
93
+ " (conv3): ConvBN(\n",
94
+ " (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
95
+ " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
96
+ " )\n",
97
+ " (skip): Identity()\n",
98
+ " (act3): hardball()\n",
99
+ " )\n",
100
+ " (2): QLBlock(\n",
101
+ " (conv1): ConvBN(\n",
102
+ " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
103
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
104
+ " )\n",
105
+ " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n",
106
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
107
+ " (conv3): ConvBN(\n",
108
+ " (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
109
+ " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
110
+ " )\n",
111
+ " (skip): Identity()\n",
112
+ " (act3): hardball()\n",
113
+ " )\n",
114
+ " )\n",
115
+ " (layer2): Sequential(\n",
116
+ " (0): QLBlock(\n",
117
+ " (conv1): ConvBN(\n",
118
+ " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
119
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
120
+ " )\n",
121
+ " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256, bias=False)\n",
122
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
123
+ " (conv3): ConvBN(\n",
124
+ " (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
125
+ " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
126
+ " )\n",
127
+ " (skip): ConvBN(\n",
128
+ " (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
129
+ " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
130
+ " )\n",
131
+ " (act3): hardball()\n",
132
+ " )\n",
133
+ " (1): QLBlock(\n",
134
+ " (conv1): ConvBN(\n",
135
+ " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
136
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
137
+ " )\n",
138
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
139
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
140
+ " (conv3): ConvBN(\n",
141
+ " (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
142
+ " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
143
+ " )\n",
144
+ " (skip): Identity()\n",
145
+ " (act3): hardball()\n",
146
+ " )\n",
147
+ " (2): QLBlock(\n",
148
+ " (conv1): ConvBN(\n",
149
+ " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
150
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
151
+ " )\n",
152
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
153
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
154
+ " (conv3): ConvBN(\n",
155
+ " (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
156
+ " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
157
+ " )\n",
158
+ " (skip): Identity()\n",
159
+ " (act3): hardball()\n",
160
+ " )\n",
161
+ " (3): QLBlock(\n",
162
+ " (conv1): ConvBN(\n",
163
+ " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
164
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
165
+ " )\n",
166
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
167
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
168
+ " (conv3): ConvBN(\n",
169
+ " (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
170
+ " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
171
+ " )\n",
172
+ " (skip): Identity()\n",
173
+ " (act3): hardball()\n",
174
+ " )\n",
175
+ " )\n",
176
+ " (layer3): Sequential(\n",
177
+ " (0): QLBlock(\n",
178
+ " (conv1): ConvBN(\n",
179
+ " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
180
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
181
+ " )\n",
182
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=512, bias=False)\n",
183
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
184
+ " (conv3): ConvBN(\n",
185
+ " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
186
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
187
+ " )\n",
188
+ " (skip): ConvBN(\n",
189
+ " (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
190
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
191
+ " )\n",
192
+ " (act3): hardball()\n",
193
+ " )\n",
194
+ " (1): QLBlock(\n",
195
+ " (conv1): ConvBN(\n",
196
+ " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
197
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
198
+ " )\n",
199
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
200
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
201
+ " (conv3): ConvBN(\n",
202
+ " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
203
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
204
+ " )\n",
205
+ " (skip): Identity()\n",
206
+ " (act3): hardball()\n",
207
+ " )\n",
208
+ " (2): QLBlock(\n",
209
+ " (conv1): ConvBN(\n",
210
+ " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
211
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
212
+ " )\n",
213
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
214
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
215
+ " (conv3): ConvBN(\n",
216
+ " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
217
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
218
+ " )\n",
219
+ " (skip): Identity()\n",
220
+ " (act3): hardball()\n",
221
+ " )\n",
222
+ " (3): QLBlock(\n",
223
+ " (conv1): ConvBN(\n",
224
+ " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
225
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
226
+ " )\n",
227
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
228
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
229
+ " (conv3): ConvBN(\n",
230
+ " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
231
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
232
+ " )\n",
233
+ " (skip): Identity()\n",
234
+ " (act3): hardball()\n",
235
+ " )\n",
236
+ " (4): QLBlock(\n",
237
+ " (conv1): ConvBN(\n",
238
+ " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
239
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
240
+ " )\n",
241
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
242
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
243
+ " (conv3): ConvBN(\n",
244
+ " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
245
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
246
+ " )\n",
247
+ " (skip): Identity()\n",
248
+ " (act3): hardball()\n",
249
+ " )\n",
250
+ " (5): QLBlock(\n",
251
+ " (conv1): ConvBN(\n",
252
+ " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
253
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
254
+ " )\n",
255
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
256
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
257
+ " (conv3): ConvBN(\n",
258
+ " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
259
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
260
+ " )\n",
261
+ " (skip): Identity()\n",
262
+ " (act3): hardball()\n",
263
+ " )\n",
264
+ " )\n",
265
+ " (layer4): Sequential(\n",
266
+ " (0): QLBlock(\n",
267
+ " (conv1): ConvBN(\n",
268
+ " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
269
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
270
+ " )\n",
271
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=512, bias=False)\n",
272
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
273
+ " (conv3): ConvBN(\n",
274
+ " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
275
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
276
+ " )\n",
277
+ " (skip): ConvBN(\n",
278
+ " (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
279
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
280
+ " )\n",
281
+ " (act3): hardball()\n",
282
+ " )\n",
283
+ " (1): QLBlock(\n",
284
+ " (conv1): ConvBN(\n",
285
+ " (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
286
+ " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
287
+ " )\n",
288
+ " (conv2): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n",
289
+ " (bn2): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
290
+ " (conv3): ConvBN(\n",
291
+ " (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
292
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
293
+ " )\n",
294
+ " (skip): Identity()\n",
295
+ " (act3): hardball()\n",
296
+ " )\n",
297
+ " (2): QLBlock(\n",
298
+ " (conv1): ConvBN(\n",
299
+ " (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
300
+ " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
301
+ " )\n",
302
+ " (conv2): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n",
303
+ " (bn2): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
304
+ " (conv3): ConvBN(\n",
305
+ " (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
306
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
307
+ " )\n",
308
+ " (skip): Identity()\n",
309
+ " (act3): hardball()\n",
310
+ " )\n",
311
+ " )\n",
312
+ " (act): hardball()\n",
313
+ " (global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
314
+ " (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
315
+ ")"
316
+ ]
317
+ },
318
+ "execution_count": 5,
319
+ "metadata": {},
320
+ "output_type": "execute_result"
321
+ }
322
+ ],
323
+ "source": [
324
+ "m.eval()"
325
+ ]
326
+ },
327
+ {
328
+ "cell_type": "code",
329
+ "execution_count": 6,
330
+ "id": "2099b937",
331
+ "metadata": {},
332
+ "outputs": [
333
+ {
334
+ "name": "stdout",
335
+ "output_type": "stream",
336
+ "text": [
337
+ "layer1 >>\n",
338
+ "torch.Size([512, 64, 1, 1])\n",
339
+ "torch.Size([64, 512, 1, 1])\n",
340
+ "torch.Size([512, 64, 1, 1])\n",
341
+ "torch.Size([64, 512, 1, 1])\n",
342
+ "torch.Size([512, 64, 1, 1])\n",
343
+ "torch.Size([64, 512, 1, 1])\n",
344
+ "layer2 >>\n",
345
+ "torch.Size([512, 64, 1, 1])\n",
346
+ "torch.Size([128, 512, 1, 1])\n",
347
+ "torch.Size([128, 64, 1, 1])\n",
348
+ "torch.Size([1024, 128, 1, 1])\n",
349
+ "torch.Size([128, 1024, 1, 1])\n",
350
+ "torch.Size([1024, 128, 1, 1])\n",
351
+ "torch.Size([128, 1024, 1, 1])\n",
352
+ "torch.Size([1024, 128, 1, 1])\n",
353
+ "torch.Size([128, 1024, 1, 1])\n",
354
+ "layer3 >>\n",
355
+ "torch.Size([1024, 128, 1, 1])\n",
356
+ "torch.Size([256, 1024, 1, 1])\n",
357
+ "torch.Size([256, 128, 1, 1])\n",
358
+ "torch.Size([1024, 256, 1, 1])\n",
359
+ "torch.Size([256, 1024, 1, 1])\n",
360
+ "torch.Size([1024, 256, 1, 1])\n",
361
+ "torch.Size([256, 1024, 1, 1])\n",
362
+ "torch.Size([1024, 256, 1, 1])\n",
363
+ "torch.Size([256, 1024, 1, 1])\n",
364
+ "torch.Size([1024, 256, 1, 1])\n",
365
+ "torch.Size([256, 1024, 1, 1])\n",
366
+ "torch.Size([1024, 256, 1, 1])\n",
367
+ "torch.Size([256, 1024, 1, 1])\n",
368
+ "layer4 >>\n",
369
+ "torch.Size([1024, 256, 1, 1])\n",
370
+ "torch.Size([512, 1024, 1, 1])\n",
371
+ "torch.Size([512, 256, 1, 1])\n",
372
+ "torch.Size([2048, 512, 1, 1])\n",
373
+ "torch.Size([512, 2048, 1, 1])\n",
374
+ "torch.Size([2048, 512, 1, 1])\n",
375
+ "torch.Size([512, 2048, 1, 1])\n"
376
+ ]
377
+ }
378
+ ],
379
+ "source": [
380
+ "# fuse ConvBN\n",
381
+ "i = 1\n",
382
+ "for layer in [m.layer1, m.layer2, m.layer3, m.layer4]:\n",
383
+ " print(f'layer{i} >>')\n",
384
+ " for block in layer:\n",
385
+ " # Fuse the weights in conv1 and conv3\n",
386
+ " block.conv1.fuse_bn()\n",
387
+ " print(block.conv1.fused_weight.size())\n",
388
+ " block.conv3.fuse_bn()\n",
389
+ " print(block.conv3.fused_weight.size())\n",
390
+ " if not isinstance(block.skip, torch.nn.Identity):\n",
391
+ " layer[0].skip.fuse_bn()\n",
392
+ " print(layer[0].skip.fused_weight.size())\n",
393
+ " i += 1"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "code",
398
+ "execution_count": 7,
399
+ "id": "b3a55f82",
400
+ "metadata": {},
401
+ "outputs": [],
402
+ "source": [
403
+ "inpt = torch.randn(5,3,224,224)"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": 8,
409
+ "id": "dccbf19c",
410
+ "metadata": {},
411
+ "outputs": [],
412
+ "source": [
413
+ "out_old = m(inpt)"
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "code",
418
+ "execution_count": 10,
419
+ "id": "f0c74a04",
420
+ "metadata": {
421
+ "scrolled": true
422
+ },
423
+ "outputs": [
424
+ {
425
+ "data": {
426
+ "text/plain": [
427
+ "torch.Size([5, 1000])"
428
+ ]
429
+ },
430
+ "execution_count": 10,
431
+ "metadata": {},
432
+ "output_type": "execute_result"
433
+ }
434
+ ],
435
+ "source": [
436
+ "out_old.size()"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": 11,
442
+ "id": "a5991c8f",
443
+ "metadata": {},
444
+ "outputs": [],
445
+ "source": [
446
+ "def apply_transform(block1, block2, Q, keep_identity=True):\n",
447
+ " with torch.no_grad():\n",
448
+ " # Ensure that the out_channels of block1 is equal to the in_channels of block2\n",
449
+ " assert Q.size()[0] == Q.size()[1], \"Q needs to be a square matrix\"\n",
450
+ " n = Q.size()[0]\n",
451
+ " assert block1.conv3.conv.out_channels == n and block2.conv1.conv.in_channels == n, \"Mismatched channels between blocks\"\n",
452
+ "\n",
453
+ " n = block1.conv3.conv.out_channels\n",
454
+ " \n",
455
+ " # Calculate the inverse of Q\n",
456
+ " Q_inv = torch.inverse(Q)\n",
457
+ "\n",
458
+ " # Modify the weights of conv layers in block1\n",
459
+ " block1.conv3.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.conv3.fused_weight.data)\n",
460
+ " block1.conv3.fused_bias.data = torch.einsum('ij,j->i', Q, block1.conv3.fused_bias.data)\n",
461
+ " \n",
462
+ " if isinstance(block1.skip, torch.nn.Identity):\n",
463
+ " if not keep_identity:\n",
464
+ " block1.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)\n",
465
+ " block1.skip.weight.data = Q.unsqueeze(-1).unsqueeze(-1)\n",
466
+ " else:\n",
467
+ " block1.skip.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.skip.fused_weight.data)\n",
468
+ " block1.skip.fused_bias.data = torch.einsum('ij,j->i', Q, block1.skip.fused_bias.data)\n",
469
+ "\n",
470
+ " # Modify the weights of conv layers in block2\n",
471
+ " block2.conv1.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.conv1.fused_weight.data)\n",
472
+ " \n",
473
+ " if isinstance(block2.skip, torch.nn.Identity):\n",
474
+ " if not keep_identity:\n",
475
+ " block2.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)\n",
476
+ " block2.skip.weight.data = Q_inv.unsqueeze(-1).unsqueeze(-1)\n",
477
+ " else:\n",
478
+ " block2.skip.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.skip.fused_weight.data)\n"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "execution_count": 12,
484
+ "id": "dd96acd7",
485
+ "metadata": {},
486
+ "outputs": [],
487
+ "source": [
488
+ "Q = torch.nn.init.orthogonal_(torch.empty(256, 256))\n",
489
+ "for i in range(5):\n",
490
+ " apply_transform(m.layer3[i], m.layer3[i+1], Q, True)\n",
491
+ "apply_transform(m.layer3[5], m.layer4[0], Q, True)"
492
+ ]
493
+ },
494
+ {
495
+ "cell_type": "code",
496
+ "execution_count": 13,
497
+ "id": "e5d3628d",
498
+ "metadata": {},
499
+ "outputs": [
500
+ {
501
+ "name": "stdout",
502
+ "output_type": "stream",
503
+ "text": [
504
+ "6.0558319091796875e-05\n"
505
+ ]
506
+ }
507
+ ],
508
+ "source": [
509
+ "out_new = m(inpt)\n",
510
+ "print((out_new - out_old).abs().max().item())"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "code",
515
+ "execution_count": null,
516
+ "id": "9fce3a38",
517
+ "metadata": {},
518
+ "outputs": [],
519
+ "source": []
520
+ },
521
+ {
522
+ "cell_type": "code",
523
+ "execution_count": null,
524
+ "id": "5a54fe8b",
525
+ "metadata": {},
526
+ "outputs": [],
527
+ "source": []
528
+ }
529
+ ],
530
+ "metadata": {
531
+ "kernelspec": {
532
+ "display_name": "Python 3 (ipykernel)",
533
+ "language": "python",
534
+ "name": "python3"
535
+ },
536
+ "language_info": {
537
+ "codemirror_mode": {
538
+ "name": "ipython",
539
+ "version": 3
540
+ },
541
+ "file_extension": ".py",
542
+ "mimetype": "text/x-python",
543
+ "name": "python",
544
+ "nbconvert_exporter": "python",
545
+ "pygments_lexer": "ipython3",
546
+ "version": "3.10.6"
547
+ }
548
+ },
549
+ "nbformat": 4,
550
+ "nbformat_minor": 5
551
+ }
qlnet.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch ResNet
2
+
3
+ This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with
4
+ additional dropout and dynamic global avg/max pool.
5
+
6
+ ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman
7
+
8
+ Copyright 2019, Ross Wightman
9
+ """
10
+ import math
11
+ from functools import partial
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
19
+ from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \
20
+ get_act_layer, get_norm_layer, create_classifier, LayerNorm2d
21
+
22
+
23
+ def get_padding(kernel_size, stride, dilation=1):
24
+ padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
25
+ return padding
26
+
27
+
28
+ class softball(nn.Module):
29
+ def __init__(self, radius2=None, inplace=True):
30
+ super(softball, self).__init__()
31
+ self.radius2 = radius2 if radius2 is not None else None
32
+
33
+ def forward(self, x):
34
+ if self.radius2 is None:
35
+ self.radius2 = x.size()[1]
36
+ norm = torch.sqrt(1 + (x*x).sum(1, keepdim=True) / self.radius2)
37
+ return x / norm
38
+
39
+ class hardball(nn.Module):
40
+ def __init__(self, radius2=None):
41
+ super(hardball, self).__init__()
42
+ self.radius = np.sqrt(radius2) if radius2 is not None else None
43
+
44
+ def forward(self, x):
45
+ norm = torch.sqrt((x*x).sum(1, keepdim=True))
46
+ if self.radius is None:
47
+ self.radius = np.sqrt(x.size()[1])
48
+ return torch.where(norm > self.radius, self.radius * x / norm, x)
49
+
50
+
51
+ class ConvBN(nn.Module):
52
+ def __init__(self, conv, bn):
53
+ super(ConvBN, self).__init__()
54
+ self.conv = conv
55
+ self.bn = bn
56
+ self.fused_weight = None
57
+ self.fused_bias = None
58
+
59
+ def forward(self, x):
60
+ if self.training:
61
+ x = self.conv(x)
62
+ x = self.bn(x)
63
+ else:
64
+ if self.fused_weight is not None and self.fused_bias is not None:
65
+ x = F.conv2d(x, self.fused_weight, self.fused_bias,
66
+ self.conv.stride, self.conv.padding,
67
+ self.conv.dilation, self.conv.groups)
68
+ else:
69
+ x = self.conv(x)
70
+ x = self.bn(x)
71
+ return x
72
+
73
+ def fuse_bn(self):
74
+ if self.training:
75
+ raise RuntimeError("Call fuse_bn only in eval mode")
76
+
77
+ # Calculate the fused weight and bias
78
+ w = self.conv.weight
79
+ mean = self.bn.running_mean
80
+ var = torch.sqrt(self.bn.running_var + self.bn.eps)
81
+ gamma = self.bn.weight
82
+ beta = self.bn.bias
83
+
84
+ self.fused_weight = w * (gamma / var).reshape(-1, 1, 1, 1)
85
+ self.fused_bias = beta - (gamma * mean / var)
86
+
87
+
88
+ class QLBlock(nn.Module): # quasilinear hyperbolic system
89
+ expansion = 1
90
+
91
+ def __init__(
92
+ self,
93
+ inplanes,
94
+ planes,
95
+ stride=1,
96
+ downsample=None,
97
+ cardinality=1,
98
+ base_width=64,
99
+ reduce_first=1,
100
+ dilation=1,
101
+ first_dilation=None,
102
+ act_layer=nn.ReLU,
103
+ norm_layer=nn.BatchNorm2d,
104
+ ):
105
+ super(QLBlock, self).__init__()
106
+
107
+ k = 4 if inplanes <= 128 else 2
108
+ width = inplanes * k
109
+ outplanes = inplanes if downsample is None else inplanes * 2
110
+ first_dilation = first_dilation or dilation
111
+
112
+ self.conv1 = ConvBN(
113
+ nn.Conv2d(inplanes, width*2, kernel_size=1, stride=1,
114
+ dilation=first_dilation, groups=1, bias=False),
115
+ norm_layer(width*2))
116
+
117
+ self.conv2 = nn.Conv2d(width, width*2, kernel_size=3, stride=stride,
118
+ padding=1, dilation=first_dilation, groups=width, bias=False)
119
+ self.bn2 = norm_layer(width*2)
120
+
121
+ self.conv3 = ConvBN(
122
+ nn.Conv2d(width*2, outplanes, kernel_size=1, groups=1, bias=False),
123
+ norm_layer(outplanes))
124
+
125
+ self.skip = ConvBN(
126
+ nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride,
127
+ dilation=first_dilation, groups=1, bias=False),
128
+ norm_layer(outplanes)) if downsample is not None else nn.Identity()
129
+
130
+ self.act3 = hardball(radius2=outplanes) # if downsample is not None else None
131
+
132
+ def zero_init_last(self):
133
+ if getattr(self.conv3.bn, 'weight', None) is not None:
134
+ nn.init.zeros_(self.conv3.bn.weight)
135
+
136
+ def conv_forward(self, x):
137
+ conv = self.conv2
138
+ k = conv.in_channels
139
+ C = x.size()[1] // k
140
+ kernel = conv.weight.repeat(C, 1, 1, 1)
141
+ bias = conv.bias.repeat(C) if conv.bias is not None else None
142
+ return F.conv2d(x, kernel, bias, conv.stride,
143
+ conv.padding, conv.dilation, C * k)
144
+
145
+ def forward(self, x):
146
+ x0 = self.skip(x)
147
+ x = self.conv1(x)
148
+ C = x.size(1) // 2
149
+ x = x[:, :C, :, :] * x[:, C:, :, :]
150
+ x = self.conv2(x)
151
+ x = self.bn2(x)
152
+ x = self.conv3(x)
153
+ x += x0
154
+ if self.act3 is not None:
155
+ x = self.act3(x)
156
+ return x
157
+
158
+ def make_blocks(
159
+ block_fn,
160
+ channels,
161
+ block_repeats,
162
+ inplanes,
163
+ reduce_first=1,
164
+ output_stride=32,
165
+ down_kernel_size=1,
166
+ avg_down=False,
167
+ **kwargs,
168
+ ):
169
+ stages = []
170
+ feature_info = []
171
+ net_num_blocks = sum(block_repeats)
172
+ net_block_idx = 0
173
+ net_stride = 4
174
+ dilation = prev_dilation = 1
175
+ for stage_idx, (planes, num_blocks) in enumerate(zip(channels, block_repeats)):
176
+ stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it
177
+ stride = 1 if stage_idx == 0 else 2
178
+ if net_stride >= output_stride:
179
+ dilation *= stride
180
+ stride = 1
181
+ else:
182
+ net_stride *= stride
183
+
184
+ downsample = None
185
+ if stride != 1 or inplanes != planes * block_fn.expansion:
186
+ downsample = True
187
+
188
+ block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, **kwargs)
189
+ blocks = []
190
+ for block_idx in range(num_blocks):
191
+ downsample = downsample if block_idx == 0 else None
192
+ stride = stride if block_idx == 0 else 1
193
+ blocks.append(block_fn(
194
+ inplanes, planes, stride, downsample, first_dilation=prev_dilation,
195
+ **block_kwargs))
196
+ prev_dilation = dilation
197
+ inplanes = planes * block_fn.expansion
198
+ net_block_idx += 1
199
+
200
+ stages.append((stage_name, nn.Sequential(*blocks)))
201
+ feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name))
202
+
203
+ return stages, feature_info
204
+
205
+
206
+ class QLNet(nn.Module):
207
+ # based on timm code for ResNet / ResNeXt / SE-ResNeXt / SE-Net
208
+
209
+ def __init__(
210
+ self,
211
+ block=QLBlock, # new block
212
+ layers=[3,4,6,3], # as in resnet50
213
+ num_classes=1000,
214
+ in_chans=3,
215
+ output_stride=32,
216
+ global_pool='avg',
217
+ cardinality=1,
218
+ base_width=64,
219
+ stem_width=64,
220
+ stem_type='',
221
+ replace_stem_pool=False,
222
+ block_reduce_first=1,
223
+ down_kernel_size=1,
224
+ avg_down=False,
225
+ act_layer=nn.ReLU,
226
+ norm_layer=nn.BatchNorm2d,
227
+ zero_init_last=True,
228
+ block_args=None,
229
+ ):
230
+ """
231
+ Args:
232
+ block (nn.Module): class for the residual block. Options are BasicBlock, Bottleneck.
233
+ layers (List[int]) : number of layers in each block
234
+ num_classes (int): number of classification classes (default 1000)
235
+ in_chans (int): number of input (color) channels. (default 3)
236
+ output_stride (int): output stride of the network, 32, 16, or 8. (default 32)
237
+ global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg')
238
+ cardinality (int): number of convolution groups for 3x3 conv in Bottleneck. (default 1)
239
+ base_width (int): bottleneck channels factor. `planes * base_width / 64 * cardinality` (default 64)
240
+ stem_width (int): number of channels in stem convolutions (default 64)
241
+ stem_type (str): The type of stem (default ''):
242
+ * '', default - a single 7x7 conv with a width of stem_width
243
+ * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
244
+ * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
245
+ replace_stem_pool (bool): replace stem max-pooling layer with a 3x3 stride-2 convolution
246
+ block_reduce_first (int): Reduction factor for first convolution output width of residual blocks,
247
+ 1 for all archs except senets, where 2 (default 1)
248
+ down_kernel_size (int): kernel size of residual block downsample path,
249
+ 1x1 for most, 3x3 for senets (default: 1)
250
+ avg_down (bool): use avg pooling for projection skip connection between stages/downsample (default False)
251
+ act_layer (str, nn.Module): activation layer
252
+ norm_layer (str, nn.Module): normalization layer
253
+ zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight)
254
+ block_args (dict): Extra kwargs to pass through to block module
255
+ """
256
+ super(QLNet, self).__init__()
257
+ block_args = block_args or dict()
258
+ assert output_stride in (8, 16, 32)
259
+ self.num_classes = num_classes
260
+ self.grad_checkpointing = False
261
+
262
+ act_layer = get_act_layer(act_layer)
263
+ norm_layer = get_norm_layer(norm_layer)
264
+
265
+ # Stem
266
+ deep_stem = 'deep' in stem_type
267
+ inplanes = stem_width * 2 if deep_stem else 64
268
+ if deep_stem:
269
+ stem_chs = (stem_width, stem_width)
270
+ if 'tiered' in stem_type:
271
+ stem_chs = (3 * (stem_width // 4), stem_width)
272
+ self.conv1 = nn.Sequential(*[
273
+ nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False),
274
+ norm_layer(stem_chs[0]),
275
+ act_layer(inplace=True),
276
+ nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False),
277
+ norm_layer(stem_chs[1]),
278
+ act_layer(inplace=True),
279
+ nn.Conv2d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)])
280
+ else:
281
+ self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)
282
+ self.bn1 = norm_layer(inplanes)
283
+ self.act1 = act_layer(inplace=True)
284
+ self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
285
+
286
+ # Stem pooling. The name 'maxpool' remains for weight compatibility.
287
+ if replace_stem_pool:
288
+ self.maxpool = nn.Sequential(*filter(None, [
289
+ nn.Conv2d(inplanes, inplanes, 3, stride=2, padding=1, bias=False),
290
+ norm_layer(inplanes),
291
+ act_layer(inplace=True)
292
+ ]))
293
+ else:
294
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
295
+
296
+ # Feature Blocks
297
+ channels = [64, 128, 256, 512]
298
+ stage_modules, stage_feature_info = make_blocks(
299
+ block,
300
+ channels,
301
+ layers,
302
+ inplanes,
303
+ cardinality=cardinality,
304
+ base_width=base_width,
305
+ output_stride=output_stride,
306
+ reduce_first=block_reduce_first,
307
+ avg_down=avg_down,
308
+ down_kernel_size=down_kernel_size,
309
+ act_layer=act_layer,
310
+ norm_layer=norm_layer,
311
+ **block_args,
312
+ )
313
+ for stage in stage_modules:
314
+ self.add_module(*stage) # layer1, layer2, etc
315
+ self.feature_info.extend(stage_feature_info)
316
+
317
+ self.act = hardball(radius2=512)
318
+ # self.act = nn.Hardtanh(max_val=5, min_val=-5, inplace=True)
319
+ # self.act = nn.ReLU(inplace=True)
320
+
321
+ # Head (Pooling and Classifier)
322
+ self.num_features = 512 * block.expansion
323
+ self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
324
+
325
+ self.init_weights(zero_init_last=zero_init_last)
326
+
327
+ @staticmethod
328
+ def from_pretrained(model_name: str, load_weights=True, **kwargs) -> 'ResNet':
329
+ entry_fn = model_entrypoint(model_name, 'resnet')
330
+ return entry_fn(pretrained=not load_weights, **kwargs)
331
+
332
+ @torch.jit.ignore
333
+ def init_weights(self, zero_init_last=True):
334
+ for n, m in self.named_modules():
335
+ if isinstance(m, nn.Conv2d):
336
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='linear') # 'linear' for non-relu activations
337
+ # nn.init.xavier_normal_(m.weight)
338
+ if zero_init_last:
339
+ for m in self.modules():
340
+ if hasattr(m, 'zero_init_last'):
341
+ m.zero_init_last()
342
+
343
+ @torch.jit.ignore
344
+ def group_matcher(self, coarse=False):
345
+ matcher = dict(stem=r'^conv1|bn1|maxpool', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)')
346
+ return matcher
347
+
348
+ @torch.jit.ignore
349
+ def set_grad_checkpointing(self, enable=True):
350
+ self.grad_checkpointing = enable
351
+
352
+ @torch.jit.ignore
353
+ def get_classifier(self, name_only=False):
354
+ return 'fc' if name_only else self.fc
355
+
356
+ def reset_classifier(self, num_classes, global_pool='avg'):
357
+ self.num_classes = num_classes
358
+ self.global_pool, self.fc = create_classifier(self.num_features, 99, # self.num_classes,
359
+ pool_type=global_pool)
360
+
361
+ def forward_features(self, x):
362
+ x = self.conv1(x)
363
+ x = self.bn1(x)
364
+ x = self.act1(x)
365
+ x = self.maxpool(x)
366
+
367
+ if self.grad_checkpointing and not torch.jit.is_scripting():
368
+ x = checkpoint_seq([self.layer1, self.layer2, self.layer3, self.layer4], x, flatten=True)
369
+ else:
370
+ x = self.layer1(x)
371
+ x = self.layer2(x)
372
+ x = self.layer3(x)
373
+ x = self.layer4(x)
374
+ return x
375
+
376
+ def forward_head(self, x, pre_logits: bool = False):
377
+ x = self.global_pool(x)
378
+ return x if pre_logits else self.fc(x)
379
+
380
+ def forward(self, x):
381
+ x = self.forward_features(x)
382
+ x = self.act(x)
383
+ x = self.forward_head(x)
384
+ return x
385
+