{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "71b6152c", "metadata": { "id": "71b6152c" }, "outputs": [], "source": [ "# Install PyTorch and timm\n", "!pip install torch timm\n", "\n", "!git clone https://huggingface.co/liuyao/QLNet" ] }, { "cell_type": "code", "source": [ "# Navigate to the repository directory\n", "import os\n", "os.chdir('QLNet')" ], "metadata": { "id": "pmVezdbxzcw7" }, "id": "pmVezdbxzcw7", "execution_count": 2, "outputs": [] }, { "cell_type": "code", "source": [ "import torch, timm\n", "from qlnet import QLNet" ], "metadata": { "id": "7vDt28zlzi0r" }, "id": "7vDt28zlzi0r", "execution_count": 5, "outputs": [] }, { "cell_type": "code", "execution_count": 9, "id": "3f703be8", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3f703be8", "outputId": "de73c734-305f-4955-fe69-7b7253b4f95e" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Using device: cpu\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 9 } ], "source": [ "# Check if GPU is available and set the device accordingly\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Using device: {device}\")\n", "\n", "# Create an instance of your model and load it to the device\n", "model = QLNet().to(device)\n", "\n", "# Load the model weights\n", "model.load_state_dict(torch.load('qlnet-50-v0.pth.tar', map_location=device)['state_dict'])" ] }, { "cell_type": "code", "execution_count": 10, "id": "f14d984a", "metadata": { "scrolled": true, "colab": { "base_uri": "https://localhost:8080/" }, "id": "f14d984a", "outputId": "efc70253-4bc0-4d0c-92d8-d247118138bc" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "QLNet(\n", " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (act1): ReLU(inplace=True)\n", " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", " (layer1): Sequential(\n", " (0): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " (1): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " (2): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " )\n", " (layer2): Sequential(\n", " (0): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256, bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): ConvBN(\n", " (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (act3): hardball()\n", " )\n", " (1): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n", " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " (2): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n", " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " (3): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n", " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " )\n", " (layer3): Sequential(\n", " (0): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(128, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=512, bias=False)\n", " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): ConvBN(\n", " (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (act3): hardball()\n", " )\n", " (1): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n", " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " (2): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n", " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " (3): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n", " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " (4): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n", " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " (5): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)\n", " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " )\n", " (layer4): Sequential(\n", " (0): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=512, bias=False)\n", " (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): ConvBN(\n", " (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (act3): hardball()\n", " )\n", " (1): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n", " (bn2): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " (2): QLBlock(\n", " (conv1): ConvBN(\n", " (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (conv2): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024, bias=False)\n", " (bn2): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): ConvBN(\n", " (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (skip): Identity()\n", " (act3): hardball()\n", " )\n", " )\n", " (act): hardball()\n", " (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))\n", " (fc): Linear(in_features=512, out_features=1000, bias=True)\n", ")" ] }, "metadata": {}, "execution_count": 10 } ], "source": [ "model.eval()" ] }, { "cell_type": "code", "execution_count": 12, "id": "2099b937", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2099b937", "outputId": "ac4557a4-ed2a-47b2-eca7-d9a337fff3f1" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "layer1 >>\n", "torch.Size([512, 64, 1, 1])\n", "torch.Size([64, 512, 1, 1])\n", "torch.Size([512, 64, 1, 1])\n", "torch.Size([64, 512, 1, 1])\n", "torch.Size([512, 64, 1, 1])\n", "torch.Size([64, 512, 1, 1])\n", "layer2 >>\n", "torch.Size([512, 64, 1, 1])\n", "torch.Size([128, 512, 1, 1])\n", "torch.Size([128, 64, 1, 1])\n", "torch.Size([1024, 128, 1, 1])\n", "torch.Size([128, 1024, 1, 1])\n", "torch.Size([1024, 128, 1, 1])\n", "torch.Size([128, 1024, 1, 1])\n", "torch.Size([1024, 128, 1, 1])\n", "torch.Size([128, 1024, 1, 1])\n", "layer3 >>\n", "torch.Size([1024, 128, 1, 1])\n", "torch.Size([256, 1024, 1, 1])\n", "torch.Size([256, 128, 1, 1])\n", "torch.Size([1024, 256, 1, 1])\n", "torch.Size([256, 1024, 1, 1])\n", "torch.Size([1024, 256, 1, 1])\n", "torch.Size([256, 1024, 1, 1])\n", "torch.Size([1024, 256, 1, 1])\n", "torch.Size([256, 1024, 1, 1])\n", "torch.Size([1024, 256, 1, 1])\n", "torch.Size([256, 1024, 1, 1])\n", "torch.Size([1024, 256, 1, 1])\n", "torch.Size([256, 1024, 1, 1])\n", "layer4 >>\n", "torch.Size([1024, 256, 1, 1])\n", "torch.Size([512, 1024, 1, 1])\n", "torch.Size([512, 256, 1, 1])\n", "torch.Size([2048, 512, 1, 1])\n", "torch.Size([512, 2048, 1, 1])\n", "torch.Size([2048, 512, 1, 1])\n", "torch.Size([512, 2048, 1, 1])\n" ] } ], "source": [ "# fuse ConvBN\n", "i = 1\n", "for layer in [model.layer1, model.layer2, model.layer3, model.layer4]:\n", " print(f'layer{i} >>')\n", " for block in layer:\n", " # Fuse the weights in conv1 and conv3\n", " block.conv1.fuse_bn()\n", " print(block.conv1.fused_weight.size())\n", " block.conv3.fuse_bn()\n", " print(block.conv3.fused_weight.size())\n", " if not isinstance(block.skip, torch.nn.Identity):\n", " layer[0].skip.fuse_bn()\n", " print(layer[0].skip.fused_weight.size())\n", " i += 1" ] }, { "cell_type": "code", "execution_count": 13, "id": "b3a55f82", "metadata": { "id": "b3a55f82" }, "outputs": [], "source": [ "x = torch.randn(5,3,224,224)" ] }, { "cell_type": "code", "execution_count": 15, "id": "dccbf19c", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "dccbf19c", "outputId": "4a5409f4-761b-4682-a5be-5f55fd595135" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([5, 1000])\n" ] } ], "source": [ "y_old = model(x)\n", "print(y_old.size())" ] }, { "cell_type": "code", "execution_count": 16, "id": "a5991c8f", "metadata": { "id": "a5991c8f" }, "outputs": [], "source": [ "def apply_transform(block1, block2, Q, keep_identity=True):\n", " with torch.no_grad():\n", " # Ensure that the out_channels of block1 is equal to the in_channels of block2\n", " assert Q.size()[0] == Q.size()[1], \"Q needs to be a square matrix\"\n", " n = Q.size()[0]\n", " assert block1.conv3.conv.out_channels == n and block2.conv1.conv.in_channels == n, \"Mismatched channels between blocks\"\n", "\n", " n = block1.conv3.conv.out_channels\n", "\n", " # Calculate the inverse of Q\n", " Q_inv = torch.inverse(Q)\n", "\n", " # Modify the weights of conv layers in block1\n", " block1.conv3.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.conv3.fused_weight.data)\n", " block1.conv3.fused_bias.data = torch.einsum('ij,j->i', Q, block1.conv3.fused_bias.data)\n", "\n", " if isinstance(block1.skip, torch.nn.Identity):\n", " if not keep_identity:\n", " block1.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)\n", " block1.skip.weight.data = Q.unsqueeze(-1).unsqueeze(-1)\n", " else:\n", " block1.skip.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.skip.fused_weight.data)\n", " block1.skip.fused_bias.data = torch.einsum('ij,j->i', Q, block1.skip.fused_bias.data)\n", "\n", " # Modify the weights of conv layers in block2\n", " block2.conv1.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.conv1.fused_weight.data)\n", "\n", " if isinstance(block2.skip, torch.nn.Identity):\n", " if not keep_identity:\n", " block2.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)\n", " block2.skip.weight.data = Q_inv.unsqueeze(-1).unsqueeze(-1)\n", " else:\n", " block2.skip.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.skip.fused_weight.data)\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "dd96acd7", "metadata": { "id": "dd96acd7" }, "outputs": [], "source": [ "Q = torch.nn.init.orthogonal_(torch.empty(256, 256))\n", "for i in range(5):\n", " apply_transform(model.layer3[i], model.layer3[i+1], Q, True)\n", "apply_transform(model.layer3[5], model.layer4[0], Q, True)" ] }, { "cell_type": "code", "execution_count": 18, "id": "e5d3628d", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "e5d3628d", "outputId": "667cfe17-e3fb-4009-9553-a765c6377321" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "8.472800254821777e-05\n" ] } ], "source": [ "y_new = model(x)\n", "print((y_new - y_old).abs().max().item())" ] }, { "cell_type": "code", "execution_count": null, "id": "9fce3a38", "metadata": { "id": "9fce3a38" }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "5a54fe8b", "metadata": { "id": "5a54fe8b" }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" }, "colab": { "provenance": [] } }, "nbformat": 4, "nbformat_minor": 5 }