Spaces:
Sleeping
Sleeping
"""evaluate_prediction_planning_stack.py --load_from <wandb ID> --seed <seed> | |
--num_episodes <num_episodes> --risk_level <a list of risk-levels> | |
--num_samples <a list of numbers of prediction samples> | |
This script loads a trained predictor from <wand ID>, runs a batch of open-loop MPC evaluations | |
(i.e., without replanning) using <num_episodes> episodes while varying risk-levels and numbers of | |
prediction samples. Results are stored in scripts/logs/planner_eval/run-<wandb ID>_<seed> as a | |
collection of pickle files. | |
""" | |
import argparse | |
import os | |
import pickle | |
from time import time_ns | |
from typing import List, Tuple | |
import sys | |
import torch | |
from mmcv import Config | |
from pytorch_lightning import seed_everything | |
from tqdm import trange | |
from risk_biased.mpc_planner.planner import MPCPlanner, MPCPlannerParams | |
from risk_biased.predictors.biased_predictor import LitTrajectoryPredictor | |
from risk_biased.scene_dataset.loaders import SceneDataLoaders | |
from risk_biased.scene_dataset.scene import RandomScene, RandomSceneParams | |
from risk_biased.utils.callbacks import get_fast_slow_scenes | |
from risk_biased.utils.load_model import load_from_config, config_argparse | |
from risk_biased.utils.planner_utils import ( | |
evaluate_control_sequence, | |
get_interaction_cost, | |
AbstractState, | |
to_state, | |
) | |
def evaluate_main( | |
load_from: str, | |
seed: int, | |
num_episodes: int, | |
risk_level_list: List[float], | |
num_prediction_samples_list: List[int], | |
): | |
print(f"Risk-sensitivity level(s) to test: {risk_level_list}") | |
print(f"Number(s) of prediction samples to test: {num_prediction_samples_list} ") | |
save_dir = os.path.join( | |
os.path.dirname(os.path.realpath(__file__)), | |
"logs", | |
"planner_eval", | |
f"run-{load_from}_{seed}", | |
) | |
if not os.path.exists(save_dir): | |
os.makedirs(save_dir) | |
cfg, planner = get_cfg_and_planner(load_from=load_from) | |
if not planner.solver.params.mean_warm_start == False: | |
print( | |
"switching to mean_warm_start = False for open-loop evaluation (i.e. without re-planning)" | |
) | |
planner.solver.params.mean_warm_start = False | |
for scene_type in [ | |
"safer_slow", | |
"safer_fast", | |
]: | |
print(f"\nTesting {scene_type} scenes") | |
seed_everything(seed) | |
( | |
scene, | |
ado_state_history_batch, | |
ado_state_future_batch, | |
) = get_scene_and_ado_trajectory( | |
cfg, scene_type=scene_type, num_episodes=num_episodes | |
) | |
( | |
ego_state_history, | |
ego_state_target_trajectory, | |
) = get_ego_state_history_and_target_trajectory(cfg, scene) | |
for stack_risk_level in risk_level_list: | |
print(f" Risk_level: {stack_risk_level}") | |
file_name = f"{scene_type}_no_policy_opt_risk_level_{stack_risk_level}" | |
print(f" {file_name}") | |
stats_dict_no_policy_opt = evaluate_prediction_planning_stack( | |
planner, | |
ado_state_history_batch, | |
ado_state_future_batch, | |
ego_state_history, | |
ego_state_target_trajectory, | |
optimize_policy=False, | |
stack_risk_level=stack_risk_level, | |
risk_in_predictor=False, | |
) | |
with open(os.path.join(save_dir, file_name + ".pkl"), "wb") as f: | |
pickle.dump(stats_dict_no_policy_opt, f) | |
for num_prediction_samples in num_prediction_samples_list: | |
file_name = f"{scene_type}_{num_prediction_samples}_samples_risk_level_{stack_risk_level}" | |
if stack_risk_level == 0.0: | |
print(f" {file_name}") | |
stats_dict_risk_neutral = evaluate_prediction_planning_stack( | |
planner, | |
ado_state_history_batch, | |
ado_state_future_batch, | |
ego_state_history, | |
ego_state_target_trajectory, | |
optimize_policy=True, | |
stack_risk_level=stack_risk_level, | |
risk_in_predictor=False, | |
num_prediction_samples=num_prediction_samples, | |
) | |
with open(os.path.join(save_dir, file_name + ".pkl"), "wb") as f: | |
pickle.dump(stats_dict_risk_neutral, f) | |
else: | |
file_name_in_predictor = file_name + "_in_predictor" | |
print(f" {file_name_in_predictor}") | |
stats_dict_risk_in_predictor = evaluate_prediction_planning_stack( | |
planner, | |
ado_state_history_batch, | |
ado_state_future_batch, | |
ego_state_history, | |
ego_state_target_trajectory, | |
optimize_policy=True, | |
stack_risk_level=stack_risk_level, | |
risk_in_predictor=True, | |
num_prediction_samples=num_prediction_samples, | |
) | |
with open( | |
os.path.join(save_dir, file_name_in_predictor + ".pkl"), "wb" | |
) as f: | |
pickle.dump(stats_dict_risk_in_predictor, f) | |
file_name_in_planner = file_name + "_in_planner" | |
print(f" {file_name_in_planner}") | |
stats_dict_risk_in_planner = evaluate_prediction_planning_stack( | |
planner, | |
ado_state_history_batch, | |
ado_state_future_batch, | |
ego_state_history, | |
ego_state_target_trajectory, | |
optimize_policy=True, | |
stack_risk_level=stack_risk_level, | |
risk_in_predictor=False, | |
num_prediction_samples=num_prediction_samples, | |
) | |
with open( | |
os.path.join(save_dir, file_name_in_planner + ".pkl"), "wb" | |
) as f: | |
pickle.dump(stats_dict_risk_in_planner, f) | |
def evaluate_prediction_planning_stack( | |
planner: MPCPlanner, | |
ado_state_history_batch: AbstractState, | |
ado_state_future_batch: AbstractState, | |
ego_state_history: AbstractState, | |
ego_state_target_trajectory: AbstractState, | |
optimize_policy: bool = True, | |
stack_risk_level: float = 0.0, | |
risk_in_predictor: bool = False, | |
num_prediction_samples: int = 128, | |
num_prediction_samples_for_policy_eval: int = 4096, | |
) -> dict: | |
assert planner.solver.params.mean_warm_start == False | |
if risk_in_predictor: | |
predictor_risk_level, planner_risk_level = stack_risk_level, 0.0 | |
else: | |
predictor_risk_level, planner_risk_level = 0.0, stack_risk_level | |
stats_dict = { | |
"stack_risk_level": stack_risk_level, | |
"predictor_risk_level": predictor_risk_level, | |
"planner_risk_level": planner_risk_level, | |
} | |
num_episodes = ado_state_history_batch.shape[0] | |
assert num_episodes == ado_state_future_batch.shape[0] | |
for episode_id in trange(num_episodes, desc="episodes", leave=False): | |
ado_state_history = ado_state_history_batch[episode_id] | |
ado_state_future = ado_state_future_batch[episode_id] | |
( | |
ado_state_future_samples_for_policy_eval, | |
sample_weights, | |
) = planner.solver.sample_prediction( | |
planner.predictor, | |
ado_state_history, | |
planner.normalizer, | |
ego_state_history=ego_state_history, | |
ego_state_future=ego_state_target_trajectory, | |
num_prediction_samples=num_prediction_samples_for_policy_eval, | |
risk_level=0.0, | |
) | |
if optimize_policy: | |
start = time_ns() | |
solver_info = planner.solver.solve( | |
planner.predictor, | |
ego_state_history, | |
ego_state_target_trajectory, | |
ado_state_history, | |
planner.normalizer, | |
num_prediction_samples=num_prediction_samples, | |
verbose=False, | |
risk_level=stack_risk_level, | |
resample_prediction=False, | |
risk_in_predictor=risk_in_predictor, | |
) | |
end = time_ns() | |
computation_time_ms = (end - start) * 1e-6 | |
else: | |
planner.solver.reset() | |
computation_time_ms = 0.0 | |
solver_info = None | |
interaction_cost_gt = get_ground_truth_interaction_cost( | |
planner, ado_state_future, ego_state_history | |
) | |
interaction_risk, tracking_cost = evaluate_control_sequence( | |
planner.solver.control_sequence, | |
planner.solver.dynamics_model, | |
ego_state_history, | |
ego_state_target_trajectory, | |
ado_state_future_samples_for_policy_eval, | |
sample_weights, | |
planner.solver.interaction_cost, | |
planner.solver.tracking_cost, | |
risk_level=stack_risk_level, | |
risk_estimator=planner.solver.risk_estimator, | |
) | |
stats_dict_this_run = { | |
"computation_time_ms": computation_time_ms, | |
"interaction_cost_ground_truth": interaction_cost_gt, | |
"interaction_risk": interaction_risk, | |
"tracking_cost": tracking_cost, | |
"control_sequence": planner.solver.control_sequence, | |
"solver_info": solver_info, | |
"ado_unbiased_predictions": ado_state_future_samples_for_policy_eval.position.detach() | |
.cpu() | |
.numpy(), | |
"sample_weights": sample_weights.detach().cpu().numpy(), | |
"ado_position_future": ado_state_future.position.detach().cpu().numpy(), | |
"ado_position_history": ado_state_history.position.detach().cpu().numpy(), | |
} | |
stats_dict[episode_id] = stats_dict_this_run | |
return stats_dict | |
def get_cfg_and_predictor() -> Tuple[Config, LitTrajectoryPredictor]: | |
config_path = os.path.join( | |
os.path.dirname(os.path.realpath(__file__)), | |
"..", | |
"..", | |
"risk_biased", | |
"config", | |
"learning_config.py", | |
) | |
cfg = config_argparse(config_path) | |
predictor, _, cfg = load_from_config(cfg) | |
return cfg, predictor | |
def get_cfg_and_planner(load_from: str) -> Tuple[Config, MPCPlanner]: | |
planner_config_path = os.path.join( | |
os.path.dirname(os.path.realpath(__file__)), | |
"..", | |
"..", | |
"risk_biased", | |
"config", | |
"planning_config.py", | |
) | |
planner_cfg = Config.fromfile(planner_config_path) | |
cfg, predictor = get_cfg_and_predictor() | |
# joint_dict = {**dict(cfg), **dict(planner_cfg)} | |
# assert joint_dict == { | |
# **dict(planner_cfg), | |
# **dict(cfg), | |
# }, f"some of the entries conflict between {cfg.filename} and {planner_cfg.filename}" | |
# joint_cfg = Config(joint_dict) | |
cfg.update(planner_cfg) | |
planner_params = MPCPlannerParams.from_config(cfg) | |
normalizer = SceneDataLoaders.normalize_trajectory | |
planner = MPCPlanner(planner_params, predictor, normalizer) | |
return cfg, planner | |
def get_scene_and_ado_trajectory( | |
cfg: Config, scene_type: str, num_episodes: int | |
) -> Tuple[RandomScene, torch.Tensor, torch.Tensor]: | |
scene_params = RandomSceneParams.from_config(cfg) | |
safer_fast_scene, safer_slow_scene = get_fast_slow_scenes( | |
scene_params, num_episodes | |
) | |
assert scene_type in ["safer_fast", "safer_slow"] | |
if scene_type == "safer_fast": | |
scene = safer_fast_scene | |
elif scene_type == "safer_slow": | |
scene = safer_slow_scene | |
else: | |
raise ValueError(f"unknown scene type {scene_type}") | |
ado_trajectory = torch.from_numpy( | |
scene.get_pedestrians_trajectories().astype("float32") | |
) | |
ado_position_history = to_state(ado_trajectory[..., : cfg.num_steps, :], cfg.dt) | |
ado_position_future = to_state(ado_trajectory[..., cfg.num_steps :, :], cfg.dt) | |
return scene, ado_position_history, ado_position_future | |
def get_ego_state_history_and_target_trajectory( | |
cfg: Config, scene: RandomScene | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
ego_state_traj = to_state( | |
torch.from_numpy( | |
scene.get_ego_ref_trajectory(cfg.sample_times).astype("float32") | |
), | |
cfg.dt, | |
) | |
ego_state_history = ego_state_traj[0, :, : cfg.num_steps] | |
ego_state_target_trajectory = ego_state_traj[0, :, cfg.num_steps :] | |
return ego_state_history, ego_state_target_trajectory | |
def get_ground_truth_interaction_cost( | |
planner: MPCPlanner, | |
ado_state_future: AbstractState, # (num_agents, num_steps_future) | |
ego_state_history: AbstractState, # (1, 1, num_steps) | |
) -> float: | |
ego_state_future = planner.solver.dynamics_model.simulate( | |
ego_state_history[..., -1], planner.solver.control_sequence.unsqueeze(0) | |
) | |
interaction_cost = get_interaction_cost( | |
ego_state_future, | |
ado_state_future, | |
planner.solver.interaction_cost, | |
) | |
return interaction_cost.item() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="evaluate prediciton-planning stack using safer_fast and safer_slow scenes" | |
) | |
parser.add_argument( | |
"--load_from", | |
type=str, | |
required=True, | |
help="WandB ID to load trained predictor from", | |
) | |
parser.add_argument( | |
"--seed", | |
type=int, | |
default=0, | |
help="Random seed for evaluation", | |
) | |
parser.add_argument( | |
"--num_episodes", | |
type=int, | |
default=100, | |
help="Number of episodes", | |
) | |
parser.add_argument( | |
"--risk_level", | |
type=float, | |
nargs="+", | |
help="Risk-sensitivity level(s) to test", | |
default=[0.95, 1], | |
) | |
parser.add_argument( | |
"--num_samples", | |
type=int, | |
nargs="+", | |
help="Number(s) of prediction samples to test", | |
default=[1, 4, 16, 64, 256, 1024], | |
) | |
parser.add_argument( | |
"--force_config", | |
action="store_true", | |
help="""Use this flag to force the use of the local config file | |
when loading a model from a checkpoint. Otherwise the checkpoint config file is used. | |
In any case the parameters can be overwritten with an argparse argument.""", | |
) | |
parser.add_argument( | |
"--load_last", | |
action="store_true", | |
help="""Use this flag to force the use of the last checkpoint instead of the best one | |
when loading a model.""", | |
) | |
args = parser.parse_args() | |
# Args will be re-parsed, this keeps only the arguments that are compatible with the second parser. | |
keep_list = ["--load_from", "--seed", "--load_last", "--force_config"] | |
sys.argv = [ss for s in sys.argv for ss in s.split("=")] | |
sys.argv = [ | |
sys.argv[i] | |
for i in range(len(sys.argv)) | |
if sys.argv[i] in keep_list or sys.argv[i - 1] in keep_list or i == 0 | |
] | |
evaluate_main( | |
args.load_from, args.seed, args.num_episodes, args.risk_level, args.num_samples | |
) | |