# model_definition.py # ============================================================================ # الاستيرادات الأساسية # ============================================================================ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.optim import AdamW from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import Dataset, DataLoader from torchvision import transforms from functools import partial from typing import Optional, List from torch import Tensor import os import json import numpy as np import cv2 from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from functools import partial from collections import deque, OrderedDict import math from torch.nn import MultiheadAttention from torch.nn import TransformerEncoder, TransformerEncoderLayer from torch.nn import TransformerDecoder, TransformerDecoderLayer from timm.models.resnet import resnet50d, resnet26d, resnet18d try: from timm.layers import trunc_normal_ except ImportError: from timm.models.layers import trunc_normal_ from huggingface_hub import hf_hub_download, HfApi from huggingface_hub.utils import HfFolder # مكتبات إضافية import os import json import logging import math import copy from pathlib import Path from collections import OrderedDict # مكتبات معالجة البيانات import numpy as np import cv2 # مكتبات اختيارية (يمكن تعطيلها إذا لم تكن متوفرة) try: from tqdm import tqdm except ImportError: # إذا لم تكن tqdm متوفرة، استخدم دالة بديلة def tqdm(iterable, *args, **kwargs): return iterable # ============================================================================ # دوال مساعدة # ============================================================================ def to_2tuple(x): """تحويل قيمة إلى tuple من عنصرين""" if isinstance(x, (list, tuple)): return tuple(x) return (x, x) # ============================================================================ # ============================================================================ class HybridEmbed(nn.Module): def __init__( self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768, ): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.backbone = backbone if feature_size is None: with torch.no_grad(): training = backbone.training if training: backbone.eval() o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) if isinstance(o, (list, tuple)): o = o[-1] # last feature if backbone outputs list/tuple of features feature_size = o.shape[-2:] feature_dim = o.shape[1] backbone.train(training) else: feature_size = to_2tuple(feature_size) if hasattr(self.backbone, "feature_info"): feature_dim = self.backbone.feature_info.channels()[-1] else: feature_dim = self.backbone.num_features self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1) def forward(self, x): x = self.backbone(x) if isinstance(x, (list, tuple)): x = x[-1] # last feature if backbone outputs list/tuple of features x = self.proj(x) global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None] return x, global_x class HyperDimensionalPositionalEncoding(nn.Module): """ [GCPE v1.1 - Professional & Corrected Implementation] A novel positional encoding scheme based on geometric centrality. This class is designed as a drop-in replacement for the standard PositionEmbeddingSine, accepting similar arguments and producing an output of the same shape. This version corrects a type error in the distance calculation. """ def __init__(self, num_pos_feats=256, temperature=10000, normalize=True, scale=None): """ Args: num_pos_feats (int): The desired number of output channels for the positional encoding. This must be an even number. temperature (int): A constant used to scale the frequencies. normalize (bool): If True, normalizes the coordinates to the range [0, scale]. scale (float, optional): The scaling factor for normalization. Defaults to 2*pi. """ super().__init__() if num_pos_feats % 2 != 0: raise ValueError(f"num_pos_feats must be an even number, but got {num_pos_feats}") self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize if scale is not None and not normalize: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * math.pi self.scale = scale def forward(self, tensor: torch.Tensor) -> torch.Tensor: """ Args: tensor (torch.Tensor): A 4D tensor of shape (B, C, H, W). The content is not used, only its shape and device. Returns: torch.Tensor: A 4D tensor of positional encodings with shape (B, num_pos_feats, H, W). """ batch_size, _, h, w = tensor.shape device = tensor.device # 1. Create coordinate grids y_embed = torch.arange(h, dtype=torch.float32, device=device).view(h, 1) x_embed = torch.arange(w, dtype=torch.float32, device=device).view(1, w) # 2. Calculate normalized distance from the center # Use floating point division for center calculation center_y, center_x = (h - 1) / 2.0, (w - 1) / 2.0 # Calculate the Euclidean distance for each pixel from the center dist_map = torch.sqrt( (y_embed - center_y)**2 + (x_embed - center_x)**2 ) # ✅ CORRECTION: The max distance is a scalar, no need for torch.sqrt on a float. # We can calculate it with math.sqrt or just compute the squared value. # To keep everything in tensors for consistency, we can do this: max_dist_sq = torch.tensor(center_y**2 + center_x**2, device=device) max_dist = torch.sqrt(max_dist_sq) # Normalize the distance map to the range [0, 1] normalized_dist_map = dist_map / (max_dist + 1e-6) if self.normalize: normalized_dist_map = normalized_dist_map * self.scale pos_dist = normalized_dist_map.unsqueeze(0).repeat(batch_size, 1, 1) # 3. Create the frequency-based embedding # This part remains the same as it operates on tensors correctly. dim_t = torch.arange(self.num_pos_feats // 2, dtype=torch.float32, device=device) dim_t = self.temperature ** (2 * dim_t / (self.num_pos_feats // 2)) pos = pos_dist.unsqueeze(-1) / dim_t pos_sin = pos.sin() pos_cos = pos.cos() # 4. Concatenate and reshape to match the desired output format pos = torch.cat((pos_sin, pos_cos), dim=3) pos = pos.permute(0, 3, 1, 2) return pos class TransformerEncoder(nn.Module): def __init__(self, encoder_layer, num_layers, norm=None): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward( self, src, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): output = src for layer in self.layers: output = layer( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos, ) if self.norm is not None: output = self.norm(output) return output class SpatialSoftmax(nn.Module): def __init__(self, height, width, channel, temperature=None, data_format="NCHW"): super().__init__() self.data_format = data_format self.height = height self.width = width self.channel = channel if temperature: self.temperature = Parameter(torch.ones(1) * temperature) else: self.temperature = 1.0 pos_x, pos_y = np.meshgrid( np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width) ) pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float() pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float() self.register_buffer("pos_x", pos_x) self.register_buffer("pos_y", pos_y) def forward(self, feature): # Output: # (N, C*2) x_0 y_0 ... if self.data_format == "NHWC": feature = ( feature.transpose(1, 3) .tranpose(2, 3) .view(-1, self.height * self.width) ) else: feature = feature.view(-1, self.height * self.width) weight = F.softmax(feature / self.temperature, dim=-1) expected_x = torch.sum( torch.autograd.Variable(self.pos_x) * weight, dim=1, keepdim=True ) expected_y = torch.sum( torch.autograd.Variable(self.pos_y) * weight, dim=1, keepdim=True ) expected_xy = torch.cat([expected_x, expected_y], 1) feature_keypoints = expected_xy.view(-1, self.channel, 2) feature_keypoints[:, :, 1] = (feature_keypoints[:, :, 1] - 1) * 12 feature_keypoints[:, :, 0] = feature_keypoints[:, :, 0] * 12 return feature_keypoints class MultiPath_Generator(nn.Module): def __init__(self, in_channel, embed_dim, out_channel): super().__init__() self.spatial_softmax = SpatialSoftmax(100, 100, out_channel) self.tconv0 = nn.Sequential( nn.ConvTranspose2d(in_channel, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), ) self.tconv1 = nn.Sequential( nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), ) self.tconv2 = nn.Sequential( nn.ConvTranspose2d(256, 192, 4, 2, 1, bias=False), nn.BatchNorm2d(192), nn.ReLU(True), ) self.tconv3 = nn.Sequential( nn.ConvTranspose2d(192, 64, 4, 2, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), ) self.tconv4_list = torch.nn.ModuleList( [ nn.Sequential( nn.ConvTranspose2d(64, out_channel, 8, 2, 3, bias=False), nn.Tanh(), ) for _ in range(6) ] ) self.upsample = nn.Upsample(size=(50, 50), mode="bilinear") def forward(self, x, measurements): mask = measurements[:, :6] mask = mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 100, 100) velocity = measurements[:, 6:7].unsqueeze(-1).unsqueeze(-1) velocity = velocity.repeat(1, 32, 2, 2) n, d, c = x.shape x = x.transpose(1, 2) x = x.view(n, -1, 2, 2) x = torch.cat([x, velocity], dim=1) x = self.tconv0(x) x = self.tconv1(x) x = self.tconv2(x) x = self.tconv3(x) x = self.upsample(x) xs = [] for i in range(6): xt = self.tconv4_list[i](x) xs.append(xt) xs = torch.stack(xs, dim=1) x = torch.sum(xs * mask, dim=1) x = self.spatial_softmax(x) return x class LinearWaypointsPredictor(nn.Module): def __init__(self, input_dim, cumsum=True): super().__init__() self.cumsum = cumsum self.rank_embed = nn.Parameter(torch.zeros(1, 10, input_dim)) self.head_fc1_list = nn.ModuleList([nn.Linear(input_dim, 64) for _ in range(6)]) self.head_relu = nn.ReLU(inplace=True) self.head_fc2_list = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)]) def forward(self, x, measurements): # input shape: n 10 embed_dim bs, n, dim = x.shape x = x + self.rank_embed x = x.reshape(-1, dim) mask = measurements[:, :6] mask = torch.unsqueeze(mask, -1).repeat(n, 1, 2) rs = [] for i in range(6): res = self.head_fc1_list[i](x) res = self.head_relu(res) res = self.head_fc2_list[i](res) rs.append(res) rs = torch.stack(rs, 1) x = torch.sum(rs * mask, dim=1) x = x.view(bs, n, 2) if self.cumsum: x = torch.cumsum(x, 1) return x class GRUWaypointsPredictor(nn.Module): def __init__(self, input_dim, waypoints=10): super().__init__() # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64) self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) self.encoder = nn.Linear(2, 64) self.decoder = nn.Linear(64, 2) self.waypoints = waypoints def forward(self, x, target_point): bs = x.shape[0] z = self.encoder(target_point).unsqueeze(0) output, _ = self.gru(x, z) output = output.reshape(bs * self.waypoints, -1) output = self.decoder(output).reshape(bs, self.waypoints, 2) output = torch.cumsum(output, 1) return output class GRUWaypointsPredictorWithCommand(nn.Module): def __init__(self, input_dim, waypoints=10): super().__init__() # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64) self.grus = nn.ModuleList([torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) for _ in range(6)]) self.encoder = nn.Linear(2, 64) self.decoders = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)]) self.waypoints = waypoints def forward(self, x, target_point, measurements): bs, n, dim = x.shape mask = measurements[:, :6, None, None] mask = mask.repeat(1, 1, self.waypoints, 2) z = self.encoder(target_point).unsqueeze(0) outputs = [] for i in range(6): output, _ = self.grus[i](x, z) output = output.reshape(bs * self.waypoints, -1) output = self.decoders[i](output).reshape(bs, self.waypoints, 2) output = torch.cumsum(output, 1) outputs.append(output) outputs = torch.stack(outputs, 1) output = torch.sum(outputs * mask, dim=1) return output class TransformerDecoder(nn.Module): def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): super().__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm self.return_intermediate = return_intermediate def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): output = tgt intermediate = [] for layer in self.layers: output = layer( output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, pos=pos, query_pos=query_pos, ) if self.return_intermediate: intermediate.append(self.norm(output)) if self.norm is not None: output = self.norm(output) if self.return_intermediate: intermediate.pop() intermediate.append(output) if self.return_intermediate: return torch.stack(intermediate) return output.unsqueeze(0) class TransformerEncoderLayer(nn.Module): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.ReLU(), normalize_before=False, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = activation() self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(src, pos) src2 = self.self_attn( q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask )[0] src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = src + self.dropout2(src2) src = self.norm2(src) return src def forward_pre( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): src2 = self.norm1(src) q = k = self.with_pos_embed(src2, pos) src2 = self.self_attn( q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask )[0] src = src + self.dropout1(src2) src2 = self.norm2(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) src = src + self.dropout2(src2) return src def forward( self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, ): if self.normalize_before: return self.forward_pre(src, src_mask, src_key_padding_mask, pos) return self.forward_post(src, src_mask, src_key_padding_mask, pos) class TransformerDecoderLayer(nn.Module): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.ReLU(), normalize_before=False, ): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation = activation() self.normalize_before = normalize_before def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn( q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask )[0] tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout3(tgt2) tgt = self.norm3(tgt) return tgt def forward_pre( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): tgt2 = self.norm1(tgt) q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn( q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask )[0] tgt = tgt + self.dropout1(tgt2) tgt2 = self.norm2(tgt) tgt2 = self.multihead_attn( query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt2 = self.norm3(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout3(tgt2) return tgt def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): if self.normalize_before: return self.forward_pre( tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, ) return self.forward_post( tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, ) def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError(f"activation should be relu/gelu, not {activation}.") def build_attn_mask(mask_type): mask = torch.ones((151, 151), dtype=torch.bool).cuda() if mask_type == "seperate_all": mask[:50, :50] = False mask[50:67, 50:67] = False mask[67:84, 67:84] = False mask[84:101, 84:101] = False mask[101:151, 101:151] = False elif mask_type == "seperate_view": mask[:50, :50] = False mask[50:67, 50:67] = False mask[67:84, 67:84] = False mask[84:101, 84:101] = False mask[101:151, :] = False mask[:, 101:151] = False return mask # class InterfuserModel(nn.Module): class InterfuserHDPE(nn.Module): def __init__( self, img_size=224, multi_view_img_size=112, patch_size=8, in_chans=3, embed_dim=768, enc_depth=6, dec_depth=6, dim_feedforward=2048, normalize_before=False, rgb_backbone_name="r50", lidar_backbone_name="r50", num_heads=8, norm_layer=None, dropout=0.1, end2end=False, direct_concat=False, separate_view_attention=False, separate_all_attention=False, act_layer=None, weight_init="", freeze_num=-1, with_lidar=False, with_right_left_sensors=False, with_center_sensor=False, traffic_pred_head_type="det", waypoints_pred_head="heatmap", reverse_pos=True, use_different_backbone=False, use_view_embed=False, use_mmad_pretrain=None, ): super().__init__() self.traffic_pred_head_type = traffic_pred_head_type self.num_features = ( self.embed_dim ) = embed_dim # num_features for consistency with other models norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.reverse_pos = reverse_pos self.waypoints_pred_head = waypoints_pred_head self.with_lidar = with_lidar self.with_right_left_sensors = with_right_left_sensors self.with_center_sensor = with_center_sensor self.direct_concat = direct_concat self.separate_view_attention = separate_view_attention self.separate_all_attention = separate_all_attention self.end2end = end2end self.use_view_embed = use_view_embed if self.direct_concat: in_chans = in_chans * 4 self.with_center_sensor = False self.with_right_left_sensors = False if self.separate_view_attention: self.attn_mask = build_attn_mask("seperate_view") elif self.separate_all_attention: self.attn_mask = build_attn_mask("seperate_all") else: self.attn_mask = None if use_different_backbone: if rgb_backbone_name == "r50": self.rgb_backbone = resnet50d( pretrained=True, in_chans=in_chans, features_only=True, out_indices=[4], ) elif rgb_backbone_name == "r26": self.rgb_backbone = resnet26d( pretrained=True, in_chans=in_chans, features_only=True, out_indices=[4], ) elif rgb_backbone_name == "r18": self.rgb_backbone = resnet18d( pretrained=True, in_chans=in_chans, features_only=True, out_indices=[4], ) if lidar_backbone_name == "r50": self.lidar_backbone = resnet50d( pretrained=False, in_chans=in_chans, features_only=True, out_indices=[4], ) elif lidar_backbone_name == "r26": self.lidar_backbone = resnet26d( pretrained=False, in_chans=in_chans, features_only=True, out_indices=[4], ) elif lidar_backbone_name == "r18": self.lidar_backbone = resnet18d( pretrained=False, in_chans=3, features_only=True, out_indices=[4] ) rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone) lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone) if use_mmad_pretrain: params = torch.load(use_mmad_pretrain)["state_dict"] updated_params = OrderedDict() for key in params: if "backbone" in key: updated_params[key.replace("backbone.", "")] = params[key] self.rgb_backbone.load_state_dict(updated_params) self.rgb_patch_embed = rgb_embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) self.lidar_patch_embed = lidar_embed_layer( img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim, ) else: if rgb_backbone_name == "r50": self.rgb_backbone = resnet50d( pretrained=True, in_chans=3, features_only=True, out_indices=[4] ) elif rgb_backbone_name == "r101": self.rgb_backbone = resnet101d( pretrained=True, in_chans=3, features_only=True, out_indices=[4] ) elif rgb_backbone_name == "r26": self.rgb_backbone = resnet26d( pretrained=True, in_chans=3, features_only=True, out_indices=[4] ) elif rgb_backbone_name == "r18": self.rgb_backbone = resnet18d( pretrained=True, in_chans=3, features_only=True, out_indices=[4] ) embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone) self.rgb_patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) self.lidar_patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5)) self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1)) if self.end2end: self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4)) self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim)) elif self.waypoints_pred_head == "heatmap": self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5)) self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim)) else: self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11)) self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim)) if self.end2end: self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4) elif self.waypoints_pred_head == "heatmap": self.waypoints_generator = MultiPath_Generator( embed_dim + 32, embed_dim, 10 ) elif self.waypoints_pred_head == "gru": self.waypoints_generator = GRUWaypointsPredictor(embed_dim) elif self.waypoints_pred_head == "gru-command": self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim) elif self.waypoints_pred_head == "linear": self.waypoints_generator = LinearWaypointsPredictor(embed_dim) elif self.waypoints_pred_head == "linear-sum": self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True) self.junction_pred_head = nn.Linear(embed_dim, 2) self.traffic_light_pred_head = nn.Linear(embed_dim, 2) self.stop_sign_head = nn.Linear(embed_dim, 2) if self.traffic_pred_head_type == "det": self.traffic_pred_head = nn.Sequential( *[ nn.Linear(embed_dim + 32, 64), nn.ReLU(), nn.Linear(64, 7), # nn.Sigmoid(), ] ) elif self.traffic_pred_head_type == "seg": self.traffic_pred_head = nn.Sequential( *[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()] ) self.position_encoding = HyperDimensionalPositionalEncoding(embed_dim , normalize=True) encoder_layer = TransformerEncoderLayer( embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before ) self.encoder = TransformerEncoder(encoder_layer, enc_depth, None) decoder_layer = TransformerDecoderLayer( embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before ) decoder_norm = nn.LayerNorm(embed_dim) self.decoder = TransformerDecoder( decoder_layer, dec_depth, decoder_norm, return_intermediate=False ) self.reset_parameters() def reset_parameters(self): nn.init.uniform_(self.global_embed) nn.init.uniform_(self.view_embed) nn.init.uniform_(self.query_embed) nn.init.uniform_(self.query_pos_embed) def forward_features( self, front_image, left_image, right_image, front_center_image, lidar, measurements, ): features = [] # Front view processing front_image_token, front_image_token_global = self.rgb_patch_embed(front_image) if self.use_view_embed: front_image_token = ( front_image_token + self.view_embed[:, :, 0:1, :] + self.position_encoding(front_image_token) ) else: front_image_token = front_image_token + self.position_encoding( front_image_token ) front_image_token = front_image_token.flatten(2).permute(2, 0, 1) front_image_token_global = ( front_image_token_global + self.view_embed[:, :, 0, :] + self.global_embed[:, :, 0:1] ) front_image_token_global = front_image_token_global.permute(2, 0, 1) features.extend([front_image_token, front_image_token_global]) if self.with_right_left_sensors: # Left view processing left_image_token, left_image_token_global = self.rgb_patch_embed(left_image) if self.use_view_embed: left_image_token = ( left_image_token + self.view_embed[:, :, 1:2, :] + self.position_encoding(left_image_token) ) else: left_image_token = left_image_token + self.position_encoding( left_image_token ) left_image_token = left_image_token.flatten(2).permute(2, 0, 1) left_image_token_global = ( left_image_token_global + self.view_embed[:, :, 1, :] + self.global_embed[:, :, 1:2] ) left_image_token_global = left_image_token_global.permute(2, 0, 1) # Right view processing right_image_token, right_image_token_global = self.rgb_patch_embed( right_image ) if self.use_view_embed: right_image_token = ( right_image_token + self.view_embed[:, :, 2:3, :] + self.position_encoding(right_image_token) ) else: right_image_token = right_image_token + self.position_encoding( right_image_token ) right_image_token = right_image_token.flatten(2).permute(2, 0, 1) right_image_token_global = ( right_image_token_global + self.view_embed[:, :, 2, :] + self.global_embed[:, :, 2:3] ) right_image_token_global = right_image_token_global.permute(2, 0, 1) features.extend( [ left_image_token, left_image_token_global, right_image_token, right_image_token_global, ] ) if self.with_center_sensor: # Front center view processing ( front_center_image_token, front_center_image_token_global, ) = self.rgb_patch_embed(front_center_image) if self.use_view_embed: front_center_image_token = ( front_center_image_token + self.view_embed[:, :, 3:4, :] + self.position_encoding(front_center_image_token) ) else: front_center_image_token = ( front_center_image_token + self.position_encoding(front_center_image_token) ) front_center_image_token = front_center_image_token.flatten(2).permute( 2, 0, 1 ) front_center_image_token_global = ( front_center_image_token_global + self.view_embed[:, :, 3, :] + self.global_embed[:, :, 3:4] ) front_center_image_token_global = front_center_image_token_global.permute( 2, 0, 1 ) features.extend([front_center_image_token, front_center_image_token_global]) if self.with_lidar: lidar_token, lidar_token_global = self.lidar_patch_embed(lidar) if self.use_view_embed: lidar_token = ( lidar_token + self.view_embed[:, :, 4:5, :] + self.position_encoding(lidar_token) ) else: lidar_token = lidar_token + self.position_encoding(lidar_token) lidar_token = lidar_token.flatten(2).permute(2, 0, 1) lidar_token_global = ( lidar_token_global + self.view_embed[:, :, 4, :] + self.global_embed[:, :, 4:5] ) lidar_token_global = lidar_token_global.permute(2, 0, 1) features.extend([lidar_token, lidar_token_global]) features = torch.cat(features, 0) return features def forward(self, x): front_image = x["rgb"] left_image = x["rgb_left"] right_image = x["rgb_right"] front_center_image = x["rgb_center"] measurements = x["measurements"] target_point = x["target_point"] lidar = x["lidar"] if self.direct_concat: img_size = front_image.shape[-1] left_image = torch.nn.functional.interpolate( left_image, size=(img_size, img_size) ) right_image = torch.nn.functional.interpolate( right_image, size=(img_size, img_size) ) front_center_image = torch.nn.functional.interpolate( front_center_image, size=(img_size, img_size) ) front_image = torch.cat( [front_image, left_image, right_image, front_center_image], dim=1 ) features = self.forward_features( front_image, left_image, right_image, front_center_image, lidar, measurements, ) bs = front_image.shape[0] if self.end2end: tgt = self.query_pos_embed.repeat(bs, 1, 1) else: tgt = self.position_encoding( torch.ones((bs, 1, 20, 20), device=x["rgb"].device) ) tgt = tgt.flatten(2) tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2) tgt = tgt.permute(2, 0, 1) memory = self.encoder(features, mask=self.attn_mask) hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0] hs = hs.permute(1, 0, 2) # Batchsize , N, C if self.end2end: waypoints = self.waypoints_generator(hs, target_point) return waypoints if self.waypoints_pred_head != "heatmap": traffic_feature = hs[:, :400] is_junction_feature = hs[:, 400] traffic_light_state_feature = hs[:, 400] stop_sign_feature = hs[:, 400] waypoints_feature = hs[:, 401:411] else: traffic_feature = hs[:, :400] is_junction_feature = hs[:, 400] traffic_light_state_feature = hs[:, 400] stop_sign_feature = hs[:, 400] waypoints_feature = hs[:, 401:405] if self.waypoints_pred_head == "heatmap": waypoints = self.waypoints_generator(waypoints_feature, measurements) elif self.waypoints_pred_head == "gru": waypoints = self.waypoints_generator(waypoints_feature, target_point) elif self.waypoints_pred_head == "gru-command": waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements) elif self.waypoints_pred_head == "linear": waypoints = self.waypoints_generator(waypoints_feature, measurements) elif self.waypoints_pred_head == "linear-sum": waypoints = self.waypoints_generator(waypoints_feature, measurements) is_junction = self.junction_pred_head(is_junction_feature) traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature) stop_sign = self.stop_sign_head(stop_sign_feature) velocity = measurements[:, 6:7].unsqueeze(-1) velocity = velocity.repeat(1, 400, 32) traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2) traffic = self.traffic_pred_head(traffic_feature_with_vel) return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature def load_pretrained(self, model_path, strict=False): """ تحميل الأوزان المدربة مسبقاً - نسخة محسنة Args: model_path (str): مسار ملف الأوزان strict (bool): إذا كان True، يتطلب تطابق تام للمفاتيح """ if not model_path or not Path(model_path).exists(): logging.warning(f"ملف الأوزان غير موجود: {model_path}") logging.info("سيتم استخدام أوزان عشوائية") return False try: logging.info(f"محاولة تحميل الأوزان من: {model_path}") # تحميل الملف مع معالجة أنواع مختلفة من ملفات الحفظ checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) # استخراج state_dict من أنواع مختلفة من ملفات الحفظ if isinstance(checkpoint, dict): if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] logging.info("تم العثور على 'model_state_dict' في الملف") elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] logging.info("تم العثور على 'state_dict' في الملف") elif 'model' in checkpoint: state_dict = checkpoint['model'] logging.info("تم العثور على 'model' في الملف") else: state_dict = checkpoint logging.info("استخدام الملف كـ state_dict مباشرة") else: state_dict = checkpoint logging.info("استخدام الملف كـ state_dict مباشرة") # تنظيف أسماء المفاتيح (إزالة 'module.' إذا كانت موجودة) clean_state_dict = OrderedDict() for k, v in state_dict.items(): # إزالة 'module.' من بداية اسم المفتاح إذا كان موجوداً clean_key = k[7:] if k.startswith('module.') else k clean_state_dict[clean_key] = v # تحميل الأوزان missing_keys, unexpected_keys = self.load_state_dict(clean_state_dict, strict=strict) # تقرير حالة التحميل if missing_keys: logging.warning(f"مفاتيح مفقودة ({len(missing_keys)}): {missing_keys[:5]}..." if len(missing_keys) > 5 else f"مفاتيح مفقودة: {missing_keys}") if unexpected_keys: logging.warning(f"مفاتيح غير متوقعة ({len(unexpected_keys)}): {unexpected_keys[:5]}..." if len(unexpected_keys) > 5 else f"مفاتيح غير متوقعة: {unexpected_keys}") if not missing_keys and not unexpected_keys: logging.info("✅ تم تحميل جميع الأوزان بنجاح تام") elif not strict: logging.info("✅ تم تحميل الأوزان بنجاح (مع تجاهل عدم التطابق)") return True except Exception as e: logging.error(f"❌ خطأ في تحميل الأوزان: {str(e)}") logging.info("سيتم استخدام أوزان عشوائية") return False # ============================================================================== # الدالة الأولى: get_master_config # ============================================================================== def get_master_config(): """ [النسخة الاحترافية] يعيد قاموسًا شاملاً يحتوي على جميع إعدادات التطبيق الثابتة. هذه الدالة هي المصدر الوحيد للحقيقة للإعدادات. """ # --- القسم 1: معلومات مستودع النموذج على Hugging Face Hub --- huggingface_repo = { 'repo_id': "BaseerAI/Interfuser-Baseer-v1", # استبدله باسم مستودع النموذج الخاص بك 'filename': "pytorch_model.bin" # اسم ملف الأوزان داخل المستودع } # --- القسم 2: إعدادات بنية نموذج Interfuser --- model_params = { "img_size": 224, "embed_dim": 256, "enc_depth": 6, "dec_depth": 6, "rgb_backbone_name": 'r50', "lidar_backbone_name": 'r18', "waypoints_pred_head": 'gru', "use_different_backbone": True, "with_lidar": False, "with_right_left_sensors": False, "with_center_sensor": False, "multi_view_img_size": 112, "patch_size": 8, "in_chans": 3, "dim_feedforward": 2048, "normalize_before": False, "num_heads": 8, "dropout": 0.1, "end2end": False, "direct_concat": False, "separate_view_attention": False, "separate_all_attention": False, "freeze_num": -1, "traffic_pred_head_type": "det", "reverse_pos": True, "use_view_embed": False, "use_mmad_pretrain": None, } # --- القسم 3: إعدادات الشبكة ومنظور عين الطائر (BEV) --- grid_conf = { 'h': 20, 'w': 20, 'x_res': 1.0, 'y_res': 1.0, 'y_min': 0.0, 'y_max': 20.0, 'x_min': -10.0, 'x_max': 10.0, } # --- القسم 4: إعدادات وحدة التحكم (Controller) والمتتبع (Tracker) --- controller_params = { 'turn_KP': 0.75, 'turn_KI': 0.05, 'turn_KD': 0.25, 'turn_n': 20, 'speed_KP': 0.55, 'speed_KI': 0.05, 'speed_KD': 0.15, 'speed_n': 20, 'max_speed': 8.0, 'max_throttle': 0.75, 'min_speed': 0.1, 'brake_sensitivity': 0.3, 'light_threshold': 0.5, 'stop_threshold': 0.6, 'stop_sign_duration': 20, 'max_stop_time': 250, 'forced_move_duration': 20, 'forced_throttle': 0.5, 'max_red_light_time': 150, 'red_light_block_duration': 80, 'accel_rate': 0.1, 'decel_rate': 0.2, 'critical_distance': 4.0, 'follow_distance': 10.0, 'speed_match_factor': 0.9, 'tracker_match_thresh': 2.5, 'tracker_prune_age': 5, 'follow_grace_period': 20 } # --- القسم 5: تجميع كل شيء في قاموس رئيسي واحد --- master_config = { 'huggingface_repo': huggingface_repo, 'model_params': model_params, 'grid_conf': grid_conf, 'controller_params': controller_params, 'simulation': { 'frequency': 10.0 } } return master_config # ============================================================================== # الدالة الثانية: load_and_prepare_model # ============================================================================== def load_and_prepare_model(device: torch.device) -> InterfuserHDPE: """ [النسخة الاحترافية] تستخدم الإعدادات الرئيسية من `get_master_config` لإنشاء وتحميل النموذج. تقوم بتحويل معرّف النموذج من Hugging Face Hub إلى مسار ملف حقيقي. Args: device (torch.device): الجهاز المستهدف (CPU/GPU) Returns: Interfuser: النموذج المحمل وجاهز للاستدلال. """ try: logging.info("Initializing model loading process...") # 1. الحصول على جميع الإعدادات من المصدر الوحيد للحقيقة config = get_master_config() # 2. تحميل ملف الأوزان من Hugging Face Hub repo_info = config['huggingface_repo'] logging.info(f"Downloading model weights from repo: '{repo_info['repo_id']}'") # استخدام token إذا كان المستودع خاصًا # token = HfFolder.get_token() # أو يمكن تمريره مباشرة actual_model_path = hf_hub_download( repo_id=repo_info['repo_id'], filename=repo_info['filename'], # token=token, # قم بإلغاء التعليق إذا كان المستودع خاصًا ) logging.info(f"Model weights are available at local path: {actual_model_path}") # 3. إنشاء نسخة من النموذج باستخدام الإعدادات الصحيحة logging.info("Instantiating model with specified parameters...") model = InterfuserHDPE(**config['model_params']).to(device) # 4. تحميل الأوزان التي تم تنزيلها إلى النموذج # نستخدم الدالة المساعدة الموجودة داخل كلاس النموذج نفسه success = model.load_pretrained(actual_model_path, strict=False) if not success: logging.warning("⚠️ Model weights were not loaded successfully. The model will use random weights.") # 5. وضع النموذج في وضع التقييم (خطوة حاسمة) model.eval() logging.info("✅ Model prepared and set to evaluation mode. Ready for inference.") return model except Exception as e: # تسجيل الخطأ بالتفصيل ثم إطلاقه مرة أخرى ليتم التعامل معه في مستوى أعلى logging.error(f"❌ CRITICAL ERROR during model initialization: {e}", exc_info=True) raise