Spaces:
Runtime error
Runtime error
import random | |
import numpy as np | |
import torch | |
import onnx | |
import onnxruntime | |
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_int, truncated_noise_sample, convert_to_images) | |
# Overrided cause its some problems to pass constant float to model.forward when its converted to onnx graph | |
class BigGanONNX(BigGAN): | |
def forward(self, z, class_label, **kwargs): | |
return super().forward(z, class_label, 0.4) | |
def load_model(): | |
# Load pre-trained model tokenizer (vocabulary) | |
model = BigGanONNX.from_pretrained('biggan-deep-256') | |
model.eval() | |
return model | |
def create_inputs(): | |
truncation = 0.4 | |
class_vector = one_hot_from_int(random.randint(0, 1000), batch_size=1) | |
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1, seed=random.randint(0, 1000)) | |
# All in tensors | |
noise_vector = torch.from_numpy(noise_vector) | |
class_vector = torch.from_numpy(class_vector) | |
return noise_vector, class_vector | |
def export_model_to_onnx(): | |
model = load_model() | |
model_inputs = create_inputs() | |
torch.onnx.export( | |
model, | |
model_inputs, | |
'biggan.onnx', | |
export_params=True, | |
opset_version=11, | |
do_constant_folding=True, | |
input_names=['noise_vector', 'class_vector'], | |
output_names=['output'], | |
dynamic_axes={ | |
'noise_vector': {0: 'batch_size'}, | |
'class_vector': {0: 'batch_size'}, | |
'output': {0: 'batch_size'} | |
} | |
) | |
def check_model(): | |
onnx_model = onnx.load('biggan.onnx') | |
onnx.checker.check_model(onnx_model) | |
def to_numpy(tensor: torch.Tensor): | |
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() | |
def run(): | |
ort_session = onnxruntime.InferenceSession('biggan.onnx') | |
# compute ONNX Runtime output prediction | |
inputs = create_inputs() | |
ort_inputs = { | |
ort_session.get_inputs()[0].name: to_numpy(inputs[0]), | |
ort_session.get_inputs()[1].name: to_numpy(inputs[1]), | |
} | |
ort_outs = ort_session.run(None, ort_inputs, ) | |
print(type(ort_outs[0])) | |
if __name__ == '__main__': | |
export_model_to_onnx() | |
check_model() | |
run() | |