u2net_portrait / modnet.py
hylee's picture
init
ec88f2e
raw
history blame contribute delete
No virus
2.82 kB
import os
import cv2
import argparse
import numpy as np
from PIL import Image
import onnx
import onnxruntime
class ModNet:
def __init__(self, model_path):
# Initialize session and get prediction
self.session = onnxruntime.InferenceSession(model_path, None)
# Get x_scale_factor & y_scale_factor to resize image
def get_scale_factor(self, im_h, im_w, ref_size):
if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
if im_w >= im_h:
im_rh = ref_size
im_rw = int(im_w / im_h * ref_size)
elif im_w < im_h:
im_rw = ref_size
im_rh = int(im_h / im_w * ref_size)
else:
im_rh = im_h
im_rw = im_w
im_rw = im_rw - im_rw % 32
im_rh = im_rh - im_rh % 32
x_scale_factor = im_rw / im_w
y_scale_factor = im_rh / im_h
return x_scale_factor, y_scale_factor
def segment(self, image_path):
ref_size = 512
##############################################
# Main Inference part
##############################################
# read image
im = cv2.imread(image_path)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
# unify image channels to 3
if len(im.shape) == 2:
im = im[:, :, None]
if im.shape[2] == 1:
im = np.repeat(im, 3, axis=2)
elif im.shape[2] == 4:
im = im[:, :, 0:3]
# normalize values to scale it between -1 to 1
im = (im - 127.5) / 127.5
im_h, im_w, im_c = im.shape
x, y = self.get_scale_factor(im_h, im_w, ref_size)
# resize image
im = cv2.resize(im, None, fx=x, fy=y, interpolation=cv2.INTER_AREA)
# prepare input shape
im = np.transpose(im)
im = np.swapaxes(im, 1, 2)
im = np.expand_dims(im, axis=0).astype('float32')
input_name = self.session.get_inputs()[0].name
output_name = self.session.get_outputs()[0].name
result = self.session.run([output_name], {input_name: im})
# refine matte
matte = (np.squeeze(result[0]) * 255).astype('uint8')
matte = cv2.resize(matte, dsize=(im_w, im_h), interpolation=cv2.INTER_AREA)
# obtain predicted foreground
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if len(image.shape) == 2:
image = image[:, :, None]
if image.shape[2] == 1:
image = np.repeat(image, 3, axis=2)
elif image.shape[2] == 4:
image = image[:, :, 0:3]
matte = np.repeat(np.asarray(matte)[:, :, None], 3, axis=2) / 255
foreground = image * matte + np.full(image.shape, 255) * (1 - matte)
return foreground