# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import math import cv2 import numpy as np import random import paddle from paddleseg.cvlibs import manager import matting.transforms as T @manager.DATASETS.add_component class MattingDataset(paddle.io.Dataset): """ Pass in a dataset that conforms to the format. matting_dataset/ |--bg/ | |--train/ | |--fg/ | |--alpha/ | |--val/ | |--fg/ | |--alpha/ | |--trimap/ (if existing) | |--train.txt | |--val.txt See README.md for more information of dataset. Args: dataset_root(str): The root path of dataset. transforms(list): Transforms for image. mode (str, optional): which part of dataset to use. it is one of ('train', 'val', 'trainval'). Default: 'train'. train_file (str|list, optional): File list is used to train. It should be `foreground_image.png background_image.png` or `foreground_image.png`. It shold be provided if mode equal to 'train'. Default: None. val_file (str|list, optional): File list is used to evaluation. It should be `foreground_image.png background_image.png` or `foreground_image.png` or ``foreground_image.png background_image.png trimap_image.png`. It shold be provided if mode equal to 'val'. Default: None. get_trimap (bool, optional): Whether to get triamp. Default: True. separator (str, optional): The separator of train_file or val_file. If file name contains ' ', '|' may be perfect. Default: ' '. """ def __init__(self, dataset_root, transforms, mode='train', train_file=None, val_file=None, get_trimap=True, separator=' '): super().__init__() self.dataset_root = dataset_root self.transforms = T.Compose(transforms) self.mode = mode self.get_trimap = get_trimap self.separator = separator # check file if mode == 'train' or mode == 'trainval': if train_file is None: raise ValueError( "When `mode` is 'train' or 'trainval', `train_file must be provided!" ) if isinstance(train_file, str): train_file = [train_file] file_list = train_file if mode == 'val' or mode == 'trainval': if val_file is None: raise ValueError( "When `mode` is 'val' or 'trainval', `val_file must be provided!" ) if isinstance(val_file, str): val_file = [val_file] file_list = val_file if mode == 'trainval': file_list = train_file + val_file # read file self.fg_bg_list = [] for file in file_list: file = os.path.join(dataset_root, file) with open(file, 'r') as f: lines = f.readlines() for line in lines: line = line.strip() self.fg_bg_list.append(line) def __getitem__(self, idx): data = {} fg_bg_file = self.fg_bg_list[idx] fg_bg_file = fg_bg_file.split(self.separator) data['img_name'] = fg_bg_file[0] # using in save prediction results fg_file = os.path.join(self.dataset_root, fg_bg_file[0]) alpha_file = fg_file.replace('/fg', '/alpha') fg = cv2.imread(fg_file) alpha = cv2.imread(alpha_file, 0) data['alpha'] = alpha data['gt_fields'] = [] # line is: fg [bg] [trimap] if len(fg_bg_file) >= 2: bg_file = os.path.join(self.dataset_root, fg_bg_file[1]) bg = cv2.imread(bg_file) data['img'], data['bg'] = self.composite(fg, alpha, bg) data['fg'] = fg if self.mode in ['train', 'trainval']: data['gt_fields'].append('fg') data['gt_fields'].append('bg') data['gt_fields'].append('alpha') if len(fg_bg_file) == 3 and self.get_trimap: if self.mode == 'val': trimap_path = os.path.join(self.dataset_root, fg_bg_file[2]) if os.path.exists(trimap_path): data['trimap'] = trimap_path data['gt_fields'].append('trimap') data['ori_trimap'] = cv2.imread(trimap_path, 0) else: raise FileNotFoundError( 'trimap is not Found: {}'.format(fg_bg_file[2])) else: data['img'] = fg if self.mode in ['train', 'trainval']: data['fg'] = fg.copy() data['bg'] = fg.copy() data['gt_fields'].append('fg') data['gt_fields'].append('bg') data['gt_fields'].append('alpha') data['trans_info'] = [] # Record shape change information # Generate trimap from alpha if no trimap file provided if self.get_trimap: if 'trimap' not in data: data['trimap'] = self.gen_trimap( data['alpha'], mode=self.mode).astype('float32') data['gt_fields'].append('trimap') if self.mode == 'val': data['ori_trimap'] = data['trimap'].copy() data = self.transforms(data) # When evaluation, gt should not be transforms. if self.mode == 'val': data['gt_fields'].append('alpha') data['img'] = data['img'].astype('float32') for key in data.get('gt_fields', []): data[key] = data[key].astype('float32') if 'trimap' in data: data['trimap'] = data['trimap'][np.newaxis, :, :] if 'ori_trimap' in data: data['ori_trimap'] = data['ori_trimap'][np.newaxis, :, :] data['alpha'] = data['alpha'][np.newaxis, :, :] / 255. return data def __len__(self): return len(self.fg_bg_list) def composite(self, fg, alpha, ori_bg): fg_h, fg_w = fg.shape[:2] ori_bg_h, ori_bg_w = ori_bg.shape[:2] wratio = fg_w / ori_bg_w hratio = fg_h / ori_bg_h ratio = wratio if wratio > hratio else hratio # Resize ori_bg if it is smaller than fg. if ratio > 1: resize_h = math.ceil(ori_bg_h * ratio) resize_w = math.ceil(ori_bg_w * ratio) bg = cv2.resize( ori_bg, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR) else: bg = ori_bg bg = bg[0:fg_h, 0:fg_w, :] alpha = alpha / 255 alpha = np.expand_dims(alpha, axis=2) image = alpha * fg + (1 - alpha) * bg image = image.astype(np.uint8) return image, bg @staticmethod def gen_trimap(alpha, mode='train', eval_kernel=7): if mode == 'train': k_size = random.choice(range(2, 5)) iterations = np.random.randint(5, 15) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k_size, k_size)) dilated = cv2.dilate(alpha, kernel, iterations=iterations) eroded = cv2.erode(alpha, kernel, iterations=iterations) trimap = np.zeros(alpha.shape) trimap.fill(128) trimap[eroded > 254.5] = 255 trimap[dilated < 0.5] = 0 else: k_size = eval_kernel kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k_size, k_size)) dilated = cv2.dilate(alpha, kernel) trimap = np.zeros(alpha.shape) trimap.fill(128) trimap[alpha >= 250] = 255 trimap[dilated <= 5] = 0 return trimap