File size: 1,916 Bytes
7200298
a711240
 
 
 
 
 
 
 
 
818ec2e
a711240
 
 
 
 
 
3410172
a711240
 
 
 
3410172
a711240
 
 
 
 
 
 
 
 
 
 
 
 
818ec2e
7200298
 
 
a711240
 
 
 
7200298
 
 
a711240
 
7200298
a711240
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import yaml
from fastai.vision.all import \
    DataLoaders, \
    delegates, \
    DataBlock, \
    ImageBlock, \
    PILImage, \
    PILImageBW, \
    RandomSplitter, \
    Path, \
    get_files


class ImageImageDataLoaders(DataLoaders):
    """Basic wrapper around several `DataLoader`s with factory methods for Image to Image problems"""
    @classmethod
    @delegates(DataLoaders.from_dblock)
    def from_label_func(cls, path, filenames, label_func, valid_pct=0.2, seed=None, item_transforms=None,
                        batch_transforms=None, **kwargs):
        """Create from list of `fnames` in `path`s with `label_func`."""
        datablock = DataBlock(blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImageBW)),
                              get_y=label_func,
                              splitter=RandomSplitter(valid_pct, seed=seed),
                              item_tfms=item_transforms,
                              batch_tfms=batch_transforms)
        res = cls.from_dblock(datablock, filenames, path=path, **kwargs)
        return res


def get_y_fn(x):
    y = str(x.absolute()).replace('.jpg', '_depth.png')
    y = Path(y)

    return y


def create_data(data_path):
    with open(r"./src/code/params.yml") as f:
        params = yaml.safe_load(f)

    filenames = get_files(data_path, extensions='.jpg')
    if len(filenames) == 0:
        raise ValueError("Could not find any files in the given path")
    dataset = ImageImageDataLoaders.from_label_func(data_path,
                                                    seed=int(params['seed']),
                                                    bs=int(params['batch_size']),
                                                    num_workers=int(params['num_workers']),
                                                    filenames=filenames,
                                                    label_func=get_y_fn)

    return dataset