File size: 2,453 Bytes
087921f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import cv2
import math
import numpy as np
from torch.utils.data import Dataset
import os.path
import random
import torchvision.transforms as transforms
import torch
from PIL import Image, ImageDraw


class Image_Editing_Dataset(Dataset):
    def __init__(self, cfg, dataset_root, split='test', dataset_name=''):
        self.split = split
        self.cfg = cfg
        self.dataset_name = dataset_name
        self.img_format = '.png'

        self.dir_img = os.path.join(dataset_root, 'test_processed', 'images')
        self.dir_lab = os.path.join(dataset_root, 'test_processed', 'labels')
        self.dir_ins = os.path.join(dataset_root, 'test_processed', 'inst_map')
        name_list = os.listdir(self.dir_img)
        self.name_list = [n[:-4] for n in name_list if n.endswith(self.img_format)]
        self.name_list.sort()
        self.predefined_mask_path = os.path.join(dataset_root, f'test_processed', 'predefined_masks')

    def __getitem__(self, index):
        name = self.name_list[index]
        # input data
        img = cv2.imread(os.path.join(self.dir_img, name + '.png'))
        lab = cv2.imread(os.path.join(self.dir_lab, name + '.png'), 0)
        inst_map = Image.open(os.path.join(self.dir_ins, name + '.png'))
        inst_map = np.array(inst_map, dtype=np.int32)

        assert len(inst_map.shape) == 2
        
        img = get_transform(img)
        lab = get_transform(lab, normalize=False)
        lab = lab * 255.0

        mask = cv2.imread(os.path.join(self.predefined_mask_path, 'type_0', name + '.png'), 0) / 255
        mask = mask.reshape((1,) + mask.shape).astype(np.float32)
        
        mask = torch.from_numpy(mask)
        masked_img = img * (1. - mask)

        inst_map = inst_map.reshape((1,) + inst_map.shape).astype(np.float32)
        inst_map = torch.from_numpy(inst_map)
        
        return {'img': img, 'masked_img': masked_img, 'lab': lab, 'mask': mask, 'inst_map': inst_map, 'name': name}
        # 'mask_seam': mask_seam,

    def __len__(self):
        """Return the total number of images in the dataset."""
        return len(self.name_list)


def get_transform(img, normalize=True):
    transform_list = []

    transform_list += [transforms.ToTensor()]
    if normalize:
        transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)(img)