jupyterjazz commited on
Commit
4bfe854
1 Parent(s): 97c58e3

upload files

Browse files

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>

Files changed (12) hide show
  1. README.md +104 -0
  2. block.py +413 -0
  3. configuration_xlm_roberta.py +128 -0
  4. convert_roberta_weights_to_flash.py +170 -0
  5. embedding.py +76 -0
  6. mha.py +806 -0
  7. mlp.py +219 -0
  8. modeling_lora.py +401 -0
  9. modeling_xlm_roberta.py +1208 -0
  10. rotary.py +658 -0
  11. stochastic_depth.py +97 -0
  12. xlm_padding.py +229 -0
README.md ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - transformers
4
+ - xlm-roberta
5
+ library_name: transformers
6
+ license: cc-by-nc-4.0
7
+ language:
8
+ - multilingual
9
+ - af
10
+ - am
11
+ - ar
12
+ - as
13
+ - az
14
+ - be
15
+ - bg
16
+ - bn
17
+ - br
18
+ - bs
19
+ - ca
20
+ - cs
21
+ - cy
22
+ - da
23
+ - de
24
+ - el
25
+ - en
26
+ - eo
27
+ - es
28
+ - et
29
+ - eu
30
+ - fa
31
+ - fi
32
+ - fr
33
+ - fy
34
+ - ga
35
+ - gd
36
+ - gl
37
+ - gu
38
+ - ha
39
+ - he
40
+ - hi
41
+ - hr
42
+ - hu
43
+ - hy
44
+ - id
45
+ - is
46
+ - it
47
+ - ja
48
+ - jv
49
+ - ka
50
+ - kk
51
+ - km
52
+ - kn
53
+ - ko
54
+ - ku
55
+ - ky
56
+ - la
57
+ - lo
58
+ - lt
59
+ - lv
60
+ - mg
61
+ - mk
62
+ - ml
63
+ - mn
64
+ - mr
65
+ - ms
66
+ - my
67
+ - ne
68
+ - nl
69
+ - 'no'
70
+ - om
71
+ - or
72
+ - pa
73
+ - pl
74
+ - ps
75
+ - pt
76
+ - ro
77
+ - ru
78
+ - sa
79
+ - sd
80
+ - si
81
+ - sk
82
+ - sl
83
+ - so
84
+ - sq
85
+ - sr
86
+ - su
87
+ - sv
88
+ - sw
89
+ - ta
90
+ - te
91
+ - th
92
+ - tl
93
+ - tr
94
+ - ug
95
+ - uk
96
+ - ur
97
+ - uz
98
+ - vi
99
+ - xh
100
+ - yi
101
+ - zh
102
+ ---
103
+
104
+ Modified version of https://huggingface.co/jinaai/xlm-roberta-flash-implementation for the onnx conversion
block.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2
+ # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
+
4
+ # Copyright (c) 2024, Tri Dao.
5
+
6
+ from functools import partial
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import Tensor
12
+
13
+ from .mha import MHA
14
+ from .mlp import Mlp
15
+ from .stochastic_depth import StochasticDepth
16
+
17
+ try:
18
+ from flash_attn.ops.triton.layer_norm import RMSNorm, layer_norm_fn
19
+ except ImportError:
20
+ layer_norm_fn, RMSNorm = None, None
21
+
22
+
23
+ class Block(nn.Module):
24
+ def __init__(
25
+ self,
26
+ dim,
27
+ mixer_cls=None,
28
+ mlp_cls=None,
29
+ norm_cls=nn.LayerNorm,
30
+ dropout_cls=nn.Dropout,
31
+ prenorm=True,
32
+ resid_dropout1=0.0,
33
+ resid_dropout2=0.0,
34
+ drop_path1=0.0,
35
+ drop_path2=0.0,
36
+ fused_dropout_add_ln=False,
37
+ return_residual=False,
38
+ residual_in_fp32=False,
39
+ sequence_parallel=False,
40
+ mark_shared_params=False,
41
+ ):
42
+ """
43
+ For prenorm=True, this Block has a slightly different structure compared to a regular
44
+ prenorm Transformer block.
45
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
46
+ [Ref: https://arxiv.org/abs/2002.04745]
47
+ Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
48
+ the hidden_states (output of the MLP) and the residual.
49
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
50
+ The residual needs to be provided (except for the very first block).
51
+
52
+ For prenorm=False, this Block has the same structure as a regular postnorm Transformer
53
+ block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
54
+
55
+ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
56
+ This is for performance reason: for post-norm architecture, returning the input allows us
57
+ to fuse the backward of nn.Linear with the residual connection.
58
+ """
59
+ super().__init__()
60
+ self.prenorm = prenorm
61
+ self.fused_dropout_add_ln = fused_dropout_add_ln
62
+ self.return_residual = return_residual
63
+ self.residual_in_fp32 = residual_in_fp32
64
+ if self.residual_in_fp32:
65
+ assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
66
+ if mixer_cls is None:
67
+ mixer_cls = partial(MHA, num_heads=dim // 64)
68
+ if mlp_cls is None:
69
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
70
+ self.mixer = mixer_cls(dim)
71
+ self.dropout1 = dropout_cls(resid_dropout1)
72
+ self.drop_path1 = StochasticDepth(drop_path1, mode="row")
73
+ self.norm1 = norm_cls(dim)
74
+ self.mlp = mlp_cls(dim)
75
+ if not isinstance(self.mlp, nn.Identity):
76
+ self.dropout2 = dropout_cls(resid_dropout2)
77
+ self.drop_path2 = StochasticDepth(drop_path2, mode="row")
78
+ self.norm2 = norm_cls(dim)
79
+
80
+ if self.fused_dropout_add_ln:
81
+ assert layer_norm_fn is not None, "Triton is not installed"
82
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
83
+ self.dropout1, nn.Dropout
84
+ )
85
+
86
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
87
+ # then the input to each worker in the tensor parallel group will be different.
88
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
89
+ # For now this is not an issue because we always use sequence_parallel=True during training
90
+ # and only use sequence_parallel=False during inference.
91
+
92
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
93
+ if sequence_parallel:
94
+ for p in self.norm1.parameters():
95
+ p._sequence_parallel = True
96
+ if hasattr(self, "norm2"):
97
+ for p in self.norm2.parameters():
98
+ p._sequence_parallel = True
99
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
100
+ if mark_shared_params:
101
+ for p in self.norm1.parameters():
102
+ p._shared_params = True
103
+ if hasattr(self, "norm2"):
104
+ for p in self.norm2.parameters():
105
+ p._shared_params = True
106
+
107
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
108
+ return self.mixer.allocate_inference_cache(
109
+ batch_size, max_seqlen, dtype=dtype, **kwargs
110
+ )
111
+
112
+ def forward(
113
+ self,
114
+ hidden_states: Tensor,
115
+ residual: Optional[Tensor] = None,
116
+ mixer_subset=None,
117
+ mixer_kwargs=None,
118
+ ):
119
+ r"""Pass the input through the encoder layer.
120
+
121
+ Args:
122
+ hidden_states: the sequence to the encoder layer (required).
123
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
124
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
125
+ before applying the query projection. Useful for e.g., ViT where we only care
126
+ about the CLS token in the last layer.
127
+ """
128
+ if self.prenorm:
129
+ if not self.fused_dropout_add_ln:
130
+ dropped = self.drop_path1(self.dropout1(hidden_states))
131
+ residual = (dropped + residual) if residual is not None else dropped
132
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
133
+ if self.residual_in_fp32:
134
+ residual = residual.to(torch.float32)
135
+ else:
136
+ if self.drop_path1.p == 0 or not self.training:
137
+ rowscale1 = None
138
+ else:
139
+ rowscale1 = self.drop_path1(
140
+ torch.ones(
141
+ hidden_states.shape[:-1],
142
+ device=hidden_states.device,
143
+ dtype=hidden_states.dtype,
144
+ )
145
+ )
146
+ hidden_states, residual = layer_norm_fn(
147
+ hidden_states,
148
+ self.norm1.weight,
149
+ self.norm1.bias,
150
+ residual=residual,
151
+ eps=self.norm1.eps,
152
+ dropout_p=self.dropout1.p if self.training else 0.0,
153
+ rowscale=rowscale1,
154
+ prenorm=True,
155
+ residual_in_fp32=self.residual_in_fp32,
156
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
157
+ )
158
+ if mixer_kwargs is None:
159
+ mixer_kwargs = {}
160
+ if mixer_subset is not None:
161
+ mixer_kwargs["mixer_subset"] = mixer_subset
162
+ hidden_states = self.mixer(hidden_states, **mixer_kwargs)
163
+ if mixer_subset is not None:
164
+ residual = residual[:, mixer_subset]
165
+ if not isinstance(self.mlp, nn.Identity):
166
+ if not self.fused_dropout_add_ln:
167
+ dropped = self.drop_path2(self.dropout2(hidden_states))
168
+ residual = (dropped + residual) if residual is not None else dropped
169
+ hidden_states = self.norm2(
170
+ residual.to(dtype=self.norm2.weight.dtype)
171
+ )
172
+ if self.residual_in_fp32:
173
+ residual = residual.to(torch.float32)
174
+ else:
175
+ if self.drop_path2.p == 0 or not self.training:
176
+ rowscale2 = None
177
+ else:
178
+ rowscale2 = self.drop_path2(
179
+ torch.ones(
180
+ hidden_states.shape[:-1],
181
+ device=hidden_states.device,
182
+ dtype=hidden_states.dtype,
183
+ )
184
+ )
185
+ hidden_states, residual = layer_norm_fn(
186
+ hidden_states,
187
+ self.norm2.weight,
188
+ self.norm2.bias,
189
+ residual=residual,
190
+ eps=self.norm2.eps,
191
+ dropout_p=self.dropout2.p if self.training else 0.0,
192
+ rowscale=rowscale2,
193
+ prenorm=True,
194
+ residual_in_fp32=self.residual_in_fp32,
195
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
196
+ )
197
+ hidden_states = self.mlp(hidden_states)
198
+ return hidden_states, residual
199
+ else:
200
+ assert residual is None
201
+ mixer_out = self.mixer(
202
+ hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
203
+ )
204
+ if self.return_residual: # mixer out is actually a pair here
205
+ mixer_out, hidden_states = mixer_out
206
+ if not self.fused_dropout_add_ln:
207
+ hidden_states = self.norm1(
208
+ (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
209
+ dtype=self.norm1.weight.dtype
210
+ )
211
+ )
212
+ else:
213
+ if self.drop_path1.p == 0 or not self.training:
214
+ rowscale1 = None
215
+ else:
216
+ rowscale1 = self.drop_path1(
217
+ torch.ones(
218
+ mixer_out.shape[:-1],
219
+ device=mixer_out.device,
220
+ dtype=mixer_out.dtype,
221
+ )
222
+ )
223
+ hidden_states = layer_norm_fn(
224
+ mixer_out,
225
+ self.norm1.weight,
226
+ self.norm1.bias,
227
+ residual=hidden_states,
228
+ eps=self.norm1.eps,
229
+ dropout_p=self.dropout1.p if self.training else 0.0,
230
+ rowscale=rowscale1,
231
+ prenorm=False,
232
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
233
+ )
234
+ if not isinstance(self.mlp, nn.Identity):
235
+ mlp_out = self.mlp(
236
+ hidden_states, task_id=mixer_kwargs.get("task_id")
237
+ )
238
+ if self.return_residual: # mlp out is actually a pair here
239
+ mlp_out, hidden_states = mlp_out
240
+ if not self.fused_dropout_add_ln:
241
+ hidden_states = self.norm2(
242
+ (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
243
+ dtype=self.norm2.weight.dtype
244
+ )
245
+ )
246
+ else:
247
+ if self.drop_path2.p == 0 or not self.training:
248
+ rowscale2 = None
249
+ else:
250
+ rowscale2 = self.drop_path2(
251
+ torch.ones(
252
+ mlp_out.shape[:-1],
253
+ device=mlp_out.device,
254
+ dtype=mlp_out.dtype,
255
+ )
256
+ )
257
+ hidden_states = layer_norm_fn(
258
+ mlp_out,
259
+ self.norm2.weight,
260
+ self.norm2.bias,
261
+ residual=hidden_states,
262
+ eps=self.norm2.eps,
263
+ dropout_p=self.dropout2.p if self.training else 0.0,
264
+ rowscale=rowscale2,
265
+ prenorm=False,
266
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
267
+ )
268
+ return hidden_states
269
+
270
+
271
+ class ParallelBlock(nn.Module):
272
+ """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
273
+ and PaLM.
274
+ """
275
+
276
+ def __init__(
277
+ self,
278
+ dim,
279
+ mixer_cls=None,
280
+ mlp_cls=None,
281
+ norm_cls=nn.LayerNorm,
282
+ dropout_cls=nn.Dropout,
283
+ resid_dropout1=0.0,
284
+ resid_dropout2=0.0,
285
+ tied_norm=False,
286
+ fused_dropout_add_ln=False,
287
+ residual_in_fp32=False,
288
+ sequence_parallel=False,
289
+ mark_shared_params=False,
290
+ ):
291
+ """
292
+ This Block has a slightly different structure compared to a regular
293
+ prenorm Transformer block.
294
+ The standard block is: LN -> MHA / MLP -> Dropout -> Add.
295
+ [Ref: https://arxiv.org/abs/2002.04745]
296
+ Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
297
+ the hidden_states (output1 of the MHA / MLP) and the residual.
298
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
299
+ The residual needs to be provided (except for the very first block).
300
+ """
301
+ super().__init__()
302
+ self.tied_norm = tied_norm
303
+ self.fused_dropout_add_ln = fused_dropout_add_ln
304
+ self.residual_in_fp32 = residual_in_fp32
305
+ if mixer_cls is None:
306
+ mixer_cls = partial(MHA, num_heads=dim // 64)
307
+ if mlp_cls is None:
308
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
309
+ self.mixer = mixer_cls(dim)
310
+ self.dropout1 = dropout_cls(resid_dropout1)
311
+ self.norm1 = norm_cls(dim)
312
+ self.mlp = mlp_cls(dim)
313
+ self.dropout2 = dropout_cls(resid_dropout2)
314
+ if not self.tied_norm:
315
+ self.norm2 = norm_cls(dim)
316
+
317
+ if self.fused_dropout_add_ln:
318
+ assert layer_norm_fn is not None, "Triton is not installed"
319
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
320
+ self.dropout1, nn.Dropout
321
+ )
322
+
323
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
324
+ # then the input to each worker in the tensor parallel group will be different.
325
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
326
+ # For now this is not an issue because we always use sequence_parallel=True during training
327
+ # and only use sequence_parallel=False during inference.
328
+
329
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
330
+ if sequence_parallel:
331
+ for p in self.norm1.parameters():
332
+ p._sequence_parallel = True
333
+ if hasattr(self, "norm2"):
334
+ for p in self.norm2.parameters():
335
+ p._sequence_parallel = True
336
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
337
+ if mark_shared_params:
338
+ for p in self.norm1.parameters():
339
+ p._shared_params = True
340
+ if hasattr(self, "norm2"):
341
+ for p in self.norm2.parameters():
342
+ p._shared_params = True
343
+
344
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
345
+ return self.mixer.allocate_inference_cache(
346
+ batch_size, max_seqlen, dtype=dtype, **kwargs
347
+ )
348
+
349
+ def forward(
350
+ self,
351
+ hidden_states1: Tensor,
352
+ hidden_states2: Optional[Tensor] = None,
353
+ residual: Optional[Tensor] = None,
354
+ mixer_kwargs=None,
355
+ ):
356
+ r"""Pass the input through the encoder layer.
357
+
358
+ Args:
359
+ hidden_states1: the output of the previous attention (mixer) or embedding layer.
360
+ hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
361
+ residual.
362
+ """
363
+ # TODO: Ideally we should only do the allgather / allreduce once for
364
+ # the Linear to MLP & Attention
365
+ if not self.fused_dropout_add_ln:
366
+ dropped1 = self.dropout1(hidden_states1)
367
+ # For the very 1st block, we only want 1 dropout, not two different dropouts
368
+ if hidden_states2 is not None:
369
+ dropped2 = self.dropout2(hidden_states2)
370
+ residual = (
371
+ (residual + dropped1 + dropped2)
372
+ if residual is not None
373
+ else dropped1 + dropped2
374
+ )
375
+ else:
376
+ residual = (residual + dropped1) if residual is not None else dropped1
377
+ hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
378
+ hidden_states2 = (
379
+ self.norm2(residual.to(dtype=self.norm2.weight.dtype))
380
+ if not self.tied_norm
381
+ else hidden_states1
382
+ )
383
+ if self.residual_in_fp32:
384
+ residual = residual.to(torch.float32)
385
+ else:
386
+ weight2, bias2 = (
387
+ (self.norm2.weight, self.norm2.bias)
388
+ if not self.tied_norm
389
+ else (None, None)
390
+ )
391
+ hidden_states1, *rest, residual = layer_norm_fn(
392
+ hidden_states1,
393
+ self.norm1.weight,
394
+ self.norm1.bias,
395
+ residual=residual,
396
+ x1=hidden_states2,
397
+ weight1=weight2,
398
+ bias1=bias2,
399
+ eps=self.norm1.eps,
400
+ dropout_p=self.dropout1.p if self.training else 0.0,
401
+ prenorm=True,
402
+ residual_in_fp32=self.residual_in_fp32,
403
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
404
+ )
405
+ if self.tied_norm:
406
+ hidden_states2 = hidden_states1
407
+ else:
408
+ (hidden_states2,) = rest
409
+ if mixer_kwargs is None:
410
+ mixer_kwargs = {}
411
+ hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
412
+ hidden_states2 = self.mlp(hidden_states2)
413
+ return hidden_states1, hidden_states2, residual
configuration_xlm_roberta.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Union
2
+
3
+ import torch
4
+ from transformers import PretrainedConfig
5
+
6
+
7
+ class XLMRobertaFlashConfig(PretrainedConfig):
8
+
9
+ model_type = "xlm-roberta"
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size: int = 250002,
14
+ hidden_size: int = 1024,
15
+ num_hidden_layers: int = 24,
16
+ num_attention_heads: int = 16,
17
+ intermediate_size: int = 4096,
18
+ hidden_act: str = "gelu",
19
+ hidden_dropout_prob: float = 0.1,
20
+ attention_probs_dropout_prob: float = 0.1,
21
+ max_position_embeddings: int = 8194,
22
+ type_vocab_size: int = 1,
23
+ initializer_range: float = 0.02,
24
+ layer_norm_eps: float = 1e-05,
25
+ pad_token_id: int = 1,
26
+ bos_token_id: int = 0,
27
+ eos_token_id: int = 2,
28
+ position_embedding_type: str = "rotary",
29
+ rotary_emb_base: float = 10000.0,
30
+ use_cache: bool = True,
31
+ use_reentrant: bool = False,
32
+ classifier_dropout: Optional[float] = None,
33
+ lora_adaptations: Optional[List[str]] = None,
34
+ task_instructions: Optional[Dict[str, str]] = None,
35
+ lora_rank: int = 4,
36
+ lora_dropout_p: float = 0.0,
37
+ lora_alpha: int = 1,
38
+ lora_main_params_trainable: bool = False,
39
+ load_trained_adapters: bool = False,
40
+ use_flash_attn: bool = True,
41
+ torch_dtype: Optional[Union[str, torch.dtype]] = None,
42
+ emb_pooler: Optional[str] = None,
43
+ matryoshka_dimensions: Optional[List[int]] = None,
44
+ truncate_dim: Optional[int] = None,
45
+ **kwargs: Dict[str, Any],
46
+ ):
47
+ """
48
+ Initialize the XLMRobertaFlashConfig configuration.
49
+
50
+ Args:
51
+ vocab_size (int): Size of the vocabulary.
52
+ hidden_size (int): Dimensionality of the encoder layers and the pooler layer.
53
+ num_hidden_layers (int): Number of hidden layers in the Transformer encoder.
54
+ num_attention_heads (int): Number of attention heads for each attention layer in the Transformer encoder.
55
+ intermediate_size (int): Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer.
56
+ hidden_act (str): The activation function to use.
57
+ hidden_dropout_prob (float): The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
58
+ attention_probs_dropout_prob (float): The dropout ratio for the attention probabilities.
59
+ max_position_embeddings (int): The maximum length of the position embeddings.
60
+ type_vocab_size (int): The vocabulary size of the token type ids.
61
+ initializer_range (float): The standard deviation for initializing all weight matrices.
62
+ layer_norm_eps (float): The epsilon used by the layer normalization layers.
63
+ pad_token_id (int): The ID of the padding token.
64
+ bos_token_id (int): The ID of the beginning-of-sequence token.
65
+ eos_token_id (int): The ID of the end-of-sequence token.
66
+ position_embedding_type (str): Type of position embeddings. Options are 'absolute', 'alibi', or 'rotary'.
67
+ rotary_emb_base (float): Base for rotary embeddings.
68
+ use_cache (bool): Whether or not the model should return the last key/values attentions (not used by all models).
69
+ use_reentrant (bool): Whether or not the model should enable the 'use_reentrant' flag in gradient checkpointing.
70
+ classifier_dropout (Optional[float]): The dropout ratio for the classification head.
71
+ lora_adaptations (Optional[List[str]]): LoRA adaptations configuration.
72
+ lora_prompts (Optional[Dict[str, str]]): LoRA prompts configuration.
73
+ lora_rank (int): Rank for LoRA adaptations.
74
+ lora_dropout_p (float): Dropout probability for LoRA adaptations.
75
+ lora_alpha (int): Alpha parameter for LoRA.
76
+ lora_main_params_trainable (bool): Whether to make the main model parameters trainable when using LoRA.
77
+ load_trained_adapters (bool): Whether to load trained adapters.
78
+ use_flash_attn (bool): Whether to use FlashAttention.
79
+ torch_dtype (Optional[Union[str, torch.dtype]]): Data type for the tensors.
80
+ emb_pooler (Optional[str]): Pooling layer configuration.
81
+ matryoshka_dimensions (Optional[List[int]]): Configuration for matryoshka dimension reduction.
82
+ truncate_dim (Optional[int]): Dimension to truncate embeddings to, if any.
83
+ **kwargs (Dict[str, Any]): Additional keyword arguments passed to the configuration.
84
+ """
85
+
86
+ super().__init__(
87
+ pad_token_id=pad_token_id,
88
+ bos_token_id=bos_token_id,
89
+ eos_token_id=eos_token_id,
90
+ **kwargs,
91
+ )
92
+
93
+ self.vocab_size = vocab_size
94
+ self.hidden_size = hidden_size
95
+ self.num_hidden_layers = num_hidden_layers
96
+ self.num_attention_heads = num_attention_heads
97
+ self.hidden_act = hidden_act
98
+ self.intermediate_size = intermediate_size
99
+ self.hidden_dropout_prob = hidden_dropout_prob
100
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
101
+ self.max_position_embeddings = max_position_embeddings
102
+ self.type_vocab_size = type_vocab_size
103
+ self.initializer_range = initializer_range
104
+ self.layer_norm_eps = layer_norm_eps
105
+ self.position_embedding_type = position_embedding_type
106
+ self.rotary_emb_base = rotary_emb_base
107
+ self.use_cache = use_cache
108
+ self.use_reentrant = use_reentrant
109
+ self.classifier_dropout = classifier_dropout
110
+ self.load_trained_adapters = load_trained_adapters
111
+ self.lora_adaptations = lora_adaptations
112
+ self.task_instructions = task_instructions
113
+ self.lora_rank = lora_rank
114
+ self.lora_dropout_p = lora_dropout_p
115
+ self.lora_alpha = lora_alpha
116
+ self.lora_main_params_trainable = lora_main_params_trainable
117
+ self.use_flash_attn = use_flash_attn
118
+ self.emb_pooler = emb_pooler
119
+ self.matryoshka_dimensions = matryoshka_dimensions
120
+ self.truncate_dim = truncate_dim
121
+ if (
122
+ torch_dtype
123
+ and hasattr(torch, torch_dtype)
124
+ and type(getattr(torch, torch_dtype)) is torch.dtype
125
+ ):
126
+ self.torch_dtype = getattr(torch, torch_dtype)
127
+ else:
128
+ self.torch_dtype = torch_dtype
convert_roberta_weights_to_flash.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import OrderedDict
3
+ from transformers import PretrainedConfig
4
+ from transformers import XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification
5
+
6
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig as BertConfig
7
+ from .modeling_xlm_roberta import XLMRobertaForMaskedLM as FlashXLMRobertaForMaskedLM
8
+ from .modeling_xlm_roberta import XLMRobertaForSequenceClassification as FlashXLMRobertaForSequenceClassification
9
+ import torch
10
+
11
+ import click
12
+
13
+ ## inspired by https://github.com/Dao-AILab/flash-attention/blob/85881f547fd1053a7b4a2c3faad6690cca969279/flash_attn/models/bert.py
14
+
15
+
16
+ def remap_state_dict(state_dict, config: PretrainedConfig):
17
+ """
18
+ Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
19
+ """
20
+
21
+ # LayerNorm
22
+ def key_mapping_ln_gamma_beta(key):
23
+ key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
24
+ key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
25
+ return key
26
+
27
+ state_dict = OrderedDict(
28
+ (key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items()
29
+ )
30
+
31
+ # Layers
32
+ def key_mapping_layers(key):
33
+ return re.sub(r"^roberta.encoder.layer.", "roberta.encoder.layers.", key)
34
+
35
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
36
+
37
+ # LayerNorm
38
+ def key_mapping_ln(key):
39
+ key = re.sub(r"^roberta.embeddings.LayerNorm.", "roberta.emb_ln.", key)
40
+ key = re.sub(
41
+ r"^roberta.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
42
+ r"roberta.encoder.layers.\1.norm1.\2",
43
+ key,
44
+ )
45
+ key = re.sub(
46
+ r"^roberta.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
47
+ r"roberta.encoder.layers.\1.norm2.\2",
48
+ key,
49
+ )
50
+ key = re.sub(
51
+ r"^cls.predictions.transform.LayerNorm.(weight|bias)",
52
+ r"cls.predictions.transform.layer_norm.\1",
53
+ key,
54
+ )
55
+ return key
56
+
57
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
58
+
59
+ # MLP
60
+ def key_mapping_mlp(key):
61
+ key = re.sub(
62
+ r"^roberta.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
63
+ r"roberta.encoder.layers.\1.mlp.fc1.\2",
64
+ key,
65
+ )
66
+ key = re.sub(
67
+ r"^roberta.encoder.layers.(\d+).output.dense.(weight|bias)",
68
+ r"roberta.encoder.layers.\1.mlp.fc2.\2",
69
+ key,
70
+ )
71
+ return key
72
+
73
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
74
+
75
+ # Attention
76
+ last_layer_subset = getattr(config, "last_layer_subset", False)
77
+ for d in range(config.num_hidden_layers):
78
+ Wq = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.query.weight")
79
+ Wk = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.key.weight")
80
+ Wv = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.value.weight")
81
+ bq = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.query.bias")
82
+ bk = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.key.bias")
83
+ bv = state_dict.pop(f"roberta.encoder.layers.{d}.attention.self.value.bias")
84
+ if not (last_layer_subset and d == config.num_hidden_layers - 1):
85
+ state_dict[f"roberta.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
86
+ [Wq, Wk, Wv], dim=0
87
+ )
88
+ state_dict[f"roberta.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat(
89
+ [bq, bk, bv], dim=0
90
+ )
91
+ else:
92
+ state_dict[f"roberta.encoder.layers.{d}.mixer.Wq.weight"] = Wq
93
+ state_dict[f"roberta.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat(
94
+ [Wk, Wv], dim=0
95
+ )
96
+ state_dict[f"roberta.encoder.layers.{d}.mixer.Wq.bias"] = bq
97
+ state_dict[f"roberta.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat(
98
+ [bk, bv], dim=0
99
+ )
100
+
101
+ def key_mapping_attn(key):
102
+ return re.sub(
103
+ r"^roberta.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
104
+ r"roberta.encoder.layers.\1.mixer.out_proj.\2",
105
+ key,
106
+ )
107
+
108
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
109
+
110
+ def key_mapping_decoder_bias(key):
111
+ return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
112
+
113
+ state_dict = OrderedDict(
114
+ (key_mapping_decoder_bias(k), v) for k, v in state_dict.items()
115
+ )
116
+
117
+ # Word embedding
118
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
119
+ if pad_vocab_size_multiple > 1:
120
+ word_embeddings = state_dict["roberta.embeddings.word_embeddings.weight"]
121
+ state_dict["roberta.embeddings.word_embeddings.weight"] = F.pad(
122
+ word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
123
+ )
124
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
125
+ state_dict["cls.predictions.decoder.weight"] = F.pad(
126
+ decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
127
+ )
128
+ # If the vocab was padded, we want to set the decoder bias for those padded indices to be
129
+ # strongly negative (i.e. the decoder shouldn't predict those indices).
130
+ # TD [2022-05-09]: I don't think it affects the MLPerf training.
131
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
132
+ state_dict["cls.predictions.decoder.bias"] = F.pad(
133
+ decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
134
+ )
135
+
136
+ return state_dict
137
+
138
+
139
+ @click.command()
140
+ @click.option('--model_name', default='FacebookAI/xlm-roberta-base', help='model name')
141
+ @click.option('--revision', default='main', help='revision')
142
+ @click.option('--task', default='masked_lm', help='task')
143
+ @click.option('--output', default='converted_roberta_weights.bin', help='model name')
144
+ def main(model_name, revision, task, output):
145
+
146
+ if task == 'masked_lm':
147
+ roberta_model = XLMRobertaForMaskedLM.from_pretrained(model_name, revision=revision)
148
+ elif task == 'sequence_classification':
149
+ roberta_model = XLMRobertaForSequenceClassification.from_pretrained(model_name, revision=revision,num_labels=1)
150
+ config = BertConfig.from_dict(roberta_model.config.to_dict())
151
+ state_dict = roberta_model.state_dict()
152
+ new_state_dict = remap_state_dict(state_dict, config)
153
+
154
+ if task == 'masked_lm':
155
+ flash_model = FlashXLMRobertaForMaskedLM(config)
156
+ elif task == 'sequence_classification':
157
+ flash_model = FlashXLMRobertaForSequenceClassification(config)
158
+
159
+ for k, v in flash_model.state_dict().items():
160
+ if k not in new_state_dict:
161
+ print(f'Use old weights from {k}')
162
+ new_state_dict[k] = v
163
+
164
+ flash_model.load_state_dict(new_state_dict)
165
+
166
+ torch.save(new_state_dict, output)
167
+
168
+
169
+ if __name__ == '__main__':
170
+ main()
embedding.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py
2
+ # Commit id: f1a73d074002226c42ce65a1df170ecff9f022c0
3
+
4
+ # Copyright (c) 2022, Tri Dao.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import \
9
+ create_position_ids_from_input_ids
10
+
11
+
12
+ class XLMRobertaEmbeddings(nn.Module):
13
+ def __init__(
14
+ self,
15
+ embed_dim,
16
+ vocab_size,
17
+ max_position_embeddings,
18
+ type_vocab_size,
19
+ padding_idx=None,
20
+ device=None,
21
+ dtype=None,
22
+ ):
23
+ """
24
+ If max_position_embeddings <= 0, there's no position embeddings
25
+ If type_vocab_size <= 0, there's no token type embeddings
26
+ """
27
+ factory_kwargs = {"device": device, "dtype": dtype}
28
+ super().__init__()
29
+ self.word_embeddings = nn.Embedding(
30
+ vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
31
+ )
32
+ self.max_position_embeddings = max_position_embeddings
33
+ self.type_vocab_size = type_vocab_size
34
+ if self.max_position_embeddings > 0:
35
+ self.position_embeddings = nn.Embedding(
36
+ max_position_embeddings, embed_dim, **factory_kwargs
37
+ )
38
+ if self.type_vocab_size > 0:
39
+ self.token_type_embeddings = nn.Embedding(
40
+ type_vocab_size, embed_dim, **factory_kwargs
41
+ )
42
+
43
+ def forward(
44
+ self, input_ids, position_ids=None, token_type_ids=None, task_id=None
45
+ ):
46
+ """
47
+ input_ids: (batch, seqlen)
48
+ position_ids: (batch, seqlen)
49
+ token_type_ids: (batch, seqlen)
50
+ """
51
+ batch_size, seqlen = input_ids.shape
52
+ if task_id is not None:
53
+ embeddings = self.word_embeddings(input_ids, task_id=task_id)
54
+ else:
55
+ embeddings = self.word_embeddings(input_ids)
56
+ if self.max_position_embeddings > 0:
57
+ if position_ids is None:
58
+ position_ids = create_position_ids_from_input_ids(
59
+ input_ids, padding_idx=self.word_embeddings.padding_idx
60
+ ).to(input_ids.device)
61
+ position_embeddings = self.position_embeddings(position_ids)
62
+ embeddings = embeddings + position_embeddings
63
+ if self.type_vocab_size > 0:
64
+ if token_type_ids is None:
65
+ token_type_ids = torch.zeros(
66
+ seqlen, dtype=torch.long, device=input_ids.device
67
+ )
68
+
69
+ if task_id is not None:
70
+ token_type_embeddings = self.token_type_embeddings(
71
+ token_type_ids, task_id=task_id
72
+ )
73
+ else:
74
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
75
+ embeddings = embeddings + token_type_embeddings
76
+ return embeddings
mha.py ADDED
@@ -0,0 +1,806 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py
2
+ # Commit id: 6bbc532388e61185a92e2a563126739967b4c8c5
3
+ # Rotary varlen support from https://github.com/Dao-AILab/flash-attention/pull/556
4
+
5
+ # Copyright (c) 2023, Tri Dao.
6
+
7
+ import math
8
+ from functools import partial
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange, repeat
13
+
14
+ try:
15
+ from flash_attn import (flash_attn_kvpacked_func,
16
+ flash_attn_qkvpacked_func,
17
+ flash_attn_varlen_kvpacked_func,
18
+ flash_attn_varlen_qkvpacked_func,
19
+ flash_attn_with_kvcache)
20
+ except ImportError:
21
+ flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
22
+ flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
23
+ flash_attn_with_kvcache = None
24
+
25
+ try:
26
+ from flash_attn.ops.fused_dense import (ColumnParallelLinear, FusedDense,
27
+ RowParallelLinear)
28
+ except ImportError:
29
+ FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
30
+
31
+ from .rotary import RotaryEmbedding
32
+
33
+
34
+ # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
35
+ def get_alibi_slopes(nheads):
36
+ def get_slopes_power_of_2(nheads):
37
+ start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))
38
+ ratio = start
39
+ return [start * ratio**i for i in range(nheads)]
40
+
41
+ if math.log2(nheads).is_integer():
42
+ return get_slopes_power_of_2(nheads)
43
+ else:
44
+ closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
45
+ return (
46
+ get_slopes_power_of_2(closest_power_of_2)
47
+ + get_alibi_slopes(2 * closest_power_of_2)[0::2][
48
+ : nheads - closest_power_of_2
49
+ ]
50
+ )
51
+
52
+
53
+ class FlashSelfAttention(nn.Module):
54
+ """Implement the scaled dot product attention with softmax.
55
+ Arguments
56
+ ---------
57
+ softmax_scale: The temperature to use for the softmax attention.
58
+ (default: 1/sqrt(d_keys) where d_keys is computed at
59
+ runtime)
60
+ attention_dropout: The dropout rate to apply to the attention
61
+ (default: 0.0)
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ causal=False,
67
+ softmax_scale=None,
68
+ attention_dropout=0.0,
69
+ window_size=(-1, -1),
70
+ alibi_slopes=None,
71
+ deterministic=False,
72
+ ):
73
+ super().__init__()
74
+ assert (
75
+ flash_attn_varlen_qkvpacked_func is not None
76
+ ), "FlashAttention is not installed"
77
+ assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
78
+ self.causal = causal
79
+ self.softmax_scale = softmax_scale
80
+ self.drop = nn.Dropout(attention_dropout)
81
+ self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
82
+ self.window_size = window_size
83
+ self.deterministic = deterministic
84
+
85
+ def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
86
+ """Implements the multihead softmax attention.
87
+ Arguments
88
+ ---------
89
+ qkv: The tensor containing the query, key, and value.
90
+ If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
91
+ If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
92
+ (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
93
+ causal: if passed, will override self.causal
94
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
95
+ of the sequences in the batch, used to index into qkv.
96
+ max_seqlen: int. Maximum sequence length in the batch.
97
+ Returns:
98
+ --------
99
+ out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
100
+ else (B, S, H, D).
101
+ """
102
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
103
+ assert qkv.is_cuda
104
+ causal = self.causal if causal is None else causal
105
+ unpadded = cu_seqlens is not None
106
+ if self.alibi_slopes is not None:
107
+ self.alibi_slopes = self.alibi_slopes.to(torch.float32)
108
+ if unpadded:
109
+ assert cu_seqlens.dtype == torch.int32
110
+ assert max_seqlen is not None
111
+ assert isinstance(max_seqlen, int)
112
+ return flash_attn_varlen_qkvpacked_func(
113
+ qkv,
114
+ cu_seqlens,
115
+ max_seqlen,
116
+ self.drop.p if self.training else 0.0,
117
+ softmax_scale=self.softmax_scale,
118
+ causal=causal,
119
+ alibi_slopes=self.alibi_slopes,
120
+ window_size=self.window_size,
121
+ deterministic=self.deterministic,
122
+ )
123
+ else:
124
+ return flash_attn_qkvpacked_func(
125
+ qkv,
126
+ self.drop.p if self.training else 0.0,
127
+ softmax_scale=self.softmax_scale,
128
+ causal=causal,
129
+ alibi_slopes=self.alibi_slopes,
130
+ window_size=self.window_size,
131
+ deterministic=self.deterministic,
132
+ )
133
+
134
+
135
+ class FlashCrossAttention(nn.Module):
136
+ """Implement the scaled dot product attention with softmax.
137
+ Arguments
138
+ ---------
139
+ softmax_scale: The temperature to use for the softmax attention.
140
+ (default: 1/sqrt(d_keys) where d_keys is computed at
141
+ runtime)
142
+ attention_dropout: The dropout rate to apply to the attention
143
+ (default: 0.0)
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ causal=False,
149
+ softmax_scale=None,
150
+ attention_dropout=0.0,
151
+ alibi_slopes=None,
152
+ window_size=(-1, -1),
153
+ deterministic=False,
154
+ ):
155
+ super().__init__()
156
+ assert (
157
+ flash_attn_varlen_kvpacked_func is not None
158
+ ), "FlashAttention is not installed"
159
+ assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
160
+ self.causal = causal
161
+ self.softmax_scale = softmax_scale
162
+ self.drop = nn.Dropout(attention_dropout)
163
+ self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
164
+ self.window_size = window_size
165
+ self.deterministic = deterministic
166
+
167
+ def forward(
168
+ self,
169
+ q,
170
+ kv,
171
+ causal=None,
172
+ cu_seqlens=None,
173
+ max_seqlen=None,
174
+ cu_seqlens_k=None,
175
+ max_seqlen_k=None,
176
+ ):
177
+ """Implements the multihead softmax attention.
178
+ Arguments
179
+ ---------
180
+ q: The tensor containing the query. (B, Sq, H, D)
181
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
182
+ causal: if passed, will override self.causal
183
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
184
+ of the sequences in the batch, used to index into q.
185
+ max_seqlen: int. Maximum sequence length in the batch of q.
186
+ cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
187
+ of the sequences in the batch, used to index into kv.
188
+ max_seqlen_k: int. Maximum sequence length in the batch of k and v.
189
+ """
190
+ assert q.dtype in [torch.float16, torch.bfloat16]
191
+ assert q.is_cuda and kv.is_cuda
192
+ causal = self.causal if causal is None else causal
193
+ unpadded = cu_seqlens is not None
194
+ if self.alibi_slopes is not None:
195
+ self.alibi_slopes = self.alibi_slopes.to(torch.float32)
196
+ if unpadded:
197
+ assert cu_seqlens.dtype == torch.int32
198
+ assert max_seqlen is not None
199
+ assert isinstance(max_seqlen, int)
200
+ assert cu_seqlens_k is not None
201
+ assert cu_seqlens_k.dtype == torch.int32
202
+ assert max_seqlen_k is not None
203
+ assert isinstance(max_seqlen, int)
204
+ return flash_attn_varlen_kvpacked_func(
205
+ q,
206
+ kv,
207
+ cu_seqlens,
208
+ cu_seqlens_k,
209
+ max_seqlen,
210
+ max_seqlen_k,
211
+ self.drop.p if self.training else 0.0,
212
+ softmax_scale=self.softmax_scale,
213
+ causal=causal,
214
+ alibi_slopes=self.alibi_slopes,
215
+ window_size=self.window_size,
216
+ deterministic=self.deterministic,
217
+ )
218
+ else:
219
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
220
+ seqlen_k = kv.shape[1]
221
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
222
+ return flash_attn_kvpacked_func(
223
+ q,
224
+ kv,
225
+ self.drop.p if self.training else 0.0,
226
+ causal=causal,
227
+ softmax_scale=self.softmax_scale,
228
+ alibi_slopes=self.alibi_slopes,
229
+ window_size=self.window_size,
230
+ deterministic=self.deterministic,
231
+ )
232
+
233
+
234
+ class SelfAttention(nn.Module):
235
+ """Implement the scaled dot product attention with softmax.
236
+ Arguments
237
+ ---------
238
+ softmax_scale: The temperature to use for the softmax attention.
239
+ (default: 1/sqrt(d_keys) where d_keys is computed at
240
+ runtime)
241
+ attention_dropout: The dropout rate to apply to the attention
242
+ (default: 0.0)
243
+ """
244
+
245
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
246
+ super().__init__()
247
+ self.causal = causal
248
+ self.softmax_scale = softmax_scale
249
+ self.drop = nn.Dropout(attention_dropout)
250
+
251
+ def forward(self, qkv, causal=None, key_padding_mask=None):
252
+ """Implements the multihead softmax attention.
253
+ Arguments
254
+ ---------
255
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
256
+ causal: if passed, will override self.causal
257
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
258
+ False means to mask out. (B, S)
259
+ """
260
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
261
+ causal = self.causal if causal is None else causal
262
+ q, k, v = qkv.unbind(dim=2)
263
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
264
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
265
+ if key_padding_mask is not None:
266
+ padding_mask = torch.full(
267
+ (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
268
+ )
269
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
270
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
271
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
272
+ if causal:
273
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
274
+ # So we have to construct the mask in float
275
+ causal_mask = torch.triu(
276
+ torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
277
+ )
278
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
279
+ scores = scores + causal_mask.to(dtype=scores.dtype)
280
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
281
+ attention_drop = self.drop(attention)
282
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
283
+ return output
284
+
285
+
286
+ class CrossAttention(nn.Module):
287
+ """Implement the scaled dot product attention with softmax.
288
+ Arguments
289
+ ---------
290
+ softmax_scale: The temperature to use for the softmax attention.
291
+ (default: 1/sqrt(d_keys) where d_keys is computed at
292
+ runtime)
293
+ attention_dropout: The dropout rate to apply to the attention
294
+ (default: 0.0)
295
+ """
296
+
297
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
298
+ super().__init__()
299
+ self.causal = causal
300
+ self.softmax_scale = softmax_scale
301
+ self.drop = nn.Dropout(attention_dropout)
302
+
303
+ def forward(self, q, kv, causal=None, key_padding_mask=None):
304
+ """Implements the multihead softmax attention.
305
+ Arguments
306
+ ---------
307
+ q: The tensor containing the query. (B, Sq, H, D)
308
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
309
+ causal: if passed, will override self.causal
310
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
311
+ False means to mask out. (B, Sk)
312
+ """
313
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
314
+ causal = self.causal if causal is None else causal
315
+ seqlen_k = kv.shape[1]
316
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
317
+ if kv.shape[3] != q.shape[2]: # MQA/GQA
318
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
319
+ k, v = kv.unbind(dim=2)
320
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
321
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
322
+ if key_padding_mask is not None:
323
+ padding_mask = torch.full(
324
+ (batch_size, seqlen_k),
325
+ -10000.0,
326
+ dtype=scores.dtype,
327
+ device=scores.device,
328
+ )
329
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
330
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
331
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
332
+ if causal:
333
+ # causal mask needs to take into account the difference between seqlen_q and seqlen_k
334
+ row_idx = rearrange(
335
+ torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
336
+ )
337
+ col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
338
+ sk = (
339
+ seqlen_k
340
+ if key_padding_mask is None
341
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
342
+ )
343
+ causal_mask = col_idx > row_idx + sk - seqlen_q
344
+ scores = scores.masked_fill(causal_mask, -10000.0)
345
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
346
+ attention_drop = self.drop(attention)
347
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
348
+ return output
349
+
350
+
351
+ class LinearResidual(nn.Linear):
352
+ """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
353
+
354
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
355
+ return super().forward(input), input
356
+
357
+
358
+ def _update_kv_cache(kv, inference_params, layer_idx):
359
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
360
+ # Pre-allocate memory for key-values for inference.
361
+ num_heads, head_dim = kv.shape[-2:]
362
+ if layer_idx not in inference_params.key_value_memory_dict:
363
+ kv_cache = torch.empty(
364
+ inference_params.max_batch_size,
365
+ inference_params.max_seqlen,
366
+ 2,
367
+ num_heads,
368
+ head_dim,
369
+ dtype=kv.dtype,
370
+ device=kv.device,
371
+ )
372
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
373
+ else:
374
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
375
+ # Adjust key and value for inference
376
+ batch_start = inference_params.batch_size_offset
377
+ batch_end = batch_start + kv.shape[0]
378
+ sequence_start = inference_params.seqlen_offset
379
+ sequence_end = sequence_start + kv.shape[1]
380
+ assert batch_end <= kv_cache.shape[0]
381
+ assert sequence_end <= kv_cache.shape[1]
382
+ assert kv_cache is not None
383
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
384
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
385
+
386
+
387
+ class MHA(nn.Module):
388
+ """Multi-head self-attention and cross-attention"""
389
+
390
+ def __init__(
391
+ self,
392
+ embed_dim,
393
+ num_heads,
394
+ num_heads_kv=None,
395
+ cross_attn=False,
396
+ qkv_proj_bias=True,
397
+ out_proj_bias=True,
398
+ dropout=0.0,
399
+ softmax_scale=None,
400
+ causal=False,
401
+ layer_idx=None,
402
+ dwconv=False,
403
+ rotary_emb_dim=0,
404
+ rotary_emb_base=10000.0,
405
+ rotary_emb_scale_base=None,
406
+ rotary_emb_interleaved=False,
407
+ use_alibi=False,
408
+ window_size=(-1, -1),
409
+ fused_bias_fc=False,
410
+ use_flash_attn=False,
411
+ return_residual=False,
412
+ checkpointing=False,
413
+ device=None,
414
+ dtype=None,
415
+ ) -> None:
416
+ """
417
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
418
+ return_residual: whether to return the input x along with the output. This is for
419
+ performance reason: for post-norm architecture, returning the input allows us
420
+ to fuse the backward of nn.Linear with the residual connection.
421
+ """
422
+ factory_kwargs = {"device": device, "dtype": dtype}
423
+ super().__init__()
424
+ self.embed_dim = embed_dim
425
+ self.cross_attn = cross_attn
426
+ self.causal = causal
427
+ self.layer_idx = layer_idx
428
+ self.dwconv = dwconv
429
+ self.rotary_emb_dim = rotary_emb_dim
430
+ self.use_flash_attn = use_flash_attn
431
+ self.return_residual = return_residual
432
+ self.checkpointing = checkpointing
433
+ if use_alibi:
434
+ assert use_flash_attn, "ALiBi code path requires flash_attn"
435
+ alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
436
+ else:
437
+ alibi_slopes = None
438
+ if window_size != (-1, -1):
439
+ assert (
440
+ use_flash_attn
441
+ ), "Local (sliding window) attention code path requires flash_attn"
442
+
443
+ self.num_heads = num_heads
444
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
445
+ assert (
446
+ self.num_heads % self.num_heads_kv == 0
447
+ ), "num_heads must be divisible by num_heads_kv"
448
+ assert (
449
+ self.embed_dim % num_heads == 0
450
+ ), "embed_dim must be divisible by num_heads"
451
+ self.head_dim = self.embed_dim // num_heads
452
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
453
+ kv_dim = 2 * self.head_dim * self.num_heads_kv
454
+
455
+ if self.rotary_emb_dim > 0:
456
+ assert (
457
+ not cross_attn
458
+ ), "MHA with rotary embedding does not support cross-attention yet"
459
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
460
+ self.rotary_emb = RotaryEmbedding(
461
+ self.rotary_emb_dim,
462
+ base=rotary_emb_base,
463
+ scale_base=rotary_emb_scale_base,
464
+ interleaved=rotary_emb_interleaved,
465
+ device=device,
466
+ use_flash_attn=use_flash_attn,
467
+ )
468
+
469
+ if fused_bias_fc and FusedDense is None:
470
+ raise ImportError("fused_dense is not installed")
471
+
472
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
473
+ linear_resid_cls = (
474
+ LinearResidual
475
+ if not fused_bias_fc
476
+ else partial(FusedDense, return_residual=True)
477
+ )
478
+ wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
479
+ inner_attn_cls = (
480
+ partial(
481
+ FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size
482
+ )
483
+ if use_flash_attn
484
+ else SelfAttention
485
+ )
486
+ inner_cross_attn_cls = (
487
+ partial(
488
+ FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size
489
+ )
490
+ if use_flash_attn
491
+ else CrossAttention
492
+ )
493
+ if not self.cross_attn:
494
+ self.Wqkv = wqkv_cls(
495
+ embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs
496
+ )
497
+ else:
498
+ self.Wq = linear_cls(
499
+ embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs
500
+ )
501
+ self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
502
+ if self.dwconv:
503
+ if self.num_heads_kv == self.num_heads:
504
+ self.dwconv_qkv = nn.Conv1d(
505
+ qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
506
+ )
507
+ else:
508
+ self.dwconv_q = nn.Conv1d(
509
+ embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
510
+ )
511
+ self.dwconv_kv = nn.Conv1d(
512
+ kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim
513
+ )
514
+ self.inner_attn = inner_attn_cls(
515
+ causal=causal,
516
+ softmax_scale=softmax_scale,
517
+ attention_dropout=dropout,
518
+ )
519
+ self.inner_cross_attn = inner_cross_attn_cls(
520
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
521
+ )
522
+ self.out_proj = linear_cls(
523
+ embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs
524
+ )
525
+
526
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
527
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
528
+ device = self.out_proj.weight.device
529
+ return torch.empty(
530
+ batch_size,
531
+ max_seqlen,
532
+ 2,
533
+ self.num_heads_kv,
534
+ self.head_dim,
535
+ dtype=dtype,
536
+ device=device,
537
+ )
538
+
539
+ def _update_kv_cache(self, kv, inference_params):
540
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
541
+ assert not self.dwconv, "Generation does not support dwconv yet"
542
+ assert (
543
+ self.layer_idx is not None
544
+ ), "Generation requires layer_idx in the constructor"
545
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
546
+
547
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
548
+ """
549
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
550
+ q: (batch_size, seqlen_q, nheads, head_dim)
551
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
552
+ """
553
+ assert inference_params is not None and inference_params.seqlen_offset > 0
554
+ assert self.use_flash_attn
555
+ if self.rotary_emb_dim > 0:
556
+ assert self.rotary_emb.scale is None, "This code path does not support xPos"
557
+ self.rotary_emb._update_cos_sin_cache(
558
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
559
+ )
560
+ rotary_cos, rotary_sin = (
561
+ self.rotary_emb._cos_cached,
562
+ self.rotary_emb._sin_cached,
563
+ )
564
+ else:
565
+ rotary_cos, rotary_sin = None, None
566
+ batch = q.shape[0]
567
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
568
+ cache_seqlens = (
569
+ inference_params.lengths_per_sample[:batch]
570
+ if inference_params.lengths_per_sample is not None
571
+ else inference_params.seqlen_offset
572
+ )
573
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
574
+ context = flash_attn_with_kvcache(
575
+ q,
576
+ kv_cache[:, :, 0],
577
+ kv_cache[:, :, 1],
578
+ kv[:, :, 0],
579
+ kv[:, :, 1],
580
+ rotary_cos=rotary_cos,
581
+ rotary_sin=rotary_sin,
582
+ cache_seqlens=cache_seqlens,
583
+ softmax_scale=self.inner_cross_attn.softmax_scale,
584
+ causal=self.inner_cross_attn.causal,
585
+ rotary_interleaved=(
586
+ self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False
587
+ ),
588
+ alibi_slopes=alibi_slopes,
589
+ )
590
+ return context
591
+
592
+ def _update_kvcache_attention(self, q, kv, inference_params):
593
+ """Write kv to inference_params, then do attention"""
594
+ if (
595
+ inference_params.seqlen_offset == 0
596
+ or flash_attn_with_kvcache is None
597
+ or not self.use_flash_attn
598
+ ):
599
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
600
+ kv = self._update_kv_cache(kv, inference_params)
601
+ return self.inner_cross_attn(q, kv)
602
+ else:
603
+ batch = q.shape[0]
604
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
605
+ cache_seqlens = (
606
+ inference_params.lengths_per_sample[:batch]
607
+ if inference_params.lengths_per_sample is not None
608
+ else inference_params.seqlen_offset
609
+ )
610
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
611
+ return flash_attn_with_kvcache(
612
+ q,
613
+ kv_cache[:, :, 0],
614
+ kv_cache[:, :, 1],
615
+ kv[:, :, 0],
616
+ kv[:, :, 1],
617
+ cache_seqlens=cache_seqlens,
618
+ softmax_scale=self.inner_cross_attn.softmax_scale,
619
+ causal=self.inner_cross_attn.causal,
620
+ alibi_slopes=alibi_slopes,
621
+ )
622
+
623
+ def forward(
624
+ self,
625
+ x,
626
+ x_kv=None,
627
+ key_padding_mask=None,
628
+ cu_seqlens=None,
629
+ max_seqlen=None,
630
+ mixer_subset=None,
631
+ inference_params=None,
632
+ task_id=None,
633
+ **kwargs,
634
+ ):
635
+ """
636
+ Arguments:
637
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
638
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
639
+ is the is the sum of the sequence lengths in the batch.
640
+ x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
641
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
642
+ of the sequences in the batch, used to index into x. Only applicable when using
643
+ FlashAttention.
644
+ max_seqlen: int. Maximum sequence length in the batch.
645
+ key_padding_mask: boolean mask, True means to keep, False means to mask out.
646
+ (batch, seqlen). Only applicable when not using FlashAttention.
647
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
648
+ before applying the query projection. Useful for e.g., ViT where we only care
649
+ about the CLS token in the last layer.
650
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
651
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
652
+ """
653
+ if cu_seqlens is not None:
654
+ assert max_seqlen is not None
655
+ assert key_padding_mask is None
656
+ assert self.use_flash_attn
657
+ assert not self.dwconv
658
+ if key_padding_mask is not None:
659
+ assert cu_seqlens is None
660
+ assert max_seqlen is None
661
+ assert not self.use_flash_attn
662
+ if inference_params is not None:
663
+ assert key_padding_mask is None
664
+ assert cu_seqlens is None and max_seqlen is None
665
+ assert not self.dwconv
666
+
667
+ kwargs = (
668
+ {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
669
+ if self.use_flash_attn
670
+ else {"key_padding_mask": key_padding_mask, **kwargs}
671
+ )
672
+ seqlen_offset = (
673
+ 0
674
+ if inference_params is None
675
+ else (
676
+ inference_params.lengths_per_sample
677
+ if inference_params.lengths_per_sample is not None
678
+ else inference_params.seqlen_offset
679
+ )
680
+ )
681
+ rotary_max_seqlen = (
682
+ inference_params.max_sequence_len
683
+ if inference_params is not None
684
+ else max_seqlen
685
+ )
686
+ if not self.cross_attn and self.num_heads_kv == self.num_heads:
687
+ assert x_kv is None and mixer_subset is None
688
+
689
+ if task_id is not None:
690
+ if not self.return_residual:
691
+ qkv = self.Wqkv(x, task_id=task_id)
692
+ else:
693
+ qkv, _ = self.Wqkv(
694
+ x, task_id=task_id, residual=True
695
+ )
696
+ else:
697
+ if not self.return_residual:
698
+ qkv = self.Wqkv(x)
699
+ else:
700
+ if hasattr(self.Wqkv, "parametrizations"):
701
+ qkv, x = self.Wqkv(x, residual=True)
702
+ else:
703
+ qkv, x = self.Wqkv(x)
704
+
705
+ if self.dwconv:
706
+ qkv = rearrange(
707
+ self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2],
708
+ "b d s -> b s d",
709
+ ).contiguous()
710
+ qkv = rearrange(
711
+ qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
712
+ )
713
+ if (
714
+ inference_params is None
715
+ or inference_params.seqlen_offset == 0
716
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
717
+ or not self.use_flash_attn
718
+ ):
719
+ if self.rotary_emb_dim > 0:
720
+ qkv = self.rotary_emb(
721
+ qkv,
722
+ seqlen_offset=seqlen_offset,
723
+ cu_seqlens=cu_seqlens,
724
+ max_seqlen=rotary_max_seqlen,
725
+ )
726
+ if inference_params is None:
727
+ if not self.checkpointing:
728
+ context = self.inner_attn(qkv, **kwargs)
729
+ else:
730
+ context = torch.utils.checkpoint.checkpoint(
731
+ self.inner_attn, qkv, **kwargs
732
+ )
733
+ else:
734
+ context = self._update_kvcache_attention(
735
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
736
+ )
737
+ else:
738
+ context = self._apply_rotary_update_kvcache_attention(
739
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
740
+ )
741
+ else:
742
+ if self.cross_attn:
743
+ if not self.return_residual:
744
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
745
+ kv = self.Wkv(x_kv if x_kv is not None else x)
746
+ else:
747
+ if x_kv is not None:
748
+ kv, x_kv = self.Wkv(x_kv)
749
+ else:
750
+ kv, x = self.Wkv(x)
751
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
752
+ else:
753
+ assert self.num_heads_kv != self.num_heads
754
+ if not self.return_residual:
755
+ qkv = self.Wqkv(x)
756
+ else:
757
+ qkv, x = self.Wqkv(x)
758
+ q = qkv[..., : self.num_heads * self.head_dim]
759
+ kv = qkv[..., self.num_heads * self.head_dim :]
760
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
761
+ kv = rearrange(
762
+ kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim
763
+ )
764
+ if self.dwconv:
765
+ q = rearrange(
766
+ self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2],
767
+ "b d s -> b s d",
768
+ ).contiguous()
769
+ kv = rearrange(
770
+ self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2],
771
+ "b d s -> b s d",
772
+ ).contiguous()
773
+ if (
774
+ inference_params is None
775
+ or inference_params.seqlen_offset == 0
776
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
777
+ or not self.use_flash_attn
778
+ ):
779
+ if self.rotary_emb_dim > 0:
780
+ q, kv = self.rotary_emb(
781
+ q,
782
+ kv,
783
+ seqlen_offset=seqlen_offset,
784
+ cu_seqlens=cu_seqlens,
785
+ max_seqlen=rotary_max_seqlen,
786
+ )
787
+ if inference_params is None:
788
+ if not self.checkpointing:
789
+ context = self.inner_cross_attn(q, kv, **kwargs)
790
+ else:
791
+ context = torch.utils.checkpoint.checkpoint(
792
+ self.inner_cross_attn, q, kv, **kwargs
793
+ )
794
+ else:
795
+ context = self._update_kvcache_attention(q, kv, inference_params)
796
+ else:
797
+ context = self._apply_rotary_update_kvcache_attention(
798
+ q, kv, inference_params
799
+ )
800
+
801
+ inp = rearrange(context, "... h d -> ... (h d)")
802
+ if task_id is not None:
803
+ out = self.out_proj(inp, task_id=task_id)
804
+ else:
805
+ out = self.out_proj(inp)
806
+ return out if not self.return_residual else (out, x)
mlp.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mlp.py
2
+ # Commit id: c3b219665292c61a51153d0ded4473c494296382
3
+
4
+ # Copyright (c) 2023, Tri Dao.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.distributed import ProcessGroup
10
+
11
+ try:
12
+ from flash_attn.ops.activations import swiglu
13
+ except ImportError:
14
+ swiglu = None
15
+
16
+ try:
17
+ from flash_attn.ops.fused_dense import (ColumnParallelLinear,
18
+ RowParallelLinear)
19
+ except ImportError:
20
+ ColumnParallelLinear, RowParallelLinear = None, None
21
+
22
+ try:
23
+ from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
24
+ except ImportError:
25
+ FusedMLP, ParallelFusedMLP = None, None
26
+
27
+
28
+ class Mlp(nn.Module):
29
+ def __init__(
30
+ self,
31
+ in_features,
32
+ hidden_features=None,
33
+ out_features=None,
34
+ activation=F.gelu,
35
+ bias1=True,
36
+ bias2=True,
37
+ return_residual=False,
38
+ device=None,
39
+ dtype=None,
40
+ ):
41
+ factory_kwargs = {"device": device, "dtype": dtype}
42
+ super().__init__()
43
+ out_features = out_features if out_features is not None else in_features
44
+ hidden_features = (
45
+ hidden_features if hidden_features is not None else in_features * 4
46
+ )
47
+ self.return_residual = return_residual
48
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
49
+ self.activation = activation
50
+ self.fc2 = nn.Linear(
51
+ hidden_features, out_features, bias=bias2, **factory_kwargs
52
+ )
53
+
54
+ def forward(self, x, task_id=None):
55
+ if task_id is not None:
56
+ y = self.fc1(x, task_id=task_id)
57
+ else:
58
+ y = self.fc1(x)
59
+
60
+ y = self.activation(y)
61
+
62
+ if task_id is not None:
63
+ out = self.fc2(y, task_id=task_id)
64
+ else:
65
+ out = self.fc2(y)
66
+
67
+ return out if not self.return_residual else (out, x)
68
+
69
+
70
+ class ParallelMLP(nn.Module):
71
+ def __init__(
72
+ self,
73
+ in_features,
74
+ hidden_features=None,
75
+ out_features=None,
76
+ activation=F.gelu,
77
+ process_group: ProcessGroup = None,
78
+ sequence_parallel=True,
79
+ bias1=True,
80
+ bias2=True,
81
+ device=None,
82
+ dtype=None,
83
+ ):
84
+ factory_kwargs = {"device": device, "dtype": dtype}
85
+ super().__init__()
86
+ assert ColumnParallelLinear is not None, "Need to install fused_dense"
87
+ assert RowParallelLinear is not None, "Need to install fused_dense"
88
+ out_features = out_features if out_features is not None else in_features
89
+ hidden_features = (
90
+ hidden_features if hidden_features is not None else in_features * 4
91
+ )
92
+ self.fc1 = ColumnParallelLinear(
93
+ in_features,
94
+ hidden_features,
95
+ process_group,
96
+ bias=bias1,
97
+ sequence_parallel=sequence_parallel,
98
+ **factory_kwargs,
99
+ )
100
+ self.activation = activation
101
+ self.fc2 = RowParallelLinear(
102
+ hidden_features,
103
+ out_features,
104
+ process_group,
105
+ bias=bias2,
106
+ sequence_parallel=sequence_parallel,
107
+ **factory_kwargs,
108
+ )
109
+
110
+ def forward(self, x):
111
+ y = self.fc1(x)
112
+ y = self.activation(y)
113
+ y = self.fc2(y)
114
+ return y
115
+
116
+
117
+ class GatedMlp(nn.Module):
118
+ def __init__(
119
+ self,
120
+ in_features,
121
+ hidden_features=None,
122
+ out_features=None,
123
+ activation=F.sigmoid,
124
+ bias1=True,
125
+ bias2=True,
126
+ multiple_of=128,
127
+ return_residual=False,
128
+ device=None,
129
+ dtype=None,
130
+ ):
131
+ factory_kwargs = {"device": device, "dtype": dtype}
132
+ super().__init__()
133
+ out_features = out_features if out_features is not None else in_features
134
+ hidden_features = (
135
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
136
+ )
137
+ hidden_features = (
138
+ (hidden_features + multiple_of - 1) // multiple_of * multiple_of
139
+ )
140
+ self.return_residual = return_residual
141
+ self.fc1 = nn.Linear(
142
+ in_features, 2 * hidden_features, bias=bias1, **factory_kwargs
143
+ )
144
+ self.activation = activation
145
+ self.fc2 = nn.Linear(
146
+ hidden_features, out_features, bias=bias2, **factory_kwargs
147
+ )
148
+
149
+ def forward(self, x):
150
+ y = self.fc1(x)
151
+ if self.activation == F.sigmoid: # Special case for GLU
152
+ y = F.glu(y, dim=-1)
153
+ elif (
154
+ self.activation == F.silu and swiglu is not None
155
+ ): # Special case for SwiGLU
156
+ y, gate = y.chunk(2, dim=-1)
157
+ y = swiglu(gate, y)
158
+ else:
159
+ y, gate = y.chunk(2, dim=-1)
160
+ y = y * self.activation(gate)
161
+ y = self.fc2(y)
162
+ return y if not self.return_residual else (y, x)
163
+
164
+
165
+ class ParallelGatedMlp(nn.Module):
166
+ """Parallel GatedMlp"""
167
+
168
+ def __init__(
169
+ self,
170
+ in_features,
171
+ process_group,
172
+ hidden_features=None,
173
+ out_features=None,
174
+ activation=F.sigmoid,
175
+ bias1=True,
176
+ bias2=True,
177
+ multiple_of=128,
178
+ sequence_parallel=True,
179
+ device=None,
180
+ dtype=None,
181
+ ):
182
+ factory_kwargs = {"device": device, "dtype": dtype}
183
+ super().__init__()
184
+ out_features = out_features if out_features is not None else in_features
185
+ hidden_features = (
186
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
187
+ )
188
+ hidden_features = (
189
+ (hidden_features + multiple_of - 1) // multiple_of * multiple_of
190
+ )
191
+ if ColumnParallelLinear is None or RowParallelLinear is None:
192
+ raise ImportError("fused_dense is not installed")
193
+ self.fc1 = ColumnParallelLinear(
194
+ in_features,
195
+ 2 * hidden_features,
196
+ process_group,
197
+ bias=bias1,
198
+ sequence_parallel=sequence_parallel,
199
+ **factory_kwargs,
200
+ )
201
+ self.activation = activation
202
+ self.fc2 = RowParallelLinear(
203
+ hidden_features,
204
+ out_features,
205
+ process_group,
206
+ bias=bias2,
207
+ sequence_parallel=sequence_parallel,
208
+ **factory_kwargs,
209
+ )
210
+
211
+ def forward(self, x):
212
+ y = self.fc1(x)
213
+ if self.activation == F.sigmoid: # Special case for GLU
214
+ y = F.glu(y, dim=-1)
215
+ else:
216
+ y, gate = y.chunk(2, dim=-1)
217
+ y = y * self.activation(gate)
218
+ y = self.fc2(y)
219
+ return y
modeling_lora.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from functools import partial
4
+ from typing import Iterator, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.utils.parametrize as parametrize
9
+ from torch import nn
10
+ from torch.nn import Parameter
11
+ from torch.nn import functional as F
12
+ from transformers import PretrainedConfig
13
+
14
+ from .rotary import RotaryEmbedding
15
+ from .modeling_xlm_roberta import (XLMRobertaFlashConfig, XLMRobertaModel,
16
+ XLMRobertaPreTrainedModel)
17
+
18
+
19
+ def initialized_weights(
20
+ shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
21
+ ) -> torch.Tensor:
22
+ weight_data = []
23
+ for _ in range(num_adaptations):
24
+ new_adaption = torch.zeros(shape)
25
+ if init == "kaiming":
26
+ nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
27
+ elif init == "normal":
28
+ nn.init.normal_(new_adaption)
29
+ else:
30
+ raise NotImplementedError
31
+ weight_data.append(new_adaption)
32
+ return torch.stack(weight_data, dim=0)
33
+
34
+
35
+ class LoRAParametrization(nn.Module):
36
+ """
37
+ This LoRA implementation was inspired by https://github.com/cccntu/minLoRA
38
+ The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy
39
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software
40
+ and associated documentation files (the "Software"), to deal in the Software without restriction,
41
+ including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
42
+ and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
43
+ subject to the following conditions:
44
+ The above copyright notice and this permission notice shall be included in all copies or substantial
45
+ portions of the Software.
46
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
47
+ LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
48
+ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
49
+ WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
50
+ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ fan_in: int,
56
+ fan_out: int,
57
+ layer_type: str = "linear",
58
+ num_adaptations: int = 1,
59
+ rank: int = 4,
60
+ dropout_p: float = 0.0,
61
+ alpha: float = 1,
62
+ ):
63
+ super().__init__()
64
+ # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
65
+ # otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings
66
+ fan_in_fan_out = layer_type == "embedding"
67
+ self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
68
+
69
+ if layer_type == "linear":
70
+ self.lora_A = nn.Parameter(
71
+ initialized_weights((rank, fan_in), num_adaptations, init="kaiming")
72
+ )
73
+ self.lora_B = nn.Parameter(torch.zeros((num_adaptations, fan_out, rank)))
74
+ elif layer_type == "embedding":
75
+ self.lora_A = nn.Parameter(torch.zeros((num_adaptations, fan_in, rank)))
76
+ self.lora_B = nn.Parameter(
77
+ initialized_weights(
78
+ (rank, fan_out), num_adaptations=num_adaptations, init="normal"
79
+ )
80
+ )
81
+ else:
82
+ raise NotImplementedError
83
+
84
+ self.lora_alpha, self.rank = alpha, rank
85
+ self.scaling = alpha / rank
86
+ self.lora_dropout = nn.Dropout(p=dropout_p) if dropout_p > 0 else lambda x: x
87
+ self.dropout_fn = self._dropout if dropout_p > 0 else lambda x: x
88
+ self.register_buffer(
89
+ "lora_dropout_mask",
90
+ torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
91
+ persistent=False,
92
+ )
93
+
94
+ def _dropout(self, A):
95
+ # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
96
+ return A * self.lora_dropout(self.lora_dropout_mask)
97
+
98
+ def lora_forward(self, X, current_task):
99
+ return (
100
+ X
101
+ + torch.matmul(
102
+ *self.swap(
103
+ (
104
+ self.lora_B[current_task],
105
+ self.dropout_fn(self.lora_A[current_task]),
106
+ )
107
+ )
108
+ ).view(X.shape)
109
+ * self.scaling
110
+ )
111
+
112
+ def forward(self, X):
113
+ return X
114
+
115
+ @classmethod
116
+ def from_linear(
117
+ cls,
118
+ layer: nn.Module,
119
+ num_adaptations: int,
120
+ rank: int,
121
+ dropout_p: float,
122
+ alpha: float,
123
+ ):
124
+ assert isinstance(layer, nn.Linear)
125
+ fan_out, fan_in = layer.weight.shape
126
+ return cls(
127
+ fan_in,
128
+ fan_out,
129
+ num_adaptations=num_adaptations,
130
+ layer_type="linear",
131
+ rank=rank,
132
+ dropout_p=dropout_p,
133
+ alpha=alpha,
134
+ )
135
+
136
+ @classmethod
137
+ def from_embedding(
138
+ cls,
139
+ layer: nn.Module,
140
+ num_adaptations: int,
141
+ rank: int,
142
+ dropout_p: float,
143
+ alpha: float,
144
+ ):
145
+ assert isinstance(layer, nn.Embedding)
146
+ fan_in, fan_out = layer.weight.shape
147
+ return cls(
148
+ fan_in,
149
+ fan_out,
150
+ num_adaptations=num_adaptations,
151
+ layer_type="embedding",
152
+ rank=rank,
153
+ dropout_p=dropout_p,
154
+ alpha=alpha,
155
+ )
156
+
157
+ @classmethod
158
+ def add_to_layer(
159
+ cls,
160
+ layer: nn.Module,
161
+ num_adaptations: int,
162
+ rank: int,
163
+ dropout_p: float,
164
+ alpha: float,
165
+ ):
166
+ """
167
+ Registering LoRA adapters to all embedding and linear layers.
168
+ Additionally, we implement a custom forward function for LoRA parametrization.
169
+ This function modifies the layer's forward pass to optionally use task-specific
170
+ parameters. When a `task_id` is provided, it employs a LoRA parametrization
171
+ to modify the original weights according to the specific task. This allows
172
+ the layer to adapt dynamically to different tasks at runtime. If no `task_id`
173
+ is specified, the layer uses its original weights.
174
+ """
175
+ if isinstance(layer, nn.Linear):
176
+ parametrize.register_parametrization(
177
+ layer,
178
+ "weight",
179
+ cls.from_linear(
180
+ layer,
181
+ num_adaptations=num_adaptations,
182
+ rank=rank,
183
+ dropout_p=dropout_p,
184
+ alpha=alpha,
185
+ ),
186
+ )
187
+
188
+ def new_forward(self, input, task_id=None, residual=False):
189
+ if task_id is not None:
190
+ weights = self.parametrizations.weight[0].lora_forward(
191
+ self.weight, current_task=task_id
192
+ )
193
+ else:
194
+ weights = self.weight
195
+
196
+ out = F.linear(input, weights, self.bias)
197
+
198
+ if residual:
199
+ return out, input
200
+ return out
201
+
202
+ layer.forward = new_forward.__get__(layer, layer.__class__)
203
+
204
+ elif isinstance(layer, nn.Embedding):
205
+ parametrize.register_parametrization(
206
+ layer,
207
+ "weight",
208
+ cls.from_embedding(
209
+ layer,
210
+ num_adaptations=num_adaptations,
211
+ rank=rank,
212
+ dropout_p=dropout_p,
213
+ alpha=alpha,
214
+ ),
215
+ )
216
+
217
+ def new_forward(self, input, task_id=None):
218
+ if task_id is not None:
219
+ weights = self.parametrizations.weight[0].lora_forward(
220
+ self.weight, current_task=task_id
221
+ )
222
+ else:
223
+ weights = self.weight
224
+
225
+ out = F.embedding(
226
+ input,
227
+ weights,
228
+ self.padding_idx,
229
+ self.max_norm,
230
+ self.norm_type,
231
+ self.scale_grad_by_freq,
232
+ self.sparse,
233
+ )
234
+
235
+ return out
236
+
237
+ layer.forward = new_forward.__get__(layer, layer.__class__)
238
+
239
+
240
+ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
241
+ """
242
+ A wrapper class around the Jina XLM-RoBERTa model that integrates LoRA (Low-Rank Adaptation) adapters.
243
+ """
244
+
245
+ def __init__(
246
+ self, config: XLMRobertaFlashConfig, roberta: Optional[XLMRobertaModel] = None
247
+ ):
248
+ super().__init__(config)
249
+ if roberta is None:
250
+ self.roberta = XLMRobertaModel(config)
251
+ else:
252
+ self.roberta = roberta
253
+
254
+ self._lora_adaptations = config.lora_adaptations
255
+ if (
256
+ not isinstance(self._lora_adaptations, list)
257
+ or len(self._lora_adaptations) < 1
258
+ ):
259
+ raise ValueError(
260
+ f"`lora_adaptations` must be a list and contain at least one element"
261
+ )
262
+ self._task_instructions = config.task_instructions
263
+ if (
264
+ not isinstance(self._task_instructions, dict)
265
+ or len(self._task_instructions) != len(self._lora_adaptations)
266
+ or not all(
267
+ [v in self._lora_adaptations for v in self._task_instructions.keys()]
268
+ )
269
+ ):
270
+ raise ValueError(
271
+ f"`task_instructions` must be a dict and contain the same number of elements "
272
+ f"as `lora_adaptations` with all keys in `task_instructions` present in `lora_adaptations`."
273
+ )
274
+ self._adaptation_map = {
275
+ name: idx for idx, name in enumerate(self._lora_adaptations)
276
+ }
277
+ self._rank = config.lora_rank
278
+ self._dropout_p = config.lora_dropout_p
279
+ self._alpha = config.lora_alpha
280
+ self._register_lora(
281
+ num_adaptations=len(self._lora_adaptations),
282
+ rank=self._rank,
283
+ dropout_p=self._dropout_p,
284
+ alpha=self._alpha,
285
+ )
286
+ self.main_params_trainable = config.lora_main_params_trainable
287
+
288
+ @property
289
+ def rotary_emb_base(self):
290
+ return self.roberta.rotary_emb_base
291
+
292
+ @rotary_emb_base.setter
293
+ def rotary_emb_base(self, base):
294
+ self.roberta.rotary_emb_base = base
295
+
296
+ @property
297
+ def main_params_trainable(self):
298
+ return self._main_params_trainable
299
+
300
+ @main_params_trainable.setter
301
+ def main_params_trainable(self, val: bool):
302
+ """Whether the main parameters (i.e. those that are not LoRA) should be trainable.
303
+ This method sets the `requires_grad_` attribute of the main weights
304
+ and controls which parameters are returned in `self.parameters()`.
305
+ :param val: Whether or not to make the parameters trainable.
306
+ :return: None
307
+ """
308
+ self._main_params_trainable = val
309
+ for name, param in super().named_parameters():
310
+ if "lora" not in name:
311
+ param.requires_grad_(val)
312
+
313
+ @classmethod
314
+ def from_pretrained(
315
+ cls,
316
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
317
+ *model_args,
318
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
319
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
320
+ ignore_mismatched_sizes: bool = False,
321
+ force_download: bool = False,
322
+ local_files_only: bool = False,
323
+ token: Optional[Union[str, bool]] = None,
324
+ revision: str = "main",
325
+ use_safetensors: bool = None,
326
+ **kwargs,
327
+ ):
328
+ if config.load_trained_adapters: # checkpoint already contains LoRA adapters
329
+ return super().from_pretrained(
330
+ pretrained_model_name_or_path, *model_args, use_flash_attn=config.use_flash_attn, **kwargs
331
+ )
332
+ else: # initializing new adapters
333
+ roberta = XLMRobertaModel.from_pretrained(
334
+ pretrained_model_name_or_path, *model_args, use_flash_attn=config.use_flash_attn, **kwargs
335
+ )
336
+ return cls(config, roberta=roberta)
337
+
338
+ def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
339
+ self.apply(
340
+ partial(
341
+ LoRAParametrization.add_to_layer,
342
+ num_adaptations=num_adaptations,
343
+ rank=rank,
344
+ dropout_p=dropout_p,
345
+ alpha=alpha,
346
+ )
347
+ )
348
+
349
+ def forward(self, *args, **kwargs):
350
+ return self.roberta(*args, **kwargs)
351
+
352
+ def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
353
+ for _, param in self.named_parameters(recurse=recurse):
354
+ yield param
355
+
356
+ def named_parameters(
357
+ self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
358
+ ) -> Iterator[Tuple[str, Parameter]]:
359
+ for name, param in super().named_parameters(
360
+ prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate
361
+ ):
362
+ if "lora" in name or self.main_params_trainable:
363
+ yield name, param
364
+
365
+ @torch.inference_mode()
366
+ def encode(
367
+ self,
368
+ sentences: Union[str, List[str]],
369
+ *args,
370
+ task_type: Optional[str] = None,
371
+ **kwargs,
372
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
373
+ """
374
+ Computes sentence embeddings.
375
+ sentences(`str` or `List[str]`):
376
+ Sentence or sentences to be encoded
377
+ task_type(`str`, *optional*, defaults to `None`):
378
+ Specifies the task for which the encoding is intended. If `task_type` is not provided,
379
+ all LoRA adapters are disabled, and the model reverts to its original,
380
+ general-purpose weights.
381
+ """
382
+ if task_type and task_type not in self._lora_adaptations:
383
+ raise ValueError(
384
+ f"Unsupported task '{task_type}'. "
385
+ f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
386
+ f"Alternatively, don't pass the `task_type` argument to disable LoRA."
387
+ )
388
+ adapter_mask = None
389
+ if task_type:
390
+ task_id = self._adaptation_map[task_type]
391
+ num_examples = 1 if isinstance(sentences, str) else len(sentences)
392
+ adapter_mask = torch.full(
393
+ (num_examples,), task_id, dtype=torch.int32, device=self.device
394
+ )
395
+ if isinstance(sentences, str):
396
+ sentences = self._task_instructions[task_type] + sentences
397
+ else:
398
+ sentences = [self._task_instructions[task_type] + sentence for sentence in sentences]
399
+ return self.roberta.encode(
400
+ sentences, *args, adapter_mask=adapter_mask, **kwargs
401
+ )
modeling_xlm_roberta.py ADDED
@@ -0,0 +1,1208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
2
+ # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
+ # Copyright (c) 2022, Tri Dao.
4
+ # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
5
+ # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
6
+ # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
7
+
8
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
+
10
+ import importlib.util
11
+ import logging
12
+ import re
13
+ from collections import OrderedDict
14
+ from collections.abc import Sequence
15
+ from functools import partial
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import torch.utils.checkpoint
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
+ from transformers import AutoTokenizer, PretrainedConfig
25
+ from transformers.modeling_outputs import (MaskedLMOutput,
26
+ SequenceClassifierOutput)
27
+ from transformers.modeling_utils import PreTrainedModel
28
+ from transformers.models.bert.modeling_bert import (
29
+ BaseModelOutputWithPoolingAndCrossAttentions, BertForPreTrainingOutput)
30
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import \
31
+ XLMRobertaLMHead
32
+
33
+ from .rotary import RotaryEmbedding
34
+ from .block import Block
35
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig
36
+ from .embedding import XLMRobertaEmbeddings
37
+ from .mha import MHA
38
+ from .mlp import FusedMLP, Mlp
39
+ from .xlm_padding import index_first_axis_residual, pad_input, unpad_input
40
+
41
+ try:
42
+ from flash_attn.ops.fused_dense import FusedDense
43
+ except ImportError:
44
+ FusedDense = None
45
+
46
+ try:
47
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn
48
+ except ImportError:
49
+ layer_norm_fn = None
50
+
51
+
52
+ try:
53
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
54
+ except ImportError:
55
+ CrossEntropyLoss = torch.nn.CrossEntropyLoss
56
+
57
+ try:
58
+ from tqdm.autonotebook import trange
59
+ except ImportError:
60
+ trange = None
61
+
62
+
63
+ logger = logging.getLogger(__name__)
64
+
65
+
66
+ def get_use_flash_attn(config: XLMRobertaFlashConfig):
67
+ if not getattr(config, "use_flash_attn", False) or not torch.cuda.is_available():
68
+ return False
69
+ if importlib.util.find_spec("flash_attn") is None:
70
+ logger.warning(
71
+ "flash_attn is not installed. Using PyTorch native attention implementation."
72
+ )
73
+ return False
74
+ return True
75
+
76
+
77
+ def create_mixer_cls(config, cross_attn=False, return_residual=False):
78
+ use_flash_attn = get_use_flash_attn(config)
79
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
80
+ rotary_kwargs = {}
81
+ if config.position_embedding_type == "rotary":
82
+ rotary_kwargs["rotary_emb_dim"] = getattr(
83
+ config, "rotary_emb_dim", config.hidden_size / config.num_attention_heads
84
+ )
85
+ rotary_kwargs["rotary_emb_base"] = config.rotary_emb_base
86
+ rotary_kwargs["rotary_emb_scale_base"] = getattr(
87
+ config, "rotary_emb_scale_base", None
88
+ )
89
+ rotary_kwargs["rotary_emb_interleaved"] = getattr(
90
+ config, "rotary_emb_interleaved", False
91
+ )
92
+ mixer_cls = partial(
93
+ MHA,
94
+ num_heads=config.num_attention_heads,
95
+ cross_attn=cross_attn,
96
+ dropout=config.attention_probs_dropout_prob,
97
+ causal=False,
98
+ fused_bias_fc=fused_bias_fc,
99
+ use_flash_attn=use_flash_attn,
100
+ return_residual=return_residual,
101
+ use_alibi=config.position_embedding_type == "alibi",
102
+ **rotary_kwargs,
103
+ )
104
+ return mixer_cls
105
+
106
+
107
+ def create_mlp_cls(config, layer_idx=None, return_residual=False):
108
+ inner_dim = config.intermediate_size
109
+ fused_mlp = getattr(config, "fused_mlp", False)
110
+ if fused_mlp:
111
+ assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
112
+ "fused_mlp only " "supports approximate gelu"
113
+ )
114
+ if not fused_mlp:
115
+ approximate = (
116
+ "tanh"
117
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
118
+ else "none"
119
+ )
120
+ mlp_cls = partial(
121
+ Mlp,
122
+ hidden_features=inner_dim,
123
+ activation=partial(F.gelu, approximate=approximate),
124
+ return_residual=return_residual,
125
+ )
126
+ else:
127
+ if FusedMLP is None:
128
+ raise ImportError("fused_dense is not installed")
129
+ mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
130
+ # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
131
+ if isinstance(mlp_checkpoint_lvl, Sequence):
132
+ assert layer_idx is not None
133
+ mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
134
+ mlp_cls = partial(
135
+ FusedMLP,
136
+ hidden_features=inner_dim,
137
+ checkpoint_lvl=mlp_checkpoint_lvl,
138
+ return_residual=return_residual,
139
+ )
140
+ return mlp_cls
141
+
142
+
143
+ def create_block(config, layer_idx=None):
144
+ last_layer_subset = getattr(config, "last_layer_subset", False)
145
+ cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
146
+ # TD [2022-12-19]: For cross attention (last layer), we actually want to return the
147
+ # residual x_kv, not residual x. But it's annoying to change the API (and it only affects
148
+ # one layer) so we just choose not to return residual in this case.
149
+ return_residual = not cross_attn
150
+ mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
151
+ mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
152
+ norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
153
+ block = Block(
154
+ config.hidden_size,
155
+ mixer_cls,
156
+ mlp_cls,
157
+ norm_cls=norm_cls,
158
+ prenorm=False,
159
+ resid_dropout1=config.hidden_dropout_prob,
160
+ resid_dropout2=config.hidden_dropout_prob,
161
+ fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
162
+ return_residual=return_residual,
163
+ )
164
+ return block
165
+
166
+
167
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
168
+ def _init_weights(module, initializer_range=0.02):
169
+ if isinstance(module, nn.Linear):
170
+ nn.init.normal_(module.weight, std=initializer_range)
171
+ if module.bias is not None:
172
+ nn.init.zeros_(module.bias)
173
+ elif isinstance(module, nn.Embedding):
174
+ nn.init.normal_(module.weight, std=initializer_range)
175
+ if module.padding_idx is not None:
176
+ nn.init.zeros_(module.weight[module.padding_idx])
177
+
178
+
179
+ class XLMRobertaEncoder(nn.Module):
180
+ def __init__(self, config: XLMRobertaFlashConfig):
181
+ super().__init__()
182
+ self.use_flash_attn = get_use_flash_attn(config)
183
+ self.use_reentrant = config.use_reentrant
184
+ self.layers = nn.ModuleList(
185
+ [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
186
+ )
187
+ self._grad_checkpointing = False
188
+
189
+ @property
190
+ def gradient_checkpointing(self):
191
+ return self._grad_checkpointing
192
+
193
+ @gradient_checkpointing.setter
194
+ def gradient_checkpointing(self, value):
195
+ self._grad_checkpointing = value
196
+
197
+ def forward(
198
+ self, hidden_states, key_padding_mask=None, subset_mask=None, task_id=None
199
+ ):
200
+ """If subset_mask is not None, we only want output for the subset of the sequence.
201
+ This means that we only compute the last layer output for these tokens.
202
+ subset_mask: (batch, seqlen), dtype=torch.bool
203
+ """
204
+ if key_padding_mask is None or not self.use_flash_attn:
205
+ mixer_kwargs = {"task_id": task_id}
206
+ if key_padding_mask is not None:
207
+ mixer_kwargs["key_padding_mask"] = key_padding_mask.bool()
208
+ for layer in self.layers:
209
+ if self._grad_checkpointing:
210
+ hidden_states = torch.utils.checkpoint.checkpoint(
211
+ layer,
212
+ hidden_states,
213
+ use_reentrant=self.use_reentrant,
214
+ mixer_kwargs=mixer_kwargs,
215
+ )
216
+ else:
217
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
218
+ if subset_mask is not None:
219
+ hidden_states = hidden_states[subset_mask]
220
+ else:
221
+ batch, seqlen = hidden_states.shape[:2]
222
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch = (
223
+ unpad_input(hidden_states, key_padding_mask)
224
+ )
225
+ mixer_kwargs = {
226
+ "cu_seqlens": cu_seqlens,
227
+ "max_seqlen": max_seqlen_in_batch,
228
+ "task_id": task_id,
229
+ }
230
+
231
+ if subset_mask is None:
232
+ for layer in self.layers:
233
+ if self._grad_checkpointing:
234
+ hidden_states = torch.utils.checkpoint.checkpoint(
235
+ layer,
236
+ hidden_states,
237
+ use_reentrant=self.use_reentrant,
238
+ mixer_kwargs=mixer_kwargs,
239
+ )
240
+ else:
241
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
242
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
243
+ else:
244
+ for layer in self.layers[:-1]:
245
+ if self._grad_checkpointing:
246
+ hidden_states = torch.utils.checkpoint.checkpoint(
247
+ layer,
248
+ hidden_states,
249
+ use_reentrant=self.use_reentrant,
250
+ mixer_kwargs=mixer_kwargs,
251
+ )
252
+ else:
253
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
254
+ if key_padding_mask is not None:
255
+ subset_idx = torch.nonzero(
256
+ subset_mask[key_padding_mask], as_tuple=False
257
+ ).flatten()
258
+ subset_seqlens = (subset_mask & key_padding_mask).sum(
259
+ dim=-1, dtype=torch.int32
260
+ )
261
+ subset_cu_seqlens = F.pad(
262
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
263
+ (1, 0),
264
+ )
265
+ else:
266
+ subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
267
+ subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
268
+ subset_cu_seqlens = F.pad(
269
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
270
+ (1, 0),
271
+ )
272
+ hidden_states_subset, hidden_states = index_first_axis_residual(
273
+ hidden_states, subset_idx
274
+ )
275
+ # It's ok to set max_seqlen_q to be much larger
276
+ mixer_kwargs = {
277
+ "x_kv": hidden_states,
278
+ "cu_seqlens": subset_cu_seqlens,
279
+ "max_seqlen": max_seqlen_in_batch,
280
+ "cu_seqlens_k": cu_seqlens,
281
+ "max_seqlen_k": max_seqlen_in_batch,
282
+ }
283
+ if self._grad_checkpointing:
284
+ torch.utils.checkpoint.checkpoint(
285
+ self.layers[-1],
286
+ hidden_states_subset,
287
+ use_reentrant=self.use_reentrant,
288
+ mixer_kwargs=mixer_kwargs,
289
+ )
290
+ else:
291
+ hidden_states = self.layers[-1](
292
+ hidden_states_subset, mixer_kwargs=mixer_kwargs
293
+ )
294
+ return hidden_states
295
+
296
+
297
+ class XLMRobertaPooler(nn.Module):
298
+ def __init__(self, config):
299
+ super().__init__()
300
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
301
+ if fused_bias_fc and FusedDense is None:
302
+ raise ImportError("fused_dense is not installed")
303
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
304
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
305
+ self.activation = nn.Tanh()
306
+
307
+ def forward(self, hidden_states, pool=True, task_id=None):
308
+ # We "pool" the model by simply taking the hidden state corresponding
309
+ # to the first token.
310
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
311
+ if task_id is not None:
312
+ pooled_output = self.dense(
313
+ first_token_tensor, task_id=task_id
314
+ )
315
+ else:
316
+ pooled_output = self.dense(first_token_tensor)
317
+ pooled_output = self.activation(pooled_output)
318
+ return pooled_output
319
+
320
+
321
+ class XLMRobertaPredictionHeadTransform(nn.Module):
322
+ def __init__(self, config):
323
+ super().__init__()
324
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
325
+ if fused_bias_fc and FusedDense is None:
326
+ raise ImportError("fused_dense is not installed")
327
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
328
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
329
+ raise ImportError("Triton is not installed")
330
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
331
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
332
+ approximate = (
333
+ "tanh"
334
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
335
+ else "none"
336
+ )
337
+ self.transform_act_fn = nn.GELU(approximate=approximate)
338
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
339
+
340
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
341
+ hidden_states = self.dense(hidden_states)
342
+ hidden_states = self.transform_act_fn(hidden_states)
343
+ if not self.fused_dropout_add_ln:
344
+ hidden_states = self.layer_norm(hidden_states)
345
+ else:
346
+ hidden_states = layer_norm_fn(
347
+ hidden_states,
348
+ self.layer_norm.weight,
349
+ self.layer_norm.bias,
350
+ eps=self.layer_norm.eps,
351
+ )
352
+ return hidden_states
353
+
354
+
355
+ class XLMRobertaLMPredictionHead(nn.Module):
356
+ def __init__(self, config):
357
+ super().__init__()
358
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
359
+ if fused_bias_fc and FusedDense is None:
360
+ raise ImportError("fused_dense is not installed")
361
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
362
+
363
+ self.transform = XLMRobertaPredictionHeadTransform(config)
364
+
365
+ # The output weights are the same as the input embeddings, but there is
366
+ # an output-only bias for each token.
367
+ self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
368
+
369
+ def forward(self, hidden_states):
370
+ hidden_states = self.transform(hidden_states)
371
+ hidden_states = self.decoder(hidden_states)
372
+ return hidden_states
373
+
374
+
375
+ class XLMRobertaPreTrainingHeads(nn.Module):
376
+ def __init__(self, config):
377
+ super().__init__()
378
+ self.predictions = XLMRobertaLMPredictionHead(config)
379
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
380
+
381
+ def forward(self, sequence_output, pooled_output):
382
+ prediction_scores = self.predictions(sequence_output)
383
+ seq_relationship_score = self.seq_relationship(pooled_output)
384
+ return prediction_scores, seq_relationship_score
385
+
386
+
387
+ class XLMRobertaPreTrainedModel(PreTrainedModel):
388
+ """An abstract class to handle weights initialization and
389
+ a simple interface for dowloading and loading pretrained models.
390
+ """
391
+
392
+ config_class = XLMRobertaFlashConfig
393
+ base_model_prefix = "roberta"
394
+ supports_gradient_checkpointing = True
395
+ _supports_param_buffer_assignment = False
396
+
397
+ def _set_gradient_checkpointing(self, module, value=False):
398
+ if isinstance(module, XLMRobertaEncoder):
399
+ module.gradient_checkpointing = value
400
+
401
+ @classmethod
402
+ def from_pretrained(
403
+ cls,
404
+ *args,
405
+ **kwargs,
406
+ ):
407
+ if not "torch_dtype" in kwargs:
408
+ kwargs["torch_dtype"] = "auto"
409
+ return super().from_pretrained(*args, **kwargs)
410
+
411
+
412
+ class XLMRobertaModel(XLMRobertaPreTrainedModel):
413
+ def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
414
+ super().__init__(config)
415
+ self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
416
+ if config.vocab_size % self.pad_vocab_size_multiple != 0:
417
+ config.vocab_size += self.pad_vocab_size_multiple - (
418
+ config.vocab_size % self.pad_vocab_size_multiple
419
+ )
420
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
421
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
422
+ raise ImportError("Triton is not installed")
423
+ assert config.hidden_act in [
424
+ "gelu",
425
+ "gelu_new",
426
+ "gelu_fast",
427
+ "gelu_pytorch_tanh",
428
+ ]
429
+ self.embeddings = XLMRobertaEmbeddings(
430
+ config.hidden_size,
431
+ config.vocab_size,
432
+ (
433
+ config.max_position_embeddings
434
+ if config.position_embedding_type == "absolute"
435
+ else -1
436
+ ),
437
+ config.type_vocab_size,
438
+ padding_idx=config.pad_token_id,
439
+ )
440
+ self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
441
+ self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
442
+ self.encoder = XLMRobertaEncoder(config)
443
+ self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
444
+
445
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
446
+ self.tokenizer = AutoTokenizer.from_pretrained(
447
+ self.name_or_path, trust_remote_code=True
448
+ )
449
+ self._rotary_emb_base = config.rotary_emb_base
450
+
451
+ @torch.inference_mode()
452
+ def encode(
453
+ self: "XLMRobertaModel",
454
+ sentences: Union[str, List[str]],
455
+ batch_size: int = 32,
456
+ show_progress_bar: Optional[bool] = None,
457
+ output_value: str = "sentence_embedding",
458
+ convert_to_numpy: bool = True,
459
+ convert_to_tensor: bool = False,
460
+ device: Optional[torch.device] = None,
461
+ normalize_embeddings: bool = False,
462
+ truncate_dim: Optional[int] = None,
463
+ adapter_mask: Optional[torch.Tensor] = None,
464
+ task_type: Optional[str] = None,
465
+ **tokenizer_kwargs,
466
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
467
+ """
468
+ Computes sentence embeddings
469
+ Args:
470
+ sentences(`str` or `List[str]`):
471
+ Sentence or sentences to be encoded
472
+ batch_size(`int`, *optional*, defaults to 32):
473
+ Batch size for the computation
474
+ show_progress_bar(`bool`, *optional*, defaults to None):
475
+ Show a progress bar when encoding sentences.
476
+ If set to None, progress bar is only shown when
477
+ `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
478
+ output_value(`str`, *optional*, defaults to 'sentence_embedding'):
479
+ Default sentence_embedding, to get sentence embeddings.
480
+ Can be set to token_embeddings to get wordpiece token embeddings.
481
+ Set to None, to get all output values
482
+ convert_to_numpy(`bool`, *optional*, defaults to True):
483
+ If true, the output is a list of numpy vectors.
484
+ Else, it is a list of pytorch tensors.
485
+ convert_to_tensor(`bool`, *optional*, defaults to False):
486
+ If true, you get one large tensor as return.
487
+ Overwrites any setting from convert_to_numpy
488
+ device(`torch.device`, *optional*, defaults to None):
489
+ Which torch.device to use for the computation
490
+ normalize_embeddings(`bool`, *optional*, defaults to False):
491
+ If set to true, returned vectors will have length 1. In that case, the
492
+ faster dot-product (util.dot_score) instead of cosine similarity can
493
+ be used.
494
+ truncate_dim(`int`, *optional*, defaults to None):
495
+ The dimension to truncate sentence embeddings to. `None` does no truncation.
496
+ tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
497
+ Keyword arguments for the tokenizer
498
+ Returns:
499
+ By default, a list of tensors is returned.
500
+ If convert_to_tensor, a stacked tensor is returned.
501
+ If convert_to_numpy, a numpy matrix is returned.
502
+ """
503
+ is_training = self.training
504
+ self.eval()
505
+
506
+ if show_progress_bar is None:
507
+ show_progress_bar = (
508
+ logger.getEffectiveLevel() == logging.INFO
509
+ or logger.getEffectiveLevel() == logging.DEBUG
510
+ )
511
+
512
+ if convert_to_tensor:
513
+ convert_to_numpy = False
514
+
515
+ if output_value != "sentence_embedding":
516
+ convert_to_tensor = False
517
+ convert_to_numpy = False
518
+
519
+ input_was_string = False
520
+ if isinstance(sentences, str) or not hasattr(sentences, "__len__"):
521
+ sentences = [sentences]
522
+ input_was_string = True
523
+
524
+ if device is not None:
525
+ self.to(device)
526
+
527
+ permutation = np.argsort([-len(i) for i in sentences])
528
+ inverse_permutation = np.argsort(permutation)
529
+ sentences = [sentences[idx] for idx in permutation]
530
+
531
+ tokenizer_kwargs["padding"] = tokenizer_kwargs.get("padding", True)
532
+ tokenizer_kwargs["max_length"] = tokenizer_kwargs.get(
533
+ "max_length", self.tokenizer.init_kwargs.get("model_max_length", 8192)
534
+ )
535
+ tokenizer_kwargs["truncation"] = tokenizer_kwargs.get("truncation", True)
536
+
537
+ all_embeddings = []
538
+
539
+ if trange is not None:
540
+ range_iter = trange(
541
+ 0,
542
+ len(sentences),
543
+ batch_size,
544
+ desc="Encoding",
545
+ disable=not show_progress_bar,
546
+ )
547
+ else:
548
+ range_iter = range(0, len(sentences), batch_size)
549
+ lora_arguments = (
550
+ {"adapter_mask": adapter_mask} if adapter_mask is not None else {}
551
+ )
552
+ for i in range_iter:
553
+ encoded_input = self.tokenizer(
554
+ sentences[i : i + batch_size],
555
+ return_tensors="pt",
556
+ **tokenizer_kwargs,
557
+ ).to(self.device)
558
+ token_embs = self.forward(**encoded_input, **lora_arguments)[0]
559
+
560
+ # Accumulate in fp32 to avoid overflow
561
+ token_embs = token_embs.float()
562
+
563
+ if output_value == "token_embeddings":
564
+ raise NotImplementedError
565
+ elif output_value is None:
566
+ raise NotImplementedError
567
+ else:
568
+ if self.config.emb_pooler == "cls":
569
+ embeddings = self.cls_pooling(
570
+ token_embs, encoded_input["attention_mask"]
571
+ )
572
+ else:
573
+ embeddings = self.mean_pooling(
574
+ token_embs, encoded_input["attention_mask"]
575
+ )
576
+
577
+ if normalize_embeddings:
578
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
579
+
580
+ if convert_to_numpy:
581
+ embeddings = embeddings.cpu()
582
+ all_embeddings.extend(embeddings)
583
+
584
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
585
+
586
+ truncate_dim = truncate_dim or self.config.truncate_dim
587
+ if truncate_dim:
588
+ all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
589
+
590
+ if convert_to_tensor:
591
+ all_embeddings = torch.stack(all_embeddings)
592
+ elif convert_to_numpy:
593
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
594
+
595
+ if input_was_string:
596
+ all_embeddings = all_embeddings[0]
597
+
598
+ self.train(is_training)
599
+ return all_embeddings
600
+
601
+ def truncate_embeddings(self, embeddings, truncate_dim):
602
+ if not self.config.matryoshka_dimensions:
603
+ logger.warning(
604
+ "Matryoshka embeddings are not supported, so dimension truncation will not be performed."
605
+ )
606
+ return embeddings
607
+ elif truncate_dim in self.config.matryoshka_dimensions:
608
+ return [tensor[:truncate_dim] for tensor in embeddings]
609
+ else:
610
+ raise ValueError(
611
+ f"The provided `truncate_dim` value of {truncate_dim} is not supported. "
612
+ f"Supported dimensions are {self.config.matryoshka_dimensions}."
613
+ )
614
+
615
+ def mean_pooling(
616
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
617
+ ):
618
+ input_mask_expanded = (
619
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
620
+ )
621
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
622
+ input_mask_expanded.sum(1), min=1e-9
623
+ )
624
+
625
+ def cls_pooling(self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor):
626
+ return token_embeddings[:, 0]
627
+
628
+ @property
629
+ def rotary_emb_base(self):
630
+ return self._rotary_emb_base
631
+
632
+ @rotary_emb_base.setter
633
+ def rotary_emb_base(self, base):
634
+ if not isinstance(base, (int, float)):
635
+ raise TypeError("Base must be an integer or float")
636
+ logger.info(f"Changing RoPE base value to {base}")
637
+ for layer in self.encoder.layers:
638
+ layer.mixer.rotary_emb.base = base
639
+ self._rotary_emb_base = base
640
+
641
+ def forward(
642
+ self,
643
+ input_ids,
644
+ attention_mask,
645
+ task_id,
646
+ position_ids=None,
647
+ token_type_ids=None,
648
+ masked_tokens_mask=None,
649
+ return_dict=None,
650
+ **kwargs,
651
+ ):
652
+ """If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
653
+ we only want the output for the masked tokens. This means that we only compute the last
654
+ layer output for these tokens.
655
+ masked_tokens_mask: (batch, seqlen), dtype=torch.bool
656
+ """
657
+ if kwargs:
658
+ for key, value in kwargs.items():
659
+ if value is not None:
660
+ logger.warning(
661
+ "Flash attention implementation does not support kwargs: %s",
662
+ key,
663
+ )
664
+
665
+ return_dict = (
666
+ return_dict if return_dict is not None else self.config.use_return_dict
667
+ )
668
+
669
+ hidden_states = self.embeddings(
670
+ input_ids,
671
+ position_ids=position_ids,
672
+ token_type_ids=token_type_ids,
673
+ task_id=task_id,
674
+ )
675
+ # TD [2022-12:18]: Don't need to force residual in fp32
676
+ # BERT puts embedding LayerNorm before embedding dropout.
677
+ if not self.fused_dropout_add_ln:
678
+ hidden_states = self.emb_ln(hidden_states)
679
+ else:
680
+ hidden_states = layer_norm_fn(
681
+ hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
682
+ )
683
+ hidden_states = self.emb_drop(hidden_states)
684
+
685
+ if masked_tokens_mask is not None:
686
+ batch_size, seqlen = input_ids.shape[:2]
687
+ # We also need the first column for the CLS token
688
+ first_col_mask = torch.zeros(
689
+ batch_size, seqlen, dtype=torch.bool, device=input_ids.device
690
+ )
691
+ first_col_mask[:, 0] = True
692
+ subset_mask = masked_tokens_mask | first_col_mask
693
+ else:
694
+ subset_mask = None
695
+
696
+ sequence_output = self.encoder(
697
+ hidden_states,
698
+ key_padding_mask=attention_mask,
699
+ subset_mask=subset_mask,
700
+ task_id=task_id,
701
+ )
702
+
703
+ if masked_tokens_mask is None:
704
+ pooled_output = (
705
+ self.pooler(sequence_output, task_id=task_id)
706
+ if self.pooler is not None
707
+ else None
708
+ )
709
+ else:
710
+ # TD [2022-03-01]: the indexing here is very tricky.
711
+ if attention_mask is not None:
712
+ subset_idx = subset_mask[attention_mask]
713
+ pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
714
+ sequence_output = sequence_output[
715
+ masked_tokens_mask[attention_mask][subset_idx]
716
+ ]
717
+ else:
718
+ pool_input = sequence_output[first_col_mask[subset_mask]]
719
+ sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
720
+ pooled_output = (
721
+ self.pooler(pool_input, pool=False, task_id=task_id)
722
+ if self.pooler is not None
723
+ else None
724
+ )
725
+
726
+ if not return_dict:
727
+ return sequence_output, pooled_output
728
+
729
+ return BaseModelOutputWithPoolingAndCrossAttentions(
730
+ last_hidden_state=sequence_output,
731
+ pooler_output=pooled_output,
732
+ )
733
+
734
+
735
+ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
736
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
737
+
738
+ def __init__(self, config):
739
+ super().__init__(config)
740
+
741
+ if config.is_decoder:
742
+ logger.warning(
743
+ "If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder=False` for "
744
+ "bi-directional self-attention."
745
+ )
746
+
747
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
748
+ self.lm_head = XLMRobertaLMHead(config)
749
+
750
+ # Initialize weights and apply final processing
751
+ self.post_init()
752
+
753
+ def get_input_embeddings(self):
754
+ return self.roberta.embeddings.word_embeddings
755
+
756
+ def get_output_embeddings(self):
757
+ return self.lm_head.decoder
758
+
759
+ def set_output_embeddings(self, new_embeddings):
760
+ self.lm_head.decoder = new_embeddings
761
+
762
+ def forward(
763
+ self,
764
+ input_ids: Optional[torch.LongTensor] = None,
765
+ attention_mask: Optional[torch.FloatTensor] = None,
766
+ token_type_ids: Optional[torch.LongTensor] = None,
767
+ position_ids: Optional[torch.LongTensor] = None,
768
+ head_mask: Optional[torch.FloatTensor] = None,
769
+ inputs_embeds: Optional[torch.FloatTensor] = None,
770
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
771
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
772
+ labels: Optional[torch.LongTensor] = None,
773
+ output_attentions: Optional[bool] = None,
774
+ output_hidden_states: Optional[bool] = None,
775
+ return_dict: Optional[bool] = None,
776
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
777
+ r"""
778
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
779
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
780
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
781
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
782
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
783
+ Used to hide legacy arguments that have been deprecated.
784
+ """
785
+ return_dict = (
786
+ return_dict if return_dict is not None else self.config.use_return_dict
787
+ )
788
+
789
+ outputs = self.roberta(
790
+ input_ids,
791
+ attention_mask=attention_mask,
792
+ token_type_ids=token_type_ids,
793
+ position_ids=position_ids,
794
+ head_mask=head_mask,
795
+ inputs_embeds=inputs_embeds,
796
+ encoder_hidden_states=encoder_hidden_states,
797
+ encoder_attention_mask=encoder_attention_mask,
798
+ output_attentions=output_attentions,
799
+ output_hidden_states=output_hidden_states,
800
+ return_dict=return_dict,
801
+ )
802
+ sequence_output = outputs[0]
803
+ prediction_scores = self.lm_head(sequence_output)
804
+
805
+ masked_lm_loss = None
806
+ if labels is not None:
807
+ # move labels to correct device to enable model parallelism
808
+ labels = labels.to(prediction_scores.device)
809
+ loss_fct = CrossEntropyLoss()
810
+ masked_lm_loss = loss_fct(
811
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
812
+ )
813
+
814
+ if not return_dict:
815
+ output = (prediction_scores,) + outputs[2:]
816
+ return (
817
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
818
+ )
819
+
820
+ return MaskedLMOutput(
821
+ loss=masked_lm_loss,
822
+ logits=prediction_scores,
823
+ hidden_states=outputs.hidden_states,
824
+ attentions=outputs.attentions,
825
+ )
826
+
827
+
828
+ def remap_state_dict(state_dict, config: PretrainedConfig):
829
+ """
830
+ Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
831
+ """
832
+
833
+ # LayerNorm
834
+ def key_mapping_ln_gamma_beta(key):
835
+ key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
836
+ key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
837
+ return key
838
+
839
+ state_dict = OrderedDict(
840
+ (key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items()
841
+ )
842
+
843
+ # Layers
844
+ def key_mapping_layers(key):
845
+ return re.sub(r"^bert.encoder.layer.", "bert.encoder.layers.", key)
846
+
847
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
848
+
849
+ # LayerNorm
850
+ def key_mapping_ln(key):
851
+ key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
852
+ key = re.sub(
853
+ r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
854
+ r"bert.encoder.layers.\1.norm1.\2",
855
+ key,
856
+ )
857
+ key = re.sub(
858
+ r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
859
+ r"bert.encoder.layers.\1.norm2.\2",
860
+ key,
861
+ )
862
+ key = re.sub(
863
+ r"^cls.predictions.transform.LayerNorm.(weight|bias)",
864
+ r"cls.predictions.transform.layer_norm.\1",
865
+ key,
866
+ )
867
+ return key
868
+
869
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
870
+
871
+ # MLP
872
+ def key_mapping_mlp(key):
873
+ key = re.sub(
874
+ r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
875
+ r"bert.encoder.layers.\1.mlp.fc1.\2",
876
+ key,
877
+ )
878
+ key = re.sub(
879
+ r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
880
+ r"bert.encoder.layers.\1.mlp.fc2.\2",
881
+ key,
882
+ )
883
+ return key
884
+
885
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
886
+
887
+ # Attention
888
+ last_layer_subset = getattr(config, "last_layer_subset", False)
889
+ for d in range(config.num_hidden_layers):
890
+ Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
891
+ Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
892
+ Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
893
+ bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
894
+ bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
895
+ bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
896
+ if not (last_layer_subset and d == config.num_hidden_layers - 1):
897
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
898
+ [Wq, Wk, Wv], dim=0
899
+ )
900
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat(
901
+ [bq, bk, bv], dim=0
902
+ )
903
+ else:
904
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
905
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat(
906
+ [Wk, Wv], dim=0
907
+ )
908
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
909
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat(
910
+ [bk, bv], dim=0
911
+ )
912
+
913
+ def key_mapping_attn(key):
914
+ return re.sub(
915
+ r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
916
+ r"bert.encoder.layers.\1.mixer.out_proj.\2",
917
+ key,
918
+ )
919
+
920
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
921
+
922
+ def key_mapping_decoder_bias(key):
923
+ return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
924
+
925
+ state_dict = OrderedDict(
926
+ (key_mapping_decoder_bias(k), v) for k, v in state_dict.items()
927
+ )
928
+
929
+ # Word embedding
930
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
931
+ if pad_vocab_size_multiple > 1:
932
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
933
+ state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
934
+ word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
935
+ )
936
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
937
+ state_dict["cls.predictions.decoder.weight"] = F.pad(
938
+ decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
939
+ )
940
+ # If the vocab was padded, we want to set the decoder bias for those padded indices to be
941
+ # strongly negative (i.e. the decoder shouldn't predict those indices).
942
+ # TD [2022-05-09]: I don't think it affects the MLPerf training.
943
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
944
+ state_dict["cls.predictions.decoder.bias"] = F.pad(
945
+ decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
946
+ )
947
+
948
+ return state_dict
949
+
950
+
951
+ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
952
+ """
953
+ Map the state_dict of a flash_attn model to be Huggingface BERT compatible.
954
+
955
+ This function is meant to be the inverse of remap_state_dict.
956
+ """
957
+ # Word embedding
958
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
959
+ if pad_vocab_size_multiple > 1:
960
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
961
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
962
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
963
+ # unpad embeddings
964
+ state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
965
+ : config.orig_vocab_size, :
966
+ ]
967
+ state_dict["cls.predictions.decoder.weight"] = decoder_weight[
968
+ : config.orig_vocab_size, :
969
+ ]
970
+ state_dict["cls.predictions.decoder.bias"] = decoder_bias[
971
+ : config.orig_vocab_size
972
+ ]
973
+
974
+ for d in range(config.num_hidden_layers):
975
+ last_layer_subset = getattr(config, "last_layer_subset", False)
976
+ if not last_layer_subset or d != (config.num_hidden_layers - 1):
977
+ Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
978
+ Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
979
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = (
980
+ Wqkv_weights[: Wqkv_weights.shape[0] // 3, :]
981
+ )
982
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = (
983
+ Wqkv_weights[
984
+ Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
985
+ ]
986
+ )
987
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = (
988
+ Wqkv_weights[2 * Wqkv_weights.shape[0] // 3 :, :]
989
+ )
990
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = (
991
+ Wqkv_biases[: Wqkv_biases.shape[0] // 3]
992
+ )
993
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = (
994
+ Wqkv_biases[Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3]
995
+ )
996
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = (
997
+ Wqkv_biases[2 * Wqkv_biases.shape[0] // 3 :]
998
+ )
999
+ else:
1000
+ Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
1001
+ Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
1002
+ Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
1003
+ Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
1004
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = (
1005
+ Wq_weight
1006
+ )
1007
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = (
1008
+ Wkv_weights[: Wkv_weights.shape[0] // 2, :]
1009
+ )
1010
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = (
1011
+ Wkv_weights[Wkv_weights.shape[0] // 2 :, :]
1012
+ )
1013
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
1014
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
1015
+ : Wkv_biases.shape[0] // 2
1016
+ ]
1017
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = (
1018
+ Wkv_biases[Wkv_biases.shape[0] // 2 :]
1019
+ )
1020
+
1021
+ def inv_key_mapping_ln(key):
1022
+ key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
1023
+ key = re.sub(
1024
+ r"bert.encoder.layers.(\d+).norm1.(weight|bias)",
1025
+ r"bert.encoder.layers.\1.attention.output.LayerNorm.\2",
1026
+ key,
1027
+ )
1028
+ key = re.sub(
1029
+ r"bert.encoder.layers.(\d+).norm2.(weight|bias)",
1030
+ r"bert.encoder.layers.\1.output.LayerNorm.\2",
1031
+ key,
1032
+ )
1033
+ key = re.sub(
1034
+ r"cls.predictions.transform.layer_norm.(weight|bias)",
1035
+ r"cls.predictions.transform.LayerNorm.\1",
1036
+ key,
1037
+ )
1038
+ return key
1039
+
1040
+ def inv_key_mapping_ln_gamma_beta(key):
1041
+ key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key)
1042
+ key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key)
1043
+ return key
1044
+
1045
+ def inv_key_mapping_layers(key):
1046
+ return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key)
1047
+
1048
+ def inv_key_mapping_mlp(key):
1049
+ key = re.sub(
1050
+ r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)",
1051
+ r"bert.encoder.layer.\1.intermediate.dense.\2",
1052
+ key,
1053
+ )
1054
+ key = re.sub(
1055
+ r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)",
1056
+ r"bert.encoder.layer.\1.output.dense.\2",
1057
+ key,
1058
+ )
1059
+ return key
1060
+
1061
+ def inv_key_mapping_attn(key):
1062
+ return re.sub(
1063
+ r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)",
1064
+ r"bert.encoder.layer.\1.attention.output.dense.\2",
1065
+ key,
1066
+ )
1067
+
1068
+ def inv_key_mapping_decoder_bias(key):
1069
+ return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
1070
+
1071
+ state_dict = OrderedDict(
1072
+ (inv_key_mapping_ln(key), value) for key, value in state_dict.items()
1073
+ )
1074
+ state_dict = OrderedDict(
1075
+ (inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
1076
+ )
1077
+ state_dict = OrderedDict(
1078
+ (inv_key_mapping_layers(key), value) for key, value in state_dict.items()
1079
+ )
1080
+ state_dict = OrderedDict(
1081
+ (inv_key_mapping_mlp(key), value) for key, value in state_dict.items()
1082
+ )
1083
+ state_dict = OrderedDict(
1084
+ (inv_key_mapping_attn(key), value) for key, value in state_dict.items()
1085
+ )
1086
+ state_dict = OrderedDict(
1087
+ (inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
1088
+ )
1089
+
1090
+ return state_dict
1091
+
1092
+
1093
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->XLMRoberta
1094
+ class XLMRobertaClassificationHead(nn.Module):
1095
+ """Head for sentence-level classification tasks."""
1096
+
1097
+ def __init__(self, config):
1098
+ super().__init__()
1099
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
1100
+ if fused_bias_fc and FusedDense is None:
1101
+ raise ImportError("fused_dense is not installed")
1102
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
1103
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
1104
+ classifier_dropout = (
1105
+ config.classifier_dropout
1106
+ if config.classifier_dropout is not None
1107
+ else config.hidden_dropout_prob
1108
+ )
1109
+ self.dropout = nn.Dropout(classifier_dropout)
1110
+ self.out_proj = linear_cls(config.hidden_size, config.num_labels)
1111
+
1112
+ def forward(self, features, **kwargs):
1113
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1114
+ x = self.dropout(x)
1115
+ x = self.dense(x)
1116
+ x = torch.tanh(x)
1117
+ x = self.dropout(x)
1118
+ x = self.out_proj(x)
1119
+ return x
1120
+
1121
+
1122
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
1123
+ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
1124
+ def __init__(self, config):
1125
+ super().__init__(config)
1126
+ self.num_labels = config.num_labels
1127
+ self.config = config
1128
+
1129
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
1130
+ self.classifier = XLMRobertaClassificationHead(config)
1131
+
1132
+ # Initialize weights and apply final processing
1133
+ self.post_init()
1134
+
1135
+ def forward(
1136
+ self,
1137
+ input_ids: Optional[torch.LongTensor] = None,
1138
+ attention_mask: Optional[torch.FloatTensor] = None,
1139
+ token_type_ids: Optional[torch.LongTensor] = None,
1140
+ position_ids: Optional[torch.LongTensor] = None,
1141
+ head_mask: Optional[torch.FloatTensor] = None,
1142
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1143
+ labels: Optional[torch.LongTensor] = None,
1144
+ output_attentions: Optional[bool] = None,
1145
+ output_hidden_states: Optional[bool] = None,
1146
+ return_dict: Optional[bool] = None,
1147
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1148
+ r"""
1149
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1150
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1151
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1152
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1153
+ """
1154
+ return_dict = (
1155
+ return_dict if return_dict is not None else self.config.use_return_dict
1156
+ )
1157
+
1158
+ outputs = self.roberta(
1159
+ input_ids,
1160
+ attention_mask=attention_mask,
1161
+ token_type_ids=token_type_ids,
1162
+ position_ids=position_ids,
1163
+ head_mask=head_mask,
1164
+ inputs_embeds=inputs_embeds,
1165
+ output_attentions=output_attentions,
1166
+ output_hidden_states=output_hidden_states,
1167
+ return_dict=return_dict,
1168
+ )
1169
+ sequence_output = outputs[0]
1170
+ logits = self.classifier(sequence_output)
1171
+
1172
+ loss = None
1173
+ if labels is not None:
1174
+ # move labels to correct device to enable model parallelism
1175
+ labels = labels.to(logits.device)
1176
+ if self.config.problem_type is None:
1177
+ if self.num_labels == 1:
1178
+ self.config.problem_type = "regression"
1179
+ elif self.num_labels > 1 and (
1180
+ labels.dtype == torch.long or labels.dtype == torch.int
1181
+ ):
1182
+ self.config.problem_type = "single_label_classification"
1183
+ else:
1184
+ self.config.problem_type = "multi_label_classification"
1185
+
1186
+ if self.config.problem_type == "regression":
1187
+ loss_fct = MSELoss()
1188
+ if self.num_labels == 1:
1189
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1190
+ else:
1191
+ loss = loss_fct(logits, labels)
1192
+ elif self.config.problem_type == "single_label_classification":
1193
+ loss_fct = CrossEntropyLoss()
1194
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1195
+ elif self.config.problem_type == "multi_label_classification":
1196
+ loss_fct = BCEWithLogitsLoss()
1197
+ loss = loss_fct(logits, labels)
1198
+
1199
+ if not return_dict:
1200
+ output = (logits,) + outputs[2:]
1201
+ return ((loss,) + output) if loss is not None else output
1202
+
1203
+ return SequenceClassifierOutput(
1204
+ loss=loss,
1205
+ logits=logits,
1206
+ hidden_states=outputs.hidden_states,
1207
+ attentions=outputs.attentions,
1208
+ )
rotary.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py
2
+ # Commit id: 3566596ad867ee415dd3c12616dd50c610176f6c
3
+ # Rotary varlen support from https://github.com/Dao-AILab/flash-attention/pull/556
4
+
5
+ # Copyright (c) 2023, Tri Dao.
6
+
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ from einops import rearrange, repeat
11
+
12
+ if torch.cuda.is_available():
13
+ try:
14
+ from flash_attn.ops.triton.rotary import apply_rotary
15
+ except ImportError:
16
+
17
+ def apply_rotary(*args, **kwargs):
18
+ raise RuntimeError(
19
+ "FlashAttention is not installed. To proceed with training, please install FlashAttention. "
20
+ "For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model."
21
+ )
22
+
23
+
24
+ def rotate_half(x, interleaved=False):
25
+ if not interleaved:
26
+ x1, x2 = x.chunk(2, dim=-1)
27
+ return torch.cat((-x2, x1), dim=-1)
28
+ else:
29
+ x1, x2 = x[..., ::2], x[..., 1::2]
30
+ return rearrange(
31
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
32
+ )
33
+
34
+
35
+ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
36
+ """
37
+ x: (batch_size, seqlen, nheads, headdim)
38
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
39
+ """
40
+ ro_dim = cos.shape[-1] * 2
41
+ assert ro_dim <= x.shape[-1]
42
+ cos, sin = (
43
+ cos[: x.shape[1]],
44
+ sin[: x.shape[1]],
45
+ )
46
+ cos = repeat(
47
+ cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
48
+ )
49
+ sin = repeat(
50
+ sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
51
+ )
52
+ return torch.cat(
53
+ [
54
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
55
+ x[..., ro_dim:],
56
+ ],
57
+ dim=-1,
58
+ )
59
+
60
+
61
+ class ApplyRotaryEmb(torch.autograd.Function):
62
+ @staticmethod
63
+ def forward(
64
+ ctx,
65
+ x,
66
+ cos,
67
+ sin,
68
+ interleaved=False,
69
+ inplace=False,
70
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
71
+ cu_seqlens: Optional[torch.Tensor] = None,
72
+ max_seqlen: Optional[int] = None,
73
+ ):
74
+ out = apply_rotary(
75
+ x,
76
+ cos,
77
+ sin,
78
+ seqlen_offsets=seqlen_offsets,
79
+ cu_seqlens=cu_seqlens,
80
+ max_seqlen=max_seqlen,
81
+ interleaved=interleaved,
82
+ inplace=inplace,
83
+ )
84
+
85
+ if isinstance(seqlen_offsets, int):
86
+ ctx.save_for_backward(
87
+ cos, sin, cu_seqlens
88
+ ) # Can't save int with save_for_backward
89
+ ctx.seqlen_offsets = seqlen_offsets
90
+ else:
91
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
92
+ ctx.seqlen_offsets = None
93
+ ctx.interleaved = interleaved
94
+ ctx.inplace = inplace
95
+ ctx.max_seqlen = max_seqlen
96
+ return out if not inplace else x
97
+
98
+ @staticmethod
99
+ def backward(ctx, do):
100
+ seqlen_offsets = ctx.seqlen_offsets
101
+ if seqlen_offsets is None:
102
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
103
+ else:
104
+ cos, sin, cu_seqlens = ctx.saved_tensors
105
+ # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
106
+ # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
107
+ if not ctx.interleaved and not ctx.inplace:
108
+ do = do.clone()
109
+
110
+ dx = apply_rotary(
111
+ do,
112
+ cos,
113
+ sin,
114
+ seqlen_offsets=seqlen_offsets,
115
+ cu_seqlens=cu_seqlens,
116
+ max_seqlen=ctx.max_seqlen,
117
+ interleaved=ctx.interleaved,
118
+ inplace=ctx.inplace,
119
+ conjugate=True,
120
+ )
121
+ return dx, None, None, None, None, None, None, None
122
+
123
+
124
+ def apply_rotary_emb(
125
+ x,
126
+ cos,
127
+ sin,
128
+ interleaved=False,
129
+ inplace=False,
130
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
131
+ cu_seqlens: Optional[torch.Tensor] = None,
132
+ max_seqlen: Optional[int] = None,
133
+ ):
134
+ """
135
+ Arguments:
136
+ x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
137
+ else (total_seqlen, nheads, headdim)
138
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
139
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
140
+ of 1st half and 2nd half (GPT-NeoX style).
141
+ inplace: if True, apply rotary embedding in-place.
142
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
143
+ Most commonly used in inference when we have KV cache.
144
+ cu_seqlens: (batch + 1,) or None
145
+ max_seqlen: int
146
+ Return:
147
+ out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
148
+ else (total_seqlen, nheads, headdim)
149
+ rotary_dim must be <= headdim
150
+ Apply rotary embedding to the first rotary_dim of x.
151
+ """
152
+ return ApplyRotaryEmb.apply(
153
+ x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
154
+ )
155
+
156
+
157
+ # For backward compatibility
158
+ apply_rotary_emb_func = apply_rotary_emb
159
+
160
+
161
+ class ApplyRotaryEmbQKV_(torch.nn.Module):
162
+ @staticmethod
163
+ def forward(
164
+ qkv,
165
+ cos,
166
+ sin,
167
+ cos_k=None,
168
+ sin_k=None,
169
+ interleaved=False,
170
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
171
+ cu_seqlens: Optional[torch.Tensor] = None,
172
+ max_seqlen: Optional[int] = None,
173
+ use_flash_attn: bool = True,
174
+ ):
175
+ # batch, seqlen, three, nheads, headdim = qkv.shape
176
+ assert qkv.shape[-3] == 3
177
+ if cos_k is None and sin_k is None and qkv.is_contiguous():
178
+
179
+ if use_flash_attn:
180
+ # Call 1 kernel instead of 2 kernels
181
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
182
+ # dimensions, we get the same tensor
183
+ qk = rearrange(qkv[..., :2, :, :], "... t h d -> ... (t h) d")
184
+ # qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
185
+ apply_rotary(
186
+ qk,
187
+ cos,
188
+ sin,
189
+ seqlen_offsets=seqlen_offsets,
190
+ interleaved=interleaved,
191
+ inplace=True,
192
+ cu_seqlens=cu_seqlens,
193
+ max_seqlen=max_seqlen,
194
+ )
195
+ else:
196
+ q_rot = apply_rotary_emb_torch(
197
+ qkv[:, :, 0],
198
+ cos,
199
+ sin,
200
+ interleaved=interleaved,
201
+ )
202
+ k_rot = apply_rotary_emb_torch(
203
+ qkv[:, :, 1],
204
+ cos,
205
+ sin,
206
+ interleaved=interleaved,
207
+ )
208
+ qkv = torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
209
+ else:
210
+ cos_k = cos if cos_k is None else cos_k
211
+ sin_k = sin if sin_k is None else sin_k
212
+ q, k = qkv[..., 0, :, :], qkv[..., 1, :, :]
213
+ apply_rotary(
214
+ q,
215
+ cos,
216
+ sin,
217
+ seqlen_offsets,
218
+ interleaved=interleaved,
219
+ inplace=True,
220
+ cu_seqlens=cu_seqlens,
221
+ max_seqlen=max_seqlen,
222
+ )
223
+ apply_rotary(
224
+ k,
225
+ cos_k,
226
+ sin_k,
227
+ seqlen_offsets,
228
+ interleaved=interleaved,
229
+ inplace=True,
230
+ cu_seqlens=cu_seqlens,
231
+ max_seqlen=max_seqlen,
232
+ )
233
+ ctx.save_for_backward(cos, sin, cos_k, sin_k)
234
+ # if isinstance(seqlen_offsets, int):
235
+ # ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens)
236
+ # ctx.seqlen_offsets = seqlen_offsets
237
+ # else:
238
+ # ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets)
239
+ # ctx.seqlen_offsets = None
240
+ # ctx.max_seqlen = max_seqlen
241
+ # ctx.interleaved = interleaved
242
+ return qkv
243
+
244
+ # @staticmethod
245
+ # def backward(ctx, dqkv):
246
+ # seqlen_offsets = ctx.seqlen_offsets
247
+ # if seqlen_offsets is None:
248
+ # cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets = ctx.saved_tensors
249
+ # else:
250
+ # cos, sin, cos_k, sin_k, cu_seqlens = ctx.saved_tensors
251
+ # if cos_k is None and sin_k is None and dqkv.is_contiguous():
252
+ # # Call 1 kernel instead of 2 kernels
253
+ # # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
254
+ # # dimensions, we get the same tensor
255
+ # dqk = rearrange(dqkv[..., :2, :, :], "... t h d -> ... (t h) d")
256
+ # apply_rotary(
257
+ # dqk,
258
+ # cos,
259
+ # sin,
260
+ # seqlen_offsets=seqlen_offsets,
261
+ # interleaved=ctx.interleaved,
262
+ # inplace=True,
263
+ # conjugate=True,
264
+ # cu_seqlens=cu_seqlens,
265
+ # max_seqlen=ctx.max_seqlen,
266
+ # )
267
+ # else:
268
+ # cos_k = cos if cos_k is None else cos_k
269
+ # sin_k = sin if sin_k is None else sin_k
270
+ # dq, dk = dqkv[..., 0, :, :], dqkv[..., 1, :, :]
271
+ # apply_rotary(
272
+ # dq,
273
+ # cos,
274
+ # sin,
275
+ # seqlen_offsets,
276
+ # interleaved=ctx.interleaved,
277
+ # inplace=True,
278
+ # conjugate=True,
279
+ # cu_seqlens=cu_seqlens,
280
+ # max_seqlen=ctx.max_seqlen,
281
+ # )
282
+ # apply_rotary(
283
+ # dk,
284
+ # cos_k,
285
+ # sin_k,
286
+ # seqlen_offsets,
287
+ # interleaved=ctx.interleaved,
288
+ # inplace=True,
289
+ # conjugate=True,
290
+ # cu_seqlens=cu_seqlens,
291
+ # max_seqlen=ctx.max_seqlen,
292
+ # )
293
+ # return dqkv, None, None, None, None, None, None, None, None, None
294
+
295
+
296
+ def apply_rotary_emb_qkv_(
297
+ qkv,
298
+ cos,
299
+ sin,
300
+ cos_k=None,
301
+ sin_k=None,
302
+ interleaved=False,
303
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
304
+ cu_seqlens: Optional[torch.Tensor] = None,
305
+ max_seqlen: Optional[int] = None,
306
+ use_flash_attn=True,
307
+ ):
308
+ """
309
+ Arguments:
310
+ qkv: (batch_size, seqlen, 3, nheads, headdim) if cu_seqlens is None
311
+ else (total_seqlen, 3, nheads, headdim)
312
+ cos, sin: (seqlen, rotary_dim / 2)
313
+ cos_k, sin_k: (seqlen, rotary_dim / 2), optional
314
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
315
+ 1st half and 2nd half (GPT-NeoX style).
316
+ seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
317
+ Most commonly used in inference when we have KV cache.
318
+ cu_seqlens: (batch + 1,) or None
319
+ max_seqlen: int
320
+ Return:
321
+ qkv: (batch_size, seqlen, 3, nheads, headdim) if cu_seqlens is None
322
+ else (total_seqlen, 3, nheads, headdim)
323
+ rotary_dim must be <= headdim
324
+ Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
325
+ """
326
+ return ApplyRotaryEmbQKV_.forward(
327
+ qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn,
328
+ )
329
+
330
+
331
+ class ApplyRotaryEmbKV_(torch.autograd.Function):
332
+ @staticmethod
333
+ def forward(
334
+ ctx,
335
+ kv,
336
+ cos,
337
+ sin,
338
+ interleaved=False,
339
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
340
+ cu_seqlens: Optional[torch.Tensor] = None,
341
+ max_seqlen: Optional[int] = None,
342
+ ):
343
+ # batch, seqlen, two, nheads, headdim = kv.shape
344
+ assert kv.shape[-3] == 2
345
+ k = kv[..., 0, :, :]
346
+ apply_rotary(
347
+ k,
348
+ cos,
349
+ sin,
350
+ seqlen_offsets=seqlen_offsets,
351
+ interleaved=interleaved,
352
+ inplace=True,
353
+ cu_seqlens=cu_seqlens,
354
+ max_seqlen=max_seqlen,
355
+ )
356
+ if isinstance(seqlen_offsets, int):
357
+ ctx.save_for_backward(
358
+ cos, sin, cu_seqlens
359
+ ) # Can't save int with save_for_backward
360
+ ctx.seqlen_offsets = seqlen_offsets
361
+ else:
362
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
363
+ ctx.seqlen_offsets = None
364
+ ctx.max_seqlen = max_seqlen
365
+ ctx.interleaved = interleaved
366
+ return kv
367
+
368
+ @staticmethod
369
+ def backward(ctx, dkv):
370
+ seqlen_offsets = ctx.seqlen_offsets
371
+ if seqlen_offsets is None:
372
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
373
+ else:
374
+ cos, sin, cu_seqlens = ctx.saved_tensors
375
+ apply_rotary(
376
+ dkv[..., 0, :, :],
377
+ cos,
378
+ sin,
379
+ seqlen_offsets=seqlen_offsets,
380
+ interleaved=ctx.interleaved,
381
+ inplace=True,
382
+ conjugate=True,
383
+ cu_seqlens=cu_seqlens,
384
+ max_seqlen=ctx.max_seqlen,
385
+ )
386
+ return dkv, None, None, None, None, None, None
387
+
388
+
389
+ apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
390
+
391
+
392
+ def apply_rotary_emb_kv_(
393
+ kv,
394
+ cos,
395
+ sin,
396
+ interleaved=False,
397
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
398
+ cu_seqlens: Optional[torch.Tensor] = None,
399
+ max_seqlen: Optional[int] = None,
400
+ ):
401
+ """
402
+ Arguments:
403
+ kv: (batch_size, seqlen, 2, nheads, headdim) if cu_seqlens is None
404
+ else (total_seqlen, 2, nheads, headdim)
405
+ cos, sin: (seqlen, rotary_dim / 2)
406
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
407
+ 1st half and 2nd half (GPT-NeoX style).
408
+ seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
409
+ Most commonly used in inference when we have KV cache.
410
+ cu_seqlens: (batch + 1,) or None
411
+ max_seqlen: int
412
+ Return:
413
+ kv: (batch_size, seqlen, 2, nheads, headdim) if cu_seqlens is None
414
+ else (total_seqlen, 2, nheads, headdim)
415
+ rotary_dim must be <= headdim
416
+ Apply rotary embedding *inplace* to the first rotary_dim of K.
417
+ """
418
+ return ApplyRotaryEmbKV_.apply(
419
+ kv, cos, sin, interleaved, seqlen_offsets, cu_seqlens, max_seqlen
420
+ )
421
+
422
+
423
+ class RotaryEmbedding(torch.nn.Module):
424
+ """
425
+ The rotary position embeddings from RoFormer_ (Su et. al).
426
+ A crucial insight from the method is that the query and keys are
427
+ transformed by rotation matrices which depend on the relative positions.
428
+ Other implementations are available in the Rotary Transformer repo_ and in
429
+ GPT-NeoX_, GPT-NeoX was an inspiration
430
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
431
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
432
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
433
+ If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
434
+ A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
435
+ Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
436
+ """
437
+
438
+ def __init__(
439
+ self,
440
+ dim: int,
441
+ base=10000.0,
442
+ interleaved=False,
443
+ scale_base=None,
444
+ pos_idx_in_fp32=True,
445
+ device=None,
446
+ use_flash_attn=True,
447
+ ):
448
+ """
449
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
450
+ of 1st half and 2nd half (GPT-NeoX style).
451
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
452
+ otherwise they might be in lower precision.
453
+ This option was added because previously (before 2023-07-02), when we construct
454
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
455
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
456
+ self.inv_freq would be bf16, and the position indices are also in bf16.
457
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
458
+ embeddings for some positions will coincide.
459
+ To maintain compatibility with models previously trained in pure bf16,
460
+ we add this option.
461
+ """
462
+ super().__init__()
463
+ self.dim = dim
464
+ self._base = float(base)
465
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
466
+ self.use_flash_attn = use_flash_attn
467
+ # Generate and save the inverse frequency buffer (non trainable)
468
+ inv_freq = self._compute_inv_freq(device)
469
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
470
+ self.interleaved = interleaved
471
+ self.scale_base = scale_base
472
+ scale = (
473
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
474
+ / (1.4 * dim)
475
+ if scale_base is not None
476
+ else None
477
+ )
478
+ self.register_buffer("scale", scale, persistent=False)
479
+
480
+ self._seq_len_cached = 8194
481
+ self._cos_cached = None
482
+ self._sin_cached = None
483
+ # self._cos_k_cached = None
484
+ # self._sin_k_cached = None
485
+ self._update_cos_sin_cache(seqlen=self._seq_len_cached, device=device)
486
+
487
+
488
+ @property
489
+ def base(self):
490
+ return self._base
491
+
492
+ @base.setter
493
+ def base(self, new_base):
494
+ new_base = float(new_base)
495
+ if new_base > 0:
496
+ if self._base != new_base: # only update if the base value has changed
497
+ self._base = new_base
498
+ self._update_cos_sin_cache(
499
+ self._seq_len_cached,
500
+ device=self.inv_freq.device,
501
+ dtype=self._cos_cached.dtype if self._cos_cached is not None else None,
502
+ rotary_base_changed=True,
503
+ )
504
+ else:
505
+ raise ValueError("Rotary base value must be positive")
506
+
507
+ def _compute_inv_freq(self, device=None):
508
+ return 1.0 / (
509
+ self.base
510
+ ** (
511
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
512
+ / self.dim
513
+ )
514
+ )
515
+
516
+ def _update_cos_sin_cache(
517
+ self, seqlen, device=None, dtype=None, rotary_base_changed=False
518
+ ):
519
+ # Reset the tables if the sequence length has changed,
520
+ # if we're on a new device (possibly due to tracing for instance),
521
+ # or if we're switching from inference mode to training
522
+ # or if the rotary base value was changed
523
+ if (
524
+ seqlen > self._seq_len_cached
525
+ or self._cos_cached is None
526
+ or self._cos_cached.device != device
527
+ or self._cos_cached.dtype != dtype
528
+ or (self.training and self._cos_cached.is_inference())
529
+ or rotary_base_changed
530
+ ):
531
+ if seqlen != self._seq_len_cached:
532
+ self._seq_len_cached = seqlen
533
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
534
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
535
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
536
+ if rotary_base_changed:
537
+ self.inv_freq = self._compute_inv_freq(device=device)
538
+ if self.pos_idx_in_fp32:
539
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
540
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
541
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
542
+ # cos & sin output to change significantly.
543
+ # We want to recompute self.inv_freq if it was not loaded in fp32
544
+ if self.inv_freq.dtype != torch.float32:
545
+ inv_freq = self._compute_inv_freq(device=device)
546
+ else:
547
+ inv_freq = self.inv_freq
548
+ else:
549
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
550
+ inv_freq = self.inv_freq
551
+
552
+ # Don't do einsum, it converts fp32 to fp16 under AMP
553
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
554
+ freqs = torch.outer(t, inv_freq)
555
+ if self.scale is None:
556
+ self._cos_cached = torch.cos(freqs).to(dtype)
557
+ self._sin_cached = torch.sin(freqs).to(dtype)
558
+ else:
559
+ power = (
560
+ torch.arange(
561
+ seqlen, dtype=self.scale.dtype, device=self.scale.device
562
+ )
563
+ - seqlen // 2
564
+ ) / self.scale_base
565
+ scale = self.scale.to(device=power.device) ** rearrange(
566
+ power, "s -> s 1"
567
+ )
568
+ # We want the multiplication by scale to happen in fp32
569
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
570
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
571
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
572
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
573
+
574
+ def forward(
575
+ self,
576
+ qkv: torch.Tensor,
577
+ kv: Optional[torch.Tensor] = None,
578
+ seqlen_offset: Union[int, torch.Tensor] = 0,
579
+ cu_seqlens: Optional[torch.Tensor] = None,
580
+ max_seqlen: Optional[int] = None,
581
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
582
+ """
583
+ qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
584
+ else it's just q of shape (batch, seqlen, nheads, headdim)
585
+ kv: (batch, seqlen, 2, nheads, headdim)
586
+ seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
587
+ Most commonly used in inference when we have KV cache.
588
+ If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
589
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
590
+ Apply rotary embedding *inplace* to qkv and / or kv.
591
+ """
592
+ if cu_seqlens is not None:
593
+ assert max_seqlen is not None
594
+ seqlen = qkv.shape[1] if max_seqlen is None else max_seqlen
595
+ # if max_seqlen is not None:
596
+ # self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
597
+ # elif isinstance(seqlen_offset, int):
598
+ # self._update_cos_sin_cache(
599
+ # seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype
600
+ # )
601
+ if kv is None:
602
+ if self.scale is None:
603
+ return apply_rotary_emb_qkv_(
604
+ qkv,
605
+ self._cos_cached,
606
+ self._sin_cached,
607
+ interleaved=self.interleaved,
608
+ seqlen_offsets=seqlen_offset,
609
+ cu_seqlens=cu_seqlens,
610
+ max_seqlen=max_seqlen,
611
+ use_flash_attn=self.use_flash_attn,
612
+ )
613
+ else:
614
+ return apply_rotary_emb_qkv_(
615
+ qkv,
616
+ self._cos_cached,
617
+ self._sin_cached,
618
+ self._cos_k_cached,
619
+ self._sin_k_cached,
620
+ interleaved=self.interleaved,
621
+ seqlen_offsets=seqlen_offset,
622
+ cu_seqlens=cu_seqlens,
623
+ max_seqlen=max_seqlen,
624
+ use_flash_attn=self.use_flash_attn,
625
+ )
626
+ else:
627
+ q = qkv
628
+ q = apply_rotary_emb_func(
629
+ q,
630
+ self._cos_cached,
631
+ self._sin_cached,
632
+ interleaved=self.interleaved,
633
+ inplace=True,
634
+ seqlen_offsets=seqlen_offset,
635
+ cu_seqlens=cu_seqlens,
636
+ max_seqlen=max_seqlen,
637
+ )
638
+ if self.scale is None:
639
+ kv = apply_rotary_emb_kv_(
640
+ kv,
641
+ self._cos_cached,
642
+ self._sin_cached,
643
+ interleaved=self.interleaved,
644
+ seqlen_offsets=seqlen_offset,
645
+ cu_seqlens=cu_seqlens,
646
+ max_seqlen=max_seqlen,
647
+ )
648
+ else:
649
+ kv = apply_rotary_emb_kv_(
650
+ kv,
651
+ self._cos_k_cached,
652
+ self._sin_k_cached,
653
+ interleaved=self.interleaved,
654
+ seqlen_offsets=seqlen_offset,
655
+ cu_seqlens=cu_seqlens,
656
+ max_seqlen=max_seqlen,
657
+ )
658
+ return q, kv
stochastic_depth.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation modified from torchvision:
2
+ # https://github.com/pytorch/vision/blob/main/torchvision/ops/stochastic_depth.py
3
+ #
4
+ # License:
5
+ # BSD 3-Clause License
6
+ #
7
+ # Copyright (c) Soumith Chintala 2016,
8
+ # All rights reserved.
9
+ #
10
+ # Redistribution and use in source and binary forms, with or without
11
+ # modification, are permitted provided that the following conditions are met:
12
+ #
13
+ # * Redistributions of source code must retain the above copyright notice, this
14
+ # list of conditions and the following disclaimer.
15
+ #
16
+ # * Redistributions in binary form must reproduce the above copyright notice,
17
+ # this list of conditions and the following disclaimer in the documentation
18
+ # and/or other materials provided with the distribution.
19
+ #
20
+ # * Neither the name of the copyright holder nor the names of its
21
+ # contributors may be used to endorse or promote products derived from
22
+ # this software without specific prior written permission.
23
+ #
24
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34
+
35
+ import torch
36
+ import torch.fx
37
+ from torch import Tensor, nn
38
+
39
+
40
+ def stochastic_depth(
41
+ input: Tensor, p: float, mode: str, training: bool = True
42
+ ) -> Tensor:
43
+ """
44
+ Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
45
+ <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
46
+ branches of residual architectures.
47
+
48
+ Args:
49
+ input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
50
+ being its batch i.e. a batch with ``N`` rows.
51
+ p (float): probability of the input to be zeroed.
52
+ mode (str): ``"batch"`` or ``"row"``.
53
+ ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
54
+ randomly selected rows from the batch.
55
+ training: apply stochastic depth if is ``True``. Default: ``True``
56
+
57
+ Returns:
58
+ Tensor[N, ...]: The randomly zeroed tensor.
59
+ """
60
+ if p < 0.0 or p > 1.0:
61
+ raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
62
+ if mode not in ["batch", "row"]:
63
+ raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
64
+ if not training or p == 0.0:
65
+ return input
66
+
67
+ survival_rate = 1.0 - p
68
+ if mode == "row":
69
+ size = [input.shape[0]] + [1] * (input.ndim - 1)
70
+ else:
71
+ size = [1] * input.ndim
72
+ noise = torch.empty(size, dtype=input.dtype, device=input.device)
73
+ noise = noise.bernoulli_(survival_rate)
74
+ if survival_rate > 0.0:
75
+ noise.div_(survival_rate)
76
+ return input * noise
77
+
78
+
79
+ torch.fx.wrap("stochastic_depth")
80
+
81
+
82
+ class StochasticDepth(nn.Module):
83
+ """
84
+ See :func:`stochastic_depth`.
85
+ """
86
+
87
+ def __init__(self, p: float, mode: str) -> None:
88
+ super().__init__()
89
+ self.p = p
90
+ self.mode = mode
91
+
92
+ def forward(self, input: Tensor) -> Tensor:
93
+ return stochastic_depth(input, self.p, self.mode, self.training)
94
+
95
+ def __repr__(self) -> str:
96
+ s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
97
+ return s
xlm_padding.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2
+ # Commit id: c94cd09744d20f0ac587a351ff6ff2e8ad11ae1b
3
+
4
+ # Previously adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+
10
+
11
+ class IndexFirstAxis(torch.autograd.Function):
12
+ @staticmethod
13
+ def forward(ctx, input, indices):
14
+ ctx.save_for_backward(indices)
15
+ assert input.ndim >= 2
16
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
17
+ second_dim = other_shape.numel()
18
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
19
+ # return input[indices]
20
+ return torch.gather(
21
+ rearrange(input, "b ... -> b (...)"),
22
+ 0,
23
+ repeat(indices, "z -> z d", d=second_dim),
24
+ ).reshape(-1, *other_shape)
25
+
26
+ @staticmethod
27
+ def backward(ctx, grad_output):
28
+ (indices,) = ctx.saved_tensors
29
+ assert grad_output.ndim >= 2
30
+ other_shape = grad_output.shape[1:]
31
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
32
+ grad_input = torch.zeros(
33
+ [ctx.first_axis_dim, grad_output.shape[1]],
34
+ device=grad_output.device,
35
+ dtype=grad_output.dtype,
36
+ )
37
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
38
+ # grad_input[indices] = grad_output
39
+ grad_input.scatter_(
40
+ 0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output
41
+ )
42
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
43
+
44
+
45
+ index_first_axis = IndexFirstAxis.apply
46
+
47
+
48
+ class IndexPutFirstAxis(torch.autograd.Function):
49
+ @staticmethod
50
+ def forward(ctx, values, indices, first_axis_dim):
51
+ ctx.save_for_backward(indices)
52
+ assert indices.ndim == 1
53
+ assert values.ndim >= 2
54
+ output = torch.zeros(
55
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
56
+ )
57
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
58
+ output[indices] = values
59
+ # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
60
+ return output
61
+
62
+ @staticmethod
63
+ def backward(ctx, grad_output):
64
+ (indices,) = ctx.saved_tensors
65
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
66
+ grad_values = grad_output[indices]
67
+ # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
68
+ return grad_values, None, None
69
+
70
+
71
+ index_put_first_axis = IndexPutFirstAxis.apply
72
+
73
+
74
+ class IndexFirstAxisResidual(torch.autograd.Function):
75
+ @staticmethod
76
+ def forward(ctx, input, indices):
77
+ ctx.save_for_backward(indices)
78
+ assert input.ndim >= 2
79
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
80
+ second_dim = other_shape.numel()
81
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
82
+ output = input[indices]
83
+ # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
84
+ # memory format to channel_first. In other words, input might not be contiguous.
85
+ # If we don't detach, Pytorch complains about output being a view and is being modified inplace
86
+ return output, input.detach()
87
+
88
+ @staticmethod
89
+ def backward(ctx, grad_output, grad_residual):
90
+ (indices,) = ctx.saved_tensors
91
+ assert grad_output.ndim >= 2
92
+ other_shape = grad_output.shape[1:]
93
+ assert grad_residual.shape[1:] == other_shape
94
+ grad_input = grad_residual
95
+ # grad_input[indices] += grad_output
96
+ indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
97
+ indices = indices.expand_as(grad_output)
98
+ grad_input.scatter_add_(0, indices, grad_output)
99
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
100
+
101
+
102
+ index_first_axis_residual = IndexFirstAxisResidual.apply
103
+
104
+
105
+ def unpad_input(hidden_states, attention_mask):
106
+ """
107
+ Arguments:
108
+ hidden_states: (batch, seqlen, ...)
109
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
110
+ Return:
111
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
112
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
113
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
114
+ max_seqlen_in_batch: int
115
+ """
116
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
117
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
118
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
119
+ cu_seqlens = F.pad(
120
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
121
+ )
122
+
123
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
124
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
125
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
126
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
127
+ # so we write custom forward and backward to make it a bit faster.
128
+ return (
129
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
130
+ indices,
131
+ cu_seqlens,
132
+ max_seqlen_in_batch,
133
+ )
134
+
135
+
136
+ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
137
+ """
138
+ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
139
+ The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
140
+
141
+ For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
142
+ ```
143
+ [
144
+ [2, 3, 0, 0, 0, 0],
145
+ [3, 2, 0, 0, 0, 0],
146
+ [6, 0, 0, 0, 0, 0]
147
+ ]
148
+ ```
149
+ , which refers to the 3D-attention mask:
150
+ ```
151
+ [
152
+ [
153
+ [1, 0, 0, 0, 0, 0],
154
+ [1, 1, 0, 0, 0, 0],
155
+ [0, 0, 1, 0, 0, 0],
156
+ [0, 0, 1, 1, 0, 0],
157
+ [0, 0, 1, 1, 1, 0],
158
+ [0, 0, 0, 0, 0, 1]
159
+ ],
160
+ [
161
+ [1, 0, 0, 0, 0, 0],
162
+ [1, 1, 0, 0, 0, 0],
163
+ [1, 1, 1, 0, 0, 0],
164
+ [0, 0, 0, 1, 0, 0],
165
+ [0, 0, 0, 1, 1, 0],
166
+ [0, 0, 0, 0, 0, 1]
167
+ ],
168
+ [
169
+ [1, 0, 0, 0, 0, 0],
170
+ [1, 1, 0, 0, 0, 0],
171
+ [1, 1, 1, 0, 0, 0],
172
+ [1, 1, 1, 1, 0, 0],
173
+ [1, 1, 1, 1, 1, 0],
174
+ [1, 1, 1, 1, 1, 1]
175
+ ]
176
+ ]
177
+ ```.
178
+
179
+ Arguments:
180
+ hidden_states: (batch, seqlen, ...)
181
+ attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
182
+ Return:
183
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
184
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
185
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
186
+ max_seqlen_in_batch: int
187
+ """
188
+ length = attention_mask_in_length.sum(dim=-1)
189
+ seqlen = attention_mask_in_length.size(-1)
190
+ attention_mask_2d = torch.arange(
191
+ seqlen, device=length.device, dtype=length.dtype
192
+ ).expand(len(length), seqlen) < length.unsqueeze(1)
193
+ real_indices_idx = torch.nonzero(
194
+ attention_mask_in_length.flatten(), as_tuple=False
195
+ ).flatten()
196
+ seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
197
+ indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
198
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
199
+ cu_seqlens = F.pad(
200
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
201
+ )
202
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
203
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
204
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
205
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
206
+ # so we write custom forward and backward to make it a bit faster.
207
+ return (
208
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
209
+ indices,
210
+ cu_seqlens,
211
+ max_seqlen_in_batch,
212
+ )
213
+
214
+
215
+ def pad_input(hidden_states, indices, batch, seqlen):
216
+ """
217
+ Arguments:
218
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
219
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
220
+ batch: int, batch size for the padded sequence.
221
+ seqlen: int, maximum sequence length for the padded sequence.
222
+ Return:
223
+ hidden_states: (batch, seqlen, ...)
224
+ """
225
+ dim = hidden_states.shape[-1]
226
+ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
227
+ # output[indices] = hidden_states
228
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
229
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)