File size: 6,281 Bytes
231a1cf bcb3a7b 231a1cf bcb3a7b 231a1cf bcb3a7b 231a1cf bcb3a7b 231a1cf bcb3a7b 231a1cf bcb3a7b 231a1cf bcb3a7b 231a1cf bcb3a7b 231a1cf bcb3a7b 231a1cf bcb3a7b 231a1cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import dataclasses
import enum
import logging
import pathlib
import time
import numpy as np
from openpi_client import websocket_client_policy as _websocket_client_policy
import polars as pl
import rich
import tqdm
import tyro
logger = logging.getLogger(__name__)
class EnvMode(enum.Enum):
"""Supported environments."""
ALOHA = "aloha"
ALOHA_SIM = "aloha_sim"
DROID = "droid"
LIBERO = "libero"
@dataclasses.dataclass
class Args:
"""Command line arguments."""
# Host and port to connect to the server.
host: str = "0.0.0.0"
# Port to connect to the server. If None, the server will use the default port.
port: int | None = 8000
# API key to use for the server.
api_key: str | None = None
# Number of steps to run the policy for.
num_steps: int = 20
# Path to save the timings to a parquet file. (e.g., timing.parquet)
timing_file: pathlib.Path | None = None
# Environment to run the policy in.
env: EnvMode = EnvMode.ALOHA_SIM
class TimingRecorder:
"""Records timing measurements for different keys."""
def __init__(self) -> None:
self._timings: dict[str, list[float]] = {}
def record(self, key: str, time_ms: float) -> None:
"""Record a timing measurement for the given key."""
if key not in self._timings:
self._timings[key] = []
self._timings[key].append(time_ms)
def get_stats(self, key: str) -> dict[str, float]:
"""Get statistics for the given key."""
times = self._timings[key]
return {
"mean": float(np.mean(times)),
"std": float(np.std(times)),
"p25": float(np.quantile(times, 0.25)),
"p50": float(np.quantile(times, 0.50)),
"p75": float(np.quantile(times, 0.75)),
"p90": float(np.quantile(times, 0.90)),
"p95": float(np.quantile(times, 0.95)),
"p99": float(np.quantile(times, 0.99)),
}
def print_all_stats(self) -> None:
"""Print statistics for all keys in a concise format."""
table = rich.table.Table(
title="[bold blue]Timing Statistics[/bold blue]",
show_header=True,
header_style="bold white",
border_style="blue",
title_justify="center",
)
# Add metric column with custom styling
table.add_column("Metric", style="cyan", justify="left", no_wrap=True)
# Add statistical columns with consistent styling
stat_columns = [
("Mean", "yellow", "mean"),
("Std", "yellow", "std"),
("P25", "magenta", "p25"),
("P50", "magenta", "p50"),
("P75", "magenta", "p75"),
("P90", "magenta", "p90"),
("P95", "magenta", "p95"),
("P99", "magenta", "p99"),
]
for name, style, _ in stat_columns:
table.add_column(name, justify="right", style=style, no_wrap=True)
# Add rows for each metric with formatted values
for key in sorted(self._timings.keys()):
stats = self.get_stats(key)
values = [f"{stats[key]:.1f}" for _, _, key in stat_columns]
table.add_row(key, *values)
# Print with custom console settings
console = rich.console.Console(width=None, highlight=True)
console.print(table)
def write_parquet(self, path: pathlib.Path) -> None:
"""Save the timings to a parquet file."""
logger.info(f"Writing timings to {path}")
frame = pl.DataFrame(self._timings)
path.parent.mkdir(parents=True, exist_ok=True)
frame.write_parquet(path)
def main(args: Args) -> None:
obs_fn = {
EnvMode.ALOHA: _random_observation_aloha,
EnvMode.ALOHA_SIM: _random_observation_aloha,
EnvMode.DROID: _random_observation_droid,
EnvMode.LIBERO: _random_observation_libero,
}[args.env]
policy = _websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
api_key=args.api_key,
)
logger.info(f"Server metadata: {policy.get_server_metadata()}")
# Send a few observations to make sure the model is loaded.
for _ in range(2):
policy.infer(obs_fn())
timing_recorder = TimingRecorder()
for _ in tqdm.trange(args.num_steps, desc="Running policy"):
inference_start = time.time()
action = policy.infer(obs_fn())
timing_recorder.record("client_infer_ms", 1000 * (time.time() - inference_start))
for key, value in action.get("server_timing", {}).items():
timing_recorder.record(f"server_{key}", value)
for key, value in action.get("policy_timing", {}).items():
timing_recorder.record(f"policy_{key}", value)
timing_recorder.print_all_stats()
if args.timing_file is not None:
timing_recorder.write_parquet(args.timing_file)
def _random_observation_aloha() -> dict:
return {
"state": np.ones((14,)),
"images": {
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
},
"prompt": "do something",
}
def _random_observation_droid() -> dict:
return {
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/joint_position": np.random.rand(7),
"observation/gripper_position": np.random.rand(1),
"prompt": "do something",
}
def _random_observation_libero() -> dict:
return {
"observation/state": np.random.rand(8),
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"prompt": "do something",
}
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main(tyro.cli(Args))
|