02alexander commited on
Commit
a0fdd41
1 Parent(s): 9e0db6d

debug printing for make3d

Browse files
Files changed (1) hide show
  1. app.py +77 -55
app.py CHANGED
@@ -12,11 +12,13 @@ from omegaconf import OmegaConf
12
  from einops import rearrange, repeat
13
  from tqdm import tqdm
14
  import threading
 
15
  from typing import Any
16
  from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
17
  import rerun as rr
18
  from gradio_rerun import Rerun
19
 
 
20
  from src.utils.train_util import instantiate_from_config
21
  from src.utils.camera_util import (
22
  FOV_to_intrinsics,
@@ -25,6 +27,7 @@ from src.utils.camera_util import (
25
  )
26
  from src.utils.mesh_util import save_obj, save_glb
27
  from src.utils.infer_util import remove_background, resize_foreground, images_to_video
 
28
 
29
  import tempfile
30
  from functools import partial
@@ -126,7 +129,7 @@ print(f'type(pipeline)={type(pipeline)}')
126
  # load reconstruction model
127
  print('Loading reconstruction model ...')
128
  model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
129
- model = instantiate_from_config(model_config)
130
  state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
131
  state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
132
  model.load_state_dict(state_dict, strict=True)
@@ -152,29 +155,30 @@ def preprocess(input_image, do_remove_background):
152
  return input_image
153
 
154
 
155
- def pipeline_callback(pipe: Any, step_index: int, timestep: float, callback_kwargs: dict[str, Any]) -> dict[str, Any]:
156
  rr.set_time_sequence("iteration", step_index)
157
  rr.set_time_seconds("timestep", timestep)
158
  latents = callback_kwargs["latents"]
159
  image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] # type: ignore[attr-defined]
160
  image = pipe.image_processor.postprocess(image, output_type="np").squeeze() # type: ignore[attr-defined]
161
 
162
- rr.log("output", rr.Image(image))
163
- rr.log("latents", rr.Tensor(latents.squeeze()))
 
164
  return callback_kwargs
165
 
166
  @spaces.GPU
167
- def generate_mvs(input_image, sample_steps, sample_seed):
168
-
169
- print(threading.get_ident())
170
 
171
  seed_everything(sample_seed)
172
 
173
- return pipeline(
174
  input_image,
175
  num_inference_steps=sample_steps,
176
- callback_on_step_end=pipeline_callback,
177
- )
 
 
178
 
179
  # sampling
180
  # z123_image = pipeline(
@@ -190,10 +194,9 @@ def generate_mvs(input_image, sample_steps, sample_seed):
190
 
191
  # return z123_image, show_image
192
 
193
-
194
  @spaces.GPU
195
- def make3d(images):
196
-
197
  global model
198
  if IS_FLEXICUBES:
199
  model.init_flexicubes_geometry(device, use_renderer=False)
@@ -205,9 +208,12 @@ def make3d(images):
205
 
206
  input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
207
  render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
 
208
 
209
  images = images.unsqueeze(0).to(device)
210
  images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
 
 
211
 
212
  mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
213
  print(mesh_fpath)
@@ -219,26 +225,31 @@ def make3d(images):
219
  with torch.no_grad():
220
  # get triplane
221
  planes = model.forward_planes(images, input_cameras)
 
222
 
223
  # # get video
224
- # chunk_size = 20 if IS_FLEXICUBES else 1
225
- # render_size = 384
226
 
227
  # frames = []
228
- # for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
229
- # if IS_FLEXICUBES:
230
- # frame = model.forward_geometry(
231
- # planes,
232
- # render_cameras[:, i:i+chunk_size],
233
- # render_size=render_size,
234
- # )['img']
235
- # else:
236
- # frame = model.synthesizer(
237
- # planes,
238
- # cameras=render_cameras[:, i:i+chunk_size],
239
- # render_size=render_size,
240
- # )['images_rgb']
241
- # frames.append(frame)
 
 
 
 
242
  # frames = torch.cat(frames, dim=1)
243
 
244
  # images_to_video(
@@ -255,10 +266,13 @@ def make3d(images):
255
  use_texture_map=False,
256
  **infer_config,
257
  )
 
 
258
 
259
  vertices, faces, vertex_colors = mesh_out
260
  vertices = vertices[:, [1, 2, 0]]
261
-
 
262
  save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
263
  save_obj(vertices, faces, vertex_colors, mesh_fpath)
264
 
@@ -266,31 +280,47 @@ def make3d(images):
266
 
267
  return mesh_fpath, mesh_glb_fpath
268
 
269
- @spaces.GPU
270
- def print_thread_ident_from_gpu():
271
- print(threading.get_ident())
272
-
273
  @rr.thread_local_stream("InstantMesh")
274
  def log_to_rr(input_image, do_remove_background, sample_steps, sample_seed):
 
 
 
 
 
275
 
276
- print(threading.get_ident())
277
- print_thread_ident_from_gpu()
278
 
279
- # preprocessed_image = preprocess(input_image, do_remove_background)
280
 
281
- # stream = rr.binary_stream()
 
282
 
283
- # rr.log("preprocessed_image", rr.Image(preprocessed_image))
 
 
 
 
 
 
 
 
 
 
 
284
 
285
- # yield stream.read()
 
286
 
287
- # z123_out = generate_mvs(input_image, sample_steps, sample_seed)
288
- # print(z123_out)
289
- # for image in z123_out.images:
290
- # rr.log("z123image", rr.Image(image))
291
- # yield stream.read()
292
- # yield stream.read()
293
- # pass
 
 
 
294
 
295
  _HEADER_ = '''
296
  <h2><b>Official 🤗 Gradio Demo</b></h2><h2><a href='https://github.com/TencentARC/InstantMesh' target='_blank'><b>InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models</b></a></h2>
@@ -343,14 +373,6 @@ with gr.Blocks() as demo:
343
  type="pil",
344
  elem_id="content_image",
345
  )
346
- processed_image = gr.Image(
347
- label="Processed Image",
348
- image_mode="RGBA",
349
- #width=256,
350
- #height=256,
351
- type="pil",
352
- interactive=False
353
- )
354
  with gr.Row():
355
  with gr.Group():
356
  do_remove_background = gr.Checkbox(
 
12
  from einops import rearrange, repeat
13
  from tqdm import tqdm
14
  import threading
15
+ from queue import SimpleQueue
16
  from typing import Any
17
  from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
18
  import rerun as rr
19
  from gradio_rerun import Rerun
20
 
21
+ import src
22
  from src.utils.train_util import instantiate_from_config
23
  from src.utils.camera_util import (
24
  FOV_to_intrinsics,
 
27
  )
28
  from src.utils.mesh_util import save_obj, save_glb
29
  from src.utils.infer_util import remove_background, resize_foreground, images_to_video
30
+ from src.models.lrm_mesh import InstantMesh
31
 
32
  import tempfile
33
  from functools import partial
 
129
  # load reconstruction model
130
  print('Loading reconstruction model ...')
131
  model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
132
+ model: InstantMesh = instantiate_from_config(model_config)
133
  state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
134
  state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
135
  model.load_state_dict(state_dict, strict=True)
 
155
  return input_image
156
 
157
 
158
+ def pipeline_callback(output_queue: SimpleQueue, pipe: Any, step_index: int, timestep: float, callback_kwargs: dict[str, Any]) -> dict[str, Any]:
159
  rr.set_time_sequence("iteration", step_index)
160
  rr.set_time_seconds("timestep", timestep)
161
  latents = callback_kwargs["latents"]
162
  image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] # type: ignore[attr-defined]
163
  image = pipe.image_processor.postprocess(image, output_type="np").squeeze() # type: ignore[attr-defined]
164
 
165
+ output_queue.put(("log", "mvs/image", rr.Image(image)))
166
+ output_queue.put(("log", "mvs/latents", rr.Tensor(latents.squeeze())))
167
+
168
  return callback_kwargs
169
 
170
  @spaces.GPU
171
+ def generate_mvs(input_image, sample_steps, sample_seed, output_queue: SimpleQueue):
 
 
172
 
173
  seed_everything(sample_seed)
174
 
175
+ z123_image = pipeline(
176
  input_image,
177
  num_inference_steps=sample_steps,
178
+ callback_on_step_end=lambda *args, **kwargs: pipeline_callback(output_queue, *args, **kwargs),
179
+ ).images[0]
180
+
181
+ output_queue.put(("z123_image", z123_image))
182
 
183
  # sampling
184
  # z123_image = pipeline(
 
194
 
195
  # return z123_image, show_image
196
 
 
197
  @spaces.GPU
198
+ def make3d(output_queue: SimpleQueue, images: Image.Image):
199
+ print(f'type(images)={type(images)}')
200
  global model
201
  if IS_FLEXICUBES:
202
  model.init_flexicubes_geometry(device, use_renderer=False)
 
208
 
209
  input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
210
  render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
211
+ print(f'type(input_cameras)={type(input_cameras)}')
212
 
213
  images = images.unsqueeze(0).to(device)
214
  images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
215
+ print(f'type(images)={type(images)}')
216
+
217
 
218
  mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
219
  print(mesh_fpath)
 
225
  with torch.no_grad():
226
  # get triplane
227
  planes = model.forward_planes(images, input_cameras)
228
+ print(f'type(planes)={type(planes)}')
229
 
230
  # # get video
231
+ chunk_size = 20 if IS_FLEXICUBES else 1
232
+ render_size = 384
233
 
234
  # frames = []
235
+ for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
236
+ if IS_FLEXICUBES:
237
+ frame = model.forward_geometry(
238
+ planes,
239
+ render_cameras[:, i:i+chunk_size],
240
+ render_size=render_size,
241
+ )['img']
242
+ else:
243
+ frame = model.synthesizer(
244
+ planes,
245
+ cameras=render_cameras[:, i:i+chunk_size],
246
+ render_size=render_size,
247
+ )['images_rgb']
248
+
249
+ print(f'type(framee)={type(frame)}')
250
+ output_queue.put(("log", "3dvideo", rr.Image(frame)))
251
+ # frames.append(frame)
252
+
253
  # frames = torch.cat(frames, dim=1)
254
 
255
  # images_to_video(
 
266
  use_texture_map=False,
267
  **infer_config,
268
  )
269
+ print(f'type(mesh_out)={type(mesh_out)}')
270
+
271
 
272
  vertices, faces, vertex_colors = mesh_out
273
  vertices = vertices[:, [1, 2, 0]]
274
+ print(f'type(vertices)={type(vertices)}')
275
+
276
  save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
277
  save_obj(vertices, faces, vertex_colors, mesh_fpath)
278
 
 
280
 
281
  return mesh_fpath, mesh_glb_fpath
282
 
 
 
 
 
283
  @rr.thread_local_stream("InstantMesh")
284
  def log_to_rr(input_image, do_remove_background, sample_steps, sample_seed):
285
+ preprocessed_image = preprocess(input_image, do_remove_background)
286
+
287
+ stream = rr.binary_stream()
288
+
289
+ rr.log("preprocessed_image", rr.Image(preprocessed_image))
290
 
291
+ yield stream.read()
 
292
 
293
+ output_queue = SimpleQueue()
294
 
295
+ mvs_thread = threading.Thread(target=generate_mvs, args=[input_image, sample_steps, sample_seed, output_queue])
296
+ mvs_thread.start()
297
 
298
+ while True:
299
+ msg = output_queue.get()
300
+ if msg[0] == "z123_image":
301
+ z123_image = msg[1]
302
+ break
303
+ elif msg[0] == "log":
304
+ entity_path = msg[1]
305
+ entity = msg[2]
306
+ rr.log(entity_path, entity)
307
+ yield stream.read()
308
+
309
+ mvs_thread.join()
310
 
311
+ rr.log("z123image", rr.Image(z123_image))
312
+ yield stream.read()
313
 
314
+ mesh_fpath, mesh_glb_fpath = make3d(output_queue, z123_image)
315
+
316
+ while not output_queue.empty():
317
+ msg = output_queue.get()
318
+ if msg[0] == "log":
319
+ entity_path = msg[1]
320
+ entity = msg[2]
321
+ rr.log(entity_path, entity)
322
+ yield stream.read()
323
+
324
 
325
  _HEADER_ = '''
326
  <h2><b>Official 🤗 Gradio Demo</b></h2><h2><a href='https://github.com/TencentARC/InstantMesh' target='_blank'><b>InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models</b></a></h2>
 
373
  type="pil",
374
  elem_id="content_image",
375
  )
 
 
 
 
 
 
 
 
376
  with gr.Row():
377
  with gr.Group():
378
  do_remove_background = gr.Checkbox(