| | import json |
| | import os |
| | import logging |
| | import sys |
| | import html |
| | from pathlib import Path |
| | from typing import Dict, List, Optional |
| | from functools import lru_cache |
| |
|
| | import gradio as gr |
| | import pandas as pd |
| | import plotly.graph_objects as go |
| | import plotly.io as pio |
| | from huggingface_hub import snapshot_download, HfApi |
| |
|
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| | handlers=[logging.StreamHandler(sys.stdout)], |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| | DEFAULT_DATASET_ID = os.getenv( |
| | "DATASET_ID", "raffaelkultyshev/humanoid-robots-training-dataset" |
| | ) |
| | LOCAL_DATASET_DIR = Path("dataset_cache") |
| | HF_TOKEN = os.getenv("HF_TOKEN") |
| |
|
| | JOINT_ALIASES = { |
| | "wrist": "Wrist", |
| | "thumb_tip": "Thumb Tip", |
| | "index_mcp": "Index MCP", |
| | "index_tip": "Index Tip", |
| | } |
| |
|
| | JOINT_NAME_MAP = { |
| | "wrist": "WRIST", |
| | "thumb_tip": "THUMB_TIP", |
| | "index_mcp": "INDEX_FINGER_MCP", |
| | "index_tip": "INDEX_FINGER_TIP", |
| | } |
| |
|
| | METRIC_LABELS = { |
| | "x_cm": "X (cm)", |
| | "y_cm": "Y (cm)", |
| | "z_cm": "Z (cm)", |
| | "yaw_deg": "Yaw (°)", |
| | "pitch_deg": "Pitch (°)", |
| | "roll_deg": "Roll (°)", |
| | } |
| |
|
| | PLOT_GRID = [ |
| | ["x_cm", "y_cm", "z_cm"], |
| | ["yaw_deg", "pitch_deg", "roll_deg"], |
| | ] |
| |
|
| | PLOT_ORDER = [metric for row in PLOT_GRID for metric in row] |
| |
|
| | CUSTOM_CSS = """ |
| | :root, .gradio-container, body { |
| | background-color: #050a18 !important; |
| | color: #f8fafc !important; |
| | font-family: 'Inter', 'Segoe UI', system-ui, sans-serif; |
| | } |
| | .side-panel { |
| | background: #0f172a; |
| | padding: 20px; |
| | border-radius: 18px; |
| | border: 1px solid #1f2b47; |
| | min-height: 100%; |
| | } |
| | .stats-card ul { |
| | list-style: none; |
| | padding: 0; |
| | margin: 0; |
| | font-size: 0.92rem; |
| | } |
| | .stats-card li { |
| | margin-bottom: 10px; |
| | color: #e2e8f0; |
| | } |
| | .stats-card span { |
| | display: inline-block; |
| | margin-right: 6px; |
| | color: #7dd3fc; |
| | } |
| | .episodes-title { |
| | margin: 18px 0 8px; |
| | font-size: 0.78rem; |
| | text-transform: uppercase; |
| | letter-spacing: 0.14em; |
| | color: #94a3b8; |
| | } |
| | .episode-list .gr-form { |
| | padding: 0; |
| | } |
| | .episode-list .gr-form > div { |
| | gap: 0; |
| | } |
| | .episode-list input[type="radio"] { |
| | display: none; |
| | } |
| | .episode-list label { |
| | background: transparent !important; |
| | border: none !important; |
| | color: #cbd5f5 !important; |
| | padding: 3px 0 !important; |
| | justify-content: flex-start; |
| | font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; |
| | font-size: 0.9rem; |
| | text-decoration: underline; |
| | } |
| | .episode-list label:hover { |
| | color: #67e8f9 !important; |
| | cursor: pointer; |
| | } |
| | .episode-list input[type="radio"]:checked + label { |
| | color: #facc15 !important; |
| | font-weight: 700; |
| | margin-left: -2px; |
| | } |
| | .main-panel { |
| | padding-top: 8px; |
| | } |
| | .instruction-card { |
| | background: #0f172a; |
| | padding: 18px 20px; |
| | border-radius: 18px; |
| | border: 1px solid #1f2b47; |
| | } |
| | .instruction-label { |
| | font-size: 0.75rem; |
| | letter-spacing: 0.12em; |
| | text-transform: uppercase; |
| | color: #94a3b8; |
| | margin-bottom: 10px; |
| | } |
| | .instruction-text { |
| | font-size: 1.1rem; |
| | line-height: 1.5; |
| | } |
| | .video-card { |
| | background: #0f172a; |
| | border: 1px solid #1f2b47; |
| | border-radius: 18px; |
| | padding: 18px 20px; |
| | margin-top: 18px; |
| | } |
| | .video-title { |
| | font-size: 0.78rem; |
| | text-transform: uppercase; |
| | letter-spacing: 0.18em; |
| | color: #94a3b8; |
| | margin-bottom: 8px; |
| | } |
| | .video-panel video { |
| | border-radius: 12px; |
| | border: 1px solid #1f2b47; |
| | background: #030712; |
| | } |
| | .download-button button { |
| | border-radius: 999px; |
| | border: 1px solid #334155; |
| | background: #1e293b; |
| | color: #f8fafc; |
| | font-size: 0.85rem; |
| | padding: 8px 24px; |
| | } |
| | .download-button button:hover { |
| | border-color: #67e8f9; |
| | color: #67e8f9; |
| | } |
| | .plots-wrap { |
| | margin-top: 18px; |
| | } |
| | .plots-wrap .gr-row { |
| | gap: 16px; |
| | } |
| | .plot-html { |
| | background: #111a2c; |
| | border-radius: 12px; |
| | padding: 10px; |
| | border: 1px solid #1f2b47; |
| | min-height: 320px; |
| | } |
| | .plot-html iframe { |
| | width: 100%; |
| | height: 300px; |
| | border: none; |
| | } |
| | """ |
| |
|
| |
|
| | @lru_cache(maxsize=1) |
| | def get_dataset_revision(repo_id: str) -> Optional[str]: |
| | try: |
| | info = HfApi(token=HF_TOKEN).repo_info(repo_id=repo_id, repo_type="dataset") |
| | return info.sha |
| | except Exception as exc: |
| | logger.warning(f"Could not fetch dataset revision for {repo_id}: {exc}") |
| | return None |
| |
|
| |
|
| | @lru_cache(maxsize=2) |
| | def get_dataset_root(repo_id: str, revision: Optional[str]) -> Path: |
| | local_path = snapshot_download( |
| | repo_id=repo_id, |
| | repo_type="dataset", |
| | local_dir=LOCAL_DATASET_DIR, |
| | local_dir_use_symlinks=False, |
| | revision=revision, |
| | token=HF_TOKEN, |
| | ) |
| | return Path(local_path) |
| |
|
| |
|
| | @lru_cache(maxsize=2) |
| | def load_info(repo_id: str, revision: Optional[str]) -> Dict: |
| | root = get_dataset_root(repo_id, revision) |
| | info_path = root / "meta" / "info.json" |
| | with open(info_path, "r", encoding="utf-8") as f: |
| | return json.load(f) |
| |
|
| |
|
| | def resolve_path(root: Path, template: str, episode_chunk: int, episode_index: int) -> Path: |
| | if isinstance(template, dict): |
| | rgb_template = template.get("rgb") |
| | if rgb_template is None: |
| | raise ValueError("RGB template missing from metadata") |
| | return root / rgb_template.format(episode_chunk=episode_chunk, episode_index=episode_index) |
| | return root / template.format(episode_chunk=episode_chunk, episode_index=episode_index) |
| |
|
| |
|
| | @lru_cache(maxsize=64) |
| | def load_episode(repo_id: str, episode_index: int, revision: Optional[str]) -> Dict: |
| | info = load_info(repo_id, revision) |
| | root = get_dataset_root(repo_id, revision) |
| | episode_meta = next((ep for ep in info["episodes"] if ep["episode_index"] == episode_index), None) |
| | if not episode_meta: |
| | raise ValueError(f"Episode {episode_index} not found in metadata") |
| |
|
| | chunk = episode_meta["episode_chunk"] |
| | parquet_path = resolve_path(root, info["data_path"], chunk, episode_index) |
| | if not parquet_path.exists(): |
| | raise FileNotFoundError(f"Parquet file not found: {parquet_path}") |
| |
|
| | df = pd.read_parquet(parquet_path) |
| | timestamps, state_df = build_state_dataframe(df) |
| |
|
| | rgb_path = resolve_path(root, info["video_path"], chunk, episode_index) |
| |
|
| | instruction = ( |
| | episode_meta.get("language_instruction") |
| | or ( |
| | df["language_instruction"].dropna().iloc[0] |
| | if "language_instruction" in df.columns and not df["language_instruction"].isna().all() |
| | else info.get("task", "Tape roll to bowl") |
| | ) |
| | ) |
| |
|
| | return { |
| | "timestamps": timestamps, |
| | "state_df": state_df, |
| | "rgb_path": rgb_path, |
| | "instruction": instruction, |
| | } |
| |
|
| |
|
| | def build_state_dataframe(df: pd.DataFrame) -> (List[float], pd.DataFrame): |
| | if "frame_idx" not in df.columns or "timestamp_s" not in df.columns: |
| | raise ValueError("Episode parquet missing frame timing information.") |
| |
|
| | frame_times = ( |
| | df[["frame_idx", "timestamp_s"]] |
| | .drop_duplicates("frame_idx") |
| | .set_index("frame_idx") |
| | .sort_index() |
| | ) |
| | frame_indices = frame_times.index.to_list() |
| |
|
| | state_df = pd.DataFrame(index=frame_indices) |
| | for alias, joint_name in JOINT_NAME_MAP.items(): |
| | joint_df = ( |
| | df[df["joint_name"] == joint_name] |
| | .set_index("frame_idx") |
| | .sort_index() |
| | .reindex(frame_indices) |
| | ) |
| | for metric in METRIC_LABELS.keys(): |
| | if metric in joint_df.columns: |
| | state_df[f"{alias}_{metric}"] = joint_df[metric].astype(float) |
| |
|
| | state_df.reset_index(drop=True, inplace=True) |
| | timestamps = frame_times["timestamp_s"].to_list() |
| | return timestamps, state_df |
| |
|
| |
|
| | def build_plot_fig(data: Dict, metric: str) -> go.Figure: |
| | timestamps = data["timestamps"] |
| | state_df = data["state_df"] |
| | fig = go.Figure() |
| | for alias, label in JOINT_ALIASES.items(): |
| | col_name = f"{alias}_{metric}" |
| | if col_name not in state_df.columns: |
| | continue |
| | fig.add_trace( |
| | go.Scatter( |
| | x=timestamps, |
| | y=state_df[col_name], |
| | mode="lines", |
| | name=label, |
| | ) |
| | ) |
| | fig.update_layout( |
| | margin=dict(l=20, r=20, t=30, b=20), |
| | height=250, |
| | template="plotly_dark", |
| | legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), |
| | xaxis_title="Time (s)", |
| | yaxis_title=METRIC_LABELS[metric], |
| | ) |
| | fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="rgba(255,255,255,0.1)") |
| | fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="rgba(255,255,255,0.1)") |
| | return fig |
| |
|
| |
|
| | def build_plot_html(data: Dict, metric: str) -> str: |
| | fig = build_plot_fig(data, metric) |
| | return pio.to_html(fig, include_plotlyjs="cdn", full_html=False) |
| |
|
| |
|
| | def format_episode_label(idx: int) -> str: |
| | return f"Episode {idx:02d}" |
| |
|
| |
|
| | def parse_episode_label(label: str) -> int: |
| | return int(label.replace("Episode", "").strip()) |
| |
|
| |
|
| | def format_instruction_html(text: str) -> str: |
| | safe_text = html.escape(text) |
| | return ( |
| | '<div class="instruction-card">' |
| | '<p class="instruction-label">Language Instruction</p>' |
| | f'<p class="instruction-text">{safe_text}</p>' |
| | "</div>" |
| | ) |
| |
|
| |
|
| | def build_interface(): |
| | revision = get_dataset_revision(DEFAULT_DATASET_ID) |
| | info = load_info(DEFAULT_DATASET_ID, revision) |
| | episode_indices = sorted(ep["episode_index"] for ep in info["episodes"]) |
| | if not episode_indices: |
| | raise RuntimeError("No episodes found in dataset metadata.") |
| |
|
| | default_idx = episode_indices[0] |
| | default_label = format_episode_label(default_idx) |
| | default_data = load_episode(DEFAULT_DATASET_ID, default_idx, revision) |
| | default_video = str(default_data["rgb_path"]) |
| | default_instruction = default_data["instruction"] |
| | default_figs = {metric: build_plot_html(default_data, metric) for metric in METRIC_LABELS.keys()} |
| |
|
| | total_frames = sum(ep.get("num_frames", 0) for ep in info["episodes"]) |
| | fps = info.get("fps", 30.0) |
| | stats_html = f""" |
| | <div class="stats-card"> |
| | <ul> |
| | <li><span>Number of samples/frames:</span> {total_frames:,}</li> |
| | <li><span>Number of episodes:</span> {len(episode_indices)}</li> |
| | <li><span>Frames per second:</span> {fps:.1f}</li> |
| | </ul> |
| | </div> |
| | """ |
| |
|
| | theme = gr.themes.Soft( |
| | primary_hue="cyan", secondary_hue="blue", neutral_hue="slate" |
| | ).set( |
| | body_background_fill="#0c1424", |
| | body_text_color="#f8fafc", |
| | block_background_fill="#111a2c", |
| | block_title_text_color="#f8fafc", |
| | input_background_fill="#151f33", |
| | border_color_primary="#1f2b47", |
| | shadow_drop="none", |
| | ) |
| |
|
| | with gr.Blocks(theme=theme, css=CUSTOM_CSS) as demo: |
| | gr.Markdown("# Humanoid Robots Hand Pose Viewer") |
| | gr.Markdown( |
| | "Visualize RGB + 6DoF hand trajectories for all Moving_Mini tasks " |
| | "(humanoid-robots-training-dataset)." |
| | ) |
| |
|
| | with gr.Row(equal_height=True): |
| | with gr.Column(scale=1, min_width=260, elem_classes=["side-panel"]): |
| | gr.HTML(stats_html) |
| | gr.HTML('<div class="episodes-title">Episodes</div>') |
| | episode_radio = gr.Radio( |
| | choices=[format_episode_label(i) for i in episode_indices], |
| | value=default_label, |
| | label="Episodes", |
| | elem_classes=["episode-list"], |
| | ) |
| | with gr.Column(scale=2, min_width=640, elem_classes=["main-panel"]): |
| | instruction_box = gr.HTML( |
| | format_instruction_html(default_instruction), |
| | label="Language Instruction", |
| | ) |
| | with gr.Column(elem_classes=["video-card"]): |
| | gr.HTML('<div class="video-title">RGB</div>') |
| | video = gr.Video( |
| | height=360, |
| | value=default_video, |
| | elem_classes=["video-panel"], |
| | show_label=False, |
| | show_download_button=False, |
| | ) |
| | download_button = gr.DownloadButton( |
| | label="Download", |
| | value=default_video, |
| | elem_classes=["download-button"], |
| | ) |
| |
|
| | plot_outputs = [] |
| | gr.Markdown("### Joint trajectories", elem_classes=["plots-title"]) |
| | with gr.Column(elem_classes=["plots-wrap"]): |
| | for row in PLOT_GRID: |
| | with gr.Row(): |
| | for metric in row: |
| | plot = gr.HTML(value=default_figs[metric], elem_classes=["plot-html"]) |
| | plot_outputs.append(plot) |
| |
|
| | outputs = [instruction_box, video, download_button] + plot_outputs |
| |
|
| | def load_episode_payload(label: str): |
| | idx = parse_episode_label(label) |
| | data = load_episode(DEFAULT_DATASET_ID, idx, revision) |
| | video_path = str(data["rgb_path"]) |
| | figs = [build_plot_html(data, metric) for metric in PLOT_ORDER] |
| | return [ |
| | format_instruction_html(data["instruction"]), |
| | video_path, |
| | gr.DownloadButton.update(value=video_path), |
| | *figs, |
| | ] |
| |
|
| | episode_radio.change(fn=load_episode_payload, inputs=episode_radio, outputs=outputs) |
| |
|
| | return demo |
| |
|
| |
|
| |
|
| | def main(): |
| | demo = build_interface() |
| | demo.queue().launch(show_api=False) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|