import os import io import pickle import sys from functools import partial from inspect import signature import matplotlib.pyplot as plt from tqdm import tqdm from einops import repeat import fire import numpy as np from pytorch_lightning.utilities.seed import seed_everything import torch from risk_biased.utils.config_argparse import config_argparse from risk_biased.utils.cost import TTCCostTorch, TTCCostParams, get_cost from risk_biased.utils.risk import get_risk_estimator from risk_biased.utils.load_model import load_from_config def to_device(batch, device): output = [] for item in batch: output.append(item.to(device)) return output class CPU_Unpickler(pickle.Unpickler): def find_class(self, module, name): if module == "torch.storage" and name == "_load_from_bytes": return lambda b: torch.load(io.BytesIO(b), map_location="cpu") else: return super().find_class(module, name) def distance(pred, truth): """ pred (Tensor): (..., time, xy) truth (Tensor): (..., time, xy) mask_loss (Tensor): (..., time) Defaults to None. """ return torch.sqrt(torch.sum(torch.square(pred[..., :2] - truth[..., :2]), -1)) def compute_metrics( predictor, batch, cost, risk_levels, risk_estimator, dt, unnormalizer, n_samples_risk, n_samples_stats, ): # risk_unbiased # risk_biased # cost # FDE: unbiased, biased(risk_level=[0, 0.3, 0.5, 0.8, 1]) (for all samples so minFDE can be computed later) # ADE (for all samples so minADE can be computed later) x, mask_x, y, mask_y, mask_loss, map, mask_map, offset, x_ego, y_ego = batch mask_z = mask_x.any(-1) _, z_mean_inference, z_log_std_inference = predictor.model( x, mask_x, map, mask_map, offset=offset, x_ego=x_ego, y_ego=y_ego, risk_level=None, ) latent_distribs = { "inference": { "mean": z_mean_inference[:, 1].detach().cpu(), "log_std": z_log_std_inference[:, 1].detach().cpu(), } } inference_distances = [] cost_list = [] # Cut the number of samples in packs to avoid out-of-memory problems # Compute and store cost for all packs for _ in range(n_samples_risk // n_samples_stats): z_samples_inference = predictor.model.inference_encoder.sample( z_mean_inference, z_log_std_inference, n_samples=n_samples_stats, ) y_samples = predictor.model.decode( z_samples=z_samples_inference, mask_z=mask_z, x=x, mask_x=mask_x, map=map, mask_map=mask_map, offset=offset, ) mask_loss_samples = repeat(mask_loss, "b a t -> b a s t", s=n_samples_stats) # Computing unbiased cost cost_list.append( get_cost( cost, x, y_samples, offset, x_ego, y_ego, dt, unnormalizer, mask_loss_samples, )[:, 1:2] ) inference_distances.append(distance(y_samples, y.unsqueeze(2))[:, 1:2]) cost_dic = {} cost_dic["inference"] = torch.cat(cost_list, 2).detach().cpu() distance_dic = {} distance_dic["inference"] = torch.cat(inference_distances, 2).detach().cpu() # Set up the output risk tensor risk_dic = {} # Loop on risk_level values to fill the risk estimation for each value and compute stats at each risk level for rl in risk_levels: risk_level = ( torch.ones( (x.shape[0], x.shape[1]), device=x.device, ) * rl ) risk_dic[f"biased_{rl}"] = risk_estimator( risk_level[:, 1:2].detach().cpu(), cost_dic["inference"] ) y_samples_biased, z_mean_biased, z_log_std_biased = predictor.model( x, mask_x, map, mask_map, offset=offset, x_ego=x_ego, y_ego=y_ego, risk_level=risk_level, n_samples=n_samples_stats, ) latent_distribs[f"biased_{rl}"] = { "mean": z_mean_biased[:, 1].detach().cpu(), "log_std": z_log_std_biased[:, 1].detach().cpu(), } distance_dic[f"biased_{rl}"] = ( distance(y_samples_biased, y.unsqueeze(2))[:, 1].detach().cpu() ) cost_dic[f"biased_{rl}"] = ( get_cost( cost, x, y_samples_biased, offset, x_ego, y_ego, dt, unnormalizer, mask_loss_samples, )[:, 1] .detach() .cpu() ) # Return risks for the batch and all risk values return { "risk": risk_dic, "cost": cost_dic, "distance": distance_dic, "latent_distribs": latent_distribs, "mask": mask_loss[:, 1].detach().cpu(), } def cat_metrics_rec(metrics1, metrics2, cat_to): for key in metrics1.keys(): if key not in metrics2.keys(): raise RuntimeError( f"Trying to concatenate objects with different keys: {key} is not in second argument keys." ) elif isinstance(metrics1[key], dict): if key not in cat_to.keys(): cat_to[key] = {} cat_metrics_rec(metrics1[key], metrics2[key], cat_to[key]) elif isinstance(metrics1[key], torch.Tensor): cat_to[key] = torch.cat((metrics1[key], metrics2[key]), 0) def cat_metrics(metrics1, metrics2): out = {} cat_metrics_rec(metrics1, metrics2, out) return out def masked_mean_std_ste(data, mask): mask = mask.view(data.shape) norm = mask.sum().clamp_min(1) mean = (data * mask).sum() / norm std = torch.sqrt(((data - mean) * mask).square().sum() / norm) return mean.item(), std.item(), (std / torch.sqrt(norm)).item() def masked_mean_range(data, mask): data = data[mask] mean = data.mean() min = torch.quantile(data, 0.05) max = torch.quantile(data, 0.95) return mean, min, max def masked_mean_dim(data, mask, dim): norm = mask.sum(dim).clamp_min(1) mean = (data * mask).sum(dim) / norm return mean def plot_risk_error(metrics, risk_levels, risk_estimator, max_n_samples, path_save): cost_inference = metrics["cost"]["inference"] cost_biased_0 = metrics["cost"]["biased_0"] mask = metrics["mask"].any(1) ones_tensor = torch.ones(mask.shape[0]) n_samples = np.minimum(cost_biased_0.shape[1], max_n_samples) for rl in risk_levels: key = f"biased_{rl}" reference_risk = metrics["risk"][key] mean_inference_risk_error_per_samples = np.zeros(n_samples - 1) min_inference_risk_error_per_samples = np.zeros(n_samples - 1) max_inference_risk_error_per_samples = np.zeros(n_samples - 1) # mean_biased_0_risk_error_per_samples = np.zeros(n_samples-1) # min_biased_0_risk_error_per_samples = np.zeros(n_samples-1) # max_biased_0_risk_error_per_samples = np.zeros(n_samples-1) mean_biased_risk_error_per_samples = np.zeros(n_samples - 1) min_biased_risk_error_per_samples = np.zeros(n_samples - 1) max_biased_risk_error_per_samples = np.zeros(n_samples - 1) risk_level_tensor = ones_tensor * rl for sub_samples in range(1, n_samples): perm = torch.randperm(metrics["cost"][key].shape[1])[:sub_samples] risk_error_biased = metrics["cost"][key][:, perm].mean(1) - reference_risk ( mean_biased_risk_error_per_samples[sub_samples - 1], min_biased_risk_error_per_samples[sub_samples - 1], max_biased_risk_error_per_samples[sub_samples - 1], ) = masked_mean_range(risk_error_biased, mask) risk_error_inference = ( risk_estimator(risk_level_tensor, cost_inference[:, :, :sub_samples]) - reference_risk ) ( mean_inference_risk_error_per_samples[sub_samples - 1], min_inference_risk_error_per_samples[sub_samples - 1], max_inference_risk_error_per_samples[sub_samples - 1], ) = masked_mean_range(risk_error_inference, mask) # risk_error_biased_0 = risk_estimator(risk_level_tensor, cost_biased_0[:, :sub_samples]) - reference_risk # (mean_biased_0_risk_error_per_samples[sub_samples-1], min_biased_0_risk_error_per_samples[sub_samples-1], max_biased_0_risk_error_per_samples[sub_samples-1]) = masked_mean_range(risk_error_biased_0, mask) plt.plot( range(1, n_samples), mean_inference_risk_error_per_samples, label="Inference", ) plt.fill_between( range(1, n_samples), min_inference_risk_error_per_samples, max_inference_risk_error_per_samples, alpha=0.3, ) # plt.plot(range(1, n_samples), mean_biased_0_risk_error_per_samples, label="Unbiased") # plt.fill_between(range(1, n_samples), min_biased_0_risk_error_per_samples, max_biased_0_risk_error_per_samples, alpha=.3) plt.plot( range(1, n_samples), mean_biased_risk_error_per_samples, label="Biased" ) plt.fill_between( range(1, n_samples), min_biased_risk_error_per_samples, max_biased_risk_error_per_samples, alpha=0.3, ) plt.ylim( np.min(min_inference_risk_error_per_samples), np.max(max_biased_risk_error_per_samples), ) plt.hlines(y=0, xmin=0, xmax=n_samples, colors="black", linestyles="--", lw=0.3) plt.xlabel("Number of samples") plt.ylabel("Risk estimation error") plt.legend() plt.title(f"Risk estimation error at risk-level={rl}") plt.gcf().set_size_inches(4, 3) plt.legend(loc="lower right") plt.savefig(fname=os.path.join(path_save, f"risk_level_{rl}.svg")) plt.savefig(fname=os.path.join(path_save, f"risk_level_{rl}.png")) plt.clf() # plt.show() def compute_stats(metrics, n_samples_mean_cost=4): biased_risk_estimate = {} for key in metrics["cost"].keys(): if key == "inference": continue risk = metrics["risk"][key] mean_cost = metrics["cost"][key][:, :n_samples_mean_cost].mean(1) risk_error = mean_cost - risk biased_risk_estimate[key] = {} ( biased_risk_estimate[key]["mean"], biased_risk_estimate[key]["std"], biased_risk_estimate[key]["ste"], ) = masked_mean_std_ste(risk_error, metrics["mask"].any(1)) ( biased_risk_estimate[key]["mean_abs"], biased_risk_estimate[key]["std_abs"], biased_risk_estimate[key]["ste_abs"], ) = masked_mean_std_ste(risk_error.abs(), metrics["mask"].any(1)) risk_stats = {} for key in metrics["risk"].keys(): risk_stats[key] = {} ( risk_stats[key]["mean"], risk_stats[key]["std"], risk_stats[key]["ste"], ) = masked_mean_std_ste(metrics["risk"][key], metrics["mask"].any(1)) cost_stats = {} for key in metrics["cost"].keys(): cost_stats[key] = {} ( cost_stats[key]["mean"], cost_stats[key]["std"], cost_stats[key]["ste"], ) = masked_mean_std_ste( metrics["cost"][key], metrics["mask"].any(-1, keepdim=True) ) distance_stats = {} for key in metrics["distance"].keys(): distance_stats[key] = {"FDE": {}, "ADE": {}, "minFDE": {}, "minADE": {}} ( distance_stats[key]["FDE"]["mean"], distance_stats[key]["FDE"]["std"], distance_stats[key]["FDE"]["ste"], ) = masked_mean_std_ste( metrics["distance"][key][..., -1], metrics["mask"][:, None, -1] ) mean_dist = masked_mean_dim( metrics["distance"][key], metrics["mask"][:, None, :], -1 ) ( distance_stats[key]["ADE"]["mean"], distance_stats[key]["ADE"]["std"], distance_stats[key]["ADE"]["ste"], ) = masked_mean_std_ste(mean_dist, metrics["mask"].any(-1, keepdim=True)) for i in [6, 16, 32]: distance_stats[key]["minFDE"][i] = {} min_dist, _ = metrics["distance"][key][:, :i, -1].min(1) ( distance_stats[key]["minFDE"][i]["mean"], distance_stats[key]["minFDE"][i]["std"], distance_stats[key]["minFDE"][i]["ste"], ) = masked_mean_std_ste(min_dist, metrics["mask"][:, -1]) distance_stats[key]["minADE"][i] = {} mean_dist, _ = masked_mean_dim( metrics["distance"][key][:, :i], metrics["mask"][:, None, :], -1 ).min(1) ( distance_stats[key]["minADE"][i]["mean"], distance_stats[key]["minADE"][i]["std"], distance_stats[key]["minADE"][i]["ste"], ) = masked_mean_std_ste(mean_dist, metrics["mask"].any(-1)) return { "risk": risk_stats, "biased_risk_estimate": biased_risk_estimate, "cost": cost_stats, "distance": distance_stats, } def print_stats(stats, n_samples_mean_cost=4): slash = "\\" brace_open = "{" brace_close = "}" print("\\begin{tabular}{lccccc}") print("\\hline") print( f"Predictive Model & ${slash}sigma$ & minFDE(16) & FDE (1) & Risk est. error ({n_samples_mean_cost}) & Risk est. $|$error$|$ ({n_samples_mean_cost}) {slash}{slash}" ) print("\\hline") for key in stats["distance"].keys(): strg = ( f" ${stats['distance'][key]['minFDE'][16]['mean']:.2f}$ {slash}scriptsize{brace_open}${slash}pm {stats['distance'][key]['minFDE'][16]['ste']:.2f}${brace_close}" + f"& ${stats['distance'][key]['FDE']['mean']:.2f}$ {slash}scriptsize{brace_open}${slash}pm {stats['distance'][key]['FDE']['ste']:.2f}${brace_close}" ) if key == "inference": strg = ( "Unbiased CVAE & " + f"{slash}scriptsize{brace_open}NA{brace_close} &" + strg + f"& {slash}scriptsize{brace_open}NA{brace_close} & {slash}scriptsize{brace_open}NA{brace_close} {slash}{slash}" ) print(strg) print("\\hline") else: strg = ( "Biased CVAE & " + f"{key[7:]} & " + strg + f"& ${stats['biased_risk_estimate'][key]['mean']:.2f}$ {slash}scriptsize{brace_open}${slash}pm {stats['biased_risk_estimate'][key]['ste']:.2f}${brace_close}" + f"& ${stats['biased_risk_estimate'][key]['mean_abs']:.2f}$ {slash}scriptsize{brace_open}${slash}pm {stats['biased_risk_estimate'][key]['ste_abs']:.2f}${brace_close}" + f"{slash}{slash}" ) print(strg) print("\\hline") print("\\end{tabular}") def main( log_path, force_recompute, n_samples_risk=256, n_samples_stats=32, n_samples_plot=16, args_to_parser=[], ): # Overwrite sys.argv so it doesn't mess up the parser. sys.argv = sys.argv[0:1] + args_to_parser working_dir = os.path.dirname(os.path.realpath(__file__)) config_path = os.path.join( working_dir, "..", "..", "risk_biased", "config", "learning_config.py" ) waymo_config_path = os.path.join( working_dir, "..", "..", "risk_biased", "config", "waymo_config.py" ) cfg = config_argparse([config_path, waymo_config_path]) file_path = os.path.join(log_path, f"metrics_{cfg.load_from}.pickle") fig_path = os.path.join(log_path, f"plots_{cfg.load_from}") if not os.path.exists(fig_path): os.makedirs(fig_path) risk_levels = [0, 0.3, 0.5, 0.8, 0.95, 1] cost = TTCCostTorch(TTCCostParams.from_config(cfg)) risk_estimator = get_risk_estimator(cfg.risk_estimator) n_samples_mean_cost = 4 if not os.path.exists(file_path) or force_recompute: with torch.no_grad(): if cfg.seed is not None: seed_everything(cfg.seed) predictor, dataloaders, cfg = load_from_config(cfg) device = torch.device(cfg.gpus[0]) predictor = predictor.to(device) val_loader = dataloaders.val_dataloader(shuffle=False, drop_last=False) # This loops over batches in the validation dataset beg = 0 metrics_all = None for val_batch in tqdm(val_loader): end = beg + val_batch[0].shape[0] metrics = compute_metrics( predictor=predictor, batch=to_device(val_batch, device), cost=cost, risk_levels=risk_levels, risk_estimator=risk_estimator, dt=cfg.dt, unnormalizer=dataloaders.unnormalize_trajectory, n_samples_risk=n_samples_risk, n_samples_stats=n_samples_stats, ) if metrics_all is None: metrics_all = metrics else: metrics_all = cat_metrics(metrics_all, metrics) beg = end with open(file_path, "wb") as handle: pickle.dump(metrics_all, handle) else: print(f"Loading pre-computed metrics from {file_path}") with open(file_path, "rb") as handle: metrics_all = CPU_Unpickler(handle).load() stats = compute_stats(metrics_all, n_samples_mean_cost=n_samples_mean_cost) print_stats(stats, n_samples_mean_cost=n_samples_mean_cost) plot_risk_error(metrics_all, risk_levels, risk_estimator, n_samples_plot, fig_path) if __name__ == "__main__": # main("./logs/002/", False, 256, 32, 16) # Fire turns the main function into a script, then the risk_biased module argparse reads the other arguments. # Thus, the way to use it would be: # >python compute_stats.py # This is a hack to separate the Fire script args from the argparse arguments args_to_parser = sys.argv[len(signature(main).parameters) :] partial_main = partial(main, args_to_parser=args_to_parser) sys.argv = sys.argv[: len(signature(main).parameters)] # Runs the main as a script fire.Fire(partial_main)