File size: 3,460 Bytes
da2e2ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from typing import Dict, Tuple, List

import pytorch_lightning as pl
import torch
from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling
from torch import Tensor

from navsim.agents.abstract_agent import AbstractAgent
from navsim.agents.vadv2.vadv2_agent import Vadv2Agent
from navsim.common.dataclasses import Trajectory


class AgentLightningModuleMap(pl.LightningModule):
    def __init__(

            self,

            agent: AbstractAgent,

    ):
        super().__init__()
        self.agent = agent

    def _step(

            self,

            batch: Tuple[Dict[str, Tensor], Dict[str, Tensor], List[str]],

            logging_prefix: str,

    ):
        features, targets = batch
        if logging_prefix in ['train', 'val'] and isinstance(self.agent, Vadv2Agent):
            prediction = self.agent.forward_train(features, targets['interpolated_traj'])
        else:
            prediction = self.agent.forward(features)

        loss, loss_dict = self.agent.compute_loss(features, targets, prediction)

        for k, v in loss_dict.items():
            self.log(f"{logging_prefix}/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log(f"{logging_prefix}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        return loss

    def training_step(

            self,

            batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],

            batch_idx: int

    ):
        return self._step(batch, "train")

    def validation_step(

            self,

            batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],

            batch_idx: int

    ):
        return self._step(batch, "val")

    def configure_optimizers(self):
        return self.agent.get_optimizers()

    def predict_step(

            self,

            batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],

            batch_idx: int

    ):
        features, targets, tokens = batch
        self.agent.eval()
        with torch.no_grad():
            predictions = self.agent.forward(features)
            poses = predictions["trajectory"].cpu().numpy()

            imis = predictions["imi"].softmax(-1).log().cpu().numpy()
            nocs = predictions["noc"].log().cpu().numpy()
            das = predictions["da"].log().cpu().numpy()
            ttcs = predictions["ttc"].log().cpu().numpy()
            comforts = predictions["comfort"].log().cpu().numpy()
            progresses = predictions["progress"].log().cpu().numpy()
        if poses.shape[1] == 40:
            interval_length = 0.1
        else:
            interval_length = 0.5

        return {token: {
            'trajectory': Trajectory(pose, TrajectorySampling(time_horizon=4, interval_length=interval_length)),
            'imi': imi,
            'noc': noc,
            'da': da,
            'ttc': ttc,
            'comfort': comfort,
            'progress': progress
        } for pose, imi, noc, da, ttc, comfort, progress, token in zip(poses, imis, nocs, das, ttcs, comforts, progresses,
                                                                          tokens)}
    # def on_after_backward(self) -> None:
    #     print("on_after_backward enter")
    #     for name, param in self.named_parameters():
    #         if param.grad is None:
    #             print(name)
    #     print("on_after_backward exit")