Dan Fu commited on
Commit
03ef821
1 Parent(s): bea4c10

Automodel support

Browse files
bert_layers.py ADDED
@@ -0,0 +1,986 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
3
+ # Copyright (c) 2022, Tri Dao.
4
+ # Copyright (c) 2023, MosaicML.
5
+ # Copyright (c) 2023, Dan Fu and Simran Arora.
6
+
7
+ import copy
8
+ import logging
9
+ import math
10
+ import os
11
+ import sys
12
+ import warnings
13
+ from typing import List, Optional, Tuple, Union
14
+ from functools import partial
15
+
16
+ # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from
17
+ # sys.path.append(os.path.dirname(os.path.realpath(__file__)))
18
+
19
+ from .bert_padding import (index_first_axis,
20
+ index_put_first_axis, pad_input,
21
+ unpad_input, unpad_input_only)
22
+ import torch
23
+ import torch.nn as nn
24
+ from einops import rearrange
25
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
26
+ from transformers.activations import ACT2FN
27
+ from transformers.modeling_outputs import (MaskedLMOutput,
28
+ SequenceClassifierOutput)
29
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel
30
+
31
+ from .blockdiag_linear import BlockdiagLinear
32
+ from .monarch_mixer_sequence_mixer import MonarchMixerSequenceMixing
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ torch.backends.cuda.matmul.allow_tf32 = True
37
+ torch.backends.cudnn.allow_tf32 = True
38
+
39
+ class BertEmbeddings(nn.Module):
40
+ """Construct the embeddings for words, ignoring position.
41
+
42
+ There are no positional embeddings since we use ALiBi and token_type
43
+ embeddings.
44
+
45
+ This module is modeled after the Hugging Face BERT's
46
+ :class:`~transformers.model.bert.modeling_bert.BertEmbeddings`, but is
47
+ modified as part of Mosaic BERT's ALiBi implementation. The key change is
48
+ that position embeddings are removed. Position information instead comes
49
+ from attention biases that scale linearly with the position distance
50
+ between query and key tokens.
51
+
52
+ This module ignores the `position_ids` input to the `forward` method.
53
+ """
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(config.vocab_size,
58
+ config.hidden_size,
59
+ padding_idx=config.pad_token_id)
60
+ # ALiBi doesn't use position embeddings
61
+ if config.use_positional_encodings:
62
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
63
+ self.use_positional_encodings = config.use_positional_encodings
64
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
65
+ config.hidden_size)
66
+
67
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model
68
+ # variable name and be able to load any TensorFlow checkpoint file
69
+ self.LayerNorm = nn.LayerNorm(config.hidden_size,
70
+ eps=config.layer_norm_eps)
71
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
72
+ if config.use_positional_encodings:
73
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
74
+ self.register_buffer('token_type_ids',
75
+ torch.zeros(config.max_position_embeddings,
76
+ dtype=torch.long),
77
+ persistent=False)
78
+
79
+ def forward(
80
+ self,
81
+ input_ids: Optional[torch.LongTensor] = None,
82
+ token_type_ids: Optional[torch.LongTensor] = None,
83
+ position_ids: Optional[torch.LongTensor] = None,
84
+ inputs_embeds: Optional[torch.FloatTensor] = None,
85
+ past_key_values_length: int = 0,
86
+ return_position_encodings: bool = False,
87
+ ) -> torch.Tensor:
88
+ if (input_ids is not None) == (inputs_embeds is not None):
89
+ raise ValueError('Must specify either input_ids or input_embeds!')
90
+ if input_ids is not None:
91
+ input_shape = input_ids.size()
92
+ else:
93
+ assert inputs_embeds is not None # just for type checking
94
+ input_shape = inputs_embeds.size()[:-1]
95
+
96
+ seq_length = input_shape[1]
97
+
98
+ if position_ids is None:
99
+ if self.use_positional_encodings:
100
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
101
+
102
+ # Setting the token_type_ids to the registered buffer in constructor
103
+ # where it is all zeros, which usually occurs when it's auto-generated;
104
+ # registered buffer helps users when tracing the model without passing
105
+ # token_type_ids, solves issue #5664
106
+ if token_type_ids is None:
107
+ if hasattr(self, 'token_type_ids'):
108
+ assert isinstance(self.token_type_ids, torch.LongTensor)
109
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
110
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
111
+ input_shape[0], seq_length)
112
+ token_type_ids = buffered_token_type_ids_expanded # type: ignore
113
+ else:
114
+ token_type_ids = torch.zeros(input_shape, # type: ignore
115
+ dtype=torch.long,
116
+ device=self.word_embeddings.device) # type: ignore # yapf: disable
117
+
118
+ if inputs_embeds is None:
119
+ inputs_embeds = self.word_embeddings(input_ids)
120
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
121
+
122
+ embeddings = inputs_embeds + token_type_embeddings
123
+ if self.use_positional_encodings:
124
+ position_embeddings = self.position_embeddings(position_ids)
125
+ embeddings += position_embeddings
126
+ embeddings = self.LayerNorm(embeddings)
127
+ embeddings = self.dropout(embeddings)
128
+ if return_position_encodings:
129
+ return embeddings, position_embeddings
130
+ else:
131
+ return embeddings
132
+
133
+ class BertMLP(nn.Module):
134
+ """Applies the FFN at the end of each BERT layer."""
135
+
136
+ def __init__(self, config):
137
+ super().__init__()
138
+ self.config = config
139
+
140
+ if self.config.use_monarch_mlp:
141
+ linear_cls = partial(BlockdiagLinear, nblocks=self.config.monarch_mlp_nblocks)
142
+ else:
143
+ linear_cls = nn.Linear
144
+
145
+ self.gated_layers = linear_cls(config.hidden_size,
146
+ config.intermediate_size,
147
+ bias=False)
148
+ self.act = nn.GELU(approximate='none')
149
+ self.wo = linear_cls(config.intermediate_size, config.hidden_size)
150
+
151
+ self.layernorm = nn.LayerNorm(config.hidden_size,
152
+ eps=config.layer_norm_eps)
153
+
154
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
155
+
156
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
157
+ """Compute new hidden states from current hidden states.
158
+
159
+ Args:
160
+ hidden_states (torch.Tensor): The (unpadded) hidden states from
161
+ the attention layer [nnz, dim].
162
+ """
163
+
164
+ residual_connection = hidden_states
165
+ hidden_states = self.gated_layers(hidden_states)
166
+ hidden_states = self.act(hidden_states)
167
+ hidden_states = self.dropout(hidden_states)
168
+ hidden_states = self.wo(hidden_states)
169
+ hidden_states = self.layernorm(hidden_states + residual_connection)
170
+ return hidden_states
171
+
172
+
173
+ class BertGatedLinearUnitMLP(nn.Module):
174
+ """Applies the FFN at the end of each BERT layer with a Gated Linear Unit"""
175
+
176
+ def __init__(self, config):
177
+ super().__init__()
178
+ self.config = config
179
+
180
+ self.is_padded = True
181
+
182
+ if self.config.use_monarch_mlp:
183
+ linear_cls = partial(BlockdiagLinear, nblocks=self.config.monarch_mlp_nblocks)
184
+ else:
185
+ linear_cls = nn.Linear
186
+ self.gated_layers = linear_cls(
187
+ config.hidden_size,
188
+ config.intermediate_size * 2,
189
+ bias=False
190
+ )
191
+ self.act = nn.GELU(approximate='none')
192
+ self.wo = linear_cls(config.intermediate_size, config.hidden_size)
193
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
194
+ self.layernorm = nn.LayerNorm(config.hidden_size,
195
+ eps=config.layer_norm_eps)
196
+
197
+
198
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
199
+ """Compute new hidden states from current hidden states.
200
+
201
+ Args:
202
+ hidden_states (torch.Tensor): The (unpadded) hidden states from
203
+ the attention layer [nnz, dim].
204
+ """
205
+
206
+ residual_connection = hidden_states
207
+ # compute the activation
208
+ hidden_states = self.gated_layers(hidden_states)
209
+
210
+ if self.is_padded:
211
+ gated = hidden_states[:, :, :self.config.intermediate_size]
212
+ non_gated = hidden_states[:, :, self.config.intermediate_size:]
213
+ else:
214
+ gated = hidden_states[:, :self.config.intermediate_size]
215
+ non_gated = hidden_states[:, self.config.intermediate_size:]
216
+
217
+ hidden_states = self.act(gated) * non_gated
218
+ hidden_states = self.dropout(hidden_states)
219
+ # multiply by the second matrix
220
+ hidden_states = self.wo(hidden_states)
221
+ # add the residual connection and post-LN
222
+ hidden_states = self.layernorm(hidden_states + residual_connection)
223
+
224
+ return hidden_states
225
+
226
+
227
+ class BertLayer(nn.Module):
228
+ """BERT layer, which includes Sequence Mixing (e.g. Hyena) and State Mixing (e.g. MLP)."""
229
+
230
+ def __init__(self, config):
231
+ super(BertLayer, self).__init__()
232
+
233
+ mm_cls = MonarchMixerSequenceMixing
234
+ self.attention = mm_cls(
235
+ config.hidden_size,
236
+ l_max=config.long_conv_l_max,
237
+ hyena_kernel_lr=config.long_conv_kernel_learning_rate,
238
+ bidirectional=config.bidirectional,
239
+
240
+ hyena_lr_pos_emb=config.hyena_lr_pos_emb,
241
+ hyena_w=config.hyena_w,
242
+ hyena_w_mod=config.hyena_w_mod,
243
+ hyena_wd=config.hyena_wd,
244
+ hyena_emb_dim=config.hyena_emb_dim,
245
+ hyena_filter_dropout=config.hyena_filter_dropout,
246
+ hyena_filter_order=config.hyena_filter_order,
247
+ residual_long_conv=config.residual_long_conv,
248
+ hyena_training_additions=config.hyena_training_additions,
249
+ )
250
+
251
+ if config.use_glu_mlp:
252
+ self.mlp = BertGatedLinearUnitMLP(config)
253
+ else:
254
+ self.mlp = BertMLP(config)
255
+
256
+ def forward(
257
+ self,
258
+ hidden_states: torch.Tensor,
259
+ cu_seqlens: torch.Tensor,
260
+ seqlen: int,
261
+ subset_idx: Optional[torch.Tensor] = None,
262
+ indices: Optional[torch.Tensor] = None,
263
+ attn_mask: Optional[torch.Tensor] = None,
264
+ bias: Optional[torch.Tensor] = None,
265
+ ) -> torch.Tensor:
266
+ """Forward pass for a BERT layer, including both attention and MLP.
267
+
268
+ Args:
269
+ hidden_states: (total_nnz, dim)
270
+ cu_seqlens: (batch + 1,)
271
+ seqlen: int
272
+ subset_idx: () set of indices whose values we care about at the end of the layer
273
+ (e.g., the masked tokens, if this is the final layer).
274
+ indices: None or (total_nnz,)
275
+ attn_mask: None or (batch, max_seqlen_in_batch)
276
+ bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
277
+ """
278
+
279
+ attention_output = self.attention(hidden_states)
280
+ if type(attention_output) == tuple:
281
+ attention_output, _ = attention_output
282
+
283
+ layer_output = self.mlp(attention_output)
284
+
285
+ return layer_output
286
+
287
+
288
+ class BertEncoder(nn.Module):
289
+ """A stack of BERT layers providing the backbone of BERT.
290
+
291
+ Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
292
+ at padded tokens, and pre-computes attention biases to implement ALiBi.
293
+ """
294
+
295
+ def __init__(self, config):
296
+ super().__init__()
297
+ layer = BertLayer(config)
298
+ self.layer = nn.ModuleList(
299
+ [copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
300
+
301
+ self.num_attention_heads = config.num_attention_heads
302
+
303
+ def rebuild_alibi_tensor(self,
304
+ size: int,
305
+ device: Optional[Union[torch.device, str]] = None):
306
+ # Alibi
307
+ # Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
308
+ # In the causal case, you can exploit the fact that softmax is invariant to a uniform translation
309
+ # of the logits, which makes the math work out *after* applying causal masking. If no causal masking
310
+ # will be applied, it is necessary to construct the diagonal mask.
311
+ n_heads = self.num_attention_heads
312
+
313
+ def _get_alibi_head_slopes(n_heads: int) -> List[float]:
314
+
315
+ def get_slopes_power_of_2(n_heads: int) -> List[float]:
316
+ start = (2**(-2**-(math.log2(n_heads) - 3)))
317
+ ratio = start
318
+ return [start * ratio**i for i in range(n_heads)]
319
+
320
+ # In the paper, they only train models that have 2^a heads for some a. This function
321
+ # has some good properties that only occur when the input is a power of 2. To
322
+ # maintain that even when the number of heads is not a power of 2, we use a
323
+ # workaround.
324
+ if math.log2(n_heads).is_integer():
325
+ return get_slopes_power_of_2(n_heads)
326
+
327
+ closest_power_of_2 = 2**math.floor(math.log2(n_heads))
328
+ slopes_a = get_slopes_power_of_2(closest_power_of_2)
329
+ slopes_b = _get_alibi_head_slopes(2 * closest_power_of_2)
330
+ slopes_b = slopes_b[0::2][:n_heads - closest_power_of_2]
331
+ return slopes_a + slopes_b
332
+
333
+ context_position = torch.arange(size, device=device)[:, None]
334
+ memory_position = torch.arange(size, device=device)[None, :]
335
+ relative_position = torch.abs(memory_position - context_position)
336
+ # [n_heads, max_token_length, max_token_length]
337
+ relative_position = relative_position.unsqueeze(0).expand(
338
+ n_heads, -1, -1)
339
+ slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device)
340
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position
341
+ # [1, n_heads, max_token_length, max_token_length]
342
+ alibi = alibi.unsqueeze(0)
343
+ assert alibi.shape == torch.Size([1, n_heads, size, size])
344
+
345
+ self._current_alibi_size = size
346
+ self.alibi = alibi
347
+
348
+ def forward(
349
+ self,
350
+ hidden_states: torch.Tensor,
351
+ attention_mask: torch.Tensor,
352
+ output_all_encoded_layers: Optional[bool] = True,
353
+ subset_mask: Optional[torch.Tensor] = None,
354
+ position_encodings: Optional[torch.Tensor] = None,
355
+ ) -> List[torch.Tensor]:
356
+
357
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
358
+ extended_attention_mask = extended_attention_mask.to(
359
+ dtype=next(self.parameters()).dtype) # fp16 compatibility
360
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
361
+ attention_mask_bool = attention_mask.bool()
362
+ batch, seqlen = hidden_states.shape[:2]
363
+
364
+ cu_seqlens = None
365
+ indices = None
366
+ alibi_attn_mask = None
367
+
368
+ all_encoder_layers = []
369
+ for layer_module in self.layer:
370
+ hidden_states = layer_module(hidden_states,
371
+ cu_seqlens,
372
+ seqlen,
373
+ None,
374
+ indices,
375
+ attn_mask=attention_mask,
376
+ bias=alibi_attn_mask
377
+ )
378
+ if position_encodings is not None:
379
+ hidden_states = hidden_states + position_encodings
380
+ if output_all_encoded_layers:
381
+ all_encoder_layers.append(hidden_states)
382
+ if subset_mask is not None:
383
+ hidden_states = hidden_states[subset_mask]
384
+
385
+ if not output_all_encoded_layers:
386
+ all_encoder_layers.append(hidden_states)
387
+ return all_encoder_layers
388
+
389
+
390
+ class BertPooler(nn.Module):
391
+
392
+ def __init__(self, config):
393
+ super(BertPooler, self).__init__()
394
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
395
+ self.activation = nn.Tanh()
396
+ self.pool_all = config.pool_all
397
+
398
+ def forward(self,
399
+ hidden_states: torch.Tensor,
400
+ pool: Optional[bool] = True,
401
+ mask= None) -> torch.Tensor:
402
+ # We "pool" the model by simply taking the hidden state corresponding
403
+ # to the first token.
404
+ if not self.pool_all:
405
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
406
+ pooled_output = self.dense(first_token_tensor)
407
+ pooled_output = self.activation(pooled_output)
408
+ else:
409
+ # mean pool everything that isn't masked out
410
+ denom = torch.sum(mask, dim=1, keepdim=True)
411
+ mean_tensor = torch.sum((hidden_states) * mask.unsqueeze(-1), dim = 1) / denom
412
+ pooled_output = self.dense(mean_tensor)
413
+ pooled_output = self.activation(pooled_output)
414
+ return pooled_output
415
+
416
+
417
+ class BertPredictionHeadTransform(nn.Module):
418
+
419
+ def __init__(self, config):
420
+ super().__init__()
421
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
422
+ if isinstance(config.hidden_act, str):
423
+ self.transform_act_fn = ACT2FN[config.hidden_act]
424
+ else:
425
+ self.transform_act_fn = config.hidden_act
426
+ self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12)
427
+
428
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
429
+ hidden_states = self.dense(hidden_states)
430
+ hidden_states = self.transform_act_fn(hidden_states)
431
+ hidden_states = self.LayerNorm(hidden_states)
432
+ return hidden_states
433
+
434
+
435
+ class BertModel(BertPreTrainedModel):
436
+ """Overall BERT model.
437
+
438
+ Args:
439
+ config: a BertConfig class instance with the configuration to build a new model
440
+
441
+ Inputs:
442
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
443
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
444
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
445
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
446
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
447
+ a `sentence B` token (see BERT paper for more details).
448
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
449
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
450
+ input sequence length in the current batch. It's the mask that we typically use for attention when
451
+ a batch has varying length sentences.
452
+ `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
453
+
454
+ Outputs: Tuple of (encoded_layers, pooled_output)
455
+ `encoded_layers`: controlled by `output_all_encoded_layers` argument:
456
+ - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
457
+ of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
458
+ encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
459
+ - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
460
+ to the last attention block of shape [batch_size, sequence_length, hidden_size],
461
+ `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
462
+ classifier pretrained on top of the hidden state associated to the first character of the
463
+ input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
464
+
465
+ Example usage:
466
+ ```python
467
+ # Already been converted into WordPiece token ids
468
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
469
+ input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
470
+ token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
471
+ config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
472
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
473
+ model = BertModel(config=config)
474
+ all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
475
+ ```
476
+ """
477
+
478
+ def __init__(self, config, add_pooling_layer=True):
479
+ super(BertModel, self).__init__(config)
480
+ self.embeddings = BertEmbeddings(config)
481
+ self.encoder = BertEncoder(config)
482
+
483
+ self.pooler = BertPooler(config) if add_pooling_layer else None
484
+ self.post_init()
485
+
486
+
487
+ def get_input_embeddings(self):
488
+ return self.embeddings.word_embeddings
489
+
490
+ def set_input_embeddings(self, value):
491
+ self.embeddings.word_embeddings = value
492
+
493
+ def forward(
494
+ self,
495
+ input_ids: torch.Tensor,
496
+ token_type_ids: Optional[torch.Tensor] = None,
497
+ attention_mask: Optional[torch.Tensor] = None,
498
+ position_ids: Optional[torch.Tensor] = None,
499
+ output_all_encoded_layers: Optional[bool] = False,
500
+ masked_tokens_mask: Optional[torch.Tensor] = None,
501
+ **kwargs
502
+ ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
503
+ if attention_mask is None:
504
+ attention_mask = torch.ones_like(input_ids)
505
+ if token_type_ids is None:
506
+ token_type_ids = torch.zeros_like(input_ids)
507
+
508
+ embedding_output = self.embeddings(
509
+ input_ids,
510
+ token_type_ids,
511
+ position_ids
512
+ )
513
+ position_encodings = None
514
+
515
+ subset_mask = []
516
+ first_col_mask = []
517
+
518
+ if masked_tokens_mask is None:
519
+ subset_mask = None
520
+ else:
521
+ first_col_mask = torch.zeros_like(masked_tokens_mask)
522
+ first_col_mask[:, 0] = True
523
+ subset_mask = masked_tokens_mask | first_col_mask
524
+
525
+ encoder_outputs = self.encoder(
526
+ embedding_output,
527
+ attention_mask,
528
+ output_all_encoded_layers=output_all_encoded_layers,
529
+ subset_mask=subset_mask,
530
+ position_encodings=position_encodings)
531
+ if masked_tokens_mask is None:
532
+ sequence_output = encoder_outputs[-1]
533
+ pooled_output = self.pooler(
534
+ sequence_output, mask = attention_mask) if self.pooler is not None else None
535
+ else:
536
+ # TD [2022-03-01]: the indexing here is very tricky.
537
+ attention_mask_bool = attention_mask.bool()
538
+ subset_idx = subset_mask[attention_mask_bool] # type: ignore
539
+ sequence_output = encoder_outputs[-1][
540
+ masked_tokens_mask[attention_mask_bool][subset_idx]]
541
+ if self.pooler is not None:
542
+ pool_input = encoder_outputs[-1][
543
+ first_col_mask[attention_mask_bool][subset_idx]]
544
+ pooled_output = self.pooler(pool_input, pool=False, mask = attention_mask)
545
+ else:
546
+ pooled_output = None
547
+
548
+ if not output_all_encoded_layers:
549
+ encoder_outputs = sequence_output
550
+
551
+ if self.pooler is not None:
552
+ return encoder_outputs, pooled_output
553
+
554
+ return encoder_outputs, None
555
+
556
+
557
+ ###################
558
+ # Bert Heads
559
+ ###################
560
+ class BertLMPredictionHead(nn.Module):
561
+
562
+ def __init__(self, config, bert_model_embedding_weights):
563
+ super().__init__()
564
+ self.transform = BertPredictionHeadTransform(config)
565
+ # The output weights are the same as the input embeddings, but there is
566
+ # an output-only bias for each token.
567
+ self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
568
+ bert_model_embedding_weights.size(0))
569
+ self.decoder.weight = bert_model_embedding_weights
570
+
571
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
572
+ hidden_states = self.transform(hidden_states)
573
+ hidden_states = self.decoder(hidden_states)
574
+ return hidden_states
575
+
576
+
577
+ class BertOnlyMLMHead(nn.Module):
578
+
579
+ def __init__(self, config, bert_model_embedding_weights):
580
+ super().__init__()
581
+ self.predictions = BertLMPredictionHead(config,
582
+ bert_model_embedding_weights)
583
+
584
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
585
+ prediction_scores = self.predictions(sequence_output)
586
+ return prediction_scores
587
+
588
+
589
+ class BertOnlyNSPHead(nn.Module):
590
+
591
+ def __init__(self, config):
592
+ super().__init__()
593
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
594
+
595
+ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
596
+ seq_relationship_score = self.seq_relationship(pooled_output)
597
+ return seq_relationship_score
598
+
599
+
600
+ #######################
601
+ # Construct Bert model
602
+ #######################
603
+ class BertForMaskedLM(BertPreTrainedModel):
604
+
605
+ def __init__(self, config):
606
+ super().__init__(config)
607
+
608
+ if config.is_decoder:
609
+ warnings.warn(
610
+ 'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
611
+ 'bi-directional self-attention.')
612
+
613
+ self.bert = BertModel(config, add_pooling_layer=False)
614
+ self.cls = BertOnlyMLMHead(config,
615
+ self.bert.embeddings.word_embeddings.weight)
616
+
617
+ # Initialize weights and apply final processing
618
+ self.post_init()
619
+
620
+ @classmethod
621
+ def from_composer(cls,
622
+ pretrained_checkpoint,
623
+ state_dict=None,
624
+ cache_dir=None,
625
+ from_tf=False,
626
+ config=None,
627
+ *inputs,
628
+ **kwargs):
629
+ """Load from pre-trained."""
630
+ model = cls(config, *inputs, **kwargs)
631
+ if from_tf:
632
+ raise ValueError(
633
+ 'TensorFlow is not supported.')
634
+
635
+ state_dict = torch.load(pretrained_checkpoint)
636
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
637
+ consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
638
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict,
639
+ strict=False)
640
+
641
+ if len(missing_keys) > 0:
642
+ logger.warning(
643
+ f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
644
+ )
645
+ if len(unexpected_keys) > 0:
646
+ logger.warning(
647
+ f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}"
648
+ )
649
+
650
+ return model
651
+
652
+ def get_output_embeddings(self):
653
+ return self.cls.predictions.decoder
654
+
655
+ def set_output_embeddings(self, new_embeddings):
656
+ self.cls.predictions.decoder = new_embeddings
657
+
658
+ def forward(
659
+ self,
660
+ input_ids: Optional[torch.Tensor] = None,
661
+ attention_mask: Optional[torch.Tensor] = None,
662
+ token_type_ids: Optional[torch.Tensor] = None,
663
+ position_ids: Optional[torch.Tensor] = None,
664
+ head_mask: Optional[torch.Tensor] = None,
665
+ inputs_embeds: Optional[torch.Tensor] = None,
666
+ encoder_hidden_states: Optional[torch.Tensor] = None,
667
+ encoder_attention_mask: Optional[torch.Tensor] = None,
668
+ labels: Optional[torch.Tensor] = None,
669
+ output_attentions: Optional[bool] = None,
670
+ output_hidden_states: Optional[bool] = None,
671
+ return_dict: Optional[bool] = None,
672
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
673
+ # labels should be a `torch.LongTensor` of shape
674
+ # `(batch_size, sequence_length)`. These are used for computing the
675
+ # masked language modeling loss.
676
+ #
677
+ # Indices should be in `[-100, 0, ..., config.vocab_size]` (see
678
+ # `input_ids` docstring) Tokens with indices set to `-100` are ignored
679
+ # (masked), the loss is only computed for the tokens with labels in `[0,
680
+ # ..., config.vocab_size]`
681
+ #
682
+ # Prediction scores are only computed for masked tokens and the (bs,
683
+ # seqlen) dimensions are flattened
684
+ if (input_ids is not None) == (inputs_embeds is not None):
685
+ raise ValueError('Must specify either input_ids or input_embeds!')
686
+
687
+ if labels is None:
688
+ masked_tokens_mask = None
689
+ else:
690
+ masked_tokens_mask = labels > 0
691
+
692
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
693
+
694
+ outputs = self.bert(
695
+ input_ids,
696
+ attention_mask=attention_mask,
697
+ token_type_ids=token_type_ids,
698
+ position_ids=position_ids,
699
+ head_mask=head_mask,
700
+ inputs_embeds=inputs_embeds,
701
+ encoder_hidden_states=encoder_hidden_states,
702
+ encoder_attention_mask=encoder_attention_mask,
703
+ output_attentions=output_attentions,
704
+ output_hidden_states=output_hidden_states,
705
+ return_dict=return_dict,
706
+ masked_tokens_mask=masked_tokens_mask,
707
+ )
708
+
709
+ if torch.isnan(outputs[0]).any():
710
+ print("NaNs in outputs.")
711
+ raise ValueError()
712
+
713
+ #print("MLM Outputs")
714
+ #print(outputs[0].shape)
715
+
716
+ pooled_output = outputs[0]
717
+
718
+ last_hidden_state_formatted = outputs[0][:,0,:].view(-1, self.config.hidden_size)
719
+ return {"sentence_embedding": last_hidden_state_formatted}
720
+
721
+ def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
722
+ attention_mask: torch.Tensor,
723
+ **model_kwargs):
724
+ input_shape = input_ids.shape
725
+ effective_batch_size = input_shape[0]
726
+
727
+ # add a dummy token
728
+ if self.config.pad_token_id is None:
729
+ raise ValueError('The PAD token should be defined for generation')
730
+
731
+ attention_mask = torch.cat([
732
+ attention_mask,
733
+ attention_mask.new_zeros((attention_mask.shape[0], 1))
734
+ ], dim=-1)
735
+ dummy_token = torch.full((effective_batch_size, 1),
736
+ self.config.pad_token_id,
737
+ dtype=torch.long,
738
+ device=input_ids.device)
739
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
740
+
741
+ return {'input_ids': input_ids, 'attention_mask': attention_mask}
742
+
743
+
744
+ class BertForSequenceClassification(BertPreTrainedModel):
745
+ """Bert Model transformer with a sequence classification/regression head.
746
+
747
+ This head is just a linear layer on top of the pooled output. Used for,
748
+ e.g., GLUE tasks.
749
+ """
750
+
751
+ def __init__(self, config):
752
+ super().__init__(config)
753
+ self.num_labels = config.num_labels
754
+ self.config = config
755
+
756
+ self.bert = BertModel(config)
757
+ classifier_dropout = (config.classifier_dropout
758
+ if config.classifier_dropout is not None else
759
+ config.hidden_dropout_prob)
760
+ self.dropout = nn.Dropout(classifier_dropout)
761
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
762
+
763
+ # Initialize weights and apply final processing
764
+ self.post_init()
765
+
766
+ @classmethod
767
+ def from_composer(cls,
768
+ pretrained_checkpoint,
769
+ state_dict=None,
770
+ cache_dir=None,
771
+ from_tf=False,
772
+ config=None,
773
+ *inputs,
774
+ **kwargs):
775
+ """Load from pre-trained."""
776
+ model = cls(config, *inputs, **kwargs)
777
+ if from_tf:
778
+ raise ValueError(
779
+ 'TensorFlow is not supported.')
780
+
781
+ state_dict = torch.load(pretrained_checkpoint)
782
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
783
+ consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
784
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict,
785
+ strict=False)
786
+
787
+ if len(missing_keys) > 0:
788
+ logger.warning(
789
+ f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
790
+ )
791
+ if len(unexpected_keys) > 0:
792
+ logger.warning(
793
+ f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}"
794
+ )
795
+
796
+ return model
797
+
798
+ def forward(
799
+ self,
800
+ input_ids: Optional[torch.Tensor] = None,
801
+ attention_mask: Optional[torch.Tensor] = None,
802
+ token_type_ids: Optional[torch.Tensor] = None,
803
+ position_ids: Optional[torch.Tensor] = None,
804
+ head_mask: Optional[torch.Tensor] = None,
805
+ inputs_embeds: Optional[torch.Tensor] = None,
806
+ labels: Optional[torch.Tensor] = None,
807
+ output_attentions: Optional[bool] = None,
808
+ output_hidden_states: Optional[bool] = None,
809
+ return_dict: Optional[bool] = None,
810
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
811
+ # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
812
+ # Labels for computing the sequence classification/regression loss.
813
+ # Indices should be in `[0, ..., config.num_labels - 1]`.
814
+ # If `config.num_labels == 1` a regression loss is computed
815
+ # (mean-square loss). If `config.num_labels > 1` a classification loss
816
+ # is computed (cross-entropy).
817
+
818
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
819
+
820
+ outputs = self.bert(
821
+ input_ids,
822
+ attention_mask=attention_mask,
823
+ token_type_ids=token_type_ids,
824
+ position_ids=position_ids,
825
+ head_mask=head_mask,
826
+ inputs_embeds=inputs_embeds,
827
+ output_attentions=output_attentions,
828
+ output_hidden_states=output_hidden_states,
829
+ return_dict=return_dict,
830
+ )
831
+
832
+ pooled_output = outputs[1]
833
+
834
+ pooled_output = self.dropout(pooled_output)
835
+ logits = self.classifier(pooled_output)
836
+
837
+ loss = None
838
+ if labels is not None:
839
+ # Compute loss
840
+ if self.config.problem_type is None:
841
+ if self.num_labels == 1:
842
+ self.config.problem_type = 'regression'
843
+ elif self.num_labels > 1 and (labels.dtype == torch.long or
844
+ labels.dtype == torch.int):
845
+ self.config.problem_type = 'single_label_classification'
846
+ else:
847
+ self.config.problem_type = 'multi_label_classification'
848
+
849
+ if self.config.problem_type == 'regression':
850
+ loss_fct = nn.MSELoss()
851
+ if self.num_labels == 1:
852
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
853
+ else:
854
+ loss = loss_fct(logits, labels)
855
+ elif self.config.problem_type == 'single_label_classification':
856
+ loss_fct = nn.CrossEntropyLoss()
857
+ loss = loss_fct(logits.view(-1, self.num_labels),
858
+ labels.view(-1))
859
+ elif self.config.problem_type == 'multi_label_classification':
860
+ loss_fct = nn.BCEWithLogitsLoss()
861
+ loss = loss_fct(logits, labels)
862
+
863
+ if not return_dict:
864
+ output = (logits,) + outputs[2:]
865
+ return ((loss,) + output) if loss is not None else output
866
+
867
+ return SequenceClassifierOutput(
868
+ loss=loss,
869
+ logits=logits,
870
+ hidden_states=None,
871
+ attentions=None,
872
+ )
873
+
874
+ class BertForTextEncoding(BertPreTrainedModel):
875
+
876
+ def __init__(self, config):
877
+ super().__init__(config)
878
+
879
+ if config.is_decoder:
880
+ warnings.warn(
881
+ 'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
882
+ 'bi-directional self-attention.')
883
+
884
+ self.bert = BertModel(config, add_pooling_layer=True)
885
+
886
+ # Initialize weights and apply final processing
887
+ self.post_init()
888
+
889
+ @classmethod
890
+ def from_composer(cls,
891
+ pretrained_checkpoint,
892
+ state_dict=None,
893
+ cache_dir=None,
894
+ from_tf=False,
895
+ config=None,
896
+ *inputs,
897
+ **kwargs):
898
+ """Load from pre-trained."""
899
+ model = cls(config, *inputs, **kwargs)
900
+ if from_tf:
901
+ raise ValueError(
902
+ 'TensorFlow is not supported.')
903
+
904
+ state_dict = torch.load(pretrained_checkpoint)
905
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
906
+ consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
907
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict,
908
+ strict=False)
909
+
910
+ if len(missing_keys) > 0:
911
+ logger.warning(
912
+ f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
913
+ )
914
+ if len(unexpected_keys) > 0:
915
+ logger.warning(
916
+ f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}"
917
+ )
918
+
919
+ return model
920
+
921
+ def forward(
922
+ self,
923
+ input_ids: Optional[torch.Tensor] = None,
924
+ attention_mask: Optional[torch.Tensor] = None,
925
+ token_type_ids: Optional[torch.Tensor] = None,
926
+ position_ids: Optional[torch.Tensor] = None,
927
+ head_mask: Optional[torch.Tensor] = None,
928
+ inputs_embeds: Optional[torch.Tensor] = None,
929
+ encoder_hidden_states: Optional[torch.Tensor] = None,
930
+ encoder_attention_mask: Optional[torch.Tensor] = None,
931
+ labels: Optional[torch.Tensor] = None,
932
+ output_attentions: Optional[bool] = None,
933
+ output_hidden_states: Optional[bool] = None,
934
+ return_dict: Optional[bool] = None,
935
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
936
+
937
+ if (input_ids is not None) == (inputs_embeds is not None):
938
+ raise ValueError('Must specify either input_ids or input_embeds!')
939
+
940
+ if labels is None:
941
+ masked_tokens_mask = None
942
+ else:
943
+ masked_tokens_mask = labels > 0
944
+
945
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
946
+
947
+ outputs = self.bert(
948
+ input_ids,
949
+ attention_mask=attention_mask,
950
+ token_type_ids=token_type_ids,
951
+ position_ids=position_ids,
952
+ head_mask=head_mask,
953
+ inputs_embeds=inputs_embeds,
954
+ encoder_hidden_states=encoder_hidden_states,
955
+ encoder_attention_mask=encoder_attention_mask,
956
+ output_attentions=output_attentions,
957
+ output_hidden_states=output_hidden_states,
958
+ return_dict=return_dict,
959
+ masked_tokens_mask=masked_tokens_mask,
960
+ )
961
+
962
+ pooled_output = outputs[1]
963
+
964
+ return {"sentence_embedding": pooled_output}
965
+
966
+ def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
967
+ attention_mask: torch.Tensor,
968
+ **model_kwargs):
969
+ input_shape = input_ids.shape
970
+ effective_batch_size = input_shape[0]
971
+
972
+ # add a dummy token
973
+ if self.config.pad_token_id is None:
974
+ raise ValueError('The PAD token should be defined for generation')
975
+
976
+ attention_mask = torch.cat([
977
+ attention_mask,
978
+ attention_mask.new_zeros((attention_mask.shape[0], 1))
979
+ ], dim=-1)
980
+ dummy_token = torch.full((effective_batch_size, 1),
981
+ self.config.pad_token_id,
982
+ dtype=torch.long,
983
+ device=input_ids.device)
984
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
985
+
986
+ return {'input_ids': input_ids, 'attention_mask': attention_mask}
bert_padding.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
2
+ # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
3
+
4
+ """
5
+
6
+ Functions for padding and unpadding
7
+
8
+ """
9
+
10
+ from typing import Tuple, cast
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from einops import rearrange, repeat
15
+
16
+
17
+ class IndexFirstAxis(torch.autograd.Function):
18
+
19
+ @staticmethod
20
+ def forward(ctx, input: torch.Tensor,
21
+ indices: torch.Tensor) -> torch.Tensor:
22
+ """Get just the values of `input` which are at `indices`.
23
+
24
+ Arguments:
25
+ ctx: the autograd context object
26
+ input: (b, ...) 2+ dimensional tensor
27
+ indices: (num_idx) 1D tensor
28
+ """
29
+ ctx.save_for_backward(indices)
30
+ assert input.ndim >= 2
31
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[
32
+ 1:]
33
+ second_dim = other_shape.numel(
34
+ ) # product of sizes of all but first dimension
35
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
36
+ return torch.gather(
37
+ rearrange(input, 'b ... -> b (...)'), # (b, ...) -> (b, second_dim)
38
+ 0,
39
+ repeat(indices, 'z -> z d',
40
+ d=second_dim) # (indices,) -> (indices, second_dim)
41
+ ).reshape(-1, *other_shape) # (num_idx, ...)
42
+
43
+ @staticmethod
44
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
45
+ indices, = ctx.saved_tensors
46
+ assert grad_output.ndim >= 2
47
+ other_shape = grad_output.shape[1:]
48
+ grad_output = rearrange(grad_output, 'b ... -> b (...)')
49
+ grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]],
50
+ device=grad_output.device,
51
+ dtype=grad_output.dtype)
52
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
53
+ # grad_input[indices] = grad_output
54
+ grad_input.scatter_(0,
55
+ repeat(indices, 'z -> z d', d=grad_output.shape[1]),
56
+ grad_output)
57
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
58
+
59
+
60
+ index_first_axis = IndexFirstAxis.apply
61
+
62
+
63
+ class IndexPutFirstAxis(torch.autograd.Function):
64
+
65
+ @staticmethod
66
+ def forward(ctx, values: torch.Tensor, indices: torch.Tensor,
67
+ first_axis_dim) -> torch.Tensor:
68
+ ctx.save_for_backward(indices)
69
+ assert indices.ndim == 1
70
+ assert values.ndim >= 2
71
+ output = torch.zeros(first_axis_dim,
72
+ *values.shape[1:],
73
+ device=values.device,
74
+ dtype=values.dtype)
75
+ output[indices] = values
76
+ return output
77
+
78
+ @staticmethod
79
+ def backward(ctx,
80
+ grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
81
+ indices, = ctx.saved_tensors
82
+ grad_values = grad_output[indices]
83
+ return grad_values, None, None
84
+
85
+
86
+ index_put_first_axis = IndexPutFirstAxis.apply
87
+
88
+
89
+ def unpad_input(
90
+ hidden_states: torch.Tensor,
91
+ attention_mask: torch.Tensor,
92
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
93
+ """Remove padding from input sequences.
94
+
95
+ Arguments:
96
+ hidden_states: (batch, seqlen, ...)
97
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
98
+
99
+ Returns:
100
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
101
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
102
+ max_seqlen_in_batch: int
103
+ """
104
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
105
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
106
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
107
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32),
108
+ (1, 0))
109
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
110
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
111
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
112
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
113
+ # so we write custom forward and backward to make it a bit faster.
114
+ hidden_states = cast(
115
+ torch.Tensor,
116
+ index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'),
117
+ indices))
118
+ return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
119
+
120
+
121
+ def unpad_input_only(
122
+ hidden_states: torch.Tensor,
123
+ attention_mask: torch.Tensor,
124
+ ) -> torch.Tensor:
125
+ """Like unpad_input, but only return the unpadded first tensor.
126
+
127
+ Save a small amount of overhead.
128
+
129
+ Arguments:
130
+ hidden_states: (batch, seqlen, ...)
131
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
132
+
133
+ Returns:
134
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
135
+ """
136
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
137
+ rearranged = rearrange(hidden_states, 'b s ... -> (b s) ...')
138
+ return index_first_axis(rearranged, indices) # type: ignore
139
+
140
+
141
+ def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int,
142
+ seqlen: int) -> torch.Tensor:
143
+ """Add padding to sequences.
144
+
145
+ Arguments:
146
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
147
+ indices: (total_nnz)
148
+
149
+ Returns:
150
+ hidden_states: (batch, seqlen, ...)
151
+ """
152
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
153
+ return rearrange(output, '(b s) ... -> b s ...', b=batch) # type: ignore
blockdiag_linear.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/HazyResearch/fly/tree/master/src/models/layers
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+
8
+ from .structured_linear import StructuredLinear
9
+ from .blockdiag_multiply import blockdiag_multiply
10
+
11
+
12
+ class BlockdiagLinear(StructuredLinear):
13
+
14
+ def __init__(self, *args, nblocks=4, shuffle=False, **kwargs):
15
+ """shuffle: apply channel_shuffle operation before the matmul as in ShuffleNet
16
+ """
17
+ super().__init__(*args, **kwargs)
18
+ in_blksz = int(math.ceil(self.in_features / nblocks))
19
+ out_blksz = int(math.ceil(self.out_features / nblocks))
20
+ self.in_features_extended = in_blksz * nblocks
21
+ self.out_features_extended = out_blksz * nblocks
22
+ self.shuffle = shuffle
23
+ self.weight = nn.Parameter(torch.empty(nblocks, out_blksz, in_blksz))
24
+ self.reset_parameters()
25
+
26
+ def set_weights_from_dense_init(self, dense_init_fn_):
27
+ dense_weight = torch.empty(self.out_features_extended, self.in_features_extended,
28
+ device=self.weight.device, dtype=self.weight.dtype)
29
+ dense_init_fn_(dense_weight)
30
+ # Scale by sqrt because the weight is sparse
31
+ scaling = math.sqrt(dense_weight.numel() / self.weight.numel())
32
+ dense_weight *= scaling
33
+ with torch.no_grad():
34
+ nblocks = self.weight.shape[0]
35
+ self.weight.copy_(rearrange(dense_weight, '(b o) (b1 i) -> b b1 o i',
36
+ b=nblocks, b1=nblocks)[0])
37
+
38
+ @property
39
+ def saving(self):
40
+ return self.weight.numel() / (self.in_features * self.out_features)
41
+
42
+ def forward_matmul(self, x):
43
+ x = self.preprocess(x)
44
+ if self.shuffle:
45
+ x = rearrange(x, '... (group c_per_group) -> ... (c_per_group group)',
46
+ group=self.weight.shape[0]) # group=nblocks
47
+ output = blockdiag_multiply(x, self.weight)
48
+ return self.postprocess(output)
49
+
50
+
51
+ class BlockdiagSparsityConfig:
52
+
53
+ def __init__(self, nblocks, block=32, global_size=0):
54
+ """shuffle: apply channel_shuffle operation before the matmul as in ShuffleNet
55
+ """
56
+ self.nblocks = nblocks
57
+ self.block = block
58
+ self.global_size = global_size
59
+
60
+ def make_layout(self, out_features, in_features):
61
+ assert out_features % self.block == 0 and in_features % self.block == 0
62
+ assert out_features % self.nblocks == 0 and in_features % self.nblocks == 0
63
+ layout = torch.block_diag(*[torch.ones(out_features // self.nblocks,
64
+ in_features // self.nblocks,
65
+ dtype=torch.int32)] * self.nblocks)
66
+ if self.global_size > 0:
67
+ layout[:self.global_size] = 1
68
+ layout[:, :self.global_size] = 1
69
+ # Convert from (out_features, in_features) mask to
70
+ # (out_features // block, in_features // block) mask
71
+ layout = rearrange(layout, '(p blksz) (r blksz1) -> p r (blksz blksz1)',
72
+ blksz=self.block, blksz1=self.block)
73
+ return (layout > 0).any(dim=-1).int()
blockdiag_multiply.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/HazyResearch/fly/tree/master/src/models/layers
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from einops import rearrange
7
+
8
+
9
+ def blockdiag_weight_to_dense_weight(weight):
10
+ """
11
+ Argumments:
12
+ weight: (nblocks, out / nblocks, in / blocks)
13
+ Return:
14
+ dense_weight: (out / in)
15
+ """
16
+ return torch.block_diag(*torch.unbind(weight, dim=0))
17
+
18
+
19
+ def blockdiag_multiply_reference(x, weight):
20
+ """
21
+ This implementation is slow but more likely to be correct.
22
+ Arguments:
23
+ x: (..., n)
24
+ weight: (nblocks, q, n / nblocks)
25
+ Outputs:
26
+ out: (..., nblocks * q)
27
+ """
28
+ n = x.shape[-1]
29
+ nblocks, q, p = weight.shape
30
+ assert nblocks * p == n
31
+
32
+ x_reshaped = rearrange(x, '... (nblocks p) -> ... nblocks p', nblocks=nblocks)
33
+ return rearrange(torch.einsum('...kp, kqp -> ...kq', x_reshaped, weight),
34
+ '... nblocks q -> ... (nblocks q)')
35
+
36
+
37
+ class BlockdiagMultiply(torch.autograd.Function):
38
+
39
+ """This is a faster implementation, with careful memory copies for the fastest
40
+ bmm performance.
41
+ The backward pass is also written manually with careful memory copies.
42
+ Arguments:
43
+ x: (..., n)
44
+ weight: (nblocks, q, n / nblocks)
45
+ Outputs:
46
+ out: (..., nblocks * q)
47
+ """
48
+
49
+ @staticmethod
50
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.bfloat16)
51
+ def forward(ctx, x, weight):
52
+ ctx.save_for_backward(x, weight)
53
+ batch_shape, n = x.shape[:-1], x.shape[-1]
54
+ batch_dim = np.prod(batch_shape)
55
+ nblocks, q, p = weight.shape
56
+ assert nblocks * p == n
57
+ x_reshaped = x.reshape(batch_dim, nblocks, p).transpose(0, 1)
58
+ out = torch.empty(batch_dim, nblocks, q, device=x.device, dtype=x.dtype).transpose(0, 1)
59
+ out = torch.bmm(x_reshaped, weight.transpose(-1, -2), out=out).transpose(0, 1)
60
+ return out.reshape(*batch_shape, nblocks * q)
61
+
62
+ @staticmethod
63
+ @torch.cuda.amp.custom_bwd
64
+ def backward(ctx, dout):
65
+ x, weight = ctx.saved_tensors
66
+ batch_shape, n = x.shape[:-1], x.shape[-1]
67
+ batch_dim = np.prod(batch_shape)
68
+ nblocks, q, p = weight.shape
69
+ assert nblocks * p == n
70
+ dx, dweight = None, None
71
+ dout_reshaped = dout.reshape(batch_dim, nblocks, q).transpose(0, 1)
72
+ if ctx.needs_input_grad[0]:
73
+ dx = torch.empty(batch_dim, nblocks, p, device=x.device, dtype=x.dtype)
74
+ dx = torch.bmm(dout_reshaped, weight.conj(),
75
+ out=dx.transpose(0, 1)).transpose(0, 1).reshape(*batch_shape, n)
76
+ if ctx.needs_input_grad[1]:
77
+ x_reshaped = x.reshape(batch_dim, nblocks, p).transpose(0, 1)
78
+ dweight = torch.bmm(dout_reshaped.transpose(-1, -2), x_reshaped.conj())
79
+ return dx, dweight
80
+
81
+
82
+ blockdiag_multiply = BlockdiagMultiply.apply
config.json CHANGED
@@ -1,4 +1,48 @@
1
  {
2
- "model_type": "m2_bert"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  }
4
-
 
1
  {
2
+ "_name_or_path": "togethercomputer/m2-bert-80M-8k-retrieval",
3
+ "alibi_starting_size": 8192,
4
+ "architectures": [
5
+ "BertForSequenceClassification"
6
+ ],
7
+ "attention_probs_dropout_prob": 0.0,
8
+ "bidirectional": true,
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_bert.BertConfig",
11
+ "AutoModelForSequenceClassification": "bert_layers.BertForTextEncoding",
12
+ "AutoTokenizer": "bert-base-uncased"
13
+ },
14
+ "classifier_dropout": null,
15
+ "gradient_checkpointing": false,
16
+ "hidden_act": "gelu",
17
+ "hidden_dropout_prob": 0.1,
18
+ "hidden_size": 768,
19
+ "hyena_emb_dim": 5,
20
+ "hyena_filter_dropout": 0.2,
21
+ "hyena_filter_order": 128,
22
+ "hyena_lr_pos_emb": 1e-05,
23
+ "hyena_training_additions": false,
24
+ "hyena_w": 10,
25
+ "hyena_w_mod": 1,
26
+ "hyena_wd": 0.1,
27
+ "initializer_range": 0.02,
28
+ "intermediate_size": 3072,
29
+ "layer_norm_eps": 1e-12,
30
+ "long_conv_kernel_learning_rate": 0.001,
31
+ "long_conv_l_max": 8192,
32
+ "max_position_embeddings": 8192,
33
+ "model_type": "bert",
34
+ "monarch_mlp_nblocks": 4,
35
+ "num_attention_heads": 12,
36
+ "num_hidden_layers": 12,
37
+ "pad_token_id": 0,
38
+ "pool_all": false,
39
+ "position_embedding_type": "absolute",
40
+ "residual_long_conv": true,
41
+ "transformers_version": "4.28.1",
42
+ "type_vocab_size": 2,
43
+ "use_cache": true,
44
+ "use_glu_mlp": true,
45
+ "use_monarch_mlp": true,
46
+ "use_positional_encodings": true,
47
+ "vocab_size": 30528
48
  }
 
configuration_bert.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertConfig
2
+
3
+
4
+ class BertConfig(BertConfig):
5
+
6
+ def __init__(
7
+ self,
8
+ alibi_starting_size: int = 512,
9
+ attention_probs_dropout_prob: float = 0.0,
10
+
11
+ # mlp
12
+ use_glu_mlp: bool = True,
13
+ use_monarch_mlp: bool = False,
14
+ monarch_mlp_nblocks: int = 4,
15
+
16
+ # position
17
+ use_positional_encodings: bool = False,
18
+ max_position_embeddings: int = 512,
19
+
20
+ # architecture selection
21
+ residual_long_conv: bool = False,
22
+
23
+ # hyena and long conv hyperparameters
24
+ bidirectional: bool = True,
25
+ hyena_w_mod: int = 1,
26
+ hyena_filter_dropout: float = 0.2,
27
+ hyena_filter_order: int = 64,
28
+ hyena_training_additions: bool = False,
29
+
30
+ # efficiency
31
+ use_flash_mm: bool = False,
32
+
33
+ # average pooling instead of CLS token
34
+ pool_all: bool = False,
35
+
36
+ **kwargs,
37
+ ):
38
+ """Configuration class for MosaicBert.
39
+
40
+ Args:
41
+ alibi_starting_size (int): Use `alibi_starting_size` to determine how large of an alibi tensor to
42
+ create when initializing the model. You should be able to ignore this parameter in most cases.
43
+ Defaults to 512.
44
+ attention_probs_dropout_prob (float): By default, turn off attention dropout in Mosaic BERT.
45
+ Defaults to 0.0.
46
+ """
47
+ super().__init__(
48
+ attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
49
+ self.alibi_starting_size = alibi_starting_size
50
+
51
+ # mlp
52
+ self.use_glu_mlp = use_glu_mlp
53
+ self.use_monarch_mlp = use_monarch_mlp
54
+ self.monarch_mlp_nblocks = monarch_mlp_nblocks
55
+
56
+ # positional encodings
57
+ self.use_positional_encodings = use_positional_encodings
58
+ self.max_position_embeddings = max_position_embeddings
59
+
60
+ # architecture
61
+ self.residual_long_conv = residual_long_conv
62
+
63
+ # hyena and long conv hyperparameters
64
+ self.bidirectional = bidirectional
65
+ self.hyena_w_mod = hyena_w_mod
66
+ self.hyena_filter_dropout = hyena_filter_dropout
67
+ self.hyena_filter_order = hyena_filter_order
68
+ self.hyena_training_additions = hyena_training_additions
69
+
70
+ # efficiency
71
+ self.use_flash_mm = use_flash_mm
72
+
73
+ # average pooling instead of CLS token
74
+ self.pool_all = pool_all
75
+
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.28.1",
4
+ "use_cache": false,
5
+ "eos_token_id": [0, 50278]
6
+ }
hyena_utils.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Dan Fu and Simran Arora.
2
+ # Adapted from https://github.com/HazyResearch/safari/blob/main/src/models/sequence/hyena.py
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange
11
+ import opt_einsum as oe
12
+ contract = oe.contract
13
+
14
+ """ Utils for the training loop. Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py """
15
+
16
+ class OptimModule(nn.Module):
17
+ """ Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """
18
+
19
+ def register(self, name, tensor, lr=None, wd=0.0):
20
+ """Register a tensor with a configurable learning rate and 0 weight decay"""
21
+
22
+ if lr == 0.0:
23
+ self.register_buffer(name, tensor)
24
+ else:
25
+ self.register_parameter(name, nn.Parameter(tensor))
26
+
27
+ optim = {}
28
+ if lr is not None: optim["lr"] = lr
29
+ if wd is not None: optim["weight_decay"] = wd
30
+ setattr(getattr(self, name), "_optim", optim)
31
+
32
+
33
+ def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None):
34
+ # u.shape: B H L
35
+ seqlen = u.shape[-1]
36
+
37
+ fft_size = 2 * seqlen
38
+ k_f = torch.fft.rfft(k, n=fft_size) / fft_size
39
+ if k_rev is not None:
40
+ k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
41
+ k_f = k_f + k_rev_f.conj()
42
+ u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
43
+
44
+ if len(u.shape) > 3:
45
+ k_f = k_f.unsqueeze(1)
46
+
47
+ y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
48
+
49
+ out = y + u * D
50
+
51
+ if gelu:
52
+ out = F.gelu(out)
53
+ if dropout_mask is not None:
54
+ return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype)
55
+ else:
56
+ return out.to(dtype=u.dtype)
57
+
58
+
59
+ @torch.jit.script
60
+ def mul_sum(q, y):
61
+ return (q * y).sum(dim=1)
62
+
63
+
64
+ class Sin(nn.Module):
65
+ def __init__(self, dim, w=10, w_mod=1, train_freq=True):
66
+ super().__init__()
67
+
68
+ init_tensor = torch.ones(1, dim)
69
+ self.freq = (
70
+ nn.Parameter(w * init_tensor)
71
+ if train_freq
72
+ else w * torch.ones(1, dim)
73
+ )
74
+ self.w_mod = w_mod
75
+
76
+ def forward(self, x):
77
+ return torch.sin(self.w_mod * self.freq * x)
78
+
79
+
80
+ class PositionalEmbedding(OptimModule):
81
+ def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float = 1e-5, **kwargs):
82
+ """Complex exponential positional embeddings for Hyena filters."""
83
+ super().__init__()
84
+
85
+ self.seq_len = seq_len
86
+ # The time embedding fed to the filteres is normalized so that t_f = 1
87
+ t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
88
+
89
+ if emb_dim > 1:
90
+ bands = (emb_dim - 1) // 2
91
+ # To compute the right embeddings we use the "proper" linspace
92
+ t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
93
+ w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1
94
+
95
+ f = torch.linspace(1e-4, bands - 1, bands)[None, None]
96
+ z = torch.exp(-1j * f * w)
97
+ z = torch.cat([t, z.real, z.imag], dim=-1)
98
+ self.register("z", z, lr=lr_pos_emb)
99
+ self.register("t", t, lr=0.0)
100
+
101
+ def forward(self, L):
102
+ return self.z[:, :L], self.t[:, :L]
103
+
104
+
105
+ class ExponentialModulation(OptimModule):
106
+ def __init__(
107
+ self,
108
+ d_model,
109
+ fast_decay_pct=0.3,
110
+ slow_decay_pct=1.5,
111
+ target=1e-2,
112
+ modulation_lr=0.0,
113
+ shift: float = 0.0,
114
+ **kwargs,
115
+ ):
116
+ super().__init__()
117
+ self.shift = shift
118
+ max_decay = math.log(target) / fast_decay_pct
119
+ min_decay = math.log(target) / slow_decay_pct
120
+ deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]
121
+ self.register("deltas", deltas, lr=modulation_lr)
122
+
123
+ def forward(self, t, x):
124
+ decay = torch.exp(-t * self.deltas.abs())
125
+ x = x * (decay + self.shift)
126
+ return x
127
+
128
+
129
+ class HyenaFilter(OptimModule):
130
+ def __init__(
131
+ self,
132
+ d_model,
133
+ emb_dim=3, # dim of input to MLP, augments with positional encoding
134
+ order=16, # width of the implicit MLP
135
+ seq_len=1024,
136
+ lr=1e-3,
137
+ lr_pos_emb=1e-5,
138
+ dropout=0.0,
139
+ w=1, # frequency of periodic activations
140
+ w_mod=1, # non-learnable modification of w
141
+ wd=0, # weight decay of kernel parameters
142
+ bias=True,
143
+ num_inner_mlps=2,
144
+ linear_mixer=False,
145
+ modulate: bool = True,
146
+ normalized=False,
147
+ bidirectional=False,
148
+ **kwargs,
149
+ ):
150
+ """
151
+ Implicit long filter with modulation.
152
+
153
+ Args:
154
+ d_model: number of channels in the input
155
+ emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands
156
+ order: width of the FFN
157
+ num_inner_mlps: number of inner linear layers inside filter MLP
158
+
159
+ Note:
160
+ filter_dropout is not implemented
161
+ """
162
+ super().__init__()
163
+
164
+ self.d_model=d_model
165
+ self.emb_dim=emb_dim
166
+ self.seq_len=seq_len
167
+ self.modulate=modulate
168
+ self.use_bias = bias
169
+ self.bidirectional = bidirectional
170
+
171
+ self.bias = nn.Parameter(torch.randn(self.d_model))
172
+ self.dropout = nn.Dropout(dropout)
173
+
174
+ act = Sin(dim=order, w=w, w_mod=w_mod)
175
+ assert (
176
+ emb_dim % 2 != 0 and emb_dim >= 3
177
+ ), "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)"
178
+ self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb)
179
+
180
+ # uses a variable number of inner linear layers
181
+ if linear_mixer is False:
182
+ self.implicit_filter = nn.Sequential(
183
+ nn.Linear(emb_dim, order),
184
+ act,
185
+ )
186
+ for i in range(num_inner_mlps):
187
+ self.implicit_filter.append(nn.Linear(order, order))
188
+ self.implicit_filter.append(act)
189
+ self.implicit_filter.append(nn.Linear(order, d_model, bias=False))
190
+ else:
191
+ self.implicit_filter = nn.Sequential(
192
+ nn.Linear(emb_dim, d_model, bias=False),
193
+ )
194
+
195
+ if self.bidirectional:
196
+ self.implicit_filter_rev = nn.Sequential(
197
+ nn.Linear(emb_dim, order),
198
+ act,
199
+ )
200
+ for i in range(num_inner_mlps):
201
+ self.implicit_filter_rev.append(nn.Linear(order, order))
202
+ self.implicit_filter_rev.append(act)
203
+ self.implicit_filter_rev.append(nn.Linear(order, d_model, bias=False))
204
+
205
+ self.modulation = ExponentialModulation(d_model, **kwargs)
206
+
207
+ self.normalized = normalized
208
+ for c in self.implicit_filter.children():
209
+ for name, v in c.state_dict().items():
210
+ optim = {"weight_decay": wd, "lr": lr}
211
+ setattr(getattr(c, name), "_optim", optim)
212
+
213
+ def filter(self, L, *args, **kwargs):
214
+ z, t = self.pos_emb(L)
215
+ h = self.implicit_filter(z)
216
+ if self.modulate:
217
+ h = self.modulation(t, h)
218
+ if self.normalized:
219
+ h = h / torch.norm(h, dim=-1, p=1, keepdim=True)
220
+ return h
221
+
222
+ def filter_rev(self, L, *args, **kwargs):
223
+ z, t = self.pos_emb(L)
224
+ h = self.implicit_filter_rev(z)
225
+ if self.modulate:
226
+ h = self.modulation(t, h)
227
+ if self.normalized:
228
+ h = h / torch.norm(h, dim=-1, p=1, keepdim=True)
229
+ return h
230
+
231
+ def forward(self, x, L, k_fwd=None, k_rev=None, bias=None, *args, **kwargs):
232
+ if k_fwd is None:
233
+ k_fwd = self.filter(L)
234
+ if self.bidirectional and k_rev is None:
235
+ k_rev = self.filter_rev(L)
236
+
237
+ # Ensure compatibility with filters that return a tuple
238
+ k_fwd = k_fwd[0] if type(k_fwd) is tuple else k_fwd
239
+ if bias is None:
240
+ bias = self.bias
241
+ bias = bias if self.use_bias else 0 * bias
242
+
243
+ if self.bidirectional:
244
+ k_rev = k_rev[0] if type(k_rev) is tuple else k_rev
245
+ k = F.pad(k_fwd, (0, L)) \
246
+ + F.pad(k_rev.flip(-1), (L, 0))
247
+ else:
248
+ k = k_fwd
249
+
250
+
251
+ y = fftconv_ref(
252
+ x,
253
+ k,
254
+ bias,
255
+ dropout_mask=None,
256
+ gelu=False,
257
+ )
258
+
259
+ return y.to(dtype=x.dtype)
monarch_mixer_sequence_mixer.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Dan Fu and Simran Arora.
2
+ # Adapted from https://github.com/HazyResearch/safari/blob/main/src/models/sequence/hyena.py
3
+
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ import opt_einsum as oe
7
+
8
+ contract = oe.contract
9
+ from .hyena_utils import HyenaFilter
10
+
11
+
12
+ class MonarchMixerSequenceMixing(nn.Module):
13
+ def __init__(
14
+ self,
15
+ d_model,
16
+ l_max=128,
17
+ dropout=0.0,
18
+ hyena_kernel_lr=None,
19
+ bidirectional=False,
20
+ hyena_lr_pos_emb=1e-5,
21
+ hyena_w=10,
22
+ hyena_w_mod=1,
23
+ hyena_wd=0.1,
24
+ hyena_emb_dim=3,
25
+ hyena_filter_dropout=0.0,
26
+ hyena_filter_order=16,
27
+ residual_long_conv=False,
28
+ hyena_training_additions=False,
29
+ ):
30
+ super().__init__()
31
+
32
+ self.d_model = d_model
33
+ self.l_max = l_max
34
+ self.kernel_lr = hyena_kernel_lr
35
+ self.channels = 1
36
+ self.bidirectional = bidirectional
37
+ self.residual_long_conv = residual_long_conv
38
+ self.NUM_PROJECTIONS = 3
39
+
40
+ print('-- Bidirectional:', self.bidirectional)
41
+ print("-- Using Long Conv Residual:", self.residual_long_conv)
42
+ print('-- Hyena w:', hyena_w)
43
+ print('-- Hyena w mod:', hyena_w_mod)
44
+ print(f"-- Hyena filter order: {hyena_filter_order}")
45
+ print(f"-- Hyena filter dropout: {hyena_filter_dropout}")
46
+ print(f"-- Hyena filter wd: {hyena_wd}")
47
+ print(f"-- Hyena filter emb dim: {hyena_emb_dim}")
48
+ print(f"-- Hyena filter lr: {hyena_kernel_lr}")
49
+ print(f"-- Hyena filter lr pos emb: {hyena_lr_pos_emb}")
50
+
51
+ self.filter_fn = HyenaFilter(
52
+ self.d_model,
53
+ order=hyena_filter_order,
54
+ seq_len=self.l_max,
55
+ dropout=hyena_filter_dropout,
56
+ bidirectional=self.bidirectional,
57
+ lr=hyena_kernel_lr,
58
+ lr_pos_emb=hyena_lr_pos_emb,
59
+ w=hyena_w, # frequency of periodic activations
60
+ w_mod=hyena_w_mod,
61
+ wd=hyena_wd, # weight decay of kernel parameters
62
+ emb_dim=hyena_emb_dim,
63
+ )
64
+
65
+ if self.residual_long_conv:
66
+ self.filter_fn2 = HyenaFilter(
67
+ self.d_model,
68
+ order=hyena_filter_order,
69
+ seq_len=self.l_max,
70
+ dropout=hyena_filter_dropout,
71
+ bidirectional=self.bidirectional,
72
+ lr=hyena_kernel_lr,
73
+ lr_pos_emb=hyena_lr_pos_emb,
74
+ w=hyena_w, # frequency of periodic activations
75
+ w_mod=hyena_w_mod,
76
+ wd=hyena_wd, # weight decay of kernel parameters
77
+ emb_dim=hyena_emb_dim,
78
+ )
79
+
80
+ # setup projections
81
+ self.in_linear = nn.Linear(d_model, 3 * d_model)
82
+ self.out_linear = nn.Linear(d_model, d_model)
83
+ self.hyena_training_additions = hyena_training_additions
84
+ if self.hyena_training_additions:
85
+ self.act = nn.Identity()
86
+ self.drop = nn.Dropout(dropout)
87
+ self.layernorm = nn.LayerNorm(d_model)
88
+
89
+ # setup short conv
90
+ total_width = self.d_model * self.NUM_PROJECTIONS
91
+ self.short_filter = nn.Conv1d(
92
+ in_channels=total_width,
93
+ out_channels=total_width,
94
+ kernel_size=3,
95
+ groups=total_width,
96
+ padding=2,
97
+ )
98
+
99
+
100
+ def forward(self, u, **kwargs):
101
+ # u is B L H
102
+ if self.hyena_training_additions:
103
+ u = self.layernorm(u)
104
+ L = u.size(-2)
105
+
106
+ # in projection
107
+ u_orig = u
108
+ u = self.in_linear(u)
109
+ u = rearrange(u, "b l d -> b d l")
110
+
111
+ # short filter
112
+ uc = self.short_filter(u)[..., :L]
113
+
114
+ x1, x2, v = uc.split(self.d_model, dim=1)
115
+
116
+ v = v * x1
117
+ if self.hyena_training_additions:
118
+ v = self.drop(v)
119
+
120
+ k = self.filter_fn.filter(L, device=u.device)
121
+ k = rearrange(k, "c l d -> c d l")[0] # `c` is always 1 by default
122
+
123
+ if self.bidirectional:
124
+ k_rev = self.filter_fn.filter_rev(L, device=u.device)
125
+ k_rev = rearrange(k_rev, "c l d -> c d l")[0] # `c` is always 1 by default
126
+ else:
127
+ k_rev = None
128
+
129
+ y = self.filter_fn(v, L, k_fwd=k, k_rev=k_rev, bias= self.filter_fn.bias[None, :, None])
130
+
131
+ if self.residual_long_conv:
132
+ k2 = self.filter_fn2.filter(L, device=u.device)
133
+ k2 = rearrange(k2, "c l d -> c d l")[0]
134
+
135
+ if self.bidirectional:
136
+ k2_rev = self.filter_fn2.filter_rev(L, device=u.device)
137
+ k2_rev = rearrange(k2_rev, "c l d -> c d l")[0] # `c` is always 1 by default
138
+ else:
139
+ k2_rev = None
140
+
141
+ yu = self.filter_fn2(u_orig.transpose(-1, -2), L, k_fwd=k2, k_rev=k2_rev, bias= self.filter_fn2.bias[None, :, None])
142
+
143
+ # post gating
144
+ y = y * x2
145
+
146
+ if self.residual_long_conv:
147
+ y = y + yu
148
+
149
+ y = y.transpose(-1, -2)
150
+ if self.hyena_training_additions:
151
+ y = self.drop(self.act(y))
152
+ y = self.out_linear(y)
153
+
154
+ return y, None
155
+
156
+
structured_linear.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/HazyResearch/fly/tree/master/src/models/layers
2
+
3
+ import math
4
+ from functools import partial
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.nn import init
9
+
10
+
11
+ class StructuredLinear(nn.Module):
12
+
13
+ def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
14
+ """Subclasses should call reset_parameters
15
+ """
16
+ factory_kwargs = {'device': device, 'dtype': dtype}
17
+ super().__init__()
18
+ self.in_features = in_features
19
+ self.out_features = out_features
20
+ # Subclasses may override {in,out}_features_extended
21
+ if not hasattr(self, 'in_features_extended'):
22
+ self.in_features_extended = in_features
23
+ if not hasattr(self, 'out_features_extended'):
24
+ self.out_features_extended = out_features
25
+ if bias:
26
+ self.bias = nn.Parameter(torch.zeros(out_features, **factory_kwargs))
27
+ else:
28
+ self.register_parameter('bias', None)
29
+
30
+ def reset_parameters(self) -> None:
31
+ self.set_weights_from_dense_init(dense_init_fn_=partial(init.kaiming_uniform_, a=math.sqrt(5)))
32
+ self.reset_parameters_bias()
33
+
34
+ def set_weights_from_dense_init(self, dense_init_fn_):
35
+ raise NotImplementedError
36
+
37
+ def reset_parameters_bias(self):
38
+ if self.bias is not None:
39
+ fan_in = self.bias.shape[-1]
40
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
41
+ init.uniform_(self.bias, -bound, bound)
42
+
43
+ @property
44
+ def saving(self):
45
+ raise NotImplementedError
46
+
47
+ def convert_to_dense_weight(self):
48
+ factory_kwargs = {'device': self.weight.device, 'dtype': self.weight.dtype}
49
+ dense_weight = self.forward_matmul(torch.eye(self.in_features, **factory_kwargs)).T
50
+ return dense_weight
51
+
52
+ def preprocess(self, x):
53
+ in_features = x.shape[-1]
54
+ if in_features < self.in_features_extended:
55
+ x = F.pad(x, (0, self.in_features_extended - in_features))
56
+ return x
57
+
58
+ def postprocess(self, output):
59
+ out_features_extended = output.shape[-1]
60
+ if out_features_extended > self.out_features:
61
+ output = output[..., :self.out_features]
62
+ return output
63
+
64
+ def forward_matmul(self, x):
65
+ raise NotImplementedError
66
+
67
+ def forward(self, x):
68
+ output = self.forward_matmul(x)
69
+ # Convert bias to output.dtype in case of AMP, otherwise bias and activation will be in FP32
70
+ return (output + self.bias.to(dtype=output.dtype)) if self.bias is not None else output