In [1]:
import torch, timm
from qlnet import QLNet

In [2]:
m = QLNet()

In [3]:
state_dict = torch.load('qlnet22-16m.pth.tar')

In [4]:
m.load_state_dict(state_dict)



In [5]:
m.eval()

QLNet(
 (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
 (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 (act1): ReLU(inplace=True)
 (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
 (layer1): Sequential(
 (0): QLBlock(
 (conv1): ConvBN(
 (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
 (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 )
 (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
 (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 (conv3): ConvBN(
 (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
 (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 )
 (skip): Identity()
 (act3): hardball()
 )
 (1): QLBlock(
 (conv1): ConvBN(
 (conv): Conv2d(64, 512, kernel_size=

In [6]:
# fuse ConvBN
i = 1
for layer in [m.layer1, m.layer2, m.layer3, m.layer4]:
 print(f'layer{i} >>')
 for block in layer:
 # Fuse the weights in conv1 and conv3
 block.conv1.fuse_bn()
 print(block.conv1.fused_weight.size())
 block.conv3.fuse_bn()
 print(block.conv3.fused_weight.size())
 if not isinstance(block.skip, torch.nn.Identity):
 layer[0].skip.fuse_bn()
 print(layer[0].skip.fused_weight.size())
 i += 1

layer1 >>
torch.Size([512, 64, 1, 1])
torch.Size([64, 256, 1, 1])
torch.Size([512, 64, 1, 1])
torch.Size([64, 256, 1, 1])
torch.Size([512, 64, 1, 1])
torch.Size([64, 256, 1, 1])
layer2 >>
torch.Size([512, 64, 1, 1])
torch.Size([128, 256, 1, 1])
torch.Size([128, 64, 1, 1])
torch.Size([1024, 128, 1, 1])
torch.Size([128, 512, 1, 1])
torch.Size([1024, 128, 1, 1])
torch.Size([128, 512, 1, 1])
torch.Size([1024, 128, 1, 1])
torch.Size([128, 512, 1, 1])
layer3 >>
torch.Size([1024, 128, 1, 1])
torch.Size([256, 512, 1, 1])
torch.Size([256, 128, 1, 1])
torch.Size([2048, 256, 1, 1])
torch.Size([256, 1024, 1, 1])
torch.Size([2048, 256, 1, 1])
torch.Size([256, 1024, 1, 1])
torch.Size([2048, 256, 1, 1])
torch.Size([256, 1024, 1, 1])
torch.Size([2048, 256, 1, 1])
torch.Size([256, 1024, 1, 1])
torch.Size([2048, 256, 1, 1])
torch.Size([256, 1024, 1, 1])
layer4 >>
torch.Size([2048, 256, 1, 1])
torch.Size([512, 1024, 1, 1])
torch.Size([512, 256, 1, 1])
torch.Size([2048, 512, 1, 1])
torch.Size([512, 1024, 

In [7]:
x = torch.randn(5,3,224,224)

In [8]:
out_old = m(x)

In [9]:
out_old.size()

torch.Size([5, 1000])

In [10]:
def apply_transform(block1, block2, Q, keep_identity=True):
 with torch.no_grad():
 # Ensure that the out_channels of block1 is equal to the in_channels of block2
 assert Q.size()[0] == Q.size()[1], "Q needs to be a square matrix"
 n = Q.size()[0]
 assert block1.conv3.conv.out_channels == n and block2.conv1.conv.in_channels == n, "Mismatched channels between blocks"

 n = block1.conv3.conv.out_channels
 
 # Calculate the inverse of Q
 Q_inv = torch.inverse(Q)

 # Modify the weights of conv layers in block1
 block1.conv3.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.conv3.fused_weight.data)
 block1.conv3.fused_bias.data = torch.einsum('ij,j->i', Q, block1.conv3.fused_bias.data)
 
 if isinstance(block1.skip, torch.nn.Identity):
 if not keep_identity:
 block1.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)
 block1.skip.weight.data = Q.unsqueeze(-1).unsqueeze(-1)
 else:
 block1.skip.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.skip.fused_weight.data)
 block1.skip.fused_bias.data = torch.einsum('ij,j->i', Q, block1.skip.fused_bias.data)

 # Modify the weights of conv layers in block2
 block2.conv1.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.conv1.fused_weight.data)
 
 if isinstance(block2.skip, torch.nn.Identity):
 if not keep_identity:
 block2.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)
 block2.skip.weight.data = Q_inv.unsqueeze(-1).unsqueeze(-1)
 else:
 block2.skip.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.skip.fused_weight.data)


In [11]:
Q = torch.nn.init.orthogonal_(torch.empty(256, 256))
for i in range(5):
 apply_transform(m.layer3[i], m.layer3[i+1], Q, True)
apply_transform(m.layer3[5], m.layer4[0], Q, True)

In [12]:
out_new = m(x)
print((out_new - out_old).abs().max().item())

6.666779518127441e-05
