# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import math import time import cv2 import numpy as np import paddle import paddle.nn.functional as F from paddleseg import utils from paddleseg.core import infer from paddleseg.utils import logger, progbar, TimeAverager from matting.utils import mkdir def partition_list(arr, m): """split the list 'arr' into m pieces""" n = int(math.ceil(len(arr) / float(m))) return [arr[i:i + n] for i in range(0, len(arr), n)] def save_alpha_pred(alpha, path, trimap=None): """ The value of alpha is range [0, 1], shape should be [h,w] """ dirname = os.path.dirname(path) if not os.path.exists(dirname): os.makedirs(dirname) trimap = cv2.imread(trimap, 0) alpha[trimap == 0] = 0 alpha[trimap == 255] = 255 alpha = (alpha).astype('uint8') cv2.imwrite(path, alpha) def reverse_transform(alpha, trans_info): """recover pred to origin shape""" for item in trans_info[::-1]: if item[0] == 'resize': h, w = item[1][0], item[1][1] alpha = F.interpolate(alpha, [h, w], mode='bilinear') elif item[0] == 'padding': h, w = item[1][0], item[1][1] alpha = alpha[:, :, 0:h, 0:w] else: raise Exception("Unexpected info '{}' in im_info".format(item[0])) return alpha def preprocess(img, transforms, trimap=None): data = {} data['img'] = img if trimap is not None: data['trimap'] = trimap data['gt_fields'] = ['trimap'] data['trans_info'] = [] data = transforms(data) data['img'] = paddle.to_tensor(data['img']) data['img'] = data['img'].unsqueeze(0) if trimap is not None: data['trimap'] = paddle.to_tensor(data['trimap']) data['trimap'] = data['trimap'].unsqueeze((0, 1)) return data def predict(model, model_path, transforms, image_list, image_dir=None, trimap_list=None, save_dir='output'): """ predict and visualize the image_list. Args: model (nn.Layer): Used to predict for input image. model_path (str): The path of pretrained model. transforms (transforms.Compose): Preprocess for input image. image_list (list): A list of image path to be predicted. image_dir (str, optional): The root directory of the images predicted. Default: None. trimap_list (list, optional): A list of trimap of image_list. Default: None. save_dir (str, optional): The directory to save the visualized results. Default: 'output'. """ utils.utils.load_entire_model(model, model_path) model.eval() nranks = paddle.distributed.get_world_size() local_rank = paddle.distributed.get_rank() if nranks > 1: img_lists = partition_list(image_list, nranks) trimap_lists = partition_list( trimap_list, nranks) if trimap_list is not None else None else: img_lists = [image_list] trimap_lists = [trimap_list] if trimap_list is not None else None logger.info("Start to predict...") progbar_pred = progbar.Progbar(target=len(img_lists[0]), verbose=1) preprocess_cost_averager = TimeAverager() infer_cost_averager = TimeAverager() postprocess_cost_averager = TimeAverager() batch_start = time.time() with paddle.no_grad(): for i, im_path in enumerate(img_lists[local_rank]): preprocess_start = time.time() trimap = trimap_lists[local_rank][ i] if trimap_list is not None else None data = preprocess(img=im_path, transforms=transforms, trimap=trimap) preprocess_cost_averager.record(time.time() - preprocess_start) infer_start = time.time() alpha_pred = model(data) infer_cost_averager.record(time.time() - infer_start) postprocess_start = time.time() alpha_pred = reverse_transform(alpha_pred, data['trans_info']) alpha_pred = (alpha_pred.numpy()).squeeze() alpha_pred = (alpha_pred * 255).astype('uint8') # get the saved name # if image_dir is not None: # im_file = im_path.replace(image_dir, '') # else: # im_file = os.path.basename(im_path) # if im_file[0] == '/' or im_file[0] == '\\': # im_file = im_file[1:] # save_path = os.path.join(save_dir, im_file) # mkdir(save_path) # save_alpha_pred(alpha_pred, save_path, trimap=trimap) postprocess_cost_averager.record(time.time() - postprocess_start) preprocess_cost = preprocess_cost_averager.get_average() infer_cost = infer_cost_averager.get_average() postprocess_cost = postprocess_cost_averager.get_average() if local_rank == 0: progbar_pred.update(i + 1, [('preprocess_cost', preprocess_cost), ('infer_cost cost', infer_cost), ('postprocess_cost', postprocess_cost)]) preprocess_cost_averager.reset() infer_cost_averager.reset() postprocess_cost_averager.reset() return alpha_pred