Spaces:
Runtime error
Runtime error
import numpy as np | |
from PIL import Image | |
import cv2 | |
import onnxruntime | |
from .tensor2numpy import NNormalize, NTo_Tensor, NUnsqueeze | |
from .vision import image2bgr | |
def read_modnet_image(input_image, ref_size=512): | |
im = Image.fromarray(np.uint8(input_image)) | |
width, length = im.size[0], im.size[1] | |
im = np.asarray(im) | |
im = image2bgr(im) | |
im = cv2.resize(im, (ref_size, ref_size), interpolation=cv2.INTER_AREA) | |
im = NNormalize(im, mean=np.array([0.5, 0.5, 0.5]), std=np.array([0.5, 0.5, 0.5])) | |
im = NUnsqueeze(NTo_Tensor(im)) | |
return im, width, length | |
def get_modnet_matting(input_image, checkpoint_path="./test.onnx", ref_size=512): | |
print("checkpoint_path:", checkpoint_path) | |
sess = onnxruntime.InferenceSession(checkpoint_path) | |
input_name = sess.get_inputs()[0].name | |
output_name = sess.get_outputs()[0].name | |
im, width, length = read_modnet_image(input_image=input_image, ref_size=ref_size) | |
matte = sess.run([output_name], {input_name: im}) | |
matte = (matte[0] * 255).astype('uint8') | |
matte = np.squeeze(matte) | |
mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA) | |
b, g, r = cv2.split(np.uint8(input_image)) | |
output_image = cv2.merge((b, g, r, mask)) | |
return output_image |