# Converting PyTorch to ONNX

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

 from .autonotebook import tqdm as notebook_tqdm


## Defining the model

In [2]:
class DoubleConv(nn.Module):
 def __init__(self, in_channels, out_channels):
 super().__init__()
 self.conv = nn.Sequential(
 nn.Conv2d(in_channels, out_channels, 3, padding=1),
 nn.BatchNorm2d(out_channels),
 nn.ReLU(inplace=True),
 nn.Conv2d(out_channels, out_channels, 3, padding=1),
 nn.BatchNorm2d(out_channels),
 nn.ReLU(inplace=True)
 )
 def forward(self, x):
 return self.conv(x)

class Down(nn.Module):
 def __init__(self, in_channels, out_channels):
 super().__init__()
 self.down = nn.Sequential(
 nn.MaxPool2d(2),
 DoubleConv(in_channels, out_channels)
 )
 def forward(self, x):
 return self.down(x)

class Up(nn.Module):
 def __init__(self, in_channels, out_channels, bilinear=False):
 super().__init__()
 if bilinear:
 self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
 nn.Conv2d(in_channels, in_channels // 2, 1))
 else:
 self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, 2, stride=2)
 
 self.conv = DoubleConv(in_channels, out_channels)

 def forward(self, x1, x2):
 x1 = self.up(x1)
 diffY = x2.size()[2] - x1.size()[2]
 diffX = x2.size()[3] - x1.size()[3]
 x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
 x = torch.cat([x2, x1], dim=1)
 return self.conv(x)

class OutConv(nn.Module):
 def __init__(self, in_channels, out_channels):
 super().__init__()
 self.conv = nn.Conv2d(in_channels, out_channels, 1)
 self.sigmoid = nn.Sigmoid()

 def forward(self, x):
 return self.sigmoid(self.conv(x))

class UNet(nn.Module):
 def __init__(self, n_channels, n_classes):
 super().__init__()
 self.inc = DoubleConv(n_channels, 64)
 self.down1 = Down(64, 128)
 self.down2 = Down(128, 256)
 self.down3 = Down(256, 512)
 self.down4 = Down(512, 1024)
 self.up1 = Up(1024, 512)
 self.up2 = Up(512, 256)
 self.up3 = Up(256, 128)
 self.up4 = Up(128, 64)
 self.outc = OutConv(64, n_classes)

 def forward(self, x):
 x1 = self.inc(x)
 x2 = self.down1(x1)
 x3 = self.down2(x2)
 x4 = self.down3(x3)
 x5 = self.down4(x4)
 x = self.up1(x5, x4)
 x = self.up2(x4, x3)
 x = self.up3(x3, x2)
 x = self.up4(x2, x1)
 logits = self.outc(x)
 return logits

## Loading the model

In [3]:
model = UNet(n_channels=3, n_classes=1)
model.load_state_dict(torch.load('../weights/water_bodies_model.pth'))

dummy_input = torch.randn(1, 3, 256, 256)

model.eval()
torch_out = model(dummy_input)

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

## Converting to ONNX

In [None]:
onnx_path = '../weights/model.onnx'

torch.onnx.export(model,
 dummy_input,
 onnx_path,
 verbose=True,
 input_names = ['input'], # the model's input names
 output_names = ['output'], # the model's output names
 dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
 'output' : {0 : 'batch_size'}})

Exported graph: graph(%input : Float(*, 3, 256, 256, strides=[196608, 65536, 256, 1], requires_grad=0, device=cpu),
 %up4.up.weight : Float(128, 64, 2, 2, strides=[256, 4, 2, 1], requires_grad=1, device=cpu),
 %up4.up.bias : Float(64, strides=[1], requires_grad=1, device=cpu),
 %outc.conv.weight : Float(1, 64, 1, 1, strides=[64, 1, 1, 1], requires_grad=1, device=cpu),
 %outc.conv.bias : Float(1, strides=[1], requires_grad=1, device=cpu),
 %onnx::Conv_226 : Float(64, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu),
 %onnx::Conv_227 : Float(64, strides=[1], requires_grad=0, device=cpu),
 %onnx::Conv_229 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),
 %onnx::Conv_230 : Float(64, strides=[1], requires_grad=0, device=cpu),
 %onnx::Conv_232 : Float(128, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),
 %onnx::Conv_233 : Float(128, strides=[1], requires_grad=0, device=cpu),
 %onnx::Conv_235 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1

 _C._jit_pass_onnx_graph_shape_type_inference(
 _C._jit_pass_onnx_graph_shape_type_inference(


## Verifying the ONNX model

In [None]:
import onnx

onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

## Comparing ONNX Runtime and PyTorch results

In [None]:
import onnxruntime
import numpy as np

ort_session = onnxruntime.InferenceSession(onnx_path)

def to_numpy(tensor):
 return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(dummy_input)}
ort_outs = ort_session.run(None, ort_inputs)

np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")


Exported model has been tested with ONNXRuntime, and the result looks good!
