xiexh20 commited on
Commit
cd9aff3
1 Parent(s): 0732f1c

add examples models

Browse files
app.py CHANGED
@@ -22,6 +22,7 @@ import imageio
22
  import gradio as gr
23
  import plotly.graph_objs as go
24
  import training_utils
 
25
 
26
  from configs.structured import ProjectConfig
27
  from demo import DemoRunner
@@ -91,7 +92,7 @@ def plot_points(colors, coords):
91
  return fig
92
 
93
 
94
- def inference(runner: DemoRunner, cfg: ProjectConfig, rgb, mask_hum, mask_obj, std_coverage, input_seed):
95
  """
96
  given user input, run inference
97
  :param runner:
@@ -101,26 +102,38 @@ def inference(runner: DemoRunner, cfg: ProjectConfig, rgb, mask_hum, mask_obj, s
101
  :param mask_obj: (h, w, 3), np array
102
  :param std_coverage: float value, used to estimate camera translation
103
  :param input_seed: random seed
 
104
  :return: path to the 3D reconstruction, and an interactive 3D figure for visualizing the point cloud
105
  """
106
- # Set random seed
107
- training_utils.set_seed(int(input_seed))
 
 
108
 
109
- data = DemoDataset([], (cfg.dataset.image_size, cfg.dataset.image_size),
110
  std_coverage)
111
- batch = data.image2batch(rgb, mask_hum, mask_obj)
112
-
113
- out_stage1, out_stage2 = runner.forward_batch(batch, cfg)
114
- points = out_stage2.points_packed().cpu().numpy()
115
- colors = out_stage2.features_packed().cpu().numpy()
116
- fig = plot_points(colors, points)
117
- # save tmp point cloud
118
- outdir = './results'
119
- os.makedirs(outdir, exist_ok=True)
120
- trimesh.PointCloud(points, colors).export(outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage2.ply")
121
- trimesh.PointCloud(out_stage1.points_packed().cpu().numpy(),
122
- out_stage1.features_packed().cpu().numpy()).export(outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage1.ply")
123
- return fig, outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage2.ply"
 
 
 
 
 
 
 
 
 
124
 
125
 
126
  @hydra.main(config_path='configs', config_name='configs', version_base='1.1')
@@ -129,6 +142,8 @@ def main(cfg: ProjectConfig):
129
  runner = DemoRunner(cfg)
130
 
131
  # runner = None # without model initialization, it shows one line of thumbnail
 
 
132
 
133
  # Setup interface
134
  demo = gr.Blocks(title="HDM Interaction Reconstruction Demo")
@@ -147,33 +162,43 @@ def main(cfg: ProjectConfig):
147
  # TODO: add hint for this value here
148
  input_std = gr.Number(label='Gaussian std coverage', value=3.5)
149
  input_seed = gr.Number(label='Random seed', value=42)
 
 
 
 
 
 
 
150
  # Output visualization
151
  with gr.Row():
152
  pc_plot = gr.Plot(label="Reconstructed point cloud")
153
  out_pc_download = gr.File(label="3D reconstruction for download") # this allows downloading
 
 
 
154
 
155
  gr.HTML("""<br/>""")
156
  # Control
157
  with gr.Row():
158
  button_recon = gr.Button("Start Reconstruction", interactive=True, variant='secondary')
159
  button_recon.click(fn=partial(inference, runner, cfg),
160
- inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed],
161
- outputs=[pc_plot, out_pc_download])
162
  gr.HTML("""<br/>""")
163
  # Example input
164
  example_dir = cfg.run.code_dir_abs+"/examples"
165
  rgb, ps, obj = 'k1.color.jpg', 'k1.person_mask.png', 'k1.obj_rend_mask.png'
166
  example_images = gr.Examples([
167
- [f"{example_dir}/017450/{rgb}", f"{example_dir}/017450/{ps}", f"{example_dir}/017450/{obj}", 3.0, 42],
168
- [f"{example_dir}/002446/{rgb}", f"{example_dir}/002446/{ps}", f"{example_dir}/002446/{obj}", 3.0, 42],
169
- [f"{example_dir}/053431/{rgb}", f"{example_dir}/053431/{ps}", f"{example_dir}/053431/{obj}", 3.8, 42],
170
- [f"{example_dir}/158107/{rgb}", f"{example_dir}/158107/{ps}", f"{example_dir}/158107/{obj}", 3.8, 42],
171
 
172
- ], inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed],)
173
 
174
  # demo.launch(share=True)
175
  # Enabling queue for runtime>60s, see: https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062
176
- demo.queue(concurrency_count=3).launch(share=True)
177
 
178
  if __name__ == '__main__':
179
- main()
 
22
  import gradio as gr
23
  import plotly.graph_objs as go
24
  import training_utils
25
+ import traceback
26
 
27
  from configs.structured import ProjectConfig
28
  from demo import DemoRunner
 
92
  return fig
93
 
94
 
95
+ def inference(runner: DemoRunner, cfg: ProjectConfig, rgb, mask_hum, mask_obj, std_coverage, input_seed, input_cls):
96
  """
97
  given user input, run inference
98
  :param runner:
 
102
  :param mask_obj: (h, w, 3), np array
103
  :param std_coverage: float value, used to estimate camera translation
104
  :param input_seed: random seed
105
+ :param input_cls: the object category of the input image
106
  :return: path to the 3D reconstruction, and an interactive 3D figure for visualizing the point cloud
107
  """
108
+ log = ""
109
+ try:
110
+ # Set random seed
111
+ training_utils.set_seed(int(input_seed))
112
 
113
+ data = DemoDataset([], (cfg.dataset.image_size, cfg.dataset.image_size),
114
  std_coverage)
115
+ batch = data.image2batch(rgb, mask_hum, mask_obj)
116
+
117
+ if input_cls != 'general':
118
+ log += f"Reloading fine-tuned checkpoint of category {input_cls}\n"
119
+ runner.reload_checkpoint(input_cls)
120
+
121
+ out_stage1, out_stage2 = runner.forward_batch(batch, cfg)
122
+ points = out_stage2.points_packed().cpu().numpy()
123
+ colors = out_stage2.features_packed().cpu().numpy()
124
+ fig = plot_points(colors, points)
125
+ # save tmp point cloud
126
+ outdir = './results'
127
+ os.makedirs(outdir, exist_ok=True)
128
+ trimesh.PointCloud(points, colors).export(outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage2_{input_cls}.ply")
129
+ trimesh.PointCloud(out_stage1.points_packed().cpu().numpy(),
130
+ out_stage1.features_packed().cpu().numpy()).export(
131
+ outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage1_{input_cls}.ply")
132
+ log += 'Successfully reconstructed the image.'
133
+ except Exception as e:
134
+ log = traceback.format_exc()
135
+
136
+ return fig, outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage2_{input_cls}.ply", log
137
 
138
 
139
  @hydra.main(config_path='configs', config_name='configs', version_base='1.1')
 
142
  runner = DemoRunner(cfg)
143
 
144
  # runner = None # without model initialization, it shows one line of thumbnail
145
+ # TODO: add instructions on how to get masks
146
+ # TODO: add instructions on how to use the demo, input output, example outputs etc.
147
 
148
  # Setup interface
149
  demo = gr.Blocks(title="HDM Interaction Reconstruction Demo")
 
162
  # TODO: add hint for this value here
163
  input_std = gr.Number(label='Gaussian std coverage', value=3.5)
164
  input_seed = gr.Number(label='Random seed', value=42)
165
+ # TODO: add description outside label
166
+ input_cls = gr.Dropdown(label='Object category (we have fine tuned the model for specific categories, '
167
+ 'reconstructing with these model should lead to better result '
168
+ 'for specific categories.) ',
169
+ choices=['general', 'backpack', 'ball', 'bottle', 'box',
170
+ 'chair', 'skateboard', 'suitcase', 'table'],
171
+ value='general')
172
  # Output visualization
173
  with gr.Row():
174
  pc_plot = gr.Plot(label="Reconstructed point cloud")
175
  out_pc_download = gr.File(label="3D reconstruction for download") # this allows downloading
176
+ with gr.Row():
177
+ out_log = gr.TextArea(label='Output log')
178
+
179
 
180
  gr.HTML("""<br/>""")
181
  # Control
182
  with gr.Row():
183
  button_recon = gr.Button("Start Reconstruction", interactive=True, variant='secondary')
184
  button_recon.click(fn=partial(inference, runner, cfg),
185
+ inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed, input_cls],
186
+ outputs=[pc_plot, out_pc_download, out_log])
187
  gr.HTML("""<br/>""")
188
  # Example input
189
  example_dir = cfg.run.code_dir_abs+"/examples"
190
  rgb, ps, obj = 'k1.color.jpg', 'k1.person_mask.png', 'k1.obj_rend_mask.png'
191
  example_images = gr.Examples([
192
+ [f"{example_dir}/017450/{rgb}", f"{example_dir}/017450/{ps}", f"{example_dir}/017450/{obj}", 3.0, 42, 'skateboard'],
193
+ [f"{example_dir}/002446/{rgb}", f"{example_dir}/002446/{ps}", f"{example_dir}/002446/{obj}", 3.0, 42, 'ball'],
194
+ [f"{example_dir}/053431/{rgb}", f"{example_dir}/053431/{ps}", f"{example_dir}/053431/{obj}", 3.8, 42, 'chair'],
195
+ [f"{example_dir}/158107/{rgb}", f"{example_dir}/158107/{ps}", f"{example_dir}/158107/{obj}", 3.8, 42, 'chair'],
196
 
197
+ ], inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed, input_cls],)
198
 
199
  # demo.launch(share=True)
200
  # Enabling queue for runtime>60s, see: https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062
201
+ demo.queue().launch(share=True)
202
 
203
  if __name__ == '__main__':
204
+ main()
demo.py CHANGED
@@ -65,8 +65,8 @@ class DemoRunner:
65
  self.rend_size = cfg.dataset.image_size
66
  self.device = 'cuda'
67
 
68
- def load_checkpoint(self, ckpt_file1, model_stage1):
69
- checkpoint = torch.load(ckpt_file1, map_location='cpu')
70
  state_dict, key = checkpoint['model'], 'model'
71
  if any(k.startswith('module.') for k in state_dict.keys()):
72
  state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
@@ -78,6 +78,13 @@ class DemoRunner:
78
  if len(unexpected_keys):
79
  print(f' - Unexpected_keys: {unexpected_keys}')
80
 
 
 
 
 
 
 
 
81
  @torch.no_grad()
82
  def run(self):
83
  "simply run the demo on given images, and save the results"
 
65
  self.rend_size = cfg.dataset.image_size
66
  self.device = 'cuda'
67
 
68
+ def load_checkpoint(self, ckpt_file1, model_stage1, device='cpu'):
69
+ checkpoint = torch.load(ckpt_file1, map_location=device)
70
  state_dict, key = checkpoint['model'], 'model'
71
  if any(k.startswith('module.') for k in state_dict.keys()):
72
  state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
 
78
  if len(unexpected_keys):
79
  print(f' - Unexpected_keys: {unexpected_keys}')
80
 
81
+ def reload_checkpoint(self, cat_name):
82
+ "load checkpoint of models fine tuned on specific categories"
83
+ ckpt_file1 = hf_hub_download("xiexh20/HDM-models", f'{self.cfg.run.stage1_name}-{cat_name}.pth')
84
+ self.load_checkpoint(ckpt_file1, self.model_stage1, device=self.device)
85
+ ckpt_file2 = hf_hub_download("xiexh20/HDM-models", f'{self.cfg.run.stage2_name}-{cat_name}.pth')
86
+ self.load_checkpoint(ckpt_file2, self.model_stage2, device=self.device)
87
+
88
  @torch.no_grad()
89
  def run(self):
90
  "simply run the demo on given images, and save the results"
examples/002446/k1.color.jpg ADDED
examples/002446/k1.color.json ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "body_joints": [
3
+ 362.91015625,
4
+ 159.39576721191406,
5
+ 0.9023686647415161,
6
+ 373.57745361328125,
7
+ 180.60316467285156,
8
+ 0.8592674136161804,
9
+ 333.528564453125,
10
+ 179.45702362060547,
11
+ 0.7867028713226318,
12
+ 278.2209167480469,
13
+ 207.63121032714844,
14
+ 0.8840203285217285,
15
+ 228.78005981445312,
16
+ 234.69793701171875,
17
+ 0.8324164152145386,
18
+ 417.08209228515625,
19
+ 181.77294921875,
20
+ 0.7164953947067261,
21
+ 477.138427734375,
22
+ 199.3846893310547,
23
+ 0.7733086347579956,
24
+ 539.4710083007812,
25
+ 219.44891357421875,
26
+ 0.8321817517280579,
27
+ 401.8182678222656,
28
+ 288.8574676513672,
29
+ 0.61277836561203,
30
+ 382.9984436035156,
31
+ 294.7460632324219,
32
+ 0.5884051322937012,
33
+ 388.8341979980469,
34
+ 377.1164245605469,
35
+ 0.8282020092010498,
36
+ 488.86529541015625,
37
+ 404.145751953125,
38
+ 0.6257187724113464,
39
+ 420.6218566894531,
40
+ 282.9443664550781,
41
+ 0.5774698257446289,
42
+ 455.9610290527344,
43
+ 361.8221130371094,
44
+ 0.8058001399040222,
45
+ 557.13916015625,
46
+ 339.43017578125,
47
+ 0.69627445936203,
48
+ 352.3575134277344,
49
+ 151.14682006835938,
50
+ 0.9335765242576599,
51
+ 371.185791015625,
52
+ 146.48798370361328,
53
+ 0.8626495003700256,
54
+ 342.9620666503906,
55
+ 150.00089263916016,
56
+ 0.0641486719250679,
57
+ 390.03204345703125,
58
+ 135.8568878173828,
59
+ 0.8869808316230774,
60
+ 595.938720703125,
61
+ 338.2825012207031,
62
+ 0.25365617871284485,
63
+ 594.7731323242188,
64
+ 334.75506591796875,
65
+ 0.23056654632091522,
66
+ 561.8401489257812,
67
+ 331.20794677734375,
68
+ 0.29395991563796997,
69
+ 484.1672058105469,
70
+ 435.9705810546875,
71
+ 0.6335450410842896,
72
+ 479.44921875,
73
+ 433.6032409667969,
74
+ 0.5307492017745972,
75
+ 501.7928466796875,
76
+ 398.28533935546875,
77
+ 0.5881072878837585
78
+ ]
79
+ }
examples/002446/k1.obj_rend_mask.png ADDED
examples/002446/k1.person_mask.png ADDED
examples/053431/k1.color.jpg ADDED
examples/053431/k1.obj_rend_mask.png ADDED
examples/053431/k1.person_mask.png ADDED
examples/158107/k1.color.jpg ADDED
examples/158107/k1.obj_rend_mask.png ADDED
examples/158107/k1.person_mask.png ADDED