{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "71b6152c", "metadata": {}, "outputs": [], "source": [ "import torch, timm\n", "from qlnet import QLNet" ] }, { "cell_type": "code", "execution_count": 2, "id": "4e7ed219", "metadata": {}, "outputs": [], "source": [ "m = QLNet()" ] }, { "cell_type": "code", "execution_count": 3, "id": "3f703be8", "metadata": {}, "outputs": [], "source": [ "state_dict = torch.load('qlnet22-16m.pth.tar')" ] }, { "cell_type": "code", "execution_count": 4, "id": "435e2358", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m.load_state_dict(state_dict)" ] }, { "cell_type": "code", "execution_count": 5, "id": "f14d984a", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "QLNet(\n", " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", " (layer1): Sequential(\n", " (0): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " (1): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " (2): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " )\n", " (layer2): Sequential(\n", " (0): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256, bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): ConvBN(\n", " (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (act3): hardball()\n", " )\n", " (1): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " (2): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " (3): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " )\n", " (layer3): Sequential(\n", " (0): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=512, bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): ConvBN(\n", " (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (act3): hardball()\n", " )\n", " (1): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(256, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n", " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " (2): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(256, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n", " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " (3): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(256, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n", " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " (4): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(256, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n", " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " (5): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(256, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n", " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " )\n", " (layer4): Sequential(\n", " (0): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(256, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=1024, bias=False)\n", " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): ConvBN(\n", " (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (act3): hardball()\n", " )\n", " (1): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n", " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " (2): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n", " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " )\n", " (act): hardball()\n", " (global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n", " (fc): Linear(in_features=512, out_features=1000, bias=True)\n", ")" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m.eval()" ] }, { "cell_type": "code", "execution_count": 6, "id": "2099b937", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "layer1 >>\n", "torch.Size([512, 64, 1, 1])\n", "torch.Size([64, 256, 1, 1])\n", "torch.Size([512, 64, 1, 1])\n", "torch.Size([64, 256, 1, 1])\n", "torch.Size([512, 64, 1, 1])\n", "torch.Size([64, 256, 1, 1])\n", "layer2 >>\n", "torch.Size([512, 64, 1, 1])\n", "torch.Size([128, 256, 1, 1])\n", "torch.Size([128, 64, 1, 1])\n", "torch.Size([1024, 128, 1, 1])\n", "torch.Size([128, 512, 1, 1])\n", "torch.Size([1024, 128, 1, 1])\n", "torch.Size([128, 512, 1, 1])\n", "torch.Size([1024, 128, 1, 1])\n", "torch.Size([128, 512, 1, 1])\n", "layer3 >>\n", "torch.Size([1024, 128, 1, 1])\n", "torch.Size([256, 512, 1, 1])\n", "torch.Size([256, 128, 1, 1])\n", "torch.Size([2048, 256, 1, 1])\n", "torch.Size([256, 1024, 1, 1])\n", "torch.Size([2048, 256, 1, 1])\n", "torch.Size([256, 1024, 1, 1])\n", "torch.Size([2048, 256, 1, 1])\n", "torch.Size([256, 1024, 1, 1])\n", "torch.Size([2048, 256, 1, 1])\n", "torch.Size([256, 1024, 1, 1])\n", "torch.Size([2048, 256, 1, 1])\n", "torch.Size([256, 1024, 1, 1])\n", "layer4 >>\n", "torch.Size([2048, 256, 1, 1])\n", "torch.Size([512, 1024, 1, 1])\n", "torch.Size([512, 256, 1, 1])\n", "torch.Size([2048, 512, 1, 1])\n", "torch.Size([512, 1024, 1, 1])\n", "torch.Size([2048, 512, 1, 1])\n", "torch.Size([512, 1024, 1, 1])\n" ] } ], "source": [ "# fuse ConvBN\n", "i = 1\n", "for layer in [m.layer1, m.layer2, m.layer3, m.layer4]:\n", " print(f'layer{i} >>')\n", " for block in layer:\n", " # Fuse the weights in conv1 and conv3\n", " block.conv1.fuse_bn()\n", " print(block.conv1.fused_weight.size())\n", " block.conv3.fuse_bn()\n", " print(block.conv3.fused_weight.size())\n", " if not isinstance(block.skip, torch.nn.Identity):\n", " layer[0].skip.fuse_bn()\n", " print(layer[0].skip.fused_weight.size())\n", " i += 1" ] }, { "cell_type": "code", "execution_count": 7, "id": "b3a55f82", "metadata": {}, "outputs": [], "source": [ "x = torch.randn(5,3,224,224)" ] }, { "cell_type": "code", "execution_count": 8, "id": "dccbf19c", "metadata": {}, "outputs": [], "source": [ "out_old = m(x)" ] }, { "cell_type": "code", "execution_count": 9, "id": "f0c74a04", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "torch.Size([5, 1000])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out_old.size()" ] }, { "cell_type": "code", "execution_count": 10, "id": "a5991c8f", "metadata": {}, "outputs": [], "source": [ "def apply_transform(block1, block2, Q, keep_identity=True):\n", " with torch.no_grad():\n", " # Ensure that the out_channels of block1 is equal to the in_channels of block2\n", " assert Q.size()[0] == Q.size()[1], \"Q needs to be a square matrix\"\n", " n = Q.size()[0]\n", " assert block1.conv3.conv.out_channels == n and block2.conv1.conv.in_channels == n, \"Mismatched channels between blocks\"\n", "\n", " n = block1.conv3.conv.out_channels\n", " \n", " # Calculate the inverse of Q\n", " Q_inv = torch.inverse(Q)\n", "\n", " # Modify the weights of conv layers in block1\n", " block1.conv3.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.conv3.fused_weight.data)\n", " block1.conv3.fused_bias.data = torch.einsum('ij,j->i', Q, block1.conv3.fused_bias.data)\n", " \n", " if isinstance(block1.skip, torch.nn.Identity):\n", " if not keep_identity:\n", " block1.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)\n", " block1.skip.weight.data = Q.unsqueeze(-1).unsqueeze(-1)\n", " else:\n", " block1.skip.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.skip.fused_weight.data)\n", " block1.skip.fused_bias.data = torch.einsum('ij,j->i', Q, block1.skip.fused_bias.data)\n", "\n", " # Modify the weights of conv layers in block2\n", " block2.conv1.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.conv1.fused_weight.data)\n", " \n", " if isinstance(block2.skip, torch.nn.Identity):\n", " if not keep_identity:\n", " block2.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)\n", " block2.skip.weight.data = Q_inv.unsqueeze(-1).unsqueeze(-1)\n", " else:\n", " block2.skip.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.skip.fused_weight.data)\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "dd96acd7", "metadata": {}, "outputs": [], "source": [ "Q = torch.nn.init.orthogonal_(torch.empty(256, 256))\n", "for i in range(5):\n", " apply_transform(m.layer3[i], m.layer3[i+1], Q, True)\n", "apply_transform(m.layer3[5], m.layer4[0], Q, True)" ] }, { "cell_type": "code", "execution_count": 12, "id": "e5d3628d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "6.666779518127441e-05\n" ] } ], "source": [ "out_new = m(x)\n", "print((out_new - out_old).abs().max().item())" ] }, { "cell_type": "code", "execution_count": null, "id": "9fce3a38", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "5a54fe8b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }