Include pointcloud2mesh components to additionally view 3d model

#14
by RamAnanth1 - opened
Files changed (1) hide show
  1. app.py +31 -2
app.py CHANGED
@@ -9,6 +9,9 @@ from point_e.diffusion.sampler import PointCloudSampler
9
  from point_e.models.download import load_checkpoint
10
  from point_e.models.configs import MODEL_CONFIGS, model_from_config
11
  from point_e.util.plotting import plot_point_cloud
 
 
 
12
 
13
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
 
@@ -29,6 +32,14 @@ base_model.load_state_dict(load_checkpoint(base_name, device))
29
  print('downloading upsampler checkpoint...')
30
  upsampler_model.load_state_dict(load_checkpoint('upsample', device))
31
 
 
 
 
 
 
 
 
 
32
  sampler = PointCloudSampler(
33
  device=device,
34
  models=[base_model, upsampler_model],
@@ -65,12 +76,30 @@ def inference(prompt):
65
  )
66
  ),
67
  )
68
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  demo = gr.Interface(
71
  fn=inference,
72
  inputs="text",
73
- outputs=gr.Plot(),
74
  examples=[
75
  ["a red motorcycle"],
76
  ["a RED pumpkin"],
9
  from point_e.models.download import load_checkpoint
10
  from point_e.models.configs import MODEL_CONFIGS, model_from_config
11
  from point_e.util.plotting import plot_point_cloud
12
+ from point_e.util.pc_to_mesh import marching_cubes_mesh
13
+
14
+ import trimesh
15
 
16
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
 
32
  print('downloading upsampler checkpoint...')
33
  upsampler_model.load_state_dict(load_checkpoint('upsample', device))
34
 
35
+ print('creating SDF model...')
36
+ name = 'sdf'
37
+ sdf_model = model_from_config(MODEL_CONFIGS[name], device)
38
+ sdf_model.eval()
39
+
40
+ print('loading SDF model...')
41
+ sdf_model.load_state_dict(load_checkpoint(name, device))
42
+
43
  sampler = PointCloudSampler(
44
  device=device,
45
  models=[base_model, upsampler_model],
76
  )
77
  ),
78
  )
79
+
80
+ # Produce a mesh (with vertex colors)
81
+ mesh = marching_cubes_mesh(
82
+ pc=pc,
83
+ model=sdf_model,
84
+ batch_size=4096,
85
+ grid_size=32, # increase to 128 for resolution used in evals
86
+ progress=True,
87
+ )
88
+
89
+ # Write the mesh to a PLY file to import into some other program.
90
+ with open("mesh.ply", 'wb') as f:
91
+ mesh.write_ply(f)
92
+
93
+ obj_file = '3d_model.obj'
94
+ mesh = trimesh.load('mesh.ply')
95
+ mesh.export(obj_file)
96
+
97
+ return fig, obj_file
98
 
99
  demo = gr.Interface(
100
  fn=inference,
101
  inputs="text",
102
+ outputs=[gr.Plot(),gr.Model3D(value=None)],
103
  examples=[
104
  ["a red motorcycle"],
105
  ["a RED pumpkin"],