movimento / ui.py
rydlrKE's picture
Fix CUDA import order - import spaces before torch (commit e28bffd)
2a5255e verified
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: I001
import math
import os
import threading
import time
from typing import Optional
from kimodo.constraints import load_constraints_lst, save_constraints_lst
from kimodo.exports.bvh import motion_to_bvh_bytes, save_motion_bvh
from kimodo.exports.motion_io import (
amass_npz_to_bytes,
g1_csv_to_bytes,
kimodo_npz_to_bytes,
load_motion_file,
save_kimodo_npz,
)
from kimodo.model.registry import kimodo_short_key_for_skeleton_dataset, registry_skeleton_for_joint_count
from kimodo.tools import to_torch
from kimodo.viz import viser_utils
from kimodo.viz.viser_utils import GuiElements
import numpy as np
import torch
import viser
from viser._timeline_api import PROMPT_COLORS
from . import generation
from ._qwen_prompts import call_qwen_for_prompts
from .config import (
DEFAULT_CUR_DURATION,
DEMO_UI_INSTRUCTIONS_TAB_MD,
get_datasets,
get_model_info,
get_models_for_dataset_skeleton,
get_skeleton_display_name,
get_skeleton_display_names_for_dataset,
get_skeleton_key_from_display_name,
get_short_key_from_display_name,
HF_MODE,
INIT_POSTPROCESSING,
MODEL_NAMES,
NB_TRANSITION_FRAMES,
SHOW_TRANSITION_PARAMS,
)
from .state import ClientSession
from kimodo.skeleton import G1Skeleton34, SOMASkeleton30, SOMASkeleton77
QWEN_EXAMPLE_NAME = "09_qwen_agentic_actions"
QWEN_EXAMPLE_LEGACY_NAME = "09_qwen_agentic_10_actions"
def extract_intervals_and_singles(t: torch.Tensor):
intervals = []
intervals_indices = []
single_frames = []
single_frames_indices = []
start_idx = 0
for i in range(1, len(t) + 1):
# End of run if:
# - end of tensor
# - non-consecutive value
if i == len(t) or t[i] != t[i - 1] + 1:
run_length = i - start_idx
if run_length >= 2:
intervals.append((int(t[start_idx]), int(t[i - 1])))
intervals_indices.append((start_idx, i - 1))
else:
single_frames.append(int(t[start_idx]))
single_frames_indices.append(start_idx)
start_idx = i
return intervals, intervals_indices, single_frames, single_frames_indices
def create_gui(
demo,
client: viser.ClientHandle,
model_name: str,
model_fps: float,
):
"""Create GUI elements for a specific client."""
client_id = client.client_id
def get_active_session(event_client: viser.ClientHandle | None):
if event_client is None:
return None
if not demo.client_active(event_client.client_id):
return None
return demo.client_sessions[event_client.client_id]
def build_timeline_tracks():
timeline = client.timeline
demo.set_timeline_defaults(timeline, model_fps)
timeline.set_visible(True)
timeline.set_current_frame(0)
timeline_tracks = {}
fullbody_id = timeline.add_track(
"Full-Body",
track_type="keyframe",
color=(219, 148, 86),
height_scale=0.5,
)
timeline_tracks[fullbody_id] = {
"name": "Full-Body",
"track_type": "keyframe",
"color": (219, 148, 86),
"height_scale": 0.5,
}
root2d_id = timeline.add_track(
"2D Root",
track_type="keyframe",
color=(150, 100, 200),
height_scale=0.5,
)
timeline_tracks[root2d_id] = {
"name": "2D Root",
"track_type": "keyframe",
"color": (150, 100, 200),
"height_scale": 0.5,
}
lefthand_id = timeline.add_track(
"Left Hand",
track_type="keyframe",
color=(100, 200, 150),
height_scale=0.5,
)
timeline_tracks[lefthand_id] = {
"name": "Left Hand",
"track_type": "keyframe",
"color": (100, 200, 150),
"height_scale": 0.5,
}
righthand_id = timeline.add_track(
"Right Hand",
track_type="keyframe",
color=(200, 100, 150),
height_scale=0.5,
)
timeline_tracks[righthand_id] = {
"name": "Right Hand",
"track_type": "keyframe",
"color": (200, 100, 150),
"height_scale": 0.5,
}
leftfoot_id = timeline.add_track(
"Left Foot",
track_type="keyframe",
color=(219, 148, 86),
height_scale=0.5,
)
timeline_tracks[leftfoot_id] = {
"name": "Left Foot",
"track_type": "keyframe",
"color": (219, 148, 86),
"height_scale": 0.5,
}
rightfoot_id = timeline.add_track(
"Right Foot",
track_type="keyframe",
color=(150, 100, 200),
height_scale=0.5,
)
timeline_tracks[rightfoot_id] = {
"name": "Right Foot",
"track_type": "keyframe",
"color": (150, 100, 200),
"height_scale": 0.5,
}
return timeline, timeline_tracks
timeline, timeline_tracks = build_timeline_tracks()
# These handles are part of GuiElements, but the demo currently uses timeline + buttons
# embedded in the Viser UI instead of custom controls.
gui_play_pause_button = None
gui_next_frame_button = None
gui_prev_frame_button = None
gui_timeline = None
gui_duration_slider = None
# now other gui elements
tab_group = client.gui.add_tab_group()
#
# Playback and Motion generation controls
#
with tab_group.add_tab("Generate", viser.Icon.WALK):
with client.gui.add_folder("Model Selection", expand_by_default=True):
info = get_model_info(model_name)
if info is None:
info = get_model_info(next(iter(MODEL_NAMES)))
def get_allowed_skeleton_labels(dataset_ui_label: str) -> list[str]:
labels = get_skeleton_display_names_for_dataset(dataset_ui_label, family="Kimodo")
if HF_MODE:
labels = [label for label in labels if get_skeleton_key_from_display_name(label) != "SMPLX"]
return labels
dataset_ui_label = "Rigplay" if HF_MODE else info.dataset_ui_label
datasets = ["Rigplay"] if HF_MODE else get_datasets(family="Kimodo")
skeleton_labels = get_allowed_skeleton_labels(dataset_ui_label)
initial_skeleton_label = get_skeleton_display_name(info.skeleton)
if initial_skeleton_label not in skeleton_labels and skeleton_labels:
initial_skeleton_label = skeleton_labels[0]
initial_skeleton_key = (
get_skeleton_key_from_display_name(initial_skeleton_label) if skeleton_labels else None
)
models_for_pair = (
get_models_for_dataset_skeleton(dataset_ui_label, initial_skeleton_key, family="Kimodo")
if initial_skeleton_key is not None
else []
)
version_options = [m.display_name for m in models_for_pair]
initial_version = (
info.display_name
if info.display_name in version_options
else (version_options[0] if version_options else "")
)
gui_dataset_selector = client.gui.add_dropdown(
"Training dataset",
options=datasets,
initial_value=dataset_ui_label,
visible=not HF_MODE,
)
gui_skeleton_selector = client.gui.add_dropdown(
"Model" if HF_MODE else "Skeleton",
options=skeleton_labels,
initial_value=initial_skeleton_label,
)
gui_version_selector = client.gui.add_dropdown(
"Version",
options=version_options,
initial_value=initial_version,
)
gui_version_selector.visible = len(models_for_pair) > 1
gui_model_display = client.gui.add_markdown(
content=f"**Model:** {initial_version}",
)
gui_load_model_button = client.gui.add_button(
"Load model",
hint="Load the selected model (dataset, skeleton, version).",
)
class ModelSelectorHandle:
"""Wrapper so session and callbacks can treat three dropdowns as one."""
def __init__(self):
self._dataset = gui_dataset_selector
self._skeleton = gui_skeleton_selector
self._version = gui_version_selector
self._display = gui_model_display
@property
def value(self) -> str:
return get_short_key_from_display_name(self._version.value) or ""
def set_from_short_key(self, short_key: str) -> None:
info = get_model_info(short_key)
if info is None:
return
dataset_ui_label = "Rigplay" if HF_MODE else info.dataset_ui_label
self._dataset.value = dataset_ui_label
self._skeleton.options = get_allowed_skeleton_labels(dataset_ui_label)
skeleton_label = get_skeleton_display_name(info.skeleton)
if skeleton_label not in self._skeleton.options and self._skeleton.options:
skeleton_label = self._skeleton.options[0]
self._skeleton.value = skeleton_label
skeleton_key = get_skeleton_key_from_display_name(skeleton_label)
if skeleton_key is None:
return
models = get_models_for_dataset_skeleton(dataset_ui_label, skeleton_key, family="Kimodo")
self._version.options = [m.display_name for m in models]
self._version.value = (
info.display_name if info.display_name in self._version.options else self._version.options[0]
)
self._version.visible = len(models) > 1
self._display.content = f"**Model:** {self._version.value}"
gui_model_selector = ModelSelectorHandle()
with client.gui.add_folder("Examples", expand_by_default=True):
examples_base_dir = demo.get_examples_base_dir(model_name, absolute=True)
example_dict = viser_utils.load_example_cases(examples_base_dir)
example_names = list(example_dict.keys())
example_names.append(QWEN_EXAMPLE_NAME)
gui_examples_dropdown = client.gui.add_dropdown(
"Example",
options=example_names,
initial_value=example_names[0],
)
gui_load_example_button = client.gui.add_button(
"Load Example",
hint="Load the selected example (or Qwen agentic prompt plan).",
disabled=False,
)
def update_examples_dropdown(
new_example_dict: dict[str, str],
keep_selection: bool = True,
) -> None:
example_names_local = list(new_example_dict.keys())
if QWEN_EXAMPLE_NAME not in example_names_local:
example_names_local.append(QWEN_EXAMPLE_NAME)
if QWEN_EXAMPLE_LEGACY_NAME not in example_names_local:
example_names_local.append(QWEN_EXAMPLE_LEGACY_NAME)
gui_examples_dropdown.options = example_names_local
if keep_selection and gui_examples_dropdown.value in example_names_local:
return
gui_examples_dropdown.value = example_names_local[0]
with client.gui.add_folder("Generate", expand_by_default=True):
gui_duration = client.gui.add_markdown(content=f"Total duration: {DEFAULT_CUR_DURATION:.1f} (sec)")
def update_duration_gui(duration):
gui_duration.content = f"Total duration: {duration:.1f} (sec)"
def compute_prompt_num_frames(prompt_values):
"""Convert timeline prompt bounds to per-prompt frame counts.
Convention in this demo:
- All prompts except the last are treated as [start_frame, end_frame)
(end is exclusive).
- The last prompt is treated as [start_frame, end_frame] (end is inclusive).
- This assumes the prompts values are sorted by start_frame.
"""
if len(prompt_values) == 0:
return []
num_frames = []
for i, x in enumerate(prompt_values):
cur = x.end_frame - x.start_frame
if i == len(prompt_values) - 1:
cur += 1
num_frames.append(cur)
return num_frames
def update_duration_auto():
session = demo.client_sessions[client_id]
prompt_values = sorted(
[x for x in timeline._prompts.values()],
key=lambda x: x.start_frame,
)
num_frames = compute_prompt_num_frames(prompt_values)
total_nb_frames = sum(num_frames)
cur_duration = total_nb_frames / session.model_fps
set_new_duration(client_id, cur_duration)
update_duration_gui(cur_duration)
gui_num_samples_slider = client.gui.add_slider(
"Num Samples",
min=1,
max=10,
step=1,
initial_value=1,
visible=not HF_MODE,
)
gui_use_soma_layer_checkbox = client.gui.add_checkbox(
"SOMA layer",
initial_value=False,
visible="soma" in (model_name or ""),
)
with client.gui.add_folder("Model Parameters", expand_by_default=False):
gui_seed = client.gui.add_number("Seed", initial_value=42)
with client.gui.add_folder("Diffusion", expand_by_default=False):
gui_diffusion_steps_slider = client.gui.add_slider(
"Denoising Steps",
min=2,
max=1000,
step=10,
initial_value=100,
)
with client.gui.add_folder("Classifier-Free Guidance", expand_by_default=False):
gui_cfg_checkbox = client.gui.add_checkbox(
"Enable",
initial_value=True,
visible=True,
)
gui_cfg_text_weight_slider = client.gui.add_slider(
"Text Weight",
min=0.0,
max=5.0,
step=0.1,
initial_value=2.0,
visible=True,
)
gui_cfg_constraint_weight_slider = client.gui.add_slider(
"Constraint Weight",
min=0.0,
max=5.0,
step=0.1,
initial_value=2.0,
visible=True,
)
with client.gui.add_folder(
"Transitions",
expand_by_default=False,
visible=SHOW_TRANSITION_PARAMS,
):
gui_num_transition_frames_slider = client.gui.add_slider(
"Transition frames",
min=1,
max=10,
step=1,
initial_value=NB_TRANSITION_FRAMES,
visible=True,
)
gui_share_transition_checkbox = client.gui.add_checkbox( # noqa
"Override previous frames",
initial_value=False,
visible=True,
)
gui_percentage_transition_sharing_slider = client.gui.add_slider(
"Percentage overriding frames",
min=0,
max=30,
step=1,
initial_value=10,
visible=True,
)
@gui_share_transition_checkbox.on_update
def _(event: viser.GuiEvent) -> None:
if get_active_session(event.client) is None:
return
# disable the slider if sharing transition is False
gui_percentage_transition_sharing_slider.visible = gui_share_transition_checkbox.value
with client.gui.add_folder("Post Processing", expand_by_default=False):
_model_name = model_name or ""
_postprocess_visible = "g1" not in _model_name
gui_postprocess_checkbox = client.gui.add_checkbox(
"Enable",
initial_value=INIT_POSTPROCESSING,
hint="Apply motion post-processing (not available for G1)",
visible=_postprocess_visible,
)
gui_root_margin = client.gui.add_number(
"Root Margin",
min=0.0,
# max=0.5,
step=0.01,
initial_value=0.04,
hint="Margin for root position (meters). Lower values pin root closer to target.",
visible=INIT_POSTPROCESSING and _postprocess_visible,
)
@gui_postprocess_checkbox.on_update
def _(event: viser.GuiEvent) -> None:
if get_active_session(event.client) is None:
return
# disable the slider if sharing transition is False
gui_root_margin.visible = gui_postprocess_checkbox.value
gui_real_robot_rotations_checkbox = client.gui.add_checkbox(
"Real robot rotations",
initial_value=False,
hint="Project joint rotations to G1 real robot DoF (1-DoF per joint) and clamp to axis limits from the MuJoCo XML.",
visible="g1" in _model_name,
)
with client.gui.add_folder("Qwen Auto-Prompts", expand_by_default=True):
gui_qwen_scene = client.gui.add_text(
"Scene context",
initial_value="A lone figure moving through an empty plaza",
hint="Describe the scene or character context for Qwen to generate motion prompts.",
)
gui_qwen_actions = client.gui.add_slider(
"Target actions",
min=1,
max=10,
step=1,
initial_value=6,
hint="Number of prompt segments to place on the timeline.",
)
gui_qwen_auto_run = client.gui.add_checkbox(
"Auto-run Generate after loading prompts",
initial_value=False,
)
gui_qwen_status = client.gui.add_markdown(content="")
gui_qwen_button = client.gui.add_button("Fill Timeline via Qwen", color="blue")
@gui_qwen_button.on_click
def _(event: viser.GuiEvent) -> None:
event_client = event.client
session = get_active_session(event_client)
if session is None:
return
def _run_qwen_fill_and_maybe_generate() -> None:
gui_qwen_button.disabled = True
gui_qwen_status.content = "⏳ Calling Qwen…"
target_actions = int(gui_qwen_actions.value)
history: list[str] = []
all_texts: list[str] = []
all_durations: list[float] = []
rounds = 0
# Keep requesting batches until we fill target actions (max 10).
while len(all_texts) < target_actions and rounds < 6:
remaining = target_actions - len(all_texts)
batch, history = call_qwen_for_prompts(
scene=gui_qwen_scene.value,
history=history,
requested_actions=min(5, remaining),
)
batch_texts = batch.get("texts", [])
batch_durations = batch.get("durations", [])
for t, d in zip(batch_texts, batch_durations):
if len(all_texts) >= target_actions:
break
all_texts.append(t)
all_durations.append(float(d))
rounds += 1
if len(all_texts) == 0:
gui_qwen_status.content = "⚠ Qwen did not return usable prompts"
gui_qwen_button.disabled = False
return
fps = session.model_fps
event_client.timeline.clear_prompts()
frame_cursor = 0
for i, (txt, dur) in enumerate(zip(all_texts, all_durations)):
n = max(1, int(round(dur * fps)))
start_f = frame_cursor
end_f = frame_cursor + n if i < len(all_texts) - 1 else frame_cursor + n - 1
color = PROMPT_COLORS[i % len(PROMPT_COLORS)]
event_client.timeline.add_prompt(txt, start_f, end_f, color=color)
frame_cursor += n
# Keep timeline readable by expanding zoom to planned sequence length.
target_visible_frames = int(math.ceil(1.10 * frame_cursor))
event_client.timeline.set_zoom_settings(default_num_frames_zoom=max(60, target_visible_frames))
update_duration_auto()
gui_qwen_status.content = f"✓ Loaded {len(all_texts)} Qwen prompt segments"
if gui_qwen_auto_run.value:
gui_qwen_status.content = f"✓ Loaded {len(all_texts)} prompts, generating motion…"
try:
demo.generate(
event_client,
all_texts,
[max(1, int(round(d * fps))) for d in all_durations],
gui_num_samples_slider.value,
gui_seed.value,
gui_diffusion_steps_slider.value,
cfg_weight=[
gui_cfg_text_weight_slider.value,
gui_cfg_constraint_weight_slider.value,
],
cfg_type="separated" if gui_cfg_checkbox.value else "nocfg",
postprocess_parameters={
"post_processing": gui_postprocess_checkbox.value,
"root_margin": gui_root_margin.value,
},
transitions_parameters={
"num_transition_frames": gui_num_transition_frames_slider.value,
"share_transition": gui_share_transition_checkbox.value,
"percentage_transition_override": gui_percentage_transition_sharing_slider.value / 100,
},
real_robot_rotations=gui_real_robot_rotations_checkbox.value,
)
gui_qwen_status.content = f"✓ Generated motion from {len(all_texts)} Qwen actions"
except Exception as exc:
gui_qwen_status.content = f"⚠ Generate error: {exc}"
gui_qwen_button.disabled = False
threading.Thread(target=_run_qwen_fill_and_maybe_generate, daemon=True).start()
gui_generate_button = client.gui.add_button("Generate", color="green")
with client.gui.add_folder("Constraints", expand_by_default=False):
gui_gizmo_space_dropdown = client.gui.add_dropdown(
"Gizmo space",
("Local", "World"),
initial_value="Local",
visible="g1" not in _model_name,
)
gui_edit_constraint_button = client.gui.add_button("Enter Editing Mode")
gui_snap_to_constraint_button = client.gui.add_button(
"Snap to Constraint",
disabled=True,
)
gui_reset_constraint_button = client.gui.add_button(
"Reset Constraint",
disabled=True,
)
gui_undo_drag_button = client.gui.add_button(
"Undo Move",
disabled=True,
)
with client.gui.add_folder("Root 2D Options", expand_by_default=True):
gui_dense_path_checkbox = client.gui.add_checkbox(
"Make Smooth Path",
initial_value=False,
visible=True,
)
gui_show_only_current_constraint_checkbox = client.gui.add_checkbox(
"Show only Current",
initial_value=False,
hint="Show only constraint overlays at the current frame; uncheck to show all.",
)
def apply_constraint_overlay_visibility(session: ClientSession) -> None:
demo._apply_constraint_overlay_visibility(session)
@gui_show_only_current_constraint_checkbox.on_update
def _(event: viser.GuiEvent) -> None:
session = get_active_session(event.client)
if session is None:
return
session.show_only_current_constraint = gui_show_only_current_constraint_checkbox.value
apply_constraint_overlay_visibility(session)
gui_clear_all_constraints_button = client.gui.add_button(
"Clear All Constraints",
color="red",
)
def has_constraint_at_frame(session: ClientSession, frame_idx: int) -> bool:
for constraint_name in ["Full-Body", "End-Effectors", "2D Root"]:
constraint = session.constraints.get(constraint_name)
if constraint is None:
continue
if frame_idx in constraint.keyframes:
return True
return False
def update_snap_to_constraint_button(session: ClientSession) -> None:
gui_snap_to_constraint_button.disabled = not has_constraint_at_frame(session, session.frame_idx)
def ensure_edit_snapshot(session: ClientSession, motion, frame_idx: int) -> None:
if session.edit_mode_snapshot is None:
session.edit_mode_snapshot = {}
if frame_idx in session.edit_mode_snapshot:
return
session.edit_mode_snapshot[frame_idx] = {
"joints_pos": motion.get_joints_pos(frame_idx),
"joints_rot": motion.get_joints_rot(frame_idx),
}
def _update_dense_path(motion, session):
constraint_info = session.constraints["2D Root"].get_constraint_info()
if len(constraint_info["frame_idx"]) > 0:
min_root_frame = min(constraint_info["frame_idx"])
max_root_frame = max(constraint_info["frame_idx"])
motion.set_projected_root_pos_path(
constraint_info["root_pos"][:, [0, 2]],
min_frame_idx=min_root_frame,
max_frame_idx=max_root_frame,
)
# Delay (ms) after last keyframe/interval move before updating path = "on release".
DENSE_PATH_AFTER_RELEASE_MS = 300
def _schedule_dense_path_after_release(session):
"""Schedule a single path update to run after user stops dragging."""
if "2D Root" not in session.constraints or not session.constraints["2D Root"].dense_path:
return
tdata = session.timeline_data
if tdata.get("dense_path_after_release_timer"):
tdata["dense_path_after_release_timer"].cancel()
delay = DENSE_PATH_AFTER_RELEASE_MS / 1000.0
def run():
if not demo.client_active(client_id):
return
sess = demo.client_sessions[client_id]
tdata["dense_path_after_release_timer"] = None
if "2D Root" not in sess.constraints or not sess.constraints["2D Root"].dense_path:
return
mot = list(sess.motions.values())[0]
_update_dense_path(mot, sess)
t = threading.Timer(delay, run)
tdata["dense_path_after_release_timer"] = t
t.start()
@gui_dense_path_checkbox.on_update
def _(event: viser.GuiEvent) -> None:
session = get_active_session(event.client)
if session is None:
return
if gui_dense_path_checkbox.value:
# Make sure 0 and max_frame_idx keyframes are added to the constraint
# since dense path should cover full duration for best model performance
root_2d_track = session.timeline_data["tracks_ids"]["2D Root"]
# add a locked keyframe at 0
start_keyframe_id = client.timeline.add_locked_keyframe( # noqa
root_2d_track,
0,
opacity=0.0,
)
session.timeline_data["keyframes"][start_keyframe_id] = {
"frame": 0,
"track_id": root_2d_track,
"locked": True,
"opacity": 0.0,
"value": None,
}
add_constraint_callback(
start_keyframe_id,
"2D Root",
(0, 0),
verbose=False,
)
# add a locked keyframe at max_frame_idx
end_keyframe_id = client.timeline.add_locked_keyframe(
root_2d_track,
session.max_frame_idx,
opacity=0.0,
)
session.timeline_data["keyframes"][end_keyframe_id] = {
"frame": session.max_frame_idx,
"track_id": root_2d_track,
"locked": True,
"opacity": 0.0,
"value": None,
}
add_constraint_callback(
end_keyframe_id,
"2D Root",
(session.max_frame_idx, session.max_frame_idx),
verbose=False,
)
# add a locked interval only for visual purposes
locked_interval = client.timeline.add_locked_interval( # noqa
root_2d_track,
start_frame=0,
end_frame=session.max_frame_idx,
)
session.timeline_data["intervals"][locked_interval] = {
"track_id": root_2d_track,
"start_frame_idx": 0,
"end_frame_idx": session.max_frame_idx,
"locked": True,
"opacity": 0.3,
"value": None,
}
session.constraints["2D Root"].set_dense_path(gui_dense_path_checkbox.value)
if session.constraints["2D Root"].dense_path:
# update the character motion to reflect the full path
# will be full length by construction, no need to specify min/max frame idx
motion = list(session.motions.values())[0]
_update_dense_path(motion, session)
# remove locked interval and locked keyframes
if not gui_dense_path_checkbox.value:
# Get all locked keyframes
keyframes_to_remove = []
for uuid, keyframe in client.timeline._keyframes.items():
if keyframe.locked:
keyframes_to_remove.append(uuid)
_data = session.timeline_data["keyframes"][uuid]
remove_constraint_callback(
uuid,
constraint_type=session.timeline_data["tracks"][_data["track_id"]]["name"],
frame_range=(_data["frame"], _data["frame"]),
verbose=False,
)
intervals_to_remove = []
# remove all locked intervals
for uuid, interval in client.timeline._intervals.items():
if interval.locked:
intervals_to_remove.append(uuid)
# removing keyframes and intervals
for uuid in keyframes_to_remove:
client.timeline.remove_keyframe(uuid)
for uuid in intervals_to_remove:
client.timeline.remove_interval(uuid)
apply_constraint_overlay_visibility(session)
with client.gui.add_folder(
"Load/Save",
expand_by_default=False,
visible=not HF_MODE,
):
with client.gui.add_folder("Motion", expand_by_default=False):
gui_save_motion_path_text = client.gui.add_text("Save Path", initial_value="output")
gui_save_motion_format_dropdown = client.gui.add_dropdown(
"Save Format",
options=(
["NPZ", "CSV"]
if "g1" in model_name.lower()
else ["NPZ", "AMASS NPZ"]
if "smplx" in model_name.lower()
else ["NPZ", "BVH"]
),
initial_value="NPZ",
)
gui_save_motion_button = client.gui.add_button(
"Save Motion",
hint="Save the current motion (format + path above)",
)
gui_load_motion_path_text = client.gui.add_text(
"Load Path",
initial_value="output.npz",
hint="SOMA .bvh, Kimodo or AMASS .npz, or G1 MuJoCo .csv",
)
gui_load_motion_button = client.gui.add_button(
"Load Motion",
hint="Load the selected motion",
)
with client.gui.add_folder("Constraints", expand_by_default=False):
gui_save_constraints_path_text = client.gui.add_text(
"Save Path", initial_value="output_constraints.json"
)
gui_save_constraints_button = client.gui.add_button("Save Constraints")
gui_load_constraints_path_text = client.gui.add_text(
"Load Path", initial_value="output_constraints.json"
)
gui_load_constraints_button = client.gui.add_button("Load Constraints")
with client.gui.add_folder("Example", expand_by_default=False):
gui_save_example_path_text = client.gui.add_text(
"Save Dir",
initial_value=os.path.join(
demo.get_examples_base_dir(model_name, absolute=True),
"custom_example_1",
),
)
gui_save_example_button = client.gui.add_button("Save Example")
gui_load_example_path_text = client.gui.add_text(
"Load Dir",
initial_value=os.path.join(
demo.get_examples_base_dir(model_name, absolute=True),
"custom_example_1",
),
)
gui_load_gt_checkbox = client.gui.add_checkbox(
"Load GT instead",
initial_value=False,
)
gui_load_example_from_path_button = client.gui.add_button("Load Example")
def _get_primary_motion(session: ClientSession):
return list(session.motions.values())[0]
def _motion_to_numpy_dict(motion) -> dict[str, np.ndarray]:
joints_pos = motion.joints_pos.detach().cpu().numpy()
joints_rot = motion.joints_rot.detach().cpu().numpy()
joints_local_rot = motion.joints_local_rot.detach().cpu().numpy()
if joints_pos.ndim != 3:
raise ValueError(f"Expected unbatched joints_pos with shape [T, J, 3], got {joints_pos.shape}")
if joints_rot.ndim != 4:
raise ValueError(f"Expected unbatched joints_rot with shape [T, J, 3, 3], got {joints_rot.shape}")
if joints_local_rot.ndim != 4:
raise ValueError(
"Expected unbatched joints_local_rot with shape " f"[T, J, 3, 3], got {joints_local_rot.shape}"
)
motion_data = {
"posed_joints": joints_pos,
"global_rot_mats": joints_rot,
"local_rot_mats": joints_local_rot,
"root_positions": joints_pos[:, motion.skeleton.root_idx, :],
}
if motion.foot_contacts is not None:
foot_contacts = motion.foot_contacts.detach().cpu().numpy()
if foot_contacts.ndim != 2:
raise ValueError(
f"Expected unbatched foot_contacts with shape [T, C], got {foot_contacts.shape}"
)
motion_data["foot_contacts"] = foot_contacts
return motion_data
def _coerce_save_path(raw_path: str, *, ext: str) -> str:
"""Ensure the save path ends with the correct extension for the chosen format."""
name = (raw_path or "").strip()
if name == "":
return f"output{ext}"
known_exts = (".npz", ".bvh", ".csv")
if name.lower().endswith(known_exts):
return os.path.splitext(name)[0] + ext
if os.path.splitext(name)[1] == "":
return name + ext
return name
def save_motion(client, save_path, fmt):
session = demo.client_sessions[client.client_id]
motion = _get_primary_motion(session)
motion_data = _motion_to_numpy_dict(motion)
if fmt == "BVH":
save_path = _coerce_save_path(save_path, ext=".bvh")
save_motion_bvh(
save_path,
motion.joints_local_rot,
motion.joints_pos[:, session.skeleton.root_idx, :],
skeleton=session.skeleton,
fps=float(session.model_fps),
)
elif fmt == "CSV":
save_path = _coerce_save_path(save_path, ext=".csv")
data = g1_csv_to_bytes(motion_data, session.skeleton, demo.device)
with open(save_path, "wb") as f:
f.write(data)
elif fmt == "AMASS NPZ":
save_path = _coerce_save_path(save_path, ext=".npz")
data = amass_npz_to_bytes(motion_data, session.skeleton, session.model_fps)
with open(save_path, "wb") as f:
f.write(data)
else:
save_path = _coerce_save_path(save_path, ext=".npz")
save_kimodo_npz(save_path, motion_data)
return save_path
@gui_save_motion_button.on_click
def _(event: viser.GuiEvent) -> None:
event_client = event.client
if get_active_session(event_client) is None:
return
raw_path = gui_save_motion_path_text.value
fmt = str(gui_save_motion_format_dropdown.value).upper()
try:
saved_path = save_motion(event_client, raw_path, fmt)
event_client.add_notification(
title="Motion saved!",
body=f"Saved motion to {saved_path}",
auto_close_seconds=5.0,
color="green",
)
except Exception as e:
import traceback
traceback.print_exc()
event_client.add_notification(
title="Failed to save motion!",
body=str(e),
auto_close_seconds=5.0,
color="red",
)
def load_motion(client, load_path):
session = demo.client_sessions[client.client_id]
fps_arg = session.model_fps if session.model_fps and session.model_fps > 0 else None
motion_dict, num_joints_motion = load_motion_file(load_path, target_fps=fps_arg)
target_skel = registry_skeleton_for_joint_count(num_joints_motion)
current_info = get_model_info(session.model_name)
current_skel = current_info.skeleton if current_info is not None else None
if current_skel != target_skel:
dataset = current_info.dataset if current_info is not None else "RP"
new_key = kimodo_short_key_for_skeleton_dataset(target_skel, dataset)
if new_key is None:
new_key = kimodo_short_key_for_skeleton_dataset(target_skel, "RP")
if new_key is None:
raise ValueError(
f"No Kimodo model found for skeleton {target_skel} (motion has J={num_joints_motion})."
)
if new_key != session.model_name:
gui_model_selector.set_from_short_key(new_key)
apply_model_selection(new_key)
_update_visibility_for_loaded_model(new_key)
client.add_notification(
title="Model switched",
body=f"Switched to {new_key} to match loaded motion (J={num_joints_motion}).",
auto_close_seconds=5.0,
color="blue",
)
session = demo.client_sessions[client.client_id]
joints_pos = motion_dict["posed_joints"].to(device=demo.device, dtype=torch.float32)
joints_rot = motion_dict["global_rot_mats"].to(device=demo.device, dtype=torch.float32)
foot_contacts = motion_dict.get("foot_contacts")
if foot_contacts is not None:
foot_contacts = foot_contacts.to(device=demo.device, dtype=torch.float32)
# Support both batched [B, T, J, 3] and unbatched [T, J, 3]; take first sample if batched
if joints_pos.ndim == 4:
joints_pos = joints_pos[0]
if joints_rot.ndim == 5:
joints_rot = joints_rot[0]
if foot_contacts is not None and foot_contacts.ndim == 3:
foot_contacts = foot_contacts[0]
# Motion must match the current model's skeleton after auto-switch
num_joints_loaded = joints_pos.shape[1]
num_joints_skeleton = session.skeleton.nbjoints
if num_joints_loaded != num_joints_skeleton:
# Backward compat: expand 30-joint SOMA motion to 77
if (
num_joints_loaded == 30
and num_joints_skeleton == 77
and isinstance(session.skeleton, SOMASkeleton77)
):
from kimodo.skeleton import global_rots_to_local_rots
skel30 = SOMASkeleton30().to(demo.device)
if "local_rot_mats" in motion_dict:
local_rot_30 = motion_dict["local_rot_mats"].to(device=demo.device, dtype=torch.float32)
if local_rot_30.ndim == 4:
local_rot_30 = local_rot_30[0]
else:
local_rot_30 = global_rots_to_local_rots(joints_rot, skel30)
local_rot_77 = skel30.to_SOMASkeleton77(local_rot_30)
root_positions = joints_pos[:, skel30.root_idx, :]
joints_rot, joints_pos, _ = session.skeleton.fk(local_rot_77, root_positions)
if foot_contacts is not None and foot_contacts.shape[-1] == 4:
foot_contacts = torch.cat(
[
foot_contacts[..., :2],
foot_contacts[..., 1:2],
foot_contacts[..., 2:4],
foot_contacts[..., 3:4],
],
dim=-1,
)
else:
raise ValueError(
f"The loaded motion has {num_joints_loaded} joints but the current model "
f"({session.model_name}) has {num_joints_skeleton} joints. "
"Load a motion generated with the same skeleton, or switch the model to match the motion."
)
elif joints_rot.shape[1] != num_joints_skeleton:
raise ValueError(
f"Rotation data has {joints_rot.shape[1]} joints but the current model has "
f"{num_joints_skeleton} joints. The NPZ may be corrupted or from a different skeleton."
)
# Apply G1 real robot projection (1-DoF per joint + axis limits) if enabled.
if (
"g1" in session.model_name
and isinstance(session.skeleton, G1Skeleton34)
and gui_real_robot_rotations_checkbox.value
):
joints_pos, joints_rot = generation.apply_g1_real_robot_projection(
session.skeleton, joints_pos, joints_rot
)
# Update duration and frame range based on loaded motion
num_frames = joints_pos.shape[0]
duration = num_frames / session.model_fps
# Update GUI elements
session.cur_duration = duration
session.max_frame_idx = num_frames - 1
# Clear existing motions and add the loaded one
demo.clear_motions(client.client_id)
demo.add_character_motion(
client,
session.skeleton,
joints_pos,
joints_rot,
foot_contacts,
)
# Reset to frame 0
demo.set_frame(client.client_id, 0)
@gui_load_motion_button.on_click
def _(event: viser.GuiEvent) -> None:
event_client = event.client
session = get_active_session(event_client)
if session is None:
return
load_path = gui_load_motion_path_text.value
loading_notif = event_client.add_notification(
title="Loading motion...",
body=f"Loading from {load_path}",
loading=True,
with_close_button=False,
auto_close_seconds=None,
)
try:
load_motion(event_client, load_path)
loading_notif.title = "Motion loaded!"
loading_notif.body = f"Loaded motion from {load_path} ({session.max_frame_idx + 1} frames, {session.cur_duration:.2f}s)"
loading_notif.loading = False
loading_notif.with_close_button = True
loading_notif.auto_close_seconds = 5.0
loading_notif.color = "green"
except Exception as e:
import traceback
traceback.print_exc()
loading_notif.title = "Failed to load motion!"
loading_notif.body = str(e)
loading_notif.loading = False
loading_notif.with_close_button = True
loading_notif.auto_close_seconds = 10.0
loading_notif.color = "red"
def save_constraints(client, save_path):
session = demo.client_sessions[client.client_id]
# Keep save behavior aligned with demo frame convention:
# valid frame indices are [0, max_frame_idx], so count is +1.
num_frames = session.max_frame_idx + 1
model_bundle = demo.load_model(session.model_name)
constraints_lst = demo.compute_model_constraints_lst(session, model_bundle, num_frames)
save_constraints_lst(save_path, constraints_lst)
@gui_save_constraints_button.on_click
def _(event: viser.GuiEvent) -> None:
event_client = event.client
if get_active_session(event_client) is None:
return
try:
save_path = gui_save_constraints_path_text.value
save_constraints(event_client, save_path)
event_client.add_notification(
title="Constraints saved!",
body=f"Saved constraints to {save_path}",
auto_close_seconds=5.0,
color="green",
)
except Exception as e:
import traceback
traceback.print_exc()
event_client.add_notification(
title="Failed to save constraints!",
body=str(e),
auto_close_seconds=10.0,
color="red",
)
def load_constraints(client, load_path):
session = demo.client_sessions[client.client_id]
constraints_lst = load_constraints_lst(load_path, skeleton=session.skeleton)
# Clear existing constraints first
with session.timeline_data["keyframe_update_lock"]:
for constraint in list(session.constraints.values()):
constraint.clear()
client.timeline.clear_keyframes()
client.timeline.clear_intervals()
# Add loaded constraints to the session
# We need to directly add constraint data, not read from current motion
device = demo.device
for constraint_obj in constraints_lst:
constraint_type = constraint_obj.name
# decompose the frame indices into intervals or single keyframes
frame_indices = constraint_obj.frame_indices
(
intervals,
intervals_indices,
single_frames,
single_frames_indices,
) = extract_intervals_and_singles(frame_indices)
load_targets: list[dict] = []
root_pos = None
if constraint_type == "root2d":
# smooth_root_2d is [T, 2] (x, z), convert to [T, 3] (x, 0, z)
num_frames = constraint_obj.smooth_root_2d.shape[0]
root_pos = torch.zeros(num_frames, 3, device=device)
root_pos[:, 0] = constraint_obj.smooth_root_2d[:, 0]
root_pos[:, 2] = constraint_obj.smooth_root_2d[:, 1]
load_targets = [
{
"track_name": "2D Root",
"constraint_track": session.constraints["2D Root"],
}
]
elif constraint_type == "fullbody":
load_targets = [
{
"track_name": "Full-Body",
"constraint_track": session.constraints["Full-Body"],
}
]
elif constraint_type in {
"left-hand",
"right-hand",
"left-foot",
"right-foot",
}:
track_name = {
"left-hand": "Left Hand",
"right-hand": "Right Hand",
"left-foot": "Left Foot",
"right-foot": "Right Foot",
}[constraint_type]
load_targets = [
{
"track_name": track_name,
"constraint_track": session.constraints["End-Effectors"],
"joint_names": constraint_obj.joint_names,
"end_effector_type": constraint_type,
}
]
elif constraint_type in {"end-effector", "end-effectors"}:
# Backward-compatible loader:
# split a generic end-effector constraint into per-limb timeline tracks.
joint_names_set = set(constraint_obj.joint_names)
for jname, track_name, eff_type in [
("LeftHand", "Left Hand", "left-hand"),
("RightHand", "Right Hand", "right-hand"),
("LeftFoot", "Left Foot", "left-foot"),
("RightFoot", "Right Foot", "right-foot"),
]:
if jname not in joint_names_set:
continue
target_joint_names = [jname]
if "Hips" in joint_names_set:
target_joint_names.append("Hips")
load_targets.append(
{
"track_name": track_name,
"constraint_track": session.constraints["End-Effectors"],
"joint_names": target_joint_names,
"end_effector_type": eff_type,
}
)
if not load_targets:
raise KeyError(
"No recognized end-effector joint in constraint "
f"joint_names={constraint_obj.joint_names}"
)
else:
raise KeyError(f"Unsupported constraint type in loader: {constraint_type}")
for target in load_targets:
track_id = session.timeline_data["tracks_ids"][target["track_name"]]
constraint_track = target["constraint_track"]
# add intervals
for (start_idx, end_idx), (start_idx_t, end_idx_t) in zip(intervals, intervals_indices):
# Add to timeline
interval_id = client.timeline.add_interval(track_id, start_idx, end_idx)
session.timeline_data["intervals"][interval_id] = {
"track_id": track_id,
"start_frame_idx": start_idx,
"end_frame_idx": end_idx,
"locked": False,
"opacity": 1.0,
"value": None,
}
if constraint_type == "root2d":
constraint_track.add_interval(
interval_id,
start_idx,
end_idx,
root_pos[start_idx_t : end_idx_t + 1],
)
elif constraint_type == "fullbody":
constraint_track.add_interval(
interval_id,
start_idx,
end_idx,
constraint_obj.global_joints_positions[start_idx_t : end_idx_t + 1],
constraint_obj.global_joints_rots[start_idx_t : end_idx_t + 1],
)
else:
constraint_track.add_interval(
interval_id,
start_idx,
end_idx,
constraint_obj.global_joints_positions[start_idx_t : end_idx_t + 1],
constraint_obj.global_joints_rots[start_idx_t : end_idx_t + 1],
target["joint_names"],
target["end_effector_type"],
)
# add keyframes
for frame, frame_t in zip(single_frames, single_frames_indices):
# Add to timeline
keyframe_id = client.timeline.add_keyframe(track_id, frame)
session.timeline_data["keyframes"][keyframe_id] = {
"track_id": track_id,
"frame": frame,
"locked": False,
"opacity": 1.0,
"value": None,
}
if constraint_type == "root2d":
constraint_track.add_keyframe(
keyframe_id,
frame,
root_pos[frame_t],
)
elif constraint_type == "fullbody":
constraint_track.add_keyframe(
keyframe_id,
frame,
constraint_obj.global_joints_positions[frame_t],
constraint_obj.global_joints_rots[frame_t],
)
else:
constraint_track.add_keyframe(
keyframe_id,
frame,
constraint_obj.global_joints_positions[frame_t],
constraint_obj.global_joints_rots[frame_t],
target["joint_names"],
target["end_effector_type"],
)
@gui_load_constraints_button.on_click
def _(event: viser.GuiEvent) -> None:
event_client = event.client
if get_active_session(event_client) is None:
return
try:
load_path = gui_load_constraints_path_text.value
load_constraints(event_client, load_path)
session = demo.client_sessions[event_client.client_id]
apply_constraint_overlay_visibility(session)
event_client.add_notification(
title="Constraints loaded!",
body=f"Loaded constraints from {load_path}",
auto_close_seconds=5.0,
color="green",
)
except Exception as e:
import traceback
traceback.print_exc()
event_client.add_notification(
title="Failed to load constraints!",
body=str(e),
auto_close_seconds=10.0,
color="red",
)
with client.gui.add_folder("Exports", expand_by_default=False):
with client.gui.add_folder("Screenshot", expand_by_default=False, visible=not HF_MODE):
gui_screenshot_path_text = client.gui.add_text(
"Save Path",
initial_value="render.png",
hint="Filename for the screenshot (PNG).",
)
gui_screenshot_button = client.gui.add_button(
"Download Screenshot",
hint="Capture the current canvas and download a PNG.",
)
with client.gui.add_folder("Video", expand_by_default=False, visible=not HF_MODE):
gui_video_path_text = client.gui.add_text(
"Save Path",
initial_value="render.mp4",
hint="Filename for the video (MP4).",
)
gui_video_button = client.gui.add_button(
"Download Video",
hint="Render every frame and download as MP4.",
)
with client.gui.add_folder("Motion", expand_by_default=True):
gui_download_name_text = client.gui.add_text(
"Name",
initial_value="output",
hint="Base filename to save as (extension will be added based on format if omitted).",
)
gui_download_format_dropdown = client.gui.add_dropdown(
"Format",
options=(
["NPZ", "CSV"]
if "g1" in model_name.lower()
else ["NPZ", "AMASS NPZ"]
if "smplx" in model_name.lower()
else ["NPZ", "BVH"]
),
initial_value="NPZ",
)
gui_download_button = client.gui.add_button(
"Download",
hint="Download the current motion (format + name above).",
)
def _download_bytes_to_browser(
event_client: viser.ClientHandle,
*,
data: bytes,
filename: str,
mime_type: str = "application/octet-stream",
) -> None:
"""Trigger a browser download for an in-memory byte payload.
Important: this intentionally does NOT use `showSaveFilePicker()` to avoid
Chrome/Edge's file-write permission prompt ("this site can see edits you make").
If you want "always ask where to save", configure your browser download settings.
"""
import base64
import json
# Base64 is the most robust way to move binary over our websocket JS channel.
b64 = base64.b64encode(data).decode("ascii")
js = f"""
(() => {{
const filename = {json.dumps(filename)};
const mimeType = {json.dumps(mime_type)};
const b64 = {json.dumps(b64)};
// Decode base64 -> Uint8Array.
const binStr = atob(b64);
const bytes = new Uint8Array(binStr.length);
for (let i = 0; i < binStr.length; i++) bytes[i] = binStr.charCodeAt(i);
const blob = new Blob([bytes], {{ type: mimeType }});
// Standard browser download behavior.
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = filename;
document.body.appendChild(a);
a.click();
a.remove();
URL.revokeObjectURL(url);
}})();
"""
# Reuse viser’s JS execution mechanism (used for Plotly setup).
from viser import _messages as _viser_messages
event_client.gui._websock_interface.queue_message( # type: ignore[attr-defined]
_viser_messages.RunJavascriptMessage(source=js)
)
def _motion_to_npz_bytes(motion) -> bytes:
motion_data = _motion_to_numpy_dict(motion)
return kimodo_npz_to_bytes(motion_data)
def _motion_to_csv_bytes(motion, session: ClientSession) -> bytes:
motion_data = _motion_to_numpy_dict(motion)
return g1_csv_to_bytes(motion_data, session.skeleton, demo.device)
def _motion_to_amass_npz_bytes(motion, session: ClientSession) -> bytes:
motion_data = _motion_to_numpy_dict(motion)
return amass_npz_to_bytes(motion_data, session.skeleton, session.model_fps)
def _get_motion_export_formats(loaded_model_name: str) -> list[str]:
model_name_lower = (loaded_model_name or "").lower()
if "g1" in model_name_lower:
return ["NPZ", "CSV"]
if "smplx" in model_name_lower:
return ["NPZ", "AMASS NPZ"]
return ["NPZ", "BVH"]
def _update_format_dropdown(dropdown, loaded_model_name: str) -> None:
new_options = _get_motion_export_formats(loaded_model_name)
current_value = str(dropdown.value)
dropdown.options = new_options
dropdown.value = current_value if current_value in new_options else new_options[0]
def _update_motion_export_dropdown(loaded_model_name: str) -> None:
_update_format_dropdown(gui_download_format_dropdown, loaded_model_name)
_update_format_dropdown(gui_save_motion_format_dropdown, loaded_model_name)
def _coerce_download_filename(raw_name: str, *, ext: str) -> str:
"""Coerce a user-entered filename to a safe basename with the desired extension.
- If empty: uses "output{ext}"
- If no extension: appends ext
- If endswith a known export extension: rewrites extension to ext (prevents mismatches)
- Any provided directory components are stripped
"""
import os
name = (raw_name or "").strip()
name = os.path.basename(name.replace("\\", "/"))
if name == "":
return f"output{ext}"
known_exts = (".npz", ".bvh", ".csv", ".png", ".mp4")
lower = name.lower()
if lower.endswith(known_exts):
return os.path.splitext(name)[0] + ext
root, cur_ext = os.path.splitext(name)
if cur_ext == "":
return name + ext
return name
def _get_render_size(event_client: viser.ClientHandle) -> tuple[int, int]:
width = int(event_client.camera.image_width)
height = int(event_client.camera.image_height)
if width <= 0 or height <= 0:
# Fall back to a reasonable default if the camera hasn't synced yet.
return (1280, 720)
return (width, height)
def _round_up_to_multiple(value: int, multiple: int) -> int:
if multiple <= 0:
return value
return ((value + multiple - 1) // multiple) * multiple
def _download_canvas_to_browser(event_client: viser.ClientHandle, *, filename: str) -> None:
"""Use the client-side canvas save path to avoid server-side renders."""
import json
js = f"""
(() => {{
const filename = {json.dumps(filename)};
const canvases = Array.from(document.querySelectorAll("canvas"));
if (!canvases.length) {{
console.error("No canvases found to save.");
return;
}}
// Pick the largest canvas by area (usually the main 3D view).
const canvas = canvases.reduce((best, cur) => {{
const bestArea = (best?.width || 0) * (best?.height || 0);
const curArea = (cur?.width || 0) * (cur?.height || 0);
return curArea > bestArea ? cur : best;
}}, null);
if (!canvas) {{
console.error("No canvas selected to save.");
return;
}}
canvas.toBlob((blob) => {{
if (!blob) {{
console.error("Export failed");
return;
}}
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = filename;
document.body.appendChild(a);
a.click();
a.remove();
URL.revokeObjectURL(url);
}}, "image/png");
}})();
"""
from viser import _messages as _viser_messages
event_client.gui._websock_interface.queue_message( # type: ignore[attr-defined]
_viser_messages.RunJavascriptMessage(source=js)
)
@gui_screenshot_button.on_click
def _(event: viser.GuiEvent) -> None:
event_client = event.client
if get_active_session(event_client) is None:
return
try:
filename = _coerce_download_filename(
str(gui_screenshot_path_text.value),
ext=".png",
)
_download_canvas_to_browser(event_client, filename=filename)
event_client.add_notification(
title="Screenshot download started",
body=f"Saving {filename}",
auto_close_seconds=5.0,
color="green",
)
except Exception as e:
import traceback
traceback.print_exc()
event_client.add_notification(
title="Failed to download screenshot!",
body=str(e),
auto_close_seconds=10.0,
color="red",
)
@gui_video_button.on_click
def _(event: viser.GuiEvent) -> None:
event_client = event.client
session = get_active_session(event_client)
if session is None:
return
recording_notification: viser.NotificationHandle | None = None
try:
recording_notification = event_client.add_notification(
title="Recording video...",
body="Saving frames, please wait.",
loading=True,
with_close_button=False,
auto_close_seconds=None,
color="blue",
)
event_client.timeline.disable_constraints()
width, height = _get_render_size(event_client)
# Avoid ffmpeg macro block resizing warnings.
width = _round_up_to_multiple(width, 16)
height = _round_up_to_multiple(height, 16)
original_frame = session.frame_idx
frames = []
for frame_idx in range(session.max_frame_idx + 1):
demo.set_frame(
event_client.client_id,
frame_idx,
update_timeline=True,
)
frames.append(
event_client.get_render(
height=height,
width=width,
transport_format="jpeg",
)
)
# Restore the original frame (and timeline).
demo.set_frame(event_client.client_id, original_frame)
import imageio.v3 as iio
filename = _coerce_download_filename(
str(gui_video_path_text.value),
ext=".mp4",
)
payload = iio.imwrite(
"<bytes>",
frames,
extension=".mp4",
fps=float(session.model_fps),
codec="h264",
plugin="pyav",
)
event_client.send_file_download(filename, payload, save_immediately=True)
event_client.add_notification(
title="Video download started",
body=f"Saving {filename}",
auto_close_seconds=5.0,
color="green",
)
except Exception as e:
import traceback
traceback.print_exc()
event_client.add_notification(
title="Failed to download video!",
body=str(e),
auto_close_seconds=10.0,
color="red",
)
finally:
event_client.timeline.enable_constraints()
if recording_notification is not None:
recording_notification.remove()
@gui_download_button.on_click
def _(event: viser.GuiEvent) -> None:
event_client = event.client
session = get_active_session(event_client)
if session is None:
return
motion = _get_primary_motion(session)
try:
fmt = str(gui_download_format_dropdown.value).upper()
raw_name = str(gui_download_name_text.value)
if fmt == "BVH":
filename = _coerce_download_filename(raw_name, ext=".bvh")
payload = motion_to_bvh_bytes(
motion.joints_local_rot,
motion.joints_pos[:, session.skeleton.root_idx, :], # root positions
skeleton=session.skeleton,
fps=float(session.model_fps),
)
mime = "text/plain"
elif fmt == "CSV":
filename = _coerce_download_filename(raw_name, ext=".csv")
payload = _motion_to_csv_bytes(motion, session)
mime = "text/csv"
elif fmt == "AMASS NPZ":
filename = _coerce_download_filename(raw_name, ext=".npz")
payload = _motion_to_amass_npz_bytes(motion, session)
mime = "application/octet-stream"
else:
# Default to NPZ (most common and matches existing save/load).
filename = _coerce_download_filename(raw_name, ext=".npz")
payload = _motion_to_npz_bytes(motion)
mime = "application/octet-stream"
_download_bytes_to_browser(
event_client,
data=payload,
filename=filename,
mime_type=mime,
)
event_client.add_notification(
title="Download started",
body=f"Saving {filename}",
auto_close_seconds=5.0,
color="green",
)
except Exception as e:
import traceback
traceback.print_exc()
event_client.add_notification(
title="Failed to download motion!",
body=str(e),
auto_close_seconds=10.0,
color="red",
)
@gui_save_example_button.on_click
def _(event: viser.GuiEvent) -> None:
from kimodo.tools import save_json
event_client = event.client
session = get_active_session(event_client)
if session is None:
return
save_dir = gui_save_example_path_text.value
if os.path.exists(save_dir):
event_client.add_notification(
title="Failed to save example!",
body="Example directory already exists",
auto_close_seconds=10.0,
color="red",
)
return
try:
os.makedirs(save_dir)
# save the constraints
constraint_path = os.path.join(save_dir, "constraints.json")
save_constraints(event_client, constraint_path)
# save the motion
motion_path = os.path.join(save_dir, "motion.npz")
save_motion(event_client, motion_path, "NPZ")
# save the gui metadata
meta_path = os.path.join(save_dir, "meta.json")
prompt_texts = []
prompt_durations_sec = []
prompt_values = sorted(
[x for x in client.timeline._prompts.values()],
key=lambda x: x.start_frame,
)
for i, prompt in enumerate(prompt_values):
prompt_texts.append(prompt.text)
# Match demo/generation convention:
# non-last prompts: [start, end) ; last prompt: [start, end].
n_frames = prompt.end_frame - prompt.start_frame
if i == len(prompt_values) - 1:
n_frames += 1
prompt_durations_sec.append(n_frames / session.model_fps)
if len(prompt_texts) == 1:
meta_info = {
"text": prompt_texts[0],
"duration": prompt_durations_sec[0],
}
else:
meta_info = {
"texts": prompt_texts,
"durations": prompt_durations_sec,
}
meta_info["num_samples"] = gui_num_samples_slider.value
meta_info["seed"] = gui_seed.value
meta_info["diffusion_steps"] = gui_diffusion_steps_slider.value
meta_info["cfg"] = {
"enabled": gui_cfg_checkbox.value,
"text_weight": gui_cfg_text_weight_slider.value,
"constraint_weight": gui_cfg_constraint_weight_slider.value,
}
save_json(meta_path, meta_info)
# update the example dropdown
session.example_dict = viser_utils.load_example_cases(session.examples_base_dir)
update_examples_dropdown(session.example_dict, keep_selection=True)
event_client.add_notification(
title="Example saved!",
body=f"Saved example to {save_dir}",
auto_close_seconds=5.0,
color="green",
)
except Exception as e:
import traceback
traceback.print_exc()
event_client.add_notification(
title="Failed to save example!",
body=str(e),
auto_close_seconds=10.0,
color="red",
)
def set_new_duration(client_id, new_duration):
session = demo.client_sessions[client_id]
session.cur_duration = new_duration
update_duration_gui(new_duration)
session.max_frame_idx = int(session.cur_duration * session.model_fps - 1)
if session.frame_idx > session.max_frame_idx:
demo.set_frame(client_id, session.max_frame_idx)
def apply_model_selection(new_model_name: str) -> None:
session = demo.client_sessions[client_id]
if new_model_name == session.model_name:
return
session.playing = False # Pause playback when switching models.
old_model_fps = session.model_fps
old_duration = session.cur_duration
old_prompts = [
(prompt.text, prompt.start_frame, prompt.end_frame) for prompt in client.timeline._prompts.values()
]
old_default_zoom_frames = client.timeline._default_num_frames_zoom
old_max_zoom_frames = client.timeline._max_frames_zoom
model_bundle = demo.load_model(new_model_name)
# Clear motions and constraints when switching models.
if session.edit_mode and session.motions:
exit_editing_mode(session)
session.edit_mode = False
demo.clear_motions(client_id)
with session.timeline_data["keyframe_update_lock"]:
for constraint in list(session.constraints.values()):
constraint.clear()
session.constraints = demo.build_constraint_tracks(client, model_bundle.skeleton)
session.timeline_data["keyframes"] = {}
session.timeline_data["intervals"] = {}
client.timeline.clear_keyframes()
client.timeline.clear_intervals()
session.model_name = new_model_name
session.model_fps = model_bundle.model_fps
session.skeleton = model_bundle.skeleton
session.motion_rep = model_bundle.motion_rep
session.cur_duration = old_duration
session.max_frame_idx = int(session.cur_duration * session.model_fps - 1)
session.frame_idx = 0
session.edit_mode = False
demo.set_timeline_defaults(client.timeline, session.model_fps)
client.timeline.set_current_frame(0)
gui_model_fps.value = session.model_fps
update_duration_gui(session.cur_duration)
if old_model_fps > 0:
default_zoom_seconds = old_default_zoom_frames / old_model_fps
max_zoom_seconds = old_max_zoom_frames / old_model_fps
new_default_zoom = int(round(default_zoom_seconds * session.model_fps))
new_max_zoom = int(round(max_zoom_seconds * session.model_fps))
new_default_zoom = max(1, new_default_zoom)
new_max_zoom = max(new_default_zoom, new_max_zoom)
client.timeline.set_zoom_settings(
default_num_frames_zoom=new_default_zoom,
max_frames_zoom=new_max_zoom,
)
client.timeline.clear_prompts()
if old_prompts and old_model_fps > 0:
for i, (prompt_text, start_frame, end_frame) in enumerate(old_prompts):
start_sec = start_frame / old_model_fps
end_sec = end_frame / old_model_fps
new_start = int(round(start_sec * session.model_fps))
new_end = int(round(end_sec * session.model_fps))
new_start = max(0, min(new_start, session.max_frame_idx))
new_end = max(new_start, min(new_end, session.max_frame_idx))
color = PROMPT_COLORS[i % len(PROMPT_COLORS)]
client.timeline.add_prompt(prompt_text, new_start, new_end, color=color)
session.examples_base_dir = demo.get_examples_base_dir(new_model_name, absolute=True)
session.example_dict = viser_utils.load_example_cases(session.examples_base_dir)
update_examples_dropdown(session.example_dict, keep_selection=False)
gui_save_example_path_text.value = os.path.join(
demo.get_examples_base_dir(new_model_name, absolute=True),
"custom_example_1",
)
gui_load_example_path_text.value = os.path.join(
demo.get_examples_base_dir(new_model_name, absolute=True),
"custom_example_1",
)
demo.add_character_motion(client, session.skeleton)
apply_constraint_overlay_visibility(session)
def _update_version_and_display_from_dataset_skeleton() -> None:
dataset_ui = gui_dataset_selector.value
skeleton_display = gui_skeleton_selector.value
skeleton_val = get_skeleton_key_from_display_name(skeleton_display)
if skeleton_val is None:
return
models = get_models_for_dataset_skeleton(dataset_ui, skeleton_val, family="Kimodo")
if not models:
return
gui_version_selector.options = [m.display_name for m in models]
gui_version_selector.value = models[0].display_name
gui_version_selector.visible = len(models) > 1
gui_model_display.content = f"**Model:** {models[0].display_name}"
def _update_visibility_for_loaded_model(loaded_model_name: str) -> None:
"""Update model-specific controls from the currently loaded model only."""
if not loaded_model_name:
return
_update_motion_export_dropdown(loaded_model_name)
gui_use_soma_layer_checkbox.visible = "soma" in loaded_model_name
_is_g1 = "g1" in loaded_model_name
gui_real_robot_rotations_checkbox.visible = _is_g1
gui_postprocess_checkbox.visible = not _is_g1
gui_root_margin.visible = not _is_g1 and gui_postprocess_checkbox.value
if _is_g1:
gui_gizmo_space_dropdown.value = "Local"
gui_gizmo_space_dropdown.visible = not _is_g1
gui_gizmo_space_dropdown.disabled = _is_g1
def _on_load_model_click(event: viser.GuiEvent) -> None:
"""Load the currently selected model (called from Load model button)."""
if get_active_session(event.client) is None:
return
new_model_name = gui_model_selector.value
if not new_model_name:
return
info = get_model_info(new_model_name)
if info is None:
return
session = demo.client_sessions[event.client.client_id]
if new_model_name == session.model_name:
return
loading_notif = event.client.add_notification(
title="Loading model...",
body=f"Loading {info.display_name}",
loading=True,
with_close_button=False,
)
try:
apply_model_selection(new_model_name)
_update_visibility_for_loaded_model(new_model_name)
loading_notif.title = "Model loaded"
loading_notif.body = f"{info.display_name} is ready."
loading_notif.loading = False
loading_notif.with_close_button = True
loading_notif.auto_close_seconds = 5.0
loading_notif.color = "green"
except Exception as e:
loading_notif.loading = False
loading_notif.with_close_button = True
event.client.add_notification(
title="Model failed to load",
body=str(e),
color="red",
auto_close_seconds=10.0,
)
gui_model_selector.set_from_short_key(session.model_name)
@gui_load_model_button.on_click
def _(event: viser.GuiEvent) -> None:
_on_load_model_click(event)
@gui_dataset_selector.on_update
def _(event: viser.GuiEvent) -> None:
if get_active_session(event.client) is None:
return
skeleton_labels = get_allowed_skeleton_labels(gui_dataset_selector.value)
gui_skeleton_selector.options = skeleton_labels
gui_skeleton_selector.value = skeleton_labels[0] if skeleton_labels else ""
_update_version_and_display_from_dataset_skeleton()
@gui_skeleton_selector.on_update
def _(event: viser.GuiEvent) -> None:
if get_active_session(event.client) is None:
return
_update_version_and_display_from_dataset_skeleton()
@gui_version_selector.on_update
def _(event: viser.GuiEvent) -> None:
if get_active_session(event.client) is None:
return
info = get_model_info(gui_model_selector.value)
if info is not None:
gui_model_display.content = f"**Model:** {info.display_name}"
@gui_use_soma_layer_checkbox.on_update
def _(event: viser.GuiEvent) -> None:
session = get_active_session(event.client)
if session is None or "soma" not in (session.model_name or ""):
return
loading_notif = event.client.add_notification(
title="Applying SOMA layer...",
body="Updating mesh.",
loading=True,
with_close_button=False,
)
try:
current_motion = list(session.motions.values())[0] if session.motions else None
current_frame_idx = session.frame_idx
# Recreate the character to apply the new SOMA mesh mode selection.
demo.clear_motions(event.client.client_id)
if current_motion is None:
demo.add_character_motion(event.client, session.skeleton)
else:
demo.add_character_motion(
event.client,
session.skeleton,
current_motion.joints_pos,
current_motion.joints_rot,
current_motion.foot_contacts,
)
demo.set_frame(event.client.client_id, current_frame_idx)
except Exception as e:
print(e)
event.client.add_notification(
title="SOMA layer failed",
body=str(e),
color="red",
auto_close_seconds=10.0,
)
gui_use_soma_layer_checkbox.value = not gui_use_soma_layer_checkbox.value
finally:
loading_notif.loading = False
loading_notif.with_close_button = True
loading_notif.auto_close_seconds = 2.0
@gui_real_robot_rotations_checkbox.on_update
def _(event: viser.GuiEvent) -> None:
session = get_active_session(event.client)
if session is None or "g1" not in session.model_name:
return
if not isinstance(session.skeleton, G1Skeleton34) or not session.motions:
return
if not gui_real_robot_rotations_checkbox.value:
return
# Reproject all displayed G1 motions to real robot DoF (1-DoF per joint + axis limits).
from kimodo.skeleton import global_rots_to_local_rots
current_frame_idx = session.frame_idx
for motion in session.motions.values():
if motion.length <= 1:
continue
rest_pos = motion.joints_pos[0:1]
rest_rot = motion.joints_rot[0:1]
same_as_rest = (motion.joints_pos - rest_pos).abs().max().item() < 1e-6 and (
motion.joints_rot - rest_rot
).abs().max().item() < 1e-6
if same_as_rest:
continue
new_pos, new_rot = generation.apply_g1_real_robot_projection(
session.skeleton,
motion.joints_pos,
motion.joints_rot,
)
motion.joints_pos = new_pos
motion.joints_rot = new_rot
motion.joints_local_rot = global_rots_to_local_rots(new_rot, session.skeleton)
# Refresh skeleton and skinned mesh caches so the viz uses new positions.
motion.precompute_mesh_info()
demo.set_frame(event.client.client_id, current_frame_idx)
event.client.add_notification(
title="Real robot projection applied",
body="The motion is projected to G1 real robot DoF (1-DoF per joint, clamped to axis limits).",
auto_close_seconds=4.0,
color="green",
)
def load_example_from_path(
event_client: viser.ClientHandle,
example_path: str,
load_gt: bool = False,
) -> None:
from kimodo.meta import parse_prompts_from_meta
from kimodo.tools import load_json
session = get_active_session(event_client)
if session is None:
return
# Pause playback when loading an example.
session.playing = False
if not os.path.isdir(example_path):
event_client.add_notification(
title="Example path not found",
body=f"Directory does not exist: {example_path}",
auto_close_seconds=5.0,
color="red",
)
return
try:
# constraints
constraints_path = os.path.join(example_path, "constraints.json")
if os.path.exists(constraints_path):
load_constraints(event_client, constraints_path)
else:
# clear all existing constraints
with session.timeline_data["keyframe_update_lock"]:
for constraint in list(session.constraints.values()):
constraint.clear()
event_client.timeline.clear_keyframes()
event_client.timeline.clear_intervals()
# motion
motion_filename = "gt_motion.npz" if load_gt else "motion.npz"
motion_path = os.path.join(example_path, motion_filename)
if os.path.exists(motion_path):
load_motion(event_client, motion_path)
# metadata
meta_path = os.path.join(example_path, "meta.json")
if os.path.exists(meta_path):
meta_info = load_json(meta_path)
event_client.timeline.clear_prompts()
texts, durations_sec = parse_prompts_from_meta(meta_info)
fps = session.model_fps
# Convert durations (seconds) to consecutive frame bounds
num_frames = 0
frame_bounds = []
for i, d in enumerate(durations_sec):
n_frames = max(1, int(round(d * fps)))
start_frame = num_frames
# Inverse of compute_prompt_num_frames():
# non-last prompts end at next prompt start (exclusive),
# last prompt includes its end frame.
if i == len(durations_sec) - 1:
end_frame = num_frames + n_frames - 1
else:
end_frame = num_frames + n_frames
frame_bounds.append((start_frame, end_frame))
num_frames += n_frames
# Adapt timeline zoom to the loaded motion.
target_visible_frames = int(math.ceil(1.10 * num_frames))
event_client.timeline.set_zoom_settings(
default_num_frames_zoom=target_visible_frames,
)
for i, (prompt_text, (start_frame, end_frame)) in enumerate(zip(texts, frame_bounds)):
color = PROMPT_COLORS[i % len(PROMPT_COLORS)]
event_client.timeline.add_prompt(prompt_text, start_frame, end_frame, color=color)
update_duration_auto()
# Only load optional fields if present
if "num_samples" in meta_info:
gui_num_samples_slider.value = meta_info["num_samples"]
if "seed" in meta_info:
gui_seed.value = meta_info["seed"]
if "diffusion_steps" in meta_info:
gui_diffusion_steps_slider.value = meta_info["diffusion_steps"]
if "cfg" in meta_info:
cfg = meta_info["cfg"]
if "enabled" in cfg:
gui_cfg_checkbox.value = cfg["enabled"]
if "text_weight" in cfg:
gui_cfg_text_weight_slider.value = cfg["text_weight"]
if "constraint_weight" in cfg:
gui_cfg_constraint_weight_slider.value = cfg["constraint_weight"]
# Set frame to 0 when example is loaded.
session.frame_idx = 0
event_client.timeline.set_current_frame(0)
demo.set_frame(event_client.client_id, 0)
event_client.add_notification(
title="Example loaded!",
body=f"Loaded example from {example_path}",
auto_close_seconds=5.0,
color="green",
)
except Exception as e:
import traceback
traceback.print_exc()
event_client.add_notification(
title="Failed to load example!",
body=str(e),
auto_close_seconds=10.0,
color="red",
)
def load_qwen_example_plan(event_client: viser.ClientHandle) -> None:
"""Load a Qwen-generated 10-action prompt plan into the timeline.
This preserves the native UI flow:
1) Load Example -> fills timeline prompt segments
2) Generate -> synthesizes motion from loaded prompts
"""
session = get_active_session(event_client)
if session is None:
return
def _thread_fn() -> None:
try:
history: list[str] = []
all_texts: list[str] = []
all_durations: list[float] = []
target_actions = 10
rounds = 0
while len(all_texts) < target_actions and rounds < 8:
remaining = target_actions - len(all_texts)
batch, history = call_qwen_for_prompts(
scene="Agentic demo: keep one character in continuous motion",
history=history,
requested_actions=min(5, remaining),
)
texts = batch.get("texts", [])
durations = batch.get("durations", [])
for t, d in zip(texts, durations):
if len(all_texts) >= target_actions:
break
all_texts.append(t)
all_durations.append(float(d))
rounds += 1
if len(all_texts) == 0:
event_client.add_notification(
title="Qwen example load failed",
body="No prompt segments were produced.",
auto_close_seconds=6.0,
color="red",
)
return
fps = session.model_fps
event_client.timeline.clear_prompts()
frame_cursor = 0
for i, (txt, dur) in enumerate(zip(all_texts, all_durations)):
n_frames = max(1, int(round(dur * fps)))
start_frame = frame_cursor
end_frame = frame_cursor + n_frames if i < len(all_texts) - 1 else frame_cursor + n_frames - 1
color = PROMPT_COLORS[i % len(PROMPT_COLORS)]
event_client.timeline.add_prompt(txt, start_frame, end_frame, color=color)
frame_cursor += n_frames
target_visible_frames = int(math.ceil(1.10 * frame_cursor))
event_client.timeline.set_zoom_settings(default_num_frames_zoom=max(60, target_visible_frames))
update_duration_auto()
event_client.add_notification(
title="Qwen example loaded",
body=f"Loaded {len(all_texts)} prompt segments. Click Generate to synthesize motion.",
auto_close_seconds=6.0,
color="green",
)
except Exception as e:
event_client.add_notification(
title="Qwen example load failed",
body=str(e),
auto_close_seconds=8.0,
color="red",
)
threading.Thread(target=_thread_fn, daemon=True).start()
@gui_load_example_button.on_click
def _(event: viser.GuiEvent) -> None:
event_client = event.client
session = get_active_session(event_client)
if session is None:
return
if gui_examples_dropdown.value in (QWEN_EXAMPLE_NAME, QWEN_EXAMPLE_LEGACY_NAME):
load_qwen_example_plan(event_client)
return
if not session.example_dict or (gui_examples_dropdown.value not in session.example_dict):
event_client.add_notification(
title="No examples available",
body="No examples found for the selected model.",
auto_close_seconds=5.0,
color="red",
)
return
example_path = session.example_dict[gui_examples_dropdown.value]
load_example_from_path(event_client, example_path, gui_load_gt_checkbox.value)
@gui_load_example_from_path_button.on_click
def _(event: viser.GuiEvent) -> None:
event_client = event.client
session = get_active_session(event_client)
if session is None:
return
example_path = gui_load_example_path_text.value
if not example_path:
event_client.add_notification(
title="No example path",
body="Please provide an example directory.",
auto_close_seconds=5.0,
color="red",
)
return
load_example_from_path(event_client, example_path, gui_load_gt_checkbox.value)
@gui_cfg_checkbox.on_update
def _(_) -> None:
if not demo.client_active(client_id):
return
val = gui_cfg_checkbox.value
gui_cfg_text_weight_slider.visible = val
gui_cfg_constraint_weight_slider.visible = val
def exit_editing_mode(session: ClientSession):
gui_edit_constraint_button.label = "Enter Editing Mode"
gui_generate_button.disabled = False
gui_generate_button.label = "Generate"
gui_reset_constraint_button.disabled = True
if "g1" in session.model_name:
gui_gizmo_space_dropdown.value = "Local"
gui_gizmo_space_dropdown.disabled = True
gui_gizmo_space_dropdown.visible = False
else:
gui_gizmo_space_dropdown.disabled = False
gui_gizmo_space_dropdown.visible = True
gui_undo_drag_button.disabled = True
gui_use_soma_layer_checkbox.disabled = False
session.edit_mode_snapshot = None
session.undo_drag_snapshot = None
motion = list(session.motions.values())[0]
motion.clear_all_gizmos()
motion.character.set_skinned_mesh_wireframe(False)
motion.character.set_skeleton_visibility(False)
motion.character.set_skinned_mesh_visibility(True)
motion.character.set_skinned_mesh_opacity(1.0)
session.gui_elements.gui_viz_skinned_mesh_opacity_slider.value = 1.0
# If the path is dense, put the motion back on the path
if "2D Root" in session.constraints and session.constraints["2D Root"].dense_path:
_update_dense_path(motion, session)
gui_viz_skinned_mesh_checkbox.value = True
gui_viz_skeleton_checkbox.value = False
# enter editing mode callback
@gui_edit_constraint_button.on_click
def _(event: viser.GuiEvent) -> None:
event_client = event.client
session = get_active_session(event_client)
if session is None:
return
session.edit_mode = not session.edit_mode
edit_alert = "Entered editing mode"
no_edit_alert = "Exited editing mode"
edit_message = "You can now modify pose or path constraints."
no_edit_message = "Can now generate motions."
event_client.add_notification(
title=edit_alert if session.edit_mode else no_edit_alert,
body=edit_message if session.edit_mode else no_edit_message,
auto_close_seconds=10.0,
color="blue",
)
if session.edit_mode:
gui_edit_constraint_button.label = "Exit Editing Mode"
gui_generate_button.disabled = True
gui_generate_button.label = "Generate Disabled In Editing Mode"
if "g1" in session.model_name:
gui_gizmo_space_dropdown.value = "Local"
gui_gizmo_space_dropdown.disabled = True
gui_use_soma_layer_checkbox.disabled = True
assert len(session.motions) == 1, "Only one motion allowed in edit mode"
motion = list(session.motions.values())[0]
snapshot_frame_idx = min(session.frame_idx, motion.length - 1)
session.edit_mode_snapshot = {}
ensure_edit_snapshot(session, motion, snapshot_frame_idx)
gui_reset_constraint_button.disabled = False
motion.character.set_skeleton_visibility(True)
# motion.character.set_skinned_mesh_wireframe(True)
motion.character.set_skinned_mesh_opacity(0.65)
session.gui_elements.gui_viz_skinned_mesh_opacity_slider.value = 0.65
motion.character.set_skinned_mesh_visibility(True)
gui_viz_skinned_mesh_checkbox.value = True
gui_viz_skeleton_checkbox.value = True
# need gizmos for root translation and individual joints
def _on_root2d_gizmo_release():
if "2D Root" in session.constraints and session.constraints["2D Root"].dense_path:
mot = list(session.motions.values())[0]
_update_dense_path(mot, session)
def _on_gizmo_drag_start():
mot = list(session.motions.values())[0]
frame_idx = min(session.frame_idx, mot.length - 1)
session.undo_drag_snapshot = {
"frame_idx": frame_idx,
"joints_pos": mot.get_joints_pos(frame_idx),
"joints_rot": mot.get_joints_rot(frame_idx),
}
gui_undo_drag_button.disabled = False
motion.add_root_translation_gizmo(
session.constraints,
on_2d_root_drag_end=_on_root2d_gizmo_release,
on_drag_start=_on_gizmo_drag_start,
)
gizmo_space = "local" if "g1" in session.model_name else gui_gizmo_space_dropdown.value.lower()
motion.add_joint_gizmos(
session.constraints,
space=gizmo_space,
on_drag_start=_on_gizmo_drag_start,
)
else:
exit_editing_mode(session)
@gui_reset_constraint_button.on_click
def _(event: viser.GuiEvent) -> None:
event_client = event.client
session = get_active_session(event_client)
if session is None or not session.edit_mode_snapshot:
return
if not session.motions:
return
motion = list(session.motions.values())[0]
snapshot_frame_idx = min(session.frame_idx, motion.length - 1)
if snapshot_frame_idx not in session.edit_mode_snapshot:
return
motion.update_pose_at_frame(
snapshot_frame_idx,
joints_pos=session.edit_mode_snapshot[snapshot_frame_idx]["joints_pos"],
joints_rot=session.edit_mode_snapshot[snapshot_frame_idx]["joints_rot"],
)
demo.set_frame(event_client.client_id, snapshot_frame_idx, update_timeline=False)
@gui_undo_drag_button.on_click
def _(event: viser.GuiEvent) -> None:
event_client = event.client
session = get_active_session(event_client)
if session is None or session.undo_drag_snapshot is None:
return
if not session.motions:
return
motion = list(session.motions.values())[0]
frame_idx = session.undo_drag_snapshot["frame_idx"]
motion.update_pose_at_frame(
frame_idx,
joints_pos=session.undo_drag_snapshot["joints_pos"],
joints_rot=session.undo_drag_snapshot["joints_rot"],
)
demo.set_frame(event_client.client_id, frame_idx, update_timeline=False)
session.undo_drag_snapshot = None
gui_undo_drag_button.disabled = True
def validate_interval(start_frame_idx: int, end_frame_idx: int, max_frame_idx: int) -> bool:
if start_frame_idx < 0 or start_frame_idx > max_frame_idx:
return False
if end_frame_idx < 0 or end_frame_idx > max_frame_idx:
return False
if end_frame_idx < start_frame_idx:
return False
return True
def clamp_interval_to_range(
start_frame_idx: int, end_frame_idx: int, max_frame_idx: int
) -> Optional[tuple[int, int]]:
if end_frame_idx < 0 or start_frame_idx > max_frame_idx:
return None
start_clamped = max(0, start_frame_idx)
end_clamped = min(max_frame_idx, end_frame_idx)
if end_clamped < start_clamped:
return None
return start_clamped, end_clamped
# add constraint callback
def add_constraint_callback(
constraint_id: str,
constraint_type: str,
frame_range: tuple[int, int],
joint_names: list[str] = None,
verbose: bool = True,
):
"""Add a constraint to the session.
Args:
constraint_type: str, the type of constraint to add
frame_range: tuple[int, int], the frame range to add the constraint to
joint_names: list[str], the names of the joints to constraint if the constraint type is End-Effectors
"""
# Check if session still exists
if not demo.client_active(client_id):
return
session = demo.client_sessions[client_id]
assert len(session.motions) == 1, "Only one motion allowed for adding constraints"
motion = list(session.motions.values())[0]
end_effector_type = None
if constraint_type in [
"Left Hand",
"Right Hand",
"Left Foot",
"Right Foot",
]:
joint_names = [constraint_type.replace(" ", ""), "Hips"]
# Hips are required because of smooth root representation
end_effector_type = constraint_type.replace(" ", "-").lower()
constraint_type = "End-Effectors"
# check to make sure interval is valid
is_interval = frame_range[1] != frame_range[0]
start_frame_idx = int(frame_range[0])
end_frame_idx = int(frame_range[1])
if is_interval:
clamped = clamp_interval_to_range(start_frame_idx, end_frame_idx, session.max_frame_idx)
if clamped is None:
print("Interval outside range! Couldn't add constraint.")
return
start_frame_idx, end_frame_idx = clamped
else:
if not validate_interval(start_frame_idx, end_frame_idx, session.max_frame_idx):
print("Invalid interval! Couldn't add constraint.")
return
# collect input args for the constraint based on which track it is
if is_interval:
constraint_kwargs = {
"interval_id": constraint_id,
"start_frame_idx": start_frame_idx,
"end_frame_idx": end_frame_idx,
}
else:
constraint_kwargs = {
"keyframe_id": constraint_id,
"frame_idx": start_frame_idx,
}
if constraint_type in ["Full-Body", "End-Effectors"]:
constraint_kwargs["joints_pos"] = motion.get_joints_pos(start_frame_idx, end_frame_idx)
constraint_kwargs["joints_rot"] = motion.get_joints_rot(start_frame_idx, end_frame_idx)
if constraint_type == "End-Effectors":
constraint_kwargs["joint_names"] = joint_names
constraint_kwargs["end_effector_type"] = end_effector_type
elif constraint_type == "2D Root":
constraint_kwargs["root_pos"] = motion.get_projected_root_pos(start_frame_idx, end_frame_idx)
# add the keyframe(s) to the constraint track
constraint = session.constraints[constraint_type]
if is_interval:
constraint.add_interval(**constraint_kwargs)
else:
constraint.add_keyframe(**constraint_kwargs)
apply_constraint_overlay_visibility(session)
if verbose:
client.add_notification(
title="Constraint added",
body="",
auto_close_seconds=5.0,
color="blue",
)
# timeline callbacks for keyframes and intervals
@client.timeline.on_keyframe_add
def _(keyframe_id: str, track_id: str, frame: int):
"""Called when a keyframe is added to a track."""
if not demo.client_active(client_id):
return
session = demo.client_sessions[client_id]
with session.timeline_data["keyframe_update_lock"]:
constraint_type = session.timeline_data["tracks"][track_id]["name"]
add_constraint_callback(
keyframe_id,
constraint_type,
(frame, frame),
verbose=False,
)
keyframe_data = client.timeline._keyframes.get(keyframe_id)
session.timeline_data["keyframes"][keyframe_id] = {
"frame": frame,
"track_id": track_id,
"locked": bool(keyframe_data.locked) if keyframe_data is not None else False,
"opacity": keyframe_data.opacity if keyframe_data is not None else 1.0,
"value": keyframe_data.value if keyframe_data is not None else None,
}
# Update smooth path when adding a keyframe (single action, not drag).
if constraint_type == "2D Root" and session.constraints["2D Root"].dense_path:
motion = list(session.motions.values())[0]
_update_dense_path(motion, session)
@client.timeline.on_interval_add
def handle_interval_add(interval_id: str, track_id: str, start_frame: int, end_frame: int):
"""Called when an interval is added to a track."""
if not demo.client_active(client_id):
return
session = demo.client_sessions[client_id]
with session.timeline_data["keyframe_update_lock"]:
constraint_type = session.timeline_data["tracks"][track_id]["name"]
add_constraint_callback(
interval_id,
constraint_type,
(start_frame, end_frame),
verbose=False,
)
interval_data = client.timeline._intervals.get(interval_id)
session.timeline_data["intervals"][interval_id] = {
"track_id": track_id,
"start_frame_idx": start_frame,
"end_frame_idx": end_frame,
"locked": bool(interval_data.locked) if interval_data is not None else False,
"opacity": interval_data.opacity if interval_data is not None else 1.0,
"value": interval_data.value if interval_data is not None else None,
}
if constraint_type == "2D Root" and session.constraints["2D Root"].dense_path:
motion = list(session.motions.values())[0]
_update_dense_path(motion, session)
def remove_constraint_callback(
constraint_id: str,
constraint_type: str,
frame_range: tuple[int, int],
verbose: bool = True,
) -> None:
if not demo.client_active(client_id):
return
session = demo.client_sessions[client_id]
session.updating_motions = True
is_interval = frame_range[1] != frame_range[0]
start_frame_idx = int(frame_range[0])
end_frame_idx = int(frame_range[1])
if is_interval:
clamped = clamp_interval_to_range(start_frame_idx, end_frame_idx, session.max_frame_idx)
if clamped is None:
return
start_frame_idx, end_frame_idx = clamped
else:
if not validate_interval(start_frame_idx, end_frame_idx, session.max_frame_idx):
print("Invalid interval! Couldn't remove constraint.")
return
if constraint_type in [
"Left Hand",
"Right Hand",
"Left Foot",
"Right Foot",
]:
constraint_type = "End-Effectors"
constraint = session.constraints[constraint_type]
if is_interval:
constraint.remove_interval(constraint_id, start_frame_idx, end_frame_idx)
else:
constraint.remove_keyframe(constraint_id, start_frame_idx)
if verbose:
client.add_notification(
title="Constraint removed",
body="",
auto_close_seconds=5.0,
color="blue",
)
@client.timeline.on_keyframe_move
def handle_keyframe_move(keyframe_id: str, new_frame: int):
"""Called when a keyframe is moved to a new frame."""
# print(f"Keyframe moved: {keyframe_id} to frame {new_frame}")
if not demo.client_active(client_id):
return
session = demo.client_sessions[client_id]
# Cancel any pending timer for this keyframe
timeline_data = session.timeline_data
with timeline_data["keyframe_update_lock"]:
if keyframe_id in timeline_data["keyframe_move_timers"]:
timeline_data["keyframe_move_timers"][keyframe_id].cancel()
# Store the latest target frame
timeline_data["pending_keyframe_moves"][keyframe_id] = new_frame
# Create a new timer to execute the actual move after a delay
# This debounces rapid movements - only execute when user stops moving
timer = threading.Timer(
0.03, # 10ms delay - adjust as needed
_execute_keyframe_move,
args=(client_id, keyframe_id, new_frame, session),
)
timeline_data["keyframe_move_timers"][keyframe_id] = timer
timer.start()
def _execute_keyframe_move(
client_id: int,
keyframe_id: str,
new_frame: int,
session: ClientSession,
):
"""Actually execute the keyframe move operation (called after debounce delay)."""
timeline_data = session.timeline_data
with timeline_data["keyframe_update_lock"]:
# Check if this move is still the latest one
if keyframe_id not in timeline_data["pending_keyframe_moves"]:
return # Move was cancelled
if timeline_data["pending_keyframe_moves"][keyframe_id] != new_frame:
return # A newer move superseded this one
# Remove from pending
del timeline_data["pending_keyframe_moves"][keyframe_id]
if keyframe_id in timeline_data["keyframe_move_timers"]:
del timeline_data["keyframe_move_timers"][keyframe_id]
# Now execute the actual move (keep it in the lock so we don't delete it while moving)
if keyframe_id not in timeline_data["keyframes"]:
# double check
return
keyframe_data = timeline_data["keyframes"][keyframe_id]
if not keyframe_data:
return
# if the frame did not move, don't do anything
if keyframe_data["frame"] == new_frame:
return
track_id = keyframe_data["track_id"]
constraint_type = timeline_data["tracks"][track_id]["name"]
cur_frame = keyframe_data["frame"]
# Remove constraint at old frame
remove_constraint_callback(
keyframe_id,
constraint_type,
(cur_frame, cur_frame),
verbose=False,
)
# Add constraint at new frame
add_constraint_callback(
keyframe_id,
constraint_type,
(new_frame, new_frame),
verbose=False,
)
# update our data
keyframe_data["frame"] = new_frame
# Schedule path update only after user stops dragging (no move for 300ms).
if constraint_type == "2D Root":
_schedule_dense_path_after_release(session)
@client.timeline.on_keyframe_delete
def handle_keyframe_delete(keyframe_id: str):
"""Called when a keyframe is deleted."""
if not demo.client_active(client_id):
return
session = demo.client_sessions[client_id]
with session.timeline_data["keyframe_update_lock"]:
if keyframe_id not in session.timeline_data["keyframes"]:
return
keyframe_data = session.timeline_data["keyframes"][keyframe_id]
track_id = keyframe_data["track_id"]
constraint_type = session.timeline_data["tracks"][track_id]["name"]
cur_frame = keyframe_data["frame"]
remove_constraint_callback(
keyframe_id,
constraint_type,
(cur_frame, cur_frame),
verbose=False,
)
del session.timeline_data["keyframes"][keyframe_id]
if constraint_type == "2D Root" and session.constraints["2D Root"].dense_path:
motion = list(session.motions.values())[0]
_update_dense_path(motion, session)
@client.timeline.on_interval_move
def handle_interval_move(interval_id: str, new_start: int, new_end: int):
"""Called when an interval is moved or resized."""
# print(f"Interval moved: {interval_id} to {new_start}-{new_end}")
if not demo.client_active(client_id):
return
session = demo.client_sessions[client_id]
# Cancel any pending timer for this interval
# We share the same lock for keyframe and interval moves assuming the user can't move both at the same time
timeline_data = session.timeline_data
with timeline_data["keyframe_update_lock"]:
if interval_id in timeline_data["keyframe_move_timers"]:
timeline_data["keyframe_move_timers"][interval_id].cancel()
# Store the latest target frame
new_interval = (new_start, new_end)
timeline_data["pending_keyframe_moves"][interval_id] = new_interval
# Create a new timer to execute the actual move after a delay
# This debounces rapid movements - only execute when user stops moving
timer = threading.Timer(
0.5, # 100ms delay - adding interval is much slower than moving a keyframe
_execute_interval_move,
args=(client_id, interval_id, new_interval, session),
)
timeline_data["keyframe_move_timers"][interval_id] = timer
timer.start()
def _execute_interval_move(
client_id: int,
interval_id: str,
new_interval: tuple[int, int],
session: ClientSession,
):
"""Actually execute the interval move operation (called after debounce delay)."""
timeline_data = session.timeline_data
with timeline_data["keyframe_update_lock"]:
# Check if this move is still the latest one
if interval_id not in timeline_data["pending_keyframe_moves"]:
return # Move was cancelled
if timeline_data["pending_keyframe_moves"][interval_id] != new_interval:
return # A newer move superseded this one
# Remove from pending
del timeline_data["pending_keyframe_moves"][interval_id]
if interval_id in timeline_data["keyframe_move_timers"]:
del timeline_data["keyframe_move_timers"][interval_id]
# Now execute the actual move
if interval_id not in timeline_data["intervals"]:
return
interval_data = timeline_data["intervals"][interval_id]
if not interval_data:
return
# if the interval did not move, don't do anything
if (
interval_data["start_frame_idx"] == new_interval[0]
and interval_data["end_frame_idx"] == new_interval[1]
):
return
track_id = interval_data["track_id"]
constraint_type = timeline_data["tracks"][track_id]["name"]
cur_range = (
interval_data["start_frame_idx"],
interval_data["end_frame_idx"],
)
# Remove constraint at old frame
remove_constraint_callback(
interval_id,
constraint_type,
cur_range,
verbose=False,
)
# Add constraint at new frame
add_constraint_callback(
interval_id,
constraint_type,
new_interval,
verbose=False,
)
# update our data
interval_data["start_frame_idx"] = new_interval[0]
interval_data["end_frame_idx"] = new_interval[1]
# Schedule path update only after user stops dragging (no move for 300ms).
if constraint_type == "2D Root":
_schedule_dense_path_after_release(session)
@client.timeline.on_interval_delete
def handle_interval_delete(interval_id: str):
"""Called when an interval is deleted."""
if not demo.client_active(client_id):
return
session = demo.client_sessions[client_id]
with session.timeline_data["keyframe_update_lock"]:
if interval_id not in session.timeline_data["intervals"]:
return
interval_data = session.timeline_data["intervals"][interval_id]
track_id = interval_data["track_id"]
constraint_type = session.timeline_data["tracks"][track_id]["name"]
remove_constraint_callback(
interval_id,
constraint_type,
(
interval_data["start_frame_idx"],
interval_data["end_frame_idx"],
),
verbose=False,
)
del session.timeline_data["intervals"][interval_id]
if constraint_type == "2D Root" and session.constraints["2D Root"].dense_path:
motion = list(session.motions.values())[0]
_update_dense_path(motion, session)
@gui_snap_to_constraint_button.on_click
def _(event: viser.GuiEvent) -> None:
event_client = event.client
session = get_active_session(event_client)
if session is None:
return
target_character_motion = list(session.motions.values())[0]
frame_idx = session.frame_idx
if frame_idx >= target_character_motion.length:
# frame idx larger than the motion, could not snap
return
for constraint_name in ["Full-Body", "End-Effectors"]:
if (
constraint_name in session.constraints
and frame_idx in session.constraints[constraint_name].keyframes
):
pos = session.constraints[constraint_name].keyframes[frame_idx]["joints_pos"]
rot = session.constraints[constraint_name].keyframes[frame_idx]["joints_rot"]
# update the full joints_pos of the character to match the constraints
target_character_motion.update_pose_at_frame(
frame_idx,
joints_pos=pos,
joints_rot=rot,
)
target_character_motion.set_frame(frame_idx)
return # motion already fully changed
if "2D Root" in session.constraints and frame_idx in session.constraints["2D Root"].keyframes:
# update only the root position
new_root_pos = session.constraints["2D Root"].keyframes[frame_idx]
old_root_pos = target_character_motion.get_projected_root_pos(frame_idx)
root_diff = new_root_pos - old_root_pos
root_diff[1] = 0.0 # don't change height
new_joints_pos = (
target_character_motion.joints_pos[frame_idx]
+ to_torch(
root_diff,
device=target_character_motion.joints_pos.device,
dtype=target_character_motion.joints_pos.dtype,
)[None]
)
rot = target_character_motion.joints_rot[frame_idx]
target_character_motion.update_pose_at_frame(
frame_idx,
joints_pos=new_joints_pos,
joints_rot=rot,
)
target_character_motion.set_frame(frame_idx)
@gui_clear_all_constraints_button.on_click
def _(event: viser.GuiEvent) -> None:
event_client = event.client
session = get_active_session(event_client)
if session is None:
return
with session.timeline_data["keyframe_update_lock"]:
# use the lock here to wait for any constraint updates to finish
for constraint in list(session.constraints.values()):
constraint.clear()
client.timeline.clear_keyframes()
client.timeline.clear_intervals()
if gui_dense_path_checkbox.value:
gui_dense_path_checkbox.value = False
if "2D Root" in session.constraints:
session.constraints["2D Root"].set_dense_path(False)
# generation callback
@gui_generate_button.on_click
def _(event: viser.GuiEvent) -> None:
event_client = event.client
session = get_active_session(event_client)
if session is None:
return
generating_notif = event_client.add_notification(
title="Generating motion...",
body="Generating motions for the given prompt!",
loading=True,
with_close_button=False,
)
gui_generate_button.disabled = True
client.timeline.disable_constraints()
num_samples = gui_num_samples_slider.value
timeline = session.client.timeline
# sort them to avoid issues:
prompt_values = sorted([x for x in timeline._prompts.values()], key=lambda x: x.start_frame)
texts = [x.text for x in prompt_values]
num_frames = compute_prompt_num_frames(prompt_values)
# compute the total duration
total_nb_frames = sum(num_frames)
total_duration = total_nb_frames / session.model_fps
# update just in case
set_new_duration(client_id, total_duration)
transitions_parameters = {
"num_transition_frames": gui_num_transition_frames_slider.value,
"share_transition": gui_share_transition_checkbox.value,
"percentage_transition_override": gui_percentage_transition_sharing_slider.value / 100,
}
# G1: postprocessing is disabled (does not work well for this model).
postprocess_parameters = {
"post_processing": (False if "g1" in session.model_name else gui_postprocess_checkbox.value),
"root_margin": gui_root_margin.value,
}
try:
demo.generate(
event_client,
texts,
num_frames,
num_samples,
gui_seed.value,
gui_diffusion_steps_slider.value,
cfg_weight=[
gui_cfg_text_weight_slider.value,
gui_cfg_constraint_weight_slider.value,
],
cfg_type="separated" if gui_cfg_checkbox.value else "nocfg",
postprocess_parameters=postprocess_parameters,
transitions_parameters=transitions_parameters,
real_robot_rotations=gui_real_robot_rotations_checkbox.value,
)
session.max_frame_idx = int(session.cur_duration * session.model_fps - 1)
session.max_frame_idx = int(session.cur_duration * session.model_fps) - 1
if session.frame_idx > session.max_frame_idx:
session.frame_idx = session.max_frame_idx
if num_samples > 1:
# add mesh selector to choose character to commit
def commit_motion(event: viser.GuiEvent) -> None:
target = event.target
commit_name = target.name.split("/")[1] # e.g. /character0/simple_skinned
print(f"Committing motion for character: {commit_name}")
# delete non-selected motions
new_motion_kwargs = None
for character_name, motion in session.motions.items():
if character_name == commit_name:
new_motion_kwargs = {
"skeleton": session.skeleton,
"joints_rot": motion.joints_rot,
"foot_contacts": motion.foot_contacts,
}
root_x_offset = motion.joints_pos[0, session.skeleton.root_idx, 0]
new_joints_pos = motion.joints_pos.clone()
new_joints_pos[..., 0] -= root_x_offset
new_motion_kwargs["joints_pos"] = new_joints_pos
break
# clear and re-add the selected motion
demo.clear_motions(event_client.client_id)
demo.add_character_motion(event_client, **new_motion_kwargs)
gui_edit_constraint_button.disabled = False
gui_generate_button.disabled = False
gui_snap_to_constraint_button.disabled = False
client.timeline.enable_constraints()
gui_generate_button.label = "Generate"
gui_save_example_button.disabled = False
gui_save_motion_button.disabled = False
gui_download_button.disabled = False
gui_save_constraints_button.disabled = False
gui_load_example_button.disabled = False
for motion in session.motions.values():
char = motion.character
character_name = char.name # e.g. "character0"
if char.skinned_mesh is not None:
char.skinned_mesh.on_click(commit_motion)
elif char.g1_mesh_rig is not None:
# Register click on every part so any part can be clicked,
# and use highlight_group so the whole robot highlights together.
for handle in char.g1_mesh_rig.mesh_handles:
handle.on_click(commit_motion, highlight_group=character_name)
gui_edit_constraint_button.disabled = True
gui_generate_button.disabled = True
gui_snap_to_constraint_button.disabled = True
gui_generate_button.label = "Choose Sample Before Generating"
gui_save_example_button.disabled = True
gui_save_motion_button.disabled = True
gui_download_button.disabled = True
gui_save_constraints_button.disabled = True
gui_load_example_button.disabled = True
else:
gui_edit_constraint_button.disabled = False
gui_generate_button.disabled = False
gui_snap_to_constraint_button.disabled = False
client.timeline.enable_constraints()
generating_notif.title = "Motion generation finished!"
generating_notif.body = "Motions have been generated successfully for the given prompt."
if num_samples > 1:
generating_notif.body += " Now choose which sample to commit."
generating_notif.loading = False
generating_notif.with_close_button = True
generating_notif.auto_close_seconds = 5.0
generating_notif.color = "green"
# put the motion at zero
demo.set_frame(client_id, 0)
except Exception as e:
import traceback
traceback.print_exc()
print(f"Error during generation for client {event_client.client_id}: {e}")
# Re-enable buttons and notify the user
if event_client.client_id in demo.client_sessions:
session = demo.client_sessions[event_client.client_id]
gui_generate_button.disabled = False
gui_load_example_button.disabled = False
gui_save_example_button.disabled = False
gui_save_motion_button.disabled = False
gui_download_button.disabled = False
# Reuse persistent notification instead of creating a new one
try:
generating_notif.title = "Generation failed!"
generating_notif.body = f"Error: {str(e)}"
generating_notif.loading = False
generating_notif.with_close_button = True
generating_notif.auto_close_seconds = 6.0
generating_notif.color = "red"
except Exception:
pass
demo.check_cuda_health()
#
# Visualization settings
#
with tab_group.add_tab("Visualize", viser.Icon.EYE):
with client.gui.add_folder("Playback", expand_by_default=True):
gui_model_fps = client.gui.add_number("Model FPS", initial_value=model_fps, disabled=True)
gui_playback_speed_buttons = client.gui.add_button_group(
"Playback Speed",
options=[
"0.5x",
"1x",
"2x",
],
)
gui_playback_speed_buttons.value = "1x"
@client.timeline.on_frame_change
def handle_timeline_frame_change(new_frame_idx: int):
"""Update the frame when the user clicks on the timeline."""
demo.set_frame(client_id, new_frame_idx, update_timeline=False)
session = demo.client_sessions.get(client_id)
if session is not None:
if session.edit_mode and session.motions:
motion = list(session.motions.values())[0]
snapshot_frame_idx = min(session.frame_idx, motion.length - 1)
ensure_edit_snapshot(session, motion, snapshot_frame_idx)
update_snap_to_constraint_button(session)
@client.timeline.on_prompt_add
async def _on_add(
prompt_id: str,
start_frame: int,
end_frame: int,
text: str,
color: tuple[int, int, int] | None,
) -> None:
update_duration_auto()
@client.timeline.on_prompt_update
async def _on_update(prompt_id: str, new_text: str) -> None:
update_duration_auto()
@client.timeline.on_prompt_resize
async def _on_resize(prompt_id: str, new_start: int, new_end: int) -> None:
update_duration_auto()
@client.timeline.on_prompt_move
async def _on_move(prompt_id: str, new_start: int, new_end: int) -> None:
update_duration_auto()
@client.timeline.on_prompt_delete
async def _on_delete(prompt_id: str) -> None:
update_duration_auto()
def play_pause_button_callback(session: ClientSession):
session.playing = not session.playing
def next_frame_callback(session: ClientSession):
if session.frame_idx < session.max_frame_idx:
session.frame_idx += 1
if session.frame_idx == session.max_frame_idx:
pass
demo.set_frame(client_id, session.frame_idx)
def prev_frame_callback(session: ClientSession):
if session.frame_idx > 0:
session.frame_idx -= 1
if session.frame_idx == 0:
pass
demo.set_frame(client_id, session.frame_idx)
@gui_playback_speed_buttons.on_click
def _(_) -> None:
if not demo.client_active(client_id):
return
speed_map = {
"0.5x": 0.5,
"1x": 1.0,
"2x": 2.0,
}
session = demo.client_sessions[client_id]
session.playback_speed = speed_map[gui_playback_speed_buttons.value]
with client.gui.add_folder("Body options", expand_by_default=True):
gui_viz_skinned_mesh_checkbox = client.gui.add_checkbox("Show Mesh", initial_value=True)
gui_viz_skinned_mesh_opacity_slider = client.gui.add_slider(
"Mesh Opacity", min=0.0, max=1.0, step=0.01, initial_value=1.0
)
gui_viz_skeleton_checkbox = client.gui.add_checkbox("Show Skeleton", initial_value=False)
gui_viz_foot_contacts_checkbox = client.gui.add_checkbox("Show Foot Contacts", initial_value=False)
gui_viz_foot_contacts_checkbox.visible = gui_viz_skeleton_checkbox.value
with client.gui.add_folder("Camera options", expand_by_default=True):
gui_camera_fov_slider = client.gui.add_slider(
"Camera FOV (deg)",
min=30.0,
max=90.0,
step=1.0,
initial_value=45.0,
)
client.camera.fov = np.deg2rad(gui_camera_fov_slider.value)
with client.gui.add_folder("Interface options", expand_by_default=True):
gui_show_timeline_checkbox = client.gui.add_checkbox(
"Show Timeline",
initial_value=True,
)
gui_show_constraint_tracks_checkbox = client.gui.add_checkbox(
"Show Constraint tracks",
initial_value=True,
)
gui_show_constraint_labels_checkbox = client.gui.add_checkbox(
"Show Constraint labels",
initial_value=True,
)
gui_show_starting_direction_checkbox = client.gui.add_checkbox(
"Show Starting Direction",
initial_value=True,
)
gui_dark_mode_checkbox = client.gui.add_checkbox(
"Dark Mode",
initial_value=False, # Default to light mode
)
gui_show_constraint_tracks_checkbox.visible = gui_show_timeline_checkbox.value
demo.set_start_direction_visible(client_id, gui_show_starting_direction_checkbox.value)
@gui_dark_mode_checkbox.on_update
def _(_):
# Apply the theme using configure_theme (pass uuid so titlebar toggle stays)
demo.configure_theme(
client,
gui_dark_mode_checkbox.value,
titlebar_dark_mode_checkbox_uuid=gui_dark_mode_checkbox.uuid,
)
session = demo.client_sessions[client.client_id]
for motion in session.motions.values():
motion.character.change_theme(gui_dark_mode_checkbox.value)
# Show dark mode toggle in titlebar (right of Github), hide sidebar checkbox
demo.configure_theme(
client,
gui_dark_mode_checkbox.value,
titlebar_dark_mode_checkbox_uuid=gui_dark_mode_checkbox.uuid,
)
gui_dark_mode_checkbox.visible = False
@gui_show_constraint_labels_checkbox.on_update
def _(_):
if not demo.client_active(client_id):
return
session = demo.client_sessions[client_id]
for constraint in session.constraints.values():
constraint.set_label_visibility(gui_show_constraint_labels_checkbox.value)
@gui_show_timeline_checkbox.on_update
def _(_):
if not demo.client_active(client_id):
return
session = demo.client_sessions[client_id]
session.client.timeline.set_visible(gui_show_timeline_checkbox.value)
gui_show_constraint_tracks_checkbox.visible = gui_show_timeline_checkbox.value
if gui_show_timeline_checkbox.value:
demo.set_constraint_tracks_visible(session, gui_show_constraint_tracks_checkbox.value)
@gui_show_constraint_tracks_checkbox.on_update
def _(_):
if not demo.client_active(client_id):
return
session = demo.client_sessions[client_id]
demo.set_constraint_tracks_visible(session, gui_show_constraint_tracks_checkbox.value)
@gui_show_starting_direction_checkbox.on_update
def _(_):
if not demo.client_active(client_id):
return
demo.set_start_direction_visible(client_id, gui_show_starting_direction_checkbox.value)
@gui_viz_skeleton_checkbox.on_update
def _(_) -> None:
if not demo.client_active(client_id):
return
session = demo.client_sessions[client_id]
gui_viz_foot_contacts_checkbox.visible = gui_viz_skeleton_checkbox.value
if not gui_viz_skeleton_checkbox.value:
gui_viz_foot_contacts_checkbox.value = False
for motion in session.motions.values():
motion.character.set_skeleton_visibility(gui_viz_skeleton_checkbox.value)
@gui_viz_foot_contacts_checkbox.on_update
def _(_) -> None:
if not demo.client_active(client_id):
return
session = demo.client_sessions[client_id]
for motion in session.motions.values():
motion.character.set_show_foot_contacts(
gui_viz_foot_contacts_checkbox.value, frame_idx=motion.cur_frame_idx
)
@gui_viz_skinned_mesh_checkbox.on_update
def _(_) -> None:
if not demo.client_active(client_id):
return
session = demo.client_sessions[client_id]
for motion in session.motions.values():
motion.character.set_skinned_mesh_visibility(gui_viz_skinned_mesh_checkbox.value)
@gui_viz_skinned_mesh_opacity_slider.on_update
def _(_) -> None:
if not demo.client_active(client_id):
return
session = demo.client_sessions[client_id]
for motion in session.motions.values():
motion.character.set_skinned_mesh_opacity(gui_viz_skinned_mesh_opacity_slider.value)
@gui_camera_fov_slider.on_update
def _(_) -> None:
if not demo.client_active(client_id):
return
client.camera.fov = np.deg2rad(gui_camera_fov_slider.value)
#
# Instructions tab
#
with tab_group.add_tab("Instructions", viser.Icon.INFO_CIRCLE):
client.gui.add_markdown(DEMO_UI_INSTRUCTIONS_TAB_MD)
#
# Keyboard events
#
@client.scene.on_keyboard_event("keydown", debounce_ms=100)
def handle_key(event: viser.KeyboardEvent) -> None:
# Check if client session still exists
if client_id not in demo.client_sessions:
return
session = demo.client_sessions[client_id]
# Space bar: only toggle on FIRST press
if event.key == " ":
now = time.monotonic()
if now - session.last_space_toggle_time >= 0.2:
session.last_space_toggle_time = now
play_pause_button_callback(session)
return
# Handle arrow keys: frame navigation (fast OS repeat with 50ms debounce).
elif event.key == "ArrowLeft":
prev_frame_callback(session)
elif event.key == "ArrowRight":
next_frame_callback(session)
gui_elements = GuiElements(
gui_play_pause_button=gui_play_pause_button,
gui_next_frame_button=gui_next_frame_button,
gui_prev_frame_button=gui_prev_frame_button,
gui_generate_button=gui_generate_button,
gui_model_fps=gui_model_fps,
gui_timeline=gui_timeline,
gui_viz_skeleton_checkbox=gui_viz_skeleton_checkbox,
gui_viz_foot_contacts_checkbox=gui_viz_foot_contacts_checkbox,
gui_viz_skinned_mesh_checkbox=gui_viz_skinned_mesh_checkbox,
gui_viz_skinned_mesh_opacity_slider=gui_viz_skinned_mesh_opacity_slider,
gui_camera_fov_slider=gui_camera_fov_slider,
gui_duration_slider=gui_duration_slider,
gui_num_samples_slider=gui_num_samples_slider,
gui_cfg_checkbox=gui_cfg_checkbox,
gui_cfg_text_weight_slider=gui_cfg_text_weight_slider,
gui_cfg_constraint_weight_slider=gui_cfg_constraint_weight_slider,
gui_diffusion_steps_slider=gui_diffusion_steps_slider,
gui_seed=gui_seed,
gui_postprocess_checkbox=gui_postprocess_checkbox,
gui_root_margin=gui_root_margin,
gui_real_robot_rotations_checkbox=gui_real_robot_rotations_checkbox,
gui_dark_mode_checkbox=gui_dark_mode_checkbox,
gui_use_soma_layer_checkbox=gui_use_soma_layer_checkbox,
)
return (
gui_elements,
timeline_tracks,
example_dict,
gui_examples_dropdown,
gui_save_example_path_text,
gui_model_selector,
)