|
from __future__ import annotations
|
|
|
|
from typing import Any, List, Dict
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import LRScheduler
|
|
|
|
from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling
|
|
|
|
from navsim.agents.abstract_agent import AbstractAgent
|
|
from navsim.common.dataclasses import AgentInput, SensorConfig
|
|
from navsim.planning.training.abstract_feature_target_builder import (
|
|
AbstractFeatureBuilder,
|
|
AbstractTargetBuilder,
|
|
)
|
|
from navsim.common.dataclasses import Scene
|
|
|
|
import torch
|
|
|
|
|
|
class EgoStatusFeatureBuilder(AbstractFeatureBuilder):
|
|
def __init__(self):
|
|
pass
|
|
|
|
def get_unique_name(self) -> str:
|
|
return "ego_status_feature"
|
|
|
|
def compute_features(self, agent_input: AgentInput) -> Dict[str, torch.Tensor]:
|
|
ego_status = agent_input.ego_statuses[-1]
|
|
velocity = torch.tensor(ego_status.ego_velocity)
|
|
acceleration = torch.tensor(ego_status.ego_acceleration)
|
|
driving_command = torch.tensor(ego_status.driving_command)
|
|
ego_status_feature = torch.cat([velocity, acceleration, driving_command], dim=-1)
|
|
|
|
return {"ego_status": ego_status_feature}
|
|
|
|
|
|
class TrajectoryTargetBuilder(AbstractTargetBuilder):
|
|
def __init__(self, trajectory_sampling: TrajectorySampling):
|
|
self._trajectory_sampling = trajectory_sampling
|
|
|
|
def get_unique_name(self) -> str:
|
|
return "trajectory_target"
|
|
|
|
def compute_targets(self, scene: Scene) -> Dict[str, torch.Tensor]:
|
|
future_trajectory = scene.get_future_trajectory(
|
|
num_trajectory_frames=self._trajectory_sampling.num_poses
|
|
)
|
|
return {"trajectory": torch.tensor(future_trajectory.poses)}
|
|
|
|
|
|
class EgoStatusMLPAgent(AbstractAgent):
|
|
def __init__(
|
|
self,
|
|
trajectory_sampling: TrajectorySampling,
|
|
hidden_layer_dim: int,
|
|
lr: float,
|
|
checkpoint_path: str = None,
|
|
):
|
|
super().__init__()
|
|
self._trajectory_sampling = trajectory_sampling
|
|
self._checkpoint_path = checkpoint_path
|
|
|
|
self._lr = lr
|
|
|
|
self._mlp = torch.nn.Sequential(
|
|
torch.nn.Linear(8, hidden_layer_dim),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(hidden_layer_dim, hidden_layer_dim),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(hidden_layer_dim, hidden_layer_dim),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(hidden_layer_dim, self._trajectory_sampling.num_poses * 3),
|
|
)
|
|
|
|
def name(self) -> str:
|
|
"""Inherited, see superclass."""
|
|
|
|
return self.__class__.__name__
|
|
|
|
def initialize(self) -> None:
|
|
"""Inherited, see superclass."""
|
|
if torch.cuda.is_available():
|
|
state_dict: Dict[str, Any] = torch.load(self._checkpoint_path)["state_dict"]
|
|
else:
|
|
state_dict: Dict[str, Any] = torch.load(
|
|
self._checkpoint_path, map_location=torch.device("cpu")
|
|
)["state_dict"]
|
|
self.load_state_dict({k.replace("agent.", ""): v for k, v in state_dict.items()})
|
|
|
|
def get_sensor_config(self) -> SensorConfig:
|
|
"""Inherited, see superclass."""
|
|
return SensorConfig.build_no_sensors()
|
|
|
|
def get_target_builders(self) -> List[AbstractTargetBuilder]:
|
|
return [
|
|
TrajectoryTargetBuilder(trajectory_sampling=self._trajectory_sampling),
|
|
]
|
|
|
|
def get_feature_builders(self) -> List[AbstractFeatureBuilder]:
|
|
return [EgoStatusFeatureBuilder()]
|
|
|
|
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
poses: torch.Tensor = self._mlp(features["ego_status"])
|
|
return {"trajectory": poses.reshape(-1, self._trajectory_sampling.num_poses, 3)}
|
|
|
|
def compute_loss(
|
|
self,
|
|
features: Dict[str, torch.Tensor],
|
|
targets: Dict[str, torch.Tensor],
|
|
predictions: Dict[str, torch.Tensor],
|
|
) -> torch.Tensor:
|
|
return torch.nn.functional.l1_loss(predictions["trajectory"], targets["trajectory"])
|
|
|
|
def get_optimizers(self) -> Optimizer | Dict[str, Optimizer | LRScheduler]:
|
|
return torch.optim.Adam(self._mlp.parameters(), lr=self._lr)
|
|
|