Spaces:
Runtime error
Runtime error
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import cv2 | |
import numpy as np | |
import time | |
import paddle | |
import paddle.nn.functional as F | |
from paddleseg.utils import TimeAverager, calculate_eta, logger, progbar | |
from ppmatting.metrics import metric | |
from pymatting.util.util import load_image, save_image, stack_images | |
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml | |
np.set_printoptions(suppress=True) | |
def save_alpha_pred(alpha, path): | |
""" | |
The value of alpha is range [0, 1], shape should be [h,w] | |
""" | |
dirname = os.path.dirname(path) | |
if not os.path.exists(dirname): | |
os.makedirs(dirname) | |
alpha = (alpha).astype('uint8') | |
cv2.imwrite(path, alpha) | |
def reverse_transform(alpha, trans_info): | |
"""recover pred to origin shape""" | |
for item in trans_info[::-1]: | |
if item[0][0] == 'resize': | |
h, w = item[1][0].numpy()[0], item[1][1].numpy()[0] | |
alpha = cv2.resize(alpha, dsize=(w, h)) | |
elif item[0][0] == 'padding': | |
h, w = item[1][0].numpy()[0], item[1][1].numpy()[0] | |
alpha = alpha[0:h, 0:w] | |
else: | |
raise Exception("Unexpected info '{}' in im_info".format(item[0])) | |
return alpha | |
def evaluate_ml(model, | |
eval_dataset, | |
num_workers=0, | |
print_detail=True, | |
save_dir='output/results', | |
save_results=True): | |
loader = paddle.io.DataLoader( | |
eval_dataset, | |
batch_size=1, | |
drop_last=False, | |
num_workers=num_workers, | |
return_list=True, ) | |
total_iters = len(loader) | |
mse_metric = metric.MSE() | |
sad_metric = metric.SAD() | |
grad_metric = metric.Grad() | |
conn_metric = metric.Conn() | |
if print_detail: | |
logger.info("Start evaluating (total_samples: {}, total_iters: {})...". | |
format(len(eval_dataset), total_iters)) | |
progbar_val = progbar.Progbar(target=total_iters, verbose=1) | |
reader_cost_averager = TimeAverager() | |
batch_cost_averager = TimeAverager() | |
batch_start = time.time() | |
img_name = '' | |
i = 0 | |
ignore_cnt = 0 | |
for iter, data in enumerate(loader): | |
reader_cost_averager.record(time.time() - batch_start) | |
image_rgb_chw = data['img'].numpy()[0] | |
image_rgb_hwc = np.transpose(image_rgb_chw, (1, 2, 0)) | |
trimap = data['trimap'].numpy().squeeze() / 255.0 | |
image = image_rgb_hwc * 0.5 + 0.5 # reverse normalize (x/255 - mean) / std | |
is_fg = trimap >= 0.9 | |
is_bg = trimap <= 0.1 | |
if is_fg.sum() == 0 or is_bg.sum() == 0: | |
ignore_cnt += 1 | |
logger.info(str(iter)) | |
continue | |
alpha_pred = model(image, trimap) | |
alpha_pred = reverse_transform(alpha_pred, data['trans_info']) | |
alpha_gt = data['alpha'].numpy().squeeze() * 255 | |
trimap = data['ori_trimap'].numpy().squeeze() | |
alpha_pred = np.round(alpha_pred * 255) | |
mse = mse_metric.update(alpha_pred, alpha_gt, trimap) | |
sad = sad_metric.update(alpha_pred, alpha_gt, trimap) | |
grad = grad_metric.update(alpha_pred, alpha_gt, trimap) | |
conn = conn_metric.update(alpha_pred, alpha_gt, trimap) | |
if sad > 1000: | |
print(data['img_name'][0]) | |
if save_results: | |
alpha_pred_one = alpha_pred | |
alpha_pred_one[trimap == 255] = 255 | |
alpha_pred_one[trimap == 0] = 0 | |
save_name = data['img_name'][0] | |
name, ext = os.path.splitext(save_name) | |
if save_name == img_name: | |
save_name = name + '_' + str(i) + ext | |
i += 1 | |
else: | |
img_name = save_name | |
save_name = name + '_' + str(0) + ext | |
i = 1 | |
save_alpha_pred(alpha_pred_one, os.path.join(save_dir, save_name)) | |
batch_cost_averager.record( | |
time.time() - batch_start, num_samples=len(alpha_gt)) | |
batch_cost = batch_cost_averager.get_average() | |
reader_cost = reader_cost_averager.get_average() | |
if print_detail: | |
progbar_val.update(iter + 1, | |
[('SAD', sad), ('MSE', mse), ('Grad', grad), | |
('Conn', conn), ('batch_cost', batch_cost), | |
('reader cost', reader_cost)]) | |
reader_cost_averager.reset() | |
batch_cost_averager.reset() | |
batch_start = time.time() | |
mse = mse_metric.evaluate() | |
sad = sad_metric.evaluate() | |
grad = grad_metric.evaluate() | |
conn = conn_metric.evaluate() | |
logger.info('[EVAL] SAD: {:.4f}, MSE: {:.4f}, Grad: {:.4f}, Conn: {:.4f}'. | |
format(sad, mse, grad, conn)) | |
logger.info('{}'.format(ignore_cnt)) | |
return sad, mse, grad, conn | |