File size: 2,904 Bytes
9f1563e
6c0bd0b
9f1563e
 
 
 
 
 
 
bd1896f
 
573145d
bf7f9b5
573145d
 
 
3f4e81a
573145d
 
3f4e81a
586d2b6
573145d
586d2b6
bf7f9b5
573145d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a162ed1
3f4e81a
 
 
a162ed1
3f4e81a
 
 
 
573145d
586d2b6
 
 
 
 
 
 
 
 
 
 
573145d
3f4e81a
573145d
 
 
 
3f4e81a
 
 
 
 
 
a162ed1
3f4e81a
573145d
 
 
 
 
a162ed1
573145d
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
A Gradio app that uses Rerun to visualize a Hugging Face dataset.

This app mounts the Gradio app inside of FastAPI in order to set the CORS headers.

Run this from the terminal as you would normally start a FastAPI app: `uvicorn app:app`
and navigate to http://localhost:8000 in your browser.
"""

from __future__ import annotations

import urllib
from pathlib import Path

import gradio as gr
import rerun as rr
from datasets import load_dataset
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset

from dataset_conversion import log_dataset_to_rerun, log_lerobot_dataset_to_rerun

CUSTOM_PATH = "/"

app = FastAPI()

origins = [
    "https://app.rerun.io",
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
)


def html_template(rrd: str, app_url: str = "https://app.rerun.io") -> str:
    encoded_url = urllib.parse.quote(rrd)
    return f"""<div style="width:100%; height:70vh;"><iframe style="width:100%; height:100%;" src="{app_url}?url={encoded_url}" frameborder="0" allowfullscreen=""></iframe></div>"""


def show_dataset(dataset_id: str, episode_index: int) -> str:
    rr.init("dataset")

    # TODO(jleibs): manage cache better and put in proper storage
    filename = Path(f"tmp/{dataset_id}_{episode_index}.rrd")
    if not filename.exists():
        filename.parent.mkdir(parents=True, exist_ok=True)

        rr.save(filename.as_posix())

        if "/" in dataset_id and dataset_id.split("/")[0] == "lerobot":
            dataset = LeRobotDataset(dataset_id)
            log_lerobot_dataset_to_rerun(dataset, episode_index)
        else:
            dataset = load_dataset(dataset_id, split="train", streaming=True)

            # This is for LeRobot datasets (https://huggingface.co/lerobot):
            ds_subset = dataset.filter(
                lambda frame: "episode_index" not in frame or frame["episode_index"] == episode_index
            )
            log_dataset_to_rerun(ds_subset)

    return filename.as_posix()


with gr.Blocks() as demo:
    with gr.Row():
        search_in = HuggingfaceHubSearch(
            "lerobot/pusht",
            label="Search Huggingface Hub",
            placeholder="Search for models on Huggingface",
            search_type="dataset",
        )
        episode_index = gr.Number(1, label="Episode Index")
        button = gr.Button("Show Dataset")
    with gr.Row():
        rrd = gr.File()
    with gr.Row():
        viewer = gr.HTML()

    button.click(show_dataset, inputs=[search_in, episode_index], outputs=rrd)
    rrd.change(
        html_template,
        js="""(rrd) => { console.log(rrd.url); return rrd.url}""",
        inputs=[rrd],
        outputs=viewer,
        preprocess=False,
    )


app = gr.mount_gradio_app(app, demo, path=CUSTOM_PATH)