yairschiff commited on
Commit
9bc99b5
1 Parent(s): aacd2e0

Upload BiMambaForMaskedLM

Browse files
Files changed (5) hide show
  1. README.md +201 -0
  2. config.json +44 -0
  3. configuration_bimamba.py +51 -0
  4. model.safetensors +3 -0
  5. modeling_bimamba.py +417 -0
README.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+
201
+
config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BiMambaForMaskedLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_bimamba.BiMambaConfig",
7
+ "AutoModel": "modeling_bimamba.BiMamba",
8
+ "AutoModelForMaskedLM": "modeling_bimamba.BiMambaForMaskedLM"
9
+ },
10
+ "bidirectional": true,
11
+ "bidirectional_strategy": "add",
12
+ "bidirectional_weight_tie": true,
13
+ "d_model": 2,
14
+ "fused_add_norm": true,
15
+ "initializer_cfg": {
16
+ "initializer_range": 0.02,
17
+ "n_residuals_per_layer": 1,
18
+ "rescale_prenorm_residual": true
19
+ },
20
+ "model_type": "bimamba",
21
+ "n_layer": 1,
22
+ "norm_epsilon": 1e-05,
23
+ "pad_token_id": -100,
24
+ "pad_vocab_size_multiple": 8,
25
+ "residual_in_fp32": true,
26
+ "rms_norm": true,
27
+ "ssm_cfg": {
28
+ "bias": false,
29
+ "conv_bias": true,
30
+ "d_conv": 4,
31
+ "d_state": 16,
32
+ "dt_init": "random",
33
+ "dt_init_floor": 0.0001,
34
+ "dt_max": 0.1,
35
+ "dt_min": 0.001,
36
+ "dt_rank": "auto",
37
+ "dt_scale": 1.0,
38
+ "expand": 2,
39
+ "use_fast_path": true
40
+ },
41
+ "torch_dtype": "float32",
42
+ "transformers_version": "4.38.1",
43
+ "vocab_size": 8
44
+ }
configuration_bimamba.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Caduceus config for Hugging Face.
2
+
3
+ """
4
+
5
+ from typing import Optional, Union
6
+
7
+ from transformers import PretrainedConfig
8
+
9
+
10
+ class BiMambaConfig(PretrainedConfig):
11
+ """Config that extends the original MambaConfig with params relevant to bi-directionality."""
12
+ model_type = "bimamba"
13
+
14
+ def __init__(
15
+ self,
16
+ # From original MambaConfig
17
+ d_model: int = 2560,
18
+ n_layer: int = 64,
19
+ vocab_size: int = 50277,
20
+ ssm_cfg: Optional[dict] = None,
21
+ rms_norm: bool = True,
22
+ residual_in_fp32: bool = True,
23
+ fused_add_norm: bool = True,
24
+ pad_vocab_size_multiple: int = 8,
25
+
26
+ # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
27
+ norm_epsilon: float = 1e-5,
28
+
29
+ # Used in init_weights
30
+ initializer_cfg: Optional[dict] = None,
31
+
32
+ # Caduceus-specific params
33
+ bidirectional: bool = True,
34
+ bidirectional_strategy: Union[str, None] = "add",
35
+ bidirectional_weight_tie: bool = True,
36
+ **kwargs,
37
+ ):
38
+ super().__init__(**kwargs)
39
+ self.d_model = d_model
40
+ self.n_layer = n_layer
41
+ self.vocab_size = vocab_size
42
+ self.ssm_cfg = ssm_cfg
43
+ self.rms_norm = rms_norm
44
+ self.residual_in_fp32 = residual_in_fp32
45
+ self.fused_add_norm = fused_add_norm
46
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
47
+ self.norm_epsilon = norm_epsilon
48
+ self.initializer_cfg = initializer_cfg
49
+ self.bidirectional = bidirectional
50
+ self.bidirectional_strategy = bidirectional_strategy
51
+ self.bidirectional_weight_tie = bidirectional_weight_tie
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e27efb874440812a910cf613e84d73caa4d133cd1e4dc33d8dbae5eae60de1f
3
+ size 4112
modeling_bimamba.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """BiMamba model for Hugging Face.
2
+
3
+ """
4
+
5
+ import math
6
+ from functools import partial
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ from mamba_ssm.modules.mamba_simple import Mamba, Block
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+ from transformers import PreTrainedModel
14
+ from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput
15
+
16
+ try:
17
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
18
+ except ImportError:
19
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
20
+
21
+ from .configuration_bimamba import BiMambaConfig
22
+
23
+
24
+ def create_block(
25
+ d_model,
26
+ ssm_cfg=None,
27
+ norm_epsilon=1e-5,
28
+ rms_norm=False,
29
+ residual_in_fp32=False,
30
+ fused_add_norm=False,
31
+ layer_idx=None,
32
+ bidirectional=True,
33
+ bidirectional_strategy="add",
34
+ bidirectional_weight_tie=True,
35
+ device=None,
36
+ dtype=None,
37
+ ):
38
+ """Create BiMamba block.
39
+
40
+ Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
41
+ """
42
+ if ssm_cfg is None:
43
+ ssm_cfg = {}
44
+ factory_kwargs = {"device": device, "dtype": dtype}
45
+ bidirectional_kwargs = {
46
+ "bidirectional": bidirectional,
47
+ "bidirectional_strategy": bidirectional_strategy,
48
+ "bidirectional_weight_tie": bidirectional_weight_tie,
49
+ }
50
+ mixer_cls = partial(BiMambaWrapper, layer_idx=layer_idx, **ssm_cfg, **bidirectional_kwargs, **factory_kwargs)
51
+ norm_cls = partial(
52
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
53
+ )
54
+ block_cls = Block
55
+ block = block_cls(
56
+ d_model,
57
+ mixer_cls,
58
+ norm_cls=norm_cls,
59
+ fused_add_norm=fused_add_norm,
60
+ residual_in_fp32=residual_in_fp32,
61
+ )
62
+ block.layer_idx = layer_idx
63
+ return block
64
+
65
+
66
+ class BiMambaWrapper(nn.Module):
67
+ """Thin wrapper around Mamba to support bi-directionality."""
68
+
69
+ def __init__(
70
+ self,
71
+ d_model: int,
72
+ bidirectional: bool = True,
73
+ bidirectional_strategy: Optional[str] = "add",
74
+ bidirectional_weight_tie: bool = True,
75
+ **mamba_kwargs,
76
+ ):
77
+ super().__init__()
78
+ if bidirectional and bidirectional_strategy is None:
79
+ bidirectional_strategy = "add" # Default strategy: `add`
80
+ if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]:
81
+ raise NotImplementedError(f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!")
82
+ self.bidirectional = bidirectional
83
+ self.bidirectional_strategy = bidirectional_strategy
84
+ self.mamba_fwd = Mamba(
85
+ d_model=d_model,
86
+ **mamba_kwargs
87
+ )
88
+ if bidirectional:
89
+ self.mamba_rev = Mamba(
90
+ d_model=d_model,
91
+ **mamba_kwargs
92
+ )
93
+ if bidirectional_weight_tie: # Tie in and out projections (where most of param count lies)
94
+ self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight
95
+ self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias
96
+ self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight
97
+ self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias
98
+ else:
99
+ self.mamba_rev = None
100
+
101
+ def forward(self, hidden_states, inference_params=None):
102
+ """Bidirectional-enabled forward pass
103
+
104
+ hidden_states: (B, L, D)
105
+ Returns: same shape as hidden_states
106
+ """
107
+ out = self.mamba_fwd(hidden_states, inference_params=inference_params)
108
+ if self.bidirectional:
109
+ out_rev = self.mamba_rev(
110
+ hidden_states.flip(dims=(1,)), # Flip along the sequence length dimension
111
+ inference_params=inference_params
112
+ ).flip(dims=(1,)) # Flip back for combining with forward hidden states
113
+ if self.bidirectional_strategy == "add":
114
+ out = out + out_rev
115
+ elif self.bidirectional_strategy == "ew_multiply":
116
+ out = out * out_rev
117
+ else:
118
+ raise NotImplementedError(f"`{self.bidirectional_strategy}` for bi-directionality not implemented!")
119
+ return out
120
+
121
+
122
+ class BiMambaEmbeddings(nn.Module):
123
+ def __init__(
124
+ self,
125
+ config: BiMambaConfig,
126
+ device=None,
127
+ dtype=None,
128
+ ):
129
+ super().__init__()
130
+ factory_kwargs = {"device": device, "dtype": dtype}
131
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.d_model, **factory_kwargs)
132
+
133
+ def forward(self, input_ids):
134
+ """
135
+ input_ids: (batch, seqlen)
136
+ """
137
+ return self.word_embeddings(input_ids)
138
+
139
+
140
+ class BiMambaMixerModel(nn.Module):
141
+ def __init__(
142
+ self,
143
+ config: BiMambaConfig,
144
+ device=None,
145
+ dtype=None,
146
+ ) -> None:
147
+ super().__init__()
148
+ factory_kwargs = {"device": device, "dtype": dtype}
149
+
150
+ self.fused_add_norm = config.fused_add_norm
151
+ self.residual_in_fp32 = config.residual_in_fp32
152
+
153
+ self.embeddings = BiMambaEmbeddings(config, **factory_kwargs)
154
+
155
+ # Mamba changes the order of residual and layer norm:
156
+ # Instead of LN -> Attn / MLP -> Add, we do:
157
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
158
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
159
+ # This is for performance reason: we can fuse add + layer_norm.
160
+ if config.fused_add_norm:
161
+ if layer_norm_fn is None or rms_norm_fn is None:
162
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
163
+
164
+ self.layers = nn.ModuleList(
165
+ [
166
+ create_block(
167
+ config.d_model,
168
+ ssm_cfg=config.ssm_cfg,
169
+ norm_epsilon=config.norm_epsilon,
170
+ rms_norm=config.rms_norm,
171
+ residual_in_fp32=config.residual_in_fp32,
172
+ fused_add_norm=config.fused_add_norm,
173
+ layer_idx=i,
174
+ bidirectional=config.bidirectional,
175
+ bidirectional_strategy=config.bidirectional_strategy,
176
+ bidirectional_weight_tie=config.bidirectional_weight_tie,
177
+ **factory_kwargs,
178
+ )
179
+ for i in range(config.n_layer)
180
+ ]
181
+ )
182
+
183
+ norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)(
184
+ config.d_model, eps=config.norm_epsilon, **factory_kwargs
185
+ )
186
+ self.norm_f = norm_f
187
+
188
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
189
+ """Mixer forward."""
190
+ all_hidden_states = []
191
+ if inputs_embeds is not None:
192
+ hidden_states = inputs_embeds
193
+ else:
194
+ hidden_states = self.embeddings(input_ids)
195
+
196
+ residual = None
197
+ for layer in self.layers:
198
+ if output_hidden_states:
199
+ all_hidden_states.append(hidden_states)
200
+ # TODO: Add support for gradient checkpointing
201
+ hidden_states, residual = layer(
202
+ hidden_states, residual, inference_params=None
203
+ )
204
+
205
+ if not self.fused_add_norm:
206
+ residual = (hidden_states + residual) if residual is not None else hidden_states
207
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
208
+ else:
209
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
210
+ # Set prenorm=False here since we don't need the residual
211
+ hidden_states = fused_add_norm_fn(
212
+ hidden_states,
213
+ self.norm_f.weight,
214
+ self.norm_f.bias,
215
+ eps=self.norm_f.eps,
216
+ residual=residual,
217
+ prenorm=False,
218
+ residual_in_fp32=self.residual_in_fp32,
219
+ )
220
+ if output_hidden_states:
221
+ all_hidden_states.append(hidden_states)
222
+ return hidden_states, all_hidden_states
223
+
224
+
225
+ def cross_entropy(logits, y, ignore_index=-100):
226
+ """Cross entropy loss."""
227
+ logits = logits.view(-1, logits.shape[-1])
228
+ y = y.view(-1)
229
+ return F.cross_entropy(logits, y, ignore_index=ignore_index)
230
+
231
+
232
+ def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
233
+ """Weighted cross entropy loss (discounts certain tokens)."""
234
+ logits = logits.view(-1, logits.shape[-1])
235
+ y = y.view(-1)
236
+ ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
237
+ loss_weights = loss_weights.view(-1)
238
+ loss_weights[y == ignore_index] = 0.0
239
+ # TODO: Follows GPN implementation, but should we remove weight normalization?
240
+ return (ce * (loss_weights / loss_weights.sum())).sum()
241
+
242
+
243
+ class BiMambaPreTrainedModel(PreTrainedModel):
244
+ """PreTrainedModel wrapper for BiMamba backbone."""
245
+ config_class = BiMambaConfig
246
+ base_model_prefix = "bimamba"
247
+ supports_gradient_checkpointing = False
248
+ _no_split_modules = ["BiMambaWrapper"]
249
+
250
+ def _init_weights(
251
+ self,
252
+ module,
253
+ initializer_range=0.02, # Now only used for embedding layer.
254
+ **kwargs,
255
+ ):
256
+ """Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py"""
257
+
258
+ n_layer = self.config.n_layer
259
+ initialized_cfg = self.config.initializer_cfg if self.config.initializer_cfg is not None else {}
260
+ rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True)
261
+ initializer_range = initialized_cfg.get("initializer_range", initializer_range)
262
+ n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1)
263
+
264
+ if isinstance(module, nn.Linear):
265
+ if module.bias is not None:
266
+ if not getattr(module.bias, "_no_reinit", False):
267
+ nn.init.zeros_(module.bias)
268
+ elif isinstance(module, nn.Embedding):
269
+ nn.init.normal_(module.weight, std=initializer_range)
270
+
271
+ if rescale_prenorm_residual:
272
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
273
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth.
274
+ # > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of
275
+ # residual layers.
276
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
277
+ #
278
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
279
+ for name, p in module.named_parameters():
280
+ if name in ["out_proj.weight", "fc2.weight"]:
281
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
282
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
283
+ # We need to reinit p since this code could be called multiple times
284
+ # Having just p *= scale would repeatedly scale it down
285
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
286
+ with torch.no_grad():
287
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
288
+
289
+
290
+ class BiMamba(BiMambaPreTrainedModel):
291
+ """BiMamba model that can be instantiated using HF patterns."""
292
+ def __init__(self, config: BiMambaConfig, device=None, dtype=None, **kwargs):
293
+ super().__init__(config)
294
+
295
+ # Adjust vocab size if vocab padding is set.
296
+ if config.vocab_size % config.pad_vocab_size_multiple != 0:
297
+ config.vocab_size += config.pad_vocab_size_multiple - (config.vocab_size % config.pad_vocab_size_multiple)
298
+
299
+ self.config = config
300
+ factory_kwargs = {"device": device, "dtype": dtype}
301
+ self.backbone = BiMambaMixerModel(config, **factory_kwargs, **kwargs)
302
+
303
+ def forward(
304
+ self,
305
+ input_ids: torch.LongTensor = None,
306
+ inputs_embeds: Optional[torch.FloatTensor] = None,
307
+ output_hidden_states: Optional[bool] = None,
308
+ return_dict: Optional[bool] = None,
309
+ ) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]:
310
+ """HF-compatible forward method."""
311
+ output_hidden_states = (
312
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
313
+ )
314
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
315
+
316
+ hidden_states, all_hidden_states = self.backbone(
317
+ input_ids,
318
+ inputs_embeds=inputs_embeds,
319
+ output_hidden_states=output_hidden_states
320
+ )
321
+ if return_dict:
322
+ return BaseModelOutputWithNoAttention(
323
+ last_hidden_state=hidden_states,
324
+ hidden_states=all_hidden_states if output_hidden_states else None
325
+ )
326
+ elif output_hidden_states:
327
+ return hidden_states, all_hidden_states
328
+ else:
329
+ return hidden_states
330
+
331
+
332
+ class BiMambaForMaskedLM(BiMambaPreTrainedModel):
333
+ """HF-compatible BiMamba model for masked language modeling."""
334
+
335
+ def __init__(self, config: BiMambaConfig, device=None, dtype=None, **kwargs):
336
+ super().__init__(config, **kwargs)
337
+ factory_kwargs = {"device": device, "dtype": dtype}
338
+ self.bimamba = BiMamba(config, **factory_kwargs, **kwargs)
339
+ self.lm_head = nn.Linear(
340
+ config.d_model,
341
+ self.config.vocab_size, # Use BiMamba config as it might have been updated
342
+ bias=False,
343
+ **factory_kwargs
344
+ )
345
+
346
+ # Initialize weights and apply final processing
347
+ self.post_init()
348
+
349
+ def get_input_embeddings(self):
350
+ return self.bimamba.backbone.embeddings.word_embeddings
351
+
352
+ def set_input_embeddings(self, value):
353
+ self.bimamba.backbone.embeddings.word_embeddings = value
354
+
355
+ def get_output_embeddings(self):
356
+ return self.lm_head
357
+
358
+ def set_output_embeddings(self, new_embeddings):
359
+ """Overrides output embeddings."""
360
+ self.lm_head = new_embeddings
361
+
362
+ def tie_weights(self):
363
+ """Tie weights."""
364
+ super().tie_weights()
365
+
366
+ def get_decoder(self):
367
+ """Get decoder (backbone) for the model."""
368
+ return self.bimamba
369
+
370
+ def set_decoder(self, decoder):
371
+ """Set decoder (backbone) for the model."""
372
+ self.bimamba = decoder
373
+
374
+ def forward(
375
+ self,
376
+ input_ids: torch.LongTensor = None,
377
+ inputs_embeds: Optional[torch.FloatTensor] = None,
378
+ labels: Optional[torch.LongTensor] = None,
379
+ loss_weights: Optional[torch.FloatTensor] = None,
380
+ output_hidden_states: Optional[bool] = None,
381
+ return_dict: Optional[bool] = None,
382
+ ) -> Union[Tuple, MaskedLMOutput]:
383
+ """HF-compatible forward method."""
384
+
385
+ output_hidden_states = (
386
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
387
+ )
388
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
389
+
390
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
391
+ outputs = self.bimamba(
392
+ input_ids=input_ids,
393
+ inputs_embeds=inputs_embeds,
394
+ output_hidden_states=output_hidden_states,
395
+ return_dict=return_dict,
396
+ )
397
+
398
+ hidden_states = outputs[0]
399
+ logits = self.lm_head(hidden_states)
400
+ logits = logits.float()
401
+
402
+ loss = None
403
+ if labels is not None:
404
+ if loss_weights is not None:
405
+ loss = weighted_cross_entropy(logits, labels, loss_weights, ignore_index=self.config.pad_token_id)
406
+ else:
407
+ loss = cross_entropy(logits, labels, ignore_index=self.config.pad_token_id)
408
+
409
+ if not return_dict:
410
+ output = (logits,) + outputs[1:]
411
+ return (loss,) + output if loss is not None else output
412
+
413
+ return MaskedLMOutput(
414
+ loss=loss,
415
+ logits=logits,
416
+ hidden_states=outputs.hidden_states,
417
+ )