Spaces:
Sleeping
Sleeping
File size: 3,123 Bytes
1602469 6442fbd 23aef68 6442fbd 23aef68 1602469 4337a31 1602469 23aef68 4337a31 f44e37f 4337a31 f44e37f 4337a31 23aef68 4337a31 23aef68 4337a31 23aef68 4337a31 23aef68 6442fbd 23aef68 6442fbd 4337a31 23aef68 4337a31 9901ded 23aef68 4337a31 5245583 4337a31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import logging
from typing import Any
import numpy as np
import rerun as rr
from datasets import load_dataset
from PIL import Image
from tqdm import tqdm
logger = logging.getLogger(__name__)
def to_rerun(column_name: str, value: Any) -> Any:
"""Do our best to interpret the value and convert it to a Rerun-compatible archetype."""
if isinstance(value, Image.Image):
if "depth" in column_name:
return rr.DepthImage(value)
else:
return rr.Image(value)
elif isinstance(value, np.ndarray):
return rr.Tensor(value)
elif isinstance(value, list):
if isinstance(value[0], float):
return rr.BarChart(value)
else:
return rr.TextDocument(str(value)) # Fallback to text
elif isinstance(value, float) or isinstance(value, int):
return rr.Scalar(value)
else:
return rr.TextDocument(str(value)) # Fallback to text
def log_dataset_to_rerun(dataset) -> None:
# Special time-like columns for LeRobot datasets (https://huggingface.co/datasets/lerobot/):
TIME_LIKE = {"index", "frame_id", "timestamp"}
# Ignore these columns (again, LeRobot-specific):
IGNORE = {"episode_data_index_from", "episode_data_index_to", "episode_id"}
for row in tqdm(dataset):
# Handle time-like columns first, since they set a state (time is an index in Rerun):
for column_name in TIME_LIKE:
if column_name in row:
cell = row[column_name]
if isinstance(cell, int):
rr.set_time_sequence(column_name, cell)
elif isinstance(cell, float):
rr.set_time_seconds(column_name, cell) # assume seconds
else:
print(f"Unknown time-like column {column_name} with value {cell}")
# Now log actual data columns:
for column_name, cell in row.items():
if column_name in TIME_LIKE or column_name in IGNORE:
continue
rr.log(column_name, to_rerun(column_name, cell))
def main():
# Ensure the logging gets written to stderr:
logging.getLogger().addHandler(logging.StreamHandler())
logging.getLogger().setLevel(logging.INFO)
parser = argparse.ArgumentParser(description="Log a HuggingFace dataset to Rerun.")
parser.add_argument("--dataset", default="lerobot/pusht", help="The name of the dataset to load")
parser.add_argument("--episode-id", default=1, help="Which episode to select")
args = parser.parse_args()
print("Loading dataset…")
dataset = load_dataset(args.dataset, split="train", streaming=True)
# This is for LeRobot datasets (https://huggingface.co/lerobot):
ds_subset = dataset.filter(lambda frame: "episode_id" not in frame or frame["episode_id"] == args.episode_id)
print("Starting Rerun…")
rr.init(f"rerun_example_huggingface {args.dataset}", spawn=True)
print("Logging to Rerun…")
log_dataset_to_rerun(ds_subset)
if __name__ == "__main__":
main()
|