Image Classification
timm
PDE
ConvNet
liuyao commited on
Commit
3265f09
1 Parent(s): a703759

Upload QLNet_symmetry.ipynb

Browse files
Files changed (1) hide show
  1. QLNet_symmetry.ipynb +594 -540
QLNet_symmetry.ipynb CHANGED
@@ -1,551 +1,605 @@
1
  {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "id": "71b6152c",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "import torch, timm\n",
11
- "from qlnet import QLNet"
12
- ]
13
- },
14
- {
15
- "cell_type": "code",
16
- "execution_count": 2,
17
- "id": "4e7ed219",
18
- "metadata": {},
19
- "outputs": [],
20
- "source": [
21
- "m = QLNet()"
22
- ]
23
- },
24
- {
25
- "cell_type": "code",
26
- "execution_count": 3,
27
- "id": "3f703be8",
28
- "metadata": {},
29
- "outputs": [],
30
- "source": [
31
- "state_dict = torch.load('qlnet-50-v0.pth.tar')['state_dict']"
32
- ]
33
- },
34
- {
35
- "cell_type": "code",
36
- "execution_count": 4,
37
- "id": "435e2358",
38
- "metadata": {},
39
- "outputs": [
40
  {
41
- "data": {
42
- "text/plain": [
43
- "<All keys matched successfully>"
 
 
 
 
 
 
 
 
 
44
  ]
45
- },
46
- "execution_count": 4,
47
- "metadata": {},
48
- "output_type": "execute_result"
49
- }
50
- ],
51
- "source": [
52
- "m.load_state_dict(state_dict)"
53
- ]
54
- },
55
- {
56
- "cell_type": "code",
57
- "execution_count": 5,
58
- "id": "f14d984a",
59
- "metadata": {
60
- "scrolled": true
61
- },
62
- "outputs": [
63
  {
64
- "data": {
65
- "text/plain": [
66
- "QLNet(\n",
67
- " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
68
- " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
69
- " (act1): ReLU(inplace=True)\n",
70
- " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
71
- " (layer1): Sequential(\n",
72
- " (0): QLBlock(\n",
73
- " (conv1): ConvBN(\n",
74
- " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
75
- " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
76
- " )\n",
77
- " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n",
78
- " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
79
- " (conv3): ConvBN(\n",
80
- " (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
81
- " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
82
- " )\n",
83
- " (skip): Identity()\n",
84
- " (act3): hardball()\n",
85
- " )\n",
86
- " (1): QLBlock(\n",
87
- " (conv1): ConvBN(\n",
88
- " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
89
- " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
90
- " )\n",
91
- " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n",
92
- " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
93
- " (conv3): ConvBN(\n",
94
- " (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
95
- " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
96
- " )\n",
97
- " (skip): Identity()\n",
98
- " (act3): hardball()\n",
99
- " )\n",
100
- " (2): QLBlock(\n",
101
- " (conv1): ConvBN(\n",
102
- " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
103
- " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
104
- " )\n",
105
- " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n",
106
- " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
107
- " (conv3): ConvBN(\n",
108
- " (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
109
- " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
110
- " )\n",
111
- " (skip): Identity()\n",
112
- " (act3): hardball()\n",
113
- " )\n",
114
- " )\n",
115
- " (layer2): Sequential(\n",
116
- " (0): QLBlock(\n",
117
- " (conv1): ConvBN(\n",
118
- " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
119
- " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
120
- " )\n",
121
- " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256, bias=False)\n",
122
- " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
123
- " (conv3): ConvBN(\n",
124
- " (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
125
- " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
126
- " )\n",
127
- " (skip): ConvBN(\n",
128
- " (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
129
- " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
130
- " )\n",
131
- " (act3): hardball()\n",
132
- " )\n",
133
- " (1): QLBlock(\n",
134
- " (conv1): ConvBN(\n",
135
- " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
136
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
137
- " )\n",
138
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
139
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
140
- " (conv3): ConvBN(\n",
141
- " (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
142
- " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
143
- " )\n",
144
- " (skip): Identity()\n",
145
- " (act3): hardball()\n",
146
- " )\n",
147
- " (2): QLBlock(\n",
148
- " (conv1): ConvBN(\n",
149
- " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
150
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
151
- " )\n",
152
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
153
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
154
- " (conv3): ConvBN(\n",
155
- " (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
156
- " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
157
- " )\n",
158
- " (skip): Identity()\n",
159
- " (act3): hardball()\n",
160
- " )\n",
161
- " (3): QLBlock(\n",
162
- " (conv1): ConvBN(\n",
163
- " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
164
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
165
- " )\n",
166
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
167
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
168
- " (conv3): ConvBN(\n",
169
- " (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
170
- " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
171
- " )\n",
172
- " (skip): Identity()\n",
173
- " (act3): hardball()\n",
174
- " )\n",
175
- " )\n",
176
- " (layer3): Sequential(\n",
177
- " (0): QLBlock(\n",
178
- " (conv1): ConvBN(\n",
179
- " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
180
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
181
- " )\n",
182
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=512, bias=False)\n",
183
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
184
- " (conv3): ConvBN(\n",
185
- " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
186
- " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
187
- " )\n",
188
- " (skip): ConvBN(\n",
189
- " (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
190
- " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
191
- " )\n",
192
- " (act3): hardball()\n",
193
- " )\n",
194
- " (1): QLBlock(\n",
195
- " (conv1): ConvBN(\n",
196
- " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
197
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
198
- " )\n",
199
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
200
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
201
- " (conv3): ConvBN(\n",
202
- " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
203
- " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
204
- " )\n",
205
- " (skip): Identity()\n",
206
- " (act3): hardball()\n",
207
- " )\n",
208
- " (2): QLBlock(\n",
209
- " (conv1): ConvBN(\n",
210
- " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
211
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
212
- " )\n",
213
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
214
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
215
- " (conv3): ConvBN(\n",
216
- " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
217
- " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
218
- " )\n",
219
- " (skip): Identity()\n",
220
- " (act3): hardball()\n",
221
- " )\n",
222
- " (3): QLBlock(\n",
223
- " (conv1): ConvBN(\n",
224
- " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
225
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
226
- " )\n",
227
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
228
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
229
- " (conv3): ConvBN(\n",
230
- " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
231
- " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
232
- " )\n",
233
- " (skip): Identity()\n",
234
- " (act3): hardball()\n",
235
- " )\n",
236
- " (4): QLBlock(\n",
237
- " (conv1): ConvBN(\n",
238
- " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
239
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
240
- " )\n",
241
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
242
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
243
- " (conv3): ConvBN(\n",
244
- " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
245
- " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
246
- " )\n",
247
- " (skip): Identity()\n",
248
- " (act3): hardball()\n",
249
- " )\n",
250
- " (5): QLBlock(\n",
251
- " (conv1): ConvBN(\n",
252
- " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
253
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
254
- " )\n",
255
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
256
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
257
- " (conv3): ConvBN(\n",
258
- " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
259
- " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
260
- " )\n",
261
- " (skip): Identity()\n",
262
- " (act3): hardball()\n",
263
- " )\n",
264
- " )\n",
265
- " (layer4): Sequential(\n",
266
- " (0): QLBlock(\n",
267
- " (conv1): ConvBN(\n",
268
- " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
269
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
270
- " )\n",
271
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=512, bias=False)\n",
272
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
273
- " (conv3): ConvBN(\n",
274
- " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
275
- " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
276
- " )\n",
277
- " (skip): ConvBN(\n",
278
- " (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
279
- " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
280
- " )\n",
281
- " (act3): hardball()\n",
282
- " )\n",
283
- " (1): QLBlock(\n",
284
- " (conv1): ConvBN(\n",
285
- " (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
286
- " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
287
- " )\n",
288
- " (conv2): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n",
289
- " (bn2): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
290
- " (conv3): ConvBN(\n",
291
- " (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
292
- " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
293
- " )\n",
294
- " (skip): Identity()\n",
295
- " (act3): hardball()\n",
296
- " )\n",
297
- " (2): QLBlock(\n",
298
- " (conv1): ConvBN(\n",
299
- " (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
300
- " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
301
- " )\n",
302
- " (conv2): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n",
303
- " (bn2): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
304
- " (conv3): ConvBN(\n",
305
- " (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
306
- " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
307
- " )\n",
308
- " (skip): Identity()\n",
309
- " (act3): hardball()\n",
310
- " )\n",
311
- " )\n",
312
- " (act): hardball()\n",
313
- " (global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
314
- " (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
315
- ")"
316
  ]
317
- },
318
- "execution_count": 5,
319
- "metadata": {},
320
- "output_type": "execute_result"
321
- }
322
- ],
323
- "source": [
324
- "m.eval()"
325
- ]
326
- },
327
- {
328
- "cell_type": "code",
329
- "execution_count": 6,
330
- "id": "2099b937",
331
- "metadata": {},
332
- "outputs": [
333
  {
334
- "name": "stdout",
335
- "output_type": "stream",
336
- "text": [
337
- "layer1 >>\n",
338
- "torch.Size([512, 64, 1, 1])\n",
339
- "torch.Size([64, 512, 1, 1])\n",
340
- "torch.Size([512, 64, 1, 1])\n",
341
- "torch.Size([64, 512, 1, 1])\n",
342
- "torch.Size([512, 64, 1, 1])\n",
343
- "torch.Size([64, 512, 1, 1])\n",
344
- "layer2 >>\n",
345
- "torch.Size([512, 64, 1, 1])\n",
346
- "torch.Size([128, 512, 1, 1])\n",
347
- "torch.Size([128, 64, 1, 1])\n",
348
- "torch.Size([1024, 128, 1, 1])\n",
349
- "torch.Size([128, 1024, 1, 1])\n",
350
- "torch.Size([1024, 128, 1, 1])\n",
351
- "torch.Size([128, 1024, 1, 1])\n",
352
- "torch.Size([1024, 128, 1, 1])\n",
353
- "torch.Size([128, 1024, 1, 1])\n",
354
- "layer3 >>\n",
355
- "torch.Size([1024, 128, 1, 1])\n",
356
- "torch.Size([256, 1024, 1, 1])\n",
357
- "torch.Size([256, 128, 1, 1])\n",
358
- "torch.Size([1024, 256, 1, 1])\n",
359
- "torch.Size([256, 1024, 1, 1])\n",
360
- "torch.Size([1024, 256, 1, 1])\n",
361
- "torch.Size([256, 1024, 1, 1])\n",
362
- "torch.Size([1024, 256, 1, 1])\n",
363
- "torch.Size([256, 1024, 1, 1])\n",
364
- "torch.Size([1024, 256, 1, 1])\n",
365
- "torch.Size([256, 1024, 1, 1])\n",
366
- "torch.Size([1024, 256, 1, 1])\n",
367
- "torch.Size([256, 1024, 1, 1])\n",
368
- "layer4 >>\n",
369
- "torch.Size([1024, 256, 1, 1])\n",
370
- "torch.Size([512, 1024, 1, 1])\n",
371
- "torch.Size([512, 256, 1, 1])\n",
372
- "torch.Size([2048, 512, 1, 1])\n",
373
- "torch.Size([512, 2048, 1, 1])\n",
374
- "torch.Size([2048, 512, 1, 1])\n",
375
- "torch.Size([512, 2048, 1, 1])\n"
376
- ]
377
- }
378
- ],
379
- "source": [
380
- "# fuse ConvBN\n",
381
- "i = 1\n",
382
- "for layer in [m.layer1, m.layer2, m.layer3, m.layer4]:\n",
383
- " print(f'layer{i} >>')\n",
384
- " for block in layer:\n",
385
- " # Fuse the weights in conv1 and conv3\n",
386
- " block.conv1.fuse_bn()\n",
387
- " print(block.conv1.fused_weight.size())\n",
388
- " block.conv3.fuse_bn()\n",
389
- " print(block.conv3.fused_weight.size())\n",
390
- " if not isinstance(block.skip, torch.nn.Identity):\n",
391
- " layer[0].skip.fuse_bn()\n",
392
- " print(layer[0].skip.fused_weight.size())\n",
393
- " i += 1"
394
- ]
395
- },
396
- {
397
- "cell_type": "code",
398
- "execution_count": 7,
399
- "id": "b3a55f82",
400
- "metadata": {},
401
- "outputs": [],
402
- "source": [
403
- "inpt = torch.randn(5,3,224,224)"
404
- ]
405
- },
406
- {
407
- "cell_type": "code",
408
- "execution_count": 8,
409
- "id": "dccbf19c",
410
- "metadata": {},
411
- "outputs": [],
412
- "source": [
413
- "out_old = m(inpt)"
414
- ]
415
- },
416
- {
417
- "cell_type": "code",
418
- "execution_count": 10,
419
- "id": "f0c74a04",
420
- "metadata": {
421
- "scrolled": true
422
- },
423
- "outputs": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  {
425
- "data": {
426
- "text/plain": [
427
- "torch.Size([5, 1000])"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  ]
429
- },
430
- "execution_count": 10,
431
- "metadata": {},
432
- "output_type": "execute_result"
433
- }
434
- ],
435
- "source": [
436
- "out_old.size()"
437
- ]
438
- },
439
- {
440
- "cell_type": "code",
441
- "execution_count": 11,
442
- "id": "a5991c8f",
443
- "metadata": {},
444
- "outputs": [],
445
- "source": [
446
- "def apply_transform(block1, block2, Q, keep_identity=True):\n",
447
- " with torch.no_grad():\n",
448
- " # Ensure that the out_channels of block1 is equal to the in_channels of block2\n",
449
- " assert Q.size()[0] == Q.size()[1], \"Q needs to be a square matrix\"\n",
450
- " n = Q.size()[0]\n",
451
- " assert block1.conv3.conv.out_channels == n and block2.conv1.conv.in_channels == n, \"Mismatched channels between blocks\"\n",
452
- "\n",
453
- " n = block1.conv3.conv.out_channels\n",
454
- " \n",
455
- " # Calculate the inverse of Q\n",
456
- " Q_inv = torch.inverse(Q)\n",
457
- "\n",
458
- " # Modify the weights of conv layers in block1\n",
459
- " block1.conv3.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.conv3.fused_weight.data)\n",
460
- " block1.conv3.fused_bias.data = torch.einsum('ij,j->i', Q, block1.conv3.fused_bias.data)\n",
461
- " \n",
462
- " if isinstance(block1.skip, torch.nn.Identity):\n",
463
- " if not keep_identity:\n",
464
- " block1.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)\n",
465
- " block1.skip.weight.data = Q.unsqueeze(-1).unsqueeze(-1)\n",
466
- " else:\n",
467
- " block1.skip.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.skip.fused_weight.data)\n",
468
- " block1.skip.fused_bias.data = torch.einsum('ij,j->i', Q, block1.skip.fused_bias.data)\n",
469
- "\n",
470
- " # Modify the weights of conv layers in block2\n",
471
- " block2.conv1.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.conv1.fused_weight.data)\n",
472
- " \n",
473
- " if isinstance(block2.skip, torch.nn.Identity):\n",
474
- " if not keep_identity:\n",
475
- " block2.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)\n",
476
- " block2.skip.weight.data = Q_inv.unsqueeze(-1).unsqueeze(-1)\n",
477
- " else:\n",
478
- " block2.skip.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.skip.fused_weight.data)\n"
479
- ]
480
- },
481
- {
482
- "cell_type": "code",
483
- "execution_count": 12,
484
- "id": "dd96acd7",
485
- "metadata": {},
486
- "outputs": [],
487
- "source": [
488
- "Q = torch.nn.init.orthogonal_(torch.empty(256, 256))\n",
489
- "for i in range(5):\n",
490
- " apply_transform(m.layer3[i], m.layer3[i+1], Q, True)\n",
491
- "apply_transform(m.layer3[5], m.layer4[0], Q, True)"
492
- ]
493
- },
494
- {
495
- "cell_type": "code",
496
- "execution_count": 13,
497
- "id": "e5d3628d",
498
- "metadata": {},
499
- "outputs": [
 
 
 
 
 
 
 
 
 
 
 
500
  {
501
- "name": "stdout",
502
- "output_type": "stream",
503
- "text": [
504
- "6.0558319091796875e-05\n"
505
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  }
507
- ],
508
- "source": [
509
- "out_new = m(inpt)\n",
510
- "print((out_new - out_old).abs().max().item())"
511
- ]
512
- },
513
- {
514
- "cell_type": "code",
515
- "execution_count": null,
516
- "id": "9fce3a38",
517
- "metadata": {},
518
- "outputs": [],
519
- "source": []
520
- },
521
- {
522
- "cell_type": "code",
523
- "execution_count": null,
524
- "id": "5a54fe8b",
525
- "metadata": {},
526
- "outputs": [],
527
- "source": []
528
- }
529
- ],
530
- "metadata": {
531
- "kernelspec": {
532
- "display_name": "Python 3 (ipykernel)",
533
- "language": "python",
534
- "name": "python3"
535
  },
536
- "language_info": {
537
- "codemirror_mode": {
538
- "name": "ipython",
539
- "version": 3
540
- },
541
- "file_extension": ".py",
542
- "mimetype": "text/x-python",
543
- "name": "python",
544
- "nbconvert_exporter": "python",
545
- "pygments_lexer": "ipython3",
546
- "version": "3.10.6"
547
- }
548
- },
549
- "nbformat": 4,
550
- "nbformat_minor": 5
551
- }
 
1
  {
2
+ "cells": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "71b6152c",
7
+ "metadata": {
8
+ "id": "71b6152c"
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "# Install PyTorch and timm\n",
13
+ "!pip install torch timm\n",
14
+ "\n",
15
+ "!git clone https://huggingface.co/liuyao/QLNet"
16
  ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "source": [
21
+ "# Navigate to the repository directory\n",
22
+ "import os\n",
23
+ "os.chdir('QLNet')"
24
+ ],
25
+ "metadata": {
26
+ "id": "pmVezdbxzcw7"
27
+ },
28
+ "id": "pmVezdbxzcw7",
29
+ "execution_count": 2,
30
+ "outputs": []
31
+ },
 
 
 
32
  {
33
+ "cell_type": "code",
34
+ "source": [
35
+ "import torch, timm\n",
36
+ "from qlnet import QLNet"
37
+ ],
38
+ "metadata": {
39
+ "id": "7vDt28zlzi0r"
40
+ },
41
+ "id": "7vDt28zlzi0r",
42
+ "execution_count": 5,
43
+ "outputs": []
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 9,
48
+ "id": "3f703be8",
49
+ "metadata": {
50
+ "colab": {
51
+ "base_uri": "https://localhost:8080/"
52
+ },
53
+ "id": "3f703be8",
54
+ "outputId": "de73c734-305f-4955-fe69-7b7253b4f95e"
55
+ },
56
+ "outputs": [
57
+ {
58
+ "output_type": "stream",
59
+ "name": "stdout",
60
+ "text": [
61
+ "Using device: cpu\n"
62
+ ]
63
+ },
64
+ {
65
+ "output_type": "execute_result",
66
+ "data": {
67
+ "text/plain": [
68
+ "<All keys matched successfully>"
69
+ ]
70
+ },
71
+ "metadata": {},
72
+ "execution_count": 9
73
+ }
74
+ ],
75
+ "source": [
76
+ "# Check if GPU is available and set the device accordingly\n",
77
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
78
+ "print(f\"Using device: {device}\")\n",
79
+ "\n",
80
+ "# Create an instance of your model and load it to the device\n",
81
+ "model = QLNet().to(device)\n",
82
+ "\n",
83
+ "# Load the model weights\n",
84
+ "model.load_state_dict(torch.load('qlnet-50-v0.pth.tar', map_location=device)['state_dict'])"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  ]
86
+ },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  {
88
+ "cell_type": "code",
89
+ "execution_count": 10,
90
+ "id": "f14d984a",
91
+ "metadata": {
92
+ "scrolled": true,
93
+ "colab": {
94
+ "base_uri": "https://localhost:8080/"
95
+ },
96
+ "id": "f14d984a",
97
+ "outputId": "efc70253-4bc0-4d0c-92d8-d247118138bc"
98
+ },
99
+ "outputs": [
100
+ {
101
+ "output_type": "execute_result",
102
+ "data": {
103
+ "text/plain": [
104
+ "QLNet(\n",
105
+ " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
106
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
107
+ " (act1): ReLU(inplace=True)\n",
108
+ " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
109
+ " (layer1): Sequential(\n",
110
+ " (0): QLBlock(\n",
111
+ " (conv1): ConvBN(\n",
112
+ " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
113
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
114
+ " )\n",
115
+ " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n",
116
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
117
+ " (conv3): ConvBN(\n",
118
+ " (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
119
+ " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
120
+ " )\n",
121
+ " (skip): Identity()\n",
122
+ " (act3): hardball()\n",
123
+ " )\n",
124
+ " (1): QLBlock(\n",
125
+ " (conv1): ConvBN(\n",
126
+ " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
127
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
128
+ " )\n",
129
+ " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n",
130
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
131
+ " (conv3): ConvBN(\n",
132
+ " (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
133
+ " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
134
+ " )\n",
135
+ " (skip): Identity()\n",
136
+ " (act3): hardball()\n",
137
+ " )\n",
138
+ " (2): QLBlock(\n",
139
+ " (conv1): ConvBN(\n",
140
+ " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
141
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
142
+ " )\n",
143
+ " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n",
144
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
145
+ " (conv3): ConvBN(\n",
146
+ " (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
147
+ " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
148
+ " )\n",
149
+ " (skip): Identity()\n",
150
+ " (act3): hardball()\n",
151
+ " )\n",
152
+ " )\n",
153
+ " (layer2): Sequential(\n",
154
+ " (0): QLBlock(\n",
155
+ " (conv1): ConvBN(\n",
156
+ " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
157
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
158
+ " )\n",
159
+ " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256, bias=False)\n",
160
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
161
+ " (conv3): ConvBN(\n",
162
+ " (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
163
+ " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
164
+ " )\n",
165
+ " (skip): ConvBN(\n",
166
+ " (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
167
+ " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
168
+ " )\n",
169
+ " (act3): hardball()\n",
170
+ " )\n",
171
+ " (1): QLBlock(\n",
172
+ " (conv1): ConvBN(\n",
173
+ " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
174
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
175
+ " )\n",
176
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
177
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
178
+ " (conv3): ConvBN(\n",
179
+ " (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
180
+ " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
181
+ " )\n",
182
+ " (skip): Identity()\n",
183
+ " (act3): hardball()\n",
184
+ " )\n",
185
+ " (2): QLBlock(\n",
186
+ " (conv1): ConvBN(\n",
187
+ " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
188
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
189
+ " )\n",
190
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
191
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
192
+ " (conv3): ConvBN(\n",
193
+ " (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
194
+ " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
195
+ " )\n",
196
+ " (skip): Identity()\n",
197
+ " (act3): hardball()\n",
198
+ " )\n",
199
+ " (3): QLBlock(\n",
200
+ " (conv1): ConvBN(\n",
201
+ " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
202
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
203
+ " )\n",
204
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
205
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
206
+ " (conv3): ConvBN(\n",
207
+ " (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
208
+ " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
209
+ " )\n",
210
+ " (skip): Identity()\n",
211
+ " (act3): hardball()\n",
212
+ " )\n",
213
+ " )\n",
214
+ " (layer3): Sequential(\n",
215
+ " (0): QLBlock(\n",
216
+ " (conv1): ConvBN(\n",
217
+ " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
218
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
219
+ " )\n",
220
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=512, bias=False)\n",
221
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
222
+ " (conv3): ConvBN(\n",
223
+ " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
224
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
225
+ " )\n",
226
+ " (skip): ConvBN(\n",
227
+ " (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
228
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
229
+ " )\n",
230
+ " (act3): hardball()\n",
231
+ " )\n",
232
+ " (1): QLBlock(\n",
233
+ " (conv1): ConvBN(\n",
234
+ " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
235
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
236
+ " )\n",
237
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
238
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
239
+ " (conv3): ConvBN(\n",
240
+ " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
241
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
242
+ " )\n",
243
+ " (skip): Identity()\n",
244
+ " (act3): hardball()\n",
245
+ " )\n",
246
+ " (2): QLBlock(\n",
247
+ " (conv1): ConvBN(\n",
248
+ " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
249
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
250
+ " )\n",
251
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
252
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
253
+ " (conv3): ConvBN(\n",
254
+ " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
255
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
256
+ " )\n",
257
+ " (skip): Identity()\n",
258
+ " (act3): hardball()\n",
259
+ " )\n",
260
+ " (3): QLBlock(\n",
261
+ " (conv1): ConvBN(\n",
262
+ " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
263
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
264
+ " )\n",
265
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
266
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
267
+ " (conv3): ConvBN(\n",
268
+ " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
269
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
270
+ " )\n",
271
+ " (skip): Identity()\n",
272
+ " (act3): hardball()\n",
273
+ " )\n",
274
+ " (4): QLBlock(\n",
275
+ " (conv1): ConvBN(\n",
276
+ " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
277
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
278
+ " )\n",
279
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
280
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
281
+ " (conv3): ConvBN(\n",
282
+ " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
283
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
284
+ " )\n",
285
+ " (skip): Identity()\n",
286
+ " (act3): hardball()\n",
287
+ " )\n",
288
+ " (5): QLBlock(\n",
289
+ " (conv1): ConvBN(\n",
290
+ " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
291
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
292
+ " )\n",
293
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
294
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
295
+ " (conv3): ConvBN(\n",
296
+ " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
297
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
298
+ " )\n",
299
+ " (skip): Identity()\n",
300
+ " (act3): hardball()\n",
301
+ " )\n",
302
+ " )\n",
303
+ " (layer4): Sequential(\n",
304
+ " (0): QLBlock(\n",
305
+ " (conv1): ConvBN(\n",
306
+ " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
307
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
308
+ " )\n",
309
+ " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=512, bias=False)\n",
310
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
311
+ " (conv3): ConvBN(\n",
312
+ " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
313
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
314
+ " )\n",
315
+ " (skip): ConvBN(\n",
316
+ " (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
317
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
318
+ " )\n",
319
+ " (act3): hardball()\n",
320
+ " )\n",
321
+ " (1): QLBlock(\n",
322
+ " (conv1): ConvBN(\n",
323
+ " (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
324
+ " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
325
+ " )\n",
326
+ " (conv2): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n",
327
+ " (bn2): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
328
+ " (conv3): ConvBN(\n",
329
+ " (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
330
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
331
+ " )\n",
332
+ " (skip): Identity()\n",
333
+ " (act3): hardball()\n",
334
+ " )\n",
335
+ " (2): QLBlock(\n",
336
+ " (conv1): ConvBN(\n",
337
+ " (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
338
+ " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
339
+ " )\n",
340
+ " (conv2): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n",
341
+ " (bn2): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
342
+ " (conv3): ConvBN(\n",
343
+ " (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
344
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
345
+ " )\n",
346
+ " (skip): Identity()\n",
347
+ " (act3): hardball()\n",
348
+ " )\n",
349
+ " )\n",
350
+ " (act): hardball()\n",
351
+ " (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
352
+ " (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
353
+ ")"
354
+ ]
355
+ },
356
+ "metadata": {},
357
+ "execution_count": 10
358
+ }
359
+ ],
360
+ "source": [
361
+ "model.eval()"
362
+ ]
363
+ },
364
  {
365
+ "cell_type": "code",
366
+ "execution_count": 12,
367
+ "id": "2099b937",
368
+ "metadata": {
369
+ "colab": {
370
+ "base_uri": "https://localhost:8080/"
371
+ },
372
+ "id": "2099b937",
373
+ "outputId": "ac4557a4-ed2a-47b2-eca7-d9a337fff3f1"
374
+ },
375
+ "outputs": [
376
+ {
377
+ "output_type": "stream",
378
+ "name": "stdout",
379
+ "text": [
380
+ "layer1 >>\n",
381
+ "torch.Size([512, 64, 1, 1])\n",
382
+ "torch.Size([64, 512, 1, 1])\n",
383
+ "torch.Size([512, 64, 1, 1])\n",
384
+ "torch.Size([64, 512, 1, 1])\n",
385
+ "torch.Size([512, 64, 1, 1])\n",
386
+ "torch.Size([64, 512, 1, 1])\n",
387
+ "layer2 >>\n",
388
+ "torch.Size([512, 64, 1, 1])\n",
389
+ "torch.Size([128, 512, 1, 1])\n",
390
+ "torch.Size([128, 64, 1, 1])\n",
391
+ "torch.Size([1024, 128, 1, 1])\n",
392
+ "torch.Size([128, 1024, 1, 1])\n",
393
+ "torch.Size([1024, 128, 1, 1])\n",
394
+ "torch.Size([128, 1024, 1, 1])\n",
395
+ "torch.Size([1024, 128, 1, 1])\n",
396
+ "torch.Size([128, 1024, 1, 1])\n",
397
+ "layer3 >>\n",
398
+ "torch.Size([1024, 128, 1, 1])\n",
399
+ "torch.Size([256, 1024, 1, 1])\n",
400
+ "torch.Size([256, 128, 1, 1])\n",
401
+ "torch.Size([1024, 256, 1, 1])\n",
402
+ "torch.Size([256, 1024, 1, 1])\n",
403
+ "torch.Size([1024, 256, 1, 1])\n",
404
+ "torch.Size([256, 1024, 1, 1])\n",
405
+ "torch.Size([1024, 256, 1, 1])\n",
406
+ "torch.Size([256, 1024, 1, 1])\n",
407
+ "torch.Size([1024, 256, 1, 1])\n",
408
+ "torch.Size([256, 1024, 1, 1])\n",
409
+ "torch.Size([1024, 256, 1, 1])\n",
410
+ "torch.Size([256, 1024, 1, 1])\n",
411
+ "layer4 >>\n",
412
+ "torch.Size([1024, 256, 1, 1])\n",
413
+ "torch.Size([512, 1024, 1, 1])\n",
414
+ "torch.Size([512, 256, 1, 1])\n",
415
+ "torch.Size([2048, 512, 1, 1])\n",
416
+ "torch.Size([512, 2048, 1, 1])\n",
417
+ "torch.Size([2048, 512, 1, 1])\n",
418
+ "torch.Size([512, 2048, 1, 1])\n"
419
+ ]
420
+ }
421
+ ],
422
+ "source": [
423
+ "# fuse ConvBN\n",
424
+ "i = 1\n",
425
+ "for layer in [model.layer1, model.layer2, model.layer3, model.layer4]:\n",
426
+ " print(f'layer{i} >>')\n",
427
+ " for block in layer:\n",
428
+ " # Fuse the weights in conv1 and conv3\n",
429
+ " block.conv1.fuse_bn()\n",
430
+ " print(block.conv1.fused_weight.size())\n",
431
+ " block.conv3.fuse_bn()\n",
432
+ " print(block.conv3.fused_weight.size())\n",
433
+ " if not isinstance(block.skip, torch.nn.Identity):\n",
434
+ " layer[0].skip.fuse_bn()\n",
435
+ " print(layer[0].skip.fused_weight.size())\n",
436
+ " i += 1"
437
  ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": 13,
442
+ "id": "b3a55f82",
443
+ "metadata": {
444
+ "id": "b3a55f82"
445
+ },
446
+ "outputs": [],
447
+ "source": [
448
+ "x = torch.randn(5,3,224,224)"
449
+ ]
450
+ },
451
+ {
452
+ "cell_type": "code",
453
+ "execution_count": 15,
454
+ "id": "dccbf19c",
455
+ "metadata": {
456
+ "colab": {
457
+ "base_uri": "https://localhost:8080/"
458
+ },
459
+ "id": "dccbf19c",
460
+ "outputId": "4a5409f4-761b-4682-a5be-5f55fd595135"
461
+ },
462
+ "outputs": [
463
+ {
464
+ "output_type": "stream",
465
+ "name": "stdout",
466
+ "text": [
467
+ "torch.Size([5, 1000])\n"
468
+ ]
469
+ }
470
+ ],
471
+ "source": [
472
+ "y_old = model(x)\n",
473
+ "print(y_old.size())"
474
+ ]
475
+ },
476
+ {
477
+ "cell_type": "code",
478
+ "execution_count": 16,
479
+ "id": "a5991c8f",
480
+ "metadata": {
481
+ "id": "a5991c8f"
482
+ },
483
+ "outputs": [],
484
+ "source": [
485
+ "def apply_transform(block1, block2, Q, keep_identity=True):\n",
486
+ " with torch.no_grad():\n",
487
+ " # Ensure that the out_channels of block1 is equal to the in_channels of block2\n",
488
+ " assert Q.size()[0] == Q.size()[1], \"Q needs to be a square matrix\"\n",
489
+ " n = Q.size()[0]\n",
490
+ " assert block1.conv3.conv.out_channels == n and block2.conv1.conv.in_channels == n, \"Mismatched channels between blocks\"\n",
491
+ "\n",
492
+ " n = block1.conv3.conv.out_channels\n",
493
+ "\n",
494
+ " # Calculate the inverse of Q\n",
495
+ " Q_inv = torch.inverse(Q)\n",
496
+ "\n",
497
+ " # Modify the weights of conv layers in block1\n",
498
+ " block1.conv3.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.conv3.fused_weight.data)\n",
499
+ " block1.conv3.fused_bias.data = torch.einsum('ij,j->i', Q, block1.conv3.fused_bias.data)\n",
500
+ "\n",
501
+ " if isinstance(block1.skip, torch.nn.Identity):\n",
502
+ " if not keep_identity:\n",
503
+ " block1.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)\n",
504
+ " block1.skip.weight.data = Q.unsqueeze(-1).unsqueeze(-1)\n",
505
+ " else:\n",
506
+ " block1.skip.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.skip.fused_weight.data)\n",
507
+ " block1.skip.fused_bias.data = torch.einsum('ij,j->i', Q, block1.skip.fused_bias.data)\n",
508
+ "\n",
509
+ " # Modify the weights of conv layers in block2\n",
510
+ " block2.conv1.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.conv1.fused_weight.data)\n",
511
+ "\n",
512
+ " if isinstance(block2.skip, torch.nn.Identity):\n",
513
+ " if not keep_identity:\n",
514
+ " block2.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)\n",
515
+ " block2.skip.weight.data = Q_inv.unsqueeze(-1).unsqueeze(-1)\n",
516
+ " else:\n",
517
+ " block2.skip.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.skip.fused_weight.data)\n"
518
+ ]
519
+ },
520
  {
521
+ "cell_type": "code",
522
+ "execution_count": 17,
523
+ "id": "dd96acd7",
524
+ "metadata": {
525
+ "id": "dd96acd7"
526
+ },
527
+ "outputs": [],
528
+ "source": [
529
+ "Q = torch.nn.init.orthogonal_(torch.empty(256, 256))\n",
530
+ "for i in range(5):\n",
531
+ " apply_transform(model.layer3[i], model.layer3[i+1], Q, True)\n",
532
+ "apply_transform(model.layer3[5], model.layer4[0], Q, True)"
533
+ ]
534
+ },
535
+ {
536
+ "cell_type": "code",
537
+ "execution_count": 18,
538
+ "id": "e5d3628d",
539
+ "metadata": {
540
+ "colab": {
541
+ "base_uri": "https://localhost:8080/"
542
+ },
543
+ "id": "e5d3628d",
544
+ "outputId": "667cfe17-e3fb-4009-9553-a765c6377321"
545
+ },
546
+ "outputs": [
547
+ {
548
+ "output_type": "stream",
549
+ "name": "stdout",
550
+ "text": [
551
+ "8.472800254821777e-05\n"
552
+ ]
553
+ }
554
+ ],
555
+ "source": [
556
+ "y_new = model(x)\n",
557
+ "print((y_new - y_old).abs().max().item())"
558
+ ]
559
+ },
560
+ {
561
+ "cell_type": "code",
562
+ "execution_count": null,
563
+ "id": "9fce3a38",
564
+ "metadata": {
565
+ "id": "9fce3a38"
566
+ },
567
+ "outputs": [],
568
+ "source": []
569
+ },
570
+ {
571
+ "cell_type": "code",
572
+ "execution_count": null,
573
+ "id": "5a54fe8b",
574
+ "metadata": {
575
+ "id": "5a54fe8b"
576
+ },
577
+ "outputs": [],
578
+ "source": []
579
+ }
580
+ ],
581
+ "metadata": {
582
+ "kernelspec": {
583
+ "display_name": "Python 3 (ipykernel)",
584
+ "language": "python",
585
+ "name": "python3"
586
+ },
587
+ "language_info": {
588
+ "codemirror_mode": {
589
+ "name": "ipython",
590
+ "version": 3
591
+ },
592
+ "file_extension": ".py",
593
+ "mimetype": "text/x-python",
594
+ "name": "python",
595
+ "nbconvert_exporter": "python",
596
+ "pygments_lexer": "ipython3",
597
+ "version": "3.10.6"
598
+ },
599
+ "colab": {
600
+ "provenance": []
601
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
  },
603
+ "nbformat": 4,
604
+ "nbformat_minor": 5
605
+ }