Image Classification
timm
PDE
ConvNet
File size: 23,102 Bytes
793ef19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "71b6152c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch, timm\n",
    "from qlnet import QLNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4e7ed219",
   "metadata": {},
   "outputs": [],
   "source": [
    "m = QLNet()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3f703be8",
   "metadata": {},
   "outputs": [],
   "source": [
    "state_dict = torch.load('qlnet-50-v0.pth.tar')['state_dict']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "435e2358",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m.load_state_dict(state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f14d984a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "QLNet(\n",
       "  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
       "  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (act1): ReLU(inplace=True)\n",
       "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  (layer1): Sequential(\n",
       "    (0): QLBlock(\n",
       "      (conv1): ConvBN(\n",
       "        (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): ConvBN(\n",
       "        (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (skip): Identity()\n",
       "      (act3): hardball()\n",
       "    )\n",
       "    (1): QLBlock(\n",
       "      (conv1): ConvBN(\n",
       "        (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): ConvBN(\n",
       "        (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (skip): Identity()\n",
       "      (act3): hardball()\n",
       "    )\n",
       "    (2): QLBlock(\n",
       "      (conv1): ConvBN(\n",
       "        (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): ConvBN(\n",
       "        (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (skip): Identity()\n",
       "      (act3): hardball()\n",
       "    )\n",
       "  )\n",
       "  (layer2): Sequential(\n",
       "    (0): QLBlock(\n",
       "      (conv1): ConvBN(\n",
       "        (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256, bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): ConvBN(\n",
       "        (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (skip): ConvBN(\n",
       "        (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (act3): hardball()\n",
       "    )\n",
       "    (1): QLBlock(\n",
       "      (conv1): ConvBN(\n",
       "        (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
       "      (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): ConvBN(\n",
       "        (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (skip): Identity()\n",
       "      (act3): hardball()\n",
       "    )\n",
       "    (2): QLBlock(\n",
       "      (conv1): ConvBN(\n",
       "        (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
       "      (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): ConvBN(\n",
       "        (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (skip): Identity()\n",
       "      (act3): hardball()\n",
       "    )\n",
       "    (3): QLBlock(\n",
       "      (conv1): ConvBN(\n",
       "        (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
       "      (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): ConvBN(\n",
       "        (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (skip): Identity()\n",
       "      (act3): hardball()\n",
       "    )\n",
       "  )\n",
       "  (layer3): Sequential(\n",
       "    (0): QLBlock(\n",
       "      (conv1): ConvBN(\n",
       "        (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=512, bias=False)\n",
       "      (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): ConvBN(\n",
       "        (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (skip): ConvBN(\n",
       "        (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (act3): hardball()\n",
       "    )\n",
       "    (1): QLBlock(\n",
       "      (conv1): ConvBN(\n",
       "        (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
       "      (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): ConvBN(\n",
       "        (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (skip): Identity()\n",
       "      (act3): hardball()\n",
       "    )\n",
       "    (2): QLBlock(\n",
       "      (conv1): ConvBN(\n",
       "        (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
       "      (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): ConvBN(\n",
       "        (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (skip): Identity()\n",
       "      (act3): hardball()\n",
       "    )\n",
       "    (3): QLBlock(\n",
       "      (conv1): ConvBN(\n",
       "        (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
       "      (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): ConvBN(\n",
       "        (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (skip): Identity()\n",
       "      (act3): hardball()\n",
       "    )\n",
       "    (4): QLBlock(\n",
       "      (conv1): ConvBN(\n",
       "        (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
       "      (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): ConvBN(\n",
       "        (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (skip): Identity()\n",
       "      (act3): hardball()\n",
       "    )\n",
       "    (5): QLBlock(\n",
       "      (conv1): ConvBN(\n",
       "        (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n",
       "      (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): ConvBN(\n",
       "        (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (skip): Identity()\n",
       "      (act3): hardball()\n",
       "    )\n",
       "  )\n",
       "  (layer4): Sequential(\n",
       "    (0): QLBlock(\n",
       "      (conv1): ConvBN(\n",
       "        (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=512, bias=False)\n",
       "      (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): ConvBN(\n",
       "        (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (skip): ConvBN(\n",
       "        (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (act3): hardball()\n",
       "    )\n",
       "    (1): QLBlock(\n",
       "      (conv1): ConvBN(\n",
       "        (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (conv2): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n",
       "      (bn2): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): ConvBN(\n",
       "        (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (skip): Identity()\n",
       "      (act3): hardball()\n",
       "    )\n",
       "    (2): QLBlock(\n",
       "      (conv1): ConvBN(\n",
       "        (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (conv2): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n",
       "      (bn2): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): ConvBN(\n",
       "        (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (skip): Identity()\n",
       "      (act3): hardball()\n",
       "    )\n",
       "  )\n",
       "  (act): hardball()\n",
       "  (global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n",
       "  (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2099b937",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "layer1 >>\n",
      "torch.Size([512, 64, 1, 1])\n",
      "torch.Size([64, 512, 1, 1])\n",
      "torch.Size([512, 64, 1, 1])\n",
      "torch.Size([64, 512, 1, 1])\n",
      "torch.Size([512, 64, 1, 1])\n",
      "torch.Size([64, 512, 1, 1])\n",
      "layer2 >>\n",
      "torch.Size([512, 64, 1, 1])\n",
      "torch.Size([128, 512, 1, 1])\n",
      "torch.Size([128, 64, 1, 1])\n",
      "torch.Size([1024, 128, 1, 1])\n",
      "torch.Size([128, 1024, 1, 1])\n",
      "torch.Size([1024, 128, 1, 1])\n",
      "torch.Size([128, 1024, 1, 1])\n",
      "torch.Size([1024, 128, 1, 1])\n",
      "torch.Size([128, 1024, 1, 1])\n",
      "layer3 >>\n",
      "torch.Size([1024, 128, 1, 1])\n",
      "torch.Size([256, 1024, 1, 1])\n",
      "torch.Size([256, 128, 1, 1])\n",
      "torch.Size([1024, 256, 1, 1])\n",
      "torch.Size([256, 1024, 1, 1])\n",
      "torch.Size([1024, 256, 1, 1])\n",
      "torch.Size([256, 1024, 1, 1])\n",
      "torch.Size([1024, 256, 1, 1])\n",
      "torch.Size([256, 1024, 1, 1])\n",
      "torch.Size([1024, 256, 1, 1])\n",
      "torch.Size([256, 1024, 1, 1])\n",
      "torch.Size([1024, 256, 1, 1])\n",
      "torch.Size([256, 1024, 1, 1])\n",
      "layer4 >>\n",
      "torch.Size([1024, 256, 1, 1])\n",
      "torch.Size([512, 1024, 1, 1])\n",
      "torch.Size([512, 256, 1, 1])\n",
      "torch.Size([2048, 512, 1, 1])\n",
      "torch.Size([512, 2048, 1, 1])\n",
      "torch.Size([2048, 512, 1, 1])\n",
      "torch.Size([512, 2048, 1, 1])\n"
     ]
    }
   ],
   "source": [
    "# fuse ConvBN\n",
    "i = 1\n",
    "for layer in [m.layer1, m.layer2, m.layer3, m.layer4]:\n",
    "    print(f'layer{i} >>')\n",
    "    for block in layer:\n",
    "        # Fuse the weights in conv1 and conv3\n",
    "        block.conv1.fuse_bn()\n",
    "        print(block.conv1.fused_weight.size())\n",
    "        block.conv3.fuse_bn()\n",
    "        print(block.conv3.fused_weight.size())\n",
    "        if not isinstance(block.skip, torch.nn.Identity):\n",
    "            layer[0].skip.fuse_bn()\n",
    "            print(layer[0].skip.fused_weight.size())\n",
    "    i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b3a55f82",
   "metadata": {},
   "outputs": [],
   "source": [
    "inpt = torch.randn(5,3,224,224)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "dccbf19c",
   "metadata": {},
   "outputs": [],
   "source": [
    "out_old = m(inpt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f0c74a04",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5, 1000])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_old.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a5991c8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_transform(block1, block2, Q, keep_identity=True):\n",
    "    with torch.no_grad():\n",
    "        # Ensure that the out_channels of block1 is equal to the in_channels of block2\n",
    "        assert Q.size()[0] == Q.size()[1], \"Q needs to be a square matrix\"\n",
    "        n = Q.size()[0]\n",
    "        assert block1.conv3.conv.out_channels == n and block2.conv1.conv.in_channels == n, \"Mismatched channels between blocks\"\n",
    "\n",
    "        n = block1.conv3.conv.out_channels\n",
    "        \n",
    "        # Calculate the inverse of Q\n",
    "        Q_inv = torch.inverse(Q)\n",
    "\n",
    "        # Modify the weights of conv layers in block1\n",
    "        block1.conv3.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.conv3.fused_weight.data)\n",
    "        block1.conv3.fused_bias.data = torch.einsum('ij,j->i', Q, block1.conv3.fused_bias.data)\n",
    "        \n",
    "        if isinstance(block1.skip, torch.nn.Identity):\n",
    "            if not keep_identity:\n",
    "                block1.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)\n",
    "                block1.skip.weight.data = Q.unsqueeze(-1).unsqueeze(-1)\n",
    "        else:\n",
    "            block1.skip.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.skip.fused_weight.data)\n",
    "            block1.skip.fused_bias.data = torch.einsum('ij,j->i', Q, block1.skip.fused_bias.data)\n",
    "\n",
    "        # Modify the weights of conv layers in block2\n",
    "        block2.conv1.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.conv1.fused_weight.data)\n",
    "        \n",
    "        if isinstance(block2.skip, torch.nn.Identity):\n",
    "            if not keep_identity:\n",
    "                block2.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)\n",
    "                block2.skip.weight.data = Q_inv.unsqueeze(-1).unsqueeze(-1)\n",
    "        else:\n",
    "            block2.skip.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.skip.fused_weight.data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "dd96acd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "Q = torch.nn.init.orthogonal_(torch.empty(256, 256))\n",
    "for i in range(5):\n",
    "    apply_transform(m.layer3[i], m.layer3[i+1], Q, True)\n",
    "apply_transform(m.layer3[5], m.layer4[0], Q, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e5d3628d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6.0558319091796875e-05\n"
     ]
    }
   ],
   "source": [
    "out_new = m(inpt)\n",
    "print((out_new - out_old).abs().max().item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fce3a38",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a54fe8b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}