uryzhpi's picture
Improve inference benchmark
bcb3a7b
import dataclasses
import enum
import logging
import pathlib
import time
import numpy as np
from openpi_client import websocket_client_policy as _websocket_client_policy
import polars as pl
import rich
import tqdm
import tyro
logger = logging.getLogger(__name__)
class EnvMode(enum.Enum):
"""Supported environments."""
ALOHA = "aloha"
ALOHA_SIM = "aloha_sim"
DROID = "droid"
LIBERO = "libero"
@dataclasses.dataclass
class Args:
"""Command line arguments."""
# Host and port to connect to the server.
host: str = "0.0.0.0"
# Port to connect to the server. If None, the server will use the default port.
port: int | None = 8000
# API key to use for the server.
api_key: str | None = None
# Number of steps to run the policy for.
num_steps: int = 20
# Path to save the timings to a parquet file. (e.g., timing.parquet)
timing_file: pathlib.Path | None = None
# Environment to run the policy in.
env: EnvMode = EnvMode.ALOHA_SIM
class TimingRecorder:
"""Records timing measurements for different keys."""
def __init__(self) -> None:
self._timings: dict[str, list[float]] = {}
def record(self, key: str, time_ms: float) -> None:
"""Record a timing measurement for the given key."""
if key not in self._timings:
self._timings[key] = []
self._timings[key].append(time_ms)
def get_stats(self, key: str) -> dict[str, float]:
"""Get statistics for the given key."""
times = self._timings[key]
return {
"mean": float(np.mean(times)),
"std": float(np.std(times)),
"p25": float(np.quantile(times, 0.25)),
"p50": float(np.quantile(times, 0.50)),
"p75": float(np.quantile(times, 0.75)),
"p90": float(np.quantile(times, 0.90)),
"p95": float(np.quantile(times, 0.95)),
"p99": float(np.quantile(times, 0.99)),
}
def print_all_stats(self) -> None:
"""Print statistics for all keys in a concise format."""
table = rich.table.Table(
title="[bold blue]Timing Statistics[/bold blue]",
show_header=True,
header_style="bold white",
border_style="blue",
title_justify="center",
)
# Add metric column with custom styling
table.add_column("Metric", style="cyan", justify="left", no_wrap=True)
# Add statistical columns with consistent styling
stat_columns = [
("Mean", "yellow", "mean"),
("Std", "yellow", "std"),
("P25", "magenta", "p25"),
("P50", "magenta", "p50"),
("P75", "magenta", "p75"),
("P90", "magenta", "p90"),
("P95", "magenta", "p95"),
("P99", "magenta", "p99"),
]
for name, style, _ in stat_columns:
table.add_column(name, justify="right", style=style, no_wrap=True)
# Add rows for each metric with formatted values
for key in sorted(self._timings.keys()):
stats = self.get_stats(key)
values = [f"{stats[key]:.1f}" for _, _, key in stat_columns]
table.add_row(key, *values)
# Print with custom console settings
console = rich.console.Console(width=None, highlight=True)
console.print(table)
def write_parquet(self, path: pathlib.Path) -> None:
"""Save the timings to a parquet file."""
logger.info(f"Writing timings to {path}")
frame = pl.DataFrame(self._timings)
path.parent.mkdir(parents=True, exist_ok=True)
frame.write_parquet(path)
def main(args: Args) -> None:
obs_fn = {
EnvMode.ALOHA: _random_observation_aloha,
EnvMode.ALOHA_SIM: _random_observation_aloha,
EnvMode.DROID: _random_observation_droid,
EnvMode.LIBERO: _random_observation_libero,
}[args.env]
policy = _websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
api_key=args.api_key,
)
logger.info(f"Server metadata: {policy.get_server_metadata()}")
# Send a few observations to make sure the model is loaded.
for _ in range(2):
policy.infer(obs_fn())
timing_recorder = TimingRecorder()
for _ in tqdm.trange(args.num_steps, desc="Running policy"):
inference_start = time.time()
action = policy.infer(obs_fn())
timing_recorder.record("client_infer_ms", 1000 * (time.time() - inference_start))
for key, value in action.get("server_timing", {}).items():
timing_recorder.record(f"server_{key}", value)
for key, value in action.get("policy_timing", {}).items():
timing_recorder.record(f"policy_{key}", value)
timing_recorder.print_all_stats()
if args.timing_file is not None:
timing_recorder.write_parquet(args.timing_file)
def _random_observation_aloha() -> dict:
return {
"state": np.ones((14,)),
"images": {
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
},
"prompt": "do something",
}
def _random_observation_droid() -> dict:
return {
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/joint_position": np.random.rand(7),
"observation/gripper_position": np.random.rand(1),
"prompt": "do something",
}
def _random_observation_libero() -> dict:
return {
"observation/state": np.random.rand(8),
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"prompt": "do something",
}
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main(tyro.cli(Args))