lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
8.46 kB
import math
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
from diffusers import DDIMScheduler
from navsim.agents.dm.backbone import DMBackbone
from navsim.agents.dm.dm_config import DMConfig
from navsim.agents.dm.utils import VerletStandardizer
from navsim.agents.transfuser.transfuser_model import AgentHead
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class DMModel(nn.Module):
def __init__(self, config: DMConfig):
super().__init__()
self._query_splits = [
config.num_bounding_boxes,
]
self._config = config
assert config.backbone_type in ['vit', 'intern', 'vov', 'resnet', 'eva', 'moe', 'moe_ult32', 'swin']
if config.backbone_type == 'eva':
raise ValueError(f'{config.backbone_type} not supported')
elif config.backbone_type == 'intern' or config.backbone_type == 'vov' or \
config.backbone_type == 'swin' or config.backbone_type == 'vit':
self._backbone = DMBackbone(config)
img_num = 2 if config.use_back_view else 1
self._keyval_embedding = nn.Embedding(
config.img_vert_anchors * config.img_horz_anchors * img_num, config.tf_d_model
) # 8x8 feature grid + trajectory
self._query_embedding = nn.Embedding(sum(self._query_splits), config.tf_d_model)
# usually, the BEV features are variable in size.
self.downscale_layer = nn.Conv2d(self._backbone.img_feat_c, config.tf_d_model, kernel_size=1)
self._status_encoding = nn.Linear((4 + 2 + 2) * config.num_ego_status, config.tf_d_model)
tf_decoder_layer = nn.TransformerDecoderLayer(
d_model=config.tf_d_model,
nhead=config.tf_num_head,
dim_feedforward=config.tf_d_ffn,
dropout=config.tf_dropout,
batch_first=True,
)
self._tf_decoder = nn.TransformerDecoder(tf_decoder_layer, config.tf_num_layers)
self._agent_head = AgentHead(
num_agents=config.num_bounding_boxes,
d_ffn=config.tf_d_ffn,
d_model=config.tf_d_model,
)
self._trajectory_head = DMTrajHead(
num_poses=config.trajectory_sampling.num_poses,
d_ffn=config.tf_d_ffn,
d_model=config.tf_d_model,
nhead=config.vadv2_head_nhead,
nlayers=config.vadv2_head_nlayers,
vocab_path=config.vocab_path,
config=config
)
def img_feat_blc(self, camera_feature):
img_features = self._backbone(camera_feature)
img_features = self.downscale_layer(img_features).flatten(-2, -1)
img_features = img_features.permute(0, 2, 1)
return img_features
def forward(self, features: Dict[str, torch.Tensor],
interpolated_traj=None) -> Dict[str, torch.Tensor]:
camera_feature: torch.Tensor = features["camera_feature"]
status_feature: torch.Tensor = features["status_feature"]
if isinstance(camera_feature, list):
camera_feature = camera_feature[-1]
# todo temp fix!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# status_feature[:, 0] = 0.0
# status_feature[:, 1] = 1.0
# status_feature[:, 2] = 0.0
# status_feature[:, 3] = 0.0
batch_size = status_feature.shape[0]
img_features = self.img_feat_blc(camera_feature)
if self._config.use_back_view:
img_features_back = self.img_feat_blc(features["camera_feature_back"])
img_features = torch.cat([img_features, img_features_back], 1)
if self._config.num_ego_status == 1 and status_feature.shape[1] == 32:
status_encoding = self._status_encoding(status_feature[:, :8])
else:
status_encoding = self._status_encoding(status_feature)
keyval = img_features
keyval += self._keyval_embedding.weight[None, ...]
query = self._query_embedding.weight[None, ...].repeat(batch_size, 1, 1)
agents_query = self._tf_decoder(query, keyval)
output: Dict[str, torch.Tensor] = {}
trajectory = self._trajectory_head(keyval, status_encoding, features['history_waypoints'])
output.update(trajectory)
agents = self._agent_head(agents_query)
output.update(agents)
return output
class DMTrajHead(nn.Module):
def __init__(self, num_poses: int, d_ffn: int, d_model: int, vocab_path: str,
nhead: int, nlayers: int, config: DMConfig = None
):
super().__init__()
self.d_model = d_model
self.config = config
self._num_poses = num_poses
self.transformer = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model, nhead, d_ffn,
dropout=0.0, batch_first=True
), nlayers
)
self.vocab = nn.Parameter(
torch.from_numpy(np.load(vocab_path)),
requires_grad=False
)
self.H = config.trajectory_sampling.num_poses
self.T = config.T
self.standardizer = VerletStandardizer()
self.decoder_mlp = nn.Sequential(
nn.Linear(self.d_model, self.d_model),
nn.ReLU(),
nn.Linear(self.d_model, self.d_model),
nn.ReLU(),
nn.Linear(self.d_model, self.H * 3)
)
self.encoder_mlp = nn.Sequential(
nn.Linear(self.H * 3, self.d_model),
nn.ReLU(),
nn.Linear(self.d_model, self.d_model),
)
self.sigma_encoder = nn.Sequential(
SinusoidalPosEmb(self.d_model),
)
self.scheduler = DDIMScheduler(
num_train_timesteps=self.T,
beta_schedule='scaled_linear',
prediction_type='epsilon',
)
self.scheduler.set_timesteps(self.T)
def denoise(self, ego_trajectory, env_features, status_encoding, timesteps):
B = ego_trajectory.shape[0]
ego_trajectory = ego_trajectory.reshape(B, -1).to(torch.float32)
sigma = timesteps.reshape(-1, 1)
if sigma.numel() == 1:
sigma = sigma.repeat(B, 1)
sigma = sigma.float() / self.T
sigma_embeddings = self.sigma_encoder(sigma).squeeze(1)
ego_emb = self.encoder_mlp(ego_trajectory) + status_encoding + sigma_embeddings
ego_attn = self.transformer(ego_emb[:, None], env_features)
out = self.decoder_mlp(ego_attn).reshape(B, -1)
return out
def forward(self, bev_feature, status_encoding, history_waypoints) -> Dict[str, torch.Tensor]:
# todo sinusoidal embedding
# vocab: 4096, 40, 3
# bev_feature: B, 32, C
# embedded_vocab: B, 4096, C
B = bev_feature.shape[0]
result = {}
if not self.config.is_training:
ego_trajectory = torch.randn((B, self.H * 3),
device=bev_feature.device)
timesteps = self.scheduler.timesteps
residual = torch.zeros_like(ego_trajectory)
for t in timesteps:
with torch.no_grad():
residual += self.denoise(
ego_trajectory,
bev_feature,
status_encoding,
t.to(ego_trajectory.device)
)
out = self.scheduler.step(residual, t, ego_trajectory)
ego_trajectory = out.prev_sample
ego_trajectory = self.standardizer.untransform_features(ego_trajectory, history_waypoints)
result["trajectory"] = ego_trajectory.reshape(B, self.H, 3)
result['imi'], result['noc'], result['da'], result['ttc'], result['comfort'], result['progress'] = (
torch.ones((B, 4096)) for _ in range(6)
)
result['history_waypoints'] = history_waypoints
result['env_features'] = bev_feature
result['status_encoding'] = status_encoding
return result