File size: 4,982 Bytes
bd86ed9
 
 
 
 
 
 
 
3fb6608
bd86ed9
 
0f1bbf6
bd86ed9
 
3fbdaa2
8b0757c
bd86ed9
 
 
 
 
 
 
 
 
 
 
 
 
c19a5d3
92224a7
 
 
 
 
 
 
 
c19a5d3
 
 
 
92224a7
bd86ed9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fb6608
 
 
 
3034d2d
f92af9d
3fb6608
 
7a8232f
dccb91c
 
c9de5c0
dccb91c
 
 
 
 
c9de5c0
e911b6f
 
e39f6fa
e911b6f
 
 
 
 
e39f6fa
e911b6f
f5ce9f2
 
b3a5f53
7a4e452
 
b3a5f53
e911b6f
bd86ed9
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
import cv2
import numpy as np
import os
from PIL import Image
import spaces
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose, Normalize
import tempfile
from gradio_imageslider import ImageSlider
import matplotlib.pyplot as plt

from iebins.networks.NewCRFDepth import NewCRFDepth
from iebins.util.transfrom import Resize, NormalizeImage, PrepareForNet
from iebins.utils import post_process_depth, flip_lr

css = """
#img-display-container {
    max-height: 100vh;
    }
#img-display-input {
    max-height: 80vh;
    }
#img-display-output {
    max-height: 80vh;
    }
"""
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = NewCRFDepth(version='large07', inv_depth=False,
                    max_depth=10, pretrained=None).to(DEVICE).eval()
model.train()
num_params = sum([np.prod(p.size()) for p in model.parameters()])
print("== Total number of parameters: {}".format(num_params))
num_params_update = sum([np.prod(p.shape)
                        for p in model.parameters() if p.requires_grad])
print("== Total number of learning parameters: {}".format(num_params_update))

model = torch.nn.DataParallel(model)
checkpoint = torch.load('checkpoints/nyu_L.pth',
                        map_location=torch.device(DEVICE))
model.load_state_dict(checkpoint['model'])
print("== Loaded checkpoint '{}'".format('checkpoints/nyu_L.pth'))

title = "# IEBins: Iterative Elastic Bins for Monocular Depth Estimation"
description = """Demo for **IEBins: Iterative Elastic Bins for Monocular Depth Estimation**.
Please refer to the [paper](https://arxiv.org/abs/2309.14137), [github](https://github.com/ShuweiShao/IEBins), or [poster](https://nips.cc/media/PosterPDFs/NeurIPS%202023/70695.png?t=1701662442.5228624) for more details."""

transform = Compose([
    Resize(
        width=518,
        height=518,
        resize_target=False,
        keep_aspect_ratio=True,
        ensure_multiple_of=14,
        resize_method='lower_bound',
        image_interpolation_method=cv2.INTER_CUBIC,
    ),
    NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    PrepareForNet(),
])


@spaces.GPU
@torch.no_grad()
def predict_depth(model, image):
    return model(image)


with gr.Blocks(css=css) as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    with gr.Row():
        input_image = gr.Image(label="Input Image",
                               type='numpy', elem_id='img-display-input')
        depth_image_slider = ImageSlider(
            label="Depth Map with Slider View", elem_id='img-display-output', position=0.5,)
    raw_file = gr.File(
        label="16-bit raw depth (can be considered as disparity)")
    submit = gr.Button("Submit")

    def on_submit(image):
        original_image = image.copy()

        h, w = image.shape[:2]

        # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
        # image = transform({'image': image})['image']
        # image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)

        image = np.asarray(image, dtype=np.float32) / 255.0
        image = torch.from_numpy(image.transpose((2, 0, 1)))
        image = Normalize(mean=[0.485, 0.456, 0.406], std=[
            0.229, 0.224, 0.225])(image)
        # image = torch.from_numpy(image).unsqueeze(0)
        with torch.no_grad():
            image = torch.autograd.Variable(image.unsqueeze(0))
            print("== Processing image")
            pred_depths_r_list, _, _ = model(image)
            image_flipped = flip_lr(image)
            pred_depths_r_list_flipped, _, _ = model(image_flipped)
            pred_depth = post_process_depth(
                pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
            print("== Finished processing image")

            # Convert the PyTorch tensor to a NumPy array and squeeze
            pred_depth = pred_depth.cpu().numpy().squeeze()

            # Convert to uint8 if necessary for the colormap
            pred_output_depth = pred_depth.astype(np.uint8)

            # Apply color map
            output_image = cv2.applyColorMap(
                pred_output_depth, cv2.COLORMAP_INFERNO)[:, :, ::-1]

            # Continue with your file saving operations
            tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
            # cv2.imwrite(tmp.name, output_image)
            plt.imsave(tmp.name, pred_depth, cmap='jet')

            return [(original_image, output_image), tmp.name]

    submit.click(on_submit, inputs=[input_image], outputs=[
                 depth_image_slider, raw_file])

    example_files = os.listdir('examples')
    example_files.sort()
    example_files = [os.path.join('examples', filename)
                     for filename in example_files]
    examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[
                           depth_image_slider, raw_file], fn=on_submit, cache_examples=False)


if __name__ == '__main__':
    demo.queue().launch()