File size: 2,371 Bytes
3b49518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os

import pandas as pd
from PIL import Image, ImageFile
from torch.utils.data import Dataset

ImageFile.LOAD_TRUNCATED_IMAGES = True

from .transforms import task_transform


class TaskonomyDataset(Dataset):
    def __init__(self, 
                 data_root,
                 tasks, 
                 split='train', 
                 variant='tiny',
                 image_size=256,
                 max_images=None):
        """
        Taskonomy dataloader.

        Args:
            data_root: Root of Taskonomy data directory
            tasks: List of tasks. Any of ['rgb', 'depth_euclidean', 'depth_zbuffer',
                'edge_occlusion', 'edge_texture', 'keypoints2d', 'keypoints3d', 'normal',
                'principal_curvature', 'reshading', 'mask_valid'].
            split: One of {'train', 'val', 'test'}
            variant: One of {'debug', 'tiny', 'medium', 'full', 'fullplus'}
            image_size: Target image size
            max_images: Optional subset selection
        """
        super(TaskonomyDataset, self).__init__()
        self.data_root = data_root
        self.tasks = tasks
        self.split = split
        self.variant = variant
        self.image_size=image_size
        self.max_images = max_images
        
        self.image_ids = pd.read_csv(
            os.path.join(os.path.dirname(__file__), 'splits', f'{self.variant}_{self.split}.csv')
        ).to_numpy()
        
        if isinstance(self.max_images, int):
            self.image_ids = self.image_ids[:self.max_images]
        
        print(f'Initialized TaskonomyDataset with {len(self.image_ids)} images from variant {self.variant} in split {self.split}.')
        
        
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, index):
        
        # building / point / view
        building, point, view = self.image_ids[index]
        
        result = {}
        for task in self.tasks:
            task_id = 'depth_zbuffer' if task == 'mask_valid' else task
            path = os.path.join(
                self.data_root, task, building, f'point_{point}_view_{view}_domain_{task_id}.png'
            )
            img = Image.open(path)
            # Perform transformations
            img = task_transform(img, task=task, image_size=self.image_size)
            result[task] = img

        return result