Raffael-Kultyshev's picture
Update dataset to DynamicIntelligence/humanoid-robots-training-dataset
fb66e8f verified
import json
import os
import logging
import sys
import html
from pathlib import Path
from typing import Dict, List, Optional
from functools import lru_cache
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
from huggingface_hub import snapshot_download, HfApi
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
DEFAULT_DATASET_ID = os.getenv(
"DATASET_ID", "DynamicIntelligence/humanoid-robots-training-dataset"
)
LOCAL_DATASET_DIR = Path("dataset_cache")
HF_TOKEN = os.getenv("HF_TOKEN")
JOINT_ALIASES = {
"wrist": "Wrist",
"thumb_tip": "Thumb Tip",
"index_mcp": "Index MCP",
"index_tip": "Index Tip",
}
JOINT_NAME_MAP = {
"wrist": "WRIST",
"thumb_tip": "THUMB_TIP",
"index_mcp": "INDEX_FINGER_MCP",
"index_tip": "INDEX_FINGER_TIP",
}
METRIC_LABELS = {
"x_cm": "X (cm)",
"y_cm": "Y (cm)",
"z_cm": "Z (cm)",
"yaw_deg": "Yaw (°)",
"pitch_deg": "Pitch (°)",
"roll_deg": "Roll (°)",
}
PLOT_GRID = [
["x_cm", "y_cm", "z_cm"],
["yaw_deg", "pitch_deg", "roll_deg"],
]
PLOT_ORDER = [metric for row in PLOT_GRID for metric in row]
CUSTOM_CSS = """
:root, .gradio-container, body {
background-color: #050a18 !important;
color: #f8fafc !important;
font-family: 'Inter', 'Segoe UI', system-ui, sans-serif;
}
.side-panel {
background: #0f172a;
padding: 20px;
border-radius: 18px;
border: 1px solid #1f2b47;
min-height: 100%;
}
.stats-card ul {
list-style: none;
padding: 0;
margin: 0;
font-size: 0.92rem;
}
.stats-card li {
margin-bottom: 10px;
color: #e2e8f0;
}
.stats-card span {
display: inline-block;
margin-right: 6px;
color: #7dd3fc;
}
.episodes-title {
margin: 18px 0 8px;
font-size: 0.78rem;
text-transform: uppercase;
letter-spacing: 0.14em;
color: #94a3b8;
}
.episode-list .gr-form {
padding: 0;
}
.episode-list .gr-form > div {
gap: 0;
}
.episode-list input[type="radio"] {
display: none;
}
.episode-list label {
background: transparent !important;
border: none !important;
color: #cbd5f5 !important;
padding: 3px 0 !important;
justify-content: flex-start;
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;
font-size: 0.9rem;
text-decoration: underline;
}
.episode-list label:hover {
color: #67e8f9 !important;
cursor: pointer;
}
.episode-list input[type="radio"]:checked + label {
color: #facc15 !important;
font-weight: 700;
margin-left: -2px;
}
.main-panel {
padding-top: 8px;
}
.instruction-card {
background: #0f172a;
padding: 18px 20px;
border-radius: 18px;
border: 1px solid #1f2b47;
}
.instruction-label {
font-size: 0.75rem;
letter-spacing: 0.12em;
text-transform: uppercase;
color: #94a3b8;
margin-bottom: 10px;
}
.instruction-text {
font-size: 1.1rem;
line-height: 1.5;
}
.video-card {
background: #0f172a;
border: 1px solid #1f2b47;
border-radius: 18px;
padding: 18px 20px;
margin-top: 18px;
}
.video-title {
font-size: 0.78rem;
text-transform: uppercase;
letter-spacing: 0.18em;
color: #94a3b8;
margin-bottom: 8px;
}
.video-panel video {
border-radius: 12px;
border: 1px solid #1f2b47;
background: #030712;
}
.download-button button {
border-radius: 999px;
border: 1px solid #334155;
background: #1e293b;
color: #f8fafc;
font-size: 0.85rem;
padding: 8px 24px;
}
.download-button button:hover {
border-color: #67e8f9;
color: #67e8f9;
}
.plots-wrap {
margin-top: 18px;
}
.plots-wrap .gr-row {
gap: 16px;
}
.plot-html {
background: #111a2c;
border-radius: 12px;
padding: 10px;
border: 1px solid #1f2b47;
min-height: 320px;
}
.plot-html iframe {
width: 100%;
height: 300px;
border: none;
}
"""
@lru_cache(maxsize=1)
def get_dataset_revision(repo_id: str) -> Optional[str]:
try:
info = HfApi(token=HF_TOKEN).repo_info(repo_id=repo_id, repo_type="dataset")
return info.sha
except Exception as exc:
logger.warning(f"Could not fetch dataset revision for {repo_id}: {exc}")
return None
@lru_cache(maxsize=2)
def get_dataset_root(repo_id: str, revision: Optional[str]) -> Path:
local_path = snapshot_download(
repo_id=repo_id,
repo_type="dataset",
local_dir=LOCAL_DATASET_DIR,
local_dir_use_symlinks=False,
revision=revision,
token=HF_TOKEN,
)
return Path(local_path)
@lru_cache(maxsize=2)
def load_info(repo_id: str, revision: Optional[str]) -> Dict:
root = get_dataset_root(repo_id, revision)
info_path = root / "meta" / "info.json"
with open(info_path, "r", encoding="utf-8") as f:
return json.load(f)
def resolve_path(root: Path, template: str, episode_chunk: int, episode_index: int) -> Path:
if isinstance(template, dict):
rgb_template = template.get("rgb")
if rgb_template is None:
raise ValueError("RGB template missing from metadata")
return root / rgb_template.format(episode_chunk=episode_chunk, episode_index=episode_index)
return root / template.format(episode_chunk=episode_chunk, episode_index=episode_index)
@lru_cache(maxsize=64)
def load_episode(repo_id: str, episode_index: int, revision: Optional[str]) -> Dict:
info = load_info(repo_id, revision)
root = get_dataset_root(repo_id, revision)
episode_meta = next((ep for ep in info["episodes"] if ep["episode_index"] == episode_index), None)
if not episode_meta:
raise ValueError(f"Episode {episode_index} not found in metadata")
chunk = episode_meta["episode_chunk"]
parquet_path = resolve_path(root, info["data_path"], chunk, episode_index)
if not parquet_path.exists():
raise FileNotFoundError(f"Parquet file not found: {parquet_path}")
df = pd.read_parquet(parquet_path)
timestamps, state_df = build_state_dataframe(df)
rgb_path = resolve_path(root, info["video_path"], chunk, episode_index)
instruction = (
episode_meta.get("language_instruction")
or (
df["language_instruction"].dropna().iloc[0]
if "language_instruction" in df.columns and not df["language_instruction"].isna().all()
else info.get("task", "Tape roll to bowl")
)
)
return {
"timestamps": timestamps,
"state_df": state_df,
"rgb_path": rgb_path,
"instruction": instruction,
}
def build_state_dataframe(df: pd.DataFrame) -> (List[float], pd.DataFrame):
if "frame_idx" not in df.columns or "timestamp_s" not in df.columns:
raise ValueError("Episode parquet missing frame timing information.")
frame_times = (
df[["frame_idx", "timestamp_s"]]
.drop_duplicates("frame_idx")
.set_index("frame_idx")
.sort_index()
)
frame_indices = frame_times.index.to_list()
state_df = pd.DataFrame(index=frame_indices)
for alias, joint_name in JOINT_NAME_MAP.items():
joint_df = (
df[df["joint_name"] == joint_name]
.set_index("frame_idx")
.sort_index()
.reindex(frame_indices)
)
for metric in METRIC_LABELS.keys():
if metric in joint_df.columns:
state_df[f"{alias}_{metric}"] = joint_df[metric].astype(float)
state_df.reset_index(drop=True, inplace=True)
timestamps = frame_times["timestamp_s"].to_list()
return timestamps, state_df
def build_plot_fig(data: Dict, metric: str) -> go.Figure:
timestamps = data["timestamps"]
state_df = data["state_df"]
fig = go.Figure()
for alias, label in JOINT_ALIASES.items():
col_name = f"{alias}_{metric}"
if col_name not in state_df.columns:
continue
fig.add_trace(
go.Scatter(
x=timestamps,
y=state_df[col_name],
mode="lines",
name=label,
)
)
fig.update_layout(
margin=dict(l=20, r=20, t=30, b=20),
height=250,
template="plotly_dark",
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
xaxis_title="Time (s)",
yaxis_title=METRIC_LABELS[metric],
)
fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor="rgba(255,255,255,0.1)")
fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor="rgba(255,255,255,0.1)")
return fig
def build_plot_html(data: Dict, metric: str) -> str:
fig = build_plot_fig(data, metric)
return pio.to_html(fig, include_plotlyjs="cdn", full_html=False)
def format_episode_label(idx: int) -> str:
return f"Episode {idx:02d}"
def parse_episode_label(label: str) -> int:
return int(label.replace("Episode", "").strip())
def format_instruction_html(text: str) -> str:
safe_text = html.escape(text)
return (
'<div class="instruction-card">'
'<p class="instruction-label">Language Instruction</p>'
f'<p class="instruction-text">{safe_text}</p>'
"</div>"
)
def build_interface():
revision = get_dataset_revision(DEFAULT_DATASET_ID)
info = load_info(DEFAULT_DATASET_ID, revision)
episode_indices = sorted(ep["episode_index"] for ep in info["episodes"])
if not episode_indices:
raise RuntimeError("No episodes found in dataset metadata.")
default_idx = episode_indices[0]
default_label = format_episode_label(default_idx)
default_data = load_episode(DEFAULT_DATASET_ID, default_idx, revision)
default_video = str(default_data["rgb_path"])
default_instruction = default_data["instruction"]
default_figs = {metric: build_plot_html(default_data, metric) for metric in METRIC_LABELS.keys()}
total_frames = sum(ep.get("num_frames", 0) for ep in info["episodes"])
fps = info.get("fps", 30.0)
stats_html = f"""
<div class="stats-card">
<ul>
<li><span>Number of samples/frames:</span> {total_frames:,}</li>
<li><span>Number of episodes:</span> {len(episode_indices)}</li>
<li><span>Frames per second:</span> {fps:.1f}</li>
</ul>
</div>
"""
theme = gr.themes.Soft(
primary_hue="cyan", secondary_hue="blue", neutral_hue="slate"
).set(
body_background_fill="#0c1424",
body_text_color="#f8fafc",
block_background_fill="#111a2c",
block_title_text_color="#f8fafc",
input_background_fill="#151f33",
border_color_primary="#1f2b47",
shadow_drop="none",
)
with gr.Blocks(theme=theme, css=CUSTOM_CSS) as demo:
gr.Markdown("# Humanoid Robots Hand Pose Viewer")
gr.Markdown(
"Visualize RGB + 6DoF hand trajectories for all Moving_Mini tasks "
"(humanoid-robots-training-dataset)."
)
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=260, elem_classes=["side-panel"]):
gr.HTML(stats_html)
gr.HTML('<div class="episodes-title">Episodes</div>')
episode_radio = gr.Radio(
choices=[format_episode_label(i) for i in episode_indices],
value=default_label,
label="Episodes",
elem_classes=["episode-list"],
)
with gr.Column(scale=2, min_width=640, elem_classes=["main-panel"]):
instruction_box = gr.HTML(
format_instruction_html(default_instruction),
label="Language Instruction",
)
with gr.Column(elem_classes=["video-card"]):
gr.HTML('<div class="video-title">RGB</div>')
video = gr.Video(
height=360,
value=default_video,
elem_classes=["video-panel"],
show_label=False,
show_download_button=False,
)
download_button = gr.DownloadButton(
label="Download",
value=default_video,
elem_classes=["download-button"],
)
plot_outputs = []
gr.Markdown("### Joint trajectories", elem_classes=["plots-title"])
with gr.Column(elem_classes=["plots-wrap"]):
for row in PLOT_GRID:
with gr.Row():
for metric in row:
plot = gr.HTML(value=default_figs[metric], elem_classes=["plot-html"])
plot_outputs.append(plot)
outputs = [instruction_box, video, download_button] + plot_outputs
def load_episode_payload(label: str):
idx = parse_episode_label(label)
data = load_episode(DEFAULT_DATASET_ID, idx, revision)
video_path = str(data["rgb_path"])
figs = [build_plot_html(data, metric) for metric in PLOT_ORDER]
return [
format_instruction_html(data["instruction"]),
video_path,
gr.DownloadButton.update(value=video_path),
*figs,
]
episode_radio.change(fn=load_episode_payload, inputs=episode_radio, outputs=outputs)
return demo
def main():
demo = build_interface()
demo.queue().launch(show_api=False)
if __name__ == "__main__":
main()