ac5113 commited on
Commit
9f6d75b
1 Parent(s): 63e7312

plotly test

Browse files
Files changed (2) hide show
  1. app.py +144 -3
  2. requirements.txt +2 -1
app.py CHANGED
@@ -14,6 +14,7 @@ import gradio as gr
14
 
15
  import trimesh
16
  import pyrender
 
17
 
18
  from models.deco import DECO
19
  from common import constants
@@ -62,7 +63,16 @@ description = '''
62
  - [ECON](https://huggingface.co/spaces/Yuliang/ECON)
63
 
64
  </details>
65
- '''
 
 
 
 
 
 
 
 
 
66
 
67
  def initiate_model(model_path):
68
  deco_model = DECO('hrnet', True, device)
@@ -175,6 +185,134 @@ def initiate_model(model_path):
175
  # IMG.thumbnail((3000, 3000))
176
  # return IMG
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def main(pil_img, out_dir='demo_out', model_path='checkpoint/deco_best.pth', mesh_colour=[130, 130, 130, 255], annot_colour=[0, 255, 0, 255]):
179
  deco_model = initiate_model(model_path)
180
 
@@ -219,7 +357,10 @@ def main(pil_img, out_dir='demo_out', model_path='checkpoint/deco_best.pth', mes
219
  print(f'Saving mesh to {mesh_out_dir}')
220
  body_model_smpl.export(os.path.join(mesh_out_dir, 'pred.obj'))
221
 
222
- return rend, os.path.join(mesh_out_dir, 'pred.obj')
 
 
 
223
 
224
  with gr.Blocks(title="DECO", css=".gradio-container") as demo:
225
  gr.Markdown(description)
@@ -230,7 +371,7 @@ with gr.Blocks(title="DECO", css=".gradio-container") as demo:
230
  with gr.Column():
231
  input_image = gr.Image(label="Input image", type="pil")
232
  with gr.Column():
233
- output_image = gr.Image(label="Renders", type="pil")
234
  output_meshes = gr.File(label="3D meshes")
235
 
236
  gr.HTML("""<br/>""")
 
14
 
15
  import trimesh
16
  import pyrender
17
+ import plotly.graph_objects as go
18
 
19
  from models.deco import DECO
20
  from common import constants
 
63
  - [ECON](https://huggingface.co/spaces/Yuliang/ECON)
64
 
65
  </details>
66
+ '''
67
+
68
+ DEFAULT_LIGHTING = dict(
69
+ ambient=0.6,
70
+ diffuse=0.5,
71
+ fresnel=0.01,
72
+ specular=0.1,
73
+ roughness=0.001)
74
+
75
+ DEFAULT_LIGHT_POSITION = dict(x=6, y=0, z=10)
76
 
77
  def initiate_model(model_path):
78
  deco_model = DECO('hrnet', True, device)
 
185
  # IMG.thumbnail((3000, 3000))
186
  # return IMG
187
 
188
+ def create_layout(dummy, camera=None):
189
+ if camera is None:
190
+ camera = dict(
191
+ up=dict(x=0, y=1, z=0),
192
+ center=dict(x=0, y=0, z=0),
193
+ eye=dict(x=dummy.x.mean(), y=0, z=3),
194
+ projection=dict(type='perspective'))
195
+
196
+ layout = dict(
197
+ scene={
198
+ "xaxis": {
199
+ 'showgrid': False,
200
+ 'zeroline': False,
201
+ 'visible': False,
202
+ "range": [dummy.x.min(), dummy.x.max()]
203
+ },
204
+ "yaxis": {
205
+ 'showgrid': False,
206
+ 'zeroline': False,
207
+ 'visible': False,
208
+ "range": [dummy.y.min(), dummy.y.max()]
209
+ },
210
+ "zaxis": {
211
+ 'showgrid': False,
212
+ 'zeroline': False,
213
+ 'visible': False,
214
+ "range": [dummy.z.min(), dummy.z.max()]
215
+ },
216
+ },
217
+ autosize=False,
218
+ width=1000, height=1000,
219
+ scene_camera=camera,
220
+ scene_aspectmode="data",
221
+ clickmode="event+select",
222
+ margin={'l': 0, 't': 0}
223
+ )
224
+
225
+ return layout
226
+
227
+ def create_fig(dummy, colors=[], camera=None):
228
+ fig = go.Figure(
229
+ data=dummy.mesh_3d(colors),
230
+ layout=create_layout(dummy, camera))
231
+ return fig
232
+
233
+ class Dummy:
234
+
235
+ def __init__(self, mesh_path):
236
+ """A simple polygonal dummy with colored patches."""
237
+ self._load_trimesh(mesh_path)
238
+
239
+ def _load_trimesh(self, path):
240
+ """Load a mesh given a path to a .PLY file."""
241
+ self._trimesh = trimesh.load(path, process=False)
242
+ self._vertices = np.array(self._trimesh.vertices)
243
+ self._faces = np.array(self._trimesh.faces)
244
+ self.colors = self._trimesh.visual.vertex_colors
245
+
246
+ @property
247
+ def vertices(self):
248
+ """All the mesh vertices."""
249
+ return self._vertices
250
+
251
+ @property
252
+ def faces(self):
253
+ """All the mesh faces."""
254
+ return self._faces
255
+
256
+ @property
257
+ def n_vertices(self):
258
+ """Number of vertices in a mesh."""
259
+ return self._vertices.shape[0]
260
+
261
+ @property
262
+ def n_faces(self):
263
+ """Number of faces in a mesh."""
264
+ return self._faces.shape[0]
265
+
266
+ @property
267
+ def x(self):
268
+ """An array of vertex x coordinates"""
269
+ return self._vertices[:, 0]
270
+
271
+ @property
272
+ def y(self):
273
+ """An array of vertex y coordinates"""
274
+ return self._vertices[:, 1]
275
+
276
+ @property
277
+ def z(self):
278
+ """An array of vertex z coordinates"""
279
+ return self._vertices[:, 2]
280
+
281
+ @property
282
+ def i(self):
283
+ """An array of the first face vertices"""
284
+ return self._faces[:, 0]
285
+
286
+ @property
287
+ def j(self):
288
+ """An array of the second face vertices"""
289
+ return self._faces[:, 1]
290
+
291
+ @property
292
+ def k(self):
293
+ """An array of the third face vertices"""
294
+ return self._faces[:, 2]
295
+
296
+ @property
297
+ def default_selection(self):
298
+ """Default patch selection mask."""
299
+ return dict(vertices=[])
300
+
301
+ def mesh_3d(
302
+ self,
303
+ lighting=DEFAULT_LIGHTING,
304
+ light_position=DEFAULT_LIGHT_POSITION
305
+ ):
306
+ """Construct a Mesh3D object give a clickmask for patch coloring."""
307
+
308
+ return go.Mesh3d(
309
+ x=self.x, y=self.y, z=self.z,
310
+ i=self.i, j=self.j, k=self.k,
311
+ vertexcolor=self.colors,
312
+ lighting=lighting,
313
+ lightposition=light_position,
314
+ hoverinfo='none')
315
+
316
  def main(pil_img, out_dir='demo_out', model_path='checkpoint/deco_best.pth', mesh_colour=[130, 130, 130, 255], annot_colour=[0, 255, 0, 255]):
317
  deco_model = initiate_model(model_path)
318
 
 
357
  print(f'Saving mesh to {mesh_out_dir}')
358
  body_model_smpl.export(os.path.join(mesh_out_dir, 'pred.obj'))
359
 
360
+ dummy = Dummy(os.path.join(mesh_out_dir, 'pred.obj'))
361
+ fig = create_fig(dummy)
362
+
363
+ return fig, os.path.join(mesh_out_dir, 'pred.obj')
364
 
365
  with gr.Blocks(title="DECO", css=".gradio-container") as demo:
366
  gr.Markdown(description)
 
371
  with gr.Column():
372
  input_image = gr.Image(label="Input image", type="pil")
373
  with gr.Column():
374
+ output_image = gr.Plot(label="Renders")
375
  output_meshes = gr.File(label="3D meshes")
376
 
377
  gr.HTML("""<br/>""")
requirements.txt CHANGED
@@ -18,4 +18,5 @@ chumpy
18
  numpy==1.23.1
19
  yacs
20
  gradio
21
- ipykernel
 
 
18
  numpy==1.23.1
19
  yacs
20
  gradio
21
+ ipykernel
22
+ plotly