ror's picture
ror HF Staff
Probably v1
79e7993
import json
from copy import deepcopy
import numpy as np
def make_id(config: dict, keys_to_ignore: list[str]) -> str:
keys = sorted(set(config.keys()))
return "_".join(str(config[k]) for k in keys if k not in keys_to_ignore)
class ModelBenchmarkData:
def __init__(self, json_path: str) -> None:
with open(json_path, "r") as f:
self.data: dict = json.load(f)
def compute_ttft(self, measures: dict) -> list[float]:
return [dts[0] for dts in measures["dt_tokens"]]
def compute_itl(self, measures: dict) -> list[float]:
return [
(dts[-1] - dts[0]) / (len(dts) - 1) if len(dts) > 2 else 0
for dts in measures["dt_tokens"]
]
def compute_throughput(self, measures: dict, batch_size: int) -> list[float]:
return [
(batch_size * len(dts) / e2e) if e2e > 0 else 0
for e2e, dts in zip(measures["e2e_latency"], measures["dt_tokens"])
]
def compute_e2e_latency(self, measures: dict) -> list[float]:
return measures["e2e_latency"][:]
def ensure_coherence(self) -> tuple[int, int, int]:
all_hyperparams = set()
for data in self.data.values():
config = data["config"]
hyperparams = (
config["batch_size"],
config["sequence_length"],
config["num_tokens_to_generate"],
)
all_hyperparams.add(hyperparams)
if len(all_hyperparams) > 1:
raise ValueError(
f"Different batch size, sequence length or nb of tokens to generate between configs: {all_hyperparams}"
)
return all_hyperparams.pop()
def get_bar_plot_data(
self, collapse_on_cache: bool = True, collapse_on_compile_mode: bool = True
) -> dict:
# Gather data for each scenario
per_scenario_data = {}
for cfg_name, data in self.data.items():
per_scenario_data[cfg_name] = {
"ttft": self.compute_ttft(data["measures"]),
"itl": self.compute_itl(data["measures"]),
"e2e": self.compute_e2e_latency(data["measures"]),
"config": data["config"],
}
# Eventually collapse on cache
if collapse_on_cache:
collapsed_keys = {}
for cfg_name, data in per_scenario_data.items():
keys_to_ignore = ["name"]
keys_to_ignore += ["use_cache"] if collapse_on_cache else []
keys_to_ignore += ["compile_mode"] if collapse_on_compile_mode else []
duply_cfg = deepcopy(data["config"])
duply_cfg["compiled"] = duply_cfg["compile_mode"] is not None
cfg_id = make_id(duply_cfg, keys_to_ignore)
cfg_e2e = np.mean(data["e2e"])
other_name, other_e2e = collapsed_keys.get(cfg_id, (None, 1e16))
if cfg_e2e < other_e2e:
collapsed_keys[cfg_id] = (cfg_name, cfg_e2e)
per_scenario_data = {
k: per_scenario_data[k] for k, _ in collapsed_keys.values()
}
return per_scenario_data
def load_data(
keep_common_scenarios_only: bool = False,
) -> dict[str, ModelBenchmarkData]:
data = {
"MI325": ModelBenchmarkData("mi325_data.json"),
"H100": ModelBenchmarkData("h100_data.json"),
}
if keep_common_scenarios_only:
common_scenarios = set(data["MI325"].data.keys()) & set(
data["H100"].data.keys()
)
for device_name, device_data in data.items():
device_data.data = {
k: v for k, v in device_data.data.items() if k in common_scenarios
}
return data