| """ |
| SLIP: Sensor Language Integrated Pre-training |
| Self-contained model file for HuggingFace Hub (trust_remote_code=True). |
| |
| Usage: |
| from transformers import AutoModel, AutoTokenizer |
| model = AutoModel.from_pretrained("LeoChen085/SLIP", trust_remote_code=True, device_map="auto") |
| tokenizer = AutoTokenizer.from_pretrained("LeoChen085/SLIP", trust_remote_code=True) |
| |
| # Task-specific checkpoint (download manually): |
| from huggingface_hub import hf_hub_download |
| from safetensors.torch import load_file |
| state_dict = load_file(hf_hub_download("LeoChen085/SLIP", "har.safetensors")) |
| model.load_state_dict(state_dict, strict=False) |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Tuple, List |
| from einops import rearrange, repeat, reduce |
| from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel |
| from transformers.activations import ACT2FN |
| from configuration_slip import SLIPConfig |
|
|
|
|
| |
| |
| |
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None): |
| super().__init__() |
| self.dim = dim |
| self.max_position_embeddings = max_position_embeddings |
| self.base = base |
| inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, |
| 2, dtype=torch.int64).float().to(device) / self.dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self._set_cos_sin_cache( |
| seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()) |
|
|
| def _set_cos_sin_cache(self, seq_len, device, dtype): |
| self.max_seq_len_cached = seq_len |
| t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) |
| freqs = torch.outer(t, self.inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) |
| self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) |
|
|
| def forward(self, x, seq_len=None): |
| if seq_len > self.max_seq_len_cached: |
| self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) |
| return (self.cos_cached[:seq_len].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype)) |
|
|
|
|
| def rotate_half(x): |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2:] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): |
| cos = cos[position_ids].unsqueeze(unsqueeze_dim) |
| sin = sin[position_ids].unsqueeze(unsqueeze_dim) |
| return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) |
|
|
|
|
| def apply_rotary_pos_emb_2d(q, k, cos_h, sin_h, cos_w, sin_w, pos_h, pos_w, unsqueeze_dim=1): |
| Dh = q.shape[-1] |
| q_h, q_w = q.split(Dh // 2, dim=-1) |
| k_h, k_w = k.split(Dh // 2, dim=-1) |
| q_h, k_h = apply_rotary_pos_emb(q_h, k_h, cos_h, sin_h, pos_h.long(), unsqueeze_dim=unsqueeze_dim) |
| q_w, k_w = apply_rotary_pos_emb(q_w, k_w, cos_w, sin_w, pos_w.long(), unsqueeze_dim=unsqueeze_dim) |
| return torch.cat([q_h, q_w], dim=-1), torch.cat([k_h, k_w], dim=-1) |
|
|
|
|
| def build_2d_position_ids(attention_mask, flatten=True): |
| B, V, P = attention_mask.shape |
| mask = attention_mask.to(dtype=torch.long) |
| pos_patch = (mask.cumsum(dim=-1) - 1) * mask |
| var_valid = mask.any(dim=-1).to(dtype=torch.long) |
| pos_var_base = (var_valid.cumsum(dim=1) - 1) * var_valid |
| pos_var = pos_var_base.unsqueeze(-1).expand(B, V, P) * mask |
| if flatten: |
| return pos_var.reshape(B, V * P).long(), pos_patch.reshape(B, V * P).long() |
| return pos_var.long(), pos_patch.long() |
|
|
|
|
| |
| |
| |
|
|
| def flatten_list(input_list): |
| return [item for sublist in input_list for item in sublist] |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, hidden_size, intermediate_size, hidden_act): |
| super().__init__() |
| self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
| self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
| self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) |
| self.act_fn = ACT2FN[hidden_act] |
|
|
| def forward(self, hidden_state): |
| return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) |
|
|
|
|
| class TsRoPEAttention(nn.Module): |
| def __init__(self, layer_idx, **cfg): |
| super().__init__() |
| self.hidden_size = cfg.get("embed_dim", 768) |
| self.num_heads = cfg.get("num_heads", 12) |
| self.head_dim = self.hidden_size // self.num_heads |
| self.attention_dropout = cfg.get("dropout_rate", 0.1) |
| self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) |
| self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) |
| self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) |
| self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
| self.rotary_emb = RotaryEmbedding(self.head_dim // 2, max_position_embeddings=cfg.get("max_position_embeddings")) |
|
|
| def forward(self, hidden_states, attention_mask=None, **kwargs): |
| bsz, q_len, _ = hidden_states.size() |
| tmp_attn_mask = rearrange(attention_mask, 'b nvar p -> b (nvar p)') |
| query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| tmp_attn_mask = tmp_attn_mask.unsqueeze(1).unsqueeze(2).expand(-1, 1, q_len, q_len).bool() |
| pos_var, pos_patch = build_2d_position_ids(attention_mask, flatten=True) |
| cos_h, sin_h = self.rotary_emb(query_states, seq_len=int(pos_var.max().item()) + 1) |
| cos_w, sin_w = self.rotary_emb(query_states, seq_len=int(pos_patch.max().item()) + 1) |
| query_states, key_states = apply_rotary_pos_emb_2d( |
| query_states, key_states, cos_h, sin_h, cos_w, sin_w, pos_var, pos_patch) |
| attn_output = F.scaled_dot_product_attention( |
| query_states, key_states, value_states, tmp_attn_mask, dropout_p=self.attention_dropout) |
| attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size) |
| return self.o_proj(attn_output) |
|
|
|
|
| class MultiSizePatchEmbed(nn.Module): |
| def __init__(self, base_patch=32, **cfg): |
| super().__init__() |
| self.base_patch = base_patch |
| hidden_size = cfg['embed_dim'] |
| intermediate_size = cfg['mlp_ratio'] * hidden_size |
| self.intermediate_size = intermediate_size |
| self.hidden_size = hidden_size |
| self.shared_linear = nn.Linear(base_patch * 3, intermediate_size) |
| self.shared_residual = nn.Linear(base_patch * 3, hidden_size) |
| self.dropout = nn.Dropout(cfg['dropout_rate']) |
| self.act = ACT2FN['silu'] |
| self.output_layer = nn.Linear(intermediate_size, hidden_size) |
|
|
| def resize_weight(self, patch_size): |
| base_w, base_b = self.shared_linear.weight, self.shared_linear.bias |
| res_w, res_b = self.shared_residual.weight, self.shared_residual.bias |
| new_w = F.interpolate(base_w.unsqueeze(1), size=patch_size, mode="linear", align_corners=False).squeeze(1).to(base_w.dtype) |
| new_res_w = F.interpolate(res_w.unsqueeze(1), size=patch_size, mode="linear", align_corners=False).squeeze(1).to(res_w.dtype) |
| return new_w, base_b, new_res_w, res_b |
|
|
| def forward(self, x_list, attention_mask, time_idx): |
| device = self.shared_linear.weight.device |
| dtype = self.shared_linear.weight.dtype |
| sizes = torch.tensor([x.shape[-1] for x in x_list]) |
| unique_sizes = sizes.unique(sorted=True) |
| N = x_list[0].shape[0] |
| outputs = torch.empty(len(x_list), N, self.intermediate_size, device=device, dtype=dtype) |
| res_outputs = torch.empty(len(x_list), N, self.hidden_size, device=device, dtype=dtype) |
| for psize in unique_sizes.tolist(): |
| idxs = (sizes == psize).nonzero(as_tuple=True)[0] |
| xs = torch.stack([x_list[i] for i in idxs]).to(device=device, non_blocking=True) |
| mask = torch.stack([attention_mask[i] for i in idxs]).to(device=device, non_blocking=True) |
| ti = torch.stack([time_idx[i] for i in idxs]).to(device=device, non_blocking=True) |
| xs = torch.cat([xs, mask, ti], dim=-1) |
| w, b, r_w, r_b = self.resize_weight(psize * 3) |
| res_outputs[idxs] = F.linear(xs, r_w, r_b) |
| outputs[idxs] = F.linear(xs, w, b) |
| return self.dropout(self.output_layer(self.act(outputs))) + res_outputs |
|
|
|
|
| class PatchEmbedding(nn.Module): |
| def __init__(self, **cfg): |
| super().__init__() |
| patch_size = cfg['patch_size'] |
| self.patch_size = patch_size |
| self.dropout = nn.Dropout(cfg.get('dropout_rate', 0.1)) |
| hidden_size = cfg['embed_dim'] |
| self.hidden_layer = nn.Linear(patch_size * 3, hidden_size) |
| self.act = ACT2FN['silu'] |
| self.output_layer = nn.Linear(hidden_size, hidden_size) |
| self.residual_layer = nn.Linear(patch_size * 3, hidden_size) |
|
|
| def forward(self, x, mask, time_idx): |
| x = rearrange(x, 'bs nvar (nump ps) -> (bs nvar) nump ps', ps=self.patch_size) |
| mask = rearrange(mask, 'bs nvar (nump ps) -> (bs nvar) nump ps', ps=self.patch_size) |
| time_idx = rearrange(time_idx, 'bs nvar (nump ps) -> (bs nvar) nump ps', ps=self.patch_size) |
| x = torch.cat([x, mask, time_idx], dim=-1) |
| return self.dropout(self.output_layer(self.act(self.hidden_layer(x)))) + self.residual_layer(x) |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, layer_idx, is_rope=True, **cfg): |
| super().__init__() |
| self.is_rope = is_rope |
| self.hidden_size = cfg.get("embed_dim", 768) |
| self.num_heads = cfg.get("num_heads", 12) |
| self.head_dim = self.hidden_size // self.num_heads |
| self.attention_dropout = cfg.get("dropout_rate", 0.1) |
| self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) |
| self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) |
| self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) |
| self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
| if self.is_rope: |
| self.rotary_emb = RotaryEmbedding(self.head_dim, max_position_embeddings=cfg.get("sensor_max_len", 2880)) |
|
|
| def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs): |
| bsz, q_len, _ = hidden_states.size() |
| query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| if self.is_rope: |
| cos, sin = self.rotary_emb(value_states, seq_len=key_states.shape[-2]) |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) |
| attn_output = F.scaled_dot_product_attention( |
| query_states, key_states, value_states, attention_mask, dropout_p=self.attention_dropout) |
| return self.o_proj(attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)) |
|
|
|
|
| class CrossAttention(nn.Module): |
| def __init__(self, dim=768, *, context_dim=384, num_heads=12, dropout_rate=0.1): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.attn_dropout = dropout_rate |
| self.norm = nn.LayerNorm(dim) |
| self.context_norm = nn.LayerNorm(context_dim) |
| self.q_proj = nn.Linear(dim, dim, bias=True) |
| self.k_proj = nn.Linear(context_dim, dim, bias=True) |
| self.v_proj = nn.Linear(context_dim, dim, bias=True) |
| self.o_proj = nn.Linear(dim, dim, bias=False) |
|
|
| def forward(self, query, context, attention_mask=None, **kwargs): |
| bsz, q_len, _ = query.size() |
| assert context.size(0) == bsz, ( |
| f"Context batch size ({context.size(0)}) must match query batch size ({bsz}). " |
| f"Ensure sensor and text inputs have the same batch size." |
| ) |
| k_len = context.size(1) |
| query = self.norm(query) |
| context = self.context_norm(context) |
| q = self.q_proj(query).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(context).view(bsz, k_len, self.num_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(context).view(bsz, k_len, self.num_heads, self.head_dim).transpose(1, 2) |
| attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=self.attn_dropout) |
| return self.o_proj(attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.dim)) |
|
|
|
|
| class AllAttention(nn.Module): |
| def __init__(self, layer_idx, **cfg): |
| super().__init__() |
| self.self_attention = TsRoPEAttention(layer_idx=layer_idx, **cfg) |
| self.layer_norm = nn.LayerNorm(cfg.get('embed_dim')) |
| self.dropout = nn.Dropout(cfg.get('dropout_rate', 0.1)) |
|
|
| def forward(self, hidden_states, attention_mask): |
| return hidden_states + self.dropout(self.self_attention(self.layer_norm(hidden_states), attention_mask)) |
|
|
|
|
| class TimeSelfAttention(nn.Module): |
| def __init__(self, layer_idx, **cfg): |
| super().__init__() |
| self.self_attention = Attention(layer_idx=layer_idx, is_rope=True, **cfg) |
| self.layer_norm = nn.LayerNorm(cfg.get('embed_dim', 768)) |
| self.dropout = nn.Dropout(cfg.get('dropout_rate', 0.1)) |
|
|
| def forward(self, hidden_states, attention_mask, position_ids): |
| q_len = hidden_states.size(1) |
| am = rearrange(attention_mask, 'b nvar p -> (b nvar) p') |
| am = am.unsqueeze(1).unsqueeze(2).expand(-1, 1, q_len, q_len).bool() |
| return hidden_states + self.dropout(self.self_attention(self.layer_norm(hidden_states), am, position_ids)) |
|
|
|
|
| class GroupSelfAttention(nn.Module): |
| def __init__(self, layer_idx, **cfg): |
| super().__init__() |
| self.self_attention = Attention(layer_idx, is_rope=False, **cfg) |
| self.layer_norm = nn.LayerNorm(cfg.get('embed_dim', 768)) |
| self.dropout = nn.Dropout(cfg.get('dropout_rate', 0.1)) |
|
|
| def forward(self, hidden_states, attention_mask, group_ids): |
| BS, nvar, _ = attention_mask.shape |
| hidden_states = rearrange(hidden_states, '(bs nvar) l d -> (bs l) nvar d', bs=BS, nvar=nvar) |
| am = rearrange(attention_mask, 'bs nvar l -> (bs l) nvar') |
| group_attn_mask = am.unsqueeze(1).unsqueeze(2).expand(-1, 1, nvar, nvar).bool() |
| hidden_states = hidden_states + self.dropout(self.self_attention(self.layer_norm(hidden_states), group_attn_mask)) |
| return rearrange(hidden_states, '(bs l) nvar d -> (bs nvar) l d', bs=BS, nvar=nvar) |
|
|
|
|
| class AttentionPooling(nn.Module): |
| def __init__(self, dim=768, mlp_ratio=4, context_dim=384, num_heads=12, dropout_rate=0.1): |
| super().__init__() |
| self.cross_attn = CrossAttention(dim=dim, context_dim=context_dim, num_heads=num_heads, dropout_rate=dropout_rate) |
| self.ffn_norm = nn.LayerNorm(dim) |
| self.ffn_layer = MLP(hidden_size=dim, intermediate_size=dim * mlp_ratio, hidden_act='silu') |
| self.post_norm = nn.LayerNorm(dim) |
|
|
| def forward(self, x, context, attn_mask=None): |
| b, n, _ = x.shape |
| kv_len = context.shape[1] |
| attn_mask = rearrange(attn_mask, 'b nvar p -> b (nvar p)') |
| attn_mask = attn_mask.view(b, 1, 1, kv_len).expand(b, 1, n, kv_len).bool() |
| x = self.cross_attn(x, context, attn_mask) |
| x = x + self.ffn_layer(self.ffn_norm(x)) |
| return self.post_norm(x) |
|
|
|
|
| class SensorEncoderLayer(nn.Module): |
| def __init__(self, layer_idx, **cfg): |
| super().__init__() |
| hidden_size = cfg['embed_dim'] |
| self.channel_attn_type = cfg.get('channel_attn_type', 'group_attn') |
| if self.channel_attn_type == 'group_attn': |
| self.ts_attn = TimeSelfAttention(layer_idx=layer_idx, **cfg) |
| self.group_attn = GroupSelfAttention(layer_idx=layer_idx, **cfg) |
| elif self.channel_attn_type == 'univariate': |
| self.ts_attn = TimeSelfAttention(layer_idx=layer_idx, **cfg) |
| else: |
| self.ts_attn = AllAttention(layer_idx=layer_idx, **cfg) |
| self.norm = nn.LayerNorm(hidden_size) |
| self.ffn_layer = MLP(hidden_size=hidden_size, intermediate_size=cfg['mlp_ratio'] * hidden_size, hidden_act='silu') |
|
|
| def forward(self, hidden_states, attention_mask=None, group_ids=None, position_ids=None): |
| if self.channel_attn_type == 'group_attn': |
| hidden_states = self.ts_attn(hidden_states, attention_mask, position_ids) |
| hidden_states = self.group_attn(hidden_states, attention_mask, group_ids) |
| elif self.channel_attn_type == 'univariate': |
| hidden_states = self.ts_attn(hidden_states, attention_mask, position_ids) |
| else: |
| hidden_states = self.ts_attn(hidden_states, attention_mask) |
| residual = hidden_states |
| return residual + self.ffn_layer(self.norm(hidden_states)) |
|
|
|
|
| class SensorTransformerModel(nn.Module): |
| def __init__(self, **cfg): |
| super().__init__() |
| patch_size = cfg.get('patch_size', None) |
| self.patch_size = patch_size |
| self.patch_embed = PatchEmbedding(**cfg) if patch_size else MultiSizePatchEmbed(**cfg) |
| self.blocks = nn.ModuleList([SensorEncoderLayer(i, **cfg) for i in range(cfg['depth'])]) |
| self.norm = nn.LayerNorm(cfg['embed_dim']) |
| self.embed_dim = cfg['embed_dim'] |
| self.channel_attn_type = cfg.get('channel_attn_type', 'group_attn') |
|
|
| def forward(self, input_ids, attention_mask, time_index): |
| if self.patch_size is None: |
| BS = len(input_ids) |
| hidden_states = self.patch_embed(flatten_list(input_ids), flatten_list(attention_mask), flatten_list(time_index)) |
| attention_mask = self._get_self_attn_mask(attention_mask).to(hidden_states.device) |
| position_ids = rearrange(self._build_rope_position_ids(attention_mask), 'b nvar p -> (b nvar) p') |
| else: |
| BS = input_ids.shape[0] |
| hidden_states = self.patch_embed(input_ids, attention_mask, time_index) |
| attention_mask = reduce(attention_mask, 'b v (p ps) -> b v p', 'max', ps=self.patch_size) |
| position_ids = rearrange(self._build_rope_position_ids(attention_mask), 'b nvar p -> (b nvar) p') |
|
|
| if self.channel_attn_type == 'all_attn': |
| hidden_states = rearrange(hidden_states, '(b nvar) l d -> b (nvar l) d', b=BS) |
| for blk in self.blocks: |
| hidden_states = blk(hidden_states, attention_mask=attention_mask, group_ids=None, position_ids=position_ids) |
| if self.channel_attn_type == 'group_attn': |
| hidden_states = rearrange(hidden_states, '(b nvar) l d -> b (nvar l) d', b=BS) |
| return self.norm(hidden_states), attention_mask |
|
|
| def _build_rope_position_ids(self, attention_mask): |
| mask = attention_mask.to(torch.long) |
| return (mask.cumsum(dim=-1) - 1) * mask |
|
|
| def _get_self_attn_mask(self, attn_mask_list): |
| collapsed = [] |
| for sample_masks in attn_mask_list: |
| collapsed.append(torch.stack([(m.sum(dim=-1) > 0).to(m.dtype) for m in sample_masks], dim=0)) |
| return torch.stack(collapsed, dim=0) |
|
|
|
|
| |
| |
| |
|
|
| class Residual(nn.Module): |
| def __init__(self, fn): |
| super().__init__() |
| self.fn = fn |
|
|
| def forward(self, x, *args, **kwargs): |
| return self.fn(x, *args, **kwargs) + x |
|
|
|
|
| class Gemma3MultimodalLayer(nn.Module): |
| def __init__(self, original_layer, cross_attn_block): |
| super().__init__() |
| self.original_layer = original_layer |
| self.cross_attn_block = cross_attn_block |
| self.vis_x = None |
|
|
| def condition_vis_x(self, vis_x): |
| self.vis_x = vis_x |
|
|
| def __getattr__(self, name): |
| try: |
| return super().__getattr__(name) |
| except AttributeError: |
| return getattr(self.original_layer, name) |
|
|
| def forward(self, hidden_states, **kwargs): |
| assert self.vis_x is not None, "vis_x must be set before forward pass." |
| outputs = self.original_layer(hidden_states, **kwargs) |
| hidden_states = self.cross_attn_block(outputs[0], context=self.vis_x) |
| return (hidden_states,) + outputs[1:] |
|
|
|
|
| class Gemma3MultimodalModel(nn.Module): |
| def __init__(self, model_id="google/gemma-3-270m", init_from_pretrained=False, split_layer=12, dtype=None): |
| super().__init__() |
| if init_from_pretrained: |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_id, trust_remote_code=True) |
| else: |
| config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) |
| config.torch_dtype = dtype or torch.float32 |
| self.model = AutoModelForCausalLM.from_config( |
| config, trust_remote_code=True) |
|
|
| self.split_layer = split_layer |
| hidden_size = self.model.config.hidden_size |
| num_heads = self.model.config.num_attention_heads |
| self.hidden_size = hidden_size |
|
|
| for i in range(split_layer, len(self.model.model.layers)): |
| cross_attn = CrossAttention( |
| dim=hidden_size, context_dim=hidden_size, num_heads=num_heads, dropout_rate=0.1) |
| self.model.model.layers[i] = Gemma3MultimodalLayer( |
| self.model.model.layers[i], Residual(cross_attn)) |
|
|
| def condition_image(self, image_embeds): |
| self.image_embeds = image_embeds |
| for layer in self.model.model.layers: |
| if isinstance(layer, Gemma3MultimodalLayer): |
| layer.condition_vis_x(self.image_embeds) |
|
|
| def forward(self, input_ids, attention_mask=None, return_embeddings=False, **kwargs): |
| outputs = self.model( |
| input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, **kwargs) |
| text_sentence_embedding = outputs.hidden_states[self.split_layer][:, -1, :] |
| if return_embeddings: |
| return outputs |
| return text_sentence_embedding, outputs.logits |
|
|
|
|
| |
| |
| |
|
|
| def masked_mean(t, mask, dim=1, eps=1e-6): |
| t = t.masked_fill(~mask, 0.) |
| numer = t.sum(dim=dim) |
| denom = mask.sum(dim=dim).clamp(min=eps) |
| return numer / denom |
|
|
|
|
| class EmbedToLatents(nn.Module): |
| def __init__(self, dim, dim_latents): |
| super().__init__() |
| self.to_latents = nn.Linear(dim, dim_latents, bias=False) |
|
|
| def forward(self, x): |
| return F.normalize(self.to_latents(x), dim=-1) |
|
|
|
|
| class SLIPPreTrainedModel(PreTrainedModel): |
| config_class = SLIPConfig |
| base_model_prefix = "slip" |
| supports_gradient_checkpointing = False |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.constant_(module.bias, 0) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.constant_(module.bias, 0) |
| nn.init.constant_(module.weight, 1.0) |
|
|
|
|
| class SLIPModel(SLIPPreTrainedModel): |
| """ |
| SLIP: Sensor Language Integrated Pre-training. |
| |
| Usage: |
| model = AutoModel.from_pretrained("LeoChen085/SLIP", trust_remote_code=True) |
| """ |
|
|
| def __init__(self, config: SLIPConfig): |
| super().__init__(config) |
|
|
| |
| sensor_cfg = config.sensor_encoder |
| self.sensor_encoder = SensorTransformerModel(**sensor_cfg) |
| dim = self.sensor_encoder.embed_dim |
|
|
| |
| self.multimodalModel = Gemma3MultimodalModel( |
| config.llm_model_name, |
| init_from_pretrained=False, |
| split_layer=config.split_layer, |
| dtype=getattr(config, "torch_dtype", None), |
| ) |
|
|
| lm_dim = self.multimodalModel.hidden_size |
| common_dim = config.common_dim |
|
|
| |
| num_img_queries = config.num_img_queries |
| if num_img_queries > 0: |
| self.img_queries = nn.Parameter(torch.randn(num_img_queries + 1, common_dim)) |
| self.img_attn_pool = AttentionPooling( |
| dim=common_dim, context_dim=dim, num_heads=config.num_heads) |
| dim = common_dim |
|
|
| |
| self.img_to_latents = EmbedToLatents(dim, common_dim) |
| self.text_to_latents = EmbedToLatents(common_dim, common_dim) |
|
|
| |
| self.temperature = nn.Parameter(torch.tensor(math.log(1 / 0.07))) |
| self.temperature_max = math.log(1 / 0.07) |
|
|
| def embed_sensor(self, sensors, sensor_attn_mask=None, time_index=None): |
| sensor_tokens, attn_mask = self.sensor_encoder(sensors, sensor_attn_mask, time_index=time_index) |
| if hasattr(self, "img_attn_pool"): |
| img_queries = repeat(self.img_queries, "n d -> b n d", b=sensor_tokens.shape[0]) |
| sensor_tokens = self.img_attn_pool(img_queries, sensor_tokens, attn_mask) |
| return sensor_tokens, attn_mask.bool() |
|
|
| def forward(self, text=None, sensors=None, **kwargs): |
| """ |
| Forward pass for contrastive + captioning training. |
| For inference, use get_embedding(), get_sensor_embedding(), or generate(). |
| """ |
| sensor_hidden, sensor_mask = self.embed_sensor( |
| sensors=sensors["input_ids"], sensor_attn_mask=sensors["attention_mask"], |
| time_index=sensors["time_index"]) |
| self.multimodalModel.condition_image(sensor_hidden) |
| text_hidden, logits = self.multimodalModel( |
| input_ids=text["input_ids"][:, :-1], attention_mask=text["attention_mask"][:, :-1]) |
| text_hidden = self.text_to_latents(text_hidden) |
| sensor_hidden = self.img_to_latents(sensor_hidden) |
| return {"text_hidden": text_hidden, "sensor_hidden": sensor_hidden, "logits": logits} |
|
|
| @torch.no_grad() |
| def get_embedding(self, text, sensors): |
| sensor_hidden, sensor_mask = self.embed_sensor( |
| sensors=sensors["input_ids"], sensor_attn_mask=sensors["attention_mask"], |
| time_index=sensors["time_index"]) |
| self.multimodalModel.condition_image(sensor_hidden) |
| text_hidden, _ = self.multimodalModel( |
| input_ids=text["input_ids"][:, :-1], attention_mask=text["attention_mask"][:, :-1]) |
| text_hidden = self.text_to_latents(text_hidden) |
| sensor_hidden = self.img_to_latents(sensor_hidden) |
| if hasattr(self, "img_attn_pool"): |
| sensor_hidden = sensor_hidden[:, 0, :] |
| else: |
| sensor_hidden = masked_mean(sensor_hidden, rearrange(sensor_mask, "b n p -> b (n p) 1"), dim=1) |
| return text_hidden, sensor_hidden |
|
|
| @torch.no_grad() |
| def get_sensor_embedding(self, input_ids, mask, time_index): |
| sensor_hidden, sensor_mask = self.embed_sensor(sensors=input_ids, sensor_attn_mask=mask, time_index=time_index) |
| sensor_hidden = self.img_to_latents(sensor_hidden) |
| if hasattr(self, "img_attn_pool"): |
| sensor_hidden = sensor_hidden[:, 0, :] |
| else: |
| sensor_hidden = masked_mean(sensor_hidden, rearrange(sensor_mask, "b n p -> b (n p) 1"), dim=1) |
| return sensor_hidden |
|
|
| @torch.no_grad() |
| def generate(self, text, sensors, **generate_kwargs): |
| sensor_hidden, _ = self.embed_sensor( |
| sensors=sensors["input_ids"], sensor_attn_mask=sensors["attention_mask"], |
| time_index=sensors["time_index"]) |
| self.multimodalModel.condition_image(sensor_hidden) |
| return self.multimodalModel.model.generate( |
| input_ids=text["input_ids"], attention_mask=text["attention_mask"], |
| max_new_tokens=generate_kwargs.get("max_new_tokens", 300), |
| do_sample=generate_kwargs.get("do_sample", False), |
| num_beams=generate_kwargs.get("num_beams", 1)) |
|
|
| def sft_training(self, text, sensors, return_output=False): |
| sensor_hidden, _ = self.embed_sensor( |
| sensors=sensors["input_ids"], sensor_attn_mask=sensors["attention_mask"], |
| time_index=sensors["time_index"]) |
| self.multimodalModel.condition_image(sensor_hidden) |
| outputs = self.multimodalModel.model( |
| input_ids=text["input_ids"], attention_mask=text["attention_mask"], return_dict=True) |
| if return_output: |
| return outputs |
| logits = outputs.logits |
| labels = text["labels"] |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = labels[:, 1:].contiguous() |
| ce = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), |
| reduction="none", ignore_index=-100) |
| if "loss_weights" in text: |
| loss_weights = text["loss_weights"][:, 1:].contiguous().view(-1) |
| loss = (ce * loss_weights).sum() / loss_weights.sum() |
| else: |
| loss = ce.mean() |
| return {"loss": loss} |
|
|