|
import os
|
|
import pickle
|
|
from typing import Any, Union
|
|
|
|
import numpy as np
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import LRScheduler
|
|
|
|
from navsim.agents.vadv2.vadv2_config import Vadv2Config
|
|
from navsim.agents.vadv2.vadv2_features import (
|
|
Vadv2FeatureBuilder,
|
|
Vadv2TargetBuilder,
|
|
)
|
|
from navsim.agents.vadv2.vadv2_loss import vadv2_loss_pdm_w_progress
|
|
from navsim.agents.vadv2.vadv2_pdm_model_progress import Vadv2ModelPDMProgress
|
|
from navsim.common.dataclasses import SensorConfig
|
|
from navsim.planning.training.abstract_feature_target_builder import (
|
|
AbstractFeatureBuilder,
|
|
AbstractTargetBuilder,
|
|
)
|
|
|
|
DEVKIT_ROOT = os.getenv('NAVSIM_DEVKIT_ROOT')
|
|
TRAJ_PDM_ROOT = os.getenv('NAVSIM_TRAJPDM_ROOT')
|
|
|
|
from typing import Dict, List
|
|
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling
|
|
|
|
from navsim.agents.abstract_agent import AbstractAgent
|
|
from navsim.common.dataclasses import Trajectory
|
|
|
|
|
|
class Vadv2AgentPDMProgress(AbstractAgent):
|
|
def __init__(
|
|
self,
|
|
config: Vadv2Config,
|
|
lr: float,
|
|
checkpoint_path: str = None,
|
|
pdm_split=None,
|
|
metrics=None,
|
|
):
|
|
super().__init__()
|
|
config.trajectory_pdm_weight = {
|
|
'noc': 3.0,
|
|
'da': 3.0,
|
|
'ttc': 2.0,
|
|
'progress': config.progress_weight,
|
|
'comfort': 1.0,
|
|
}
|
|
self._config = config
|
|
self._lr = lr
|
|
self.metrics = metrics
|
|
self._checkpoint_path = checkpoint_path
|
|
self.vadv2_model = Vadv2ModelPDMProgress(config)
|
|
self.vocab_size = config.vocab_size
|
|
self.backbone_wd = config.backbone_wd
|
|
new_pkl_dir = f'vocab_score_full_{self.vocab_size}_navtrain'
|
|
self.vocab_pdm_score_full = pickle.load(
|
|
open(f'{TRAJ_PDM_ROOT}/{new_pkl_dir}/{pdm_split}.pkl', 'rb'))
|
|
|
|
def name(self) -> str:
|
|
"""Inherited, see superclass."""
|
|
|
|
return self.__class__.__name__
|
|
|
|
def initialize(self) -> None:
|
|
"""Inherited, see superclass."""
|
|
|
|
|
|
|
|
|
|
|
|
state_dict: Dict[str, Any] = torch.load(self._checkpoint_path, map_location=torch.device("cpu"))["state_dict"]
|
|
self.load_state_dict({k.replace("agent.", ""): v for k, v in state_dict.items()})
|
|
|
|
def get_sensor_config(self) -> SensorConfig:
|
|
"""Inherited, see superclass."""
|
|
return SensorConfig.build_mm_sensors()
|
|
|
|
def get_target_builders(self) -> List[AbstractTargetBuilder]:
|
|
return [Vadv2TargetBuilder(config=self._config)]
|
|
|
|
def get_feature_builders(self) -> List[AbstractFeatureBuilder]:
|
|
return [Vadv2FeatureBuilder(config=self._config)]
|
|
|
|
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
return self.vadv2_model(features)
|
|
|
|
def forward_train(self, features, interpolated_traj):
|
|
return self.vadv2_model(features, interpolated_traj)
|
|
|
|
def compute_loss(
|
|
self,
|
|
features: Dict[str, torch.Tensor],
|
|
targets: Dict[str, torch.Tensor],
|
|
predictions: Dict[str, torch.Tensor],
|
|
tokens=None
|
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
|
|
scores = {}
|
|
for k in self.metrics:
|
|
tmp = [self.vocab_pdm_score_full[token][k][None] for token in tokens]
|
|
scores[k] = (torch.from_numpy(np.concatenate(tmp, axis=0))
|
|
.to(predictions['trajectory'].device))
|
|
return vadv2_loss_pdm_w_progress(targets, predictions, self._config, scores)
|
|
|
|
def get_optimizers(self) -> Union[Optimizer, Dict[str, Union[Optimizer, LRScheduler]]]:
|
|
if self._config.backbone_type == 'moe':
|
|
backbone_params_eva = '_backbone.image_encoder.eva'
|
|
backbone_params_da = '_backbone.image_encoder.davit'
|
|
img_backbone_params = list(
|
|
filter(lambda kv: backbone_params_eva in kv[0] or backbone_params_da in kv[0], self.vadv2_model.named_parameters())
|
|
)
|
|
default_params = list(filter(lambda kv: backbone_params_da not in kv[0] and backbone_params_eva not in kv[0], self.vadv2_model.named_parameters()))
|
|
params_lr_dict = [
|
|
{'params': [tmp[1] for tmp in default_params]},
|
|
{
|
|
'params': [tmp[1] for tmp in img_backbone_params],
|
|
'lr': self._lr * self._config.lr_mult_backbone,
|
|
'weight_decay': self.backbone_wd
|
|
}
|
|
]
|
|
return torch.optim.Adam(params_lr_dict, lr=self._lr)
|
|
backbone_params_name = '_backbone.image_encoder'
|
|
img_backbone_params = list(
|
|
filter(lambda kv: backbone_params_name in kv[0], self.vadv2_model.named_parameters()))
|
|
default_params = list(filter(lambda kv: backbone_params_name not in kv[0], self.vadv2_model.named_parameters()))
|
|
params_lr_dict = [
|
|
{'params': [tmp[1] for tmp in default_params]},
|
|
{
|
|
'params': [tmp[1] for tmp in img_backbone_params],
|
|
'lr': self._lr * self._config.lr_mult_backbone,
|
|
'weight_decay': self.backbone_wd
|
|
}
|
|
]
|
|
return torch.optim.Adam(params_lr_dict, lr=self._lr)
|
|
|
|
def get_training_callbacks(self) -> List[pl.Callback]:
|
|
return [
|
|
|
|
ModelCheckpoint(
|
|
save_top_k=30,
|
|
monitor="val/loss_epoch",
|
|
mode="min",
|
|
dirpath=f"{os.environ.get('NAVSIM_EXP_ROOT')}/{self._config.ckpt_path}/",
|
|
filename="{epoch:02d}-{step:04d}",
|
|
)
|
|
]
|
|
|
|
def compute_trajectory(self, agent_input):
|
|
"""
|
|
Submission
|
|
"""
|
|
self.eval()
|
|
features: Dict[str, torch.Tensor] = {}
|
|
|
|
for builder in self.get_feature_builders():
|
|
features.update(builder.compute_features(agent_input))
|
|
|
|
|
|
features = {k: v.unsqueeze(0).cuda() for k, v in features.items()}
|
|
vocab = self.vadv2_model._trajectory_head.vocab
|
|
self.vadv2_model = self.vadv2_model.cuda()
|
|
|
|
with torch.no_grad():
|
|
predictions = self.vadv2_model(features)
|
|
|
|
imis = predictions["imi"].softmax(-1).log().cpu().numpy()
|
|
nocs = predictions["noc"].log().cpu().numpy()
|
|
das = predictions["da"].log().cpu().numpy()
|
|
ttcs = predictions["ttc"].log().cpu().numpy()
|
|
comforts = predictions["comfort"].log().cpu().numpy()
|
|
progresses = predictions["progress"].log().cpu().numpy()
|
|
|
|
imi_weight = 0.1
|
|
noc_weight = 0.25
|
|
da_weight = 2.0
|
|
ttc_weight = 3.0
|
|
progress_weight = 5.0
|
|
comfort_weight = 1.0
|
|
tpc_weight = 2.25
|
|
|
|
|
|
score = (
|
|
imi_weight * imis +
|
|
noc_weight * nocs +
|
|
da_weight * das +
|
|
tpc_weight * (
|
|
ttc_weight * ttcs +
|
|
comfort_weight * comforts +
|
|
progress_weight * progresses
|
|
)
|
|
)[0].argmax(0)
|
|
traj = vocab[score].cpu().numpy()
|
|
return Trajectory(traj,
|
|
TrajectorySampling(time_horizon=4, interval_length=0.1))
|
|
|