risk_biased_prediction / scripts /eval_scripts /evaluate_prediction_planning_stack.py
jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
raw
history blame
No virus
15.3 kB
"""evaluate_prediction_planning_stack.py --load_from <wandb ID> --seed <seed>
--num_episodes <num_episodes> --risk_level <a list of risk-levels>
--num_samples <a list of numbers of prediction samples>
This script loads a trained predictor from <wand ID>, runs a batch of open-loop MPC evaluations
(i.e., without replanning) using <num_episodes> episodes while varying risk-levels and numbers of
prediction samples. Results are stored in scripts/logs/planner_eval/run-<wandb ID>_<seed> as a
collection of pickle files.
"""
import argparse
import os
import pickle
from time import time_ns
from typing import List, Tuple
import sys
import torch
from mmcv import Config
from pytorch_lightning import seed_everything
from tqdm import trange
from risk_biased.mpc_planner.planner import MPCPlanner, MPCPlannerParams
from risk_biased.predictors.biased_predictor import LitTrajectoryPredictor
from risk_biased.scene_dataset.loaders import SceneDataLoaders
from risk_biased.scene_dataset.scene import RandomScene, RandomSceneParams
from risk_biased.utils.callbacks import get_fast_slow_scenes
from risk_biased.utils.load_model import load_from_config, config_argparse
from risk_biased.utils.planner_utils import (
evaluate_control_sequence,
get_interaction_cost,
AbstractState,
to_state,
)
def evaluate_main(
load_from: str,
seed: int,
num_episodes: int,
risk_level_list: List[float],
num_prediction_samples_list: List[int],
):
print(f"Risk-sensitivity level(s) to test: {risk_level_list}")
print(f"Number(s) of prediction samples to test: {num_prediction_samples_list} ")
save_dir = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"logs",
"planner_eval",
f"run-{load_from}_{seed}",
)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cfg, planner = get_cfg_and_planner(load_from=load_from)
if not planner.solver.params.mean_warm_start == False:
print(
"switching to mean_warm_start = False for open-loop evaluation (i.e. without re-planning)"
)
planner.solver.params.mean_warm_start = False
for scene_type in [
"safer_slow",
"safer_fast",
]:
print(f"\nTesting {scene_type} scenes")
seed_everything(seed)
(
scene,
ado_state_history_batch,
ado_state_future_batch,
) = get_scene_and_ado_trajectory(
cfg, scene_type=scene_type, num_episodes=num_episodes
)
(
ego_state_history,
ego_state_target_trajectory,
) = get_ego_state_history_and_target_trajectory(cfg, scene)
for stack_risk_level in risk_level_list:
print(f" Risk_level: {stack_risk_level}")
file_name = f"{scene_type}_no_policy_opt_risk_level_{stack_risk_level}"
print(f" {file_name}")
stats_dict_no_policy_opt = evaluate_prediction_planning_stack(
planner,
ado_state_history_batch,
ado_state_future_batch,
ego_state_history,
ego_state_target_trajectory,
optimize_policy=False,
stack_risk_level=stack_risk_level,
risk_in_predictor=False,
)
with open(os.path.join(save_dir, file_name + ".pkl"), "wb") as f:
pickle.dump(stats_dict_no_policy_opt, f)
for num_prediction_samples in num_prediction_samples_list:
file_name = f"{scene_type}_{num_prediction_samples}_samples_risk_level_{stack_risk_level}"
if stack_risk_level == 0.0:
print(f" {file_name}")
stats_dict_risk_neutral = evaluate_prediction_planning_stack(
planner,
ado_state_history_batch,
ado_state_future_batch,
ego_state_history,
ego_state_target_trajectory,
optimize_policy=True,
stack_risk_level=stack_risk_level,
risk_in_predictor=False,
num_prediction_samples=num_prediction_samples,
)
with open(os.path.join(save_dir, file_name + ".pkl"), "wb") as f:
pickle.dump(stats_dict_risk_neutral, f)
else:
file_name_in_predictor = file_name + "_in_predictor"
print(f" {file_name_in_predictor}")
stats_dict_risk_in_predictor = evaluate_prediction_planning_stack(
planner,
ado_state_history_batch,
ado_state_future_batch,
ego_state_history,
ego_state_target_trajectory,
optimize_policy=True,
stack_risk_level=stack_risk_level,
risk_in_predictor=True,
num_prediction_samples=num_prediction_samples,
)
with open(
os.path.join(save_dir, file_name_in_predictor + ".pkl"), "wb"
) as f:
pickle.dump(stats_dict_risk_in_predictor, f)
file_name_in_planner = file_name + "_in_planner"
print(f" {file_name_in_planner}")
stats_dict_risk_in_planner = evaluate_prediction_planning_stack(
planner,
ado_state_history_batch,
ado_state_future_batch,
ego_state_history,
ego_state_target_trajectory,
optimize_policy=True,
stack_risk_level=stack_risk_level,
risk_in_predictor=False,
num_prediction_samples=num_prediction_samples,
)
with open(
os.path.join(save_dir, file_name_in_planner + ".pkl"), "wb"
) as f:
pickle.dump(stats_dict_risk_in_planner, f)
def evaluate_prediction_planning_stack(
planner: MPCPlanner,
ado_state_history_batch: AbstractState,
ado_state_future_batch: AbstractState,
ego_state_history: AbstractState,
ego_state_target_trajectory: AbstractState,
optimize_policy: bool = True,
stack_risk_level: float = 0.0,
risk_in_predictor: bool = False,
num_prediction_samples: int = 128,
num_prediction_samples_for_policy_eval: int = 4096,
) -> dict:
assert planner.solver.params.mean_warm_start == False
if risk_in_predictor:
predictor_risk_level, planner_risk_level = stack_risk_level, 0.0
else:
predictor_risk_level, planner_risk_level = 0.0, stack_risk_level
stats_dict = {
"stack_risk_level": stack_risk_level,
"predictor_risk_level": predictor_risk_level,
"planner_risk_level": planner_risk_level,
}
num_episodes = ado_state_history_batch.shape[0]
assert num_episodes == ado_state_future_batch.shape[0]
for episode_id in trange(num_episodes, desc="episodes", leave=False):
ado_state_history = ado_state_history_batch[episode_id]
ado_state_future = ado_state_future_batch[episode_id]
(
ado_state_future_samples_for_policy_eval,
sample_weights,
) = planner.solver.sample_prediction(
planner.predictor,
ado_state_history,
planner.normalizer,
ego_state_history=ego_state_history,
ego_state_future=ego_state_target_trajectory,
num_prediction_samples=num_prediction_samples_for_policy_eval,
risk_level=0.0,
)
if optimize_policy:
start = time_ns()
solver_info = planner.solver.solve(
planner.predictor,
ego_state_history,
ego_state_target_trajectory,
ado_state_history,
planner.normalizer,
num_prediction_samples=num_prediction_samples,
verbose=False,
risk_level=stack_risk_level,
resample_prediction=False,
risk_in_predictor=risk_in_predictor,
)
end = time_ns()
computation_time_ms = (end - start) * 1e-6
else:
planner.solver.reset()
computation_time_ms = 0.0
solver_info = None
interaction_cost_gt = get_ground_truth_interaction_cost(
planner, ado_state_future, ego_state_history
)
interaction_risk, tracking_cost = evaluate_control_sequence(
planner.solver.control_sequence,
planner.solver.dynamics_model,
ego_state_history,
ego_state_target_trajectory,
ado_state_future_samples_for_policy_eval,
sample_weights,
planner.solver.interaction_cost,
planner.solver.tracking_cost,
risk_level=stack_risk_level,
risk_estimator=planner.solver.risk_estimator,
)
stats_dict_this_run = {
"computation_time_ms": computation_time_ms,
"interaction_cost_ground_truth": interaction_cost_gt,
"interaction_risk": interaction_risk,
"tracking_cost": tracking_cost,
"control_sequence": planner.solver.control_sequence,
"solver_info": solver_info,
"ado_unbiased_predictions": ado_state_future_samples_for_policy_eval.position.detach()
.cpu()
.numpy(),
"sample_weights": sample_weights.detach().cpu().numpy(),
"ado_position_future": ado_state_future.position.detach().cpu().numpy(),
"ado_position_history": ado_state_history.position.detach().cpu().numpy(),
}
stats_dict[episode_id] = stats_dict_this_run
return stats_dict
def get_cfg_and_predictor() -> Tuple[Config, LitTrajectoryPredictor]:
config_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"..",
"..",
"risk_biased",
"config",
"learning_config.py",
)
cfg = config_argparse(config_path)
predictor, _, cfg = load_from_config(cfg)
return cfg, predictor
def get_cfg_and_planner(load_from: str) -> Tuple[Config, MPCPlanner]:
planner_config_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"..",
"..",
"risk_biased",
"config",
"planning_config.py",
)
planner_cfg = Config.fromfile(planner_config_path)
cfg, predictor = get_cfg_and_predictor()
# joint_dict = {**dict(cfg), **dict(planner_cfg)}
# assert joint_dict == {
# **dict(planner_cfg),
# **dict(cfg),
# }, f"some of the entries conflict between {cfg.filename} and {planner_cfg.filename}"
# joint_cfg = Config(joint_dict)
cfg.update(planner_cfg)
planner_params = MPCPlannerParams.from_config(cfg)
normalizer = SceneDataLoaders.normalize_trajectory
planner = MPCPlanner(planner_params, predictor, normalizer)
return cfg, planner
def get_scene_and_ado_trajectory(
cfg: Config, scene_type: str, num_episodes: int
) -> Tuple[RandomScene, torch.Tensor, torch.Tensor]:
scene_params = RandomSceneParams.from_config(cfg)
safer_fast_scene, safer_slow_scene = get_fast_slow_scenes(
scene_params, num_episodes
)
assert scene_type in ["safer_fast", "safer_slow"]
if scene_type == "safer_fast":
scene = safer_fast_scene
elif scene_type == "safer_slow":
scene = safer_slow_scene
else:
raise ValueError(f"unknown scene type {scene_type}")
ado_trajectory = torch.from_numpy(
scene.get_pedestrians_trajectories().astype("float32")
)
ado_position_history = to_state(ado_trajectory[..., : cfg.num_steps, :], cfg.dt)
ado_position_future = to_state(ado_trajectory[..., cfg.num_steps :, :], cfg.dt)
return scene, ado_position_history, ado_position_future
def get_ego_state_history_and_target_trajectory(
cfg: Config, scene: RandomScene
) -> Tuple[torch.Tensor, torch.Tensor]:
ego_state_traj = to_state(
torch.from_numpy(
scene.get_ego_ref_trajectory(cfg.sample_times).astype("float32")
),
cfg.dt,
)
ego_state_history = ego_state_traj[0, :, : cfg.num_steps]
ego_state_target_trajectory = ego_state_traj[0, :, cfg.num_steps :]
return ego_state_history, ego_state_target_trajectory
def get_ground_truth_interaction_cost(
planner: MPCPlanner,
ado_state_future: AbstractState, # (num_agents, num_steps_future)
ego_state_history: AbstractState, # (1, 1, num_steps)
) -> float:
ego_state_future = planner.solver.dynamics_model.simulate(
ego_state_history[..., -1], planner.solver.control_sequence.unsqueeze(0)
)
interaction_cost = get_interaction_cost(
ego_state_future,
ado_state_future,
planner.solver.interaction_cost,
)
return interaction_cost.item()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="evaluate prediciton-planning stack using safer_fast and safer_slow scenes"
)
parser.add_argument(
"--load_from",
type=str,
required=True,
help="WandB ID to load trained predictor from",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Random seed for evaluation",
)
parser.add_argument(
"--num_episodes",
type=int,
default=100,
help="Number of episodes",
)
parser.add_argument(
"--risk_level",
type=float,
nargs="+",
help="Risk-sensitivity level(s) to test",
default=[0.95, 1],
)
parser.add_argument(
"--num_samples",
type=int,
nargs="+",
help="Number(s) of prediction samples to test",
default=[1, 4, 16, 64, 256, 1024],
)
parser.add_argument(
"--force_config",
action="store_true",
help="""Use this flag to force the use of the local config file
when loading a model from a checkpoint. Otherwise the checkpoint config file is used.
In any case the parameters can be overwritten with an argparse argument.""",
)
parser.add_argument(
"--load_last",
action="store_true",
help="""Use this flag to force the use of the last checkpoint instead of the best one
when loading a model.""",
)
args = parser.parse_args()
# Args will be re-parsed, this keeps only the arguments that are compatible with the second parser.
keep_list = ["--load_from", "--seed", "--load_last", "--force_config"]
sys.argv = [ss for s in sys.argv for ss in s.split("=")]
sys.argv = [
sys.argv[i]
for i in range(len(sys.argv))
if sys.argv[i] in keep_list or sys.argv[i - 1] in keep_list or i == 0
]
evaluate_main(
args.load_from, args.seed, args.num_episodes, args.risk_level, args.num_samples
)