inference_hw_onnx / convert_to_onnx.py
vladmir077's picture
first commit
6e231e7
import random
import numpy as np
import torch
import onnx
import onnxruntime
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_int, truncated_noise_sample, convert_to_images)
# Overrided cause its some problems to pass constant float to model.forward when its converted to onnx graph
class BigGanONNX(BigGAN):
def forward(self, z, class_label, **kwargs):
return super().forward(z, class_label, 0.4)
def load_model():
# Load pre-trained model tokenizer (vocabulary)
model = BigGanONNX.from_pretrained('biggan-deep-256')
model.eval()
return model
def create_inputs():
truncation = 0.4
class_vector = one_hot_from_int(random.randint(0, 1000), batch_size=1)
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1, seed=random.randint(0, 1000))
# All in tensors
noise_vector = torch.from_numpy(noise_vector)
class_vector = torch.from_numpy(class_vector)
return noise_vector, class_vector
def export_model_to_onnx():
model = load_model()
model_inputs = create_inputs()
torch.onnx.export(
model,
model_inputs,
'biggan.onnx',
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['noise_vector', 'class_vector'],
output_names=['output'],
dynamic_axes={
'noise_vector': {0: 'batch_size'},
'class_vector': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
def check_model():
onnx_model = onnx.load('biggan.onnx')
onnx.checker.check_model(onnx_model)
def to_numpy(tensor: torch.Tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
def run():
ort_session = onnxruntime.InferenceSession('biggan.onnx')
# compute ONNX Runtime output prediction
inputs = create_inputs()
ort_inputs = {
ort_session.get_inputs()[0].name: to_numpy(inputs[0]),
ort_session.get_inputs()[1].name: to_numpy(inputs[1]),
}
ort_outs = ort_session.run(None, ort_inputs, )
print(type(ort_outs[0]))
if __name__ == '__main__':
export_model_to_onnx()
check_model()
run()