File size: 3,123 Bytes
1602469
 
 
 
6442fbd
23aef68
 
6442fbd
23aef68
1602469
 
4337a31
 
1602469
23aef68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4337a31
 
f44e37f
4337a31
 
f44e37f
4337a31
 
23aef68
4337a31
 
 
 
 
 
 
 
 
 
 
23aef68
 
4337a31
 
 
23aef68
4337a31
 
 
23aef68
 
 
 
6442fbd
23aef68
6442fbd
 
 
4337a31
23aef68
4337a31
9901ded
23aef68
4337a31
 
5245583
4337a31
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3

from __future__ import annotations

import argparse
import logging
from typing import Any

import numpy as np
import rerun as rr
from datasets import load_dataset
from PIL import Image
from tqdm import tqdm

logger = logging.getLogger(__name__)


def to_rerun(column_name: str, value: Any) -> Any:
    """Do our best to interpret the value and convert it to a Rerun-compatible archetype."""
    if isinstance(value, Image.Image):
        if "depth" in column_name:
            return rr.DepthImage(value)
        else:
            return rr.Image(value)
    elif isinstance(value, np.ndarray):
        return rr.Tensor(value)
    elif isinstance(value, list):
        if isinstance(value[0], float):
            return rr.BarChart(value)
        else:
            return rr.TextDocument(str(value))  # Fallback to text
    elif isinstance(value, float) or isinstance(value, int):
        return rr.Scalar(value)
    else:
        return rr.TextDocument(str(value))  # Fallback to text


def log_dataset_to_rerun(dataset) -> None:
    # Special time-like columns for LeRobot datasets (https://huggingface.co/datasets/lerobot/):
    TIME_LIKE = {"index", "frame_id", "timestamp"}

    # Ignore these columns (again, LeRobot-specific):
    IGNORE = {"episode_data_index_from", "episode_data_index_to", "episode_id"}

    for row in tqdm(dataset):
        # Handle time-like columns first, since they set a state (time is an index in Rerun):
        for column_name in TIME_LIKE:
            if column_name in row:
                cell = row[column_name]
                if isinstance(cell, int):
                    rr.set_time_sequence(column_name, cell)
                elif isinstance(cell, float):
                    rr.set_time_seconds(column_name, cell)  # assume seconds
                else:
                    print(f"Unknown time-like column {column_name} with value {cell}")

        # Now log actual data columns:
        for column_name, cell in row.items():
            if column_name in TIME_LIKE or column_name in IGNORE:
                continue

            rr.log(column_name, to_rerun(column_name, cell))


def main():
    # Ensure the logging gets written to stderr:
    logging.getLogger().addHandler(logging.StreamHandler())
    logging.getLogger().setLevel(logging.INFO)

    parser = argparse.ArgumentParser(description="Log a HuggingFace dataset to Rerun.")
    parser.add_argument("--dataset", default="lerobot/pusht", help="The name of the dataset to load")
    parser.add_argument("--episode-id", default=1, help="Which episode to select")
    args = parser.parse_args()

    print("Loading dataset…")
    dataset = load_dataset(args.dataset, split="train", streaming=True)

    # This is for LeRobot datasets (https://huggingface.co/lerobot):
    ds_subset = dataset.filter(lambda frame: "episode_id" not in frame or frame["episode_id"] == args.episode_id)

    print("Starting Rerun…")
    rr.init(f"rerun_example_huggingface {args.dataset}", spawn=True)

    print("Logging to Rerun…")
    log_dataset_to_rerun(ds_subset)


if __name__ == "__main__":
    main()