| |
| |
|
|
| |
| import math |
| import os |
| import threading |
| import time |
| from typing import Optional |
|
|
| from kimodo.constraints import load_constraints_lst, save_constraints_lst |
| from kimodo.exports.bvh import motion_to_bvh_bytes, save_motion_bvh |
| from kimodo.exports.motion_io import ( |
| amass_npz_to_bytes, |
| g1_csv_to_bytes, |
| kimodo_npz_to_bytes, |
| load_motion_file, |
| save_kimodo_npz, |
| ) |
| from kimodo.model.registry import kimodo_short_key_for_skeleton_dataset, registry_skeleton_for_joint_count |
| from kimodo.tools import to_torch |
| from kimodo.viz import viser_utils |
| from kimodo.viz.viser_utils import GuiElements |
| import numpy as np |
| import torch |
| import viser |
| from viser._timeline_api import PROMPT_COLORS |
|
|
| from . import generation |
| from ._qwen_prompts import call_qwen_for_prompts |
| from .config import ( |
| DEFAULT_CUR_DURATION, |
| DEMO_UI_INSTRUCTIONS_TAB_MD, |
| get_datasets, |
| get_model_info, |
| get_models_for_dataset_skeleton, |
| get_skeleton_display_name, |
| get_skeleton_display_names_for_dataset, |
| get_skeleton_key_from_display_name, |
| get_short_key_from_display_name, |
| HF_MODE, |
| INIT_POSTPROCESSING, |
| MODEL_NAMES, |
| NB_TRANSITION_FRAMES, |
| SHOW_TRANSITION_PARAMS, |
| ) |
| from .state import ClientSession |
| from kimodo.skeleton import G1Skeleton34, SOMASkeleton30, SOMASkeleton77 |
|
|
|
|
| QWEN_EXAMPLE_NAME = "09_qwen_agentic_actions" |
| QWEN_EXAMPLE_LEGACY_NAME = "09_qwen_agentic_10_actions" |
|
|
|
|
| def extract_intervals_and_singles(t: torch.Tensor): |
| intervals = [] |
| intervals_indices = [] |
| single_frames = [] |
| single_frames_indices = [] |
|
|
| start_idx = 0 |
|
|
| for i in range(1, len(t) + 1): |
| |
| |
| |
| if i == len(t) or t[i] != t[i - 1] + 1: |
| run_length = i - start_idx |
|
|
| if run_length >= 2: |
| intervals.append((int(t[start_idx]), int(t[i - 1]))) |
| intervals_indices.append((start_idx, i - 1)) |
| else: |
| single_frames.append(int(t[start_idx])) |
| single_frames_indices.append(start_idx) |
|
|
| start_idx = i |
|
|
| return intervals, intervals_indices, single_frames, single_frames_indices |
|
|
|
|
| def create_gui( |
| demo, |
| client: viser.ClientHandle, |
| model_name: str, |
| model_fps: float, |
| ): |
| """Create GUI elements for a specific client.""" |
| client_id = client.client_id |
|
|
| def get_active_session(event_client: viser.ClientHandle | None): |
| if event_client is None: |
| return None |
| if not demo.client_active(event_client.client_id): |
| return None |
| return demo.client_sessions[event_client.client_id] |
|
|
| def build_timeline_tracks(): |
| timeline = client.timeline |
| demo.set_timeline_defaults(timeline, model_fps) |
| timeline.set_visible(True) |
| timeline.set_current_frame(0) |
|
|
| timeline_tracks = {} |
| fullbody_id = timeline.add_track( |
| "Full-Body", |
| track_type="keyframe", |
| color=(219, 148, 86), |
| height_scale=0.5, |
| ) |
| timeline_tracks[fullbody_id] = { |
| "name": "Full-Body", |
| "track_type": "keyframe", |
| "color": (219, 148, 86), |
| "height_scale": 0.5, |
| } |
|
|
| root2d_id = timeline.add_track( |
| "2D Root", |
| track_type="keyframe", |
| color=(150, 100, 200), |
| height_scale=0.5, |
| ) |
| timeline_tracks[root2d_id] = { |
| "name": "2D Root", |
| "track_type": "keyframe", |
| "color": (150, 100, 200), |
| "height_scale": 0.5, |
| } |
| lefthand_id = timeline.add_track( |
| "Left Hand", |
| track_type="keyframe", |
| color=(100, 200, 150), |
| height_scale=0.5, |
| ) |
| timeline_tracks[lefthand_id] = { |
| "name": "Left Hand", |
| "track_type": "keyframe", |
| "color": (100, 200, 150), |
| "height_scale": 0.5, |
| } |
| righthand_id = timeline.add_track( |
| "Right Hand", |
| track_type="keyframe", |
| color=(200, 100, 150), |
| height_scale=0.5, |
| ) |
| timeline_tracks[righthand_id] = { |
| "name": "Right Hand", |
| "track_type": "keyframe", |
| "color": (200, 100, 150), |
| "height_scale": 0.5, |
| } |
| leftfoot_id = timeline.add_track( |
| "Left Foot", |
| track_type="keyframe", |
| color=(219, 148, 86), |
| height_scale=0.5, |
| ) |
| timeline_tracks[leftfoot_id] = { |
| "name": "Left Foot", |
| "track_type": "keyframe", |
| "color": (219, 148, 86), |
| "height_scale": 0.5, |
| } |
| rightfoot_id = timeline.add_track( |
| "Right Foot", |
| track_type="keyframe", |
| color=(150, 100, 200), |
| height_scale=0.5, |
| ) |
| timeline_tracks[rightfoot_id] = { |
| "name": "Right Foot", |
| "track_type": "keyframe", |
| "color": (150, 100, 200), |
| "height_scale": 0.5, |
| } |
| return timeline, timeline_tracks |
|
|
| timeline, timeline_tracks = build_timeline_tracks() |
| |
| |
| gui_play_pause_button = None |
| gui_next_frame_button = None |
| gui_prev_frame_button = None |
| gui_timeline = None |
| gui_duration_slider = None |
|
|
| |
| tab_group = client.gui.add_tab_group() |
|
|
| |
| |
| |
| with tab_group.add_tab("Generate", viser.Icon.WALK): |
| with client.gui.add_folder("Model Selection", expand_by_default=True): |
| info = get_model_info(model_name) |
| if info is None: |
| info = get_model_info(next(iter(MODEL_NAMES))) |
|
|
| def get_allowed_skeleton_labels(dataset_ui_label: str) -> list[str]: |
| labels = get_skeleton_display_names_for_dataset(dataset_ui_label, family="Kimodo") |
| if HF_MODE: |
| labels = [label for label in labels if get_skeleton_key_from_display_name(label) != "SMPLX"] |
| return labels |
|
|
| dataset_ui_label = "Rigplay" if HF_MODE else info.dataset_ui_label |
| datasets = ["Rigplay"] if HF_MODE else get_datasets(family="Kimodo") |
| skeleton_labels = get_allowed_skeleton_labels(dataset_ui_label) |
| initial_skeleton_label = get_skeleton_display_name(info.skeleton) |
| if initial_skeleton_label not in skeleton_labels and skeleton_labels: |
| initial_skeleton_label = skeleton_labels[0] |
| initial_skeleton_key = ( |
| get_skeleton_key_from_display_name(initial_skeleton_label) if skeleton_labels else None |
| ) |
| models_for_pair = ( |
| get_models_for_dataset_skeleton(dataset_ui_label, initial_skeleton_key, family="Kimodo") |
| if initial_skeleton_key is not None |
| else [] |
| ) |
| version_options = [m.display_name for m in models_for_pair] |
| initial_version = ( |
| info.display_name |
| if info.display_name in version_options |
| else (version_options[0] if version_options else "") |
| ) |
| gui_dataset_selector = client.gui.add_dropdown( |
| "Training dataset", |
| options=datasets, |
| initial_value=dataset_ui_label, |
| visible=not HF_MODE, |
| ) |
| gui_skeleton_selector = client.gui.add_dropdown( |
| "Model" if HF_MODE else "Skeleton", |
| options=skeleton_labels, |
| initial_value=initial_skeleton_label, |
| ) |
| gui_version_selector = client.gui.add_dropdown( |
| "Version", |
| options=version_options, |
| initial_value=initial_version, |
| ) |
| gui_version_selector.visible = len(models_for_pair) > 1 |
| gui_model_display = client.gui.add_markdown( |
| content=f"**Model:** {initial_version}", |
| ) |
| gui_load_model_button = client.gui.add_button( |
| "Load model", |
| hint="Load the selected model (dataset, skeleton, version).", |
| ) |
|
|
| class ModelSelectorHandle: |
| """Wrapper so session and callbacks can treat three dropdowns as one.""" |
|
|
| def __init__(self): |
| self._dataset = gui_dataset_selector |
| self._skeleton = gui_skeleton_selector |
| self._version = gui_version_selector |
| self._display = gui_model_display |
|
|
| @property |
| def value(self) -> str: |
| return get_short_key_from_display_name(self._version.value) or "" |
|
|
| def set_from_short_key(self, short_key: str) -> None: |
| info = get_model_info(short_key) |
| if info is None: |
| return |
| dataset_ui_label = "Rigplay" if HF_MODE else info.dataset_ui_label |
| self._dataset.value = dataset_ui_label |
| self._skeleton.options = get_allowed_skeleton_labels(dataset_ui_label) |
| skeleton_label = get_skeleton_display_name(info.skeleton) |
| if skeleton_label not in self._skeleton.options and self._skeleton.options: |
| skeleton_label = self._skeleton.options[0] |
| self._skeleton.value = skeleton_label |
| skeleton_key = get_skeleton_key_from_display_name(skeleton_label) |
| if skeleton_key is None: |
| return |
| models = get_models_for_dataset_skeleton(dataset_ui_label, skeleton_key, family="Kimodo") |
| self._version.options = [m.display_name for m in models] |
| self._version.value = ( |
| info.display_name if info.display_name in self._version.options else self._version.options[0] |
| ) |
| self._version.visible = len(models) > 1 |
| self._display.content = f"**Model:** {self._version.value}" |
|
|
| gui_model_selector = ModelSelectorHandle() |
|
|
| with client.gui.add_folder("Examples", expand_by_default=True): |
| examples_base_dir = demo.get_examples_base_dir(model_name, absolute=True) |
| example_dict = viser_utils.load_example_cases(examples_base_dir) |
| example_names = list(example_dict.keys()) |
| example_names.append(QWEN_EXAMPLE_NAME) |
| gui_examples_dropdown = client.gui.add_dropdown( |
| "Example", |
| options=example_names, |
| initial_value=example_names[0], |
| ) |
| gui_load_example_button = client.gui.add_button( |
| "Load Example", |
| hint="Load the selected example (or Qwen agentic prompt plan).", |
| disabled=False, |
| ) |
|
|
| def update_examples_dropdown( |
| new_example_dict: dict[str, str], |
| keep_selection: bool = True, |
| ) -> None: |
| example_names_local = list(new_example_dict.keys()) |
| if QWEN_EXAMPLE_NAME not in example_names_local: |
| example_names_local.append(QWEN_EXAMPLE_NAME) |
| if QWEN_EXAMPLE_LEGACY_NAME not in example_names_local: |
| example_names_local.append(QWEN_EXAMPLE_LEGACY_NAME) |
| gui_examples_dropdown.options = example_names_local |
| if keep_selection and gui_examples_dropdown.value in example_names_local: |
| return |
| gui_examples_dropdown.value = example_names_local[0] |
|
|
| with client.gui.add_folder("Generate", expand_by_default=True): |
| gui_duration = client.gui.add_markdown(content=f"Total duration: {DEFAULT_CUR_DURATION:.1f} (sec)") |
|
|
| def update_duration_gui(duration): |
| gui_duration.content = f"Total duration: {duration:.1f} (sec)" |
|
|
| def compute_prompt_num_frames(prompt_values): |
| """Convert timeline prompt bounds to per-prompt frame counts. |
| |
| Convention in this demo: |
| - All prompts except the last are treated as [start_frame, end_frame) |
| (end is exclusive). |
| - The last prompt is treated as [start_frame, end_frame] (end is inclusive). |
| - This assumes the prompts values are sorted by start_frame. |
| """ |
| if len(prompt_values) == 0: |
| return [] |
| num_frames = [] |
| for i, x in enumerate(prompt_values): |
| cur = x.end_frame - x.start_frame |
| if i == len(prompt_values) - 1: |
| cur += 1 |
| num_frames.append(cur) |
| return num_frames |
|
|
| def update_duration_auto(): |
| session = demo.client_sessions[client_id] |
| prompt_values = sorted( |
| [x for x in timeline._prompts.values()], |
| key=lambda x: x.start_frame, |
| ) |
| num_frames = compute_prompt_num_frames(prompt_values) |
| total_nb_frames = sum(num_frames) |
| cur_duration = total_nb_frames / session.model_fps |
| set_new_duration(client_id, cur_duration) |
| update_duration_gui(cur_duration) |
|
|
| gui_num_samples_slider = client.gui.add_slider( |
| "Num Samples", |
| min=1, |
| max=10, |
| step=1, |
| initial_value=1, |
| visible=not HF_MODE, |
| ) |
|
|
| gui_use_soma_layer_checkbox = client.gui.add_checkbox( |
| "SOMA layer", |
| initial_value=False, |
| visible="soma" in (model_name or ""), |
| ) |
|
|
| with client.gui.add_folder("Model Parameters", expand_by_default=False): |
| gui_seed = client.gui.add_number("Seed", initial_value=42) |
|
|
| with client.gui.add_folder("Diffusion", expand_by_default=False): |
| gui_diffusion_steps_slider = client.gui.add_slider( |
| "Denoising Steps", |
| min=2, |
| max=1000, |
| step=10, |
| initial_value=100, |
| ) |
| with client.gui.add_folder("Classifier-Free Guidance", expand_by_default=False): |
| gui_cfg_checkbox = client.gui.add_checkbox( |
| "Enable", |
| initial_value=True, |
| visible=True, |
| ) |
|
|
| gui_cfg_text_weight_slider = client.gui.add_slider( |
| "Text Weight", |
| min=0.0, |
| max=5.0, |
| step=0.1, |
| initial_value=2.0, |
| visible=True, |
| ) |
| gui_cfg_constraint_weight_slider = client.gui.add_slider( |
| "Constraint Weight", |
| min=0.0, |
| max=5.0, |
| step=0.1, |
| initial_value=2.0, |
| visible=True, |
| ) |
| with client.gui.add_folder( |
| "Transitions", |
| expand_by_default=False, |
| visible=SHOW_TRANSITION_PARAMS, |
| ): |
| gui_num_transition_frames_slider = client.gui.add_slider( |
| "Transition frames", |
| min=1, |
| max=10, |
| step=1, |
| initial_value=NB_TRANSITION_FRAMES, |
| visible=True, |
| ) |
| gui_share_transition_checkbox = client.gui.add_checkbox( |
| "Override previous frames", |
| initial_value=False, |
| visible=True, |
| ) |
| gui_percentage_transition_sharing_slider = client.gui.add_slider( |
| "Percentage overriding frames", |
| min=0, |
| max=30, |
| step=1, |
| initial_value=10, |
| visible=True, |
| ) |
|
|
| @gui_share_transition_checkbox.on_update |
| def _(event: viser.GuiEvent) -> None: |
| if get_active_session(event.client) is None: |
| return |
| |
| gui_percentage_transition_sharing_slider.visible = gui_share_transition_checkbox.value |
|
|
| with client.gui.add_folder("Post Processing", expand_by_default=False): |
| _model_name = model_name or "" |
| _postprocess_visible = "g1" not in _model_name |
| gui_postprocess_checkbox = client.gui.add_checkbox( |
| "Enable", |
| initial_value=INIT_POSTPROCESSING, |
| hint="Apply motion post-processing (not available for G1)", |
| visible=_postprocess_visible, |
| ) |
| gui_root_margin = client.gui.add_number( |
| "Root Margin", |
| min=0.0, |
| |
| step=0.01, |
| initial_value=0.04, |
| hint="Margin for root position (meters). Lower values pin root closer to target.", |
| visible=INIT_POSTPROCESSING and _postprocess_visible, |
| ) |
|
|
| @gui_postprocess_checkbox.on_update |
| def _(event: viser.GuiEvent) -> None: |
| if get_active_session(event.client) is None: |
| return |
| |
| gui_root_margin.visible = gui_postprocess_checkbox.value |
|
|
| gui_real_robot_rotations_checkbox = client.gui.add_checkbox( |
| "Real robot rotations", |
| initial_value=False, |
| hint="Project joint rotations to G1 real robot DoF (1-DoF per joint) and clamp to axis limits from the MuJoCo XML.", |
| visible="g1" in _model_name, |
| ) |
|
|
| with client.gui.add_folder("Qwen Auto-Prompts", expand_by_default=True): |
| gui_qwen_scene = client.gui.add_text( |
| "Scene context", |
| initial_value="A lone figure moving through an empty plaza", |
| hint="Describe the scene or character context for Qwen to generate motion prompts.", |
| ) |
| gui_qwen_actions = client.gui.add_slider( |
| "Target actions", |
| min=1, |
| max=10, |
| step=1, |
| initial_value=6, |
| hint="Number of prompt segments to place on the timeline.", |
| ) |
| gui_qwen_auto_run = client.gui.add_checkbox( |
| "Auto-run Generate after loading prompts", |
| initial_value=False, |
| ) |
| gui_qwen_status = client.gui.add_markdown(content="") |
| gui_qwen_button = client.gui.add_button("Fill Timeline via Qwen", color="blue") |
|
|
| @gui_qwen_button.on_click |
| def _(event: viser.GuiEvent) -> None: |
| event_client = event.client |
| session = get_active_session(event_client) |
| if session is None: |
| return |
|
|
| def _run_qwen_fill_and_maybe_generate() -> None: |
| gui_qwen_button.disabled = True |
| gui_qwen_status.content = "⏳ Calling Qwen…" |
| target_actions = int(gui_qwen_actions.value) |
| history: list[str] = [] |
| all_texts: list[str] = [] |
| all_durations: list[float] = [] |
| rounds = 0 |
|
|
| |
| while len(all_texts) < target_actions and rounds < 6: |
| remaining = target_actions - len(all_texts) |
| batch, history = call_qwen_for_prompts( |
| scene=gui_qwen_scene.value, |
| history=history, |
| requested_actions=min(5, remaining), |
| ) |
| batch_texts = batch.get("texts", []) |
| batch_durations = batch.get("durations", []) |
| for t, d in zip(batch_texts, batch_durations): |
| if len(all_texts) >= target_actions: |
| break |
| all_texts.append(t) |
| all_durations.append(float(d)) |
| rounds += 1 |
|
|
| if len(all_texts) == 0: |
| gui_qwen_status.content = "⚠ Qwen did not return usable prompts" |
| gui_qwen_button.disabled = False |
| return |
|
|
| fps = session.model_fps |
| event_client.timeline.clear_prompts() |
| frame_cursor = 0 |
| for i, (txt, dur) in enumerate(zip(all_texts, all_durations)): |
| n = max(1, int(round(dur * fps))) |
| start_f = frame_cursor |
| end_f = frame_cursor + n if i < len(all_texts) - 1 else frame_cursor + n - 1 |
| color = PROMPT_COLORS[i % len(PROMPT_COLORS)] |
| event_client.timeline.add_prompt(txt, start_f, end_f, color=color) |
| frame_cursor += n |
|
|
| |
| target_visible_frames = int(math.ceil(1.10 * frame_cursor)) |
| event_client.timeline.set_zoom_settings(default_num_frames_zoom=max(60, target_visible_frames)) |
|
|
| update_duration_auto() |
| gui_qwen_status.content = f"✓ Loaded {len(all_texts)} Qwen prompt segments" |
|
|
| if gui_qwen_auto_run.value: |
| gui_qwen_status.content = f"✓ Loaded {len(all_texts)} prompts, generating motion…" |
| try: |
| demo.generate( |
| event_client, |
| all_texts, |
| [max(1, int(round(d * fps))) for d in all_durations], |
| gui_num_samples_slider.value, |
| gui_seed.value, |
| gui_diffusion_steps_slider.value, |
| cfg_weight=[ |
| gui_cfg_text_weight_slider.value, |
| gui_cfg_constraint_weight_slider.value, |
| ], |
| cfg_type="separated" if gui_cfg_checkbox.value else "nocfg", |
| postprocess_parameters={ |
| "post_processing": gui_postprocess_checkbox.value, |
| "root_margin": gui_root_margin.value, |
| }, |
| transitions_parameters={ |
| "num_transition_frames": gui_num_transition_frames_slider.value, |
| "share_transition": gui_share_transition_checkbox.value, |
| "percentage_transition_override": gui_percentage_transition_sharing_slider.value / 100, |
| }, |
| real_robot_rotations=gui_real_robot_rotations_checkbox.value, |
| ) |
| gui_qwen_status.content = f"✓ Generated motion from {len(all_texts)} Qwen actions" |
| except Exception as exc: |
| gui_qwen_status.content = f"⚠ Generate error: {exc}" |
|
|
| gui_qwen_button.disabled = False |
|
|
| threading.Thread(target=_run_qwen_fill_and_maybe_generate, daemon=True).start() |
|
|
| gui_generate_button = client.gui.add_button("Generate", color="green") |
| with client.gui.add_folder("Constraints", expand_by_default=False): |
| gui_gizmo_space_dropdown = client.gui.add_dropdown( |
| "Gizmo space", |
| ("Local", "World"), |
| initial_value="Local", |
| visible="g1" not in _model_name, |
| ) |
| gui_edit_constraint_button = client.gui.add_button("Enter Editing Mode") |
| gui_snap_to_constraint_button = client.gui.add_button( |
| "Snap to Constraint", |
| disabled=True, |
| ) |
| gui_reset_constraint_button = client.gui.add_button( |
| "Reset Constraint", |
| disabled=True, |
| ) |
| gui_undo_drag_button = client.gui.add_button( |
| "Undo Move", |
| disabled=True, |
| ) |
|
|
| with client.gui.add_folder("Root 2D Options", expand_by_default=True): |
| gui_dense_path_checkbox = client.gui.add_checkbox( |
| "Make Smooth Path", |
| initial_value=False, |
| visible=True, |
| ) |
|
|
| gui_show_only_current_constraint_checkbox = client.gui.add_checkbox( |
| "Show only Current", |
| initial_value=False, |
| hint="Show only constraint overlays at the current frame; uncheck to show all.", |
| ) |
|
|
| def apply_constraint_overlay_visibility(session: ClientSession) -> None: |
| demo._apply_constraint_overlay_visibility(session) |
|
|
| @gui_show_only_current_constraint_checkbox.on_update |
| def _(event: viser.GuiEvent) -> None: |
| session = get_active_session(event.client) |
| if session is None: |
| return |
| session.show_only_current_constraint = gui_show_only_current_constraint_checkbox.value |
| apply_constraint_overlay_visibility(session) |
|
|
| gui_clear_all_constraints_button = client.gui.add_button( |
| "Clear All Constraints", |
| color="red", |
| ) |
|
|
| def has_constraint_at_frame(session: ClientSession, frame_idx: int) -> bool: |
| for constraint_name in ["Full-Body", "End-Effectors", "2D Root"]: |
| constraint = session.constraints.get(constraint_name) |
| if constraint is None: |
| continue |
| if frame_idx in constraint.keyframes: |
| return True |
| return False |
|
|
| def update_snap_to_constraint_button(session: ClientSession) -> None: |
| gui_snap_to_constraint_button.disabled = not has_constraint_at_frame(session, session.frame_idx) |
|
|
| def ensure_edit_snapshot(session: ClientSession, motion, frame_idx: int) -> None: |
| if session.edit_mode_snapshot is None: |
| session.edit_mode_snapshot = {} |
| if frame_idx in session.edit_mode_snapshot: |
| return |
| session.edit_mode_snapshot[frame_idx] = { |
| "joints_pos": motion.get_joints_pos(frame_idx), |
| "joints_rot": motion.get_joints_rot(frame_idx), |
| } |
|
|
| def _update_dense_path(motion, session): |
| constraint_info = session.constraints["2D Root"].get_constraint_info() |
|
|
| if len(constraint_info["frame_idx"]) > 0: |
| min_root_frame = min(constraint_info["frame_idx"]) |
| max_root_frame = max(constraint_info["frame_idx"]) |
| motion.set_projected_root_pos_path( |
| constraint_info["root_pos"][:, [0, 2]], |
| min_frame_idx=min_root_frame, |
| max_frame_idx=max_root_frame, |
| ) |
|
|
| |
| DENSE_PATH_AFTER_RELEASE_MS = 300 |
|
|
| def _schedule_dense_path_after_release(session): |
| """Schedule a single path update to run after user stops dragging.""" |
| if "2D Root" not in session.constraints or not session.constraints["2D Root"].dense_path: |
| return |
| tdata = session.timeline_data |
| if tdata.get("dense_path_after_release_timer"): |
| tdata["dense_path_after_release_timer"].cancel() |
| delay = DENSE_PATH_AFTER_RELEASE_MS / 1000.0 |
|
|
| def run(): |
| if not demo.client_active(client_id): |
| return |
| sess = demo.client_sessions[client_id] |
| tdata["dense_path_after_release_timer"] = None |
| if "2D Root" not in sess.constraints or not sess.constraints["2D Root"].dense_path: |
| return |
| mot = list(sess.motions.values())[0] |
| _update_dense_path(mot, sess) |
|
|
| t = threading.Timer(delay, run) |
| tdata["dense_path_after_release_timer"] = t |
| t.start() |
|
|
| @gui_dense_path_checkbox.on_update |
| def _(event: viser.GuiEvent) -> None: |
| session = get_active_session(event.client) |
| if session is None: |
| return |
|
|
| if gui_dense_path_checkbox.value: |
| |
| |
| root_2d_track = session.timeline_data["tracks_ids"]["2D Root"] |
|
|
| |
| start_keyframe_id = client.timeline.add_locked_keyframe( |
| root_2d_track, |
| 0, |
| opacity=0.0, |
| ) |
| session.timeline_data["keyframes"][start_keyframe_id] = { |
| "frame": 0, |
| "track_id": root_2d_track, |
| "locked": True, |
| "opacity": 0.0, |
| "value": None, |
| } |
| add_constraint_callback( |
| start_keyframe_id, |
| "2D Root", |
| (0, 0), |
| verbose=False, |
| ) |
|
|
| |
| end_keyframe_id = client.timeline.add_locked_keyframe( |
| root_2d_track, |
| session.max_frame_idx, |
| opacity=0.0, |
| ) |
| session.timeline_data["keyframes"][end_keyframe_id] = { |
| "frame": session.max_frame_idx, |
| "track_id": root_2d_track, |
| "locked": True, |
| "opacity": 0.0, |
| "value": None, |
| } |
| add_constraint_callback( |
| end_keyframe_id, |
| "2D Root", |
| (session.max_frame_idx, session.max_frame_idx), |
| verbose=False, |
| ) |
|
|
| |
| locked_interval = client.timeline.add_locked_interval( |
| root_2d_track, |
| start_frame=0, |
| end_frame=session.max_frame_idx, |
| ) |
| session.timeline_data["intervals"][locked_interval] = { |
| "track_id": root_2d_track, |
| "start_frame_idx": 0, |
| "end_frame_idx": session.max_frame_idx, |
| "locked": True, |
| "opacity": 0.3, |
| "value": None, |
| } |
|
|
| session.constraints["2D Root"].set_dense_path(gui_dense_path_checkbox.value) |
| if session.constraints["2D Root"].dense_path: |
| |
| |
| motion = list(session.motions.values())[0] |
| _update_dense_path(motion, session) |
|
|
| |
| if not gui_dense_path_checkbox.value: |
| |
| keyframes_to_remove = [] |
| for uuid, keyframe in client.timeline._keyframes.items(): |
| if keyframe.locked: |
| keyframes_to_remove.append(uuid) |
| _data = session.timeline_data["keyframes"][uuid] |
| remove_constraint_callback( |
| uuid, |
| constraint_type=session.timeline_data["tracks"][_data["track_id"]]["name"], |
| frame_range=(_data["frame"], _data["frame"]), |
| verbose=False, |
| ) |
|
|
| intervals_to_remove = [] |
| |
| for uuid, interval in client.timeline._intervals.items(): |
| if interval.locked: |
| intervals_to_remove.append(uuid) |
|
|
| |
| for uuid in keyframes_to_remove: |
| client.timeline.remove_keyframe(uuid) |
|
|
| for uuid in intervals_to_remove: |
| client.timeline.remove_interval(uuid) |
|
|
| apply_constraint_overlay_visibility(session) |
|
|
| with client.gui.add_folder( |
| "Load/Save", |
| expand_by_default=False, |
| visible=not HF_MODE, |
| ): |
| with client.gui.add_folder("Motion", expand_by_default=False): |
| gui_save_motion_path_text = client.gui.add_text("Save Path", initial_value="output") |
| gui_save_motion_format_dropdown = client.gui.add_dropdown( |
| "Save Format", |
| options=( |
| ["NPZ", "CSV"] |
| if "g1" in model_name.lower() |
| else ["NPZ", "AMASS NPZ"] |
| if "smplx" in model_name.lower() |
| else ["NPZ", "BVH"] |
| ), |
| initial_value="NPZ", |
| ) |
| gui_save_motion_button = client.gui.add_button( |
| "Save Motion", |
| hint="Save the current motion (format + path above)", |
| ) |
| gui_load_motion_path_text = client.gui.add_text( |
| "Load Path", |
| initial_value="output.npz", |
| hint="SOMA .bvh, Kimodo or AMASS .npz, or G1 MuJoCo .csv", |
| ) |
| gui_load_motion_button = client.gui.add_button( |
| "Load Motion", |
| hint="Load the selected motion", |
| ) |
| with client.gui.add_folder("Constraints", expand_by_default=False): |
| gui_save_constraints_path_text = client.gui.add_text( |
| "Save Path", initial_value="output_constraints.json" |
| ) |
| gui_save_constraints_button = client.gui.add_button("Save Constraints") |
| gui_load_constraints_path_text = client.gui.add_text( |
| "Load Path", initial_value="output_constraints.json" |
| ) |
| gui_load_constraints_button = client.gui.add_button("Load Constraints") |
| with client.gui.add_folder("Example", expand_by_default=False): |
| gui_save_example_path_text = client.gui.add_text( |
| "Save Dir", |
| initial_value=os.path.join( |
| demo.get_examples_base_dir(model_name, absolute=True), |
| "custom_example_1", |
| ), |
| ) |
| gui_save_example_button = client.gui.add_button("Save Example") |
| gui_load_example_path_text = client.gui.add_text( |
| "Load Dir", |
| initial_value=os.path.join( |
| demo.get_examples_base_dir(model_name, absolute=True), |
| "custom_example_1", |
| ), |
| ) |
| gui_load_gt_checkbox = client.gui.add_checkbox( |
| "Load GT instead", |
| initial_value=False, |
| ) |
| gui_load_example_from_path_button = client.gui.add_button("Load Example") |
|
|
| def _get_primary_motion(session: ClientSession): |
| return list(session.motions.values())[0] |
|
|
| def _motion_to_numpy_dict(motion) -> dict[str, np.ndarray]: |
| joints_pos = motion.joints_pos.detach().cpu().numpy() |
| joints_rot = motion.joints_rot.detach().cpu().numpy() |
| joints_local_rot = motion.joints_local_rot.detach().cpu().numpy() |
|
|
| if joints_pos.ndim != 3: |
| raise ValueError(f"Expected unbatched joints_pos with shape [T, J, 3], got {joints_pos.shape}") |
| if joints_rot.ndim != 4: |
| raise ValueError(f"Expected unbatched joints_rot with shape [T, J, 3, 3], got {joints_rot.shape}") |
| if joints_local_rot.ndim != 4: |
| raise ValueError( |
| "Expected unbatched joints_local_rot with shape " f"[T, J, 3, 3], got {joints_local_rot.shape}" |
| ) |
|
|
| motion_data = { |
| "posed_joints": joints_pos, |
| "global_rot_mats": joints_rot, |
| "local_rot_mats": joints_local_rot, |
| "root_positions": joints_pos[:, motion.skeleton.root_idx, :], |
| } |
| if motion.foot_contacts is not None: |
| foot_contacts = motion.foot_contacts.detach().cpu().numpy() |
| if foot_contacts.ndim != 2: |
| raise ValueError( |
| f"Expected unbatched foot_contacts with shape [T, C], got {foot_contacts.shape}" |
| ) |
| motion_data["foot_contacts"] = foot_contacts |
| return motion_data |
|
|
| def _coerce_save_path(raw_path: str, *, ext: str) -> str: |
| """Ensure the save path ends with the correct extension for the chosen format.""" |
| name = (raw_path or "").strip() |
| if name == "": |
| return f"output{ext}" |
| known_exts = (".npz", ".bvh", ".csv") |
| if name.lower().endswith(known_exts): |
| return os.path.splitext(name)[0] + ext |
| if os.path.splitext(name)[1] == "": |
| return name + ext |
| return name |
|
|
| def save_motion(client, save_path, fmt): |
| session = demo.client_sessions[client.client_id] |
| motion = _get_primary_motion(session) |
| motion_data = _motion_to_numpy_dict(motion) |
|
|
| if fmt == "BVH": |
| save_path = _coerce_save_path(save_path, ext=".bvh") |
| save_motion_bvh( |
| save_path, |
| motion.joints_local_rot, |
| motion.joints_pos[:, session.skeleton.root_idx, :], |
| skeleton=session.skeleton, |
| fps=float(session.model_fps), |
| ) |
| elif fmt == "CSV": |
| save_path = _coerce_save_path(save_path, ext=".csv") |
| data = g1_csv_to_bytes(motion_data, session.skeleton, demo.device) |
| with open(save_path, "wb") as f: |
| f.write(data) |
| elif fmt == "AMASS NPZ": |
| save_path = _coerce_save_path(save_path, ext=".npz") |
| data = amass_npz_to_bytes(motion_data, session.skeleton, session.model_fps) |
| with open(save_path, "wb") as f: |
| f.write(data) |
| else: |
| save_path = _coerce_save_path(save_path, ext=".npz") |
| save_kimodo_npz(save_path, motion_data) |
| return save_path |
|
|
| @gui_save_motion_button.on_click |
| def _(event: viser.GuiEvent) -> None: |
| event_client = event.client |
| if get_active_session(event_client) is None: |
| return |
|
|
| raw_path = gui_save_motion_path_text.value |
| fmt = str(gui_save_motion_format_dropdown.value).upper() |
| try: |
| saved_path = save_motion(event_client, raw_path, fmt) |
| event_client.add_notification( |
| title="Motion saved!", |
| body=f"Saved motion to {saved_path}", |
| auto_close_seconds=5.0, |
| color="green", |
| ) |
| except Exception as e: |
| import traceback |
|
|
| traceback.print_exc() |
| event_client.add_notification( |
| title="Failed to save motion!", |
| body=str(e), |
| auto_close_seconds=5.0, |
| color="red", |
| ) |
|
|
| def load_motion(client, load_path): |
| session = demo.client_sessions[client.client_id] |
|
|
| fps_arg = session.model_fps if session.model_fps and session.model_fps > 0 else None |
| motion_dict, num_joints_motion = load_motion_file(load_path, target_fps=fps_arg) |
|
|
| target_skel = registry_skeleton_for_joint_count(num_joints_motion) |
| current_info = get_model_info(session.model_name) |
| current_skel = current_info.skeleton if current_info is not None else None |
|
|
| if current_skel != target_skel: |
| dataset = current_info.dataset if current_info is not None else "RP" |
| new_key = kimodo_short_key_for_skeleton_dataset(target_skel, dataset) |
| if new_key is None: |
| new_key = kimodo_short_key_for_skeleton_dataset(target_skel, "RP") |
| if new_key is None: |
| raise ValueError( |
| f"No Kimodo model found for skeleton {target_skel} (motion has J={num_joints_motion})." |
| ) |
| if new_key != session.model_name: |
| gui_model_selector.set_from_short_key(new_key) |
| apply_model_selection(new_key) |
| _update_visibility_for_loaded_model(new_key) |
| client.add_notification( |
| title="Model switched", |
| body=f"Switched to {new_key} to match loaded motion (J={num_joints_motion}).", |
| auto_close_seconds=5.0, |
| color="blue", |
| ) |
| session = demo.client_sessions[client.client_id] |
|
|
| joints_pos = motion_dict["posed_joints"].to(device=demo.device, dtype=torch.float32) |
| joints_rot = motion_dict["global_rot_mats"].to(device=demo.device, dtype=torch.float32) |
| foot_contacts = motion_dict.get("foot_contacts") |
| if foot_contacts is not None: |
| foot_contacts = foot_contacts.to(device=demo.device, dtype=torch.float32) |
|
|
| |
| if joints_pos.ndim == 4: |
| joints_pos = joints_pos[0] |
| if joints_rot.ndim == 5: |
| joints_rot = joints_rot[0] |
| if foot_contacts is not None and foot_contacts.ndim == 3: |
| foot_contacts = foot_contacts[0] |
|
|
| |
| num_joints_loaded = joints_pos.shape[1] |
| num_joints_skeleton = session.skeleton.nbjoints |
| if num_joints_loaded != num_joints_skeleton: |
| |
| if ( |
| num_joints_loaded == 30 |
| and num_joints_skeleton == 77 |
| and isinstance(session.skeleton, SOMASkeleton77) |
| ): |
| from kimodo.skeleton import global_rots_to_local_rots |
|
|
| skel30 = SOMASkeleton30().to(demo.device) |
| if "local_rot_mats" in motion_dict: |
| local_rot_30 = motion_dict["local_rot_mats"].to(device=demo.device, dtype=torch.float32) |
| if local_rot_30.ndim == 4: |
| local_rot_30 = local_rot_30[0] |
| else: |
| local_rot_30 = global_rots_to_local_rots(joints_rot, skel30) |
| local_rot_77 = skel30.to_SOMASkeleton77(local_rot_30) |
| root_positions = joints_pos[:, skel30.root_idx, :] |
| joints_rot, joints_pos, _ = session.skeleton.fk(local_rot_77, root_positions) |
|
|
| if foot_contacts is not None and foot_contacts.shape[-1] == 4: |
| foot_contacts = torch.cat( |
| [ |
| foot_contacts[..., :2], |
| foot_contacts[..., 1:2], |
| foot_contacts[..., 2:4], |
| foot_contacts[..., 3:4], |
| ], |
| dim=-1, |
| ) |
| else: |
| raise ValueError( |
| f"The loaded motion has {num_joints_loaded} joints but the current model " |
| f"({session.model_name}) has {num_joints_skeleton} joints. " |
| "Load a motion generated with the same skeleton, or switch the model to match the motion." |
| ) |
| elif joints_rot.shape[1] != num_joints_skeleton: |
| raise ValueError( |
| f"Rotation data has {joints_rot.shape[1]} joints but the current model has " |
| f"{num_joints_skeleton} joints. The NPZ may be corrupted or from a different skeleton." |
| ) |
|
|
| |
| if ( |
| "g1" in session.model_name |
| and isinstance(session.skeleton, G1Skeleton34) |
| and gui_real_robot_rotations_checkbox.value |
| ): |
| joints_pos, joints_rot = generation.apply_g1_real_robot_projection( |
| session.skeleton, joints_pos, joints_rot |
| ) |
|
|
| |
| num_frames = joints_pos.shape[0] |
| duration = num_frames / session.model_fps |
|
|
| |
| session.cur_duration = duration |
| session.max_frame_idx = num_frames - 1 |
|
|
| |
| demo.clear_motions(client.client_id) |
| demo.add_character_motion( |
| client, |
| session.skeleton, |
| joints_pos, |
| joints_rot, |
| foot_contacts, |
| ) |
|
|
| |
| demo.set_frame(client.client_id, 0) |
|
|
| @gui_load_motion_button.on_click |
| def _(event: viser.GuiEvent) -> None: |
| event_client = event.client |
| session = get_active_session(event_client) |
| if session is None: |
| return |
|
|
| load_path = gui_load_motion_path_text.value |
| loading_notif = event_client.add_notification( |
| title="Loading motion...", |
| body=f"Loading from {load_path}", |
| loading=True, |
| with_close_button=False, |
| auto_close_seconds=None, |
| ) |
| try: |
| load_motion(event_client, load_path) |
|
|
| loading_notif.title = "Motion loaded!" |
| loading_notif.body = f"Loaded motion from {load_path} ({session.max_frame_idx + 1} frames, {session.cur_duration:.2f}s)" |
| loading_notif.loading = False |
| loading_notif.with_close_button = True |
| loading_notif.auto_close_seconds = 5.0 |
| loading_notif.color = "green" |
| except Exception as e: |
| import traceback |
|
|
| traceback.print_exc() |
| loading_notif.title = "Failed to load motion!" |
| loading_notif.body = str(e) |
| loading_notif.loading = False |
| loading_notif.with_close_button = True |
| loading_notif.auto_close_seconds = 10.0 |
| loading_notif.color = "red" |
|
|
| def save_constraints(client, save_path): |
| session = demo.client_sessions[client.client_id] |
| |
| |
| num_frames = session.max_frame_idx + 1 |
| model_bundle = demo.load_model(session.model_name) |
| constraints_lst = demo.compute_model_constraints_lst(session, model_bundle, num_frames) |
| save_constraints_lst(save_path, constraints_lst) |
|
|
| @gui_save_constraints_button.on_click |
| def _(event: viser.GuiEvent) -> None: |
| event_client = event.client |
| if get_active_session(event_client) is None: |
| return |
|
|
| try: |
| save_path = gui_save_constraints_path_text.value |
| save_constraints(event_client, save_path) |
| event_client.add_notification( |
| title="Constraints saved!", |
| body=f"Saved constraints to {save_path}", |
| auto_close_seconds=5.0, |
| color="green", |
| ) |
| except Exception as e: |
| import traceback |
|
|
| traceback.print_exc() |
| event_client.add_notification( |
| title="Failed to save constraints!", |
| body=str(e), |
| auto_close_seconds=10.0, |
| color="red", |
| ) |
|
|
| def load_constraints(client, load_path): |
| session = demo.client_sessions[client.client_id] |
| constraints_lst = load_constraints_lst(load_path, skeleton=session.skeleton) |
|
|
| |
| with session.timeline_data["keyframe_update_lock"]: |
| for constraint in list(session.constraints.values()): |
| constraint.clear() |
| client.timeline.clear_keyframes() |
| client.timeline.clear_intervals() |
|
|
| |
| |
| device = demo.device |
| for constraint_obj in constraints_lst: |
| constraint_type = constraint_obj.name |
|
|
| |
| frame_indices = constraint_obj.frame_indices |
| ( |
| intervals, |
| intervals_indices, |
| single_frames, |
| single_frames_indices, |
| ) = extract_intervals_and_singles(frame_indices) |
|
|
| load_targets: list[dict] = [] |
| root_pos = None |
|
|
| if constraint_type == "root2d": |
| |
| num_frames = constraint_obj.smooth_root_2d.shape[0] |
| root_pos = torch.zeros(num_frames, 3, device=device) |
| root_pos[:, 0] = constraint_obj.smooth_root_2d[:, 0] |
| root_pos[:, 2] = constraint_obj.smooth_root_2d[:, 1] |
| load_targets = [ |
| { |
| "track_name": "2D Root", |
| "constraint_track": session.constraints["2D Root"], |
| } |
| ] |
| elif constraint_type == "fullbody": |
| load_targets = [ |
| { |
| "track_name": "Full-Body", |
| "constraint_track": session.constraints["Full-Body"], |
| } |
| ] |
| elif constraint_type in { |
| "left-hand", |
| "right-hand", |
| "left-foot", |
| "right-foot", |
| }: |
| track_name = { |
| "left-hand": "Left Hand", |
| "right-hand": "Right Hand", |
| "left-foot": "Left Foot", |
| "right-foot": "Right Foot", |
| }[constraint_type] |
| load_targets = [ |
| { |
| "track_name": track_name, |
| "constraint_track": session.constraints["End-Effectors"], |
| "joint_names": constraint_obj.joint_names, |
| "end_effector_type": constraint_type, |
| } |
| ] |
| elif constraint_type in {"end-effector", "end-effectors"}: |
| |
| |
| joint_names_set = set(constraint_obj.joint_names) |
| for jname, track_name, eff_type in [ |
| ("LeftHand", "Left Hand", "left-hand"), |
| ("RightHand", "Right Hand", "right-hand"), |
| ("LeftFoot", "Left Foot", "left-foot"), |
| ("RightFoot", "Right Foot", "right-foot"), |
| ]: |
| if jname not in joint_names_set: |
| continue |
| target_joint_names = [jname] |
| if "Hips" in joint_names_set: |
| target_joint_names.append("Hips") |
| load_targets.append( |
| { |
| "track_name": track_name, |
| "constraint_track": session.constraints["End-Effectors"], |
| "joint_names": target_joint_names, |
| "end_effector_type": eff_type, |
| } |
| ) |
| if not load_targets: |
| raise KeyError( |
| "No recognized end-effector joint in constraint " |
| f"joint_names={constraint_obj.joint_names}" |
| ) |
| else: |
| raise KeyError(f"Unsupported constraint type in loader: {constraint_type}") |
|
|
| for target in load_targets: |
| track_id = session.timeline_data["tracks_ids"][target["track_name"]] |
| constraint_track = target["constraint_track"] |
|
|
| |
| for (start_idx, end_idx), (start_idx_t, end_idx_t) in zip(intervals, intervals_indices): |
| |
| interval_id = client.timeline.add_interval(track_id, start_idx, end_idx) |
| session.timeline_data["intervals"][interval_id] = { |
| "track_id": track_id, |
| "start_frame_idx": start_idx, |
| "end_frame_idx": end_idx, |
| "locked": False, |
| "opacity": 1.0, |
| "value": None, |
| } |
| if constraint_type == "root2d": |
| constraint_track.add_interval( |
| interval_id, |
| start_idx, |
| end_idx, |
| root_pos[start_idx_t : end_idx_t + 1], |
| ) |
| elif constraint_type == "fullbody": |
| constraint_track.add_interval( |
| interval_id, |
| start_idx, |
| end_idx, |
| constraint_obj.global_joints_positions[start_idx_t : end_idx_t + 1], |
| constraint_obj.global_joints_rots[start_idx_t : end_idx_t + 1], |
| ) |
| else: |
| constraint_track.add_interval( |
| interval_id, |
| start_idx, |
| end_idx, |
| constraint_obj.global_joints_positions[start_idx_t : end_idx_t + 1], |
| constraint_obj.global_joints_rots[start_idx_t : end_idx_t + 1], |
| target["joint_names"], |
| target["end_effector_type"], |
| ) |
|
|
| |
| for frame, frame_t in zip(single_frames, single_frames_indices): |
| |
| keyframe_id = client.timeline.add_keyframe(track_id, frame) |
| session.timeline_data["keyframes"][keyframe_id] = { |
| "track_id": track_id, |
| "frame": frame, |
| "locked": False, |
| "opacity": 1.0, |
| "value": None, |
| } |
| if constraint_type == "root2d": |
| constraint_track.add_keyframe( |
| keyframe_id, |
| frame, |
| root_pos[frame_t], |
| ) |
| elif constraint_type == "fullbody": |
| constraint_track.add_keyframe( |
| keyframe_id, |
| frame, |
| constraint_obj.global_joints_positions[frame_t], |
| constraint_obj.global_joints_rots[frame_t], |
| ) |
| else: |
| constraint_track.add_keyframe( |
| keyframe_id, |
| frame, |
| constraint_obj.global_joints_positions[frame_t], |
| constraint_obj.global_joints_rots[frame_t], |
| target["joint_names"], |
| target["end_effector_type"], |
| ) |
|
|
| @gui_load_constraints_button.on_click |
| def _(event: viser.GuiEvent) -> None: |
| event_client = event.client |
| if get_active_session(event_client) is None: |
| return |
|
|
| try: |
| load_path = gui_load_constraints_path_text.value |
| load_constraints(event_client, load_path) |
| session = demo.client_sessions[event_client.client_id] |
| apply_constraint_overlay_visibility(session) |
|
|
| event_client.add_notification( |
| title="Constraints loaded!", |
| body=f"Loaded constraints from {load_path}", |
| auto_close_seconds=5.0, |
| color="green", |
| ) |
| except Exception as e: |
| import traceback |
|
|
| traceback.print_exc() |
| event_client.add_notification( |
| title="Failed to load constraints!", |
| body=str(e), |
| auto_close_seconds=10.0, |
| color="red", |
| ) |
|
|
| with client.gui.add_folder("Exports", expand_by_default=False): |
| with client.gui.add_folder("Screenshot", expand_by_default=False, visible=not HF_MODE): |
| gui_screenshot_path_text = client.gui.add_text( |
| "Save Path", |
| initial_value="render.png", |
| hint="Filename for the screenshot (PNG).", |
| ) |
| gui_screenshot_button = client.gui.add_button( |
| "Download Screenshot", |
| hint="Capture the current canvas and download a PNG.", |
| ) |
| with client.gui.add_folder("Video", expand_by_default=False, visible=not HF_MODE): |
| gui_video_path_text = client.gui.add_text( |
| "Save Path", |
| initial_value="render.mp4", |
| hint="Filename for the video (MP4).", |
| ) |
| gui_video_button = client.gui.add_button( |
| "Download Video", |
| hint="Render every frame and download as MP4.", |
| ) |
| with client.gui.add_folder("Motion", expand_by_default=True): |
| gui_download_name_text = client.gui.add_text( |
| "Name", |
| initial_value="output", |
| hint="Base filename to save as (extension will be added based on format if omitted).", |
| ) |
| gui_download_format_dropdown = client.gui.add_dropdown( |
| "Format", |
| options=( |
| ["NPZ", "CSV"] |
| if "g1" in model_name.lower() |
| else ["NPZ", "AMASS NPZ"] |
| if "smplx" in model_name.lower() |
| else ["NPZ", "BVH"] |
| ), |
| initial_value="NPZ", |
| ) |
| gui_download_button = client.gui.add_button( |
| "Download", |
| hint="Download the current motion (format + name above).", |
| ) |
|
|
| def _download_bytes_to_browser( |
| event_client: viser.ClientHandle, |
| *, |
| data: bytes, |
| filename: str, |
| mime_type: str = "application/octet-stream", |
| ) -> None: |
| """Trigger a browser download for an in-memory byte payload. |
| |
| Important: this intentionally does NOT use `showSaveFilePicker()` to avoid |
| Chrome/Edge's file-write permission prompt ("this site can see edits you make"). |
| If you want "always ask where to save", configure your browser download settings. |
| """ |
| import base64 |
| import json |
|
|
| |
| b64 = base64.b64encode(data).decode("ascii") |
| js = f""" |
| (() => {{ |
| const filename = {json.dumps(filename)}; |
| const mimeType = {json.dumps(mime_type)}; |
| const b64 = {json.dumps(b64)}; |
| |
| // Decode base64 -> Uint8Array. |
| const binStr = atob(b64); |
| const bytes = new Uint8Array(binStr.length); |
| for (let i = 0; i < binStr.length; i++) bytes[i] = binStr.charCodeAt(i); |
| const blob = new Blob([bytes], {{ type: mimeType }}); |
| |
| // Standard browser download behavior. |
| const url = URL.createObjectURL(blob); |
| const a = document.createElement("a"); |
| a.href = url; |
| a.download = filename; |
| document.body.appendChild(a); |
| a.click(); |
| a.remove(); |
| URL.revokeObjectURL(url); |
| }})(); |
| """ |
| |
| from viser import _messages as _viser_messages |
|
|
| event_client.gui._websock_interface.queue_message( |
| _viser_messages.RunJavascriptMessage(source=js) |
| ) |
|
|
| def _motion_to_npz_bytes(motion) -> bytes: |
| motion_data = _motion_to_numpy_dict(motion) |
| return kimodo_npz_to_bytes(motion_data) |
|
|
| def _motion_to_csv_bytes(motion, session: ClientSession) -> bytes: |
| motion_data = _motion_to_numpy_dict(motion) |
| return g1_csv_to_bytes(motion_data, session.skeleton, demo.device) |
|
|
| def _motion_to_amass_npz_bytes(motion, session: ClientSession) -> bytes: |
| motion_data = _motion_to_numpy_dict(motion) |
| return amass_npz_to_bytes(motion_data, session.skeleton, session.model_fps) |
|
|
| def _get_motion_export_formats(loaded_model_name: str) -> list[str]: |
| model_name_lower = (loaded_model_name or "").lower() |
| if "g1" in model_name_lower: |
| return ["NPZ", "CSV"] |
| if "smplx" in model_name_lower: |
| return ["NPZ", "AMASS NPZ"] |
| return ["NPZ", "BVH"] |
|
|
| def _update_format_dropdown(dropdown, loaded_model_name: str) -> None: |
| new_options = _get_motion_export_formats(loaded_model_name) |
| current_value = str(dropdown.value) |
| dropdown.options = new_options |
| dropdown.value = current_value if current_value in new_options else new_options[0] |
|
|
| def _update_motion_export_dropdown(loaded_model_name: str) -> None: |
| _update_format_dropdown(gui_download_format_dropdown, loaded_model_name) |
| _update_format_dropdown(gui_save_motion_format_dropdown, loaded_model_name) |
|
|
| def _coerce_download_filename(raw_name: str, *, ext: str) -> str: |
| """Coerce a user-entered filename to a safe basename with the desired extension. |
| |
| - If empty: uses "output{ext}" |
| - If no extension: appends ext |
| - If endswith a known export extension: rewrites extension to ext (prevents mismatches) |
| - Any provided directory components are stripped |
| """ |
| import os |
|
|
| name = (raw_name or "").strip() |
| name = os.path.basename(name.replace("\\", "/")) |
| if name == "": |
| return f"output{ext}" |
|
|
| known_exts = (".npz", ".bvh", ".csv", ".png", ".mp4") |
| lower = name.lower() |
| if lower.endswith(known_exts): |
| return os.path.splitext(name)[0] + ext |
|
|
| root, cur_ext = os.path.splitext(name) |
| if cur_ext == "": |
| return name + ext |
| return name |
|
|
| def _get_render_size(event_client: viser.ClientHandle) -> tuple[int, int]: |
| width = int(event_client.camera.image_width) |
| height = int(event_client.camera.image_height) |
| if width <= 0 or height <= 0: |
| |
| return (1280, 720) |
| return (width, height) |
|
|
| def _round_up_to_multiple(value: int, multiple: int) -> int: |
| if multiple <= 0: |
| return value |
| return ((value + multiple - 1) // multiple) * multiple |
|
|
| def _download_canvas_to_browser(event_client: viser.ClientHandle, *, filename: str) -> None: |
| """Use the client-side canvas save path to avoid server-side renders.""" |
| import json |
|
|
| js = f""" |
| (() => {{ |
| const filename = {json.dumps(filename)}; |
| const canvases = Array.from(document.querySelectorAll("canvas")); |
| if (!canvases.length) {{ |
| console.error("No canvases found to save."); |
| return; |
| }} |
| // Pick the largest canvas by area (usually the main 3D view). |
| const canvas = canvases.reduce((best, cur) => {{ |
| const bestArea = (best?.width || 0) * (best?.height || 0); |
| const curArea = (cur?.width || 0) * (cur?.height || 0); |
| return curArea > bestArea ? cur : best; |
| }}, null); |
| if (!canvas) {{ |
| console.error("No canvas selected to save."); |
| return; |
| }} |
| canvas.toBlob((blob) => {{ |
| if (!blob) {{ |
| console.error("Export failed"); |
| return; |
| }} |
| const url = URL.createObjectURL(blob); |
| const a = document.createElement("a"); |
| a.href = url; |
| a.download = filename; |
| document.body.appendChild(a); |
| a.click(); |
| a.remove(); |
| URL.revokeObjectURL(url); |
| }}, "image/png"); |
| }})(); |
| """ |
| from viser import _messages as _viser_messages |
|
|
| event_client.gui._websock_interface.queue_message( |
| _viser_messages.RunJavascriptMessage(source=js) |
| ) |
|
|
| @gui_screenshot_button.on_click |
| def _(event: viser.GuiEvent) -> None: |
| event_client = event.client |
| if get_active_session(event_client) is None: |
| return |
|
|
| try: |
| filename = _coerce_download_filename( |
| str(gui_screenshot_path_text.value), |
| ext=".png", |
| ) |
| _download_canvas_to_browser(event_client, filename=filename) |
| event_client.add_notification( |
| title="Screenshot download started", |
| body=f"Saving {filename}", |
| auto_close_seconds=5.0, |
| color="green", |
| ) |
| except Exception as e: |
| import traceback |
|
|
| traceback.print_exc() |
| event_client.add_notification( |
| title="Failed to download screenshot!", |
| body=str(e), |
| auto_close_seconds=10.0, |
| color="red", |
| ) |
|
|
| @gui_video_button.on_click |
| def _(event: viser.GuiEvent) -> None: |
| event_client = event.client |
| session = get_active_session(event_client) |
| if session is None: |
| return |
| recording_notification: viser.NotificationHandle | None = None |
| try: |
| recording_notification = event_client.add_notification( |
| title="Recording video...", |
| body="Saving frames, please wait.", |
| loading=True, |
| with_close_button=False, |
| auto_close_seconds=None, |
| color="blue", |
| ) |
| event_client.timeline.disable_constraints() |
| width, height = _get_render_size(event_client) |
| |
| width = _round_up_to_multiple(width, 16) |
| height = _round_up_to_multiple(height, 16) |
| original_frame = session.frame_idx |
| frames = [] |
| for frame_idx in range(session.max_frame_idx + 1): |
| demo.set_frame( |
| event_client.client_id, |
| frame_idx, |
| update_timeline=True, |
| ) |
| frames.append( |
| event_client.get_render( |
| height=height, |
| width=width, |
| transport_format="jpeg", |
| ) |
| ) |
|
|
| |
| demo.set_frame(event_client.client_id, original_frame) |
|
|
| import imageio.v3 as iio |
|
|
| filename = _coerce_download_filename( |
| str(gui_video_path_text.value), |
| ext=".mp4", |
| ) |
| payload = iio.imwrite( |
| "<bytes>", |
| frames, |
| extension=".mp4", |
| fps=float(session.model_fps), |
| codec="h264", |
| plugin="pyav", |
| ) |
| event_client.send_file_download(filename, payload, save_immediately=True) |
| event_client.add_notification( |
| title="Video download started", |
| body=f"Saving {filename}", |
| auto_close_seconds=5.0, |
| color="green", |
| ) |
| except Exception as e: |
| import traceback |
|
|
| traceback.print_exc() |
| event_client.add_notification( |
| title="Failed to download video!", |
| body=str(e), |
| auto_close_seconds=10.0, |
| color="red", |
| ) |
| finally: |
| event_client.timeline.enable_constraints() |
| if recording_notification is not None: |
| recording_notification.remove() |
|
|
| @gui_download_button.on_click |
| def _(event: viser.GuiEvent) -> None: |
| event_client = event.client |
| session = get_active_session(event_client) |
| if session is None: |
| return |
| motion = _get_primary_motion(session) |
| try: |
| fmt = str(gui_download_format_dropdown.value).upper() |
| raw_name = str(gui_download_name_text.value) |
|
|
| if fmt == "BVH": |
| filename = _coerce_download_filename(raw_name, ext=".bvh") |
| payload = motion_to_bvh_bytes( |
| motion.joints_local_rot, |
| motion.joints_pos[:, session.skeleton.root_idx, :], |
| skeleton=session.skeleton, |
| fps=float(session.model_fps), |
| ) |
| mime = "text/plain" |
| elif fmt == "CSV": |
| filename = _coerce_download_filename(raw_name, ext=".csv") |
| payload = _motion_to_csv_bytes(motion, session) |
| mime = "text/csv" |
| elif fmt == "AMASS NPZ": |
| filename = _coerce_download_filename(raw_name, ext=".npz") |
| payload = _motion_to_amass_npz_bytes(motion, session) |
| mime = "application/octet-stream" |
| else: |
| |
| filename = _coerce_download_filename(raw_name, ext=".npz") |
| payload = _motion_to_npz_bytes(motion) |
| mime = "application/octet-stream" |
|
|
| _download_bytes_to_browser( |
| event_client, |
| data=payload, |
| filename=filename, |
| mime_type=mime, |
| ) |
|
|
| event_client.add_notification( |
| title="Download started", |
| body=f"Saving {filename}", |
| auto_close_seconds=5.0, |
| color="green", |
| ) |
| except Exception as e: |
| import traceback |
|
|
| traceback.print_exc() |
| event_client.add_notification( |
| title="Failed to download motion!", |
| body=str(e), |
| auto_close_seconds=10.0, |
| color="red", |
| ) |
|
|
| @gui_save_example_button.on_click |
| def _(event: viser.GuiEvent) -> None: |
| from kimodo.tools import save_json |
|
|
| event_client = event.client |
| session = get_active_session(event_client) |
| if session is None: |
| return |
|
|
| save_dir = gui_save_example_path_text.value |
| if os.path.exists(save_dir): |
| event_client.add_notification( |
| title="Failed to save example!", |
| body="Example directory already exists", |
| auto_close_seconds=10.0, |
| color="red", |
| ) |
| return |
|
|
| try: |
| os.makedirs(save_dir) |
| |
| constraint_path = os.path.join(save_dir, "constraints.json") |
| save_constraints(event_client, constraint_path) |
| |
| motion_path = os.path.join(save_dir, "motion.npz") |
| save_motion(event_client, motion_path, "NPZ") |
| |
| meta_path = os.path.join(save_dir, "meta.json") |
| prompt_texts = [] |
| prompt_durations_sec = [] |
| prompt_values = sorted( |
| [x for x in client.timeline._prompts.values()], |
| key=lambda x: x.start_frame, |
| ) |
| for i, prompt in enumerate(prompt_values): |
| prompt_texts.append(prompt.text) |
| |
| |
| n_frames = prompt.end_frame - prompt.start_frame |
| if i == len(prompt_values) - 1: |
| n_frames += 1 |
| prompt_durations_sec.append(n_frames / session.model_fps) |
| if len(prompt_texts) == 1: |
| meta_info = { |
| "text": prompt_texts[0], |
| "duration": prompt_durations_sec[0], |
| } |
| else: |
| meta_info = { |
| "texts": prompt_texts, |
| "durations": prompt_durations_sec, |
| } |
| meta_info["num_samples"] = gui_num_samples_slider.value |
| meta_info["seed"] = gui_seed.value |
| meta_info["diffusion_steps"] = gui_diffusion_steps_slider.value |
| meta_info["cfg"] = { |
| "enabled": gui_cfg_checkbox.value, |
| "text_weight": gui_cfg_text_weight_slider.value, |
| "constraint_weight": gui_cfg_constraint_weight_slider.value, |
| } |
| save_json(meta_path, meta_info) |
|
|
| |
| session.example_dict = viser_utils.load_example_cases(session.examples_base_dir) |
| update_examples_dropdown(session.example_dict, keep_selection=True) |
|
|
| event_client.add_notification( |
| title="Example saved!", |
| body=f"Saved example to {save_dir}", |
| auto_close_seconds=5.0, |
| color="green", |
| ) |
| except Exception as e: |
| import traceback |
|
|
| traceback.print_exc() |
| event_client.add_notification( |
| title="Failed to save example!", |
| body=str(e), |
| auto_close_seconds=10.0, |
| color="red", |
| ) |
|
|
| def set_new_duration(client_id, new_duration): |
| session = demo.client_sessions[client_id] |
| session.cur_duration = new_duration |
| update_duration_gui(new_duration) |
| session.max_frame_idx = int(session.cur_duration * session.model_fps - 1) |
| if session.frame_idx > session.max_frame_idx: |
| demo.set_frame(client_id, session.max_frame_idx) |
|
|
| def apply_model_selection(new_model_name: str) -> None: |
| session = demo.client_sessions[client_id] |
| if new_model_name == session.model_name: |
| return |
|
|
| session.playing = False |
|
|
| old_model_fps = session.model_fps |
| old_duration = session.cur_duration |
| old_prompts = [ |
| (prompt.text, prompt.start_frame, prompt.end_frame) for prompt in client.timeline._prompts.values() |
| ] |
| old_default_zoom_frames = client.timeline._default_num_frames_zoom |
| old_max_zoom_frames = client.timeline._max_frames_zoom |
|
|
| model_bundle = demo.load_model(new_model_name) |
|
|
| |
| if session.edit_mode and session.motions: |
| exit_editing_mode(session) |
| session.edit_mode = False |
| demo.clear_motions(client_id) |
| with session.timeline_data["keyframe_update_lock"]: |
| for constraint in list(session.constraints.values()): |
| constraint.clear() |
| session.constraints = demo.build_constraint_tracks(client, model_bundle.skeleton) |
| session.timeline_data["keyframes"] = {} |
| session.timeline_data["intervals"] = {} |
| client.timeline.clear_keyframes() |
| client.timeline.clear_intervals() |
|
|
| session.model_name = new_model_name |
| session.model_fps = model_bundle.model_fps |
| session.skeleton = model_bundle.skeleton |
| session.motion_rep = model_bundle.motion_rep |
| session.cur_duration = old_duration |
| session.max_frame_idx = int(session.cur_duration * session.model_fps - 1) |
| session.frame_idx = 0 |
| session.edit_mode = False |
|
|
| demo.set_timeline_defaults(client.timeline, session.model_fps) |
| client.timeline.set_current_frame(0) |
| gui_model_fps.value = session.model_fps |
| update_duration_gui(session.cur_duration) |
|
|
| if old_model_fps > 0: |
| default_zoom_seconds = old_default_zoom_frames / old_model_fps |
| max_zoom_seconds = old_max_zoom_frames / old_model_fps |
| new_default_zoom = int(round(default_zoom_seconds * session.model_fps)) |
| new_max_zoom = int(round(max_zoom_seconds * session.model_fps)) |
| new_default_zoom = max(1, new_default_zoom) |
| new_max_zoom = max(new_default_zoom, new_max_zoom) |
| client.timeline.set_zoom_settings( |
| default_num_frames_zoom=new_default_zoom, |
| max_frames_zoom=new_max_zoom, |
| ) |
|
|
| client.timeline.clear_prompts() |
| if old_prompts and old_model_fps > 0: |
| for i, (prompt_text, start_frame, end_frame) in enumerate(old_prompts): |
| start_sec = start_frame / old_model_fps |
| end_sec = end_frame / old_model_fps |
| new_start = int(round(start_sec * session.model_fps)) |
| new_end = int(round(end_sec * session.model_fps)) |
| new_start = max(0, min(new_start, session.max_frame_idx)) |
| new_end = max(new_start, min(new_end, session.max_frame_idx)) |
| color = PROMPT_COLORS[i % len(PROMPT_COLORS)] |
| client.timeline.add_prompt(prompt_text, new_start, new_end, color=color) |
|
|
| session.examples_base_dir = demo.get_examples_base_dir(new_model_name, absolute=True) |
| session.example_dict = viser_utils.load_example_cases(session.examples_base_dir) |
| update_examples_dropdown(session.example_dict, keep_selection=False) |
| gui_save_example_path_text.value = os.path.join( |
| demo.get_examples_base_dir(new_model_name, absolute=True), |
| "custom_example_1", |
| ) |
| gui_load_example_path_text.value = os.path.join( |
| demo.get_examples_base_dir(new_model_name, absolute=True), |
| "custom_example_1", |
| ) |
|
|
| demo.add_character_motion(client, session.skeleton) |
| apply_constraint_overlay_visibility(session) |
|
|
| def _update_version_and_display_from_dataset_skeleton() -> None: |
| dataset_ui = gui_dataset_selector.value |
| skeleton_display = gui_skeleton_selector.value |
| skeleton_val = get_skeleton_key_from_display_name(skeleton_display) |
| if skeleton_val is None: |
| return |
| models = get_models_for_dataset_skeleton(dataset_ui, skeleton_val, family="Kimodo") |
| if not models: |
| return |
| gui_version_selector.options = [m.display_name for m in models] |
| gui_version_selector.value = models[0].display_name |
| gui_version_selector.visible = len(models) > 1 |
| gui_model_display.content = f"**Model:** {models[0].display_name}" |
|
|
| def _update_visibility_for_loaded_model(loaded_model_name: str) -> None: |
| """Update model-specific controls from the currently loaded model only.""" |
| if not loaded_model_name: |
| return |
| _update_motion_export_dropdown(loaded_model_name) |
| gui_use_soma_layer_checkbox.visible = "soma" in loaded_model_name |
| _is_g1 = "g1" in loaded_model_name |
| gui_real_robot_rotations_checkbox.visible = _is_g1 |
| gui_postprocess_checkbox.visible = not _is_g1 |
| gui_root_margin.visible = not _is_g1 and gui_postprocess_checkbox.value |
| if _is_g1: |
| gui_gizmo_space_dropdown.value = "Local" |
| gui_gizmo_space_dropdown.visible = not _is_g1 |
| gui_gizmo_space_dropdown.disabled = _is_g1 |
|
|
| def _on_load_model_click(event: viser.GuiEvent) -> None: |
| """Load the currently selected model (called from Load model button).""" |
| if get_active_session(event.client) is None: |
| return |
| new_model_name = gui_model_selector.value |
| if not new_model_name: |
| return |
| info = get_model_info(new_model_name) |
| if info is None: |
| return |
| session = demo.client_sessions[event.client.client_id] |
| if new_model_name == session.model_name: |
| return |
| loading_notif = event.client.add_notification( |
| title="Loading model...", |
| body=f"Loading {info.display_name}", |
| loading=True, |
| with_close_button=False, |
| ) |
| try: |
| apply_model_selection(new_model_name) |
| _update_visibility_for_loaded_model(new_model_name) |
| loading_notif.title = "Model loaded" |
| loading_notif.body = f"{info.display_name} is ready." |
| loading_notif.loading = False |
| loading_notif.with_close_button = True |
| loading_notif.auto_close_seconds = 5.0 |
| loading_notif.color = "green" |
| except Exception as e: |
| loading_notif.loading = False |
| loading_notif.with_close_button = True |
| event.client.add_notification( |
| title="Model failed to load", |
| body=str(e), |
| color="red", |
| auto_close_seconds=10.0, |
| ) |
| gui_model_selector.set_from_short_key(session.model_name) |
|
|
| @gui_load_model_button.on_click |
| def _(event: viser.GuiEvent) -> None: |
| _on_load_model_click(event) |
|
|
| @gui_dataset_selector.on_update |
| def _(event: viser.GuiEvent) -> None: |
| if get_active_session(event.client) is None: |
| return |
| skeleton_labels = get_allowed_skeleton_labels(gui_dataset_selector.value) |
| gui_skeleton_selector.options = skeleton_labels |
| gui_skeleton_selector.value = skeleton_labels[0] if skeleton_labels else "" |
| _update_version_and_display_from_dataset_skeleton() |
|
|
| @gui_skeleton_selector.on_update |
| def _(event: viser.GuiEvent) -> None: |
| if get_active_session(event.client) is None: |
| return |
| _update_version_and_display_from_dataset_skeleton() |
|
|
| @gui_version_selector.on_update |
| def _(event: viser.GuiEvent) -> None: |
| if get_active_session(event.client) is None: |
| return |
| info = get_model_info(gui_model_selector.value) |
| if info is not None: |
| gui_model_display.content = f"**Model:** {info.display_name}" |
|
|
| @gui_use_soma_layer_checkbox.on_update |
| def _(event: viser.GuiEvent) -> None: |
| session = get_active_session(event.client) |
| if session is None or "soma" not in (session.model_name or ""): |
| return |
|
|
| loading_notif = event.client.add_notification( |
| title="Applying SOMA layer...", |
| body="Updating mesh.", |
| loading=True, |
| with_close_button=False, |
| ) |
| try: |
| current_motion = list(session.motions.values())[0] if session.motions else None |
| current_frame_idx = session.frame_idx |
|
|
| |
| demo.clear_motions(event.client.client_id) |
| if current_motion is None: |
| demo.add_character_motion(event.client, session.skeleton) |
| else: |
| demo.add_character_motion( |
| event.client, |
| session.skeleton, |
| current_motion.joints_pos, |
| current_motion.joints_rot, |
| current_motion.foot_contacts, |
| ) |
|
|
| demo.set_frame(event.client.client_id, current_frame_idx) |
| except Exception as e: |
| print(e) |
| event.client.add_notification( |
| title="SOMA layer failed", |
| body=str(e), |
| color="red", |
| auto_close_seconds=10.0, |
| ) |
| gui_use_soma_layer_checkbox.value = not gui_use_soma_layer_checkbox.value |
| finally: |
| loading_notif.loading = False |
| loading_notif.with_close_button = True |
| loading_notif.auto_close_seconds = 2.0 |
|
|
| @gui_real_robot_rotations_checkbox.on_update |
| def _(event: viser.GuiEvent) -> None: |
| session = get_active_session(event.client) |
| if session is None or "g1" not in session.model_name: |
| return |
| if not isinstance(session.skeleton, G1Skeleton34) or not session.motions: |
| return |
| if not gui_real_robot_rotations_checkbox.value: |
| return |
| |
| from kimodo.skeleton import global_rots_to_local_rots |
|
|
| current_frame_idx = session.frame_idx |
| for motion in session.motions.values(): |
| if motion.length <= 1: |
| continue |
| rest_pos = motion.joints_pos[0:1] |
| rest_rot = motion.joints_rot[0:1] |
| same_as_rest = (motion.joints_pos - rest_pos).abs().max().item() < 1e-6 and ( |
| motion.joints_rot - rest_rot |
| ).abs().max().item() < 1e-6 |
| if same_as_rest: |
| continue |
| new_pos, new_rot = generation.apply_g1_real_robot_projection( |
| session.skeleton, |
| motion.joints_pos, |
| motion.joints_rot, |
| ) |
| motion.joints_pos = new_pos |
| motion.joints_rot = new_rot |
| motion.joints_local_rot = global_rots_to_local_rots(new_rot, session.skeleton) |
| |
| motion.precompute_mesh_info() |
| demo.set_frame(event.client.client_id, current_frame_idx) |
| event.client.add_notification( |
| title="Real robot projection applied", |
| body="The motion is projected to G1 real robot DoF (1-DoF per joint, clamped to axis limits).", |
| auto_close_seconds=4.0, |
| color="green", |
| ) |
|
|
| def load_example_from_path( |
| event_client: viser.ClientHandle, |
| example_path: str, |
| load_gt: bool = False, |
| ) -> None: |
| from kimodo.meta import parse_prompts_from_meta |
| from kimodo.tools import load_json |
|
|
| session = get_active_session(event_client) |
| if session is None: |
| return |
|
|
| |
| session.playing = False |
|
|
| if not os.path.isdir(example_path): |
| event_client.add_notification( |
| title="Example path not found", |
| body=f"Directory does not exist: {example_path}", |
| auto_close_seconds=5.0, |
| color="red", |
| ) |
| return |
|
|
| try: |
| |
| constraints_path = os.path.join(example_path, "constraints.json") |
| if os.path.exists(constraints_path): |
| load_constraints(event_client, constraints_path) |
| else: |
| |
| with session.timeline_data["keyframe_update_lock"]: |
| for constraint in list(session.constraints.values()): |
| constraint.clear() |
| event_client.timeline.clear_keyframes() |
| event_client.timeline.clear_intervals() |
| |
| motion_filename = "gt_motion.npz" if load_gt else "motion.npz" |
| motion_path = os.path.join(example_path, motion_filename) |
| if os.path.exists(motion_path): |
| load_motion(event_client, motion_path) |
| |
| meta_path = os.path.join(example_path, "meta.json") |
| if os.path.exists(meta_path): |
| meta_info = load_json(meta_path) |
| event_client.timeline.clear_prompts() |
|
|
| texts, durations_sec = parse_prompts_from_meta(meta_info) |
| fps = session.model_fps |
| |
| num_frames = 0 |
| frame_bounds = [] |
| for i, d in enumerate(durations_sec): |
| n_frames = max(1, int(round(d * fps))) |
| start_frame = num_frames |
| |
| |
| |
| if i == len(durations_sec) - 1: |
| end_frame = num_frames + n_frames - 1 |
| else: |
| end_frame = num_frames + n_frames |
| frame_bounds.append((start_frame, end_frame)) |
| num_frames += n_frames |
|
|
| |
| target_visible_frames = int(math.ceil(1.10 * num_frames)) |
| event_client.timeline.set_zoom_settings( |
| default_num_frames_zoom=target_visible_frames, |
| ) |
|
|
| for i, (prompt_text, (start_frame, end_frame)) in enumerate(zip(texts, frame_bounds)): |
| color = PROMPT_COLORS[i % len(PROMPT_COLORS)] |
| event_client.timeline.add_prompt(prompt_text, start_frame, end_frame, color=color) |
|
|
| update_duration_auto() |
|
|
| |
| if "num_samples" in meta_info: |
| gui_num_samples_slider.value = meta_info["num_samples"] |
| if "seed" in meta_info: |
| gui_seed.value = meta_info["seed"] |
| if "diffusion_steps" in meta_info: |
| gui_diffusion_steps_slider.value = meta_info["diffusion_steps"] |
| if "cfg" in meta_info: |
| cfg = meta_info["cfg"] |
| if "enabled" in cfg: |
| gui_cfg_checkbox.value = cfg["enabled"] |
| if "text_weight" in cfg: |
| gui_cfg_text_weight_slider.value = cfg["text_weight"] |
| if "constraint_weight" in cfg: |
| gui_cfg_constraint_weight_slider.value = cfg["constraint_weight"] |
|
|
| |
| session.frame_idx = 0 |
| event_client.timeline.set_current_frame(0) |
| demo.set_frame(event_client.client_id, 0) |
|
|
| event_client.add_notification( |
| title="Example loaded!", |
| body=f"Loaded example from {example_path}", |
| auto_close_seconds=5.0, |
| color="green", |
| ) |
| except Exception as e: |
| import traceback |
|
|
| traceback.print_exc() |
| event_client.add_notification( |
| title="Failed to load example!", |
| body=str(e), |
| auto_close_seconds=10.0, |
| color="red", |
| ) |
|
|
| def load_qwen_example_plan(event_client: viser.ClientHandle) -> None: |
| """Load a Qwen-generated 10-action prompt plan into the timeline. |
| |
| This preserves the native UI flow: |
| 1) Load Example -> fills timeline prompt segments |
| 2) Generate -> synthesizes motion from loaded prompts |
| """ |
| session = get_active_session(event_client) |
| if session is None: |
| return |
|
|
| def _thread_fn() -> None: |
| try: |
| history: list[str] = [] |
| all_texts: list[str] = [] |
| all_durations: list[float] = [] |
| target_actions = 10 |
| rounds = 0 |
|
|
| while len(all_texts) < target_actions and rounds < 8: |
| remaining = target_actions - len(all_texts) |
| batch, history = call_qwen_for_prompts( |
| scene="Agentic demo: keep one character in continuous motion", |
| history=history, |
| requested_actions=min(5, remaining), |
| ) |
| texts = batch.get("texts", []) |
| durations = batch.get("durations", []) |
| for t, d in zip(texts, durations): |
| if len(all_texts) >= target_actions: |
| break |
| all_texts.append(t) |
| all_durations.append(float(d)) |
| rounds += 1 |
|
|
| if len(all_texts) == 0: |
| event_client.add_notification( |
| title="Qwen example load failed", |
| body="No prompt segments were produced.", |
| auto_close_seconds=6.0, |
| color="red", |
| ) |
| return |
|
|
| fps = session.model_fps |
| event_client.timeline.clear_prompts() |
| frame_cursor = 0 |
| for i, (txt, dur) in enumerate(zip(all_texts, all_durations)): |
| n_frames = max(1, int(round(dur * fps))) |
| start_frame = frame_cursor |
| end_frame = frame_cursor + n_frames if i < len(all_texts) - 1 else frame_cursor + n_frames - 1 |
| color = PROMPT_COLORS[i % len(PROMPT_COLORS)] |
| event_client.timeline.add_prompt(txt, start_frame, end_frame, color=color) |
| frame_cursor += n_frames |
|
|
| target_visible_frames = int(math.ceil(1.10 * frame_cursor)) |
| event_client.timeline.set_zoom_settings(default_num_frames_zoom=max(60, target_visible_frames)) |
| update_duration_auto() |
|
|
| event_client.add_notification( |
| title="Qwen example loaded", |
| body=f"Loaded {len(all_texts)} prompt segments. Click Generate to synthesize motion.", |
| auto_close_seconds=6.0, |
| color="green", |
| ) |
| except Exception as e: |
| event_client.add_notification( |
| title="Qwen example load failed", |
| body=str(e), |
| auto_close_seconds=8.0, |
| color="red", |
| ) |
|
|
| threading.Thread(target=_thread_fn, daemon=True).start() |
|
|
| @gui_load_example_button.on_click |
| def _(event: viser.GuiEvent) -> None: |
| event_client = event.client |
| session = get_active_session(event_client) |
| if session is None: |
| return |
|
|
| if gui_examples_dropdown.value in (QWEN_EXAMPLE_NAME, QWEN_EXAMPLE_LEGACY_NAME): |
| load_qwen_example_plan(event_client) |
| return |
|
|
| if not session.example_dict or (gui_examples_dropdown.value not in session.example_dict): |
| event_client.add_notification( |
| title="No examples available", |
| body="No examples found for the selected model.", |
| auto_close_seconds=5.0, |
| color="red", |
| ) |
| return |
|
|
| example_path = session.example_dict[gui_examples_dropdown.value] |
| load_example_from_path(event_client, example_path, gui_load_gt_checkbox.value) |
|
|
| @gui_load_example_from_path_button.on_click |
| def _(event: viser.GuiEvent) -> None: |
| event_client = event.client |
| session = get_active_session(event_client) |
| if session is None: |
| return |
|
|
| example_path = gui_load_example_path_text.value |
| if not example_path: |
| event_client.add_notification( |
| title="No example path", |
| body="Please provide an example directory.", |
| auto_close_seconds=5.0, |
| color="red", |
| ) |
| return |
| load_example_from_path(event_client, example_path, gui_load_gt_checkbox.value) |
|
|
| @gui_cfg_checkbox.on_update |
| def _(_) -> None: |
| if not demo.client_active(client_id): |
| return |
| val = gui_cfg_checkbox.value |
| gui_cfg_text_weight_slider.visible = val |
| gui_cfg_constraint_weight_slider.visible = val |
|
|
| def exit_editing_mode(session: ClientSession): |
| gui_edit_constraint_button.label = "Enter Editing Mode" |
| gui_generate_button.disabled = False |
| gui_generate_button.label = "Generate" |
| gui_reset_constraint_button.disabled = True |
| if "g1" in session.model_name: |
| gui_gizmo_space_dropdown.value = "Local" |
| gui_gizmo_space_dropdown.disabled = True |
| gui_gizmo_space_dropdown.visible = False |
| else: |
| gui_gizmo_space_dropdown.disabled = False |
| gui_gizmo_space_dropdown.visible = True |
| gui_undo_drag_button.disabled = True |
| gui_use_soma_layer_checkbox.disabled = False |
| session.edit_mode_snapshot = None |
| session.undo_drag_snapshot = None |
|
|
| motion = list(session.motions.values())[0] |
| motion.clear_all_gizmos() |
| motion.character.set_skinned_mesh_wireframe(False) |
| motion.character.set_skeleton_visibility(False) |
| motion.character.set_skinned_mesh_visibility(True) |
| motion.character.set_skinned_mesh_opacity(1.0) |
| session.gui_elements.gui_viz_skinned_mesh_opacity_slider.value = 1.0 |
|
|
| |
| if "2D Root" in session.constraints and session.constraints["2D Root"].dense_path: |
| _update_dense_path(motion, session) |
|
|
| gui_viz_skinned_mesh_checkbox.value = True |
| gui_viz_skeleton_checkbox.value = False |
|
|
| |
| @gui_edit_constraint_button.on_click |
| def _(event: viser.GuiEvent) -> None: |
| event_client = event.client |
| session = get_active_session(event_client) |
| if session is None: |
| return |
|
|
| session.edit_mode = not session.edit_mode |
|
|
| edit_alert = "Entered editing mode" |
| no_edit_alert = "Exited editing mode" |
| edit_message = "You can now modify pose or path constraints." |
| no_edit_message = "Can now generate motions." |
| event_client.add_notification( |
| title=edit_alert if session.edit_mode else no_edit_alert, |
| body=edit_message if session.edit_mode else no_edit_message, |
| auto_close_seconds=10.0, |
| color="blue", |
| ) |
|
|
| if session.edit_mode: |
| gui_edit_constraint_button.label = "Exit Editing Mode" |
| gui_generate_button.disabled = True |
| gui_generate_button.label = "Generate Disabled In Editing Mode" |
| if "g1" in session.model_name: |
| gui_gizmo_space_dropdown.value = "Local" |
| gui_gizmo_space_dropdown.disabled = True |
| gui_use_soma_layer_checkbox.disabled = True |
|
|
| assert len(session.motions) == 1, "Only one motion allowed in edit mode" |
| motion = list(session.motions.values())[0] |
| snapshot_frame_idx = min(session.frame_idx, motion.length - 1) |
| session.edit_mode_snapshot = {} |
| ensure_edit_snapshot(session, motion, snapshot_frame_idx) |
| gui_reset_constraint_button.disabled = False |
|
|
| motion.character.set_skeleton_visibility(True) |
| |
| motion.character.set_skinned_mesh_opacity(0.65) |
| session.gui_elements.gui_viz_skinned_mesh_opacity_slider.value = 0.65 |
| motion.character.set_skinned_mesh_visibility(True) |
| gui_viz_skinned_mesh_checkbox.value = True |
| gui_viz_skeleton_checkbox.value = True |
|
|
| |
| def _on_root2d_gizmo_release(): |
| if "2D Root" in session.constraints and session.constraints["2D Root"].dense_path: |
| mot = list(session.motions.values())[0] |
| _update_dense_path(mot, session) |
|
|
| def _on_gizmo_drag_start(): |
| mot = list(session.motions.values())[0] |
| frame_idx = min(session.frame_idx, mot.length - 1) |
| session.undo_drag_snapshot = { |
| "frame_idx": frame_idx, |
| "joints_pos": mot.get_joints_pos(frame_idx), |
| "joints_rot": mot.get_joints_rot(frame_idx), |
| } |
| gui_undo_drag_button.disabled = False |
|
|
| motion.add_root_translation_gizmo( |
| session.constraints, |
| on_2d_root_drag_end=_on_root2d_gizmo_release, |
| on_drag_start=_on_gizmo_drag_start, |
| ) |
| gizmo_space = "local" if "g1" in session.model_name else gui_gizmo_space_dropdown.value.lower() |
| motion.add_joint_gizmos( |
| session.constraints, |
| space=gizmo_space, |
| on_drag_start=_on_gizmo_drag_start, |
| ) |
| else: |
| exit_editing_mode(session) |
|
|
| @gui_reset_constraint_button.on_click |
| def _(event: viser.GuiEvent) -> None: |
| event_client = event.client |
| session = get_active_session(event_client) |
| if session is None or not session.edit_mode_snapshot: |
| return |
|
|
| if not session.motions: |
| return |
| motion = list(session.motions.values())[0] |
| snapshot_frame_idx = min(session.frame_idx, motion.length - 1) |
| if snapshot_frame_idx not in session.edit_mode_snapshot: |
| return |
| motion.update_pose_at_frame( |
| snapshot_frame_idx, |
| joints_pos=session.edit_mode_snapshot[snapshot_frame_idx]["joints_pos"], |
| joints_rot=session.edit_mode_snapshot[snapshot_frame_idx]["joints_rot"], |
| ) |
| demo.set_frame(event_client.client_id, snapshot_frame_idx, update_timeline=False) |
|
|
| @gui_undo_drag_button.on_click |
| def _(event: viser.GuiEvent) -> None: |
| event_client = event.client |
| session = get_active_session(event_client) |
| if session is None or session.undo_drag_snapshot is None: |
| return |
|
|
| if not session.motions: |
| return |
| motion = list(session.motions.values())[0] |
| frame_idx = session.undo_drag_snapshot["frame_idx"] |
| motion.update_pose_at_frame( |
| frame_idx, |
| joints_pos=session.undo_drag_snapshot["joints_pos"], |
| joints_rot=session.undo_drag_snapshot["joints_rot"], |
| ) |
| demo.set_frame(event_client.client_id, frame_idx, update_timeline=False) |
| session.undo_drag_snapshot = None |
| gui_undo_drag_button.disabled = True |
|
|
| def validate_interval(start_frame_idx: int, end_frame_idx: int, max_frame_idx: int) -> bool: |
| if start_frame_idx < 0 or start_frame_idx > max_frame_idx: |
| return False |
| if end_frame_idx < 0 or end_frame_idx > max_frame_idx: |
| return False |
| if end_frame_idx < start_frame_idx: |
| return False |
| return True |
|
|
| def clamp_interval_to_range( |
| start_frame_idx: int, end_frame_idx: int, max_frame_idx: int |
| ) -> Optional[tuple[int, int]]: |
| if end_frame_idx < 0 or start_frame_idx > max_frame_idx: |
| return None |
| start_clamped = max(0, start_frame_idx) |
| end_clamped = min(max_frame_idx, end_frame_idx) |
| if end_clamped < start_clamped: |
| return None |
| return start_clamped, end_clamped |
|
|
| |
| def add_constraint_callback( |
| constraint_id: str, |
| constraint_type: str, |
| frame_range: tuple[int, int], |
| joint_names: list[str] = None, |
| verbose: bool = True, |
| ): |
| """Add a constraint to the session. |
| |
| Args: |
| constraint_type: str, the type of constraint to add |
| frame_range: tuple[int, int], the frame range to add the constraint to |
| joint_names: list[str], the names of the joints to constraint if the constraint type is End-Effectors |
| """ |
| |
| if not demo.client_active(client_id): |
| return |
| session = demo.client_sessions[client_id] |
|
|
| assert len(session.motions) == 1, "Only one motion allowed for adding constraints" |
| motion = list(session.motions.values())[0] |
|
|
| end_effector_type = None |
| if constraint_type in [ |
| "Left Hand", |
| "Right Hand", |
| "Left Foot", |
| "Right Foot", |
| ]: |
| joint_names = [constraint_type.replace(" ", ""), "Hips"] |
| |
| end_effector_type = constraint_type.replace(" ", "-").lower() |
| constraint_type = "End-Effectors" |
|
|
| |
| is_interval = frame_range[1] != frame_range[0] |
| start_frame_idx = int(frame_range[0]) |
| end_frame_idx = int(frame_range[1]) |
|
|
| if is_interval: |
| clamped = clamp_interval_to_range(start_frame_idx, end_frame_idx, session.max_frame_idx) |
| if clamped is None: |
| print("Interval outside range! Couldn't add constraint.") |
| return |
| start_frame_idx, end_frame_idx = clamped |
| else: |
| if not validate_interval(start_frame_idx, end_frame_idx, session.max_frame_idx): |
| print("Invalid interval! Couldn't add constraint.") |
| return |
|
|
| |
| if is_interval: |
| constraint_kwargs = { |
| "interval_id": constraint_id, |
| "start_frame_idx": start_frame_idx, |
| "end_frame_idx": end_frame_idx, |
| } |
| else: |
| constraint_kwargs = { |
| "keyframe_id": constraint_id, |
| "frame_idx": start_frame_idx, |
| } |
|
|
| if constraint_type in ["Full-Body", "End-Effectors"]: |
| constraint_kwargs["joints_pos"] = motion.get_joints_pos(start_frame_idx, end_frame_idx) |
| constraint_kwargs["joints_rot"] = motion.get_joints_rot(start_frame_idx, end_frame_idx) |
| if constraint_type == "End-Effectors": |
| constraint_kwargs["joint_names"] = joint_names |
| constraint_kwargs["end_effector_type"] = end_effector_type |
|
|
| elif constraint_type == "2D Root": |
| constraint_kwargs["root_pos"] = motion.get_projected_root_pos(start_frame_idx, end_frame_idx) |
|
|
| |
| constraint = session.constraints[constraint_type] |
| if is_interval: |
| constraint.add_interval(**constraint_kwargs) |
| else: |
| constraint.add_keyframe(**constraint_kwargs) |
|
|
| apply_constraint_overlay_visibility(session) |
|
|
| if verbose: |
| client.add_notification( |
| title="Constraint added", |
| body="", |
| auto_close_seconds=5.0, |
| color="blue", |
| ) |
|
|
| |
| @client.timeline.on_keyframe_add |
| def _(keyframe_id: str, track_id: str, frame: int): |
| """Called when a keyframe is added to a track.""" |
| if not demo.client_active(client_id): |
| return |
| session = demo.client_sessions[client_id] |
| with session.timeline_data["keyframe_update_lock"]: |
| constraint_type = session.timeline_data["tracks"][track_id]["name"] |
| add_constraint_callback( |
| keyframe_id, |
| constraint_type, |
| (frame, frame), |
| verbose=False, |
| ) |
| keyframe_data = client.timeline._keyframes.get(keyframe_id) |
| session.timeline_data["keyframes"][keyframe_id] = { |
| "frame": frame, |
| "track_id": track_id, |
| "locked": bool(keyframe_data.locked) if keyframe_data is not None else False, |
| "opacity": keyframe_data.opacity if keyframe_data is not None else 1.0, |
| "value": keyframe_data.value if keyframe_data is not None else None, |
| } |
| |
| if constraint_type == "2D Root" and session.constraints["2D Root"].dense_path: |
| motion = list(session.motions.values())[0] |
| _update_dense_path(motion, session) |
|
|
| @client.timeline.on_interval_add |
| def handle_interval_add(interval_id: str, track_id: str, start_frame: int, end_frame: int): |
| """Called when an interval is added to a track.""" |
| if not demo.client_active(client_id): |
| return |
| session = demo.client_sessions[client_id] |
| with session.timeline_data["keyframe_update_lock"]: |
| constraint_type = session.timeline_data["tracks"][track_id]["name"] |
| add_constraint_callback( |
| interval_id, |
| constraint_type, |
| (start_frame, end_frame), |
| verbose=False, |
| ) |
| interval_data = client.timeline._intervals.get(interval_id) |
| session.timeline_data["intervals"][interval_id] = { |
| "track_id": track_id, |
| "start_frame_idx": start_frame, |
| "end_frame_idx": end_frame, |
| "locked": bool(interval_data.locked) if interval_data is not None else False, |
| "opacity": interval_data.opacity if interval_data is not None else 1.0, |
| "value": interval_data.value if interval_data is not None else None, |
| } |
| if constraint_type == "2D Root" and session.constraints["2D Root"].dense_path: |
| motion = list(session.motions.values())[0] |
| _update_dense_path(motion, session) |
|
|
| def remove_constraint_callback( |
| constraint_id: str, |
| constraint_type: str, |
| frame_range: tuple[int, int], |
| verbose: bool = True, |
| ) -> None: |
| if not demo.client_active(client_id): |
| return |
| session = demo.client_sessions[client_id] |
| session.updating_motions = True |
|
|
| is_interval = frame_range[1] != frame_range[0] |
| start_frame_idx = int(frame_range[0]) |
| end_frame_idx = int(frame_range[1]) |
|
|
| if is_interval: |
| clamped = clamp_interval_to_range(start_frame_idx, end_frame_idx, session.max_frame_idx) |
| if clamped is None: |
| return |
| start_frame_idx, end_frame_idx = clamped |
| else: |
| if not validate_interval(start_frame_idx, end_frame_idx, session.max_frame_idx): |
| print("Invalid interval! Couldn't remove constraint.") |
| return |
|
|
| if constraint_type in [ |
| "Left Hand", |
| "Right Hand", |
| "Left Foot", |
| "Right Foot", |
| ]: |
| constraint_type = "End-Effectors" |
|
|
| constraint = session.constraints[constraint_type] |
| if is_interval: |
| constraint.remove_interval(constraint_id, start_frame_idx, end_frame_idx) |
| else: |
| constraint.remove_keyframe(constraint_id, start_frame_idx) |
|
|
| if verbose: |
| client.add_notification( |
| title="Constraint removed", |
| body="", |
| auto_close_seconds=5.0, |
| color="blue", |
| ) |
|
|
| @client.timeline.on_keyframe_move |
| def handle_keyframe_move(keyframe_id: str, new_frame: int): |
| """Called when a keyframe is moved to a new frame.""" |
| |
| if not demo.client_active(client_id): |
| return |
| session = demo.client_sessions[client_id] |
|
|
| |
| timeline_data = session.timeline_data |
| with timeline_data["keyframe_update_lock"]: |
| if keyframe_id in timeline_data["keyframe_move_timers"]: |
| timeline_data["keyframe_move_timers"][keyframe_id].cancel() |
|
|
| |
| timeline_data["pending_keyframe_moves"][keyframe_id] = new_frame |
| |
| |
| timer = threading.Timer( |
| 0.03, |
| _execute_keyframe_move, |
| args=(client_id, keyframe_id, new_frame, session), |
| ) |
| timeline_data["keyframe_move_timers"][keyframe_id] = timer |
| timer.start() |
|
|
| def _execute_keyframe_move( |
| client_id: int, |
| keyframe_id: str, |
| new_frame: int, |
| session: ClientSession, |
| ): |
| """Actually execute the keyframe move operation (called after debounce delay).""" |
|
|
| timeline_data = session.timeline_data |
| with timeline_data["keyframe_update_lock"]: |
| |
| if keyframe_id not in timeline_data["pending_keyframe_moves"]: |
| return |
|
|
| if timeline_data["pending_keyframe_moves"][keyframe_id] != new_frame: |
| return |
|
|
| |
| del timeline_data["pending_keyframe_moves"][keyframe_id] |
| if keyframe_id in timeline_data["keyframe_move_timers"]: |
| del timeline_data["keyframe_move_timers"][keyframe_id] |
|
|
| |
| if keyframe_id not in timeline_data["keyframes"]: |
| |
| return |
| keyframe_data = timeline_data["keyframes"][keyframe_id] |
| if not keyframe_data: |
| return |
|
|
| |
| if keyframe_data["frame"] == new_frame: |
| return |
|
|
| track_id = keyframe_data["track_id"] |
| constraint_type = timeline_data["tracks"][track_id]["name"] |
| cur_frame = keyframe_data["frame"] |
|
|
| |
| remove_constraint_callback( |
| keyframe_id, |
| constraint_type, |
| (cur_frame, cur_frame), |
| verbose=False, |
| ) |
| |
| add_constraint_callback( |
| keyframe_id, |
| constraint_type, |
| (new_frame, new_frame), |
| verbose=False, |
| ) |
|
|
| |
| keyframe_data["frame"] = new_frame |
|
|
| |
| if constraint_type == "2D Root": |
| _schedule_dense_path_after_release(session) |
|
|
| @client.timeline.on_keyframe_delete |
| def handle_keyframe_delete(keyframe_id: str): |
| """Called when a keyframe is deleted.""" |
| if not demo.client_active(client_id): |
| return |
| session = demo.client_sessions[client_id] |
| with session.timeline_data["keyframe_update_lock"]: |
| if keyframe_id not in session.timeline_data["keyframes"]: |
| return |
| keyframe_data = session.timeline_data["keyframes"][keyframe_id] |
| track_id = keyframe_data["track_id"] |
| constraint_type = session.timeline_data["tracks"][track_id]["name"] |
| cur_frame = keyframe_data["frame"] |
| remove_constraint_callback( |
| keyframe_id, |
| constraint_type, |
| (cur_frame, cur_frame), |
| verbose=False, |
| ) |
| del session.timeline_data["keyframes"][keyframe_id] |
| if constraint_type == "2D Root" and session.constraints["2D Root"].dense_path: |
| motion = list(session.motions.values())[0] |
| _update_dense_path(motion, session) |
|
|
| @client.timeline.on_interval_move |
| def handle_interval_move(interval_id: str, new_start: int, new_end: int): |
| """Called when an interval is moved or resized.""" |
| |
| if not demo.client_active(client_id): |
| return |
| session = demo.client_sessions[client_id] |
|
|
| |
| |
| timeline_data = session.timeline_data |
| with timeline_data["keyframe_update_lock"]: |
| if interval_id in timeline_data["keyframe_move_timers"]: |
| timeline_data["keyframe_move_timers"][interval_id].cancel() |
|
|
| |
| new_interval = (new_start, new_end) |
| timeline_data["pending_keyframe_moves"][interval_id] = new_interval |
| |
| |
| timer = threading.Timer( |
| 0.5, |
| _execute_interval_move, |
| args=(client_id, interval_id, new_interval, session), |
| ) |
| timeline_data["keyframe_move_timers"][interval_id] = timer |
| timer.start() |
|
|
| def _execute_interval_move( |
| client_id: int, |
| interval_id: str, |
| new_interval: tuple[int, int], |
| session: ClientSession, |
| ): |
| """Actually execute the interval move operation (called after debounce delay).""" |
|
|
| timeline_data = session.timeline_data |
| with timeline_data["keyframe_update_lock"]: |
| |
| if interval_id not in timeline_data["pending_keyframe_moves"]: |
| return |
|
|
| if timeline_data["pending_keyframe_moves"][interval_id] != new_interval: |
| return |
|
|
| |
| del timeline_data["pending_keyframe_moves"][interval_id] |
| if interval_id in timeline_data["keyframe_move_timers"]: |
| del timeline_data["keyframe_move_timers"][interval_id] |
|
|
| |
| if interval_id not in timeline_data["intervals"]: |
| return |
| interval_data = timeline_data["intervals"][interval_id] |
| if not interval_data: |
| return |
|
|
| |
| if ( |
| interval_data["start_frame_idx"] == new_interval[0] |
| and interval_data["end_frame_idx"] == new_interval[1] |
| ): |
| return |
|
|
| track_id = interval_data["track_id"] |
| constraint_type = timeline_data["tracks"][track_id]["name"] |
| cur_range = ( |
| interval_data["start_frame_idx"], |
| interval_data["end_frame_idx"], |
| ) |
|
|
| |
| remove_constraint_callback( |
| interval_id, |
| constraint_type, |
| cur_range, |
| verbose=False, |
| ) |
| |
| add_constraint_callback( |
| interval_id, |
| constraint_type, |
| new_interval, |
| verbose=False, |
| ) |
|
|
| |
| interval_data["start_frame_idx"] = new_interval[0] |
| interval_data["end_frame_idx"] = new_interval[1] |
|
|
| |
| if constraint_type == "2D Root": |
| _schedule_dense_path_after_release(session) |
|
|
| @client.timeline.on_interval_delete |
| def handle_interval_delete(interval_id: str): |
| """Called when an interval is deleted.""" |
| if not demo.client_active(client_id): |
| return |
| session = demo.client_sessions[client_id] |
| with session.timeline_data["keyframe_update_lock"]: |
| if interval_id not in session.timeline_data["intervals"]: |
| return |
| interval_data = session.timeline_data["intervals"][interval_id] |
| track_id = interval_data["track_id"] |
| constraint_type = session.timeline_data["tracks"][track_id]["name"] |
| remove_constraint_callback( |
| interval_id, |
| constraint_type, |
| ( |
| interval_data["start_frame_idx"], |
| interval_data["end_frame_idx"], |
| ), |
| verbose=False, |
| ) |
| del session.timeline_data["intervals"][interval_id] |
| if constraint_type == "2D Root" and session.constraints["2D Root"].dense_path: |
| motion = list(session.motions.values())[0] |
| _update_dense_path(motion, session) |
|
|
| @gui_snap_to_constraint_button.on_click |
| def _(event: viser.GuiEvent) -> None: |
| event_client = event.client |
| session = get_active_session(event_client) |
| if session is None: |
| return |
|
|
| target_character_motion = list(session.motions.values())[0] |
| frame_idx = session.frame_idx |
|
|
| if frame_idx >= target_character_motion.length: |
| |
| return |
|
|
| for constraint_name in ["Full-Body", "End-Effectors"]: |
| if ( |
| constraint_name in session.constraints |
| and frame_idx in session.constraints[constraint_name].keyframes |
| ): |
| pos = session.constraints[constraint_name].keyframes[frame_idx]["joints_pos"] |
| rot = session.constraints[constraint_name].keyframes[frame_idx]["joints_rot"] |
|
|
| |
| target_character_motion.update_pose_at_frame( |
| frame_idx, |
| joints_pos=pos, |
| joints_rot=rot, |
| ) |
| target_character_motion.set_frame(frame_idx) |
| return |
|
|
| if "2D Root" in session.constraints and frame_idx in session.constraints["2D Root"].keyframes: |
| |
| new_root_pos = session.constraints["2D Root"].keyframes[frame_idx] |
| old_root_pos = target_character_motion.get_projected_root_pos(frame_idx) |
| root_diff = new_root_pos - old_root_pos |
| root_diff[1] = 0.0 |
|
|
| new_joints_pos = ( |
| target_character_motion.joints_pos[frame_idx] |
| + to_torch( |
| root_diff, |
| device=target_character_motion.joints_pos.device, |
| dtype=target_character_motion.joints_pos.dtype, |
| )[None] |
| ) |
| rot = target_character_motion.joints_rot[frame_idx] |
|
|
| target_character_motion.update_pose_at_frame( |
| frame_idx, |
| joints_pos=new_joints_pos, |
| joints_rot=rot, |
| ) |
| target_character_motion.set_frame(frame_idx) |
|
|
| @gui_clear_all_constraints_button.on_click |
| def _(event: viser.GuiEvent) -> None: |
| event_client = event.client |
| session = get_active_session(event_client) |
| if session is None: |
| return |
| with session.timeline_data["keyframe_update_lock"]: |
| |
| for constraint in list(session.constraints.values()): |
| constraint.clear() |
| client.timeline.clear_keyframes() |
| client.timeline.clear_intervals() |
| if gui_dense_path_checkbox.value: |
| gui_dense_path_checkbox.value = False |
| if "2D Root" in session.constraints: |
| session.constraints["2D Root"].set_dense_path(False) |
|
|
| |
| @gui_generate_button.on_click |
| def _(event: viser.GuiEvent) -> None: |
| event_client = event.client |
| session = get_active_session(event_client) |
| if session is None: |
| return |
|
|
| generating_notif = event_client.add_notification( |
| title="Generating motion...", |
| body="Generating motions for the given prompt!", |
| loading=True, |
| with_close_button=False, |
| ) |
| gui_generate_button.disabled = True |
| client.timeline.disable_constraints() |
|
|
| num_samples = gui_num_samples_slider.value |
| timeline = session.client.timeline |
|
|
| |
| prompt_values = sorted([x for x in timeline._prompts.values()], key=lambda x: x.start_frame) |
|
|
| texts = [x.text for x in prompt_values] |
| num_frames = compute_prompt_num_frames(prompt_values) |
|
|
| |
| total_nb_frames = sum(num_frames) |
| total_duration = total_nb_frames / session.model_fps |
|
|
| |
| set_new_duration(client_id, total_duration) |
|
|
| transitions_parameters = { |
| "num_transition_frames": gui_num_transition_frames_slider.value, |
| "share_transition": gui_share_transition_checkbox.value, |
| "percentage_transition_override": gui_percentage_transition_sharing_slider.value / 100, |
| } |
|
|
| |
| postprocess_parameters = { |
| "post_processing": (False if "g1" in session.model_name else gui_postprocess_checkbox.value), |
| "root_margin": gui_root_margin.value, |
| } |
| try: |
| demo.generate( |
| event_client, |
| texts, |
| num_frames, |
| num_samples, |
| gui_seed.value, |
| gui_diffusion_steps_slider.value, |
| cfg_weight=[ |
| gui_cfg_text_weight_slider.value, |
| gui_cfg_constraint_weight_slider.value, |
| ], |
| cfg_type="separated" if gui_cfg_checkbox.value else "nocfg", |
| postprocess_parameters=postprocess_parameters, |
| transitions_parameters=transitions_parameters, |
| real_robot_rotations=gui_real_robot_rotations_checkbox.value, |
| ) |
| session.max_frame_idx = int(session.cur_duration * session.model_fps - 1) |
| session.max_frame_idx = int(session.cur_duration * session.model_fps) - 1 |
| if session.frame_idx > session.max_frame_idx: |
| session.frame_idx = session.max_frame_idx |
|
|
| if num_samples > 1: |
| |
| def commit_motion(event: viser.GuiEvent) -> None: |
| target = event.target |
| commit_name = target.name.split("/")[1] |
| print(f"Committing motion for character: {commit_name}") |
| |
| new_motion_kwargs = None |
| for character_name, motion in session.motions.items(): |
| if character_name == commit_name: |
| new_motion_kwargs = { |
| "skeleton": session.skeleton, |
| "joints_rot": motion.joints_rot, |
| "foot_contacts": motion.foot_contacts, |
| } |
| root_x_offset = motion.joints_pos[0, session.skeleton.root_idx, 0] |
| new_joints_pos = motion.joints_pos.clone() |
| new_joints_pos[..., 0] -= root_x_offset |
| new_motion_kwargs["joints_pos"] = new_joints_pos |
| break |
| |
| demo.clear_motions(event_client.client_id) |
| demo.add_character_motion(event_client, **new_motion_kwargs) |
| gui_edit_constraint_button.disabled = False |
| gui_generate_button.disabled = False |
| gui_snap_to_constraint_button.disabled = False |
| client.timeline.enable_constraints() |
| gui_generate_button.label = "Generate" |
| gui_save_example_button.disabled = False |
| gui_save_motion_button.disabled = False |
| gui_download_button.disabled = False |
| gui_save_constraints_button.disabled = False |
| gui_load_example_button.disabled = False |
|
|
| for motion in session.motions.values(): |
| char = motion.character |
| character_name = char.name |
| if char.skinned_mesh is not None: |
| char.skinned_mesh.on_click(commit_motion) |
| elif char.g1_mesh_rig is not None: |
| |
| |
| for handle in char.g1_mesh_rig.mesh_handles: |
| handle.on_click(commit_motion, highlight_group=character_name) |
|
|
| gui_edit_constraint_button.disabled = True |
| gui_generate_button.disabled = True |
| gui_snap_to_constraint_button.disabled = True |
| gui_generate_button.label = "Choose Sample Before Generating" |
| gui_save_example_button.disabled = True |
| gui_save_motion_button.disabled = True |
| gui_download_button.disabled = True |
| gui_save_constraints_button.disabled = True |
| gui_load_example_button.disabled = True |
| else: |
| gui_edit_constraint_button.disabled = False |
| gui_generate_button.disabled = False |
| gui_snap_to_constraint_button.disabled = False |
| client.timeline.enable_constraints() |
|
|
| generating_notif.title = "Motion generation finished!" |
| generating_notif.body = "Motions have been generated successfully for the given prompt." |
| if num_samples > 1: |
| generating_notif.body += " Now choose which sample to commit." |
| generating_notif.loading = False |
| generating_notif.with_close_button = True |
| generating_notif.auto_close_seconds = 5.0 |
| generating_notif.color = "green" |
|
|
| |
| demo.set_frame(client_id, 0) |
|
|
| except Exception as e: |
| import traceback |
|
|
| traceback.print_exc() |
| print(f"Error during generation for client {event_client.client_id}: {e}") |
| |
| if event_client.client_id in demo.client_sessions: |
| session = demo.client_sessions[event_client.client_id] |
| gui_generate_button.disabled = False |
| gui_load_example_button.disabled = False |
| gui_save_example_button.disabled = False |
| gui_save_motion_button.disabled = False |
| gui_download_button.disabled = False |
| |
| try: |
| generating_notif.title = "Generation failed!" |
| generating_notif.body = f"Error: {str(e)}" |
| generating_notif.loading = False |
| generating_notif.with_close_button = True |
| generating_notif.auto_close_seconds = 6.0 |
| generating_notif.color = "red" |
| except Exception: |
| pass |
| demo.check_cuda_health() |
|
|
| |
| |
| |
| with tab_group.add_tab("Visualize", viser.Icon.EYE): |
| with client.gui.add_folder("Playback", expand_by_default=True): |
| gui_model_fps = client.gui.add_number("Model FPS", initial_value=model_fps, disabled=True) |
| gui_playback_speed_buttons = client.gui.add_button_group( |
| "Playback Speed", |
| options=[ |
| "0.5x", |
| "1x", |
| "2x", |
| ], |
| ) |
| gui_playback_speed_buttons.value = "1x" |
|
|
| @client.timeline.on_frame_change |
| def handle_timeline_frame_change(new_frame_idx: int): |
| """Update the frame when the user clicks on the timeline.""" |
| demo.set_frame(client_id, new_frame_idx, update_timeline=False) |
| session = demo.client_sessions.get(client_id) |
| if session is not None: |
| if session.edit_mode and session.motions: |
| motion = list(session.motions.values())[0] |
| snapshot_frame_idx = min(session.frame_idx, motion.length - 1) |
| ensure_edit_snapshot(session, motion, snapshot_frame_idx) |
| update_snap_to_constraint_button(session) |
|
|
| @client.timeline.on_prompt_add |
| async def _on_add( |
| prompt_id: str, |
| start_frame: int, |
| end_frame: int, |
| text: str, |
| color: tuple[int, int, int] | None, |
| ) -> None: |
| update_duration_auto() |
|
|
| @client.timeline.on_prompt_update |
| async def _on_update(prompt_id: str, new_text: str) -> None: |
| update_duration_auto() |
|
|
| @client.timeline.on_prompt_resize |
| async def _on_resize(prompt_id: str, new_start: int, new_end: int) -> None: |
| update_duration_auto() |
|
|
| @client.timeline.on_prompt_move |
| async def _on_move(prompt_id: str, new_start: int, new_end: int) -> None: |
| update_duration_auto() |
|
|
| @client.timeline.on_prompt_delete |
| async def _on_delete(prompt_id: str) -> None: |
| update_duration_auto() |
|
|
| def play_pause_button_callback(session: ClientSession): |
| session.playing = not session.playing |
|
|
| def next_frame_callback(session: ClientSession): |
| if session.frame_idx < session.max_frame_idx: |
| session.frame_idx += 1 |
| if session.frame_idx == session.max_frame_idx: |
| pass |
| demo.set_frame(client_id, session.frame_idx) |
|
|
| def prev_frame_callback(session: ClientSession): |
| if session.frame_idx > 0: |
| session.frame_idx -= 1 |
| if session.frame_idx == 0: |
| pass |
| demo.set_frame(client_id, session.frame_idx) |
|
|
| @gui_playback_speed_buttons.on_click |
| def _(_) -> None: |
| if not demo.client_active(client_id): |
| return |
| speed_map = { |
| "0.5x": 0.5, |
| "1x": 1.0, |
| "2x": 2.0, |
| } |
| session = demo.client_sessions[client_id] |
| session.playback_speed = speed_map[gui_playback_speed_buttons.value] |
|
|
| with client.gui.add_folder("Body options", expand_by_default=True): |
| gui_viz_skinned_mesh_checkbox = client.gui.add_checkbox("Show Mesh", initial_value=True) |
| gui_viz_skinned_mesh_opacity_slider = client.gui.add_slider( |
| "Mesh Opacity", min=0.0, max=1.0, step=0.01, initial_value=1.0 |
| ) |
| gui_viz_skeleton_checkbox = client.gui.add_checkbox("Show Skeleton", initial_value=False) |
| gui_viz_foot_contacts_checkbox = client.gui.add_checkbox("Show Foot Contacts", initial_value=False) |
| gui_viz_foot_contacts_checkbox.visible = gui_viz_skeleton_checkbox.value |
| with client.gui.add_folder("Camera options", expand_by_default=True): |
| gui_camera_fov_slider = client.gui.add_slider( |
| "Camera FOV (deg)", |
| min=30.0, |
| max=90.0, |
| step=1.0, |
| initial_value=45.0, |
| ) |
| client.camera.fov = np.deg2rad(gui_camera_fov_slider.value) |
| with client.gui.add_folder("Interface options", expand_by_default=True): |
| gui_show_timeline_checkbox = client.gui.add_checkbox( |
| "Show Timeline", |
| initial_value=True, |
| ) |
| gui_show_constraint_tracks_checkbox = client.gui.add_checkbox( |
| "Show Constraint tracks", |
| initial_value=True, |
| ) |
| gui_show_constraint_labels_checkbox = client.gui.add_checkbox( |
| "Show Constraint labels", |
| initial_value=True, |
| ) |
| gui_show_starting_direction_checkbox = client.gui.add_checkbox( |
| "Show Starting Direction", |
| initial_value=True, |
| ) |
| gui_dark_mode_checkbox = client.gui.add_checkbox( |
| "Dark Mode", |
| initial_value=False, |
| ) |
| gui_show_constraint_tracks_checkbox.visible = gui_show_timeline_checkbox.value |
| demo.set_start_direction_visible(client_id, gui_show_starting_direction_checkbox.value) |
|
|
| @gui_dark_mode_checkbox.on_update |
| def _(_): |
| |
| demo.configure_theme( |
| client, |
| gui_dark_mode_checkbox.value, |
| titlebar_dark_mode_checkbox_uuid=gui_dark_mode_checkbox.uuid, |
| ) |
| session = demo.client_sessions[client.client_id] |
| for motion in session.motions.values(): |
| motion.character.change_theme(gui_dark_mode_checkbox.value) |
|
|
| |
| demo.configure_theme( |
| client, |
| gui_dark_mode_checkbox.value, |
| titlebar_dark_mode_checkbox_uuid=gui_dark_mode_checkbox.uuid, |
| ) |
| gui_dark_mode_checkbox.visible = False |
|
|
| @gui_show_constraint_labels_checkbox.on_update |
| def _(_): |
| if not demo.client_active(client_id): |
| return |
| session = demo.client_sessions[client_id] |
| for constraint in session.constraints.values(): |
| constraint.set_label_visibility(gui_show_constraint_labels_checkbox.value) |
|
|
| @gui_show_timeline_checkbox.on_update |
| def _(_): |
| if not demo.client_active(client_id): |
| return |
| session = demo.client_sessions[client_id] |
| session.client.timeline.set_visible(gui_show_timeline_checkbox.value) |
| gui_show_constraint_tracks_checkbox.visible = gui_show_timeline_checkbox.value |
| if gui_show_timeline_checkbox.value: |
| demo.set_constraint_tracks_visible(session, gui_show_constraint_tracks_checkbox.value) |
|
|
| @gui_show_constraint_tracks_checkbox.on_update |
| def _(_): |
| if not demo.client_active(client_id): |
| return |
| session = demo.client_sessions[client_id] |
| demo.set_constraint_tracks_visible(session, gui_show_constraint_tracks_checkbox.value) |
|
|
| @gui_show_starting_direction_checkbox.on_update |
| def _(_): |
| if not demo.client_active(client_id): |
| return |
| demo.set_start_direction_visible(client_id, gui_show_starting_direction_checkbox.value) |
|
|
| @gui_viz_skeleton_checkbox.on_update |
| def _(_) -> None: |
| if not demo.client_active(client_id): |
| return |
| session = demo.client_sessions[client_id] |
| gui_viz_foot_contacts_checkbox.visible = gui_viz_skeleton_checkbox.value |
| if not gui_viz_skeleton_checkbox.value: |
| gui_viz_foot_contacts_checkbox.value = False |
| for motion in session.motions.values(): |
| motion.character.set_skeleton_visibility(gui_viz_skeleton_checkbox.value) |
|
|
| @gui_viz_foot_contacts_checkbox.on_update |
| def _(_) -> None: |
| if not demo.client_active(client_id): |
| return |
| session = demo.client_sessions[client_id] |
| for motion in session.motions.values(): |
| motion.character.set_show_foot_contacts( |
| gui_viz_foot_contacts_checkbox.value, frame_idx=motion.cur_frame_idx |
| ) |
|
|
| @gui_viz_skinned_mesh_checkbox.on_update |
| def _(_) -> None: |
| if not demo.client_active(client_id): |
| return |
| session = demo.client_sessions[client_id] |
| for motion in session.motions.values(): |
| motion.character.set_skinned_mesh_visibility(gui_viz_skinned_mesh_checkbox.value) |
|
|
| @gui_viz_skinned_mesh_opacity_slider.on_update |
| def _(_) -> None: |
| if not demo.client_active(client_id): |
| return |
| session = demo.client_sessions[client_id] |
| for motion in session.motions.values(): |
| motion.character.set_skinned_mesh_opacity(gui_viz_skinned_mesh_opacity_slider.value) |
|
|
| @gui_camera_fov_slider.on_update |
| def _(_) -> None: |
| if not demo.client_active(client_id): |
| return |
| client.camera.fov = np.deg2rad(gui_camera_fov_slider.value) |
|
|
| |
|
|
| |
| |
| with tab_group.add_tab("Instructions", viser.Icon.INFO_CIRCLE): |
| client.gui.add_markdown(DEMO_UI_INSTRUCTIONS_TAB_MD) |
|
|
| |
| |
| |
| @client.scene.on_keyboard_event("keydown", debounce_ms=100) |
| def handle_key(event: viser.KeyboardEvent) -> None: |
| |
| if client_id not in demo.client_sessions: |
| return |
|
|
| session = demo.client_sessions[client_id] |
|
|
| |
| if event.key == " ": |
| now = time.monotonic() |
| if now - session.last_space_toggle_time >= 0.2: |
| session.last_space_toggle_time = now |
| play_pause_button_callback(session) |
| return |
|
|
| |
| elif event.key == "ArrowLeft": |
| prev_frame_callback(session) |
| elif event.key == "ArrowRight": |
| next_frame_callback(session) |
|
|
| gui_elements = GuiElements( |
| gui_play_pause_button=gui_play_pause_button, |
| gui_next_frame_button=gui_next_frame_button, |
| gui_prev_frame_button=gui_prev_frame_button, |
| gui_generate_button=gui_generate_button, |
| gui_model_fps=gui_model_fps, |
| gui_timeline=gui_timeline, |
| gui_viz_skeleton_checkbox=gui_viz_skeleton_checkbox, |
| gui_viz_foot_contacts_checkbox=gui_viz_foot_contacts_checkbox, |
| gui_viz_skinned_mesh_checkbox=gui_viz_skinned_mesh_checkbox, |
| gui_viz_skinned_mesh_opacity_slider=gui_viz_skinned_mesh_opacity_slider, |
| gui_camera_fov_slider=gui_camera_fov_slider, |
| gui_duration_slider=gui_duration_slider, |
| gui_num_samples_slider=gui_num_samples_slider, |
| gui_cfg_checkbox=gui_cfg_checkbox, |
| gui_cfg_text_weight_slider=gui_cfg_text_weight_slider, |
| gui_cfg_constraint_weight_slider=gui_cfg_constraint_weight_slider, |
| gui_diffusion_steps_slider=gui_diffusion_steps_slider, |
| gui_seed=gui_seed, |
| gui_postprocess_checkbox=gui_postprocess_checkbox, |
| gui_root_margin=gui_root_margin, |
| gui_real_robot_rotations_checkbox=gui_real_robot_rotations_checkbox, |
| gui_dark_mode_checkbox=gui_dark_mode_checkbox, |
| gui_use_soma_layer_checkbox=gui_use_soma_layer_checkbox, |
| ) |
| return ( |
| gui_elements, |
| timeline_tracks, |
| example_dict, |
| gui_examples_dropdown, |
| gui_save_example_path_text, |
| gui_model_selector, |
| ) |
|
|