jleibs commited on
Commit
3f4e81a
1 Parent(s): a235944

Integrate with rerun dataset converter

Browse files
Files changed (2) hide show
  1. app.py +27 -57
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import urllib
2
  from collections import namedtuple
3
  from math import cos, sin
@@ -5,10 +6,13 @@ from typing import Any
5
 
6
  import gradio as gr
7
  import numpy as np
 
8
  import rerun as rr
9
  import rerun.blueprint as rrb
 
10
  from fastapi import FastAPI
11
  from fastapi.middleware.cors import CORSMiddleware
 
12
 
13
  CUSTOM_PATH = "/"
14
 
@@ -24,81 +28,47 @@ app.add_middleware(
24
  )
25
 
26
 
27
- ColorGrid = namedtuple("ColorGrid", ["positions", "colors"])
28
-
29
-
30
- def build_color_grid(x_count: int = 10, y_count: int = 10, z_count: int = 10, twist: float = 0) -> ColorGrid:
31
- """
32
- Create a cube of points with colors.
33
-
34
- The total point cloud will have x_count * y_count * z_count points.
35
-
36
- Parameters
37
- ----------
38
- x_count, y_count, z_count:
39
- Number of points in each dimension.
40
- twist:
41
- Angle to twist from bottom to top of the cube
42
-
43
- """
44
-
45
- grid = np.mgrid[
46
- slice(-x_count, x_count, x_count * 1j),
47
- slice(-y_count, y_count, y_count * 1j),
48
- slice(-z_count, z_count, z_count * 1j),
49
- ]
50
-
51
- angle = np.linspace(-float(twist) / 2, float(twist) / 2, z_count)
52
- for z in range(z_count):
53
- xv, yv, zv = grid[:, :, :, z]
54
- rot_xv = xv * cos(angle[z]) - yv * sin(angle[z])
55
- rot_yv = xv * sin(angle[z]) + yv * cos(angle[z])
56
- grid[:, :, :, z] = [rot_xv, rot_yv, zv]
57
-
58
- positions = np.vstack([xyz.ravel() for xyz in grid])
59
-
60
- colors = np.vstack([
61
- xyz.ravel()
62
- for xyz in np.mgrid[
63
- slice(0, 255, x_count * 1j),
64
- slice(0, 255, y_count * 1j),
65
- slice(0, 255, z_count * 1j),
66
- ]
67
- ])
68
-
69
- return ColorGrid(positions.T, colors.T.astype(np.uint8))
70
-
71
-
72
  def html_template(rrd: str, app_url: str = "https://app.rerun.io") -> str:
73
  encoded_url = urllib.parse.quote(rrd)
74
  return f"""<div style="width:100%; height:70vh;"><iframe style="width:100%; height:100%;" src="{app_url}?url={encoded_url}" frameborder="0" allowfullscreen=""></iframe></div>"""
75
 
76
 
77
- def show_cube(x: int, y: int, z: int) -> str:
78
- rr.init("my data")
 
 
 
 
 
 
 
79
 
80
- cube = build_color_grid(int(x), int(y), int(z), twist=0)
81
- rr.log("cube", rr.Points3D(cube.positions, colors=cube.colors, radii=0.5))
82
 
83
- blueprint = rrb.Spatial3DView(origin="cube")
 
84
 
85
- rr.save("cube.rrd", default_blueprint=blueprint)
86
 
87
- return "cube.rrd"
88
 
89
 
90
  with gr.Blocks() as demo:
91
  with gr.Row():
92
- x_count = gr.Number(minimum=1, maximum=10, value=5, precision=0, label="X Count")
93
- y_count = gr.Number(minimum=1, maximum=10, value=5, precision=0, label="Y Count")
94
- z_count = gr.Number(minimum=1, maximum=10, value=5, precision=0, label="Z Count")
95
- button = gr.Button("Show Cube")
 
 
 
 
96
  with gr.Row():
97
  rrd = gr.File()
98
  with gr.Row():
99
  viewer = gr.HTML()
100
 
101
- button.click(show_cube, inputs=[x_count, y_count, z_count], outputs=rrd)
102
  rrd.change(
103
  html_template,
104
  js="""(rrd) => { console.log(rrd.url); return rrd.url}""",
 
1
+ from pathlib import Path
2
  import urllib
3
  from collections import namedtuple
4
  from math import cos, sin
 
6
 
7
  import gradio as gr
8
  import numpy as np
9
+ from dataset_conversion import log_dataset_to_rerun
10
  import rerun as rr
11
  import rerun.blueprint as rrb
12
+ from datasets import load_dataset
13
  from fastapi import FastAPI
14
  from fastapi.middleware.cors import CORSMiddleware
15
+ from gradio_huggingfacehub_search import HuggingfaceHubSearch
16
 
17
  CUSTOM_PATH = "/"
18
 
 
28
  )
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def html_template(rrd: str, app_url: str = "https://app.rerun.io") -> str:
32
  encoded_url = urllib.parse.quote(rrd)
33
  return f"""<div style="width:100%; height:70vh;"><iframe style="width:100%; height:100%;" src="{app_url}?url={encoded_url}" frameborder="0" allowfullscreen=""></iframe></div>"""
34
 
35
 
36
+ def show_dataset(dataset_id: str, episode_id: int) -> str:
37
+ rr.init("dataset")
38
+
39
+ # TODO(jleibs): manage cache better and put in proper storage
40
+ filename = Path(f"tmp/{dataset_id}_{episode_id}.rrd")
41
+ if not filename.exists():
42
+ filename.parent.mkdir(parents=True, exist_ok=True)
43
+
44
+ rr.save(filename.as_posix())
45
 
46
+ dataset = load_dataset(dataset_id, split="train", streaming=True)
 
47
 
48
+ # This is for LeRobot datasets (https://huggingface.co/lerobot):
49
+ ds_subset = dataset.filter(lambda frame: "episode_index" not in frame or frame["episode_index"] == episode_id)
50
 
51
+ log_dataset_to_rerun(ds_subset)
52
 
53
+ return filename.as_posix()
54
 
55
 
56
  with gr.Blocks() as demo:
57
  with gr.Row():
58
+ search_in = HuggingfaceHubSearch(
59
+ "lerobot/pusht",
60
+ label="Search Huggingface Hub",
61
+ placeholder="Search for models on Huggingface",
62
+ search_type="dataset",
63
+ )
64
+ episode_id = gr.Number(1, label="Episode ID")
65
+ button = gr.Button("Show Dataset")
66
  with gr.Row():
67
  rrd = gr.File()
68
  with gr.Row():
69
  viewer = gr.HTML()
70
 
71
+ button.click(show_dataset, inputs=[search_in, episode_id], outputs=rrd)
72
  rrd.change(
73
  html_template,
74
  js="""(rrd) => { console.log(rrd.url); return rrd.url}""",
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  datasets
2
  h5py
3
  gradio==4.27.0
 
4
  pillow
5
  rerun-sdk>=0.15.0,<0.16.0
6
  tqdm
 
1
  datasets
2
  h5py
3
  gradio==4.27.0
4
+ gradio_huggingfacehub_search
5
  pillow
6
  rerun-sdk>=0.15.0,<0.16.0
7
  tqdm