Image Classification
timm
PDE
ConvNet
liuyao commited on
Commit
3f1b078
1 Parent(s): 3a6393b

Upload 2 files

Browse files
Files changed (2) hide show
  1. QLNet_symmetry.ipynb +540 -594
  2. qlnet.py +4 -4
QLNet_symmetry.ipynb CHANGED
@@ -1,605 +1,551 @@
1
  {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "id": "71b6152c",
7
- "metadata": {
8
- "id": "71b6152c"
9
- },
10
- "outputs": [],
11
- "source": [
12
- "# Install PyTorch and timm\n",
13
- "!pip install torch timm\n",
14
- "\n",
15
- "!git clone https://huggingface.co/liuyao/QLNet"
16
- ]
17
- },
18
- {
19
- "cell_type": "code",
20
- "source": [
21
- "# Navigate to the repository directory\n",
22
- "import os\n",
23
- "os.chdir('QLNet')"
24
- ],
25
- "metadata": {
26
- "id": "pmVezdbxzcw7"
27
- },
28
- "id": "pmVezdbxzcw7",
29
- "execution_count": 2,
30
- "outputs": []
31
- },
32
- {
33
- "cell_type": "code",
34
- "source": [
35
- "import torch, timm\n",
36
- "from qlnet import QLNet"
37
- ],
38
- "metadata": {
39
- "id": "7vDt28zlzi0r"
40
- },
41
- "id": "7vDt28zlzi0r",
42
- "execution_count": 5,
43
- "outputs": []
44
- },
45
- {
46
- "cell_type": "code",
47
- "execution_count": 9,
48
- "id": "3f703be8",
49
- "metadata": {
50
- "colab": {
51
- "base_uri": "https://localhost:8080/"
52
- },
53
- "id": "3f703be8",
54
- "outputId": "de73c734-305f-4955-fe69-7b7253b4f95e"
55
- },
56
- "outputs": [
57
- {
58
- "output_type": "stream",
59
- "name": "stdout",
60
- "text": [
61
- "Using device: cpu\n"
62
- ]
63
- },
64
- {
65
- "output_type": "execute_result",
66
- "data": {
67
- "text/plain": [
68
- "<All keys matched successfully>"
69
- ]
70
- },
71
- "metadata": {},
72
- "execution_count": 9
73
- }
74
- ],
75
- "source": [
76
- "# Check if GPU is available and set the device accordingly\n",
77
- "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
78
- "print(f\"Using device: {device}\")\n",
79
- "\n",
80
- "# Create an instance of your model and load it to the device\n",
81
- "model = QLNet().to(device)\n",
82
- "\n",
83
- "# Load the model weights\n",
84
- "model.load_state_dict(torch.load('qlnet-50-v0.pth.tar', map_location=device)['state_dict'])"
85
- ]
86
- },
87
- {
88
- "cell_type": "code",
89
- "execution_count": 10,
90
- "id": "f14d984a",
91
- "metadata": {
92
- "scrolled": true,
93
- "colab": {
94
- "base_uri": "https://localhost:8080/"
95
- },
96
- "id": "f14d984a",
97
- "outputId": "efc70253-4bc0-4d0c-92d8-d247118138bc"
98
- },
99
- "outputs": [
100
- {
101
- "output_type": "execute_result",
102
- "data": {
103
- "text/plain": [
104
- "QLNet(\n",
105
- " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
106
- " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
107
- " (act1): ReLU(inplace=True)\n",
108
- " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
109
- " (layer1): Sequential(\n",
110
- " (0): QLBlock(\n",
111
- " (conv1): ConvBN(\n",
112
- " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
113
- " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
114
- " )\n",
115
- " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n",
116
- " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
117
- " (conv3): ConvBN(\n",
118
- " (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
119
- " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
120
- " )\n",
121
- " (skip): Identity()\n",
122
- " (act3): hardball()\n",
123
- " )\n",
124
- " (1): QLBlock(\n",
125
- " (conv1): ConvBN(\n",
126
- " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
127
- " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
128
- " )\n",
129
- " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n",
130
- " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
131
- " (conv3): ConvBN(\n",
132
- " (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
133
- " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
134
- " )\n",
135
- " (skip): Identity()\n",
136
- " (act3): hardball()\n",
137
- " )\n",
138
- " (2): QLBlock(\n",
139
- " (conv1): ConvBN(\n",
140
- " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
141
- " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
142
- " )\n",
143
- " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n",
144
- " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
145
- " (conv3): ConvBN(\n",
146
- " (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
147
- " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
148
- " )\n",
149
- " (skip): Identity()\n",
150
- " (act3): hardball()\n",
151
- " )\n",
152
- " )\n",
153
- " (layer2): Sequential(\n",
154
- " (0): QLBlock(\n",
155
- " (conv1): ConvBN(\n",
156
- " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
157
- " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
158
- " )\n",
159
- " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256, bias=False)\n",
160
- " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
161
- " (conv3): ConvBN(\n",
162
- " (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
163
- " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
164
- " )\n",
165
- " (skip): ConvBN(\n",
166
- " (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
167
- " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
168
- " )\n",
169
- " (act3): hardball()\n",
170
- " )\n",
171
- " (1): QLBlock(\n",
172
- " (conv1): ConvBN(\n",
173
- " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
174
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
175
- " )\n",
176
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
177
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
178
- " (conv3): ConvBN(\n",
179
- " (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
180
- " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
181
- " )\n",
182
- " (skip): Identity()\n",
183
- " (act3): hardball()\n",
184
- " )\n",
185
- " (2): QLBlock(\n",
186
- " (conv1): ConvBN(\n",
187
- " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
188
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
189
- " )\n",
190
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
191
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
192
- " (conv3): ConvBN(\n",
193
- " (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
194
- " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
195
- " )\n",
196
- " (skip): Identity()\n",
197
- " (act3): hardball()\n",
198
- " )\n",
199
- " (3): QLBlock(\n",
200
- " (conv1): ConvBN(\n",
201
- " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
202
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
203
- " )\n",
204
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
205
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
206
- " (conv3): ConvBN(\n",
207
- " (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
208
- " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
209
- " )\n",
210
- " (skip): Identity()\n",
211
- " (act3): hardball()\n",
212
- " )\n",
213
- " )\n",
214
- " (layer3): Sequential(\n",
215
- " (0): QLBlock(\n",
216
- " (conv1): ConvBN(\n",
217
- " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
218
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
219
- " )\n",
220
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=512, bias=False)\n",
221
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
222
- " (conv3): ConvBN(\n",
223
- " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
224
- " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
225
- " )\n",
226
- " (skip): ConvBN(\n",
227
- " (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
228
- " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
229
- " )\n",
230
- " (act3): hardball()\n",
231
- " )\n",
232
- " (1): QLBlock(\n",
233
- " (conv1): ConvBN(\n",
234
- " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
235
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
236
- " )\n",
237
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
238
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
239
- " (conv3): ConvBN(\n",
240
- " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
241
- " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
242
- " )\n",
243
- " (skip): Identity()\n",
244
- " (act3): hardball()\n",
245
- " )\n",
246
- " (2): QLBlock(\n",
247
- " (conv1): ConvBN(\n",
248
- " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
249
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
250
- " )\n",
251
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
252
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
253
- " (conv3): ConvBN(\n",
254
- " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
255
- " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
256
- " )\n",
257
- " (skip): Identity()\n",
258
- " (act3): hardball()\n",
259
- " )\n",
260
- " (3): QLBlock(\n",
261
- " (conv1): ConvBN(\n",
262
- " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
263
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
264
- " )\n",
265
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
266
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
267
- " (conv3): ConvBN(\n",
268
- " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
269
- " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
270
- " )\n",
271
- " (skip): Identity()\n",
272
- " (act3): hardball()\n",
273
- " )\n",
274
- " (4): QLBlock(\n",
275
- " (conv1): ConvBN(\n",
276
- " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
277
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
278
- " )\n",
279
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
280
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
281
- " (conv3): ConvBN(\n",
282
- " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
283
- " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
284
- " )\n",
285
- " (skip): Identity()\n",
286
- " (act3): hardball()\n",
287
- " )\n",
288
- " (5): QLBlock(\n",
289
- " (conv1): ConvBN(\n",
290
- " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
291
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
292
- " )\n",
293
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
294
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
295
- " (conv3): ConvBN(\n",
296
- " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
297
- " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
298
- " )\n",
299
- " (skip): Identity()\n",
300
- " (act3): hardball()\n",
301
- " )\n",
302
- " )\n",
303
- " (layer4): Sequential(\n",
304
- " (0): QLBlock(\n",
305
- " (conv1): ConvBN(\n",
306
- " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
307
- " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
308
- " )\n",
309
- " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=512, bias=False)\n",
310
- " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
311
- " (conv3): ConvBN(\n",
312
- " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
313
- " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
314
- " )\n",
315
- " (skip): ConvBN(\n",
316
- " (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
317
- " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
318
- " )\n",
319
- " (act3): hardball()\n",
320
- " )\n",
321
- " (1): QLBlock(\n",
322
- " (conv1): ConvBN(\n",
323
- " (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
324
- " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
325
- " )\n",
326
- " (conv2): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n",
327
- " (bn2): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
328
- " (conv3): ConvBN(\n",
329
- " (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
330
- " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
331
- " )\n",
332
- " (skip): Identity()\n",
333
- " (act3): hardball()\n",
334
- " )\n",
335
- " (2): QLBlock(\n",
336
- " (conv1): ConvBN(\n",
337
- " (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
338
- " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
339
- " )\n",
340
- " (conv2): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n",
341
- " (bn2): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
342
- " (conv3): ConvBN(\n",
343
- " (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
344
- " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
345
- " )\n",
346
- " (skip): Identity()\n",
347
- " (act3): hardball()\n",
348
- " )\n",
349
- " )\n",
350
- " (act): hardball()\n",
351
- " (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
352
- " (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
353
- ")"
354
- ]
355
- },
356
- "metadata": {},
357
- "execution_count": 10
358
- }
359
- ],
360
- "source": [
361
- "model.eval()"
362
- ]
363
- },
364
- {
365
- "cell_type": "code",
366
- "execution_count": 12,
367
- "id": "2099b937",
368
- "metadata": {
369
- "colab": {
370
- "base_uri": "https://localhost:8080/"
371
- },
372
- "id": "2099b937",
373
- "outputId": "ac4557a4-ed2a-47b2-eca7-d9a337fff3f1"
374
- },
375
- "outputs": [
376
- {
377
- "output_type": "stream",
378
- "name": "stdout",
379
- "text": [
380
- "layer1 >>\n",
381
- "torch.Size([512, 64, 1, 1])\n",
382
- "torch.Size([64, 512, 1, 1])\n",
383
- "torch.Size([512, 64, 1, 1])\n",
384
- "torch.Size([64, 512, 1, 1])\n",
385
- "torch.Size([512, 64, 1, 1])\n",
386
- "torch.Size([64, 512, 1, 1])\n",
387
- "layer2 >>\n",
388
- "torch.Size([512, 64, 1, 1])\n",
389
- "torch.Size([128, 512, 1, 1])\n",
390
- "torch.Size([128, 64, 1, 1])\n",
391
- "torch.Size([1024, 128, 1, 1])\n",
392
- "torch.Size([128, 1024, 1, 1])\n",
393
- "torch.Size([1024, 128, 1, 1])\n",
394
- "torch.Size([128, 1024, 1, 1])\n",
395
- "torch.Size([1024, 128, 1, 1])\n",
396
- "torch.Size([128, 1024, 1, 1])\n",
397
- "layer3 >>\n",
398
- "torch.Size([1024, 128, 1, 1])\n",
399
- "torch.Size([256, 1024, 1, 1])\n",
400
- "torch.Size([256, 128, 1, 1])\n",
401
- "torch.Size([1024, 256, 1, 1])\n",
402
- "torch.Size([256, 1024, 1, 1])\n",
403
- "torch.Size([1024, 256, 1, 1])\n",
404
- "torch.Size([256, 1024, 1, 1])\n",
405
- "torch.Size([1024, 256, 1, 1])\n",
406
- "torch.Size([256, 1024, 1, 1])\n",
407
- "torch.Size([1024, 256, 1, 1])\n",
408
- "torch.Size([256, 1024, 1, 1])\n",
409
- "torch.Size([1024, 256, 1, 1])\n",
410
- "torch.Size([256, 1024, 1, 1])\n",
411
- "layer4 >>\n",
412
- "torch.Size([1024, 256, 1, 1])\n",
413
- "torch.Size([512, 1024, 1, 1])\n",
414
- "torch.Size([512, 256, 1, 1])\n",
415
- "torch.Size([2048, 512, 1, 1])\n",
416
- "torch.Size([512, 2048, 1, 1])\n",
417
- "torch.Size([2048, 512, 1, 1])\n",
418
- "torch.Size([512, 2048, 1, 1])\n"
419
- ]
420
- }
421
- ],
422
- "source": [
423
- "# fuse ConvBN\n",
424
- "i = 1\n",
425
- "for layer in [model.layer1, model.layer2, model.layer3, model.layer4]:\n",
426
- " print(f'layer{i} >>')\n",
427
- " for block in layer:\n",
428
- " # Fuse the weights in conv1 and conv3\n",
429
- " block.conv1.fuse_bn()\n",
430
- " print(block.conv1.fused_weight.size())\n",
431
- " block.conv3.fuse_bn()\n",
432
- " print(block.conv3.fused_weight.size())\n",
433
- " if not isinstance(block.skip, torch.nn.Identity):\n",
434
- " layer[0].skip.fuse_bn()\n",
435
- " print(layer[0].skip.fused_weight.size())\n",
436
- " i += 1"
437
- ]
438
- },
439
- {
440
- "cell_type": "code",
441
- "execution_count": 13,
442
- "id": "b3a55f82",
443
- "metadata": {
444
- "id": "b3a55f82"
445
- },
446
- "outputs": [],
447
- "source": [
448
- "x = torch.randn(5,3,224,224)"
449
- ]
450
- },
451
  {
452
- "cell_type": "code",
453
- "execution_count": 15,
454
- "id": "dccbf19c",
455
- "metadata": {
456
- "colab": {
457
- "base_uri": "https://localhost:8080/"
458
- },
459
- "id": "dccbf19c",
460
- "outputId": "4a5409f4-761b-4682-a5be-5f55fd595135"
461
- },
462
- "outputs": [
463
- {
464
- "output_type": "stream",
465
- "name": "stdout",
466
- "text": [
467
- "torch.Size([5, 1000])\n"
468
- ]
469
- }
470
- ],
471
- "source": [
472
- "y_old = model(x)\n",
473
- "print(y_old.size())"
474
  ]
475
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  {
477
- "cell_type": "code",
478
- "execution_count": 16,
479
- "id": "a5991c8f",
480
- "metadata": {
481
- "id": "a5991c8f"
482
- },
483
- "outputs": [],
484
- "source": [
485
- "def apply_transform(block1, block2, Q, keep_identity=True):\n",
486
- " with torch.no_grad():\n",
487
- " # Ensure that the out_channels of block1 is equal to the in_channels of block2\n",
488
- " assert Q.size()[0] == Q.size()[1], \"Q needs to be a square matrix\"\n",
489
- " n = Q.size()[0]\n",
490
- " assert block1.conv3.conv.out_channels == n and block2.conv1.conv.in_channels == n, \"Mismatched channels between blocks\"\n",
491
- "\n",
492
- " n = block1.conv3.conv.out_channels\n",
493
- "\n",
494
- " # Calculate the inverse of Q\n",
495
- " Q_inv = torch.inverse(Q)\n",
496
- "\n",
497
- " # Modify the weights of conv layers in block1\n",
498
- " block1.conv3.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.conv3.fused_weight.data)\n",
499
- " block1.conv3.fused_bias.data = torch.einsum('ij,j->i', Q, block1.conv3.fused_bias.data)\n",
500
- "\n",
501
- " if isinstance(block1.skip, torch.nn.Identity):\n",
502
- " if not keep_identity:\n",
503
- " block1.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)\n",
504
- " block1.skip.weight.data = Q.unsqueeze(-1).unsqueeze(-1)\n",
505
- " else:\n",
506
- " block1.skip.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.skip.fused_weight.data)\n",
507
- " block1.skip.fused_bias.data = torch.einsum('ij,j->i', Q, block1.skip.fused_bias.data)\n",
508
- "\n",
509
- " # Modify the weights of conv layers in block2\n",
510
- " block2.conv1.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.conv1.fused_weight.data)\n",
511
- "\n",
512
- " if isinstance(block2.skip, torch.nn.Identity):\n",
513
- " if not keep_identity:\n",
514
- " block2.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)\n",
515
- " block2.skip.weight.data = Q_inv.unsqueeze(-1).unsqueeze(-1)\n",
516
- " else:\n",
517
- " block2.skip.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.skip.fused_weight.data)\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
  ]
519
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  {
521
- "cell_type": "code",
522
- "execution_count": 17,
523
- "id": "dd96acd7",
524
- "metadata": {
525
- "id": "dd96acd7"
526
- },
527
- "outputs": [],
528
- "source": [
529
- "Q = torch.nn.init.orthogonal_(torch.empty(256, 256))\n",
530
- "for i in range(5):\n",
531
- " apply_transform(model.layer3[i], model.layer3[i+1], Q, True)\n",
532
- "apply_transform(model.layer3[5], model.layer4[0], Q, True)"
533
- ]
534
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
  {
536
- "cell_type": "code",
537
- "execution_count": 18,
538
- "id": "e5d3628d",
539
- "metadata": {
540
- "colab": {
541
- "base_uri": "https://localhost:8080/"
542
- },
543
- "id": "e5d3628d",
544
- "outputId": "667cfe17-e3fb-4009-9553-a765c6377321"
545
- },
546
- "outputs": [
547
- {
548
- "output_type": "stream",
549
- "name": "stdout",
550
- "text": [
551
- "8.472800254821777e-05\n"
552
- ]
553
- }
554
- ],
555
- "source": [
556
- "y_new = model(x)\n",
557
- "print((y_new - y_old).abs().max().item())"
558
  ]
559
- },
560
- {
561
- "cell_type": "code",
562
- "execution_count": null,
563
- "id": "9fce3a38",
564
- "metadata": {
565
- "id": "9fce3a38"
566
- },
567
- "outputs": [],
568
- "source": []
569
- },
570
- {
571
- "cell_type": "code",
572
- "execution_count": null,
573
- "id": "5a54fe8b",
574
- "metadata": {
575
- "id": "5a54fe8b"
576
- },
577
- "outputs": [],
578
- "source": []
579
  }
580
- ],
581
- "metadata": {
582
- "kernelspec": {
583
- "display_name": "Python 3 (ipykernel)",
584
- "language": "python",
585
- "name": "python3"
586
- },
587
- "language_info": {
588
- "codemirror_mode": {
589
- "name": "ipython",
590
- "version": 3
591
- },
592
- "file_extension": ".py",
593
- "mimetype": "text/x-python",
594
- "name": "python",
595
- "nbconvert_exporter": "python",
596
- "pygments_lexer": "ipython3",
597
- "version": "3.10.6"
598
- },
599
- "colab": {
600
- "provenance": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
  },
603
- "nbformat": 4,
604
- "nbformat_minor": 5
605
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "71b6152c",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch, timm\n",
11
+ "from qlnet import QLNet"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 2,
17
+ "id": "4e7ed219",
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "m = QLNet()"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 3,
27
+ "id": "3f703be8",
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "state_dict = torch.load('qlnet-10m.pth.tar')"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 4,
37
+ "id": "435e2358",
38
+ "metadata": {},
39
+ "outputs": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  {
41
+ "data": {
42
+ "text/plain": [
43
+ "<All keys matched successfully>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  ]
45
+ },
46
+ "execution_count": 4,
47
+ "metadata": {},
48
+ "output_type": "execute_result"
49
+ }
50
+ ],
51
+ "source": [
52
+ "m.load_state_dict(state_dict)"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": 5,
58
+ "id": "f14d984a",
59
+ "metadata": {
60
+ "scrolled": true
61
+ },
62
+ "outputs": [
63
  {
64
+ "data": {
65
+ "text/plain": [
66
+ "QLNet(\n",
67
+ " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
68
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
69
+ " (act1): ReLU(inplace=True)\n",
70
+ " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
71
+ " (layer1): Sequential(\n",
72
+ " (0): QLBlock(\n",
73
+ " (conv1): ConvBN(\n",
74
+ " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
75
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
76
+ " )\n",
77
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n",
78
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
79
+ " (conv3): ConvBN(\n",
80
+ " (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
81
+ " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
82
+ " )\n",
83
+ " (skip): Identity()\n",
84
+ " (act3): hardball()\n",
85
+ " )\n",
86
+ " (1): QLBlock(\n",
87
+ " (conv1): ConvBN(\n",
88
+ " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
89
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
90
+ " )\n",
91
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n",
92
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
93
+ " (conv3): ConvBN(\n",
94
+ " (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
95
+ " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
96
+ " )\n",
97
+ " (skip): Identity()\n",
98
+ " (act3): hardball()\n",
99
+ " )\n",
100
+ " (2): QLBlock(\n",
101
+ " (conv1): ConvBN(\n",
102
+ " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
103
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
104
+ " )\n",
105
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n",
106
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
107
+ " (conv3): ConvBN(\n",
108
+ " (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
109
+ " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
110
+ " )\n",
111
+ " (skip): Identity()\n",
112
+ " (act3): hardball()\n",
113
+ " )\n",
114
+ " )\n",
115
+ " (layer2): Sequential(\n",
116
+ " (0): QLBlock(\n",
117
+ " (conv1): ConvBN(\n",
118
+ " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
119
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
120
+ " )\n",
121
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256, bias=False)\n",
122
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
123
+ " (conv3): ConvBN(\n",
124
+ " (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
125
+ " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
126
+ " )\n",
127
+ " (skip): ConvBN(\n",
128
+ " (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
129
+ " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
130
+ " )\n",
131
+ " (act3): hardball()\n",
132
+ " )\n",
133
+ " (1): QLBlock(\n",
134
+ " (conv1): ConvBN(\n",
135
+ " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
136
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
137
+ " )\n",
138
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
139
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
140
+ " (conv3): ConvBN(\n",
141
+ " (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
142
+ " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
143
+ " )\n",
144
+ " (skip): Identity()\n",
145
+ " (act3): hardball()\n",
146
+ " )\n",
147
+ " (2): QLBlock(\n",
148
+ " (conv1): ConvBN(\n",
149
+ " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
150
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
151
+ " )\n",
152
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
153
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
154
+ " (conv3): ConvBN(\n",
155
+ " (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
156
+ " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
157
+ " )\n",
158
+ " (skip): Identity()\n",
159
+ " (act3): hardball()\n",
160
+ " )\n",
161
+ " (3): QLBlock(\n",
162
+ " (conv1): ConvBN(\n",
163
+ " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
164
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
165
+ " )\n",
166
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
167
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
168
+ " (conv3): ConvBN(\n",
169
+ " (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
170
+ " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
171
+ " )\n",
172
+ " (skip): Identity()\n",
173
+ " (act3): hardball()\n",
174
+ " )\n",
175
+ " )\n",
176
+ " (layer3): Sequential(\n",
177
+ " (0): QLBlock(\n",
178
+ " (conv1): ConvBN(\n",
179
+ " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
180
+ " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
181
+ " )\n",
182
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=512, bias=False)\n",
183
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
184
+ " (conv3): ConvBN(\n",
185
+ " (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
186
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
187
+ " )\n",
188
+ " (skip): ConvBN(\n",
189
+ " (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
190
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
191
+ " )\n",
192
+ " (act3): hardball()\n",
193
+ " )\n",
194
+ " (1): QLBlock(\n",
195
+ " (conv1): ConvBN(\n",
196
+ " (conv): Conv2d(256, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
197
+ " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
198
+ " )\n",
199
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n",
200
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
201
+ " (conv3): ConvBN(\n",
202
+ " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
203
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
204
+ " )\n",
205
+ " (skip): Identity()\n",
206
+ " (act3): hardball()\n",
207
+ " )\n",
208
+ " (2): QLBlock(\n",
209
+ " (conv1): ConvBN(\n",
210
+ " (conv): Conv2d(256, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
211
+ " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
212
+ " )\n",
213
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n",
214
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
215
+ " (conv3): ConvBN(\n",
216
+ " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
217
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
218
+ " )\n",
219
+ " (skip): Identity()\n",
220
+ " (act3): hardball()\n",
221
+ " )\n",
222
+ " (3): QLBlock(\n",
223
+ " (conv1): ConvBN(\n",
224
+ " (conv): Conv2d(256, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
225
+ " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
226
+ " )\n",
227
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n",
228
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
229
+ " (conv3): ConvBN(\n",
230
+ " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
231
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
232
+ " )\n",
233
+ " (skip): Identity()\n",
234
+ " (act3): hardball()\n",
235
+ " )\n",
236
+ " (4): QLBlock(\n",
237
+ " (conv1): ConvBN(\n",
238
+ " (conv): Conv2d(256, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
239
+ " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
240
+ " )\n",
241
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n",
242
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
243
+ " (conv3): ConvBN(\n",
244
+ " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
245
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
246
+ " )\n",
247
+ " (skip): Identity()\n",
248
+ " (act3): hardball()\n",
249
+ " )\n",
250
+ " (5): QLBlock(\n",
251
+ " (conv1): ConvBN(\n",
252
+ " (conv): Conv2d(256, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
253
+ " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
254
+ " )\n",
255
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n",
256
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
257
+ " (conv3): ConvBN(\n",
258
+ " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
259
+ " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
260
+ " )\n",
261
+ " (skip): Identity()\n",
262
+ " (act3): hardball()\n",
263
+ " )\n",
264
+ " )\n",
265
+ " (layer4): Sequential(\n",
266
+ " (0): QLBlock(\n",
267
+ " (conv1): ConvBN(\n",
268
+ " (conv): Conv2d(256, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
269
+ " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
270
+ " )\n",
271
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=1024, bias=False)\n",
272
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
273
+ " (conv3): ConvBN(\n",
274
+ " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
275
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
276
+ " )\n",
277
+ " (skip): ConvBN(\n",
278
+ " (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
279
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
280
+ " )\n",
281
+ " (act3): hardball()\n",
282
+ " )\n",
283
+ " (1): QLBlock(\n",
284
+ " (conv1): ConvBN(\n",
285
+ " (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
286
+ " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
287
+ " )\n",
288
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n",
289
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
290
+ " (conv3): ConvBN(\n",
291
+ " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
292
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
293
+ " )\n",
294
+ " (skip): Identity()\n",
295
+ " (act3): hardball()\n",
296
+ " )\n",
297
+ " (2): QLBlock(\n",
298
+ " (conv1): ConvBN(\n",
299
+ " (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
300
+ " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
301
+ " )\n",
302
+ " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n",
303
+ " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
304
+ " (conv3): ConvBN(\n",
305
+ " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
306
+ " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
307
+ " )\n",
308
+ " (skip): Identity()\n",
309
+ " (act3): hardball()\n",
310
+ " )\n",
311
+ " )\n",
312
+ " (act): hardball()\n",
313
+ " (global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
314
+ " (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
315
+ ")"
316
  ]
317
+ },
318
+ "execution_count": 5,
319
+ "metadata": {},
320
+ "output_type": "execute_result"
321
+ }
322
+ ],
323
+ "source": [
324
+ "m.eval()"
325
+ ]
326
+ },
327
+ {
328
+ "cell_type": "code",
329
+ "execution_count": 6,
330
+ "id": "2099b937",
331
+ "metadata": {},
332
+ "outputs": [
333
  {
334
+ "name": "stdout",
335
+ "output_type": "stream",
336
+ "text": [
337
+ "layer1 >>\n",
338
+ "torch.Size([512, 64, 1, 1])\n",
339
+ "torch.Size([64, 256, 1, 1])\n",
340
+ "torch.Size([512, 64, 1, 1])\n",
341
+ "torch.Size([64, 256, 1, 1])\n",
342
+ "torch.Size([512, 64, 1, 1])\n",
343
+ "torch.Size([64, 256, 1, 1])\n",
344
+ "layer2 >>\n",
345
+ "torch.Size([512, 64, 1, 1])\n",
346
+ "torch.Size([128, 256, 1, 1])\n",
347
+ "torch.Size([128, 64, 1, 1])\n",
348
+ "torch.Size([1024, 128, 1, 1])\n",
349
+ "torch.Size([128, 512, 1, 1])\n",
350
+ "torch.Size([1024, 128, 1, 1])\n",
351
+ "torch.Size([128, 512, 1, 1])\n",
352
+ "torch.Size([1024, 128, 1, 1])\n",
353
+ "torch.Size([128, 512, 1, 1])\n",
354
+ "layer3 >>\n",
355
+ "torch.Size([1024, 128, 1, 1])\n",
356
+ "torch.Size([256, 512, 1, 1])\n",
357
+ "torch.Size([256, 128, 1, 1])\n",
358
+ "torch.Size([2048, 256, 1, 1])\n",
359
+ "torch.Size([256, 1024, 1, 1])\n",
360
+ "torch.Size([2048, 256, 1, 1])\n",
361
+ "torch.Size([256, 1024, 1, 1])\n",
362
+ "torch.Size([2048, 256, 1, 1])\n",
363
+ "torch.Size([256, 1024, 1, 1])\n",
364
+ "torch.Size([2048, 256, 1, 1])\n",
365
+ "torch.Size([256, 1024, 1, 1])\n",
366
+ "torch.Size([2048, 256, 1, 1])\n",
367
+ "torch.Size([256, 1024, 1, 1])\n",
368
+ "layer4 >>\n",
369
+ "torch.Size([2048, 256, 1, 1])\n",
370
+ "torch.Size([512, 1024, 1, 1])\n",
371
+ "torch.Size([512, 256, 1, 1])\n",
372
+ "torch.Size([2048, 512, 1, 1])\n",
373
+ "torch.Size([512, 1024, 1, 1])\n",
374
+ "torch.Size([2048, 512, 1, 1])\n",
375
+ "torch.Size([512, 1024, 1, 1])\n"
376
+ ]
377
+ }
378
+ ],
379
+ "source": [
380
+ "# fuse ConvBN\n",
381
+ "i = 1\n",
382
+ "for layer in [m.layer1, m.layer2, m.layer3, m.layer4]:\n",
383
+ " print(f'layer{i} >>')\n",
384
+ " for block in layer:\n",
385
+ " # Fuse the weights in conv1 and conv3\n",
386
+ " block.conv1.fuse_bn()\n",
387
+ " print(block.conv1.fused_weight.size())\n",
388
+ " block.conv3.fuse_bn()\n",
389
+ " print(block.conv3.fused_weight.size())\n",
390
+ " if not isinstance(block.skip, torch.nn.Identity):\n",
391
+ " layer[0].skip.fuse_bn()\n",
392
+ " print(layer[0].skip.fused_weight.size())\n",
393
+ " i += 1"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "code",
398
+ "execution_count": 7,
399
+ "id": "b3a55f82",
400
+ "metadata": {},
401
+ "outputs": [],
402
+ "source": [
403
+ "x = torch.randn(5,3,224,224)"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": 8,
409
+ "id": "dccbf19c",
410
+ "metadata": {},
411
+ "outputs": [],
412
+ "source": [
413
+ "out_old = m(x)"
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "code",
418
+ "execution_count": 9,
419
+ "id": "f0c74a04",
420
+ "metadata": {
421
+ "scrolled": true
422
+ },
423
+ "outputs": [
424
  {
425
+ "data": {
426
+ "text/plain": [
427
+ "torch.Size([5, 1000])"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  ]
429
+ },
430
+ "execution_count": 9,
431
+ "metadata": {},
432
+ "output_type": "execute_result"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  }
434
+ ],
435
+ "source": [
436
+ "out_old.size()"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": 10,
442
+ "id": "a5991c8f",
443
+ "metadata": {},
444
+ "outputs": [],
445
+ "source": [
446
+ "def apply_transform(block1, block2, Q, keep_identity=True):\n",
447
+ " with torch.no_grad():\n",
448
+ " # Ensure that the out_channels of block1 is equal to the in_channels of block2\n",
449
+ " assert Q.size()[0] == Q.size()[1], \"Q needs to be a square matrix\"\n",
450
+ " n = Q.size()[0]\n",
451
+ " assert block1.conv3.conv.out_channels == n and block2.conv1.conv.in_channels == n, \"Mismatched channels between blocks\"\n",
452
+ "\n",
453
+ " n = block1.conv3.conv.out_channels\n",
454
+ " \n",
455
+ " # Calculate the inverse of Q\n",
456
+ " Q_inv = torch.inverse(Q)\n",
457
+ "\n",
458
+ " # Modify the weights of conv layers in block1\n",
459
+ " block1.conv3.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.conv3.fused_weight.data)\n",
460
+ " block1.conv3.fused_bias.data = torch.einsum('ij,j->i', Q, block1.conv3.fused_bias.data)\n",
461
+ " \n",
462
+ " if isinstance(block1.skip, torch.nn.Identity):\n",
463
+ " if not keep_identity:\n",
464
+ " block1.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)\n",
465
+ " block1.skip.weight.data = Q.unsqueeze(-1).unsqueeze(-1)\n",
466
+ " else:\n",
467
+ " block1.skip.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.skip.fused_weight.data)\n",
468
+ " block1.skip.fused_bias.data = torch.einsum('ij,j->i', Q, block1.skip.fused_bias.data)\n",
469
+ "\n",
470
+ " # Modify the weights of conv layers in block2\n",
471
+ " block2.conv1.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.conv1.fused_weight.data)\n",
472
+ " \n",
473
+ " if isinstance(block2.skip, torch.nn.Identity):\n",
474
+ " if not keep_identity:\n",
475
+ " block2.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)\n",
476
+ " block2.skip.weight.data = Q_inv.unsqueeze(-1).unsqueeze(-1)\n",
477
+ " else:\n",
478
+ " block2.skip.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.skip.fused_weight.data)\n"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "execution_count": 11,
484
+ "id": "dd96acd7",
485
+ "metadata": {},
486
+ "outputs": [],
487
+ "source": [
488
+ "Q = torch.nn.init.orthogonal_(torch.empty(256, 256))\n",
489
+ "for i in range(5):\n",
490
+ " apply_transform(m.layer3[i], m.layer3[i+1], Q, True)\n",
491
+ "apply_transform(m.layer3[5], m.layer4[0], Q, True)"
492
+ ]
493
+ },
494
+ {
495
+ "cell_type": "code",
496
+ "execution_count": 12,
497
+ "id": "e5d3628d",
498
+ "metadata": {},
499
+ "outputs": [
500
+ {
501
+ "name": "stdout",
502
+ "output_type": "stream",
503
+ "text": [
504
+ "6.666779518127441e-05\n"
505
+ ]
506
  }
507
+ ],
508
+ "source": [
509
+ "out_new = m(x)\n",
510
+ "print((out_new - out_old).abs().max().item())"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "code",
515
+ "execution_count": null,
516
+ "id": "9fce3a38",
517
+ "metadata": {},
518
+ "outputs": [],
519
+ "source": []
520
+ },
521
+ {
522
+ "cell_type": "code",
523
+ "execution_count": null,
524
+ "id": "5a54fe8b",
525
+ "metadata": {},
526
+ "outputs": [],
527
+ "source": []
528
+ }
529
+ ],
530
+ "metadata": {
531
+ "kernelspec": {
532
+ "display_name": "Python 3 (ipykernel)",
533
+ "language": "python",
534
+ "name": "python3"
535
  },
536
+ "language_info": {
537
+ "codemirror_mode": {
538
+ "name": "ipython",
539
+ "version": 3
540
+ },
541
+ "file_extension": ".py",
542
+ "mimetype": "text/x-python",
543
+ "name": "python",
544
+ "nbconvert_exporter": "python",
545
+ "pygments_lexer": "ipython3",
546
+ "version": "3.10.6"
547
+ }
548
+ },
549
+ "nbformat": 4,
550
+ "nbformat_minor": 5
551
+ }
qlnet.py CHANGED
@@ -104,7 +104,7 @@ class QLBlock(nn.Module): # quasilinear hyperbolic system
104
  ):
105
  super(QLBlock, self).__init__()
106
 
107
- k = 4 if inplanes <= 128 else 2
108
  width = inplanes * k
109
  outplanes = inplanes if downsample is None else inplanes * 2
110
  first_dilation = first_dilation or dilation
@@ -114,12 +114,12 @@ class QLBlock(nn.Module): # quasilinear hyperbolic system
114
  dilation=first_dilation, groups=1, bias=False),
115
  norm_layer(width*2))
116
 
117
- self.conv2 = nn.Conv2d(width, width*2, kernel_size=3, stride=stride,
118
  padding=1, dilation=first_dilation, groups=width, bias=False)
119
- self.bn2 = norm_layer(width*2)
120
 
121
  self.conv3 = ConvBN(
122
- nn.Conv2d(width*2, outplanes, kernel_size=1, groups=1, bias=False),
123
  norm_layer(outplanes))
124
 
125
  self.skip = ConvBN(
 
104
  ):
105
  super(QLBlock, self).__init__()
106
 
107
+ k = 4 if inplanes <= 256 else 2
108
  width = inplanes * k
109
  outplanes = inplanes if downsample is None else inplanes * 2
110
  first_dilation = first_dilation or dilation
 
114
  dilation=first_dilation, groups=1, bias=False),
115
  norm_layer(width*2))
116
 
117
+ self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
118
  padding=1, dilation=first_dilation, groups=width, bias=False)
119
+ self.bn2 = norm_layer(width)
120
 
121
  self.conv3 = ConvBN(
122
+ nn.Conv2d(width, outplanes, kernel_size=1, groups=1, bias=False),
123
  norm_layer(outplanes))
124
 
125
  self.skip = ConvBN(