|
import os |
|
import pickle |
|
from typing import Any, Union |
|
|
|
import numpy as np |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
from torch.optim import Optimizer |
|
from torch.optim.lr_scheduler import LRScheduler |
|
|
|
from navsim.agents.hydra.hydra_config import HydraConfig |
|
from navsim.agents.hydra.hydra_features import HydraFeatureBuilder, HydraTargetBuilder |
|
from navsim.agents.hydra.hydra_model_pe_nodet_beta import HydraModelPENoDetBeta |
|
from navsim.agents.vadv2.vadv2_config import Vadv2Config |
|
from navsim.agents.vadv2.vadv2_loss import three_to_two_classes |
|
from navsim.common.dataclasses import SensorConfig |
|
from navsim.planning.training.abstract_feature_target_builder import ( |
|
AbstractFeatureBuilder, |
|
AbstractTargetBuilder, |
|
) |
|
|
|
DEVKIT_ROOT = os.getenv('NAVSIM_DEVKIT_ROOT') |
|
TRAJ_PDM_ROOT = os.getenv('NAVSIM_TRAJPDM_ROOT') |
|
|
|
from typing import Dict, List |
|
|
|
try: |
|
from navsim.agents.utils.positional_encoding import SinePositionalEncoding3D |
|
except: |
|
print('sine pe not registered') |
|
pass |
|
|
|
import pytorch_lightning as pl |
|
import torch |
|
import torch.nn.functional as F |
|
from navsim.agents.abstract_agent import AbstractAgent |
|
|
|
|
|
def hydra_nodet_beta_loss( |
|
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config, |
|
vocab_pdm_score |
|
): |
|
""" |
|
Helper function calculating complete loss of Transfuser |
|
:param targets: dictionary of name tensor pairings |
|
:param predictions: dictionary of name tensor pairings |
|
:param config: global Transfuser config |
|
:return: combined loss value |
|
""" |
|
|
|
noc, da, ttc, comfort, progress = (predictions['noc'], predictions['da'], |
|
predictions['ttc'], |
|
predictions['comfort'], predictions['progress']) |
|
imi = predictions['imi'] |
|
|
|
da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype)) |
|
ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype)) |
|
comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype)) |
|
noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype))) |
|
progress_loss = F.l1_loss(progress, vocab_pdm_score['progress'].to(progress.dtype)) |
|
|
|
vocab = predictions["trajectory_vocab"] |
|
|
|
target_traj = targets["trajectory"] |
|
|
|
sampled_timepoints = [5 * k - 1 for k in range(1, 9)] |
|
B = target_traj.shape[0] |
|
l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma |
|
imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1)) |
|
|
|
imi_loss_final = config.trajectory_imi_weight * imi_loss |
|
|
|
noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss |
|
da_loss_final = config.trajectory_pdm_weight['da'] * da_loss |
|
ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss |
|
progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss |
|
comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss |
|
|
|
loss = ( |
|
imi_loss_final |
|
+ noc_loss_final |
|
+ da_loss_final |
|
+ ttc_loss_final |
|
+ progress_loss_final |
|
+ comfort_loss_final |
|
|
|
) |
|
return loss, { |
|
'imi_loss': imi_loss_final, |
|
'pdm_noc_loss': noc_loss_final, |
|
'pdm_da_loss': da_loss_final, |
|
'pdm_ttc_loss': ttc_loss_final, |
|
'pdm_progress_loss': progress_loss_final, |
|
'pdm_comfort_loss': comfort_loss_final |
|
} |
|
|
|
|
|
class HydraAgentPENoDetBeta(AbstractAgent): |
|
def __init__( |
|
self, |
|
config: HydraConfig, |
|
lr: float, |
|
checkpoint_path: str = None, |
|
pdm_split=None, |
|
metrics=None, |
|
): |
|
super().__init__() |
|
config.trajectory_pdm_weight = { |
|
'noc': 3.0, |
|
'da': 3.0, |
|
'ttc': config.ttc_weight, |
|
'progress': config.progress_weight, |
|
'comfort': 1.0, |
|
} |
|
self._config = config |
|
self._lr = lr |
|
self.metrics = metrics |
|
self._checkpoint_path = checkpoint_path |
|
self.vadv2_model = HydraModelPENoDetBeta(config) |
|
self.vocab_size = config.vocab_size |
|
self.backbone_wd = config.backbone_wd |
|
new_pkl_dir = f'vocab_score_full_{self.vocab_size}_navtrain' |
|
self.vocab_pdm_score_full = pickle.load( |
|
open(f'{TRAJ_PDM_ROOT}/{new_pkl_dir}/{pdm_split}.pkl', 'rb')) |
|
|
|
def name(self) -> str: |
|
"""Inherited, see superclass.""" |
|
|
|
return self.__class__.__name__ |
|
|
|
def initialize(self) -> None: |
|
"""Inherited, see superclass.""" |
|
|
|
|
|
|
|
|
|
|
|
state_dict: Dict[str, Any] = torch.load(self._checkpoint_path, map_location=torch.device("cpu"))["state_dict"] |
|
self.load_state_dict({k.replace("agent.", ""): v for k, v in state_dict.items()}) |
|
|
|
def get_sensor_config(self) -> SensorConfig: |
|
"""Inherited, see superclass.""" |
|
return SensorConfig( |
|
cam_f0=[0, 1, 2, 3], |
|
cam_l0=[0, 1, 2, 3], |
|
cam_l1=[0, 1, 2, 3], |
|
cam_l2=[0, 1, 2, 3], |
|
cam_r0=[0, 1, 2, 3], |
|
cam_r1=[0, 1, 2, 3], |
|
cam_r2=[0, 1, 2, 3], |
|
cam_b0=[0, 1, 2, 3], |
|
lidar_pc=[], |
|
) |
|
|
|
def get_target_builders(self) -> List[AbstractTargetBuilder]: |
|
return [HydraTargetBuilder(config=self._config)] |
|
|
|
def get_feature_builders(self) -> List[AbstractFeatureBuilder]: |
|
return [HydraFeatureBuilder(config=self._config)] |
|
|
|
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
return self.vadv2_model(features) |
|
|
|
def forward_train(self, features, interpolated_traj): |
|
return self.vadv2_model(features, interpolated_traj) |
|
|
|
def compute_loss( |
|
self, |
|
features: Dict[str, torch.Tensor], |
|
targets: Dict[str, torch.Tensor], |
|
predictions: Dict[str, torch.Tensor], |
|
tokens=None |
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: |
|
|
|
scores = {} |
|
for k in self.metrics: |
|
tmp = [self.vocab_pdm_score_full[token][k][None] for token in tokens] |
|
scores[k] = (torch.from_numpy(np.concatenate(tmp, axis=0)) |
|
.to(predictions['trajectory'].device)) |
|
return hydra_nodet_beta_loss(targets, predictions, self._config, scores) |
|
|
|
def get_optimizers(self) -> Union[Optimizer, Dict[str, Union[Optimizer, LRScheduler]]]: |
|
backbone_params_name = '_backbone.image_encoder' |
|
img_backbone_params = list( |
|
filter(lambda kv: backbone_params_name in kv[0], self.vadv2_model.named_parameters())) |
|
default_params = list(filter(lambda kv: backbone_params_name not in kv[0], self.vadv2_model.named_parameters())) |
|
params_lr_dict = [ |
|
{'params': [tmp[1] for tmp in default_params]}, |
|
{ |
|
'params': [tmp[1] for tmp in img_backbone_params], |
|
'lr': self._lr * self._config.lr_mult_backbone, |
|
'weight_decay': self.backbone_wd |
|
} |
|
] |
|
return torch.optim.Adam(params_lr_dict, lr=self._lr) |
|
|
|
def get_training_callbacks(self) -> List[pl.Callback]: |
|
return [ |
|
|
|
ModelCheckpoint( |
|
save_top_k=30, |
|
monitor="val/loss_epoch", |
|
mode="min", |
|
dirpath=f"{os.environ.get('NAVSIM_EXP_ROOT')}/{self._config.ckpt_path}/", |
|
filename="{epoch:02d}-{step:04d}", |
|
) |
|
] |
|
|