navsim_ours / navsim /agents /hydra /hydra_agent_pe_nodet.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
7.97 kB
import os
import pickle
from typing import Any, Union
import numpy as np
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from navsim.agents.hydra.hydra_config import HydraConfig
from navsim.agents.hydra.hydra_features import HydraFeatureBuilder, HydraTargetBuilder
from navsim.agents.hydra.hydra_model_pe_nodet import HydraModelPENoDet
from navsim.agents.vadv2.vadv2_config import Vadv2Config
from navsim.agents.vadv2.vadv2_loss import three_to_two_classes
from navsim.common.dataclasses import SensorConfig
from navsim.planning.training.abstract_feature_target_builder import (
AbstractFeatureBuilder,
AbstractTargetBuilder,
)
DEVKIT_ROOT = os.getenv('NAVSIM_DEVKIT_ROOT')
TRAJ_PDM_ROOT = os.getenv('NAVSIM_TRAJPDM_ROOT')
from typing import Dict, List
try:
from navsim.agents.utils.positional_encoding import SinePositionalEncoding3D
except:
print('sine pe not registered')
pass
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from navsim.agents.abstract_agent import AbstractAgent
def hydra_nodet_loss(
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config,
vocab_pdm_score
):
"""
Helper function calculating complete loss of Transfuser
:param targets: dictionary of name tensor pairings
:param predictions: dictionary of name tensor pairings
:param config: global Transfuser config
:return: combined loss value
"""
noc, da, ttc, comfort, progress = (predictions['noc'], predictions['da'],
predictions['ttc'],
predictions['comfort'], predictions['progress'])
imi = predictions['imi']
# 2 cls
da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype))
ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype))
comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype))
noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype)))
progress_loss = F.binary_cross_entropy(progress, vocab_pdm_score['progress'].to(progress.dtype))
vocab = predictions["trajectory_vocab"]
# B, 8 (4 secs, 0.5Hz), 3
target_traj = targets["trajectory"]
# 4, 9, ..., 39
sampled_timepoints = [5 * k - 1 for k in range(1, 9)]
B = target_traj.shape[0]
l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma
imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1))
imi_loss_final = config.trajectory_imi_weight * imi_loss
noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss
da_loss_final = config.trajectory_pdm_weight['da'] * da_loss
ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss
progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss
comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss
loss = (
imi_loss_final
+ noc_loss_final
+ da_loss_final
+ ttc_loss_final
+ progress_loss_final
+ comfort_loss_final
)
return loss, {
'imi_loss': imi_loss_final,
'pdm_noc_loss': noc_loss_final,
'pdm_da_loss': da_loss_final,
'pdm_ttc_loss': ttc_loss_final,
'pdm_progress_loss': progress_loss_final,
'pdm_comfort_loss': comfort_loss_final
}
class HydraAgentPENoDet(AbstractAgent):
def __init__(
self,
config: HydraConfig,
lr: float,
checkpoint_path: str = None,
pdm_split=None,
metrics=None,
):
super().__init__()
config.trajectory_pdm_weight = {
'noc': 3.0,
'da': 3.0,
'ttc': config.ttc_weight,
'progress': config.progress_weight,
'comfort': 1.0,
}
self._config = config
self._lr = lr
self.metrics = metrics
self._checkpoint_path = checkpoint_path
self.vadv2_model = HydraModelPENoDet(config)
self.vocab_size = config.vocab_size
self.backbone_wd = config.backbone_wd
new_pkl_dir = f'vocab_score_full_{self.vocab_size}_navtrain'
self.vocab_pdm_score_full = pickle.load(
open(f'{TRAJ_PDM_ROOT}/{new_pkl_dir}/{pdm_split}.pkl', 'rb'))
def name(self) -> str:
"""Inherited, see superclass."""
return self.__class__.__name__
def initialize(self) -> None:
"""Inherited, see superclass."""
# if torch.cuda.is_available():
# state_dict: Dict[str, Any] = torch.load(self._checkpoint_path)["state_dict"]
# else:
# state_dict: Dict[str, Any] = torch.load(self._checkpoint_path, map_location=torch.device("cpu"))[
# "state_dict"]
state_dict: Dict[str, Any] = torch.load(self._checkpoint_path, map_location=torch.device("cpu"))["state_dict"]
self.load_state_dict({k.replace("agent.", ""): v for k, v in state_dict.items()})
def get_sensor_config(self) -> SensorConfig:
"""Inherited, see superclass."""
return SensorConfig(
cam_f0=[0, 1, 2, 3],
cam_l0=[0, 1, 2, 3],
cam_l1=[0, 1, 2, 3],
cam_l2=[0, 1, 2, 3],
cam_r0=[0, 1, 2, 3],
cam_r1=[0, 1, 2, 3],
cam_r2=[0, 1, 2, 3],
cam_b0=[0, 1, 2, 3],
lidar_pc=[],
)
def get_target_builders(self) -> List[AbstractTargetBuilder]:
return [HydraTargetBuilder(config=self._config)]
def get_feature_builders(self) -> List[AbstractFeatureBuilder]:
return [HydraFeatureBuilder(config=self._config)]
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return self.vadv2_model(features)
def forward_train(self, features, interpolated_traj):
return self.vadv2_model(features, interpolated_traj)
def compute_loss(
self,
features: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor],
predictions: Dict[str, torch.Tensor],
tokens=None
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
# get the pdm score by tokens
scores = {}
for k in self.metrics:
tmp = [self.vocab_pdm_score_full[token][k][None] for token in tokens]
scores[k] = (torch.from_numpy(np.concatenate(tmp, axis=0))
.to(predictions['trajectory'].device))
return hydra_nodet_loss(targets, predictions, self._config, scores)
def get_optimizers(self) -> Union[Optimizer, Dict[str, Union[Optimizer, LRScheduler]]]:
backbone_params_name = '_backbone.image_encoder'
img_backbone_params = list(
filter(lambda kv: backbone_params_name in kv[0], self.vadv2_model.named_parameters()))
default_params = list(filter(lambda kv: backbone_params_name not in kv[0], self.vadv2_model.named_parameters()))
params_lr_dict = [
{'params': [tmp[1] for tmp in default_params]},
{
'params': [tmp[1] for tmp in img_backbone_params],
'lr': self._lr * self._config.lr_mult_backbone,
'weight_decay': self.backbone_wd
}
]
return torch.optim.Adam(params_lr_dict, lr=self._lr)
def get_training_callbacks(self) -> List[pl.Callback]:
return [
# TransfuserCallback(self._config),
ModelCheckpoint(
save_top_k=30,
monitor="val/loss_epoch",
mode="min",
dirpath=f"{os.environ.get('NAVSIM_EXP_ROOT')}/{self._config.ckpt_path}/",
filename="{epoch:02d}-{step:04d}",
)
]