navsim_ours / det_map /map /map_agent.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
6.26 kB
from __future__ import annotations
from typing import Any, List, Dict
import torch
import torch.optim as optim
import copy
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
import torch.nn as nn
from det_map.data.datasets.dataclasses import SensorConfig, Scene
from det_map.data.datasets.feature_builders import LiDARCameraFeatureBuilder
from navsim.agents.abstract_agent import AbstractAgent
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder
from det_map.det.dal.mmdet3d.models.utils.grid_mask import GridMask
import torch.nn.functional as F
from det_map.det.dal.mmdet3d.ops import Voxelization, DynamicScatter
from det_map.det.dal.mmdet3d.models import builder
from mmcv.utils import TORCH_VERSION, digit_version
from typing import Any, List, Dict
import numpy as np
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from det_map.data.datasets.dataclasses import SensorConfig, Scene
from det_map.data.datasets.feature_builders import LiDARCameraFeatureBuilder
from det_map.map.map_target import MapTargetBuilder
from navsim.agents.abstract_agent import AbstractAgent
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder
import torch.optim as optim
try:
from det_map.map.assigners import *
from det_map.map.dense_heads import *
from det_map.map.losses import *
from det_map.map.modules import *
except Exception:
raise Exception
class MapAgent(AbstractAgent):
def __init__(
self,
model,
pipelines,
lr: float,
checkpoint_path: str = None, **kwargs
):
super().__init__()
# todo eval everything
self.model = model
self.pipelines = pipelines
self._checkpoint_path = checkpoint_path
self._lr = lr
def name(self) -> str:
"""Inherited, see superclass."""
return self.__class__.__name__
def initialize(self) -> None:
"""Inherited, see superclass."""
state_dict: Dict[str, Any] = torch.load(self._checkpoint_path)["state_dict"]
self.load_state_dict({k.replace("agent.", ""): v for k, v in state_dict.items()})
def get_sensor_config(self) -> SensorConfig:
"""Inherited, see superclass."""
return SensorConfig.build_all_sensors(True)
def get_target_builders(self) -> List[AbstractTargetBuilder]:
return [
MapTargetBuilder(),
]
def get_feature_builders(self) -> List[AbstractFeatureBuilder]:
return [
LiDARCameraFeatureBuilder(self.pipelines)
]
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return self.model(features)
def compute_loss(
self,
features: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor],
predictions: Dict[str, torch.Tensor],
tokens=None
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
losses = dict()
# depth = predictions.pop('depth')
# if "gt_depth" in targets:
# gt_depth = targets["gt_depth"]
# loss_depth = self.pts_bbox_head.transformer.encoder.get_depth_loss(gt_depth, depth)
# if digit_version(TORCH_VERSION) >= digit_version('1.8'):
# loss_depth = torch.nan_to_num(loss_depth)
# losses.update(loss_depth=loss_depth)
gt_bboxes_3d = targets["gt_bboxes_3d"]
gt_labels_3d = targets["gt_labels_3d"]
# print(type(gt_labels_3d))
# gt_labels_3d = torch.tensor(gt_labels_3d)
#import pdb;
#pdb.set_trace()
#gt_labels_3d = None
gt_seg_mask = None
gt_pv_seg_mask = None
# gt_seg_mask = targets["gt_seg_mask"]
# gt_pv_seg_mask = targets["gt_pv_seg_mask"]
#import pdb;
# pdb.set_trace()
loss_inputs = [gt_bboxes_3d, gt_labels_3d, gt_seg_mask, gt_pv_seg_mask, predictions]
losses_pts = self.model.pts_bbox_head.loss(*loss_inputs, img_metas=None)
losses.update(losses_pts)
k_one2many = self.model.pts_bbox_head.k_one2many
multi_gt_bboxes_3d = copy.deepcopy(gt_bboxes_3d)
multi_gt_labels_3d = copy.deepcopy(gt_labels_3d)
# multi_gt_labels_3d = torch.zeros((gt_labels_3d.size(0), gt_labels_3d.size(1) * k_one2many))
for i, (each_gt_bboxes_3d, each_gt_labels_3d) in enumerate(zip(multi_gt_bboxes_3d, multi_gt_labels_3d)):
each_gt_bboxes_3d.instance_list = each_gt_bboxes_3d.instance_list * k_one2many
each_gt_bboxes_3d.instance_labels = each_gt_bboxes_3d.instance_labels * k_one2many
multi_gt_labels_3d[i] = each_gt_labels_3d.repeat(k_one2many)
one2many_outs = predictions['one2many_outs']
loss_one2many_inputs = [multi_gt_bboxes_3d, multi_gt_labels_3d, gt_seg_mask, gt_pv_seg_mask, one2many_outs]
loss_dict_one2many = self.model.pts_bbox_head.loss(*loss_one2many_inputs, img_metas=None)
lambda_one2many = self.model.pts_bbox_head.lambda_one2many
for key, value in loss_dict_one2many.items():
if key + "_one2many" in losses.keys():
losses[key + "_one2many"] += value * lambda_one2many
else:
losses[key + "_one2many"] = value * lambda_one2many
loss = 0
for k, v in losses.items():
loss = loss + v
return loss, losses
def get_optimizers(self) -> Optimizer | Dict[str, Optimizer | LRScheduler]:
optimizer = initialize_optimizer(self.model, self._lr)
return {'optimizer': optimizer}
def initialize_optimizer(model, lr):
optimizer = optim.AdamW([
{'params': [param for name, param in model.named_parameters() if 'img_backbone' in name], 'lr': lr * 0.1},
{'params': [param for name, param in model.named_parameters() if 'img_backbone' not in name], 'lr': lr},
], weight_decay=0.01)
return optimizer