|
import dataclasses |
|
import enum |
|
import logging |
|
import pathlib |
|
import time |
|
|
|
import numpy as np |
|
from openpi_client import websocket_client_policy as _websocket_client_policy |
|
import polars as pl |
|
import rich |
|
import tqdm |
|
import tyro |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EnvMode(enum.Enum): |
|
"""Supported environments.""" |
|
|
|
ALOHA = "aloha" |
|
ALOHA_SIM = "aloha_sim" |
|
DROID = "droid" |
|
LIBERO = "libero" |
|
|
|
|
|
@dataclasses.dataclass |
|
class Args: |
|
"""Command line arguments.""" |
|
|
|
|
|
host: str = "0.0.0.0" |
|
|
|
port: int | None = 8000 |
|
|
|
api_key: str | None = None |
|
|
|
num_steps: int = 20 |
|
|
|
timing_file: pathlib.Path | None = None |
|
|
|
env: EnvMode = EnvMode.ALOHA_SIM |
|
|
|
|
|
class TimingRecorder: |
|
"""Records timing measurements for different keys.""" |
|
|
|
def __init__(self) -> None: |
|
self._timings: dict[str, list[float]] = {} |
|
|
|
def record(self, key: str, time_ms: float) -> None: |
|
"""Record a timing measurement for the given key.""" |
|
if key not in self._timings: |
|
self._timings[key] = [] |
|
self._timings[key].append(time_ms) |
|
|
|
def get_stats(self, key: str) -> dict[str, float]: |
|
"""Get statistics for the given key.""" |
|
times = self._timings[key] |
|
return { |
|
"mean": float(np.mean(times)), |
|
"std": float(np.std(times)), |
|
"p25": float(np.quantile(times, 0.25)), |
|
"p50": float(np.quantile(times, 0.50)), |
|
"p75": float(np.quantile(times, 0.75)), |
|
"p90": float(np.quantile(times, 0.90)), |
|
"p95": float(np.quantile(times, 0.95)), |
|
"p99": float(np.quantile(times, 0.99)), |
|
} |
|
|
|
def print_all_stats(self) -> None: |
|
"""Print statistics for all keys in a concise format.""" |
|
|
|
table = rich.table.Table( |
|
title="[bold blue]Timing Statistics[/bold blue]", |
|
show_header=True, |
|
header_style="bold white", |
|
border_style="blue", |
|
title_justify="center", |
|
) |
|
|
|
|
|
table.add_column("Metric", style="cyan", justify="left", no_wrap=True) |
|
|
|
|
|
stat_columns = [ |
|
("Mean", "yellow", "mean"), |
|
("Std", "yellow", "std"), |
|
("P25", "magenta", "p25"), |
|
("P50", "magenta", "p50"), |
|
("P75", "magenta", "p75"), |
|
("P90", "magenta", "p90"), |
|
("P95", "magenta", "p95"), |
|
("P99", "magenta", "p99"), |
|
] |
|
|
|
for name, style, _ in stat_columns: |
|
table.add_column(name, justify="right", style=style, no_wrap=True) |
|
|
|
|
|
for key in sorted(self._timings.keys()): |
|
stats = self.get_stats(key) |
|
values = [f"{stats[key]:.1f}" for _, _, key in stat_columns] |
|
table.add_row(key, *values) |
|
|
|
|
|
console = rich.console.Console(width=None, highlight=True) |
|
console.print(table) |
|
|
|
def write_parquet(self, path: pathlib.Path) -> None: |
|
"""Save the timings to a parquet file.""" |
|
logger.info(f"Writing timings to {path}") |
|
frame = pl.DataFrame(self._timings) |
|
path.parent.mkdir(parents=True, exist_ok=True) |
|
frame.write_parquet(path) |
|
|
|
|
|
def main(args: Args) -> None: |
|
obs_fn = { |
|
EnvMode.ALOHA: _random_observation_aloha, |
|
EnvMode.ALOHA_SIM: _random_observation_aloha, |
|
EnvMode.DROID: _random_observation_droid, |
|
EnvMode.LIBERO: _random_observation_libero, |
|
}[args.env] |
|
|
|
policy = _websocket_client_policy.WebsocketClientPolicy( |
|
host=args.host, |
|
port=args.port, |
|
api_key=args.api_key, |
|
) |
|
logger.info(f"Server metadata: {policy.get_server_metadata()}") |
|
|
|
|
|
for _ in range(2): |
|
policy.infer(obs_fn()) |
|
|
|
timing_recorder = TimingRecorder() |
|
|
|
for _ in tqdm.trange(args.num_steps, desc="Running policy"): |
|
inference_start = time.time() |
|
action = policy.infer(obs_fn()) |
|
timing_recorder.record("client_infer_ms", 1000 * (time.time() - inference_start)) |
|
for key, value in action.get("server_timing", {}).items(): |
|
timing_recorder.record(f"server_{key}", value) |
|
for key, value in action.get("policy_timing", {}).items(): |
|
timing_recorder.record(f"policy_{key}", value) |
|
|
|
timing_recorder.print_all_stats() |
|
|
|
if args.timing_file is not None: |
|
timing_recorder.write_parquet(args.timing_file) |
|
|
|
|
|
def _random_observation_aloha() -> dict: |
|
return { |
|
"state": np.ones((14,)), |
|
"images": { |
|
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), |
|
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), |
|
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), |
|
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), |
|
}, |
|
"prompt": "do something", |
|
} |
|
|
|
|
|
def _random_observation_droid() -> dict: |
|
return { |
|
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), |
|
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), |
|
"observation/joint_position": np.random.rand(7), |
|
"observation/gripper_position": np.random.rand(1), |
|
"prompt": "do something", |
|
} |
|
|
|
|
|
def _random_observation_libero() -> dict: |
|
return { |
|
"observation/state": np.random.rand(8), |
|
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), |
|
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), |
|
"prompt": "do something", |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
logging.basicConfig(level=logging.INFO) |
|
main(tyro.cli(Args)) |
|
|