File size: 3,121 Bytes
1ed7deb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import torch
import numpy as np
from tqdm import trange
from PIL import Image


def get_state(gpu):
    import torch
    midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
    if gpu:
        midas.cuda()
    midas.eval()

    midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
    transform = midas_transforms.default_transform

    state = {"model": midas,
             "transform": transform}
    return state


def depth_to_rgba(x):
    assert x.dtype == np.float32
    assert len(x.shape) == 2
    y = x.copy()
    y.dtype = np.uint8
    y = y.reshape(x.shape+(4,))
    return np.ascontiguousarray(y)


def rgba_to_depth(x):
    assert x.dtype == np.uint8
    assert len(x.shape) == 3 and x.shape[2] == 4
    y = x.copy()
    y.dtype = np.float32
    y = y.reshape(x.shape[:2])
    return np.ascontiguousarray(y)


def run(x, state):
    model = state["model"]
    transform = state["transform"]
    hw = x.shape[:2]
    with torch.no_grad():
        prediction = model(transform((x + 1.0) * 127.5).cuda())
        prediction = torch.nn.functional.interpolate(
            prediction.unsqueeze(1),
            size=hw,
            mode="bicubic",
            align_corners=False,
        ).squeeze()
        output = prediction.cpu().numpy()
    return output


def get_filename(relpath, level=-2):
    # save class folder structure and filename:
    fn = relpath.split(os.sep)[level:]
    folder = fn[-2]
    file   = fn[-1].split('.')[0]
    return folder, file


def save_depth(dataset, path, debug=False):
    os.makedirs(path)
    N = len(dset)
    if debug:
        N = 10
    state = get_state(gpu=True)
    for idx in trange(N, desc="Data"):
        ex = dataset[idx]
        image, relpath = ex["image"], ex["relpath"]
        folder, filename = get_filename(relpath)
        # prepare
        folderabspath = os.path.join(path, folder)
        os.makedirs(folderabspath, exist_ok=True)
        savepath = os.path.join(folderabspath, filename)
        # run model
        xout = run(image, state)
        I = depth_to_rgba(xout)
        Image.fromarray(I).save("{}.png".format(savepath))


if __name__ == "__main__":
    from taming.data.imagenet import ImageNetTrain, ImageNetValidation
    out = "data/imagenet_depth"
    if not os.path.exists(out):
        print("Please create a folder or symlink '{}' to extract depth data ".format(out) +
              "(be prepared that the output size will be larger than ImageNet itself).")
        exit(1)

    # go
    dset = ImageNetValidation()
    abspath = os.path.join(out, "val")
    if os.path.exists(abspath):
        print("{} exists - not doing anything.".format(abspath))
    else:
        print("preparing {}".format(abspath))
        save_depth(dset, abspath)
        print("done with validation split")

    dset = ImageNetTrain()
    abspath = os.path.join(out, "train")
    if os.path.exists(abspath):
        print("{} exists - not doing anything.".format(abspath))
    else:
        print("preparing {}".format(abspath))
        save_depth(dset, abspath)
        print("done with train split")

    print("done done.")