navsim_ours / det_map /agent_lightning.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
3.46 kB
from typing import Dict, Tuple, List
import pytorch_lightning as pl
import torch
from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling
from torch import Tensor
from navsim.agents.abstract_agent import AbstractAgent
from navsim.agents.vadv2.vadv2_agent import Vadv2Agent
from navsim.common.dataclasses import Trajectory
class AgentLightningModuleMap(pl.LightningModule):
def __init__(
self,
agent: AbstractAgent,
):
super().__init__()
self.agent = agent
def _step(
self,
batch: Tuple[Dict[str, Tensor], Dict[str, Tensor], List[str]],
logging_prefix: str,
):
features, targets = batch
if logging_prefix in ['train', 'val'] and isinstance(self.agent, Vadv2Agent):
prediction = self.agent.forward_train(features, targets['interpolated_traj'])
else:
prediction = self.agent.forward(features)
loss, loss_dict = self.agent.compute_loss(features, targets, prediction)
for k, v in loss_dict.items():
self.log(f"{logging_prefix}/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
self.log(f"{logging_prefix}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
return loss
def training_step(
self,
batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
batch_idx: int
):
return self._step(batch, "train")
def validation_step(
self,
batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
batch_idx: int
):
return self._step(batch, "val")
def configure_optimizers(self):
return self.agent.get_optimizers()
def predict_step(
self,
batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
batch_idx: int
):
features, targets, tokens = batch
self.agent.eval()
with torch.no_grad():
predictions = self.agent.forward(features)
poses = predictions["trajectory"].cpu().numpy()
imis = predictions["imi"].softmax(-1).log().cpu().numpy()
nocs = predictions["noc"].log().cpu().numpy()
das = predictions["da"].log().cpu().numpy()
ttcs = predictions["ttc"].log().cpu().numpy()
comforts = predictions["comfort"].log().cpu().numpy()
progresses = predictions["progress"].log().cpu().numpy()
if poses.shape[1] == 40:
interval_length = 0.1
else:
interval_length = 0.5
return {token: {
'trajectory': Trajectory(pose, TrajectorySampling(time_horizon=4, interval_length=interval_length)),
'imi': imi,
'noc': noc,
'da': da,
'ttc': ttc,
'comfort': comfort,
'progress': progress
} for pose, imi, noc, da, ttc, comfort, progress, token in zip(poses, imis, nocs, das, ttcs, comforts, progresses,
tokens)}
# def on_after_backward(self) -> None:
# print("on_after_backward enter")
# for name, param in self.named_parameters():
# if param.grad is None:
# print(name)
# print("on_after_backward exit")