from dataclasses import dataclass from typing import Callable, Tuple import torch from mmcv import Config from risk_biased.predictors.biased_predictor import LitTrajectoryPredictor from risk_biased.mpc_planner.dynamics import PositionVelocityDoubleIntegrator from risk_biased.mpc_planner.planner_cost import TrackingCostParams from risk_biased.mpc_planner.solver import CrossEntropySolver, CrossEntropySolverParams from risk_biased.mpc_planner.planner_cost import TrackingCost from risk_biased.utils.cost import TTCCostTorch, TTCCostParams from risk_biased.utils.planner_utils import AbstractState, to_state from risk_biased.utils.risk import get_risk_estimator @dataclass class MPCPlannerParams: """Dataclass for MPC-Planner Parameters Args: dt_s: discrete time interval in seconds that is used for planning num_steps: number of time steps for which history of ego's and the other actor's trajectories are stored num_steps_future: number of time steps into the future for which ego's and the other actor's trajectories are considered acceleration_std_x_m_s2: Acceleration noise standard deviation (m/s^2) in x-direction that is used to initialize the Cross Entropy solver acceleration_std_y_m_s2: Acceleration noise standard deviation (m/s^2) in y-direction that is used to initialize the Cross Entropy solver risk_estimator_params: parameters for the Monte Carlo risk estimator used in the planner for ego's control optimization solver_params: parameters for the CrossEntropySolver tracking_cost_params: parameters for the TrackingCost ttc_cost_params: parameters for the TTCCost (i.e., collision cost between ego and the other actor) """ dt: float num_steps: int num_steps_future: int acceleration_std_x_m_s2: float acceleration_std_y_m_s2: float risk_estimator_params: dict solver_params: CrossEntropySolverParams tracking_cost_params: TrackingCostParams ttc_cost_params: TTCCostParams @staticmethod def from_config(cfg: Config): return MPCPlannerParams( cfg.dt, cfg.num_steps, cfg.num_steps_future, cfg.acceleration_std_x_m_s2, cfg.acceleration_std_y_m_s2, cfg.risk_estimator, CrossEntropySolverParams.from_config(cfg), TrackingCostParams.from_config(cfg), TTCCostParams.from_config(cfg), ) class MPCPlanner: """MPC Planner with a Cross Entropy solver Args: params: MPCPlannerParams object predictor: LitTrajectoryPredictor object normalizer: function that takes in an unnormalized trajectory and that outputs the normalized trajectory and the offset in this order """ def __init__( self, params: MPCPlannerParams, predictor: LitTrajectoryPredictor, normalizer: Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]], ) -> None: self.params = params self.dynamics_model = PositionVelocityDoubleIntegrator(params.dt) self.control_input_mean_init = torch.zeros( 1, params.num_steps_future, self.dynamics_model.control_dim ) self.control_input_std_init = torch.Tensor( [ params.acceleration_std_x_m_s2, params.acceleration_std_y_m_s2, ] ).expand_as(self.control_input_mean_init) self.solver = CrossEntropySolver( params=params.solver_params, dynamics_model=self.dynamics_model, control_input_mean=self.control_input_mean_init, control_input_std=self.control_input_std_init, tracking_cost_function=TrackingCost(params.tracking_cost_params), interaction_cost_function=TTCCostTorch(params.ttc_cost_params), risk_estimator=get_risk_estimator(params.risk_estimator_params), ) self.predictor = predictor self.normalizer = normalizer self._ego_state_history = [] self._ego_state_target_trajectory = None self._ego_state_planned_trajectory = None self._ado_state_history = [] self._latest_ado_position_future_samples = None def replan( self, current_ado_state: AbstractState, current_ego_state: AbstractState, target_velocity: torch.Tensor, num_prediction_samples: int = 1, risk_level: float = 0.0, resample_prediction: bool = False, risk_in_predictor: bool = False, ) -> None: """Performs re-planning given the current_ado_position, current_ego_state, and target_velocity. Updates ego_state_planned_trajectory. Note that all the information given to the solver.solve(...) is expressed in the ego-centric frame, whose origin is the initial ego position in ego_state_history and the x-direction is parallel to the initial ego velocity. Args: current_ado_position: ado state current_ego_state: ego state target_velocity: ((1), 2) tensor num_prediction_samples (optional): number of prediction samples. Defaults to 1. risk_level (optional): a risk-level float for the entire prediction-planning pipeline. If 0.0, risk-neutral prediction and planning are used. Defaults to 0.0. resample_prediction (optional): If True, prediction is re-sampled in each cross-entropy iteration. Defaults to False. risk_in_predictor (optional): If True, risk-biased prediction is used and the solver becomes risk-neutral. If False, risk-neutral prediction is used and the solver becomes risk-sensitive. Defaults to False. """ self._update_ado_state_history(current_ado_state) self._update_ego_state_history(current_ego_state) self._update_ego_state_target_trajectory(current_ego_state, target_velocity) if not self.ado_state_history.shape[-1] < self.params.num_steps: self.solver.solve( self.predictor, self._map_to_ego_centric_frame(self.ego_state_history), self._map_to_ego_centric_frame(self._ego_state_target_trajectory), self._map_to_ego_centric_frame(self.ado_state_history), self.normalizer, num_prediction_samples=num_prediction_samples, risk_level=risk_level, resample_prediction=resample_prediction, risk_in_predictor=risk_in_predictor, ) ego_state_planned_trajectory_in_ego_frame = self.dynamics_model.simulate( self._map_to_ego_centric_frame(self.ego_state_history[..., -1]), self.solver.control_sequence, ) self._ego_state_planned_trajectory = self._map_to_world_frame( ego_state_planned_trajectory_in_ego_frame ) latest_ado_position_future_samples_in_ego_frame = ( self.solver.fetch_latest_prediction() ) if latest_ado_position_future_samples_in_ego_frame is not None: self._latest_ado_position_future_samples = self._map_to_world_frame( latest_ado_position_future_samples_in_ego_frame ) else: self._latest_ado_position_future_samples = None def get_planned_next_ego_state(self) -> AbstractState: """Returns the next ego state according to the ego_state_planned_trajectory Returns: Planned state """ assert ( self._ego_state_planned_trajectory is not None ), "call self.replan(...) first" return self._ego_state_planned_trajectory[..., 0] def reset(self) -> None: """Resets the planner's internal state. This will fully reset the solver's internal state, including solver.control_input_mean_init and solver.control_input_std_init.""" self.solver.control_input_mean_init = ( self.control_input_mean_init.detach().clone() ) self.solver.control_input_std_init = ( self.control_input_std_init.detach().clone() ) self.solver.reset() self._ego_state_history = [] self._ego_state_target_trajectory = None self._ego_state_planned_trajectory = None self._ado_state_history = [] self._latest_ado_position_future_samples = None def fetch_latest_prediction(self) -> torch.Tensor: if self._latest_ado_position_future_samples is not None: return self._latest_ado_position_future_samples else: return None @property def ego_state_history(self) -> torch.Tensor: """Returns ego_state_history as a concatenated tensor Returns: ego_state_history tensor """ assert len(self._ego_state_history) > 0 return to_state( torch.stack( [ego_state.get_states(4) for ego_state in self._ego_state_history], dim=-2, ), self.params.dt, ) @property def ado_state_history(self) -> torch.Tensor: """Returns ado_position_history as a concatenated tensor Returns: ado_position_history tensor """ assert len(self._ado_state_history) > 0 return to_state( torch.stack( [ado_state.get_states(4) for ado_state in self._ado_state_history], dim=-2, ), self.params.dt, ) def _update_ego_state_history(self, current_ego_state: AbstractState) -> None: """Updates ego_state_history with the current_ego_state Args: current_ego_state: (1, state_dim) tensor """ if len(self._ego_state_history) >= self.params.num_steps: self._ego_state_history = self._ego_state_history[1:] self._ego_state_history.append(current_ego_state) assert len(self._ego_state_history) <= self.params.num_steps def _update_ado_state_history(self, current_ado_state: AbstractState) -> None: """Updates ego_state_history with the current_ado_position Args: current_ado_state states of the current non-ego vehicles """ if len(self._ado_state_history) >= self.params.num_steps: self._ado_state_history = self._ado_state_history[1:] self._ado_state_history.append(current_ado_state) assert len(self._ado_state_history) <= self.params.num_steps def _update_ego_state_target_trajectory( self, current_ego_state: AbstractState, target_velocity: torch.Tensor ) -> None: """Updates ego_state_target_trajectory based on the current_ego_state and the target_velocity Args: current_ego_state: state target_velocity: (1, 2) tensor """ target_displacement = self.params.dt * target_velocity target_position_list = [current_ego_state.position] for time_idx in range(self.params.num_steps_future): target_position_list.append(target_position_list[-1] + target_displacement) target_position_list = target_position_list[1:] target_position = torch.cat(target_position_list, dim=-2) target_state = to_state( torch.cat( (target_position, target_velocity.expand_as(target_position)), dim=-1 ), self.params.dt, ) self._ego_state_target_trajectory = target_state def _map_to_ego_centric_frame( self, trajectory_in_world_frame: AbstractState ) -> torch.Tensor: """Maps trajectory epxressed in the world frame to the ego-centric frame, whose origin is the initial ego position in ego_state_history and the x-direction is parallel to the initial ego velocity Args: trajectory: sequence of states Returns: trajectory mapped to the ego-centric frame """ # If trajectory_in_world_frame is of shape (..., state_dim) then use the associated # dynamics model in translate_position and rotate_angle. Otherwise assume that th # trajectory is in the 2D position space. ego_pos_init = self.ego_state_history.position[..., -1, :] ego_vel_init = self.ego_state_history.velocity[..., -1, :] ego_rot_init = torch.atan2(ego_vel_init[..., 1], ego_vel_init[..., 0]) trajectory_in_ego_frame = trajectory_in_world_frame.translate( -ego_pos_init ).rotate(-ego_rot_init) return trajectory_in_ego_frame def _map_to_world_frame( self, trajectory_in_ego_frame: torch.Tensor ) -> torch.Tensor: """Maps trajectory epxressed in the ego-centric frame to the world frame Args: trajectory_in_ego_frame: (..., 2) position trajectory or (..., markov_state_dim) state trajectory expressed in the ego-centric frame, whose origin is the initial ego position in ego_state_history and the x-direction is parallel to the initial ego velocity Returns: trajectory mapped to the world frame """ # state starts with x, y, angle ego_pos_init = self.ego_state_history.position[..., -1, :] ego_rot_init = self.ego_state_history.angle[..., -1, :] trajectory_in_world_frame = trajectory_in_ego_frame.rotate( ego_rot_init ).translate(ego_pos_init) return trajectory_in_world_frame