PaddleSeg-Matting / matting /dataset /matting_dataset.py
marta-0's picture
add files
6da6215
raw
history blame contribute delete
No virus
8.54 kB
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
import cv2
import numpy as np
import random
import paddle
from paddleseg.cvlibs import manager
import matting.transforms as T
@manager.DATASETS.add_component
class MattingDataset(paddle.io.Dataset):
"""
Pass in a dataset that conforms to the format.
matting_dataset/
|--bg/
|
|--train/
| |--fg/
| |--alpha/
|
|--val/
| |--fg/
| |--alpha/
| |--trimap/ (if existing)
|
|--train.txt
|
|--val.txt
See README.md for more information of dataset.
Args:
dataset_root(str): The root path of dataset.
transforms(list): Transforms for image.
mode (str, optional): which part of dataset to use. it is one of ('train', 'val', 'trainval'). Default: 'train'.
train_file (str|list, optional): File list is used to train. It should be `foreground_image.png background_image.png`
or `foreground_image.png`. It shold be provided if mode equal to 'train'. Default: None.
val_file (str|list, optional): File list is used to evaluation. It should be `foreground_image.png background_image.png`
or `foreground_image.png` or ``foreground_image.png background_image.png trimap_image.png`.
It shold be provided if mode equal to 'val'. Default: None.
get_trimap (bool, optional): Whether to get triamp. Default: True.
separator (str, optional): The separator of train_file or val_file. If file name contains ' ', '|' may be perfect. Default: ' '.
"""
def __init__(self,
dataset_root,
transforms,
mode='train',
train_file=None,
val_file=None,
get_trimap=True,
separator=' '):
super().__init__()
self.dataset_root = dataset_root
self.transforms = T.Compose(transforms)
self.mode = mode
self.get_trimap = get_trimap
self.separator = separator
# check file
if mode == 'train' or mode == 'trainval':
if train_file is None:
raise ValueError(
"When `mode` is 'train' or 'trainval', `train_file must be provided!"
)
if isinstance(train_file, str):
train_file = [train_file]
file_list = train_file
if mode == 'val' or mode == 'trainval':
if val_file is None:
raise ValueError(
"When `mode` is 'val' or 'trainval', `val_file must be provided!"
)
if isinstance(val_file, str):
val_file = [val_file]
file_list = val_file
if mode == 'trainval':
file_list = train_file + val_file
# read file
self.fg_bg_list = []
for file in file_list:
file = os.path.join(dataset_root, file)
with open(file, 'r') as f:
lines = f.readlines()
for line in lines:
line = line.strip()
self.fg_bg_list.append(line)
def __getitem__(self, idx):
data = {}
fg_bg_file = self.fg_bg_list[idx]
fg_bg_file = fg_bg_file.split(self.separator)
data['img_name'] = fg_bg_file[0] # using in save prediction results
fg_file = os.path.join(self.dataset_root, fg_bg_file[0])
alpha_file = fg_file.replace('/fg', '/alpha')
fg = cv2.imread(fg_file)
alpha = cv2.imread(alpha_file, 0)
data['alpha'] = alpha
data['gt_fields'] = []
# line is: fg [bg] [trimap]
if len(fg_bg_file) >= 2:
bg_file = os.path.join(self.dataset_root, fg_bg_file[1])
bg = cv2.imread(bg_file)
data['img'], data['bg'] = self.composite(fg, alpha, bg)
data['fg'] = fg
if self.mode in ['train', 'trainval']:
data['gt_fields'].append('fg')
data['gt_fields'].append('bg')
data['gt_fields'].append('alpha')
if len(fg_bg_file) == 3 and self.get_trimap:
if self.mode == 'val':
trimap_path = os.path.join(self.dataset_root, fg_bg_file[2])
if os.path.exists(trimap_path):
data['trimap'] = trimap_path
data['gt_fields'].append('trimap')
data['ori_trimap'] = cv2.imread(trimap_path, 0)
else:
raise FileNotFoundError(
'trimap is not Found: {}'.format(fg_bg_file[2]))
else:
data['img'] = fg
if self.mode in ['train', 'trainval']:
data['fg'] = fg.copy()
data['bg'] = fg.copy()
data['gt_fields'].append('fg')
data['gt_fields'].append('bg')
data['gt_fields'].append('alpha')
data['trans_info'] = [] # Record shape change information
# Generate trimap from alpha if no trimap file provided
if self.get_trimap:
if 'trimap' not in data:
data['trimap'] = self.gen_trimap(
data['alpha'], mode=self.mode).astype('float32')
data['gt_fields'].append('trimap')
if self.mode == 'val':
data['ori_trimap'] = data['trimap'].copy()
data = self.transforms(data)
# When evaluation, gt should not be transforms.
if self.mode == 'val':
data['gt_fields'].append('alpha')
data['img'] = data['img'].astype('float32')
for key in data.get('gt_fields', []):
data[key] = data[key].astype('float32')
if 'trimap' in data:
data['trimap'] = data['trimap'][np.newaxis, :, :]
if 'ori_trimap' in data:
data['ori_trimap'] = data['ori_trimap'][np.newaxis, :, :]
data['alpha'] = data['alpha'][np.newaxis, :, :] / 255.
return data
def __len__(self):
return len(self.fg_bg_list)
def composite(self, fg, alpha, ori_bg):
fg_h, fg_w = fg.shape[:2]
ori_bg_h, ori_bg_w = ori_bg.shape[:2]
wratio = fg_w / ori_bg_w
hratio = fg_h / ori_bg_h
ratio = wratio if wratio > hratio else hratio
# Resize ori_bg if it is smaller than fg.
if ratio > 1:
resize_h = math.ceil(ori_bg_h * ratio)
resize_w = math.ceil(ori_bg_w * ratio)
bg = cv2.resize(
ori_bg, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR)
else:
bg = ori_bg
bg = bg[0:fg_h, 0:fg_w, :]
alpha = alpha / 255
alpha = np.expand_dims(alpha, axis=2)
image = alpha * fg + (1 - alpha) * bg
image = image.astype(np.uint8)
return image, bg
@staticmethod
def gen_trimap(alpha, mode='train', eval_kernel=7):
if mode == 'train':
k_size = random.choice(range(2, 5))
iterations = np.random.randint(5, 15)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
(k_size, k_size))
dilated = cv2.dilate(alpha, kernel, iterations=iterations)
eroded = cv2.erode(alpha, kernel, iterations=iterations)
trimap = np.zeros(alpha.shape)
trimap.fill(128)
trimap[eroded > 254.5] = 255
trimap[dilated < 0.5] = 0
else:
k_size = eval_kernel
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
(k_size, k_size))
dilated = cv2.dilate(alpha, kernel)
trimap = np.zeros(alpha.shape)
trimap.fill(128)
trimap[alpha >= 250] = 255
trimap[dilated <= 5] = 0
return trimap