SemanticFPN / infer_onnx.py
zhengrongzhang's picture
change onnx to NHWC (#1)
e66aed5
import os
import sys
import pathlib
CURRENT_DIR = pathlib.Path(__file__).parent
sys.path.append(str(CURRENT_DIR))
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils import data
import torchvision.transforms as transform
import torch.nn.functional as F
import onnxruntime
from PIL import Image
import argparse
from datasets.utils import colorize_mask, build_img
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='SemanticFPN model')
parser.add_argument('--onnx_path', type=str, default='FPN_int_NHWC.onnx')
parser.add_argument('--save_path', type=str, default='./data/demo_results/senmatic_results.png')
parser.add_argument('--input_path', type=str, default='data/cityscapes/leftImg8bit/test/bonn/bonn_000000_000019_leftImg8bit.png')
parser.add_argument('--ipu', action='store_true', help='use ipu')
parser.add_argument('--provider_config', type=str, default=None,
help='provider config path')
args = parser.parse_args()
if args.ipu:
providers = ["VitisAIExecutionProvider"]
provider_options = [{"config_file": args.provider_config}]
else:
providers = ['CPUExecutionProvider']
provider_options = None
onnx_path = args.onnx_path
input_img = build_img(args)
session = onnxruntime.InferenceSession(onnx_path, providers=providers, provider_options=provider_options)
ort_input = {session.get_inputs()[0].name: input_img.cpu().numpy().transpose(0,2,3,1)}
ort_output = session.run(None, ort_input)[0].transpose(0,3,1,2)
if isinstance(ort_output, (tuple, list)):
ort_output = ort_output[0]
output = ort_output[0].transpose(1, 2, 0)
seg_pred = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
color_mask = colorize_mask(seg_pred)
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
color_mask.save(args.save_path)