# Converting PyTorch to ONNX

In [9]:
import torch

print(torch.cuda.is_available())

True


## Defining the model

In [10]:
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange
from einops.layers.torch import Rearrange

# helper methods

def group_dict_by_key(cond, d):
 return_val = [dict(), dict()]
 for key in d.keys():
 match = bool(cond(key))
 ind = int(not match)
 return_val[ind][key] = d[key]
 return (*return_val,)

def group_by_key_prefix_and_remove_prefix(prefix, d):
 kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: x.startswith(prefix), d)
 kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
 return kwargs_without_prefix, kwargs

# classes

class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1
 def __init__(self, dim, eps = 1e-5):
 super().__init__()
 self.eps = eps
 self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
 self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

 def forward(self, x):
 var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
 mean = torch.mean(x, dim = 1, keepdim = True)
 return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

class PreNorm(nn.Module):
 def __init__(self, dim, fn):
 super().__init__()
 self.norm = LayerNorm(dim)
 self.fn = fn
 def forward(self, x, **kwargs):
 x = self.norm(x)
 return self.fn(x, **kwargs)

class FeedForward(nn.Module):
 def __init__(self, dim, mult = 4, dropout = 0.):
 super().__init__()
 self.net = nn.Sequential(
 nn.Conv2d(dim, dim * mult, 1),
 nn.GELU(),
 nn.Dropout(dropout),
 nn.Conv2d(dim * mult, dim, 1),
 nn.Dropout(dropout)
 )
 def forward(self, x):
 return self.net(x)

class DepthWiseConv2d(nn.Module):
 def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
 super().__init__()
 self.net = nn.Sequential(
 nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
 nn.BatchNorm2d(dim_in),
 nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
 )
 def forward(self, x):
 return self.net(x)

class Attention(nn.Module):
 def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, dropout = 0.):
 super().__init__()
 inner_dim = dim_head * heads
 padding = proj_kernel // 2
 self.heads = heads
 self.scale = dim_head ** -0.5

 self.attend = nn.Softmax(dim = -1)
 self.dropout = nn.Dropout(dropout)

 self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
 self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)

 self.to_out = nn.Sequential(
 nn.Conv2d(inner_dim, dim, 1),
 nn.Dropout(dropout)
 )

 def forward(self, x):
 shape = x.shape
 b, n, _, y, h = *shape, self.heads
 q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
 q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))

 dots = einsum('b i d, b j d -> b i j', q, k) * self.scale

 attn = self.attend(dots)
 attn = self.dropout(attn)

 out = einsum('b i j, b j d -> b i d', attn, v)
 out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
 return self.to_out(out)

class Transformer(nn.Module):
 def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head = 64, mlp_mult = 4, dropout = 0.):
 super().__init__()
 self.layers = nn.ModuleList([])
 for _ in range(depth):
 self.layers.append(nn.ModuleList([
 PreNorm(dim, Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout)),
 PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
 ]))
 def forward(self, x):
 for attn, ff in self.layers:
 x = attn(x) + x
 x = ff(x) + x
 return x

class CvT(nn.Module):
 def __init__(
 self,
 *,
 num_classes,
 s1_emb_dim = 64,
 s1_emb_kernel = 7,
 s1_emb_stride = 4,
 s1_proj_kernel = 3,
 s1_kv_proj_stride = 2,
 s1_heads = 1,
 s1_depth = 1,
 s1_mlp_mult = 4,
 s2_emb_dim = 192,
 s2_emb_kernel = 3,
 s2_emb_stride = 2,
 s2_proj_kernel = 3,
 s2_kv_proj_stride = 2,
 s2_heads = 3,
 s2_depth = 2,
 s2_mlp_mult = 4,
 s3_emb_dim = 384,
 s3_emb_kernel = 3,
 s3_emb_stride = 2,
 s3_proj_kernel = 3,
 s3_kv_proj_stride = 2,
 s3_heads = 6,
 s3_depth = 10,
 s3_mlp_mult = 4,
 dropout = 0.
 ):
 super().__init__()
 kwargs = dict(locals())

 dim = 1
 layers = []

 for prefix in ('s1', 's2', 's3'):
 config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)

 layers.append(nn.Sequential(
 nn.Conv2d(dim, config['emb_dim'], kernel_size = config['emb_kernel'], padding = (config['emb_kernel'] // 2), stride = config['emb_stride']),
 LayerNorm(config['emb_dim']),
 Transformer(dim = config['emb_dim'], proj_kernel = config['proj_kernel'], kv_proj_stride = config['kv_proj_stride'], depth = config['depth'], heads = config['heads'], mlp_mult = config['mlp_mult'], dropout = dropout)
 ))

 dim = config['emb_dim']

 self.layers = nn.Sequential(*layers)

 self.to_logits = nn.Sequential(
 nn.AdaptiveAvgPool2d(1),
 Rearrange('... () () -> ...'),
 nn.Linear(dim, num_classes)
 )

 def forward(self, x):
 latents = self.layers(x)
 return self.to_logits(latents)

## Loading the model

In [11]:
model = CvT(num_classes = 2)
model.load_state_dict(torch.load('../model/test-5-cvt-model.pth'))

dummy_input = torch.randn(1, 1, 256, 256)

model.eval()
torch_out = model(dummy_input)

## Converting to ONNX

In [13]:
onnx_path = '../model/model.onnx'

torch.onnx.export(model,
 dummy_input,
 onnx_path,
 verbose=True,
 input_names = ['input'], # the model's input names
 output_names = ['output'], # the model's output names
 dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
 'output' : {0 : 'batch_size'}},
 opset_version=13)

 inferred_length: int = length // known_product
 known: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length}
 unknown: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length}


graph(%input : Float(*, 1, 256, 256, strides=[65536, 65536, 256, 1], requires_grad=0, device=cpu),
 %layers.0.0.weight : Float(64, 1, 7, 7, strides=[49, 49, 7, 1], requires_grad=1, device=cpu),
 %layers.0.0.bias : Float(64, strides=[1], requires_grad=1, device=cpu),
 %layers.0.1.g : Float(1, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=1, device=cpu),
 %layers.0.1.b : Float(1, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=1, device=cpu),
 %layers.0.2.layers.0.0.norm.g : Float(1, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=1, device=cpu),
 %layers.0.2.layers.0.0.norm.b : Float(1, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=1, device=cpu),
 %layers.0.2.layers.0.0.fn.to_q.net.2.weight : Float(64, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=1, device=cpu),
 %layers.0.2.layers.0.0.fn.to_kv.net.2.weight : Float(128, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=1, device=cpu),
 %layers.0.2.layers.0.0.fn.to_out.0.weight : Float(64, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=1

## Verifying the ONNX model

In [15]:
import onnx

onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

## Comparing ONNX Runtime and PyTorch results

In [16]:
import onnxruntime
import numpy as np

ort_session = onnxruntime.InferenceSession(onnx_path)

def to_numpy(tensor):
 return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(dummy_input)}
ort_outs = ort_session.run(None, ort_inputs)

np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")


Exported model has been tested with ONNXRuntime, and the result looks good!
