jonathanjordan21 commited on
Commit
8863e88
·
verified ·
1 Parent(s): f5cb409

Upload MoSMambaForCausalLM

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "state-spaces/mamba-130m-hf",
3
+ "architectures": [
4
+ "MoSMambaForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_mos_mamba.MoSMambaConfig",
8
+ "AutoModelForCausalLM": "modeling_mos_mamba.MoSMambaForCausalLM"
9
+ },
10
+ "bos_token_id": 0,
11
+ "conv_kernel": 4,
12
+ "d_inner": 1536,
13
+ "d_model": 768,
14
+ "eos_token_id": 0,
15
+ "expand": 2,
16
+ "fused_add_norm": true,
17
+ "hidden_act": "silu",
18
+ "hidden_size": 768,
19
+ "initializer_range": 0.1,
20
+ "intermediate_size": 1536,
21
+ "layer_norm_epsilon": 1e-05,
22
+ "model_type": "MoSMamba",
23
+ "n_layer": 24,
24
+ "num_hidden_layers": 24,
25
+ "num_selectivities": 6,
26
+ "num_selectivities_per_tok": 2,
27
+ "output_router_logits": true,
28
+ "pad_token_id": 0,
29
+ "pad_vocab_size_multiple": 8,
30
+ "rescale_prenorm_residual": false,
31
+ "residual_in_fp32": true,
32
+ "rms_norm": true,
33
+ "ssm_cfg": {},
34
+ "state_size": 16,
35
+ "time_step_floor": 0.0001,
36
+ "time_step_init_scheme": "random",
37
+ "time_step_max": 0.1,
38
+ "time_step_min": 0.001,
39
+ "time_step_rank": 48,
40
+ "time_step_scale": 1.0,
41
+ "torch_dtype": "float32",
42
+ "transformers_version": "4.41.2",
43
+ "use_bias": false,
44
+ "use_cache": true,
45
+ "use_conv_bias": true,
46
+ "vocab_size": 50280
47
+ }
configuration_mos_mamba.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class MoSMambaConfig(PretrainedConfig):
11
+
12
+ model_type = "MoSMamba"
13
+
14
+ def __init__(
15
+ self,
16
+ vocab_size=50280,
17
+ hidden_size=768,
18
+ state_size=16,
19
+ num_selectivities=6,
20
+ num_selectivities_per_tok=2,
21
+ num_hidden_layers=32,
22
+ layer_norm_epsilon=1e-5,
23
+ pad_token_id=0,
24
+ bos_token_id=0,
25
+ eos_token_id=0,
26
+ expand=2,
27
+ conv_kernel=4,
28
+ use_bias=False,
29
+ use_conv_bias=True,
30
+ hidden_act="silu",
31
+ initializer_range=0.1,
32
+ residual_in_fp32=True,
33
+ time_step_rank="auto",
34
+ time_step_scale=1.0,
35
+ time_step_min=0.001,
36
+ time_step_max=0.1,
37
+ time_step_init_scheme="random",
38
+ time_step_floor=1e-4,
39
+ rescale_prenorm_residual=False,
40
+ use_cache=True,
41
+ **kwargs,
42
+ ):
43
+ self.vocab_size = vocab_size
44
+ self.hidden_size = hidden_size
45
+ self.state_size = state_size
46
+ self.num_hidden_layers = num_hidden_layers
47
+ self.layer_norm_epsilon = layer_norm_epsilon
48
+ self.conv_kernel = conv_kernel
49
+ self.expand = expand
50
+ self.intermediate_size = int(expand * self.hidden_size)
51
+ self.bos_token_id = bos_token_id
52
+ self.eos_token_id = eos_token_id
53
+ self.pad_token_id = pad_token_id
54
+ self.use_bias = use_bias
55
+ self.use_conv_bias = use_conv_bias
56
+ self.hidden_act = hidden_act
57
+ self.initializer_range = initializer_range
58
+ self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
59
+ self.time_step_scale = time_step_scale
60
+ self.time_step_min = time_step_min
61
+ self.time_step_max = time_step_max
62
+ self.time_step_init_scheme = time_step_init_scheme
63
+ self.time_step_floor = time_step_floor
64
+ self.rescale_prenorm_residual = rescale_prenorm_residual
65
+ self.residual_in_fp32 = residual_in_fp32
66
+ self.use_cache = use_cache
67
+
68
+ self.num_selectivities = num_selectivities
69
+ self.num_selectivities_per_tok = num_selectivities_per_tok
70
+
71
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 0,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.41.2"
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed791e9ba38889f46e5b0fbaa3bdbd9243404567176f369073f7ebaf5b5ddba8
3
+ size 576008736
modeling_mos_mamba.py ADDED
@@ -0,0 +1,984 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch MAMBA model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import CrossEntropyLoss
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.modeling_utils import PreTrainedModel
28
+ from transformers.utils import ModelOutput
29
+ from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available
30
+ from .configuration_mos_mamba import MoSMambaConfig
31
+
32
+ import torch.nn.functional as F
33
+
34
+
35
+ if is_mamba_ssm_available():
36
+ from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
37
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
38
+ else:
39
+ selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
40
+
41
+ if is_causal_conv1d_available():
42
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
43
+ else:
44
+ causal_conv1d_update, causal_conv1d_fn = None, None
45
+
46
+ is_fast_path_available = all(
47
+ (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
48
+ )
49
+
50
+ _CHECKPOINT_FOR_DOC = "state-spaces/mamba-130m-hf"
51
+ _CONFIG_FOR_DOC = "MoSMambaConfig"
52
+
53
+
54
+ def load_balancing_loss_func(
55
+ gate_logits: torch.Tensor, num_selectivities: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
56
+ ) -> float:
57
+ r"""
58
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
59
+
60
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
61
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
62
+ experts is too unbalanced.
63
+
64
+ Args:
65
+ gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
66
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
67
+ shape [batch_size X sequence_length, num_selectivities].
68
+ attention_mask (`torch.Tensor`, None):
69
+ The attention_mask used in forward function
70
+ shape [batch_size X sequence_length] if not None.
71
+ num_selectivities (`int`, *optional*):
72
+ Number of experts
73
+
74
+ Returns:
75
+ The auxiliary loss.
76
+ """
77
+ if gate_logits is None or not isinstance(gate_logits, tuple):
78
+ return 0
79
+
80
+ if isinstance(gate_logits, tuple):
81
+ compute_device = gate_logits[0].device
82
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
83
+
84
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
85
+
86
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
87
+
88
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_selectivities)
89
+
90
+ if attention_mask is None:
91
+ # Compute the percentage of tokens routed to each experts
92
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
93
+
94
+ # Compute the average probability of routing to these experts
95
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
96
+ else:
97
+ batch_size, sequence_length = attention_mask.shape
98
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
99
+
100
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
101
+ expert_attention_mask = (
102
+ attention_mask[None, :, :, None, None]
103
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_selectivities))
104
+ .reshape(-1, top_k, num_selectivities)
105
+ .to(compute_device)
106
+ )
107
+
108
+ # Compute the percentage of tokens routed to each experts
109
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
110
+ expert_attention_mask, dim=0
111
+ )
112
+
113
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
114
+ router_per_expert_attention_mask = (
115
+ attention_mask[None, :, :, None]
116
+ .expand((num_hidden_layers, batch_size, sequence_length, num_selectivities))
117
+ .reshape(-1, num_selectivities)
118
+ .to(compute_device)
119
+ )
120
+
121
+ # Compute the average probability of routing to these experts
122
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
123
+ router_per_expert_attention_mask, dim=0
124
+ )
125
+
126
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
127
+ return overall_loss * num_selectivities
128
+
129
+
130
+ class MixtralBlockSparseTop2MLP(nn.Module):
131
+ def __init__(self, intermediate_size, hidden_size, ssm_size):
132
+ super().__init__()
133
+ self.ffn_dim = intermediate_size
134
+ self.hidden_dim = hidden_size
135
+ self.ssm_dim = ssm_size
136
+
137
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
138
+ self.w2 = nn.Linear(self.ffn_dim, self.ssm_dim, bias=False)
139
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
140
+ self.w4 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
141
+
142
+ self.act_fn = ACT2FN['silu']
143
+
144
+ def forward(self, hidden_states):
145
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
146
+ current_hidden_states = self.w4(current_hidden_states)
147
+
148
+ return current_hidden_states
149
+
150
+ class MixtureOfSelectivity(nn.Module):
151
+ def __init__(self, intermediate_size, ssm_size):
152
+ super().__init__()
153
+ self.intermediate_size = intermediate_size
154
+ self.ssm_dim = ssm_size
155
+
156
+ # self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
157
+ self.w2 = nn.Linear(self.intermediate_size, self.ssm_dim, bias=False)
158
+
159
+
160
+ def forward(self, hidden_states):
161
+ # current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
162
+ return self.w2(hidden_states)
163
+
164
+ class MoSMambaCache:
165
+ """
166
+ Arguments:
167
+ config: MoSMambaConfig
168
+ batch_size: int
169
+ dtype: torch.dtype
170
+ device: torch.device
171
+
172
+ Attributes:
173
+ seqlen_offset: int
174
+ dtype: torch.dtype
175
+ conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel_size]
176
+ ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size]
177
+ """
178
+
179
+ def __init__(
180
+ self, config: MoSMambaConfig, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
181
+ ):
182
+ self.seqlen_offset = 0
183
+ self.dtype = dtype
184
+ intermediate_size = config.intermediate_size
185
+ ssm_state_size = config.state_size
186
+ conv_kernel_size = config.conv_kernel
187
+
188
+ self.conv_states = {
189
+ i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
190
+ for i in range(config.num_hidden_layers)
191
+ }
192
+ self.ssm_states = {
193
+ i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
194
+ for i in range(config.num_hidden_layers)
195
+ }
196
+
197
+
198
+ class MoSMambaMixer(nn.Module):
199
+ """
200
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
201
+ A, D are input independent (see MoSMamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
202
+ ∆, B, C are input-dependent (this is a key difference between MoSMamba and the linear time invariant S4,
203
+ and is why MoSMamba is called **selective** state spaces)
204
+ """
205
+
206
+ def __init__(self, config: MoSMambaConfig, layer_idx: int):
207
+ super().__init__()
208
+ self.hidden_size = config.hidden_size
209
+ self.ssm_state_size = config.state_size
210
+ self.conv_kernel_size = config.conv_kernel
211
+ self.intermediate_size = config.intermediate_size
212
+ self.time_step_rank = int(config.time_step_rank)
213
+ self.layer_idx = layer_idx
214
+ self.use_conv_bias = config.use_conv_bias
215
+ self.conv1d = nn.Conv1d(
216
+ in_channels=self.intermediate_size,
217
+ out_channels=self.intermediate_size,
218
+ bias=config.use_conv_bias,
219
+ kernel_size=config.conv_kernel,
220
+ groups=self.intermediate_size,
221
+ padding=config.conv_kernel - 1,
222
+ )
223
+
224
+ self.activation = config.hidden_act
225
+ self.act = ACT2FN[config.hidden_act]
226
+
227
+ # num experts
228
+ self.num_selectivities = config.num_selectivities
229
+
230
+ # num selected experts
231
+ self.top_k = config.num_selectivities_per_tok
232
+
233
+ # projection of the input hidden states
234
+ self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
235
+ # selective projection used to make dt, B and C input dependant
236
+ # self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False
237
+
238
+ # self.x_proj = nn.ModuleList([nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) for _ in range(self.num_selectivities)])
239
+ # for i in range(self.num_selectivities):
240
+ # self.x_proj.add_module("x_proj_"+str(i), nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False))
241
+
242
+ # self.x_proj_0 = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
243
+ # self.x_proj_1 = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
244
+ # self.x_proj_2 = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
245
+ # self.x_proj_3 = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
246
+ # self.x_proj_4 = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
247
+ # self.x_proj_5 = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
248
+
249
+
250
+ # self.x_proj2 = nn.ModuleList([MixtralBlockSparseTop2MLP(self.intermediate_size,self.hidden_size, self.time_step_rank + self.ssm_state_size * 2) for _ in range(self.num_selectivities)])
251
+ self.x_proj = nn.ModuleList()
252
+ for i in range(self.num_selectivities):
253
+ self.x_proj.add_module(f"w{i}",nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False))
254
+
255
+ self.gate = nn.Linear(self.hidden_size, self.num_selectivities, bias=False)
256
+
257
+ # time step projection (discretization)
258
+ self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
259
+
260
+ # S4D real initialization. These are not discretized!
261
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
262
+ A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
263
+ A = A.expand(self.intermediate_size, -1).contiguous()
264
+
265
+ self.A_log = nn.Parameter(torch.log(A))
266
+ self.D = nn.Parameter(torch.ones(self.intermediate_size))
267
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
268
+ self.use_bias = config.use_bias
269
+
270
+ self.jitter_noise = 0.001
271
+
272
+ self.register_parameter("A_log", self.A_log)
273
+ self.register_parameter("D", self.D)
274
+
275
+ # for i in enumerate(self.x_proj):
276
+ # self.register_parameter("x_proj_"+str(i), x)
277
+
278
+
279
+ def cuda_kernels_forward(self, hidden_states: torch.Tensor, x_proj, cache_params: Optional[MoSMambaCache] = None):
280
+ # 1. Gated MLP's linear projection
281
+ # router_logits =
282
+ batch_size, seq_len, _ = hidden_states.shape
283
+
284
+ projected_states = self.in_proj(hidden_states).transpose(1, 2)
285
+
286
+ if projected_states.shape[-1] == 0:
287
+ hidden_states, gate = projected_states.chunk(2, dim=1)
288
+ dtype = hidden_states.dtype
289
+
290
+ if cache_params is not None:
291
+ ssm_state = cache_params.ssm_states[self.layer_idx].clone()
292
+ if cache_params.seqlen_offset > 0:
293
+ conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
294
+ conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
295
+ conv_state[:, :, -1] = hidden_states[:, :, 0]
296
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
297
+ hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
298
+ if self.use_conv_bias:
299
+ hidden_states += self.conv1d.bias
300
+ hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
301
+ else:
302
+ conv_state = nn.functional.pad(
303
+ hidden_states,
304
+ (self.conv_kernel_size - hidden_states.shape[-1], 0)
305
+ )
306
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
307
+ if hidden_states.shape[-1] == 0:
308
+ hidden_states = hidden_states.permute(2,1,0)
309
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
310
+ else:
311
+ ssm_state = torch.zeros(
312
+ (batch_size, self.intermediate_size, self.ssm_state_size),
313
+ device=hidden_states.device, dtype=dtype
314
+ )
315
+ # print(hidden_states.shape)
316
+ # print(self.conv1d)
317
+ if hidden_states.shape[-1] == 0:
318
+ hidden_states = hidden_states.permute(2,1,0)
319
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
320
+
321
+ scan_output = (hidden_states * self.D[None, :, None])
322
+ scan_output = (scan_output * self.act(gate))
323
+ if cache_params is not None:
324
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
325
+
326
+ # 4. Final linear projection
327
+ contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
328
+ return contextualized_states
329
+
330
+ elif self.training and cache_params is None: # Doesn't support outputting the states -> used for training
331
+ contextualized_states = mamba_inner_fn(
332
+ projected_states,
333
+ self.conv1d.weight,
334
+ self.conv1d.bias if self.use_conv_bias else None,
335
+ x_proj.weight,
336
+ self.dt_proj.weight,
337
+ self.out_proj.weight,
338
+ self.out_proj.bias.float() if self.use_bias else None,
339
+ -torch.exp(self.A_log.float()),
340
+ None, # input-dependent B
341
+ None, # input-dependent C
342
+ self.D.float(),
343
+ delta_bias=self.dt_proj.bias.float(),
344
+ delta_softplus=True,
345
+ )
346
+
347
+ else:
348
+ hidden_states, gate = projected_states.chunk(2, dim=1)
349
+ conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
350
+
351
+ # print("NON ZERO", hidden_states.shape)
352
+ # 2. Convolution sequence transformation
353
+ if cache_params is not None and cache_params.seqlen_offset > 0:
354
+ hidden_states = causal_conv1d_update(
355
+ hidden_states.squeeze(-1),
356
+ cache_params.conv_states[self.layer_idx],
357
+ conv_weights,
358
+ self.conv1d.bias,
359
+ self.activation,
360
+ )
361
+ hidden_states = hidden_states.unsqueeze(-1)
362
+ else:
363
+ if cache_params is not None:
364
+ conv_states = nn.functional.pad(
365
+ hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
366
+ )
367
+ # print(conv_states)
368
+ cache_params.conv_states[self.layer_idx].copy_(conv_states)
369
+
370
+ hidden_states = causal_conv1d_fn(
371
+ hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
372
+ )
373
+ # 3. State Space Model sequence transformation
374
+ # 3.a. input varying initialization of time_step, B and C
375
+ ssm_parameters = x_proj(hidden_states.transpose(1, 2))
376
+ time_step, B, C = torch.split(
377
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
378
+ )
379
+ discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
380
+
381
+ A = -torch.exp(self.A_log.float())
382
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
383
+ time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
384
+
385
+ if cache_params is not None and cache_params.seqlen_offset > 0:
386
+ scan_outputs = selective_state_update(
387
+ cache_params.ssm_states[self.layer_idx],
388
+ hidden_states[..., 0],
389
+ discrete_time_step[..., 0],
390
+ A,
391
+ B[:, 0],
392
+ C[:, 0],
393
+ self.D,
394
+ gate[..., 0],
395
+ time_proj_bias,
396
+ dt_softplus=True,
397
+ ).unsqueeze(-1)
398
+ else:
399
+ # print("A.shape",A.shape)
400
+ # print("hidden_states", hidden_states.shape)
401
+ # print("discrete_time_step", discrete_time_step.shape)
402
+ # print("GATE.SHAOE", gate.shape)
403
+
404
+ scan_outputs, ssm_state = selective_scan_fn(
405
+ hidden_states,
406
+ discrete_time_step,
407
+ A,
408
+ B.transpose(1, 2),
409
+ C.transpose(1, 2),
410
+ self.D.float(),
411
+ gate,
412
+ time_proj_bias,
413
+ delta_softplus=True,
414
+ return_last_state=True,
415
+ )
416
+ # print("SCANOUTPUTS | SSMSTATE", scan_outputs.shape, ssm_state.shape)
417
+ if ssm_state is not None and cache_params is not None:
418
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
419
+
420
+ # 4. Final linear projection
421
+ contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
422
+ return contextualized_states
423
+
424
+ # fmt: off
425
+ def slow_forward(self, input_states, x_proj, cache_params: Optional[MoSMambaCache]=None):
426
+ batch_size, seq_len, _ = input_states.shape
427
+ dtype = input_states.dtype
428
+ # 1. Gated MLP's linear projection
429
+ projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
430
+ hidden_states, gate = projected_states.chunk(2, dim=1)
431
+
432
+ # 2. Convolution sequence transformation
433
+ if cache_params is not None:
434
+ ssm_state = cache_params.ssm_states[self.layer_idx].clone()
435
+ if cache_params.seqlen_offset > 0:
436
+ conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
437
+ conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
438
+ conv_state[:, :, -1] = hidden_states[:, :, 0]
439
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
440
+ hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
441
+ if self.use_conv_bias:
442
+ hidden_states += self.conv1d.bias
443
+ hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
444
+ else:
445
+ conv_state = nn.functional.pad(
446
+ hidden_states,
447
+ (self.conv_kernel_size - hidden_states.shape[-1], 0)
448
+ )
449
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
450
+ if hidden_states.shape[-1] == 0:
451
+ hidden_states = hidden_states.permute(2,1,0)
452
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
453
+ else:
454
+ ssm_state = torch.zeros(
455
+ (batch_size, self.intermediate_size, self.ssm_state_size),
456
+ device=hidden_states.device, dtype=dtype
457
+ )
458
+ # print(hidden_states.shape)
459
+ # print(self.conv1d)
460
+ if hidden_states.shape[-1] == 0:
461
+ hidden_states = hidden_states.permute(2,1,0)
462
+ hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
463
+
464
+ # 3. State Space Model sequence transformation
465
+ # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
466
+ ssm_parameters = x_proj(hidden_states.transpose(1, 2))
467
+ time_step, B, C = torch.split(
468
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
469
+ )
470
+ discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
471
+ discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
472
+
473
+ # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
474
+ A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
475
+ discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
476
+ discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediade_size, seq_len, ssm_state_size]
477
+ deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
478
+
479
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
480
+ scan_outputs = []
481
+ for i in range(seq_len):
482
+ ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
483
+ scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
484
+ scan_outputs.append(scan_output[:, :, 0])
485
+ # print(scan_outputs)
486
+ scan_output = torch.stack(scan_outputs, dim=-1) if scan_outputs else torch.tensor(scan_outputs) # [batch, seq_len, intermediade_size]
487
+ scan_output = scan_output + (hidden_states * self.D[None, :, None])
488
+ scan_output = (scan_output * self.act(gate))
489
+
490
+ if cache_params is not None:
491
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
492
+
493
+ # 4. Final linear projection
494
+ contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
495
+ return contextualized_states
496
+
497
+ def forward(self, hidden_states, cache_params: Optional[MoSMambaCache] = None):
498
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
499
+
500
+ if self.training and self.jitter_noise > 0:
501
+ hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
502
+
503
+ # print('BATCH_SIZE | SEQ LENGTH | HID DIM:',batch_size, sequence_length, hidden_dim)
504
+
505
+ hidden_states = hidden_states.view(-1, hidden_dim)
506
+
507
+ router_logits = self.gate(hidden_states)
508
+
509
+ # print("ROUTER LOGITS:", router_logits, router_logits.size())
510
+
511
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
512
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
513
+ # print("ROUTING WEIGHTS", routing_weights, routing_weights.shape)
514
+ # print("SEL EXPERTS", selected_experts, selected_experts.shape)
515
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
516
+ # we cast back to the input dtype
517
+ routing_weights = routing_weights.to(hidden_states.dtype)
518
+
519
+ # print(routing_weights .shape)
520
+
521
+ final_hidden_states = torch.zeros(
522
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
523
+ )
524
+
525
+ # One hot encode the selected experts to create an expert mask
526
+ # this will be used to easily index which expert is going to be sollicitated
527
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_selectivities).permute(2, 1, 0)
528
+ # print("EXPERT MASK", expert_mask, expert_mask.shape)
529
+
530
+ # Loop over all available experts in the model and perform the computation on each expert
531
+ for expert_idx in range(self.num_selectivities):
532
+ # expert_layer = self.x_proj[expert_idx]
533
+ expert_layer = self.x_proj.get_submodule(f"w{expert_idx}")
534
+ # expert_layer = getattr(self, f'x_proj_{expert_idx}')
535
+ idx, top_x = torch.where(expert_mask[expert_idx])
536
+ # print("expert_mask[expert_idx]:",expert_mask[expert_idx], expert_mask[expert_idx].shape)
537
+
538
+
539
+ # Index the correct hidden states and compute the expert hidden state for
540
+ # the current expert. We need to make sure to multiply the output hidden
541
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
542
+ # print("TOP_x:",top_x)
543
+ # print("TOP X.SHAPE:",top_x.shape)
544
+ # print("HIDDEN STATES.SHAPE:",hidden_states.shape)
545
+ # print("HIDDEN STATES[NONE, TOPX].SHAPE:", hidden_states[None, top_x].shape)
546
+
547
+
548
+ # print("TOP_X | IDX", top_x, idx)
549
+
550
+ current_state = hidden_states[None, top_x]
551
+ # print("TOPX", top_x,top_x.shape)
552
+ # print("CURRENT_STATE",current_state.shape)
553
+ current_state = current_state.reshape(-1, hidden_dim)#.reshape(batch_size, sequence_length, hidden_dim )
554
+
555
+ # if current_state.shape[1] == 0:
556
+ # continue
557
+
558
+
559
+ # print("CURRENT_STATE",current_state)
560
+
561
+ # current_state = hidden_states.reshape(batch_size, sequence_length, hidden_dim )
562
+
563
+ # print(current_state.shape)
564
+ # if current_state.shape[0] < 1:
565
+ # print(current_state)
566
+ # current_state = current_state.reshape(batch_size, 1, hidden_dim)
567
+ # else:
568
+ # current_state = current_state.reshape(batch_size, sequence_length, hidden_dim)
569
+
570
+ # print("current_state.shape", current_state.shape, "ROUTING WEIGHTS",routing_weights[top_x, idx, None].shape)
571
+
572
+ current_state = current_state * routing_weights[top_x, idx, None]
573
+
574
+ # print("current_hidden_states.shape", current_state.shape)
575
+
576
+ current_hidden_states = current_state[None]
577
+
578
+
579
+
580
+
581
+ # print("current_hidden_states[none].shape", current_hidden_states.shape)
582
+
583
+ if current_hidden_states.shape[1] != 0:
584
+
585
+ if is_fast_path_available and "cuda" in expert_layer.weight.device.type:
586
+ # if is_fast_path_available and "cuda" in expert_layer.w2.weight.device.type:
587
+ current_hidden_states = self.cuda_kernels_forward(current_hidden_states, expert_layer, cache_params) * routing_weights[top_x, idx, None]
588
+ else:
589
+ current_hidden_states = self.slow_forward(current_hidden_states, expert_layer, cache_params) * routing_weights[top_x, idx, None]
590
+ # else:
591
+ # expert_layer.grad = torch.zeros_like(expert_layer.weight)
592
+ # current_hidden_states = expert_layer(current_state)
593
+
594
+ current_hidden_states = current_hidden_states.reshape(-1, hidden_dim)
595
+ # print(current_hidden_states.shape, final_hidden_states.shape)
596
+
597
+ # However `index_add_` only support torch tensors for indexing so we'll use
598
+ # the `top_x` tensor here.
599
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
600
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
601
+
602
+ return final_hidden_states, router_logits
603
+
604
+
605
+ class MoSMambaRMSNorm(nn.Module):
606
+ def __init__(self, hidden_size, eps=1e-6):
607
+ """
608
+ MoSMambaRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
609
+ """
610
+ super().__init__()
611
+ self.weight = nn.Parameter(torch.ones(hidden_size))
612
+ self.variance_epsilon = eps
613
+
614
+ def forward(self, hidden_states):
615
+ input_dtype = hidden_states.dtype
616
+ hidden_states = hidden_states.to(torch.float32)
617
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
618
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
619
+ return self.weight * hidden_states.to(input_dtype)
620
+
621
+
622
+ class MoSMambaBlock(nn.Module):
623
+ def __init__(self, config, layer_idx):
624
+ super().__init__()
625
+ self.config = config
626
+ self.layer_idx = layer_idx
627
+ self.residual_in_fp32 = config.residual_in_fp32
628
+ self.norm = MoSMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
629
+ self.mixer = MoSMambaMixer(config, layer_idx=layer_idx)
630
+
631
+ def forward(self, hidden_states, cache_params: Optional[MoSMambaCache] = None, output_router_logits:Optional[bool] = False):
632
+ residual = hidden_states
633
+ hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
634
+ if self.residual_in_fp32:
635
+ residual = residual.to(torch.float32)
636
+
637
+ hidden_states, router_logits = self.mixer(hidden_states, cache_params=cache_params)
638
+ hidden_states = residual + hidden_states
639
+ outputs = (hidden_states,)
640
+
641
+ if output_router_logits:
642
+ outputs += (router_logits,)
643
+ return outputs
644
+
645
+
646
+ class MoSMambaPreTrainedModel(PreTrainedModel):
647
+ """
648
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
649
+ models.
650
+ """
651
+
652
+ config_class = MoSMambaConfig
653
+ base_model_prefix = "backbone"
654
+ _no_split_modules = ["MoSMambaBlock"]
655
+ supports_gradient_checkpointing = True
656
+
657
+ def _init_weights(self, module):
658
+ """Initialize the weights."""
659
+ if isinstance(module, MoSMambaMixer):
660
+ module.A_log._no_weight_decay = True
661
+ module.D._no_weight_decay = True
662
+
663
+ dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
664
+ if self.config.time_step_init_scheme == "constant":
665
+ nn.init.constant_(module.dt_proj.weight, dt_init_std)
666
+ elif self.config.time_step_init_scheme == "random":
667
+ nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
668
+
669
+ dt = torch.exp(
670
+ torch.rand(self.config.intermediate_size)
671
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
672
+ + math.log(self.config.time_step_min)
673
+ ).clamp(min=self.config.time_step_floor)
674
+ # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
675
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
676
+ with torch.no_grad():
677
+ module.dt_proj.bias.copy_(inv_dt)
678
+ module.dt_proj.bias._no_reinit = True
679
+
680
+ if isinstance(module, nn.Linear):
681
+ if module.bias is not None:
682
+ if not getattr(module.bias, "_no_reinit", False):
683
+ nn.init.zeros_(module.bias)
684
+ elif isinstance(module, nn.Embedding):
685
+ nn.init.normal_(module.weight, std=self.config.initializer_range)
686
+
687
+ if self.config.rescale_prenorm_residual:
688
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
689
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
690
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
691
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
692
+ #
693
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
694
+ for name, p in module.named_parameters():
695
+ if name in ["out_proj.weight"]:
696
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
697
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
698
+ # We need to reinit p since this code could be called multiple times
699
+ # Having just p *= scale would repeatedly scale it down
700
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
701
+ with torch.no_grad():
702
+ p /= math.sqrt(self.config.num_layers)
703
+
704
+
705
+ @dataclass
706
+ class MoSMambaOutput(ModelOutput):
707
+ """
708
+ Class for the MAMBA model outputs.
709
+
710
+ Args:
711
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
712
+ Sequence of hidden-states at the output of the last layer of the model.
713
+ cache_params (`MoSMambaCache`):
714
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
715
+ avoid providing the old `input_ids`.
716
+
717
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
718
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
719
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
720
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
721
+
722
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
723
+ """
724
+
725
+ last_hidden_state: Optional[torch.FloatTensor] = None
726
+ cache_params: Optional[MoSMambaCache] = None
727
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
728
+ router_logits: Optional[Tuple[torch.FloatTensor]] = None
729
+
730
+
731
+ @dataclass
732
+ class MoSMambaCausalLMOutput(ModelOutput):
733
+ """
734
+ Base class for causal language model (or autoregressive) outputs.
735
+
736
+ Args:
737
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
738
+ Language modeling loss (for next-token prediction).
739
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
740
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
741
+ cache_params (`MoSMambaCache`):
742
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
743
+ avoid providing the old `input_ids`.
744
+
745
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
746
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
747
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
748
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
749
+
750
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
751
+ """
752
+
753
+ loss: Optional[torch.FloatTensor] = None
754
+ logits: Optional[torch.FloatTensor] = None
755
+ cache_params: Optional[MoSMambaCache] = None
756
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
757
+ router_logits: Optional[Tuple[torch.FloatTensor]] = None
758
+
759
+
760
+ class MoSMambaModel(MoSMambaPreTrainedModel):
761
+ def __init__(self, config):
762
+ super().__init__(config)
763
+
764
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
765
+ self.layers = nn.ModuleList([MoSMambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
766
+
767
+ self.gradient_checkpointing = False
768
+ self.norm_f = MoSMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
769
+ # Initialize weights and apply final processing
770
+ self._register_load_state_dict_pre_hook(self.load_hook)
771
+ self.post_init()
772
+ self.config.output_router_logits = True
773
+
774
+ def load_hook(self, state_dict, prefix, *args):
775
+ for k in state_dict:
776
+ if "embedding." in k:
777
+ state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
778
+ break
779
+
780
+ def get_input_embeddings(self):
781
+ return self.embeddings
782
+
783
+ def set_input_embeddings(self, new_embeddings):
784
+ self.embeddings = new_embeddings
785
+
786
+ def forward(
787
+ self,
788
+ input_ids: Optional[torch.LongTensor] = None,
789
+ inputs_embeds: Optional[torch.LongTensor] = None,
790
+ cache_params: Optional[MoSMambaCache] = None,
791
+ use_cache: Optional[bool] = None,
792
+ output_hidden_states: Optional[bool] = None,
793
+ output_router_logits: Optional[bool] = None,
794
+ return_dict: Optional[bool] = None,
795
+ **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
796
+ ) -> Union[Tuple, MoSMambaOutput]:
797
+ output_hidden_states = (
798
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
799
+ )
800
+ output_router_logits = (
801
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
802
+ )
803
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
804
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
805
+
806
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
807
+ raise ValueError(
808
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
809
+ )
810
+
811
+ if inputs_embeds is None:
812
+ inputs_embeds = self.embeddings(input_ids)
813
+
814
+ if self.gradient_checkpointing and self.training and use_cache:
815
+ use_cache = False
816
+
817
+ if cache_params is None and use_cache:
818
+ cache_params = MoSMambaCache(
819
+ self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
820
+ )
821
+
822
+ hidden_states = inputs_embeds
823
+ all_hidden_states = () if output_hidden_states else None
824
+ all_router_logits = () if output_router_logits else None
825
+ for mixer_block in self.layers:
826
+ if self.gradient_checkpointing and self.training:
827
+ layer_outputs = self._gradient_checkpointing_func(mixer_block.__call__, hidden_states, cache_params, output_router_logits)
828
+ else:
829
+ layer_outputs = mixer_block(hidden_states, cache_params=cache_params,output_router_logits=output_router_logits)
830
+
831
+ hidden_states = layer_outputs[0]
832
+
833
+ if output_hidden_states:
834
+ all_hidden_states = all_hidden_states + (hidden_states,)
835
+
836
+ if output_router_logits:
837
+ all_router_logits += (layer_outputs[-1],)
838
+
839
+ if use_cache:
840
+ cache_params.seqlen_offset += inputs_embeds.shape[1]
841
+
842
+ hidden_states = self.norm_f(hidden_states)
843
+
844
+ if output_hidden_states:
845
+ all_hidden_states = all_hidden_states + (hidden_states,)
846
+
847
+
848
+ if not return_dict:
849
+ return tuple(v for v in [hidden_states, cache_params, all_hidden_states, all_router_logits] if v is not None)
850
+
851
+ return MoSMambaOutput(
852
+ last_hidden_state=hidden_states,
853
+ cache_params=cache_params if use_cache else None,
854
+ hidden_states=all_hidden_states,
855
+ router_logits=all_router_logits,
856
+ )
857
+
858
+
859
+ class MoSMambaForCausalLM(MoSMambaPreTrainedModel):
860
+ _tied_weights_keys = ["lm_head.weight"]
861
+
862
+ def __init__(self, config):
863
+ super().__init__(config)
864
+ self.backbone = MoSMambaModel(config)
865
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
866
+ self.num_selectivities = 6
867
+ self.num_selectivities_per_tok = 2
868
+ self.router_aux_loss_coef = 0.02
869
+ # Initialize weights and apply final processing
870
+ self.post_init()
871
+
872
+ def get_output_embeddings(self):
873
+ return self.lm_head
874
+
875
+ def set_output_embeddings(self, new_embeddings):
876
+ self.lm_head = new_embeddings
877
+
878
+ def get_input_embeddings(self):
879
+ return self.backbone.get_input_embeddings()
880
+
881
+ def set_input_embeddings(self, new_embeddings):
882
+ return self.backbone.set_input_embeddings(new_embeddings)
883
+
884
+ def _update_model_kwargs_for_generation(
885
+ self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
886
+ ) -> Dict[str, Any]:
887
+ model_kwargs["cache_params"] = outputs.get("cache_params", None)
888
+ return model_kwargs
889
+
890
+ def prepare_inputs_for_generation(
891
+ self, input_ids, cache_params: Optional[MoSMambaCache] = None, inputs_embeds=None, attention_mask=None, output_router_logits=False, **kwargs
892
+ ):
893
+ # only last token for inputs_ids if the state is passed along.
894
+ if cache_params is not None:
895
+ input_ids = input_ids[:, -1].unsqueeze(-1)
896
+
897
+ if inputs_embeds is not None and cache_params is None:
898
+ model_inputs = {"inputs_embeds": inputs_embeds}
899
+ else:
900
+ model_inputs = {"input_ids": input_ids}
901
+
902
+ model_inputs["cache_params"] = cache_params
903
+ model_inputs['output_router_logits'] = output_router_logits
904
+ return model_inputs
905
+
906
+
907
+ def forward(
908
+ self,
909
+ input_ids: Optional[torch.LongTensor] = None,
910
+ inputs_embeds: Optional[torch.FloatTensor] = None,
911
+ cache_params: Optional[MoSMambaCache] = None,
912
+ labels: Optional[torch.LongTensor] = None,
913
+ output_hidden_states: Optional[bool] = None,
914
+ output_router_logits: Optional[bool] = None,
915
+ return_dict: Optional[bool] = None,
916
+ use_cache: Optional[bool] = None,
917
+ **kwargs, # for now we need this for generation
918
+ ) -> Union[Tuple, MoSMambaCausalLMOutput]:
919
+ r"""
920
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
921
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
922
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
923
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
924
+ """
925
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
926
+
927
+ output_router_logits = (
928
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
929
+ )
930
+
931
+ mamba_outputs = self.backbone(
932
+ input_ids,
933
+ cache_params=cache_params,
934
+ inputs_embeds=inputs_embeds,
935
+ output_hidden_states=output_hidden_states,
936
+ return_dict=return_dict,
937
+ use_cache=use_cache,
938
+ )
939
+ hidden_states = mamba_outputs[0]
940
+
941
+ logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
942
+
943
+ loss = None
944
+ if labels is not None:
945
+ # move labels to correct device to enable model parallelism
946
+ labels = labels.to(logits.device)
947
+ # Shift so that tokens < n predict n
948
+ shift_logits = logits[..., :-1, :].contiguous()
949
+ shift_labels = labels[..., 1:].contiguous()
950
+ # Flatten the tokens
951
+ loss_fct = CrossEntropyLoss()
952
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
953
+
954
+ aux_loss = None
955
+ if output_router_logits:
956
+ aux_loss = load_balancing_loss_func(
957
+ mamba_outputs.router_logits if return_dict else mamba_outputs[-1],
958
+ self.num_selectivities,
959
+ self.num_selectivities_per_tok,
960
+ # attention_mask,
961
+ )
962
+ if labels is not None:
963
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
964
+
965
+ # print("AUX LOSS:", aux_loss)
966
+ # print("LOSS:", loss)
967
+
968
+ if not return_dict:
969
+ output = (logits,) + mamba_outputs[1:]
970
+ if output_router_logits:
971
+ output = (aux_loss,) + output
972
+ return (loss,) + output if loss is not None else output
973
+
974
+ # if not return_dict:
975
+ # output = (logits,) + mamba_outputs[1:]
976
+ # return ((loss,) + output) if loss is not None else output
977
+
978
+ return MoSMambaCausalLMOutput(
979
+ loss=loss,
980
+ logits=logits,
981
+ cache_params=mamba_outputs.cache_params,
982
+ hidden_states=mamba_outputs.hidden_states,
983
+ router_logits=mamba_outputs.router_logits,
984
+ )