Spaces:
Running
Running
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
r""" | |
@DATE: 2024/9/5 21:21 | |
@File: human_matting.py | |
@IDE: pycharm | |
@Description: | |
人像抠图 | |
""" | |
import numpy as np | |
from PIL import Image | |
import onnxruntime | |
from .tensor2numpy import NNormalize, NTo_Tensor, NUnsqueeze | |
from .context import Context | |
import cv2 | |
import os | |
weight_path = os.path.join(os.path.dirname(__file__), "weights", "hivision_modnet.onnx") | |
def extract_human(ctx: Context): | |
""" | |
人像抠图 | |
:param ctx: 上下文 | |
""" | |
# 抠图 | |
matting_image = get_modnet_matting(ctx.processing_image, weight_path) | |
# 修复抠图 | |
ctx.processing_image = hollow_out_fix(matting_image) | |
ctx.matting_image = ctx.processing_image.copy() | |
def hollow_out_fix(src: np.ndarray) -> np.ndarray: | |
""" | |
修补抠图区域,作为抠图模型精度不够的补充 | |
:param src: | |
:return: | |
""" | |
b, g, r, a = cv2.split(src) | |
src_bgr = cv2.merge((b, g, r)) | |
# -----------padding---------- # | |
add_area = np.zeros((10, a.shape[1]), np.uint8) | |
a = np.vstack((add_area, a, add_area)) | |
add_area = np.zeros((a.shape[0], 10), np.uint8) | |
a = np.hstack((add_area, a, add_area)) | |
# -------------end------------ # | |
_, a_threshold = cv2.threshold(a, 127, 255, 0) | |
a_erode = cv2.erode( | |
a_threshold, | |
kernel=cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)), | |
iterations=3, | |
) | |
contours, hierarchy = cv2.findContours( | |
a_erode, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE | |
) | |
contours = [x for x in contours] | |
# contours = np.squeeze(contours) | |
contours.sort(key=lambda c: cv2.contourArea(c), reverse=True) | |
a_contour = cv2.drawContours(np.zeros(a.shape, np.uint8), contours[0], -1, 255, 2) | |
# a_base = a_contour[1:-1, 1:-1] | |
h, w = a.shape[:2] | |
mask = np.zeros( | |
[h + 2, w + 2], np.uint8 | |
) # mask 必须行和列都加 2,且必须为 uint8 单通道阵列 | |
cv2.floodFill(a_contour, mask=mask, seedPoint=(0, 0), newVal=255) | |
a = cv2.add(a, 255 - a_contour) | |
return cv2.merge((src_bgr, a[10:-10, 10:-10])) | |
def image2bgr(input_image): | |
if len(input_image.shape) == 2: | |
input_image = input_image[:, :, None] | |
if input_image.shape[2] == 1: | |
result_image = np.repeat(input_image, 3, axis=2) | |
elif input_image.shape[2] == 4: | |
result_image = input_image[:, :, 0:3] | |
else: | |
result_image = input_image | |
return result_image | |
def read_modnet_image(input_image, ref_size=512): | |
im = Image.fromarray(np.uint8(input_image)) | |
width, length = im.size[0], im.size[1] | |
im = np.asarray(im) | |
im = image2bgr(im) | |
im = cv2.resize(im, (ref_size, ref_size), interpolation=cv2.INTER_AREA) | |
im = NNormalize(im, mean=np.array([0.5, 0.5, 0.5]), std=np.array([0.5, 0.5, 0.5])) | |
im = NUnsqueeze(NTo_Tensor(im)) | |
return im, width, length | |
sess = None | |
def get_modnet_matting(input_image, checkpoint_path, ref_size=512): | |
global sess | |
if sess is None: | |
sess = onnxruntime.InferenceSession(checkpoint_path) | |
input_name = sess.get_inputs()[0].name | |
output_name = sess.get_outputs()[0].name | |
im, width, length = read_modnet_image(input_image=input_image, ref_size=ref_size) | |
matte = sess.run([output_name], {input_name: im}) | |
matte = (matte[0] * 255).astype("uint8") | |
matte = np.squeeze(matte) | |
mask = cv2.resize(matte, (width, length), interpolation=cv2.INTER_AREA) | |
b, g, r = cv2.split(np.uint8(input_image)) | |
output_image = cv2.merge((b, g, r, mask)) | |
return output_image | |