jonathanjordan21 commited on
Commit
9adf831
1 Parent(s): 4aa8273

Update modeling_mos_mamba.py

Browse files
Files changed (1) hide show
  1. modeling_mos_mamba.py +995 -983
modeling_mos_mamba.py CHANGED
@@ -1,984 +1,996 @@
1
- # coding=utf-8
2
- # Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """PyTorch MAMBA model."""
16
-
17
- import math
18
- from dataclasses import dataclass
19
- from typing import Any, Dict, Optional, Tuple, Union
20
-
21
- import torch
22
- import torch.utils.checkpoint
23
- from torch import nn
24
- from torch.nn import CrossEntropyLoss
25
-
26
- from transformers.activations import ACT2FN
27
- from transformers.modeling_utils import PreTrainedModel
28
- from transformers.utils import ModelOutput
29
- from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available
30
- from .configuration_mos_mamba import MoSMambaConfig
31
-
32
- import torch.nn.functional as F
33
-
34
-
35
- if is_mamba_ssm_available():
36
- from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
37
- from mamba_ssm.ops.triton.selective_state_update import selective_state_update
38
- else:
39
- selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
40
-
41
- if is_causal_conv1d_available():
42
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
43
- else:
44
- causal_conv1d_update, causal_conv1d_fn = None, None
45
-
46
- is_fast_path_available = all(
47
- (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
48
- )
49
-
50
- _CHECKPOINT_FOR_DOC = "state-spaces/mamba-130m-hf"
51
- _CONFIG_FOR_DOC = "MoSMambaConfig"
52
-
53
-
54
- def load_balancing_loss_func(
55
- gate_logits: torch.Tensor, num_selectivities: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
56
- ) -> float:
57
- r"""
58
- Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
59
-
60
- See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
61
- function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
62
- experts is too unbalanced.
63
-
64
- Args:
65
- gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
66
- Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
67
- shape [batch_size X sequence_length, num_selectivities].
68
- attention_mask (`torch.Tensor`, None):
69
- The attention_mask used in forward function
70
- shape [batch_size X sequence_length] if not None.
71
- num_selectivities (`int`, *optional*):
72
- Number of experts
73
-
74
- Returns:
75
- The auxiliary loss.
76
- """
77
- if gate_logits is None or not isinstance(gate_logits, tuple):
78
- return 0
79
-
80
- if isinstance(gate_logits, tuple):
81
- compute_device = gate_logits[0].device
82
- concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
83
-
84
- routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
85
-
86
- _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
87
-
88
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_selectivities)
89
-
90
- if attention_mask is None:
91
- # Compute the percentage of tokens routed to each experts
92
- tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
93
-
94
- # Compute the average probability of routing to these experts
95
- router_prob_per_expert = torch.mean(routing_weights, dim=0)
96
- else:
97
- batch_size, sequence_length = attention_mask.shape
98
- num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
99
-
100
- # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
101
- expert_attention_mask = (
102
- attention_mask[None, :, :, None, None]
103
- .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_selectivities))
104
- .reshape(-1, top_k, num_selectivities)
105
- .to(compute_device)
106
- )
107
-
108
- # Compute the percentage of tokens routed to each experts
109
- tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
110
- expert_attention_mask, dim=0
111
- )
112
-
113
- # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
114
- router_per_expert_attention_mask = (
115
- attention_mask[None, :, :, None]
116
- .expand((num_hidden_layers, batch_size, sequence_length, num_selectivities))
117
- .reshape(-1, num_selectivities)
118
- .to(compute_device)
119
- )
120
-
121
- # Compute the average probability of routing to these experts
122
- router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
123
- router_per_expert_attention_mask, dim=0
124
- )
125
-
126
- overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
127
- return overall_loss * num_selectivities
128
-
129
-
130
- class MixtralBlockSparseTop2MLP(nn.Module):
131
- def __init__(self, intermediate_size, hidden_size, ssm_size):
132
- super().__init__()
133
- self.ffn_dim = intermediate_size
134
- self.hidden_dim = hidden_size
135
- self.ssm_dim = ssm_size
136
-
137
- self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
138
- self.w2 = nn.Linear(self.ffn_dim, self.ssm_dim, bias=False)
139
- self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
140
- self.w4 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
141
-
142
- self.act_fn = ACT2FN['silu']
143
-
144
- def forward(self, hidden_states):
145
- current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
146
- current_hidden_states = self.w4(current_hidden_states)
147
-
148
- return current_hidden_states
149
-
150
- class MixtureOfSelectivity(nn.Module):
151
- def __init__(self, intermediate_size, ssm_size):
152
- super().__init__()
153
- self.intermediate_size = intermediate_size
154
- self.ssm_dim = ssm_size
155
-
156
- # self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
157
- self.w2 = nn.Linear(self.intermediate_size, self.ssm_dim, bias=False)
158
-
159
-
160
- def forward(self, hidden_states):
161
- # current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
162
- return self.w2(hidden_states)
163
-
164
- class MoSMambaCache:
165
- """
166
- Arguments:
167
- config: MoSMambaConfig
168
- batch_size: int
169
- dtype: torch.dtype
170
- device: torch.device
171
-
172
- Attributes:
173
- seqlen_offset: int
174
- dtype: torch.dtype
175
- conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel_size]
176
- ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size]
177
- """
178
-
179
- def __init__(
180
- self, config: MoSMambaConfig, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
181
- ):
182
- self.seqlen_offset = 0
183
- self.dtype = dtype
184
- intermediate_size = config.intermediate_size
185
- ssm_state_size = config.state_size
186
- conv_kernel_size = config.conv_kernel
187
-
188
- self.conv_states = {
189
- i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
190
- for i in range(config.num_hidden_layers)
191
- }
192
- self.ssm_states = {
193
- i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
194
- for i in range(config.num_hidden_layers)
195
- }
196
-
197
-
198
- class MoSMambaMixer(nn.Module):
199
- """
200
- Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
201
- A, D are input independent (see MoSMamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
202
- ∆, B, C are input-dependent (this is a key difference between MoSMamba and the linear time invariant S4,
203
- and is why MoSMamba is called **selective** state spaces)
204
- """
205
-
206
- def __init__(self, config: MoSMambaConfig, layer_idx: int):
207
- super().__init__()
208
- self.hidden_size = config.hidden_size
209
- self.ssm_state_size = config.state_size
210
- self.conv_kernel_size = config.conv_kernel
211
- self.intermediate_size = config.intermediate_size
212
- self.time_step_rank = int(config.time_step_rank)
213
- self.layer_idx = layer_idx
214
- self.use_conv_bias = config.use_conv_bias
215
- self.conv1d = nn.Conv1d(
216
- in_channels=self.intermediate_size,
217
- out_channels=self.intermediate_size,
218
- bias=config.use_conv_bias,
219
- kernel_size=config.conv_kernel,
220
- groups=self.intermediate_size,
221
- padding=config.conv_kernel - 1,
222
- )
223
-
224
- self.activation = config.hidden_act
225
- self.act = ACT2FN[config.hidden_act]
226
-
227
- # num experts
228
- self.num_selectivities = config.num_selectivities
229
-
230
- # num selected experts
231
- self.top_k = config.num_selectivities_per_tok
232
-
233
- # projection of the input hidden states
234
- self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
235
- # selective projection used to make dt, B and C input dependant
236
- # self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False
237
-
238
- # self.x_proj = nn.ModuleList([nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) for _ in range(self.num_selectivities)])
239
- # for i in range(self.num_selectivities):
240
- # self.x_proj.add_module("x_proj_"+str(i), nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False))
241
-
242
- # self.x_proj_0 = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
243
- # self.x_proj_1 = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
244
- # self.x_proj_2 = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
245
- # self.x_proj_3 = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
246
- # self.x_proj_4 = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
247
- # self.x_proj_5 = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
248
-
249
-
250
- # self.x_proj2 = nn.ModuleList([MixtralBlockSparseTop2MLP(self.intermediate_size,self.hidden_size, self.time_step_rank + self.ssm_state_size * 2) for _ in range(self.num_selectivities)])
251
- self.x_proj = nn.ModuleList()
252
- for i in range(self.num_selectivities):
253
- self.x_proj.add_module(f"w{i}",nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False))
254
-
255
- self.gate = nn.Linear(self.hidden_size, self.num_selectivities, bias=False)
256
-
257
- # time step projection (discretization)
258
- self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
259
-
260
- # S4D real initialization. These are not discretized!
261
- # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
262
- A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
263
- A = A.expand(self.intermediate_size, -1).contiguous()
264
-
265
- self.A_log = nn.Parameter(torch.log(A))
266
- self.D = nn.Parameter(torch.ones(self.intermediate_size))
267
- self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
268
- self.use_bias = config.use_bias
269
-
270
- self.jitter_noise = 0.001
271
-
272
- self.register_parameter("A_log", self.A_log)
273
- self.register_parameter("D", self.D)
274
-
275
- # for i in enumerate(self.x_proj):
276
- # self.register_parameter("x_proj_"+str(i), x)
277
-
278
-
279
- def cuda_kernels_forward(self, hidden_states: torch.Tensor, x_proj, cache_params: Optional[MoSMambaCache] = None):
280
- # 1. Gated MLP's linear projection
281
- # router_logits =
282
- batch_size, seq_len, _ = hidden_states.shape
283
-
284
- projected_states = self.in_proj(hidden_states).transpose(1, 2)
285
-
286
- if projected_states.shape[-1] == 0:
287
- hidden_states, gate = projected_states.chunk(2, dim=1)
288
- dtype = hidden_states.dtype
289
-
290
- if cache_params is not None:
291
- ssm_state = cache_params.ssm_states[self.layer_idx].clone()
292
- if cache_params.seqlen_offset > 0:
293
- conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
294
- conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
295
- conv_state[:, :, -1] = hidden_states[:, :, 0]
296
- cache_params.conv_states[self.layer_idx].copy_(conv_state)
297
- hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
298
- if self.use_conv_bias:
299
- hidden_states += self.conv1d.bias
300
- hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
301
- else:
302
- conv_state = nn.functional.pad(
303
- hidden_states,
304
- (self.conv_kernel_size - hidden_states.shape[-1], 0)
305
- )
306
- cache_params.conv_states[self.layer_idx].copy_(conv_state)
307
- if hidden_states.shape[-1] == 0:
308
- hidden_states = hidden_states.permute(2,1,0)
309
- hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
310
- else:
311
- ssm_state = torch.zeros(
312
- (batch_size, self.intermediate_size, self.ssm_state_size),
313
- device=hidden_states.device, dtype=dtype
314
- )
315
- # print(hidden_states.shape)
316
- # print(self.conv1d)
317
- if hidden_states.shape[-1] == 0:
318
- hidden_states = hidden_states.permute(2,1,0)
319
- hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
320
-
321
- scan_output = (hidden_states * self.D[None, :, None])
322
- scan_output = (scan_output * self.act(gate))
323
- if cache_params is not None:
324
- cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
325
-
326
- # 4. Final linear projection
327
- contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
328
- return contextualized_states
329
-
330
- elif self.training and cache_params is None: # Doesn't support outputting the states -> used for training
331
- contextualized_states = mamba_inner_fn(
332
- projected_states,
333
- self.conv1d.weight,
334
- self.conv1d.bias if self.use_conv_bias else None,
335
- x_proj.weight,
336
- self.dt_proj.weight,
337
- self.out_proj.weight,
338
- self.out_proj.bias.float() if self.use_bias else None,
339
- -torch.exp(self.A_log.float()),
340
- None, # input-dependent B
341
- None, # input-dependent C
342
- self.D.float(),
343
- delta_bias=self.dt_proj.bias.float(),
344
- delta_softplus=True,
345
- )
346
-
347
- else:
348
- hidden_states, gate = projected_states.chunk(2, dim=1)
349
- conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
350
-
351
- # print("NON ZERO", hidden_states.shape)
352
- # 2. Convolution sequence transformation
353
- if cache_params is not None and cache_params.seqlen_offset > 0:
354
- hidden_states = causal_conv1d_update(
355
- hidden_states.squeeze(-1),
356
- cache_params.conv_states[self.layer_idx],
357
- conv_weights,
358
- self.conv1d.bias,
359
- self.activation,
360
- )
361
- hidden_states = hidden_states.unsqueeze(-1)
362
- else:
363
- if cache_params is not None:
364
- conv_states = nn.functional.pad(
365
- hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
366
- )
367
- # print(conv_states)
368
- cache_params.conv_states[self.layer_idx].copy_(conv_states)
369
-
370
- hidden_states = causal_conv1d_fn(
371
- hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
372
- )
373
- # 3. State Space Model sequence transformation
374
- # 3.a. input varying initialization of time_step, B and C
375
- ssm_parameters = x_proj(hidden_states.transpose(1, 2))
376
- time_step, B, C = torch.split(
377
- ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
378
- )
379
- discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
380
-
381
- A = -torch.exp(self.A_log.float())
382
- # 3.c perform the recurrence y ← SSM(A, B, C)(x)
383
- time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
384
-
385
- if cache_params is not None and cache_params.seqlen_offset > 0:
386
- scan_outputs = selective_state_update(
387
- cache_params.ssm_states[self.layer_idx],
388
- hidden_states[..., 0],
389
- discrete_time_step[..., 0],
390
- A,
391
- B[:, 0],
392
- C[:, 0],
393
- self.D,
394
- gate[..., 0],
395
- time_proj_bias,
396
- dt_softplus=True,
397
- ).unsqueeze(-1)
398
- else:
399
- # print("A.shape",A.shape)
400
- # print("hidden_states", hidden_states.shape)
401
- # print("discrete_time_step", discrete_time_step.shape)
402
- # print("GATE.SHAOE", gate.shape)
403
-
404
- scan_outputs, ssm_state = selective_scan_fn(
405
- hidden_states,
406
- discrete_time_step,
407
- A,
408
- B.transpose(1, 2),
409
- C.transpose(1, 2),
410
- self.D.float(),
411
- gate,
412
- time_proj_bias,
413
- delta_softplus=True,
414
- return_last_state=True,
415
- )
416
- # print("SCANOUTPUTS | SSMSTATE", scan_outputs.shape, ssm_state.shape)
417
- if ssm_state is not None and cache_params is not None:
418
- cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
419
-
420
- # 4. Final linear projection
421
- contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
422
- return contextualized_states
423
-
424
- # fmt: off
425
- def slow_forward(self, input_states, x_proj, cache_params: Optional[MoSMambaCache]=None):
426
- batch_size, seq_len, _ = input_states.shape
427
- dtype = input_states.dtype
428
- # 1. Gated MLP's linear projection
429
- projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
430
- hidden_states, gate = projected_states.chunk(2, dim=1)
431
-
432
- # 2. Convolution sequence transformation
433
- if cache_params is not None:
434
- ssm_state = cache_params.ssm_states[self.layer_idx].clone()
435
- if cache_params.seqlen_offset > 0:
436
- conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
437
- conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
438
- conv_state[:, :, -1] = hidden_states[:, :, 0]
439
- cache_params.conv_states[self.layer_idx].copy_(conv_state)
440
- hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
441
- if self.use_conv_bias:
442
- hidden_states += self.conv1d.bias
443
- hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
444
- else:
445
- conv_state = nn.functional.pad(
446
- hidden_states,
447
- (self.conv_kernel_size - hidden_states.shape[-1], 0)
448
- )
449
- cache_params.conv_states[self.layer_idx].copy_(conv_state)
450
- if hidden_states.shape[-1] == 0:
451
- hidden_states = hidden_states.permute(2,1,0)
452
- hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
453
- else:
454
- ssm_state = torch.zeros(
455
- (batch_size, self.intermediate_size, self.ssm_state_size),
456
- device=hidden_states.device, dtype=dtype
457
- )
458
- # print(hidden_states.shape)
459
- # print(self.conv1d)
460
- if hidden_states.shape[-1] == 0:
461
- hidden_states = hidden_states.permute(2,1,0)
462
- hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
463
-
464
- # 3. State Space Model sequence transformation
465
- # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
466
- ssm_parameters = x_proj(hidden_states.transpose(1, 2))
467
- time_step, B, C = torch.split(
468
- ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
469
- )
470
- discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
471
- discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
472
-
473
- # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
474
- A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
475
- discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
476
- discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediade_size, seq_len, ssm_state_size]
477
- deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
478
-
479
- # 3.c perform the recurrence y ← SSM(A, B, C)(x)
480
- scan_outputs = []
481
- for i in range(seq_len):
482
- ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
483
- scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
484
- scan_outputs.append(scan_output[:, :, 0])
485
- # print(scan_outputs)
486
- scan_output = torch.stack(scan_outputs, dim=-1) if scan_outputs else torch.tensor(scan_outputs) # [batch, seq_len, intermediade_size]
487
- scan_output = scan_output + (hidden_states * self.D[None, :, None])
488
- scan_output = (scan_output * self.act(gate))
489
-
490
- if cache_params is not None:
491
- cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
492
-
493
- # 4. Final linear projection
494
- contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
495
- return contextualized_states
496
-
497
- def forward(self, hidden_states, cache_params: Optional[MoSMambaCache] = None):
498
- batch_size, sequence_length, hidden_dim = hidden_states.shape
499
-
500
- if self.training and self.jitter_noise > 0:
501
- hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
502
-
503
- # print('BATCH_SIZE | SEQ LENGTH | HID DIM:',batch_size, sequence_length, hidden_dim)
504
-
505
- hidden_states = hidden_states.view(-1, hidden_dim)
506
-
507
- router_logits = self.gate(hidden_states)
508
-
509
- # print("ROUTER LOGITS:", router_logits, router_logits.size())
510
-
511
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
512
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
513
- # print("ROUTING WEIGHTS", routing_weights, routing_weights.shape)
514
- # print("SEL EXPERTS", selected_experts, selected_experts.shape)
515
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
516
- # we cast back to the input dtype
517
- routing_weights = routing_weights.to(hidden_states.dtype)
518
-
519
- # print(routing_weights .shape)
520
-
521
- final_hidden_states = torch.zeros(
522
- (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
523
- )
524
-
525
- # One hot encode the selected experts to create an expert mask
526
- # this will be used to easily index which expert is going to be sollicitated
527
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_selectivities).permute(2, 1, 0)
528
- # print("EXPERT MASK", expert_mask, expert_mask.shape)
529
-
530
- # Loop over all available experts in the model and perform the computation on each expert
531
- for expert_idx in range(self.num_selectivities):
532
- # expert_layer = self.x_proj[expert_idx]
533
- expert_layer = self.x_proj.get_submodule(f"w{expert_idx}")
534
- # expert_layer = getattr(self, f'x_proj_{expert_idx}')
535
- idx, top_x = torch.where(expert_mask[expert_idx])
536
- # print("expert_mask[expert_idx]:",expert_mask[expert_idx], expert_mask[expert_idx].shape)
537
-
538
-
539
- # Index the correct hidden states and compute the expert hidden state for
540
- # the current expert. We need to make sure to multiply the output hidden
541
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
542
- # print("TOP_x:",top_x)
543
- # print("TOP X.SHAPE:",top_x.shape)
544
- # print("HIDDEN STATES.SHAPE:",hidden_states.shape)
545
- # print("HIDDEN STATES[NONE, TOPX].SHAPE:", hidden_states[None, top_x].shape)
546
-
547
-
548
- # print("TOP_X | IDX", top_x, idx)
549
-
550
- current_state = hidden_states[None, top_x]
551
- # print("TOPX", top_x,top_x.shape)
552
- # print("CURRENT_STATE",current_state.shape)
553
- current_state = current_state.reshape(-1, hidden_dim)#.reshape(batch_size, sequence_length, hidden_dim )
554
-
555
- # if current_state.shape[1] == 0:
556
- # continue
557
-
558
-
559
- # print("CURRENT_STATE",current_state)
560
-
561
- # current_state = hidden_states.reshape(batch_size, sequence_length, hidden_dim )
562
-
563
- # print(current_state.shape)
564
- # if current_state.shape[0] < 1:
565
- # print(current_state)
566
- # current_state = current_state.reshape(batch_size, 1, hidden_dim)
567
- # else:
568
- # current_state = current_state.reshape(batch_size, sequence_length, hidden_dim)
569
-
570
- # print("current_state.shape", current_state.shape, "ROUTING WEIGHTS",routing_weights[top_x, idx, None].shape)
571
-
572
- current_state = current_state * routing_weights[top_x, idx, None]
573
-
574
- # print("current_hidden_states.shape", current_state.shape)
575
-
576
- current_hidden_states = current_state[None]
577
-
578
-
579
-
580
-
581
- # print("current_hidden_states[none].shape", current_hidden_states.shape)
582
-
583
- if current_hidden_states.shape[1] != 0:
584
-
585
- if is_fast_path_available and "cuda" in expert_layer.weight.device.type:
586
- # if is_fast_path_available and "cuda" in expert_layer.w2.weight.device.type:
587
- current_hidden_states = self.cuda_kernels_forward(current_hidden_states, expert_layer, cache_params) * routing_weights[top_x, idx, None]
588
- else:
589
- current_hidden_states = self.slow_forward(current_hidden_states, expert_layer, cache_params) * routing_weights[top_x, idx, None]
590
- # else:
591
- # expert_layer.grad = torch.zeros_like(expert_layer.weight)
592
- # current_hidden_states = expert_layer(current_state)
593
-
594
- current_hidden_states = current_hidden_states.reshape(-1, hidden_dim)
595
- # print(current_hidden_states.shape, final_hidden_states.shape)
596
-
597
- # However `index_add_` only support torch tensors for indexing so we'll use
598
- # the `top_x` tensor here.
599
- final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
600
- final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
601
-
602
- return final_hidden_states, router_logits
603
-
604
-
605
- class MoSMambaRMSNorm(nn.Module):
606
- def __init__(self, hidden_size, eps=1e-6):
607
- """
608
- MoSMambaRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
609
- """
610
- super().__init__()
611
- self.weight = nn.Parameter(torch.ones(hidden_size))
612
- self.variance_epsilon = eps
613
-
614
- def forward(self, hidden_states):
615
- input_dtype = hidden_states.dtype
616
- hidden_states = hidden_states.to(torch.float32)
617
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
618
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
619
- return self.weight * hidden_states.to(input_dtype)
620
-
621
-
622
- class MoSMambaBlock(nn.Module):
623
- def __init__(self, config, layer_idx):
624
- super().__init__()
625
- self.config = config
626
- self.layer_idx = layer_idx
627
- self.residual_in_fp32 = config.residual_in_fp32
628
- self.norm = MoSMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
629
- self.mixer = MoSMambaMixer(config, layer_idx=layer_idx)
630
-
631
- def forward(self, hidden_states, cache_params: Optional[MoSMambaCache] = None, output_router_logits:Optional[bool] = False):
632
- residual = hidden_states
633
- hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
634
- if self.residual_in_fp32:
635
- residual = residual.to(torch.float32)
636
-
637
- hidden_states, router_logits = self.mixer(hidden_states, cache_params=cache_params)
638
- hidden_states = residual + hidden_states
639
- outputs = (hidden_states,)
640
-
641
- if output_router_logits:
642
- outputs += (router_logits,)
643
- return outputs
644
-
645
-
646
- class MoSMambaPreTrainedModel(PreTrainedModel):
647
- """
648
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
649
- models.
650
- """
651
-
652
- config_class = MoSMambaConfig
653
- base_model_prefix = "backbone"
654
- _no_split_modules = ["MoSMambaBlock"]
655
- supports_gradient_checkpointing = True
656
-
657
- def _init_weights(self, module):
658
- """Initialize the weights."""
659
- if isinstance(module, MoSMambaMixer):
660
- module.A_log._no_weight_decay = True
661
- module.D._no_weight_decay = True
662
-
663
- dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
664
- if self.config.time_step_init_scheme == "constant":
665
- nn.init.constant_(module.dt_proj.weight, dt_init_std)
666
- elif self.config.time_step_init_scheme == "random":
667
- nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
668
-
669
- dt = torch.exp(
670
- torch.rand(self.config.intermediate_size)
671
- * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
672
- + math.log(self.config.time_step_min)
673
- ).clamp(min=self.config.time_step_floor)
674
- # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
675
- inv_dt = dt + torch.log(-torch.expm1(-dt))
676
- with torch.no_grad():
677
- module.dt_proj.bias.copy_(inv_dt)
678
- module.dt_proj.bias._no_reinit = True
679
-
680
- if isinstance(module, nn.Linear):
681
- if module.bias is not None:
682
- if not getattr(module.bias, "_no_reinit", False):
683
- nn.init.zeros_(module.bias)
684
- elif isinstance(module, nn.Embedding):
685
- nn.init.normal_(module.weight, std=self.config.initializer_range)
686
-
687
- if self.config.rescale_prenorm_residual:
688
- # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
689
- # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
690
- # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
691
- # > -- GPT-2 :: https://openai.com/blog/better-language-models/
692
- #
693
- # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
694
- for name, p in module.named_parameters():
695
- if name in ["out_proj.weight"]:
696
- # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
697
- # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
698
- # We need to reinit p since this code could be called multiple times
699
- # Having just p *= scale would repeatedly scale it down
700
- nn.init.kaiming_uniform_(p, a=math.sqrt(5))
701
- with torch.no_grad():
702
- p /= math.sqrt(self.config.num_layers)
703
-
704
-
705
- @dataclass
706
- class MoSMambaOutput(ModelOutput):
707
- """
708
- Class for the MAMBA model outputs.
709
-
710
- Args:
711
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
712
- Sequence of hidden-states at the output of the last layer of the model.
713
- cache_params (`MoSMambaCache`):
714
- The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
715
- avoid providing the old `input_ids`.
716
-
717
- Includes both the State space model state matrices after the selective scan, and the Convolutional states
718
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
719
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
720
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
721
-
722
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
723
- """
724
-
725
- last_hidden_state: Optional[torch.FloatTensor] = None
726
- cache_params: Optional[MoSMambaCache] = None
727
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
728
- router_logits: Optional[Tuple[torch.FloatTensor]] = None
729
-
730
-
731
- @dataclass
732
- class MoSMambaCausalLMOutput(ModelOutput):
733
- """
734
- Base class for causal language model (or autoregressive) outputs.
735
-
736
- Args:
737
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
738
- Language modeling loss (for next-token prediction).
739
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
740
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
741
- cache_params (`MoSMambaCache`):
742
- The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
743
- avoid providing the old `input_ids`.
744
-
745
- Includes both the State space model state matrices after the selective scan, and the Convolutional states
746
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
747
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
748
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
749
-
750
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
751
- """
752
-
753
- loss: Optional[torch.FloatTensor] = None
754
- logits: Optional[torch.FloatTensor] = None
755
- cache_params: Optional[MoSMambaCache] = None
756
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
757
- router_logits: Optional[Tuple[torch.FloatTensor]] = None
758
-
759
-
760
- class MoSMambaModel(MoSMambaPreTrainedModel):
761
- def __init__(self, config):
762
- super().__init__(config)
763
-
764
- self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
765
- self.layers = nn.ModuleList([MoSMambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
766
-
767
- self.gradient_checkpointing = False
768
- self.norm_f = MoSMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
769
- # Initialize weights and apply final processing
770
- self._register_load_state_dict_pre_hook(self.load_hook)
771
- self.post_init()
772
- self.config.output_router_logits = True
773
-
774
- def load_hook(self, state_dict, prefix, *args):
775
- for k in state_dict:
776
- if "embedding." in k:
777
- state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
778
- break
779
-
780
- def get_input_embeddings(self):
781
- return self.embeddings
782
-
783
- def set_input_embeddings(self, new_embeddings):
784
- self.embeddings = new_embeddings
785
-
786
- def forward(
787
- self,
788
- input_ids: Optional[torch.LongTensor] = None,
789
- inputs_embeds: Optional[torch.LongTensor] = None,
790
- cache_params: Optional[MoSMambaCache] = None,
791
- use_cache: Optional[bool] = None,
792
- output_hidden_states: Optional[bool] = None,
793
- output_router_logits: Optional[bool] = None,
794
- return_dict: Optional[bool] = None,
795
- **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
796
- ) -> Union[Tuple, MoSMambaOutput]:
797
- output_hidden_states = (
798
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
799
- )
800
- output_router_logits = (
801
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
802
- )
803
- use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
804
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
805
-
806
- if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
807
- raise ValueError(
808
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
809
- )
810
-
811
- if inputs_embeds is None:
812
- inputs_embeds = self.embeddings(input_ids)
813
-
814
- if self.gradient_checkpointing and self.training and use_cache:
815
- use_cache = False
816
-
817
- if cache_params is None and use_cache:
818
- cache_params = MoSMambaCache(
819
- self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
820
- )
821
-
822
- hidden_states = inputs_embeds
823
- all_hidden_states = () if output_hidden_states else None
824
- all_router_logits = () if output_router_logits else None
825
- for mixer_block in self.layers:
826
- if self.gradient_checkpointing and self.training:
827
- layer_outputs = self._gradient_checkpointing_func(mixer_block.__call__, hidden_states, cache_params, output_router_logits)
828
- else:
829
- layer_outputs = mixer_block(hidden_states, cache_params=cache_params,output_router_logits=output_router_logits)
830
-
831
- hidden_states = layer_outputs[0]
832
-
833
- if output_hidden_states:
834
- all_hidden_states = all_hidden_states + (hidden_states,)
835
-
836
- if output_router_logits:
837
- all_router_logits += (layer_outputs[-1],)
838
-
839
- if use_cache:
840
- cache_params.seqlen_offset += inputs_embeds.shape[1]
841
-
842
- hidden_states = self.norm_f(hidden_states)
843
-
844
- if output_hidden_states:
845
- all_hidden_states = all_hidden_states + (hidden_states,)
846
-
847
-
848
- if not return_dict:
849
- return tuple(v for v in [hidden_states, cache_params, all_hidden_states, all_router_logits] if v is not None)
850
-
851
- return MoSMambaOutput(
852
- last_hidden_state=hidden_states,
853
- cache_params=cache_params if use_cache else None,
854
- hidden_states=all_hidden_states,
855
- router_logits=all_router_logits,
856
- )
857
-
858
-
859
- class MoSMambaForCausalLM(MoSMambaPreTrainedModel):
860
- _tied_weights_keys = ["lm_head.weight"]
861
-
862
- def __init__(self, config):
863
- super().__init__(config)
864
- self.backbone = MoSMambaModel(config)
865
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
866
- self.num_selectivities = 6
867
- self.num_selectivities_per_tok = 2
868
- self.router_aux_loss_coef = 0.02
869
- # Initialize weights and apply final processing
870
- self.post_init()
871
-
872
- def get_output_embeddings(self):
873
- return self.lm_head
874
-
875
- def set_output_embeddings(self, new_embeddings):
876
- self.lm_head = new_embeddings
877
-
878
- def get_input_embeddings(self):
879
- return self.backbone.get_input_embeddings()
880
-
881
- def set_input_embeddings(self, new_embeddings):
882
- return self.backbone.set_input_embeddings(new_embeddings)
883
-
884
- def _update_model_kwargs_for_generation(
885
- self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
886
- ) -> Dict[str, Any]:
887
- model_kwargs["cache_params"] = outputs.get("cache_params", None)
888
- return model_kwargs
889
-
890
- def prepare_inputs_for_generation(
891
- self, input_ids, cache_params: Optional[MoSMambaCache] = None, inputs_embeds=None, attention_mask=None, output_router_logits=False, **kwargs
892
- ):
893
- # only last token for inputs_ids if the state is passed along.
894
- if cache_params is not None:
895
- input_ids = input_ids[:, -1].unsqueeze(-1)
896
-
897
- if inputs_embeds is not None and cache_params is None:
898
- model_inputs = {"inputs_embeds": inputs_embeds}
899
- else:
900
- model_inputs = {"input_ids": input_ids}
901
-
902
- model_inputs["cache_params"] = cache_params
903
- model_inputs['output_router_logits'] = output_router_logits
904
- return model_inputs
905
-
906
-
907
- def forward(
908
- self,
909
- input_ids: Optional[torch.LongTensor] = None,
910
- inputs_embeds: Optional[torch.FloatTensor] = None,
911
- cache_params: Optional[MoSMambaCache] = None,
912
- labels: Optional[torch.LongTensor] = None,
913
- output_hidden_states: Optional[bool] = None,
914
- output_router_logits: Optional[bool] = None,
915
- return_dict: Optional[bool] = None,
916
- use_cache: Optional[bool] = None,
917
- **kwargs, # for now we need this for generation
918
- ) -> Union[Tuple, MoSMambaCausalLMOutput]:
919
- r"""
920
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
921
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
922
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
923
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
924
- """
925
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
926
-
927
- output_router_logits = (
928
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
929
- )
930
-
931
- mamba_outputs = self.backbone(
932
- input_ids,
933
- cache_params=cache_params,
934
- inputs_embeds=inputs_embeds,
935
- output_hidden_states=output_hidden_states,
936
- return_dict=return_dict,
937
- use_cache=use_cache,
938
- )
939
- hidden_states = mamba_outputs[0]
940
-
941
- logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
942
-
943
- loss = None
944
- if labels is not None:
945
- # move labels to correct device to enable model parallelism
946
- labels = labels.to(logits.device)
947
- # Shift so that tokens < n predict n
948
- shift_logits = logits[..., :-1, :].contiguous()
949
- shift_labels = labels[..., 1:].contiguous()
950
- # Flatten the tokens
951
- loss_fct = CrossEntropyLoss()
952
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
953
-
954
- aux_loss = None
955
- if output_router_logits:
956
- aux_loss = load_balancing_loss_func(
957
- mamba_outputs.router_logits if return_dict else mamba_outputs[-1],
958
- self.num_selectivities,
959
- self.num_selectivities_per_tok,
960
- # attention_mask,
961
- )
962
- if labels is not None:
963
- loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
964
-
965
- # print("AUX LOSS:", aux_loss)
966
- # print("LOSS:", loss)
967
-
968
- if not return_dict:
969
- output = (logits,) + mamba_outputs[1:]
970
- if output_router_logits:
971
- output = (aux_loss,) + output
972
- return (loss,) + output if loss is not None else output
973
-
974
- # if not return_dict:
975
- # output = (logits,) + mamba_outputs[1:]
976
- # return ((loss,) + output) if loss is not None else output
977
-
978
- return MoSMambaCausalLMOutput(
979
- loss=loss,
980
- logits=logits,
981
- cache_params=mamba_outputs.cache_params,
982
- hidden_states=mamba_outputs.hidden_states,
983
- router_logits=mamba_outputs.router_logits,
 
 
 
 
 
 
 
 
 
 
 
 
984
  )
 
1
+ # coding=utf-8
2
+ # Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch MAMBA model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import CrossEntropyLoss
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.modeling_utils import PreTrainedModel
28
+ from transformers.utils import ModelOutput
29
+ from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available
30
+ from .configuration_mos_mamba import MoSMambaConfig
31
+
32
+ import torch.nn.functional as F
33
+
34
+
35
+ if is_mamba_ssm_available():
36
+ from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
37
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
38
+ else:
39
+ selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
40
+
41
+ if is_causal_conv1d_available():
42
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
43
+ else:
44
+ causal_conv1d_update, causal_conv1d_fn = None, None
45
+
46
+ is_fast_path_available = all(
47
+ (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
48
+ )
49
+
50
+ _CHECKPOINT_FOR_DOC = "state-spaces/mamba-130m-hf"
51
+ _CONFIG_FOR_DOC = "MoSMambaConfig"
52
+
53
+
54
+ def load_balancing_loss_func(
55
+ gate_logits: torch.Tensor, num_selectivities: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
56
+ ) -> float:
57
+ r"""
58
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
59
+
60
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
61
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
62
+ experts is too unbalanced.
63
+
64
+ Args:
65
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
66
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
67
+ shape [batch_size X sequence_length, num_selectivities].
68
+ attention_mask (`torch.Tensor`, None):
69
+ The attention_mask used in forward function
70
+ shape [batch_size X sequence_length] if not None.
71
+ num_selectivities (`int`, *optional*):
72
+ Number of experts
73
+
74
+ Returns:
75
+ The auxiliary loss.
76
+ """
77
+ if gate_logits is None or not isinstance(gate_logits, tuple):
78
+ return 0
79
+
80
+ if isinstance(gate_logits, tuple):
81
+ compute_device = gate_logits[0].device
82
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
83
+
84
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
85
+
86
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
87
+
88
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_selectivities)
89
+
90
+ if attention_mask is None:
91
+ # Compute the percentage of tokens routed to each experts
92
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
93
+
94
+ # Compute the average probability of routing to these experts
95
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
96
+ else:
97
+ batch_size, sequence_length = attention_mask.shape
98
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
99
+
100
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
101
+ expert_attention_mask = (
102
+ attention_mask[None, :, :, None, None]
103
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_selectivities))
104
+ .reshape(-1, top_k, num_selectivities)
105
+ .to(compute_device)
106
+ )
107
+
108
+ # Compute the percentage of tokens routed to each experts
109
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
110
+ expert_attention_mask, dim=0
111
+ )
112
+
113
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
114
+ router_per_expert_attention_mask = (
115
+ attention_mask[None, :, :, None]
116
+ .expand((num_hidden_layers, batch_size, sequence_length, num_selectivities))
117
+ .reshape(-1, num_selectivities)
118
+ .to(compute_device)
119
+ )
120
+
121
+ # Compute the average probability of routing to these experts
122
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
123
+ router_per_expert_attention_mask, dim=0
124
+ )
125
+
126
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
127
+ return overall_loss * num_selectivities
128
+
129
+
130
+ class MixtralBlockSparseTop2MLP(nn.Module):
131
+ def __init__(self, intermediate_size, hidden_size, ssm_size):
132
+ super().__init__()
133
+ self.ffn_dim = intermediate_size
134
+ self.hidden_dim = hidden_size
135
+ self.ssm_dim = ssm_size
136
+
137
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
138
+ self.w2 = nn.Linear(self.ffn_dim, self.ssm_dim, bias=False)
139
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
140
+ self.w4 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
141
+
142
+ self.act_fn = ACT2FN['silu']
143
+
144
+ def forward(self, hidden_states):
145
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
146
+ current_hidden_states = self.w4(current_hidden_states)
147
+
148
+ return current_hidden_states
149
+
150
+ class MixtureOfSelectivity(nn.Module):
151
+ def __init__(self, intermediate_size, ssm_size):
152
+ super().__init__()
153
+ self.intermediate_size = intermediate_size
154
+ self.ssm_dim = ssm_size
155
+
156
+ # self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
157
+ self.w2 = nn.Linear(self.intermediate_size, self.ssm_dim, bias=False)
158
+
159
+
160
+ def forward(self, hidden_states):
161
+ # current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
162
+ return self.w2(hidden_states)
163
+
164
+ class MoSMambaCache:
165
+ """
166
+ Arguments:
167
+ config: MoSMambaConfig
168
+ batch_size: int
169
+ dtype: torch.dtype
170
+ device: torch.device
171
+
172
+ Attributes:
173
+ seqlen_offset: int
174
+ dtype: torch.dtype
175
+ conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel_size]
176
+ ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size]
177
+ """
178
+
179
+ def __init__(
180
+ self, config: MoSMambaConfig, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
181
+ ):
182
+ self.seqlen_offset = 0
183
+ self.dtype = dtype
184
+ intermediate_size = config.intermediate_size
185
+ ssm_state_size = config.state_size
186
+ conv_kernel_size = config.conv_kernel
187
+
188
+ self.conv_states = {
189
+ i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
190
+ for i in range(config.num_hidden_layers)
191
+ }
192
+ self.ssm_states = {
193
+ i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
194
+ for i in range(config.num_hidden_layers)
195
+ }
196
+
197
+
198
+ class MoSMambaMixer(nn.Module):
199
+ """
200
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
201
+ A, D are input independent (see MoSMamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
202
+ ∆, B, C are input-dependent (this is a key difference between MoSMamba and the linear time invariant S4,
203
+ and is why MoSMamba is called **selective** state spaces)
204
+ """
205
+
206
+ def __init__(self, config: MoSMambaConfig, layer_idx: int):
207
+ super().__init__()
208
+ self.hidden_size = config.hidden_size
209
+ self.ssm_state_size = config.state_size
210
+ self.conv_kernel_size = config.conv_kernel
211
+ self.intermediate_size = config.intermediate_size
212
+ self.time_step_rank = int(config.time_step_rank)
213
+ self.layer_idx = layer_idx
214
+ self.use_conv_bias = config.use_conv_bias
215
+ self.conv1d = nn.Conv1d(
216
+ in_channels=self.intermediate_size,
217
+ out_channels=self.intermediate_size,
218
+ bias=config.use_conv_bias,
219
+ kernel_size=config.conv_kernel,
220
+ groups=self.intermediate_size,
221
+ padding=config.conv_kernel - 1,
222
+ )
223
+
224
+ self.activation = config.hidden_act
225
+ self.act = ACT2FN[config.hidden_act]
226
+
227
+ # num experts
228
+ self.num_selectivities = config.num_selectivities
229
+
230
+ # num selected experts
231
+ self.top_k = config.num_selectivities_per_tok
232
+
233
+ # projection of the input hidden states
234
+ self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
235
+ # selective projection used to make dt, B and C input dependant
236
+ # self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False
237
+
238
+ # self.x_proj = nn.ModuleList([nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) for _ in range(self.num_selectivities)])
239
+ # for i in range(self.num_selectivities):
240
+ # self.x_proj.add_module("x_proj_"+str(i), nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False))
241
+
242
+ # self.x_proj_0 = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
243
+ # self.x_proj_1 = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
244
+ # self.x_proj_2 = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
245
+ # self.x_proj_3 = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
246
+ # self.x_proj_4 = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
247
+ # self.x_proj_5 = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
248
+
249
+
250
+ # self.x_proj2 = nn.ModuleList([MixtralBlockSparseTop2MLP(self.intermediate_size,self.hidden_size, self.time_step_rank + self.ssm_state_size * 2) for _ in range(self.num_selectivities)])
251
+ self.x_proj = nn.ModuleList()
252
+ for i in range(self.num_selectivities):
253
+ self.x_proj.add_module(f"w{i}",nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False))
254
+
255
+ self.gate = nn.Linear(self.hidden_size, self.num_selectivities, bias=False)
256
+
257
+ # time step projection (discretization)
258
+ self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
259
+
260
+ # S4D real initialization. These are not discretized!
261
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
262
+ A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
263
+ A = A.expand(self.intermediate_size, -1).contiguous()
264
+
265
+ self.A_log = nn.Parameter(torch.log(A))
266
+ self.D = nn.Parameter(torch.ones(self.intermediate_size))
267
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
268
+ self.use_bias = config.use_bias
269
+
270
+ self.jitter_noise = 0.001
271
+
272
+ self.register_parameter("A_log", self.A_log)
273
+ self.register_parameter("D", self.D)
274
+
275
+ # for i in enumerate(self.x_proj):
276
+ # self.register_parameter("x_proj_"+str(i), x)
277
+
278
+
279
+ def cuda_kernels_forward(self, hidden_states: torch.Tensor, x_proj, cache_params: Optional[MoSMambaCache] = None):
280
+ # 1. Gated MLP's linear projection
281
+ # router_logits =
282
+ batch_size, seq_len, _ = hidden_states.shape
283
+
284
+ projected_states = self.in_proj(hidden_states).transpose(1, 2)
285
+
286
+ if projected_states.shape[-1] == 0:
287
+ hidden_states, gate = projected_states.chunk(2, dim=1)
288
+ dtype = hidden_states.dtype
289
+
290
+ if cache_params is not None:
291
+ ssm_state = cache_params.ssm_states[self.layer_idx].clone()
292
+ if cache_params.seqlen_offset > 0:
293
+ conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
294
+ conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
295
+ conv_state[:, :, -1] = hidden_states[:, :, 0]
296
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
297
+ hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
298
+ if self.use_conv_bias:
299
+ hidden_states += self.conv1d.bias
300
+ hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
301
+ else:
302
+ conv_state = nn.functional.pad(
303
+ hidden_states,
304
+ (self.conv_kernel_size - hidden_states.shape[-1], 0)
305
+ )
306
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
307
+ if hidden_states.shape[-1] == 0:
308
+ hidden_states = hidden_states.permute(2,1,0)
309
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
310
+ else:
311
+ ssm_state = torch.zeros(
312
+ (batch_size, self.intermediate_size, self.ssm_state_size),
313
+ device=hidden_states.device, dtype=dtype
314
+ )
315
+ # print(hidden_states.shape)
316
+ # print(self.conv1d)
317
+ if hidden_states.shape[-1] == 0:
318
+ hidden_states = hidden_states.permute(2,1,0)
319
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
320
+
321
+ scan_output = (hidden_states * self.D[None, :, None])
322
+ scan_output = (scan_output * self.act(gate))
323
+ if cache_params is not None:
324
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
325
+
326
+ # 4. Final linear projection
327
+ contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
328
+ return contextualized_states
329
+
330
+ elif self.training and cache_params is None: # Doesn't support outputting the states -> used for training
331
+ contextualized_states = mamba_inner_fn(
332
+ projected_states,
333
+ self.conv1d.weight,
334
+ self.conv1d.bias if self.use_conv_bias else None,
335
+ x_proj.weight,
336
+ self.dt_proj.weight,
337
+ self.out_proj.weight,
338
+ self.out_proj.bias.float() if self.use_bias else None,
339
+ -torch.exp(self.A_log.float()),
340
+ None, # input-dependent B
341
+ None, # input-dependent C
342
+ self.D.float(),
343
+ delta_bias=self.dt_proj.bias.float(),
344
+ delta_softplus=True,
345
+ )
346
+
347
+ else:
348
+ hidden_states, gate = projected_states.chunk(2, dim=1)
349
+ conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
350
+
351
+ # print("NON ZERO", hidden_states.shape)
352
+ # 2. Convolution sequence transformation
353
+ if cache_params is not None and cache_params.seqlen_offset > 0:
354
+ hidden_states = causal_conv1d_update(
355
+ hidden_states.squeeze(-1),
356
+ cache_params.conv_states[self.layer_idx],
357
+ conv_weights,
358
+ self.conv1d.bias,
359
+ self.activation,
360
+ )
361
+ hidden_states = hidden_states.unsqueeze(-1)
362
+ else:
363
+ if cache_params is not None:
364
+ conv_states = nn.functional.pad(
365
+ hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
366
+ )
367
+ # print(conv_states)
368
+ cache_params.conv_states[self.layer_idx].copy_(conv_states)
369
+
370
+ hidden_states = causal_conv1d_fn(
371
+ hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
372
+ )
373
+ # 3. State Space Model sequence transformation
374
+ # 3.a. input varying initialization of time_step, B and C
375
+ ssm_parameters = x_proj(hidden_states.transpose(1, 2))
376
+ time_step, B, C = torch.split(
377
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
378
+ )
379
+ discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
380
+
381
+ A = -torch.exp(self.A_log.float())
382
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
383
+ time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
384
+
385
+ if cache_params is not None and cache_params.seqlen_offset > 0:
386
+ scan_outputs = selective_state_update(
387
+ cache_params.ssm_states[self.layer_idx],
388
+ hidden_states[..., 0],
389
+ discrete_time_step[..., 0],
390
+ A,
391
+ B[:, 0],
392
+ C[:, 0],
393
+ self.D,
394
+ gate[..., 0],
395
+ time_proj_bias,
396
+ dt_softplus=True,
397
+ ).unsqueeze(-1)
398
+ else:
399
+ # print("A.shape",A.shape)
400
+ # print("hidden_states", hidden_states.shape)
401
+ # print("discrete_time_step", discrete_time_step.shape)
402
+ # print("GATE.SHAOE", gate.shape)
403
+
404
+ scan_outputs, ssm_state = selective_scan_fn(
405
+ hidden_states,
406
+ discrete_time_step,
407
+ A,
408
+ B.transpose(1, 2),
409
+ C.transpose(1, 2),
410
+ self.D.float(),
411
+ gate,
412
+ time_proj_bias,
413
+ delta_softplus=True,
414
+ return_last_state=True,
415
+ )
416
+ # print("SCANOUTPUTS | SSMSTATE", scan_outputs.shape, ssm_state.shape)
417
+ if ssm_state is not None and cache_params is not None:
418
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
419
+
420
+ # 4. Final linear projection
421
+ contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
422
+ return contextualized_states
423
+
424
+ # fmt: off
425
+ def slow_forward(self, input_states, x_proj, cache_params: Optional[MoSMambaCache]=None):
426
+ batch_size, seq_len, _ = input_states.shape
427
+ dtype = input_states.dtype
428
+ # 1. Gated MLP's linear projection
429
+ projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
430
+ hidden_states, gate = projected_states.chunk(2, dim=1)
431
+
432
+ # 2. Convolution sequence transformation
433
+ if cache_params is not None:
434
+ ssm_state = cache_params.ssm_states[self.layer_idx].clone()
435
+ if cache_params.seqlen_offset > 0:
436
+ conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
437
+ conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
438
+ conv_state[:, :, -1] = hidden_states[:, :, 0]
439
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
440
+ hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
441
+ if self.use_conv_bias:
442
+ hidden_states += self.conv1d.bias
443
+ hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
444
+ else:
445
+ conv_state = nn.functional.pad(
446
+ hidden_states,
447
+ (self.conv_kernel_size - hidden_states.shape[-1], 0)
448
+ )
449
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
450
+ if hidden_states.shape[-1] == 0:
451
+ hidden_states = hidden_states.permute(2,1,0)
452
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
453
+ else:
454
+ ssm_state = torch.zeros(
455
+ (batch_size, self.intermediate_size, self.ssm_state_size),
456
+ device=hidden_states.device, dtype=dtype
457
+ )
458
+ # print(hidden_states.shape)
459
+ # print(self.conv1d)
460
+ if hidden_states.shape[-1] == 0:
461
+ hidden_states = hidden_states.permute(2,1,0)
462
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
463
+
464
+ # 3. State Space Model sequence transformation
465
+ # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
466
+ ssm_parameters = x_proj(hidden_states.transpose(1, 2))
467
+ time_step, B, C = torch.split(
468
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
469
+ )
470
+ discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
471
+ discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
472
+
473
+ # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
474
+ A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
475
+ discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
476
+ discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediade_size, seq_len, ssm_state_size]
477
+ deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
478
+
479
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
480
+ scan_outputs = []
481
+ for i in range(seq_len):
482
+ ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
483
+ scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
484
+ scan_outputs.append(scan_output[:, :, 0])
485
+ # print(scan_outputs)
486
+ scan_output = torch.stack(scan_outputs, dim=-1) if scan_outputs else torch.tensor(scan_outputs) # [batch, seq_len, intermediade_size]
487
+ scan_output = scan_output + (hidden_states * self.D[None, :, None])
488
+ scan_output = (scan_output * self.act(gate))
489
+
490
+ if cache_params is not None:
491
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
492
+
493
+ # 4. Final linear projection
494
+ contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
495
+ return contextualized_states
496
+
497
+ def forward(self, hidden_states, cache_params: Optional[MoSMambaCache] = None):
498
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
499
+
500
+ if self.training and self.jitter_noise > 0:
501
+ hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
502
+
503
+ # print('BATCH_SIZE | SEQ LENGTH | HID DIM:',batch_size, sequence_length, hidden_dim)
504
+
505
+ hidden_states = hidden_states.view(-1, hidden_dim)
506
+
507
+ router_logits = self.gate(hidden_states)
508
+
509
+ # print("ROUTER LOGITS:", router_logits, router_logits.size())
510
+
511
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
512
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
513
+ # print("ROUTING WEIGHTS", routing_weights, routing_weights.shape)
514
+ # print("SEL EXPERTS", selected_experts, selected_experts.shape)
515
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
516
+ # we cast back to the input dtype
517
+ routing_weights = routing_weights.to(hidden_states.dtype)
518
+
519
+ # print(routing_weights .shape)
520
+
521
+ final_hidden_states = torch.zeros(
522
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
523
+ )
524
+
525
+ # One hot encode the selected experts to create an expert mask
526
+ # this will be used to easily index which expert is going to be sollicitated
527
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_selectivities).permute(2, 1, 0)
528
+ # print("EXPERT MASK", expert_mask, expert_mask.shape)
529
+
530
+ # Loop over all available experts in the model and perform the computation on each expert
531
+ for expert_idx in range(self.num_selectivities):
532
+ # expert_layer = self.x_proj[expert_idx]
533
+ expert_layer = self.x_proj.get_submodule(f"w{expert_idx}")
534
+ # expert_layer = getattr(self, f'x_proj_{expert_idx}')
535
+ idx, top_x = torch.where(expert_mask[expert_idx])
536
+ # print("expert_mask[expert_idx]:",expert_mask[expert_idx], expert_mask[expert_idx].shape)
537
+
538
+
539
+ # Index the correct hidden states and compute the expert hidden state for
540
+ # the current expert. We need to make sure to multiply the output hidden
541
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
542
+ # print("TOP_x:",top_x)
543
+ # print("TOP X.SHAPE:",top_x.shape)
544
+ # print("HIDDEN STATES.SHAPE:",hidden_states.shape)
545
+ # print("HIDDEN STATES[NONE, TOPX].SHAPE:", hidden_states[None, top_x].shape)
546
+
547
+
548
+ # print("TOP_X | IDX", top_x, idx)
549
+
550
+ current_state = hidden_states[None, top_x]
551
+ # print("TOPX", top_x,top_x.shape)
552
+ # print("CURRENT_STATE",current_state.shape)
553
+ current_state = current_state.reshape(-1, hidden_dim)#.reshape(batch_size, sequence_length, hidden_dim )
554
+
555
+ # if current_state.shape[1] == 0:
556
+ # continue
557
+
558
+
559
+ # print("CURRENT_STATE",current_state)
560
+
561
+ # current_state = hidden_states.reshape(batch_size, sequence_length, hidden_dim )
562
+
563
+ # print(current_state.shape)
564
+ # if current_state.shape[0] < 1:
565
+ # print(current_state)
566
+ # current_state = current_state.reshape(batch_size, 1, hidden_dim)
567
+ # else:
568
+ # current_state = current_state.reshape(batch_size, sequence_length, hidden_dim)
569
+
570
+ # print("current_state.shape", current_state.shape, "ROUTING WEIGHTS",routing_weights[top_x, idx, None].shape)
571
+
572
+ current_state = current_state * routing_weights[top_x, idx, None]
573
+
574
+ # print("current_hidden_states.shape", current_state.shape)
575
+
576
+ current_hidden_states = current_state[None]
577
+
578
+
579
+
580
+
581
+ # print("current_hidden_states[none].shape", current_hidden_states.shape)
582
+
583
+ if current_hidden_states.shape[1] != 0:
584
+
585
+ if is_fast_path_available and "cuda" in expert_layer.weight.device.type:
586
+ # if is_fast_path_available and "cuda" in expert_layer.w2.weight.device.type:
587
+ current_hidden_states = self.cuda_kernels_forward(current_hidden_states, expert_layer, cache_params) * routing_weights[top_x, idx, None]
588
+ else:
589
+ current_hidden_states = self.slow_forward(current_hidden_states, expert_layer, cache_params) * routing_weights[top_x, idx, None]
590
+ # else:
591
+ # expert_layer.grad = torch.zeros_like(expert_layer.weight)
592
+ # current_hidden_states = expert_layer(current_state)
593
+
594
+ current_hidden_states = current_hidden_states.reshape(-1, hidden_dim)
595
+ # print(current_hidden_states.shape, final_hidden_states.shape)
596
+
597
+ # However `index_add_` only support torch tensors for indexing so we'll use
598
+ # the `top_x` tensor here.
599
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
600
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
601
+
602
+ return final_hidden_states, router_logits
603
+
604
+
605
+ class MoSMambaRMSNorm(nn.Module):
606
+ def __init__(self, hidden_size, eps=1e-6):
607
+ """
608
+ MoSMambaRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
609
+ """
610
+ super().__init__()
611
+ self.weight = nn.Parameter(torch.ones(hidden_size))
612
+ self.variance_epsilon = eps
613
+
614
+ def forward(self, hidden_states):
615
+ input_dtype = hidden_states.dtype
616
+ hidden_states = hidden_states.to(torch.float32)
617
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
618
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
619
+ return self.weight * hidden_states.to(input_dtype)
620
+
621
+
622
+ class MoSMambaBlock(nn.Module):
623
+ def __init__(self, config, layer_idx):
624
+ super().__init__()
625
+ self.config = config
626
+ self.layer_idx = layer_idx
627
+ self.residual_in_fp32 = config.residual_in_fp32
628
+ self.norm = MoSMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
629
+ self.mixer = MoSMambaMixer(config, layer_idx=layer_idx)
630
+
631
+ def forward(self, hidden_states, cache_params: Optional[MoSMambaCache] = None, output_router_logits:Optional[bool] = False):
632
+ residual = hidden_states
633
+ hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
634
+ if self.residual_in_fp32:
635
+ residual = residual.to(torch.float32)
636
+
637
+ hidden_states, router_logits = self.mixer(hidden_states, cache_params=cache_params)
638
+ hidden_states = residual + hidden_states
639
+ outputs = (hidden_states,)
640
+
641
+ if output_router_logits:
642
+ outputs += (router_logits,)
643
+ return outputs
644
+
645
+
646
+ class MoSMambaPreTrainedModel(PreTrainedModel):
647
+ """
648
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
649
+ models.
650
+ """
651
+
652
+ config_class = MoSMambaConfig
653
+ base_model_prefix = "backbone"
654
+ _no_split_modules = ["MoSMambaBlock"]
655
+ supports_gradient_checkpointing = True
656
+
657
+ def make_tensors_contiguous(self):
658
+ for name, param in self.named_parameters():
659
+ if not param.is_contiguous():
660
+ param.data = param.data.contiguous()
661
+
662
+ def save_pretrained(self, save_directory, **kwargs):
663
+ # Make tensors contiguous
664
+ self.make_tensors_contiguous()
665
+
666
+ # Call the original save_pretrained method
667
+ super().save_pretrained(save_directory, **kwargs)
668
+
669
+ def _init_weights(self, module):
670
+ """Initialize the weights."""
671
+ if isinstance(module, MoSMambaMixer):
672
+ module.A_log._no_weight_decay = True
673
+ module.D._no_weight_decay = True
674
+
675
+ dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
676
+ if self.config.time_step_init_scheme == "constant":
677
+ nn.init.constant_(module.dt_proj.weight, dt_init_std)
678
+ elif self.config.time_step_init_scheme == "random":
679
+ nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
680
+
681
+ dt = torch.exp(
682
+ torch.rand(self.config.intermediate_size)
683
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
684
+ + math.log(self.config.time_step_min)
685
+ ).clamp(min=self.config.time_step_floor)
686
+ # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
687
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
688
+ with torch.no_grad():
689
+ module.dt_proj.bias.copy_(inv_dt)
690
+ module.dt_proj.bias._no_reinit = True
691
+
692
+ if isinstance(module, nn.Linear):
693
+ if module.bias is not None:
694
+ if not getattr(module.bias, "_no_reinit", False):
695
+ nn.init.zeros_(module.bias)
696
+ elif isinstance(module, nn.Embedding):
697
+ nn.init.normal_(module.weight, std=self.config.initializer_range)
698
+
699
+ if self.config.rescale_prenorm_residual:
700
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
701
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
702
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
703
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
704
+ #
705
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
706
+ for name, p in module.named_parameters():
707
+ if name in ["out_proj.weight"]:
708
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
709
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
710
+ # We need to reinit p since this code could be called multiple times
711
+ # Having just p *= scale would repeatedly scale it down
712
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
713
+ with torch.no_grad():
714
+ p /= math.sqrt(self.config.num_layers)
715
+
716
+
717
+ @dataclass
718
+ class MoSMambaOutput(ModelOutput):
719
+ """
720
+ Class for the MAMBA model outputs.
721
+
722
+ Args:
723
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
724
+ Sequence of hidden-states at the output of the last layer of the model.
725
+ cache_params (`MoSMambaCache`):
726
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
727
+ avoid providing the old `input_ids`.
728
+
729
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
730
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
731
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
732
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
733
+
734
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
735
+ """
736
+
737
+ last_hidden_state: Optional[torch.FloatTensor] = None
738
+ cache_params: Optional[MoSMambaCache] = None
739
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
740
+ router_logits: Optional[Tuple[torch.FloatTensor]] = None
741
+
742
+
743
+ @dataclass
744
+ class MoSMambaCausalLMOutput(ModelOutput):
745
+ """
746
+ Base class for causal language model (or autoregressive) outputs.
747
+
748
+ Args:
749
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
750
+ Language modeling loss (for next-token prediction).
751
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
752
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
753
+ cache_params (`MoSMambaCache`):
754
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
755
+ avoid providing the old `input_ids`.
756
+
757
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
758
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
759
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
760
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
761
+
762
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
763
+ """
764
+
765
+ loss: Optional[torch.FloatTensor] = None
766
+ logits: Optional[torch.FloatTensor] = None
767
+ cache_params: Optional[MoSMambaCache] = None
768
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
769
+ router_logits: Optional[Tuple[torch.FloatTensor]] = None
770
+
771
+
772
+ class MoSMambaModel(MoSMambaPreTrainedModel):
773
+ def __init__(self, config):
774
+ super().__init__(config)
775
+
776
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
777
+ self.layers = nn.ModuleList([MoSMambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
778
+
779
+ self.gradient_checkpointing = False
780
+ self.norm_f = MoSMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
781
+ # Initialize weights and apply final processing
782
+ self._register_load_state_dict_pre_hook(self.load_hook)
783
+ self.post_init()
784
+ self.config.output_router_logits = True
785
+
786
+ def load_hook(self, state_dict, prefix, *args):
787
+ for k in state_dict:
788
+ if "embedding." in k:
789
+ state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
790
+ break
791
+
792
+ def get_input_embeddings(self):
793
+ return self.embeddings
794
+
795
+ def set_input_embeddings(self, new_embeddings):
796
+ self.embeddings = new_embeddings
797
+
798
+ def forward(
799
+ self,
800
+ input_ids: Optional[torch.LongTensor] = None,
801
+ inputs_embeds: Optional[torch.LongTensor] = None,
802
+ cache_params: Optional[MoSMambaCache] = None,
803
+ use_cache: Optional[bool] = None,
804
+ output_hidden_states: Optional[bool] = None,
805
+ output_router_logits: Optional[bool] = None,
806
+ return_dict: Optional[bool] = None,
807
+ **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
808
+ ) -> Union[Tuple, MoSMambaOutput]:
809
+ output_hidden_states = (
810
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
811
+ )
812
+ output_router_logits = (
813
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
814
+ )
815
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
816
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
817
+
818
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
819
+ raise ValueError(
820
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
821
+ )
822
+
823
+ if inputs_embeds is None:
824
+ inputs_embeds = self.embeddings(input_ids)
825
+
826
+ if self.gradient_checkpointing and self.training and use_cache:
827
+ use_cache = False
828
+
829
+ if cache_params is None and use_cache:
830
+ cache_params = MoSMambaCache(
831
+ self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
832
+ )
833
+
834
+ hidden_states = inputs_embeds
835
+ all_hidden_states = () if output_hidden_states else None
836
+ all_router_logits = () if output_router_logits else None
837
+ for mixer_block in self.layers:
838
+ if self.gradient_checkpointing and self.training:
839
+ layer_outputs = self._gradient_checkpointing_func(mixer_block.__call__, hidden_states, cache_params, output_router_logits)
840
+ else:
841
+ layer_outputs = mixer_block(hidden_states, cache_params=cache_params,output_router_logits=output_router_logits)
842
+
843
+ hidden_states = layer_outputs[0]
844
+
845
+ if output_hidden_states:
846
+ all_hidden_states = all_hidden_states + (hidden_states,)
847
+
848
+ if output_router_logits:
849
+ all_router_logits += (layer_outputs[-1],)
850
+
851
+ if use_cache:
852
+ cache_params.seqlen_offset += inputs_embeds.shape[1]
853
+
854
+ hidden_states = self.norm_f(hidden_states)
855
+
856
+ if output_hidden_states:
857
+ all_hidden_states = all_hidden_states + (hidden_states,)
858
+
859
+
860
+ if not return_dict:
861
+ return tuple(v for v in [hidden_states, cache_params, all_hidden_states, all_router_logits] if v is not None)
862
+
863
+ return MoSMambaOutput(
864
+ last_hidden_state=hidden_states,
865
+ cache_params=cache_params if use_cache else None,
866
+ hidden_states=all_hidden_states,
867
+ router_logits=all_router_logits,
868
+ )
869
+
870
+
871
+ class MoSMambaForCausalLM(MoSMambaPreTrainedModel):
872
+ _tied_weights_keys = ["lm_head.weight"]
873
+
874
+ def __init__(self, config):
875
+ super().__init__(config)
876
+ self.backbone = MoSMambaModel(config)
877
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
878
+ self.num_selectivities = 6
879
+ self.num_selectivities_per_tok = 2
880
+ self.router_aux_loss_coef = 0.02
881
+ # Initialize weights and apply final processing
882
+ self.post_init()
883
+
884
+ def get_output_embeddings(self):
885
+ return self.lm_head
886
+
887
+ def set_output_embeddings(self, new_embeddings):
888
+ self.lm_head = new_embeddings
889
+
890
+ def get_input_embeddings(self):
891
+ return self.backbone.get_input_embeddings()
892
+
893
+ def set_input_embeddings(self, new_embeddings):
894
+ return self.backbone.set_input_embeddings(new_embeddings)
895
+
896
+ def _update_model_kwargs_for_generation(
897
+ self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
898
+ ) -> Dict[str, Any]:
899
+ model_kwargs["cache_params"] = outputs.get("cache_params", None)
900
+ return model_kwargs
901
+
902
+ def prepare_inputs_for_generation(
903
+ self, input_ids, cache_params: Optional[MoSMambaCache] = None, inputs_embeds=None, attention_mask=None, output_router_logits=False, **kwargs
904
+ ):
905
+ # only last token for inputs_ids if the state is passed along.
906
+ if cache_params is not None:
907
+ input_ids = input_ids[:, -1].unsqueeze(-1)
908
+
909
+ if inputs_embeds is not None and cache_params is None:
910
+ model_inputs = {"inputs_embeds": inputs_embeds}
911
+ else:
912
+ model_inputs = {"input_ids": input_ids}
913
+
914
+ model_inputs["cache_params"] = cache_params
915
+ model_inputs['output_router_logits'] = output_router_logits
916
+ return model_inputs
917
+
918
+
919
+ def forward(
920
+ self,
921
+ input_ids: Optional[torch.LongTensor] = None,
922
+ inputs_embeds: Optional[torch.FloatTensor] = None,
923
+ cache_params: Optional[MoSMambaCache] = None,
924
+ labels: Optional[torch.LongTensor] = None,
925
+ output_hidden_states: Optional[bool] = None,
926
+ output_router_logits: Optional[bool] = None,
927
+ return_dict: Optional[bool] = None,
928
+ use_cache: Optional[bool] = None,
929
+ **kwargs, # for now we need this for generation
930
+ ) -> Union[Tuple, MoSMambaCausalLMOutput]:
931
+ r"""
932
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
933
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
934
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
935
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
936
+ """
937
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
938
+
939
+ output_router_logits = (
940
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
941
+ )
942
+
943
+ mamba_outputs = self.backbone(
944
+ input_ids,
945
+ cache_params=cache_params,
946
+ inputs_embeds=inputs_embeds,
947
+ output_hidden_states=output_hidden_states,
948
+ return_dict=return_dict,
949
+ use_cache=use_cache,
950
+ )
951
+ hidden_states = mamba_outputs[0]
952
+
953
+ logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
954
+
955
+ loss = None
956
+ if labels is not None:
957
+ # move labels to correct device to enable model parallelism
958
+ labels = labels.to(logits.device)
959
+ # Shift so that tokens < n predict n
960
+ shift_logits = logits[..., :-1, :].contiguous()
961
+ shift_labels = labels[..., 1:].contiguous()
962
+ # Flatten the tokens
963
+ loss_fct = CrossEntropyLoss()
964
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
965
+
966
+ aux_loss = None
967
+ if output_router_logits:
968
+ aux_loss = load_balancing_loss_func(
969
+ mamba_outputs.router_logits if return_dict else mamba_outputs[-1],
970
+ self.num_selectivities,
971
+ self.num_selectivities_per_tok,
972
+ # attention_mask,
973
+ )
974
+ if labels is not None:
975
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
976
+
977
+ # print("AUX LOSS:", aux_loss)
978
+ # print("LOSS:", loss)
979
+
980
+ if not return_dict:
981
+ output = (logits,) + mamba_outputs[1:]
982
+ if output_router_logits:
983
+ output = (aux_loss,) + output
984
+ return (loss,) + output if loss is not None else output
985
+
986
+ # if not return_dict:
987
+ # output = (logits,) + mamba_outputs[1:]
988
+ # return ((loss,) + output) if loss is not None else output
989
+
990
+ return MoSMambaCausalLMOutput(
991
+ loss=loss,
992
+ logits=logits,
993
+ cache_params=mamba_outputs.cache_params,
994
+ hidden_states=mamba_outputs.hidden_states,
995
+ router_logits=mamba_outputs.router_logits,
996
  )