from datasets import load_dataset, Dataset import fire from functools import partial, update_wrapper import numpy import os from typing import Dict, Iterable, Tuple import sys import time import torch import gradio as gr from huggingface_hub import hf_hub_download from mmcv import Config import plotly.graph_objects as go from torch.utils.data.dataloader import DataLoader from risk_biased.utils.load_model import get_predictor from risk_biased.utils.torch_utils import load_weights from risk_biased.utils.waymo_dataloader import WaymoDataloaders from risk_biased.predictors.biased_predictor import ( LitTrajectoryPredictor, ) def to_numpy(**kwargs): dic_outputs = {} for k, v in kwargs.items(): dic_outputs[k] = v.detach().cpu().numpy() return dic_outputs def get_scatter_data(x, mask_x, name, **kwargs): return [ go.Scatter( x=x[k, mask_x[k], 0], y=x[k, mask_x[k], 1], showlegend=k == 0, name=name, **kwargs, ) for k in range(x.shape[0]) ] def configuration_paths() -> Iterable[os.PathLike]: working_dir = os.path.dirname(os.path.realpath(__file__)) return [ os.path.join( working_dir, "../../risk_biased/config", config_file, ) for config_file in ("learning_config.py", "waymo_config.py") ] def load_item(index: int, dataset: Dataset, device: str = "cpu") -> Tuple: x = torch.from_numpy(numpy.array(dataset[index]["x"]).astype(numpy.float32)).to(device) mask_x = torch.from_numpy(numpy.array(dataset[index]["mask_x"]).astype(numpy.bool_)).to(device) y = torch.from_numpy(numpy.array(dataset[index]["y"]).astype(numpy.float32)).to(device) mask_y = torch.from_numpy(numpy.array(dataset[index]["mask_y"]).astype(numpy.bool_)).to(device) mask_loss = torch.from_numpy( numpy.array(dataset[index]["mask_loss"]).astype(numpy.bool_)).to(device) map_data = torch.from_numpy(numpy.array(dataset[index]["map_data"]).astype(numpy.float32)).to(device) mask_map = torch.from_numpy(numpy.array(dataset[index]["mask_map"]).astype(numpy.bool_)).to(device) offset = torch.from_numpy(numpy.array(dataset[index]["offset"]).astype(numpy.float32)).to(device) x_ego = torch.from_numpy(numpy.array(dataset[index]["x_ego"]).astype(numpy.float32)).to(device) y_ego = torch.from_numpy(numpy.array(dataset[index]["y_ego"]).astype(numpy.float32)).to(device) return (x, mask_x, map_data, mask_map, offset, x_ego, y_ego), y, mask_y, mask_loss def build_data( predictor: LitTrajectoryPredictor, dataset: Dataset, index: int, risk_level: float, n_samples: int, ) -> Dict[str, go.Scatter]: assert n_samples >= 1 batch, y, mask_y, mask_loss = load_item(index, dataset, predictor.device) predictions = predictor.predict_step( batch=batch, risk_level=risk_level, n_samples=n_samples, ) offset = batch[4] y = predictor._unnormalize_trajectory(y, offset) x = predictor._unnormalize_trajectory(batch[0], offset) numpy_data = to_numpy( predictions=predictions, y=y, mask_y=mask_y, x=x, mask_x=batch[1], map_data=batch[2], mask_map=batch[3], mask_pred=mask_loss, ) x = numpy_data["x"][0] mask_x = numpy_data["mask_x"][0] y = numpy_data["y"][0] mask_y = numpy_data["mask_y"][0] pred = numpy_data["predictions"][0] mask_pred = numpy_data["mask_pred"][0] map_data = numpy_data["map_data"][0] mask_map = numpy_data["mask_map"][0] marker_size = 12 data_x = get_scatter_data( x, mask_x, mode="lines", line=dict(width=2, color="black"), name="Past", ) ego_present = get_scatter_data( x=x[0:1, -1:], mask_x=mask_x[0:1, -1:], mode="markers", marker=dict(color="blue", size=marker_size, opacity=0.5), name="Ego", ) agent_present = get_scatter_data( x=x[1:2, -1:], mask_x=mask_x[1:2, -1:], mode="markers", marker=dict(color="green", size=marker_size, opacity=0.5), name="Agent", ) data_y = get_scatter_data( y, mask_y, mode="lines", line=dict(width=2, color="green"), name="Ground truth", ) data_map = get_scatter_data( map_data, mask_map, mode="lines", line=dict(width=15, color="gray"), opacity=0.3, name="Centerline", ) data_pred = [] forecasts_end = [] for i in range(n_samples): cur_data_pred = get_scatter_data( pred[:, i], mask_pred, mode="lines", line=dict(width=2, color="red"), name="Forecast", ) data_pred += cur_data_pred forecast_end = get_scatter_data( pred[:, i, -1:], mask_pred[:, -1:], mode="markers", marker=dict(color="red", size=marker_size/2, opacity=0.5, symbol="x"), name="Forecast end", ) forecasts_end += forecast_end static_data = data_map + data_x + data_y + data_pred + ego_present + agent_present + forecasts_end animation_opacity = 0.5 frames_x = [ go.Frame( data=[ go.Scatter( x=x[mask_x[:, k], k, 0], y=x[mask_x[:, k], k, 1], mode="markers", opacity=animation_opacity, marker=dict(color="black", size=marker_size), showlegend=False, ), go.Scatter( x=x[0:1, k, 0], y=x[0:1, k, 1], mode="markers", opacity=animation_opacity, marker=dict(color="blue", size=marker_size), showlegend=False, ), ] ) for k in range(x.shape[1]) ] frames_y_pred = [] for k in range(y.shape[1]): cur_gt_agent_data = go.Scatter( x=y[1:2][mask_y[1:2, k], k, 0], y=y[1:2][mask_y[1:2, k], k, 1], mode="markers", opacity=animation_opacity, marker=dict(color="green", size=marker_size), ) cur_gt_future_data = go.Scatter( x=y[2:][mask_y[2:, k], k, 0], y=y[2:][mask_y[2:, k], k, 1], mode="markers", opacity=animation_opacity, marker=dict(color="black", size=marker_size), ) cur_pred_data = [] for i in range(n_samples): cur_pred_data.append( go.Scatter( x=pred[mask_pred[:, k], i, k, 0], y=pred[mask_pred[:, k], i, k, 1], mode="markers", opacity=animation_opacity, marker=dict(color="red", size=marker_size), showlegend=False, ) ) cur_ego_data = go.Scatter( x=y[0:1, k, 0], y=y[0:1, k, 1], mode="markers", opacity=animation_opacity, marker=dict(color="blue", size=marker_size), ) cur_data = [cur_gt_agent_data, cur_gt_future_data, *cur_pred_data, cur_ego_data] frame = go.Frame(data=cur_data) frames_y_pred.append(frame) return {"frames": frames_x + frames_y_pred, "data": static_data} def prediction_plot( predictor: LitTrajectoryPredictor, dataset: Dataset, index: int, risk_level: float, n_samples: int = 1, use_biaser: bool = True, ) -> go.Figure: range_radius = 80 if use_biaser: risk_level = float(risk_level) else: risk_level = None layout = go.Layout( xaxis=dict( range=[-0.5*range_radius, 1.5*range_radius], autorange=False, zeroline=False, ), yaxis=dict( range=[-range_radius, range_radius], autorange=False, zeroline=False, ), title_text="Road Scene", hovermode="closest", width=800, height=600, updatemenus=[ dict( type="buttons", buttons=[ dict( label="Play", method="animate", args=[ None, dict( frame=dict(duration=100, redraw=False), mode="immediate", fromcurrent=True, ), ], ), dict( label="Pause", method="animate", args=[[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate", "transition": {"duration": 0}}], ) ], ) ], ) fig = go.Figure( **build_data(predictor, dataset, index, risk_level, n_samples), layout=layout, ) fig.update_geos(projection_type="equirectangular", visible=True, resolution=110) return fig def get_figure( predictor: LitTrajectoryPredictor, dataset: Dataset, index: int, risk_level: float, n_samples: int, ) -> go.Figure: fig = prediction_plot( predictor, dataset, index, risk_level, n_samples, use_biaser=True ) fig.update_layout() return fig def update_figure( predictor: LitTrajectoryPredictor, dataset: Dataset, index: int, risk_level: float, n_samples: int, image = None ) -> go.Figure: fig = prediction_plot( predictor, dataset, index, risk_level, n_samples, use_biaser=True ) fig.update_layout() return fig def load_predictor_from_hf(model_source: str = "TRI-ML/risk_biased_model", config_name: str="learning_config.py", checkpoint_name: str = "last.ckpt", device: str = "cpu") -> Tuple[LitTrajectoryPredictor, Dataset]: config_file = hf_hub_download(model_source, filename=config_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN')) ckpt = torch.load(hf_hub_download(model_source, filename=checkpoint_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN')), map_location="cpu") cfg = Config.fromfile(config_file) predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory) predictor = load_weights(predictor, ckpt) predictor.eval() predictor = predictor.to(device) return predictor def load_dataset_from_hf(data_source: str = "jmercat/risk_biased_dataset") -> Dataset: dataset = load_dataset(data_source, split="test", trust_remote_code=True) return dataset def main(load_from=None, cfg_path=None): # Define the device to use device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Getting dataset") dataset = load_dataset_from_hf() if load_from is not None: cfg = Config.fromfile(cfg_path) predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory) predictor = load_weights(predictor, torch.load(load_from, map_location="cpu")) else: print("Getting model.") predictor = load_predictor_from_hf(device=device) ui_update_fn = partial(update_figure, predictor, dataset) # Do the same thing as above but using the gradio blocks API with gr.Blocks() as interface: gr.Markdown( """ # Risk-Aware Prediction Make predictions for the green agent with a risk-seeking bias towards the ego vehicle in blue. The risk level is a value between 0 and 1, where 0 is not risk-seeking and 1 is the most risk-seeking. Once the sliders are set, click the "Run" button to see the predictions. The play button will animate the prediction over time (it is slow especially with many samples). For more information, see the paper [RAP: Risk-Aware Prediction for Robust Planning](https://arxiv.org/abs/2210.01368) published at [CoRL 2022](https://corl2022.org/). """) initial_index = 27 initial_n_samples = 10 image = gr.Plot(get_figure(predictor, dataset, initial_index, 0, initial_n_samples)) interface.queue() index = gr.Slider( minimum=0, maximum=len(dataset)-1, step=1, value=initial_index, label="Index", ) risk_level = gr.Slider(minimum=0, maximum=1, step=0.01, label="Risk") n_samples = gr.Slider(minimum=1, maximum=20, step=1, value=initial_n_samples, label="Number of prediction samples") button = gr.Button(label="Run") # Removed the interactive plot because it was running on the first change and all changes made during computation were ignored # This caused the plot to be out of sync with the sliders # index.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image) # risk_level.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image) # n_samples.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image) button.click(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image) interface.launch(debug=False) if __name__ == "__main__": fire.Fire(main)