monai
medical
katielink's picture
update license files
c8da237
raw
history blame
4.22 kB
import argparse
import os
import numpy as np
import onnx
import onnxruntime
import torch
from monai.networks.nets import FlexibleUNet
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model_and_export(
modelname, outname, out_channels, height, width, multigpu=False, in_channels=3, backbone="efficientnet-b0"
):
"""
Loading a model by name.
Args:
modelname: a whole path name of the model that need to be loaded.
outname: a name for output onnx model.
out_channels: output channels, which usually equals to 1 + class_number.
height: input images' height.
width: input images' width.
multigpu: if the pre-trained model trained on a multigpu environment.
in_channels: input images' channel number.
backbone: a name of backbone used by the flexible unet.
"""
isopen = os.path.exists(modelname)
if not isopen:
raise Exception("The specified model to load does not exist!")
model = FlexibleUNet(
in_channels=in_channels,
out_channels=out_channels,
backbone=backbone,
is_pad=False,
pretrained=False,
dropout=None,
)
if multigpu:
model = torch.nn.DataParallel(model)
model = model.cuda()
model.load_state_dict(torch.load(modelname, map_location=device)) # if the model is trained on multi gpu
model = model.eval()
np.random.seed(0)
x = np.random.random((1, 3, width, height))
x = torch.tensor(x, dtype=torch.float32)
x = x.cuda()
torch_out = model(x)
input_names = ["INPUT__0"]
output_names = ["OUTPUT__0"]
# Export the model
if multigpu:
model_trans = model.module
else:
model_trans = model
torch.onnx.export(
model_trans, # model to save
x, # model input
outname, # model save path
export_params=True,
verbose=True,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
opset_version=15,
dynamic_axes={"INPUT__0": {0: "batch_size"}, "OUTPUT__0": {0: "batch_size"}},
)
onnx_model = onnx.load(outname)
onnx.checker.check_model(onnx_model, full_check=True)
ort_session = onnxruntime.InferenceSession(outname)
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(["OUTPUT__0"], ort_inputs)
numpy_torch_out = to_numpy(torch_out)
# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(numpy_torch_out, ort_outs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# the original model for converting.
parser.add_argument(
"--model", type=str, default=r"/workspace/models/model.pt", help="Input an existing model weight"
)
# path to save the onnx model.
parser.add_argument(
"--outpath", type=str, default=r"/workspace/models/model.onnx", help="A path to save the onnx model."
)
parser.add_argument("--width", type=int, default=736, help="Width for exporting onnx model.")
parser.add_argument("--height", type=int, default=480, help="Height for exporting onnx model.")
parser.add_argument(
"--out_channels", type=int, default=2, help="Number of expected out_channels in model for exporting to onnx."
)
parser.add_argument("--multigpu", type=bool, default=False, help="If loading model trained with multi gpu.")
args = parser.parse_args()
modelname = args.model
outname = args.outpath
out_channels = args.out_channels
height = args.height
width = args.width
multigpu = args.multigpu
if os.path.exists(outname):
raise Exception(
"The specified outpath already exists! Change the outpath to avoid overwriting your saved model. "
)
model = load_model_and_export(modelname, outname, out_channels, height, width, multigpu)