dalle-6D / app.py
neverix
Modern PyOpenGL, last attempt
7b9ea01
raw
history blame
5.27 kB
from tqdm.auto import trange
from PIL import Image
import gradio as gr
import numpy as np
import pyrender
import trimesh
import scipy
import torch
import cv2
import os
class MidasDepth(object):
def __init__(self, model_type="DPT_Large", device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
self.device = device
self.midas = torch.hub.load("intel-isl/MiDaS", model_type).to(self.device).eval().requires_grad_(False)
self.transform = torch.hub.load("intel-isl/MiDaS", "transforms").dpt_transform
def get_depth(self, image):
if not isinstance(image, np.ndarray):
image = np.asarray(image)
if (image > 1).any():
image = image.astype("float64") / 255.
with torch.inference_mode():
batch = self.transform(image[..., :3]).to(self.device)
prediction = self.midas(batch)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=image.shape[:2],
mode="bicubic",
align_corners=False,
).squeeze()
return prediction.detach().cpu().numpy()
def process_depth(dep):
depth = dep.copy()
depth -= depth.min()
depth /= depth.max()
depth = 1 / np.clip(depth, 0.2, 1)
blurred = cv2.medianBlur(depth, 5) # 9 not available because it requires 8-bit
maxd = cv2.dilate(blurred, np.ones((3, 3)))
mind = cv2.erode(blurred, np.ones((3, 3)))
edges = maxd - mind
threshold = .05 # Better to have false positives
pick_edges = edges > threshold
return depth, pick_edges
def make_mesh(pic, depth, pick_edges):
faces = []
im = np.asarray(pic)
grid = np.mgrid[0:im.shape[0], 0:im.shape[1]].transpose(1, 2, 0
).reshape(-1, 2)[..., ::-1]
flat_grid = grid[:, 1] * im.shape[1] + grid[:, 0]
positions = np.concatenate(((grid - np.array(im.shape[:-1])[np.newaxis, :]
/ 2) / im.shape[1] * 2,
depth.flatten()[flat_grid][..., np.newaxis]),
axis=-1)
positions[:, :-1] *= positions[:, -1:]
positions[:, 1] *= -1
colors = im.reshape(-1, 3)[flat_grid]
c = lambda x, y: y * im.shape[1] + x
for y in trange(im.shape[0]):
for x in range(im.shape[1]):
if pick_edges[y, x]:
continue
if x > 0 and y > 0:
faces.append([c(x, y), c(x, y - 1), c(x - 1, y)])
if x < im.shape[1] - 1 and y < im.shape[0] - 1:
faces.append([c(x, y), c(x, y + 1), c(x + 1, y)])
face_colors = np.asarray([colors[i[0]] for i in faces])
tri_mesh = trimesh.Trimesh(vertices=positions * np.array([1.0, 1.0, -1.0]),
faces=faces,
face_colors=np.concatenate((face_colors,
face_colors[..., -1:]
* 0 + 255),
axis=-1).reshape(-1, 4),
smooth=False,
)
return tri_mesh
def args_to_mat(tx, ty, tz, rx, ry, rz):
mat = np.eye(4)
mat[:3, :3] = scipy.spatial.transform.Rotation.from_euler("XYZ", (rx, ry, rz)).as_matrix()
mat[:3, 3] = tx, ty, tz
return mat
def render(mesh, mat):
mesh = pyrender.mesh.Mesh.from_trimesh(mesh, smooth=False)
scene = pyrender.Scene(ambient_light=np.array([1.0, 1.0, 1.0]))
camera = pyrender.PerspectiveCamera(yfov=np.pi / 2, aspectRatio=1.0)
scene.add(camera, pose=mat)
scene.add(mesh)
r = pyrender.OffscreenRenderer(1024, 1024)
rgb, d = r.render(scene, pyrender.constants.RenderFlags.FLAT)
mask = d == 0
rgb = rgb.copy()
rgb[mask] = 0
res = Image.fromarray(np.concatenate((rgb,
((mask[..., np.newaxis]) == 0)
.astype(np.uint8) * 255), axis=-1))
return res
def main():
os.environ["PYOPENGL_PLATFORM"] = "egl" # "osmesa"
midas = MidasDepth()
def fn(pic, *args):
depth, pick_edges = process_depth(midas.get_depth(pic))
mesh = make_mesh(pic, depth, pick_edges)
frame = render(mesh, args_to_mat(*args))
return np.asarray(frame), (255 / np.asarray(depth)).astype(np.uint8), None
interface = gr.Interface(fn=fn, inputs=[
gr.inputs.Image(label="src", type="numpy"),
gr.inputs.Number(label="tx", default=0.0),
gr.inputs.Number(label="ty", default=0.0),
gr.inputs.Number(label="tz", default=0.0),
gr.inputs.Number(label="rx", default=0.0),
gr.inputs.Number(label="ry", default=0.0),
gr.inputs.Number(label="rz", default=0.0)
], outputs=[
gr.outputs.Image(type="numpy", label="result"),
gr.outputs.Image(type="numpy", label="depth"),
gr.outputs.Video(label="interpolated")
], title="DALL·E 6D", description="Lift DALL·E 2 (or any other model) into 3D!")
gr.TabbedInterface([interface], ["Warp 3D images"]).launch()
if __name__ == '__main__':
main()