retinaface / widerface_onnx_inference.py
zhengrongzhang's picture
Update widerface_onnx_inference.py (#3)
f3e60d3 verified
raw
history blame contribute delete
No virus
4.59 kB
import os
import argparse
import onnxruntime as ort
from utils import *
CFG = {
"name": "mobilenet0.25",
"min_sizes": [[16, 32], [64, 128], [256, 512]],
"steps": [8, 16, 32],
"variance": [0.1, 0.2],
"clip": False,
}
INPUT_SIZE = [608, 640] #resize scale
DEVICE = torch.device("cpu")
def vis(img_raw, dets, vis_thres):
"""Visualization original image
Args:
img_raw: origin image
dets: detections
vis_thres: visualization threshold
Returns:
visualization results
"""
for b in dets:
if b[4] < vis_thres:
continue
text = "{:.4f}".format(b[4])
b = list(map(int, b))
cv2.rectangle(img_raw, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
cx = b[0]
cy = b[1] + 12
cv2.putText(img_raw, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255),)
# landms
cv2.circle(img_raw, (b[5], b[6]), 1, (0, 0, 255), 4)
cv2.circle(img_raw, (b[7], b[8]), 1, (0, 255, 255), 4)
cv2.circle(img_raw, (b[9], b[10]), 1, (255, 0, 255), 4)
cv2.circle(img_raw, (b[11], b[12]), 1, (0, 255, 0), 4)
cv2.circle(img_raw, (b[13], b[14]), 1, (255, 0, 0), 4)
# save image
if not os.path.exists("./results/"):
os.makedirs("./results/")
name = "./results/" + 'result' + ".jpg"
cv2.imwrite(name, img_raw)
def Retinaface_inference(run_ort, args):
"""Infer an image with onnx seession
Args:
run_ort: Onnx session
args: including image path and hyperparameters
Returns: boxes_list, confidence_list, landm_list
boxes_list = [[left, top, right, bottom]...]
confidence_list = [[confidence]...]
landm_list = [[landms(dim=10)]...]
"""
img_raw = cv2.imread(args.image_path, cv2.IMREAD_COLOR)
# preprocess
img, scale, resize = preprocess(img_raw, INPUT_SIZE, DEVICE)
# to NHWC
img = np.transpose(img, (0, 2, 3, 1))
# forward
outputs = run_ort.run(None, {run_ort.get_inputs()[0].name: img})
# postprocess
dets = postprocess(CFG, img, outputs, scale, resize, args.confidence_threshold, args.nms_threshold, DEVICE)
# result list
boxes = dets[:, :4]
confidences = dets[:, 4:5]
landms = dets[:, 5:]
boxes_list = [box.tolist() for box in boxes]
confidence_list = [confidence.tolist() for confidence in confidences]
landm_list = [landm.tolist() for landm in landms]
# save image
if args.save_image:
vis(img_raw, dets, args.vis_thres)
return boxes_list, confidence_list, landm_list
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Retinaface")
parser.add_argument(
"-m",
"--trained_model",
default="./weights/RetinaFace_int.onnx",
type=str,
help="Trained state_dict file path to open",
)
parser.add_argument(
"--image_path",
default="./data/widerface/val/images/18--Concerts/18_Concerts_Concerts_18_38.jpg",
type=str,
help="image path",
)
parser.add_argument(
"--confidence_threshold",
default=0.4,
type=float,
help="confidence_threshold"
)
parser.add_argument(
"--nms_threshold",
default=0.4,
type=float,
help="nms_threshold"
)
parser.add_argument(
"-s",
"--save_image",
action="store_true",
default=False,
help="show detection results",
)
parser.add_argument(
"--vis_thres",
default=0.5,
type=float,
help="visualization_threshold"
)
parser.add_argument(
"--ipu",
action="store_true",
help="Use IPU for inference.",
)
parser.add_argument(
"--provider_config",
type=str,
default="vaip_config.json",
help="Path of the config file for seting provider_options.",
)
args = parser.parse_args()
if args.ipu:
providers = ["VitisAIExecutionProvider"]
provider_options = [{"config_file": args.provider_config}]
else:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
provider_options = None
print("Loading pretrained model from {}".format(args.trained_model))
run_ort = ort.InferenceSession(args.trained_model, providers=providers, provider_options=provider_options)
boxes_list, confidence_list, landm_list = Retinaface_inference(run_ort, args)
print('inference done!')
print(boxes_list, confidence_list, landm_list)