|
from typing import Dict |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from torchvision import transforms |
|
|
|
from navsim.agents.dreamer.hydra_dreamer_config import HydraDreamerConfig |
|
from navsim.common.dataclasses import AgentInput, Scene |
|
from navsim.common.dataclasses import Cameras |
|
from navsim.planning.training.abstract_feature_target_builder import ( |
|
AbstractFeatureBuilder, |
|
AbstractTargetBuilder, |
|
) |
|
|
|
|
|
def cat_flr_imgs(camera: Cameras, config: HydraDreamerConfig): |
|
l0 = camera.cam_l0.image[28:-28, 416:-416] |
|
f0 = camera.cam_f0.image[28:-28] |
|
r0 = camera.cam_r0.image[28:-28, 416:-416] |
|
|
|
stitched_image = np.concatenate([l0, f0, r0], axis=1) |
|
resized_image = cv2.resize(stitched_image, (config.camera_width, config.camera_height)) |
|
tensor_image = transforms.ToTensor()(resized_image) |
|
return tensor_image |
|
|
|
|
|
class HydraDreamerWmFeatureBuilder(AbstractFeatureBuilder): |
|
def __init__(self, config: HydraDreamerConfig): |
|
super().__init__() |
|
self._config = config |
|
|
|
def get_unique_name(self) -> str: |
|
"""Inherited, see superclass.""" |
|
return "hydra_dreamer_wm_feature" |
|
|
|
def _get_camera_feature(self, agent_input: AgentInput): |
|
""" |
|
Extract stitched camera from AgentInput |
|
:param agent_input: input dataclass |
|
:return: stitched front view image as torch tensor |
|
""" |
|
|
|
cameras = agent_input.cameras[:3] |
|
image_list = [] |
|
for camera in cameras: |
|
image_list.append(cat_flr_imgs(camera, self._config)) |
|
|
|
return image_list |
|
|
|
def compute_features(self, agent_input: AgentInput) -> Dict[str, torch.Tensor]: |
|
"""Inherited, see superclass.""" |
|
features = {} |
|
ego_status_list = [] |
|
for i in range(self._config.num_ego_status): |
|
|
|
|
|
|
|
|
|
idx = - (i + 1) |
|
ego_status_list += [ |
|
torch.tensor(agent_input.ego_statuses[idx].driving_command, dtype=torch.float32), |
|
torch.tensor(agent_input.ego_statuses[idx].ego_velocity, dtype=torch.float32), |
|
torch.tensor(agent_input.ego_statuses[idx].ego_acceleration, dtype=torch.float32), |
|
] |
|
|
|
features["status_feature"] = torch.concatenate( |
|
ego_status_list |
|
) |
|
imgs = self._get_camera_feature(agent_input) |
|
features['img_3'] = imgs[0] |
|
features['img_2'] = imgs[1] |
|
features['img_1'] = imgs[2] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return features |
|
|
|
|
|
class HydraDreamerWmTargetBuilder(AbstractTargetBuilder): |
|
def __init__(self, config: HydraDreamerConfig): |
|
super().__init__() |
|
self._config = config |
|
|
|
def get_unique_name(self) -> str: |
|
"""Inherited, see superclass.""" |
|
return "hydra_dreamer_wm_target" |
|
|
|
def compute_targets(self, scene: Scene) -> Dict[str, torch.Tensor]: |
|
return { |
|
'img_gt': cat_flr_imgs(scene.get_agent_input().cameras[-1], self._config) |
|
} |
|
|