Anymate / app.py
yfdeng's picture
add link to project
28a2da9
import spaces
import gradio as gr
import os
from Anymate.args import ui_args, anymate_args
from Anymate.utils.ui_utils import process_input, vis_joint, vis_connectivity, vis_skinning, vis_all, prepare_blender_file
from Anymate.utils.ui_utils import get_result_joint, get_result_connectivity, get_result_skinning
from Anymate.utils.utils import load_checkpoint
# Check if checkpoints exist, if not download them
if not (os.path.exists(ui_args.checkpoint_joint) and
os.path.exists(ui_args.checkpoint_conn) and
os.path.exists(ui_args.checkpoint_skin)):
print("Missing checkpoints, downloading them...")
os.system("bash Anymate/get_checkpoints.sh")
model_joint = load_checkpoint(ui_args.checkpoint_joint, 'cpu', anymate_args.num_joints).to(anymate_args.device)
model_connectivity = load_checkpoint(ui_args.checkpoint_conn, 'cpu', anymate_args.num_joints).to(anymate_args.device)
model_skinning = load_checkpoint(ui_args.checkpoint_skin, 'cpu', anymate_args.num_joints).to(anymate_args.device)
@spaces.GPU
def get_all_results(mesh_file, pc, eps=0.03, min_samples=1):
# pc = pc.to(anymate_args.device)
joints = get_result_joint(mesh_file, model_joint, pc, eps=eps, min_samples=min_samples)
conns = get_result_connectivity(mesh_file, model_connectivity, pc, joints)
skins = get_result_skinning(mesh_file, model_skinning, pc, joints, conns)
print("Finish Inference")
return
with gr.Blocks() as demo:
gr.Markdown("""
# Anymate: Auto-rigging 3D Objects
[Project](https://anymate3d.github.io/)
""")
pc = gr.State(value=None)
normalized_mesh_file = gr.State(value=None)
# result_joint = gr.State(value=None)
# result_connectivity = gr.State(value=None)
# result_skinning = gr.State(value=None)
# model_joint = gr.State(value=model_joint)
# model_connectivity = gr.State(value=model_connectivity)
# model_skinning = gr.State(value=model_skinning)
with gr.Row():
with gr.Column():
# Input section
gr.Markdown("### Input")
mesh_input = gr.Model3D(label="Input 3D Mesh", clear_color=[0.0, 0.0, 0.0, 0.0])
# Sample 3D objects section
gr.Markdown("### Sample Objects")
sample_objects_dir = './samples'
sample_objects = [os.path.join(sample_objects_dir, f) for f in os.listdir(sample_objects_dir)
if f.endswith('.obj') and os.path.isfile(os.path.join(sample_objects_dir, f))]
sample_objects.sort()
sample_dropdown = gr.Dropdown(
label="Select Sample Object",
choices=sample_objects,
interactive=True,
value=sample_objects[0]
)
load_sample_btn = gr.Button("Load Sample")
with gr.Column():
# Output section
gr.Markdown("### Output (wireframe display mode)")
mesh_output = gr.Model3D(label="Output 3D Mesh", clear_color=[0.0, 0.0, 0.0, 0.0], display_mode="wireframe")
with gr.Column():
# Output section
gr.Markdown("### (solid display mode & blender file)")
mesh_output2 = gr.Model3D(label="Output 3D Mesh", clear_color=[0.0, 0.0, 0.0, 0.0], display_mode="solid")
blender_file = gr.File(label="Output Blender File", scale=1)
# Checkpoint paths
# joint_models_dir = 'Anymate/checkpoints/joint'
# joint_models = [os.path.join(joint_models_dir, f) for f in os.listdir(joint_models_dir)
# if os.path.isfile(os.path.join(joint_models_dir, f))]
# with gr.Row():
# joint_checkpoint = gr.Dropdown(
# label="Joint Checkpoint",
# choices=joint_models,
# value=ui_args.checkpoint_joint,
# interactive=True
# )
# joint_status = gr.Checkbox(label="Joint Model Status", value=False, interactive=False, scale=0.3)
# with gr.Column():
# with gr.Row():
# load_joint_btn = gr.Button("Load", scale=0.3)
# process_joint_btn = gr.Button("Process", scale=0.3)
# conn_models_dir = 'Anymate/checkpoints/conn'
# conn_models = [os.path.join(conn_models_dir, f) for f in os.listdir(conn_models_dir)
# if os.path.isfile(os.path.join(conn_models_dir, f))]
# with gr.Row():
# conn_checkpoint = gr.Dropdown(
# label="Connection Checkpoint",
# choices=conn_models,
# value=ui_args.checkpoint_conn,
# interactive=True
# )
# conn_status = gr.Checkbox(label="Connectivity Model Status", value=False, interactive=False, scale=0.3)
# with gr.Column():
# with gr.Row():
# load_conn_btn = gr.Button("Load", scale=0.3)
# process_conn_btn = gr.Button("Process", scale=0.3)
# skin_models_dir = 'Anymate/checkpoints/skin'
# skin_models = [os.path.join(skin_models_dir, f) for f in os.listdir(skin_models_dir)
# if os.path.isfile(os.path.join(skin_models_dir, f))]
# with gr.Row():
# skin_checkpoint = gr.Dropdown(
# label="Skin Checkpoint",
# choices=skin_models,
# value=ui_args.checkpoint_skin,
# interactive=True
# )
# skin_status = gr.Checkbox(label="Skinning Model Status", value=False, interactive=False, scale=0.3)
# with gr.Column():
# with gr.Row():
# load_skin_btn = gr.Button("Load", scale=0.3)
# process_skin_btn = gr.Button("Process", scale=0.3)
with gr.Row():
# load_all_btn = gr.Button("Load all models", scale=1)
process_all_btn = gr.Button("Run all models", scale=1)
# download_btn = gr.DownloadButton("Blender File Not Ready", scale=0.3)
# blender_file = gr.File(label="Blender File", scale=1)
# Parameters for DBSCAN clustering algorithm used to adjust joint clustering
eps = gr.Number(label="Epsilon", value=0.03, interactive=True, info="Controls the maximum distance between joints in a cluster")
min_samples = gr.Number(label="Min Samples", value=1, interactive=True, info="Minimum number of joints required to form a cluster")
mesh_input.change(
process_input,
inputs=mesh_input,
outputs=[normalized_mesh_file, mesh_output, mesh_output2, blender_file, pc]
)
load_sample_btn.click(
fn=lambda sample_path: sample_path if sample_path else None,
inputs=[sample_dropdown],
outputs=[mesh_input]
).then(
process_input,
inputs=mesh_input,
outputs=[normalized_mesh_file, mesh_output, mesh_output2, blender_file, pc]
)
normalized_mesh_file.change(
lambda x: x,
inputs=normalized_mesh_file,
outputs=mesh_input
)
# result_joint.change(
# vis_joint,
# inputs=[normalized_mesh_file, result_joint],
# outputs=[mesh_output, mesh_output2]
# )
# result_connectivity.change(
# vis_connectivity,
# inputs=[normalized_mesh_file, result_joint, result_connectivity],
# outputs=[mesh_output, mesh_output2]
# )
# result_skinning.change(
# vis_skinning,
# inputs=[normalized_mesh_file, result_joint, result_connectivity, result_skinning],
# outputs=[mesh_output, mesh_output2]
# )
# result_skinning.change(
# prepare_blender_file,
# inputs=[normalized_mesh_file],
# outputs=blender_file
# )
process_all_btn.click(
get_all_results,
inputs=[normalized_mesh_file, pc, eps, min_samples],
outputs=[]
).then(
vis_all,
inputs=[normalized_mesh_file],
outputs=[mesh_output, mesh_output2]
).then(
prepare_blender_file,
inputs=[normalized_mesh_file],
outputs=blender_file
)
if __name__ == "__main__":
demo.launch()