Spaces:
Runtime error
Runtime error
from typing import Optional | |
import numpy as np | |
import paddle | |
import paddle.nn.functional as F | |
def reverse_transform(alpha, trans_info): | |
"""recover pred to origin shape""" | |
for item in trans_info[::-1]: | |
if item[0] == "resize": | |
h, w = item[1][0], item[1][1] | |
alpha = F.interpolate(alpha, [h, w], mode="bilinear") | |
elif item[0] == "padding": | |
h, w = item[1][0], item[1][1] | |
alpha = alpha[:, :, 0:h, 0:w] | |
else: | |
raise Exception(f"Unexpected info '{item[0]}' in im_info") | |
return alpha | |
def preprocess(img, transforms, trimap=None): | |
data = {} | |
data["img"] = img | |
if trimap is not None: | |
data["trimap"] = trimap | |
data["gt_fields"] = ["trimap"] | |
data["trans_info"] = [] | |
data = transforms(data) | |
data["img"] = paddle.to_tensor(data["img"]) | |
data["img"] = data["img"].unsqueeze(0) | |
if trimap is not None: | |
data["trimap"] = paddle.to_tensor(data["trimap"]) | |
data["trimap"] = data["trimap"].unsqueeze((0, 1)) | |
return data | |
def predict( | |
model, | |
transforms, | |
image: np.ndarray, | |
trimap: Optional[np.ndarray] = None, | |
): | |
with paddle.no_grad(): | |
data = preprocess(img=image, transforms=transforms, trimap=None) | |
alpha = model(data) | |
alpha = reverse_transform(alpha, data["trans_info"]) | |
alpha = alpha.numpy().squeeze() | |
if trimap is not None: | |
alpha[trimap == 0] = 0 | |
alpha[trimap == 255] = 1. | |
return alpha | |