Spaces:
Running
on
T4
Running
on
T4
kyleleey
commited on
Commit
•
f09d510
1
Parent(s):
d7c4a03
init version of app
Browse files
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: blue
|
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
python_version: 3.9.13
|
8 |
-
sdk_version:
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
license: cc-by-nc-sa-4.0
|
|
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
python_version: 3.9.13
|
8 |
+
sdk_version: 3.50.2
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
license: cc-by-nc-sa-4.0
|
app.py
CHANGED
@@ -98,7 +98,7 @@ def expand2square(pil_img, background_color):
|
|
98 |
return result
|
99 |
|
100 |
|
101 |
-
def preprocess(predictor, input_image, chk_group=None, segment=
|
102 |
RES = 1024
|
103 |
input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
|
104 |
if chk_group is not None:
|
@@ -403,13 +403,76 @@ def create_bones_scene(bones, joint_color=[66, 91, 140], bone_color=[119, 144, 1
|
|
403 |
return mesh
|
404 |
|
405 |
|
406 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
epoch = 999
|
408 |
total_iter = 999999
|
409 |
model = model_items[0]
|
410 |
memory_bank = model_items[1]
|
411 |
memory_bank_keys = model_items[2]
|
412 |
|
|
|
|
|
413 |
input_image = torch.stack([torchvision.transforms.ToTensor()(input_img)], dim=0).to(device)
|
414 |
|
415 |
with torch.no_grad():
|
@@ -455,7 +518,7 @@ def run_pipeline(model_items, cfgs, input_img, device):
|
|
455 |
gray_light = FixedDirectionLight(direction=torch.FloatTensor([0, 0, 1]).to(device), amb=0.2, diff=0.7)
|
456 |
|
457 |
image_pred, mask_pred, _, _, _, shading = model.render(
|
458 |
-
shape, texture_pred, mvp, w2c, campos, 256, background=model.background_mode,
|
459 |
im_features=im_features, light=gray_light, prior_shape=prior_shape, render_mode='diffuse',
|
460 |
render_flow=False, dino_pred=None, im_features_map=im_features_map
|
461 |
)
|
@@ -469,7 +532,7 @@ def run_pipeline(model_items, cfgs, input_img, device):
|
|
469 |
nv_meshes = make_mesh(verts=bones_meshes.verts_padded(), faces=bones_meshes.faces_padded()[0:1],
|
470 |
uvs=bones_meshes.textures.verts_uvs_padded(), uv_idx=bones_meshes.textures.faces_uvs_padded()[0:1],
|
471 |
material=material_texture.Texture2D(bones_meshes.textures.maps_padded()))
|
472 |
-
buffers = render_mesh(dr.RasterizeGLContext(), nv_meshes, mvp, w2c, campos, nv_meshes.material, lgt=gray_light, feat=im_features, dino_pred=None, resolution=256, bsdf="diffuse")
|
473 |
|
474 |
shaded = buffers["shaded"].permute(0, 3, 1, 2)
|
475 |
bone_image = shaded[:, :3, :, :]
|
@@ -481,20 +544,10 @@ def run_pipeline(model_items, cfgs, input_img, device):
|
|
481 |
mesh_image = save_images(shading, mask_pred)
|
482 |
mesh_bones_image = save_images(image_with_bones, mask_final)
|
483 |
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
final_mesh_tri = trimesh.Trimesh(
|
488 |
-
vertices=final_shape.v_pos[0].detach().cpu().numpy(),
|
489 |
-
faces=final_shape.t_pos_idx[0].detach().cpu().numpy(),
|
490 |
-
process=False,
|
491 |
-
maintain_order=True)
|
492 |
-
prior_mesh_tri = trimesh.Trimesh(
|
493 |
-
vertices=prior_shape.v_pos[0].detach().cpu().numpy(),
|
494 |
-
faces=prior_shape.t_pos_idx[0].detach().cpu().numpy(),
|
495 |
-
process=False,
|
496 |
-
maintain_order=True)
|
497 |
|
|
|
498 |
|
499 |
|
500 |
def run_demo():
|
@@ -582,7 +635,6 @@ def run_demo():
|
|
582 |
with gr.Column():
|
583 |
input_processing = gr.CheckboxGroup(['Use SAM to center animal'],
|
584 |
label='Input Image Preprocessing',
|
585 |
-
value=['Use SAM to center animal'],
|
586 |
info='untick this, if animal is already centered, e.g. in example images')
|
587 |
# with gr.Column():
|
588 |
# output_processing = gr.CheckboxGroup(['Background Removal'], label='Output Image Postprocessing', value=[])
|
@@ -599,23 +651,26 @@ def run_demo():
|
|
599 |
# with gr.Column():
|
600 |
# crop_size = gr.Number(192, label='Crop size')
|
601 |
# crop_size = 192
|
602 |
-
run_btn = gr.Button('
|
603 |
with gr.Row():
|
604 |
view_1 = gr.Image(interactive=False, height=256, show_label=False)
|
605 |
view_2 = gr.Image(interactive=False, height=256, show_label=False)
|
606 |
with gr.Row():
|
607 |
-
shape_1 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Reconstructed Model")
|
608 |
-
|
|
|
|
|
|
|
609 |
|
610 |
run_btn.click(fn=partial(preprocess, predictor),
|
611 |
inputs=[input_image, input_processing],
|
612 |
outputs=[processed_image_highres, processed_image], queue=True
|
613 |
).success(fn=partial(run_pipeline, model_items, model_cfgs),
|
614 |
-
inputs=[processed_image
|
615 |
-
outputs=[view_1, view_2, shape_1, shape_2]
|
616 |
)
|
617 |
demo.queue().launch(share=True, max_threads=80)
|
618 |
-
# _, local_url, share_url = demo.launch(share=True, server_name="0.0.0.0", server_port=23425)
|
619 |
# print('local_url: ', local_url)
|
620 |
|
621 |
|
|
|
98 |
return result
|
99 |
|
100 |
|
101 |
+
def preprocess(predictor, input_image, chk_group=None, segment=False):
|
102 |
RES = 1024
|
103 |
input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
|
104 |
if chk_group is not None:
|
|
|
403 |
return mesh
|
404 |
|
405 |
|
406 |
+
def save_mesh(mesh, file_path):
|
407 |
+
obj_file = file_path
|
408 |
+
idx = 0
|
409 |
+
print("Writing mesh: ", obj_file)
|
410 |
+
with open(obj_file, "w") as f:
|
411 |
+
# f.write(f"mtllib {fname}.mtl\n")
|
412 |
+
f.write("g default\n")
|
413 |
+
|
414 |
+
v_pos = mesh.v_pos[idx].detach().cpu().numpy() if mesh.v_pos is not None else None
|
415 |
+
v_nrm = mesh.v_nrm[idx].detach().cpu().numpy() if mesh.v_nrm is not None else None
|
416 |
+
v_tex = mesh.v_tex[idx].detach().cpu().numpy() if mesh.v_tex is not None else None
|
417 |
+
|
418 |
+
t_pos_idx = mesh.t_pos_idx[0].detach().cpu().numpy() if mesh.t_pos_idx is not None else None
|
419 |
+
t_nrm_idx = mesh.t_nrm_idx[0].detach().cpu().numpy() if mesh.t_nrm_idx is not None else None
|
420 |
+
t_tex_idx = mesh.t_tex_idx[0].detach().cpu().numpy() if mesh.t_tex_idx is not None else None
|
421 |
+
|
422 |
+
print(" writing %d vertices" % len(v_pos))
|
423 |
+
for v in v_pos:
|
424 |
+
f.write('v {} {} {} \n'.format(v[0], v[1], v[2]))
|
425 |
+
|
426 |
+
if v_nrm is not None:
|
427 |
+
print(" writing %d normals" % len(v_nrm))
|
428 |
+
assert(len(t_pos_idx) == len(t_nrm_idx))
|
429 |
+
for v in v_nrm:
|
430 |
+
f.write('vn {} {} {}\n'.format(v[0], v[1], v[2]))
|
431 |
+
|
432 |
+
# faces
|
433 |
+
f.write("s 1 \n")
|
434 |
+
f.write("g pMesh1\n")
|
435 |
+
f.write("usemtl defaultMat\n")
|
436 |
+
|
437 |
+
# Write faces
|
438 |
+
print(" writing %d faces" % len(t_pos_idx))
|
439 |
+
for i in range(len(t_pos_idx)):
|
440 |
+
f.write("f ")
|
441 |
+
for j in range(3):
|
442 |
+
f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1)))
|
443 |
+
f.write("\n")
|
444 |
+
|
445 |
+
|
446 |
+
def process_mesh(shape, name):
|
447 |
+
mesh = shape.clone()
|
448 |
+
output_glb = f'./{name}.glb'
|
449 |
+
output_obj = f'./{name}.obj'
|
450 |
+
|
451 |
+
# save the obj file for download
|
452 |
+
save_mesh(mesh, output_obj)
|
453 |
+
|
454 |
+
# save the glb for visualize
|
455 |
+
mesh_tri = trimesh.Trimesh(
|
456 |
+
vertices=mesh.v_pos[0].detach().cpu().numpy(),
|
457 |
+
faces=mesh.t_pos_idx[0][..., [2,1,0]].detach().cpu().numpy(),
|
458 |
+
process=False,
|
459 |
+
maintain_order=True
|
460 |
+
)
|
461 |
+
mesh_tri.visual.vertex_colors = (mesh.v_nrm[0][..., [2,1,0]].detach().cpu().numpy() + 1.0) * 0.5 * 255.0
|
462 |
+
mesh_tri.export(file_obj=output_glb)
|
463 |
+
|
464 |
+
return output_glb, output_obj
|
465 |
+
|
466 |
+
|
467 |
+
def run_pipeline(model_items, cfgs, input_img):
|
468 |
epoch = 999
|
469 |
total_iter = 999999
|
470 |
model = model_items[0]
|
471 |
memory_bank = model_items[1]
|
472 |
memory_bank_keys = model_items[2]
|
473 |
|
474 |
+
device = f'cuda:{_GPU_ID}'
|
475 |
+
|
476 |
input_image = torch.stack([torchvision.transforms.ToTensor()(input_img)], dim=0).to(device)
|
477 |
|
478 |
with torch.no_grad():
|
|
|
518 |
gray_light = FixedDirectionLight(direction=torch.FloatTensor([0, 0, 1]).to(device), amb=0.2, diff=0.7)
|
519 |
|
520 |
image_pred, mask_pred, _, _, _, shading = model.render(
|
521 |
+
shape, texture_pred, mvp, w2c, campos, (256, 256), background=model.background_mode,
|
522 |
im_features=im_features, light=gray_light, prior_shape=prior_shape, render_mode='diffuse',
|
523 |
render_flow=False, dino_pred=None, im_features_map=im_features_map
|
524 |
)
|
|
|
532 |
nv_meshes = make_mesh(verts=bones_meshes.verts_padded(), faces=bones_meshes.faces_padded()[0:1],
|
533 |
uvs=bones_meshes.textures.verts_uvs_padded(), uv_idx=bones_meshes.textures.faces_uvs_padded()[0:1],
|
534 |
material=material_texture.Texture2D(bones_meshes.textures.maps_padded()))
|
535 |
+
buffers = render_mesh(dr.RasterizeGLContext(), nv_meshes, mvp, w2c, campos, nv_meshes.material, lgt=gray_light, feat=im_features, dino_pred=None, resolution=(256,256), bsdf="diffuse")
|
536 |
|
537 |
shaded = buffers["shaded"].permute(0, 3, 1, 2)
|
538 |
bone_image = shaded[:, :3, :, :]
|
|
|
544 |
mesh_image = save_images(shading, mask_pred)
|
545 |
mesh_bones_image = save_images(image_with_bones, mask_final)
|
546 |
|
547 |
+
shape_glb, shape_obj = process_mesh(shape, 'reconstruced_shape')
|
548 |
+
base_shape_glb, base_shape_obj = process_mesh(prior_shape, 'reconstructed_base_shape')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
549 |
|
550 |
+
return mesh_image, mesh_bones_image, shape_glb, shape_obj, base_shape_glb, base_shape_obj
|
551 |
|
552 |
|
553 |
def run_demo():
|
|
|
635 |
with gr.Column():
|
636 |
input_processing = gr.CheckboxGroup(['Use SAM to center animal'],
|
637 |
label='Input Image Preprocessing',
|
|
|
638 |
info='untick this, if animal is already centered, e.g. in example images')
|
639 |
# with gr.Column():
|
640 |
# output_processing = gr.CheckboxGroup(['Background Removal'], label='Output Image Postprocessing', value=[])
|
|
|
651 |
# with gr.Column():
|
652 |
# crop_size = gr.Number(192, label='Crop size')
|
653 |
# crop_size = 192
|
654 |
+
run_btn = gr.Button('Reconstruct', variant='primary', interactive=True)
|
655 |
with gr.Row():
|
656 |
view_1 = gr.Image(interactive=False, height=256, show_label=False)
|
657 |
view_2 = gr.Image(interactive=False, height=256, show_label=False)
|
658 |
with gr.Row():
|
659 |
+
shape_1 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], height=512, label="Reconstructed Model")
|
660 |
+
shape_1_download = gr.File(label="Download Full Reconstructed Model")
|
661 |
+
with gr.Row():
|
662 |
+
shape_2 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], height=512, label="Bank Base Shape Model")
|
663 |
+
shape_2_download = gr.File(label="Download Full Bank Base Shape Model")
|
664 |
|
665 |
run_btn.click(fn=partial(preprocess, predictor),
|
666 |
inputs=[input_image, input_processing],
|
667 |
outputs=[processed_image_highres, processed_image], queue=True
|
668 |
).success(fn=partial(run_pipeline, model_items, model_cfgs),
|
669 |
+
inputs=[processed_image],
|
670 |
+
outputs=[view_1, view_2, shape_1, shape_1_download, shape_2, shape_2_download]
|
671 |
)
|
672 |
demo.queue().launch(share=True, max_threads=80)
|
673 |
+
# _, local_url, share_url = demo.queue().launch(share=True, server_name="0.0.0.0", server_port=23425)
|
674 |
# print('local_url: ', local_url)
|
675 |
|
676 |
|