Jiading Fang commited on
Commit
2512c83
·
1 Parent(s): 68d536e

add app file for gradio

Browse files
Files changed (1) hide show
  1. app.py +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch
6
+ import gradio as gr
7
+ import matplotlib as mpl
8
+ import matplotlib.cm as cm
9
+
10
+ from vidar.core.wrapper import Wrapper
11
+ from vidar.utils.config import read_config
12
+
13
+
14
+ def colormap_depth(depth_map):
15
+ # Input: depth_map -> HxW numpy array with depth values
16
+ # Output: colormapped_im -> HxW numpy array with colorcoded depth values
17
+ mask = depth_map!=0
18
+ disp_map = 1/depth_map
19
+ vmax = np.percentile(disp_map[mask], 95)
20
+ vmin = np.percentile(disp_map[mask], 5)
21
+ normalizer = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
22
+ mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
23
+ mask = np.repeat(np.expand_dims(mask,-1), 3, -1)
24
+ colormapped_im = (mapper.to_rgba(disp_map)[:, :, :3] * 255).astype(np.uint8)
25
+ colormapped_im[~mask] = 255
26
+ return colormapped_im
27
+
28
+ def data_to_batch(data):
29
+ batch = data.copy()
30
+ batch['rgb'][0] = batch['rgb'][0].unsqueeze(0).unsqueeze(0)
31
+ batch['rgb'][1] = batch['rgb'][1].unsqueeze(0).unsqueeze(0)
32
+ batch['intrinsics'][0] = batch['intrinsics'][0].unsqueeze(0).unsqueeze(0)
33
+ batch['pose'][0] = batch['pose'][0].unsqueeze(0).unsqueeze(0)
34
+ batch['pose'][1] = batch['pose'][1].unsqueeze(0).unsqueeze(0)
35
+ batch['depth'][0] = batch['depth'][0].unsqueeze(0).unsqueeze(0)
36
+ batch['depth'][1] = batch['depth'][1].unsqueeze(0).unsqueeze(0)
37
+
38
+ return batch
39
+
40
+
41
+ os.environ['DIST_MODE'] = 'gpu' if torch.cuda.is_available() else 'cpu'
42
+ cfg_file_path = 'configs/papers/define/scannet_temporal_test_context_1.yaml'
43
+ cfg = read_config(cfg_file_path)
44
+
45
+ wrapper = Wrapper(cfg, verbose=True)
46
+
47
+ # print('arch: ', wrapper.arch)
48
+ # print('datasets: ', wrapper.datasets)
49
+
50
+ arch = wrapper.arch
51
+ arch.eval()
52
+ val_dataset = wrapper.datasets['validation'][0]
53
+ len_val_dataset = len(val_dataset)
54
+ # print('val datasets length: ', len_val_dataset)
55
+
56
+ # data_sample = val_dataset[0]
57
+ # batch = data_to_batch(data_sample)
58
+ # output = arch(batch, epoch=0)
59
+ # print('output: ', output)
60
+
61
+ # output_depth = output['predictions']['depth'][0][0]
62
+ # print('output_depth: ', output_depth)
63
+ # output_depth = output_depth.squeeze(0).squeeze(0).permute(1,2,0)
64
+ # print('output_depth shape: ', output_depth.shape)
65
+
66
+ def sample_data_idx():
67
+ return random.randint(0, len_val_dataset-1)
68
+
69
+ def display_images_from_idx(idx):
70
+ rgbs = val_dataset[int(idx)]['rgb']
71
+ return [np.array(rgb.permute(1,2,0)) for rgb in rgbs.values()]
72
+
73
+ def infer_depth_from_idx(idx):
74
+ data_sample = val_dataset[int(idx)]
75
+ batch = data_to_batch(data_sample)
76
+ output = arch(batch, epoch=0)
77
+ output_depths = output['predictions']['depth']
78
+ return [colormap_depth(output_depth[0].squeeze(0).squeeze(0).squeeze(0).detach().numpy()) for output_depth in output_depths.values()]
79
+
80
+ with gr.Blocks() as demo:
81
+
82
+ # layout
83
+ img_box = gr.Gallery(label="Sampled Images").style(grid=[2], height="auto")
84
+ data_idx_box = gr.Textbox(
85
+ label="Sampled Data Index",
86
+ placeholder="Number between {} and {}".format(0, len_val_dataset-1),
87
+ interactive=True
88
+ )
89
+ sample_btn = gr.Button('Sample Dataset')
90
+
91
+ depth_box = gr.Gallery(label="Infered Depth").style(grid=[2], height="auto")
92
+ infer_btn = gr.Button('Depth Infer')
93
+
94
+ # actions
95
+ sample_btn.click(
96
+ fn=sample_data_idx,
97
+ inputs=None,
98
+ outputs=data_idx_box
99
+ ).success(
100
+ fn=display_images_from_idx,
101
+ inputs=data_idx_box,
102
+ outputs=img_box,
103
+ )
104
+
105
+ infer_btn.click(
106
+ fn=infer_depth_from_idx,
107
+ inputs=data_idx_box,
108
+ outputs=depth_box
109
+ )
110
+
111
+ demo.launch(server_name="0.0.0.0", server_port=7860)