import spaces import gradio as gr import os from Anymate.args import ui_args, anymate_args from Anymate.utils.ui_utils import process_input, vis_joint, vis_connectivity, vis_skinning, vis_all, prepare_blender_file from Anymate.utils.ui_utils import get_result_joint, get_result_connectivity, get_result_skinning from Anymate.utils.utils import load_checkpoint # Check if checkpoints exist, if not download them if not (os.path.exists(ui_args.checkpoint_joint) and os.path.exists(ui_args.checkpoint_conn) and os.path.exists(ui_args.checkpoint_skin)): print("Missing checkpoints, downloading them...") os.system("bash Anymate/get_checkpoints.sh") model_joint = load_checkpoint(ui_args.checkpoint_joint, 'cpu', anymate_args.num_joints).to(anymate_args.device) model_connectivity = load_checkpoint(ui_args.checkpoint_conn, 'cpu', anymate_args.num_joints).to(anymate_args.device) model_skinning = load_checkpoint(ui_args.checkpoint_skin, 'cpu', anymate_args.num_joints).to(anymate_args.device) @spaces.GPU def get_all_results(mesh_file, pc, eps=0.03, min_samples=1): # pc = pc.to(anymate_args.device) joints = get_result_joint(mesh_file, model_joint, pc, eps=eps, min_samples=min_samples) conns = get_result_connectivity(mesh_file, model_connectivity, pc, joints) skins = get_result_skinning(mesh_file, model_skinning, pc, joints, conns) print("Finish Inference") return with gr.Blocks() as demo: gr.Markdown(""" # Anymate: Auto-rigging 3D Objects [Project](https://anymate3d.github.io/) """) pc = gr.State(value=None) normalized_mesh_file = gr.State(value=None) # result_joint = gr.State(value=None) # result_connectivity = gr.State(value=None) # result_skinning = gr.State(value=None) # model_joint = gr.State(value=model_joint) # model_connectivity = gr.State(value=model_connectivity) # model_skinning = gr.State(value=model_skinning) with gr.Row(): with gr.Column(): # Input section gr.Markdown("### Input") mesh_input = gr.Model3D(label="Input 3D Mesh", clear_color=[0.0, 0.0, 0.0, 0.0]) # Sample 3D objects section gr.Markdown("### Sample Objects") sample_objects_dir = './samples' sample_objects = [os.path.join(sample_objects_dir, f) for f in os.listdir(sample_objects_dir) if f.endswith('.obj') and os.path.isfile(os.path.join(sample_objects_dir, f))] sample_objects.sort() sample_dropdown = gr.Dropdown( label="Select Sample Object", choices=sample_objects, interactive=True, value=sample_objects[0] ) load_sample_btn = gr.Button("Load Sample") with gr.Column(): # Output section gr.Markdown("### Output (wireframe display mode)") mesh_output = gr.Model3D(label="Output 3D Mesh", clear_color=[0.0, 0.0, 0.0, 0.0], display_mode="wireframe") with gr.Column(): # Output section gr.Markdown("### (solid display mode & blender file)") mesh_output2 = gr.Model3D(label="Output 3D Mesh", clear_color=[0.0, 0.0, 0.0, 0.0], display_mode="solid") blender_file = gr.File(label="Output Blender File", scale=1) # Checkpoint paths # joint_models_dir = 'Anymate/checkpoints/joint' # joint_models = [os.path.join(joint_models_dir, f) for f in os.listdir(joint_models_dir) # if os.path.isfile(os.path.join(joint_models_dir, f))] # with gr.Row(): # joint_checkpoint = gr.Dropdown( # label="Joint Checkpoint", # choices=joint_models, # value=ui_args.checkpoint_joint, # interactive=True # ) # joint_status = gr.Checkbox(label="Joint Model Status", value=False, interactive=False, scale=0.3) # with gr.Column(): # with gr.Row(): # load_joint_btn = gr.Button("Load", scale=0.3) # process_joint_btn = gr.Button("Process", scale=0.3) # conn_models_dir = 'Anymate/checkpoints/conn' # conn_models = [os.path.join(conn_models_dir, f) for f in os.listdir(conn_models_dir) # if os.path.isfile(os.path.join(conn_models_dir, f))] # with gr.Row(): # conn_checkpoint = gr.Dropdown( # label="Connection Checkpoint", # choices=conn_models, # value=ui_args.checkpoint_conn, # interactive=True # ) # conn_status = gr.Checkbox(label="Connectivity Model Status", value=False, interactive=False, scale=0.3) # with gr.Column(): # with gr.Row(): # load_conn_btn = gr.Button("Load", scale=0.3) # process_conn_btn = gr.Button("Process", scale=0.3) # skin_models_dir = 'Anymate/checkpoints/skin' # skin_models = [os.path.join(skin_models_dir, f) for f in os.listdir(skin_models_dir) # if os.path.isfile(os.path.join(skin_models_dir, f))] # with gr.Row(): # skin_checkpoint = gr.Dropdown( # label="Skin Checkpoint", # choices=skin_models, # value=ui_args.checkpoint_skin, # interactive=True # ) # skin_status = gr.Checkbox(label="Skinning Model Status", value=False, interactive=False, scale=0.3) # with gr.Column(): # with gr.Row(): # load_skin_btn = gr.Button("Load", scale=0.3) # process_skin_btn = gr.Button("Process", scale=0.3) with gr.Row(): # load_all_btn = gr.Button("Load all models", scale=1) process_all_btn = gr.Button("Run all models", scale=1) # download_btn = gr.DownloadButton("Blender File Not Ready", scale=0.3) # blender_file = gr.File(label="Blender File", scale=1) # Parameters for DBSCAN clustering algorithm used to adjust joint clustering eps = gr.Number(label="Epsilon", value=0.03, interactive=True, info="Controls the maximum distance between joints in a cluster") min_samples = gr.Number(label="Min Samples", value=1, interactive=True, info="Minimum number of joints required to form a cluster") mesh_input.change( process_input, inputs=mesh_input, outputs=[normalized_mesh_file, mesh_output, mesh_output2, blender_file, pc] ) load_sample_btn.click( fn=lambda sample_path: sample_path if sample_path else None, inputs=[sample_dropdown], outputs=[mesh_input] ).then( process_input, inputs=mesh_input, outputs=[normalized_mesh_file, mesh_output, mesh_output2, blender_file, pc] ) normalized_mesh_file.change( lambda x: x, inputs=normalized_mesh_file, outputs=mesh_input ) # result_joint.change( # vis_joint, # inputs=[normalized_mesh_file, result_joint], # outputs=[mesh_output, mesh_output2] # ) # result_connectivity.change( # vis_connectivity, # inputs=[normalized_mesh_file, result_joint, result_connectivity], # outputs=[mesh_output, mesh_output2] # ) # result_skinning.change( # vis_skinning, # inputs=[normalized_mesh_file, result_joint, result_connectivity, result_skinning], # outputs=[mesh_output, mesh_output2] # ) # result_skinning.change( # prepare_blender_file, # inputs=[normalized_mesh_file], # outputs=blender_file # ) process_all_btn.click( get_all_results, inputs=[normalized_mesh_file, pc, eps, min_samples], outputs=[] ).then( vis_all, inputs=[normalized_mesh_file], outputs=[mesh_output, mesh_output2] ).then( prepare_blender_file, inputs=[normalized_mesh_file], outputs=blender_file ) if __name__ == "__main__": demo.launch()