File size: 2,852 Bytes
a611825
c2a846f
 
a611825
c2a846f
a611825
c2a846f
 
 
 
 
a611825
c2a846f
a611825
c2a846f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
from gradio_modal import Modal
from gradio_imageslider import ImageSlider

import numpy as np

import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import utils.utils as utils

from models.dsine import DSINE

device = torch.device("cpu")

model = DSINE().to(device)
model.pixel_coords = model.pixel_coords.to(device)
model = utils.load_checkpoint("./checkpoints/dsine.pt", model)
model.eval()


def predict_normal(img_np: np.ndarray):
    # normalize
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )

    with torch.no_grad():
        img = np.array(img_np).astype(np.float32) / 255.0
        img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to("cpu")
        _, _, orig_H, orig_W = img.shape

        # zero-pad the input image so that both the width and height are multiples of 32
        l, r, t, b = utils.pad_input(orig_H, orig_W)
        img = F.pad(img, (l, r, t, b), mode="constant", value=0.0)
        img = normalize(img)

        # NOTE: if intrins is not given, we just assume that the principal point is at the center
        # and that the field-of-view is 60 degrees (feel free to modify this assumption)
        intrins = utils.get_intrins_from_fov(
            new_fov=60.0, H=orig_H, W=orig_W, device="cpu"
        ).unsqueeze(0)

        intrins[:, 0, 2] += l
        intrins[:, 1, 2] += t

        pred_norm = model(img, intrins=intrins)[-1]
        pred_norm = pred_norm[:, :, t : t + orig_H, l : l + orig_W]

        # save to output folder
        # NOTE: by saving the prediction as uint8 png format, you lose a lot of precision
        # if you want to use the predicted normals for downstream tasks, we recommend saving them as float32 NPY files
        pred_norm_np = (
            pred_norm.cpu().detach().numpy()[0, :, :, :].transpose(1, 2, 0)
        )  # (H, W, 3)
        pred_norm_np = ((pred_norm_np + 1.0) / 2.0 * 255.0).astype(np.uint8)

    return (img_np, pred_norm_np)


with gr.Blocks() as demo:
    with gr.Group():
        with gr.Row():
            input_img = gr.Image(label="Input image", image_mode="RGB")
            output_img = ImageSlider(label="Surface Normal", type="numpy")
            # output_img = gr.Image(label="Normal")

        btn = gr.Button("Predict")
        btn.click(fn=predict_normal, inputs=[input_img], outputs=[output_img])

    with Modal(visible=True, allow_user_close=False) as modal:
        gr.Markdown(
            "To use this space, you must agree to the terms and conditions. found [here](https://github.com/baegwangbin/DSINE/blob/main/LICENSE)."
        )
        btn = gr.Button("I agree")
        btn.click(lambda: Modal(visible=False), None, modal)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)