File size: 5,266 Bytes
2ddc005
1d37acd
 
 
2ddc005
 
 
1d37acd
2ddc005
 
1d37acd
 
 
 
 
 
 
 
 
 
 
 
e2dc8d7
1d37acd
 
 
 
 
 
 
 
 
 
 
 
2ddc005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9afc96
2ddc005
 
 
 
 
be529c0
2ddc005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d37acd
7b9ea01
2ddc005
1d37acd
2ddc005
 
 
 
 
 
 
3c74d2a
a336e88
 
 
 
 
 
1d37acd
b9afc96
 
 
9ce1914
 
1d37acd
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from tqdm.auto import trange
from PIL import Image
import gradio as gr
import numpy as np
import pyrender
import trimesh
import scipy
import torch
import cv2
import os


class MidasDepth(object):
    def __init__(self, model_type="DPT_Large", device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
        self.device = device
        self.midas = torch.hub.load("intel-isl/MiDaS", model_type).to(self.device).eval().requires_grad_(False)
        self.transform = torch.hub.load("intel-isl/MiDaS", "transforms").dpt_transform

    def get_depth(self, image):
        if not isinstance(image, np.ndarray):
            image = np.asarray(image)
        if (image > 1).any():
            image = image.astype("float64") / 255.
        with torch.inference_mode():
            batch = self.transform(image[..., :3]).to(self.device)
            prediction = self.midas(batch)
            prediction = torch.nn.functional.interpolate(
                prediction.unsqueeze(1),
                size=image.shape[:2],
                mode="bicubic",
                align_corners=False,
            ).squeeze()
        return prediction.detach().cpu().numpy()


def process_depth(dep):
    depth = dep.copy()
    depth -= depth.min()
    depth /= depth.max()
    depth = 1 / np.clip(depth, 0.2, 1)
    blurred = cv2.medianBlur(depth, 5)  # 9 not available because it requires 8-bit
    maxd = cv2.dilate(blurred, np.ones((3, 3)))
    mind = cv2.erode(blurred, np.ones((3, 3)))
    edges = maxd - mind
    threshold = .05  # Better to have false positives
    pick_edges = edges > threshold
    return depth, pick_edges


def make_mesh(pic, depth, pick_edges):
    faces = []
    im = np.asarray(pic)
    grid = np.mgrid[0:im.shape[0], 0:im.shape[1]].transpose(1, 2, 0
                                                            ).reshape(-1, 2)[..., ::-1]
    flat_grid = grid[:, 1] * im.shape[1] + grid[:, 0]
    positions = np.concatenate(((grid - np.array(im.shape[:-1])[np.newaxis, :]
                                 / 2) / im.shape[1] * 2,
                                depth.flatten()[flat_grid][..., np.newaxis]),
                               axis=-1)
    positions[:, :-1] *= positions[:, -1:]
    positions[:, 1] *= -1
    colors = im.reshape(-1, 3)[flat_grid]

    c = lambda x, y: y * im.shape[1] + x
    for y in trange(im.shape[0]):
        for x in range(im.shape[1]):
            if pick_edges[y, x]:
                continue
            if x > 0 and y > 0:
                faces.append([c(x, y), c(x, y - 1), c(x - 1, y)])
            if x < im.shape[1] - 1 and y < im.shape[0] - 1:
                faces.append([c(x, y), c(x, y + 1), c(x + 1, y)])
    face_colors = np.asarray([colors[i[0]] for i in faces])

    tri_mesh = trimesh.Trimesh(vertices=positions * np.array([1.0, 1.0, -1.0]),
                               faces=faces,
                               face_colors=np.concatenate((face_colors,
                                                           face_colors[..., -1:]
                                                           * 0 + 255),
                                                          axis=-1).reshape(-1, 4),
                               smooth=False,
                               )

    return tri_mesh


def args_to_mat(tx, ty, tz, rx, ry, rz):
    mat = np.eye(4)
    mat[:3, :3] = scipy.spatial.transform.Rotation.from_euler("XYZ", (rx, ry, rz)).as_matrix()
    mat[:3, 3] = tx, ty, tz
    return mat


def render(mesh, mat):
    mesh = pyrender.mesh.Mesh.from_trimesh(mesh, smooth=False)
    scene = pyrender.Scene(ambient_light=np.array([1.0, 1.0, 1.0]))
    camera = pyrender.PerspectiveCamera(yfov=np.pi / 2, aspectRatio=1.0)
    scene.add(camera, pose=mat)
    scene.add(mesh)
    r = pyrender.OffscreenRenderer(1024, 1024)
    rgb, d = r.render(scene, pyrender.constants.RenderFlags.FLAT)
    mask = d == 0
    rgb = rgb.copy()
    rgb[mask] = 0
    res = Image.fromarray(np.concatenate((rgb,
                                          ((mask[..., np.newaxis]) == 0)
                                          .astype(np.uint8) * 255), axis=-1))
    return res


def main():
    os.environ["PYOPENGL_PLATFORM"] = "egl"  # "osmesa"

    midas = MidasDepth()
    def fn(pic, *args):
        depth, pick_edges = process_depth(midas.get_depth(pic))
        mesh = make_mesh(pic, depth, pick_edges)
        frame = render(mesh, args_to_mat(*args))
        return np.asarray(frame), (255 / np.asarray(depth)).astype(np.uint8), None

    interface = gr.Interface(fn=fn, inputs=[
        gr.inputs.Image(label="src", type="numpy"),
        gr.inputs.Number(label="tx", default=0.0),
        gr.inputs.Number(label="ty", default=0.0),
        gr.inputs.Number(label="tz", default=0.0),
        gr.inputs.Number(label="rx", default=0.0),
        gr.inputs.Number(label="ry", default=0.0),
        gr.inputs.Number(label="rz", default=0.0)
    ], outputs=[
        gr.outputs.Image(type="numpy", label="result"),
        gr.outputs.Image(type="numpy", label="depth"),
        gr.outputs.Video(label="interpolated")
    ], title="DALL·E 6D", description="Lift DALL·E 2 (or any other model) into 3D!")
    gr.TabbedInterface([interface], ["Warp 3D images"]).launch()


if __name__ == '__main__':
    main()