amazingvince commited on
Commit
d572741
1 Parent(s): f015c86

Update modeling_custom_seq2seq_llm.py

Browse files
Files changed (1) hide show
  1. modeling_custom_seq2seq_llm.py +1090 -16
modeling_custom_seq2seq_llm.py CHANGED
@@ -3,13 +3,1033 @@ import torch.nn as nn
3
  from torch.nn import CrossEntropyLoss
4
  from transformers.modeling_outputs import Seq2SeqLMOutput
5
  from transformers.activations import ACT2FN
6
- from flash_atten import MHA # Import the MHA class from the provided implementation
7
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
8
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
9
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
10
  from transformers import PreTrainedModel, PretrainedConfig
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  class RMSNorm(nn.Module):
@@ -106,6 +1126,24 @@ class CustomSeq2SeqLLM(PreTrainedModel):
106
 
107
  def get_output_embeddings(self):
108
  return self.lm_head
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  def forward(
111
  self,
@@ -166,6 +1204,57 @@ class CustomSeq2SeqLLM(PreTrainedModel):
166
  shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
167
  shifted_input_ids[..., 0] = self.config.pad_token_id
168
  return shifted_input_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  class CustomEncoder(nn.Module):
171
  def __init__(self, config):
@@ -260,18 +1349,3 @@ class DecoderLayer(nn.Module):
260
 
261
  return hidden_states
262
 
263
- class FeedForward(nn.Module):
264
- def __init__(self, config):
265
- super().__init__()
266
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
267
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
268
- self.act = ACT2FN[config.hidden_act]
269
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
270
-
271
- def forward(self, x):
272
- x = self.fc1(x)
273
- x = self.act(x)
274
- x = self.dropout(x)
275
- x = self.fc2(x)
276
- x = self.dropout(x)
277
- return x
 
3
  from torch.nn import CrossEntropyLoss
4
  from transformers.modeling_outputs import Seq2SeqLMOutput
5
  from transformers.activations import ACT2FN
 
6
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
7
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
8
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
9
  from transformers import PreTrainedModel, PretrainedConfig
10
 
11
 
12
+ # Copyright (c) 2023, Tri Dao.
13
+
14
+ import math
15
+ import os
16
+ from functools import partial
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from einops import rearrange, repeat
21
+
22
+ from flash_attn.utils.distributed import get_dim_for_local_rank
23
+
24
+ try:
25
+ from flash_attn import (
26
+ flash_attn_kvpacked_func,
27
+ flash_attn_qkvpacked_func,
28
+ flash_attn_varlen_kvpacked_func,
29
+ flash_attn_varlen_qkvpacked_func,
30
+ flash_attn_with_kvcache,
31
+ )
32
+ except ImportError:
33
+ flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
34
+ flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
35
+ flash_attn_with_kvcache = None
36
+
37
+ try:
38
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
39
+ except ImportError:
40
+ FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
41
+
42
+ try:
43
+ from flash_attn.layers.rotary import RotaryEmbedding
44
+ except ImportError:
45
+ RotaryEmbedding = None
46
+
47
+
48
+ # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
49
+ def get_alibi_slopes(nheads):
50
+ def get_slopes_power_of_2(nheads):
51
+ start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))
52
+ ratio = start
53
+ return [start * ratio**i for i in range(nheads)]
54
+
55
+ if math.log2(nheads).is_integer():
56
+ return get_slopes_power_of_2(nheads)
57
+ else:
58
+ closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
59
+ return (
60
+ get_slopes_power_of_2(closest_power_of_2)
61
+ + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
62
+ )
63
+
64
+
65
+ class FlashSelfAttention(nn.Module):
66
+ """Implement the scaled dot product attention with softmax.
67
+ Arguments
68
+ ---------
69
+ softmax_scale: The temperature to use for the softmax attention.
70
+ (default: 1/sqrt(d_keys) where d_keys is computed at
71
+ runtime)
72
+ attention_dropout: The dropout rate to apply to the attention
73
+ (default: 0.0)
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ causal=False,
79
+ softmax_scale=None,
80
+ attention_dropout=0.0,
81
+ window_size=(-1, -1),
82
+ alibi_slopes=None,
83
+ deterministic=False,
84
+ ):
85
+ super().__init__()
86
+ assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
87
+ assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
88
+ self.causal = causal
89
+ self.softmax_scale = softmax_scale
90
+ self.drop = nn.Dropout(attention_dropout)
91
+ self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
92
+ self.window_size = window_size
93
+ self.deterministic = deterministic
94
+
95
+ def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
96
+ """Implements the multihead softmax attention.
97
+ Arguments
98
+ ---------
99
+ qkv: The tensor containing the query, key, and value.
100
+ If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
101
+ If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
102
+ (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
103
+ causal: if passed, will override self.causal
104
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
105
+ of the sequences in the batch, used to index into qkv.
106
+ max_seqlen: int. Maximum sequence length in the batch.
107
+ Returns:
108
+ --------
109
+ out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
110
+ else (B, S, H, D).
111
+ """
112
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
113
+ assert qkv.is_cuda
114
+ causal = self.causal if causal is None else causal
115
+ unpadded = cu_seqlens is not None
116
+ if self.alibi_slopes is not None:
117
+ self.alibi_slopes = self.alibi_slopes.to(torch.float32)
118
+ if unpadded:
119
+ assert cu_seqlens.dtype == torch.int32
120
+ assert max_seqlen is not None
121
+ assert isinstance(max_seqlen, int)
122
+ return flash_attn_varlen_qkvpacked_func(
123
+ qkv,
124
+ cu_seqlens,
125
+ max_seqlen,
126
+ self.drop.p if self.training else 0.0,
127
+ softmax_scale=self.softmax_scale,
128
+ causal=causal,
129
+ alibi_slopes=self.alibi_slopes,
130
+ window_size=self.window_size,
131
+ deterministic=self.deterministic,
132
+ )
133
+ else:
134
+ return flash_attn_qkvpacked_func(
135
+ qkv,
136
+ self.drop.p if self.training else 0.0,
137
+ softmax_scale=self.softmax_scale,
138
+ causal=causal,
139
+ alibi_slopes=self.alibi_slopes,
140
+ window_size=self.window_size,
141
+ deterministic=self.deterministic,
142
+ )
143
+
144
+
145
+ class FlashCrossAttention(nn.Module):
146
+ """Implement the scaled dot product attention with softmax.
147
+ Arguments
148
+ ---------
149
+ softmax_scale: The temperature to use for the softmax attention.
150
+ (default: 1/sqrt(d_keys) where d_keys is computed at
151
+ runtime)
152
+ attention_dropout: The dropout rate to apply to the attention
153
+ (default: 0.0)
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ causal=False,
159
+ softmax_scale=None,
160
+ attention_dropout=0.0,
161
+ alibi_slopes=None,
162
+ window_size=(-1, -1),
163
+ deterministic=False,
164
+ ):
165
+ super().__init__()
166
+ assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
167
+ assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
168
+ self.causal = causal
169
+ self.softmax_scale = softmax_scale
170
+ self.drop = nn.Dropout(attention_dropout)
171
+ self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
172
+ self.window_size = window_size
173
+ self.deterministic = deterministic
174
+
175
+ def forward(
176
+ self,
177
+ q,
178
+ kv,
179
+ causal=None,
180
+ cu_seqlens=None,
181
+ max_seqlen=None,
182
+ cu_seqlens_k=None,
183
+ max_seqlen_k=None,
184
+ ):
185
+ """Implements the multihead softmax attention.
186
+ Arguments
187
+ ---------
188
+ q: The tensor containing the query. (B, Sq, H, D)
189
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
190
+ causal: if passed, will override self.causal
191
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
192
+ of the sequences in the batch, used to index into q.
193
+ max_seqlen: int. Maximum sequence length in the batch of q.
194
+ cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
195
+ of the sequences in the batch, used to index into kv.
196
+ max_seqlen_k: int. Maximum sequence length in the batch of k and v.
197
+ """
198
+ assert q.dtype in [torch.float16, torch.bfloat16]
199
+ assert q.is_cuda and kv.is_cuda
200
+ causal = self.causal if causal is None else causal
201
+ unpadded = cu_seqlens is not None
202
+ if self.alibi_slopes is not None:
203
+ self.alibi_slopes = self.alibi_slopes.to(torch.float32)
204
+ if unpadded:
205
+ assert cu_seqlens.dtype == torch.int32
206
+ assert max_seqlen is not None
207
+ assert isinstance(max_seqlen, int)
208
+ assert cu_seqlens_k is not None
209
+ assert cu_seqlens_k.dtype == torch.int32
210
+ assert max_seqlen_k is not None
211
+ assert isinstance(max_seqlen_k, int)
212
+ return flash_attn_varlen_kvpacked_func(
213
+ q,
214
+ kv,
215
+ cu_seqlens,
216
+ cu_seqlens_k,
217
+ max_seqlen,
218
+ max_seqlen_k,
219
+ self.drop.p if self.training else 0.0,
220
+ softmax_scale=self.softmax_scale,
221
+ causal=causal,
222
+ alibi_slopes=self.alibi_slopes,
223
+ window_size=self.window_size,
224
+ deterministic=self.deterministic,
225
+ )
226
+ else:
227
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
228
+ seqlen_k = kv.shape[1]
229
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
230
+ return flash_attn_kvpacked_func(
231
+ q,
232
+ kv,
233
+ self.drop.p if self.training else 0.0,
234
+ causal=causal,
235
+ softmax_scale=self.softmax_scale,
236
+ alibi_slopes=self.alibi_slopes,
237
+ window_size=self.window_size,
238
+ deterministic=self.deterministic,
239
+ )
240
+
241
+
242
+ class SelfAttention(nn.Module):
243
+ """Implement the scaled dot product attention with softmax.
244
+ Arguments
245
+ ---------
246
+ softmax_scale: The temperature to use for the softmax attention.
247
+ (default: 1/sqrt(d_keys) where d_keys is computed at
248
+ runtime)
249
+ attention_dropout: The dropout rate to apply to the attention
250
+ (default: 0.0)
251
+ """
252
+
253
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
254
+ super().__init__()
255
+ self.causal = causal
256
+ self.softmax_scale = softmax_scale
257
+ self.drop = nn.Dropout(attention_dropout)
258
+
259
+ def forward(self, qkv, causal=None, key_padding_mask=None):
260
+ """Implements the multihead softmax attention.
261
+ Arguments
262
+ ---------
263
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
264
+ causal: if passed, will override self.causal
265
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
266
+ False means to mask out. (B, S)
267
+ """
268
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
269
+ causal = self.causal if causal is None else causal
270
+ q, k, v = qkv.unbind(dim=2)
271
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
272
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
273
+ if key_padding_mask is not None:
274
+ padding_mask = torch.full(
275
+ (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
276
+ )
277
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
278
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
279
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
280
+ if causal:
281
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
282
+ # So we have to construct the mask in float
283
+ causal_mask = torch.triu(
284
+ torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
285
+ )
286
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
287
+ scores = scores + causal_mask.to(dtype=scores.dtype)
288
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
289
+ attention_drop = self.drop(attention)
290
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
291
+ return output
292
+
293
+
294
+ class CrossAttention(nn.Module):
295
+ """Implement the scaled dot product attention with softmax.
296
+ Arguments
297
+ ---------
298
+ softmax_scale: The temperature to use for the softmax attention.
299
+ (default: 1/sqrt(d_keys) where d_keys is computed at
300
+ runtime)
301
+ attention_dropout: The dropout rate to apply to the attention
302
+ (default: 0.0)
303
+ """
304
+
305
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
306
+ super().__init__()
307
+ self.causal = causal
308
+ self.softmax_scale = softmax_scale
309
+ self.drop = nn.Dropout(attention_dropout)
310
+
311
+ def forward(self, q, kv, causal=None, key_padding_mask=None):
312
+ """Implements the multihead softmax attention.
313
+ Arguments
314
+ ---------
315
+ q: The tensor containing the query. (B, Sq, H, D)
316
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
317
+ causal: if passed, will override self.causal
318
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
319
+ False means to mask out. (B, Sk)
320
+ """
321
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
322
+ causal = self.causal if causal is None else causal
323
+ seqlen_k = kv.shape[1]
324
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
325
+ if kv.shape[3] != q.shape[2]: # MQA/GQA
326
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
327
+ k, v = kv.unbind(dim=2)
328
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
329
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
330
+ if key_padding_mask is not None:
331
+ padding_mask = torch.full(
332
+ (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
333
+ )
334
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
335
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
336
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
337
+ if causal:
338
+ # causal mask needs to take into account the difference between seqlen_q and seqlen_k
339
+ row_idx = rearrange(
340
+ torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
341
+ )
342
+ col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
343
+ sk = (
344
+ seqlen_k
345
+ if key_padding_mask is None
346
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
347
+ )
348
+ causal_mask = col_idx > row_idx + sk - seqlen_q
349
+ scores = scores.masked_fill(causal_mask, -10000.0)
350
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
351
+ attention_drop = self.drop(attention)
352
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
353
+ return output
354
+
355
+
356
+ class LinearResidual(nn.Linear):
357
+ """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
358
+
359
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
360
+ return super().forward(input), input
361
+
362
+
363
+ def _update_kv_cache(kv, inference_params, layer_idx):
364
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
365
+ # Pre-allocate memory for key-values for inference.
366
+ num_heads, head_dim = kv.shape[-2:]
367
+ if layer_idx not in inference_params.key_value_memory_dict:
368
+ kv_cache = torch.empty(
369
+ inference_params.max_batch_size,
370
+ inference_params.max_seqlen,
371
+ 2,
372
+ num_heads,
373
+ head_dim,
374
+ dtype=kv.dtype,
375
+ device=kv.device,
376
+ )
377
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
378
+ else:
379
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
380
+ # Adjust key and value for inference
381
+ batch_start = inference_params.batch_size_offset
382
+ batch_end = batch_start + kv.shape[0]
383
+ sequence_start = inference_params.seqlen_offset
384
+ sequence_end = sequence_start + kv.shape[1]
385
+ assert batch_end <= kv_cache.shape[0]
386
+ assert sequence_end <= kv_cache.shape[1]
387
+ assert kv_cache is not None
388
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
389
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
390
+
391
+
392
+ class MHA(nn.Module):
393
+ """Multi-head self-attention and cross-attention"""
394
+
395
+ def __init__(
396
+ self,
397
+ embed_dim,
398
+ num_heads,
399
+ num_heads_kv=None,
400
+ cross_attn=False,
401
+ qkv_proj_bias=True,
402
+ out_proj_bias=True,
403
+ dropout=0.0,
404
+ softmax_scale=None,
405
+ causal=False,
406
+ layer_idx=None,
407
+ dwconv=False,
408
+ rotary_emb_dim=0,
409
+ rotary_emb_base=10000.0,
410
+ rotary_emb_scale_base=None,
411
+ rotary_emb_interleaved=False,
412
+ use_alibi=False,
413
+ window_size=(-1, -1),
414
+ fused_bias_fc=False,
415
+ use_flash_attn=False,
416
+ return_residual=False,
417
+ checkpointing=False,
418
+ device=None,
419
+ dtype=None,
420
+ ) -> None:
421
+ """
422
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
423
+ return_residual: whether to return the input x along with the output. This is for
424
+ performance reason: for post-norm architecture, returning the input allows us
425
+ to fuse the backward of nn.Linear with the residual connection.
426
+ """
427
+ factory_kwargs = {"device": device, "dtype": dtype}
428
+ super().__init__()
429
+ self.embed_dim = embed_dim
430
+ self.cross_attn = cross_attn
431
+ self.causal = causal
432
+ self.layer_idx = layer_idx
433
+ self.dwconv = dwconv
434
+ self.rotary_emb_dim = rotary_emb_dim
435
+ self.use_flash_attn = use_flash_attn
436
+ self.return_residual = return_residual
437
+ self.checkpointing = checkpointing
438
+ if use_alibi:
439
+ assert use_flash_attn, "ALiBi code path requires flash_attn"
440
+ alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
441
+ else:
442
+ alibi_slopes = None
443
+ if window_size != (-1, -1):
444
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
445
+
446
+ self.num_heads = num_heads
447
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
448
+ assert (
449
+ self.num_heads % self.num_heads_kv == 0
450
+ ), "num_heads must be divisible by num_heads_kv"
451
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
452
+ self.head_dim = self.embed_dim // num_heads
453
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
454
+ kv_dim = 2 * self.head_dim * self.num_heads_kv
455
+
456
+ if self.rotary_emb_dim > 0:
457
+ assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
458
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
459
+ self.rotary_emb = RotaryEmbedding(
460
+ self.rotary_emb_dim,
461
+ base=rotary_emb_base,
462
+ scale_base=rotary_emb_scale_base,
463
+ interleaved=rotary_emb_interleaved,
464
+ device=device,
465
+ )
466
+
467
+ if fused_bias_fc and FusedDense is None:
468
+ raise ImportError("fused_dense is not installed")
469
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
470
+ linear_resid_cls = (
471
+ LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
472
+ )
473
+ wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
474
+ inner_attn_cls = (
475
+ partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
476
+ if use_flash_attn
477
+ else SelfAttention
478
+ )
479
+ inner_cross_attn_cls = (
480
+ partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
481
+ if use_flash_attn
482
+ else CrossAttention
483
+ )
484
+ if not self.cross_attn:
485
+ self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
486
+ else:
487
+ self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
488
+ self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
489
+ if self.dwconv:
490
+ if self.num_heads_kv == self.num_heads:
491
+ self.dwconv_qkv = nn.Conv1d(
492
+ qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
493
+ )
494
+ else:
495
+ self.dwconv_q = nn.Conv1d(
496
+ embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
497
+ )
498
+ self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
499
+ self.inner_attn = inner_attn_cls(
500
+ causal=causal,
501
+ softmax_scale=softmax_scale,
502
+ attention_dropout=dropout,
503
+ )
504
+ self.inner_cross_attn = inner_cross_attn_cls(
505
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
506
+ )
507
+ self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
508
+
509
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
510
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
511
+ device = self.out_proj.weight.device
512
+ return torch.empty(
513
+ batch_size,
514
+ max_seqlen,
515
+ 2,
516
+ self.num_heads_kv,
517
+ self.head_dim,
518
+ dtype=dtype,
519
+ device=device,
520
+ )
521
+
522
+ def _update_kv_cache(self, kv, inference_params):
523
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
524
+ assert not self.dwconv, "Generation does not support dwconv yet"
525
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
526
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
527
+
528
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
529
+ """
530
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
531
+ q: (batch_size, seqlen_q, nheads, head_dim)
532
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
533
+ """
534
+ assert inference_params is not None and inference_params.seqlen_offset > 0
535
+ assert self.use_flash_attn
536
+ if self.rotary_emb_dim > 0:
537
+ assert self.rotary_emb.scale is None, "This code path does not support xPos"
538
+ self.rotary_emb._update_cos_sin_cache(
539
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
540
+ )
541
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
542
+ else:
543
+ rotary_cos, rotary_sin = None, None
544
+ batch = q.shape[0]
545
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
546
+ cache_seqlens = (
547
+ inference_params.lengths_per_sample[:batch]
548
+ if inference_params.lengths_per_sample is not None
549
+ else inference_params.seqlen_offset
550
+ )
551
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
552
+ context = flash_attn_with_kvcache(
553
+ q,
554
+ kv_cache[:, :, 0],
555
+ kv_cache[:, :, 1],
556
+ kv[:, :, 0],
557
+ kv[:, :, 1],
558
+ rotary_cos=rotary_cos,
559
+ rotary_sin=rotary_sin,
560
+ cache_seqlens=cache_seqlens,
561
+ softmax_scale=self.inner_cross_attn.softmax_scale,
562
+ causal=self.inner_cross_attn.causal,
563
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
564
+ alibi_slopes=alibi_slopes,
565
+ )
566
+ return context
567
+
568
+ def _update_kvcache_attention(self, q, kv, inference_params):
569
+ """Write kv to inference_params, then do attention"""
570
+ if (
571
+ inference_params.seqlen_offset == 0
572
+ or flash_attn_with_kvcache is None
573
+ or not self.use_flash_attn
574
+ ):
575
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
576
+ kv = self._update_kv_cache(kv, inference_params)
577
+ return self.inner_cross_attn(q, kv)
578
+ else:
579
+ batch = q.shape[0]
580
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
581
+ cache_seqlens = (
582
+ inference_params.lengths_per_sample[:batch]
583
+ if inference_params.lengths_per_sample is not None
584
+ else inference_params.seqlen_offset
585
+ )
586
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
587
+ return flash_attn_with_kvcache(
588
+ q,
589
+ kv_cache[:, :, 0],
590
+ kv_cache[:, :, 1],
591
+ kv[:, :, 0],
592
+ kv[:, :, 1],
593
+ cache_seqlens=cache_seqlens,
594
+ softmax_scale=self.inner_cross_attn.softmax_scale,
595
+ causal=self.inner_cross_attn.causal,
596
+ alibi_slopes=alibi_slopes,
597
+ )
598
+
599
+ def forward(
600
+ self,
601
+ x,
602
+ x_kv=None,
603
+ key_padding_mask=None,
604
+ cu_seqlens=None,
605
+ max_seqlen=None,
606
+ mixer_subset=None,
607
+ inference_params=None,
608
+ **kwargs,
609
+ ):
610
+ """
611
+ Arguments:
612
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
613
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
614
+ is the is the sum of the sequence lengths in the batch.
615
+ x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
616
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
617
+ of the sequences in the batch, used to index into x. Only applicable when using
618
+ FlashAttention.
619
+ max_seqlen: int. Maximum sequence length in the batch.
620
+ key_padding_mask: boolean mask, True means to keep, False means to mask out.
621
+ (batch, seqlen). Only applicable when not using FlashAttention.
622
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
623
+ before applying the query projection. Useful for e.g., ViT where we only care
624
+ about the CLS token in the last layer.
625
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
626
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
627
+ """
628
+ if cu_seqlens is not None:
629
+ assert max_seqlen is not None
630
+ assert key_padding_mask is None
631
+ assert self.use_flash_attn
632
+ assert not self.dwconv
633
+ assert self.rotary_emb_dim == 0
634
+ if key_padding_mask is not None:
635
+ assert cu_seqlens is None
636
+ assert max_seqlen is None
637
+ assert not self.use_flash_attn
638
+ if inference_params is not None:
639
+ assert key_padding_mask is None
640
+ assert cu_seqlens is None and max_seqlen is None
641
+ assert not self.dwconv
642
+
643
+ kwargs = (
644
+ {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
645
+ if self.use_flash_attn
646
+ else {"key_padding_mask": key_padding_mask, **kwargs}
647
+ )
648
+ seqlen_offset = (
649
+ 0
650
+ if inference_params is None
651
+ else (
652
+ inference_params.lengths_per_sample
653
+ if inference_params.lengths_per_sample is not None
654
+ else inference_params.seqlen_offset
655
+ )
656
+ )
657
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
658
+ batch, seqlen = x.shape[:2]
659
+ if not self.cross_attn and self.num_heads_kv == self.num_heads:
660
+ assert x_kv is None and mixer_subset is None
661
+ if not self.return_residual:
662
+ qkv = self.Wqkv(x)
663
+ else:
664
+ qkv, x = self.Wqkv(x)
665
+ if self.dwconv:
666
+ qkv = rearrange(
667
+ self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
668
+ ).contiguous()
669
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
670
+ if (
671
+ inference_params is None
672
+ or inference_params.seqlen_offset == 0
673
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
674
+ or not self.use_flash_attn
675
+ ):
676
+ if self.rotary_emb_dim > 0:
677
+ qkv = self.rotary_emb(
678
+ qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
679
+ )
680
+ if inference_params is None:
681
+ if not self.checkpointing:
682
+ context = self.inner_attn(qkv, **kwargs)
683
+ else:
684
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
685
+ else:
686
+ context = self._update_kvcache_attention(
687
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
688
+ )
689
+ else:
690
+ context = self._apply_rotary_update_kvcache_attention(
691
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
692
+ )
693
+ else:
694
+ if self.cross_attn:
695
+ if not self.return_residual:
696
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
697
+ kv = self.Wkv(x_kv if x_kv is not None else x)
698
+ else:
699
+ if x_kv is not None:
700
+ kv, x_kv = self.Wkv(x_kv)
701
+ else:
702
+ kv, x = self.Wkv(x)
703
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
704
+ else:
705
+ assert self.num_heads_kv != self.num_heads
706
+ if not self.return_residual:
707
+ qkv = self.Wqkv(x)
708
+ else:
709
+ qkv, x = self.Wqkv(x)
710
+ q = qkv[..., : self.num_heads * self.head_dim]
711
+ kv = qkv[..., self.num_heads * self.head_dim :]
712
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
713
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
714
+ if self.dwconv:
715
+ q = rearrange(
716
+ self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
717
+ ).contiguous()
718
+ kv = rearrange(
719
+ self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
720
+ ).contiguous()
721
+ if (
722
+ inference_params is None
723
+ or inference_params.seqlen_offset == 0
724
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
725
+ or not self.use_flash_attn
726
+ ):
727
+ if self.rotary_emb_dim > 0:
728
+ q, kv = self.rotary_emb(
729
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
730
+ )
731
+ if inference_params is None:
732
+ if not self.checkpointing:
733
+ context = self.inner_cross_attn(q, kv, **kwargs)
734
+ else:
735
+ context = torch.utils.checkpoint.checkpoint(
736
+ self.inner_cross_attn, q, kv, **kwargs
737
+ )
738
+ else:
739
+ context = self._update_kvcache_attention(q, kv, inference_params)
740
+ else:
741
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
742
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
743
+ return out if not self.return_residual else (out, x)
744
+
745
+
746
+ class ParallelMHA(nn.Module):
747
+ """Multi-head self-attention and cross-attention"""
748
+
749
+ def __init__(
750
+ self,
751
+ embed_dim,
752
+ num_heads,
753
+ process_group,
754
+ num_heads_kv=None,
755
+ qkv_proj_bias=True,
756
+ out_proj_bias=True,
757
+ dropout=0.0,
758
+ softmax_scale=None,
759
+ causal=False,
760
+ layer_idx=None,
761
+ rotary_emb_dim=0,
762
+ rotary_emb_base=10000.0,
763
+ rotary_emb_scale_base=None,
764
+ rotary_emb_interleaved=False,
765
+ use_alibi=False,
766
+ window_size=(-1, -1),
767
+ use_flash_attn=False,
768
+ checkpointing=False,
769
+ sequence_parallel=True,
770
+ device=None,
771
+ dtype=None,
772
+ ) -> None:
773
+ factory_kwargs = {"device": device, "dtype": dtype}
774
+ super().__init__()
775
+ self.embed_dim = embed_dim
776
+ self.causal = causal
777
+ self.layer_idx = layer_idx
778
+ self.rotary_emb_dim = rotary_emb_dim
779
+ self.use_flash_attn = use_flash_attn
780
+ self.checkpointing = checkpointing
781
+ self.process_group = process_group
782
+ self.world_size = process_group.size()
783
+ self.local_rank = torch.distributed.get_rank(process_group)
784
+
785
+ self.num_heads = num_heads
786
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
787
+
788
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
789
+ assert (
790
+ self.num_heads % self.num_heads_kv == 0
791
+ ), "num_heads must be divisible by num_heads_kv"
792
+
793
+ self.num_heads_per_rank = get_dim_for_local_rank(
794
+ self.num_heads, self.world_size, self.local_rank
795
+ )
796
+ self.num_heads_kv_per_rank = get_dim_for_local_rank(
797
+ self.num_heads_kv, self.world_size, self.local_rank
798
+ )
799
+ self.head_dim = self.embed_dim // num_heads
800
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
801
+
802
+ if use_alibi:
803
+ assert use_flash_attn, "ALiBi code path requires flash_attn"
804
+ num_heads_local = math.ceil(self.num_heads / self.world_size)
805
+ alibi_slopes = torch.tensor(
806
+ get_alibi_slopes(num_heads)[
807
+ self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
808
+ ],
809
+ device=device,
810
+ )
811
+ else:
812
+ alibi_slopes = None
813
+ if window_size != (-1, -1):
814
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
815
+
816
+ if self.rotary_emb_dim > 0:
817
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
818
+ self.rotary_emb = RotaryEmbedding(
819
+ self.rotary_emb_dim,
820
+ base=rotary_emb_base,
821
+ scale_base=rotary_emb_scale_base,
822
+ interleaved=rotary_emb_interleaved,
823
+ device=device,
824
+ )
825
+
826
+ if ColumnParallelLinear is None or RowParallelLinear is None:
827
+ raise ImportError("fused_dense is not installed")
828
+ self.Wqkv = ColumnParallelLinear(
829
+ embed_dim,
830
+ qkv_dim,
831
+ process_group,
832
+ bias=qkv_proj_bias,
833
+ sequence_parallel=sequence_parallel,
834
+ multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
835
+ **factory_kwargs,
836
+ )
837
+ inner_attn_cls = (
838
+ partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
839
+ if use_flash_attn
840
+ else SelfAttention
841
+ )
842
+ inner_cross_attn_cls = (
843
+ partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
844
+ if use_flash_attn
845
+ else CrossAttention
846
+ )
847
+ self.inner_attn = inner_attn_cls(
848
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
849
+ )
850
+ self.inner_cross_attn = inner_cross_attn_cls(
851
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
852
+ )
853
+ self.out_proj = RowParallelLinear(
854
+ embed_dim,
855
+ embed_dim,
856
+ process_group,
857
+ bias=out_proj_bias,
858
+ sequence_parallel=sequence_parallel,
859
+ multiple_of=self.head_dim,
860
+ **factory_kwargs,
861
+ )
862
+
863
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
864
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
865
+ device = self.out_proj.weight.device
866
+ return torch.empty(
867
+ batch_size,
868
+ max_seqlen,
869
+ 2,
870
+ self.num_heads_kv_per_rank,
871
+ self.head_dim,
872
+ dtype=dtype,
873
+ device=device,
874
+ )
875
+
876
+ def _update_kv_cache(self, kv, inference_params):
877
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
878
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
879
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
880
+
881
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
882
+ """
883
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
884
+ q: (batch_size, seqlen_q, nheads, head_dim)
885
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
886
+ """
887
+ assert inference_params is not None and inference_params.seqlen_offset > 0
888
+ assert self.use_flash_attn
889
+ if self.rotary_emb_dim > 0:
890
+ assert self.rotary_emb.scale is None, "This code path does not support xPos"
891
+ self.rotary_emb._update_cos_sin_cache(
892
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
893
+ )
894
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
895
+ else:
896
+ rotary_cos, rotary_sin = None, None
897
+ batch = q.shape[0]
898
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
899
+ cache_seqlens = (
900
+ inference_params.lengths_per_sample[:batch]
901
+ if inference_params.lengths_per_sample is not None
902
+ else inference_params.seqlen_offset
903
+ )
904
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
905
+ context = flash_attn_with_kvcache(
906
+ q,
907
+ kv_cache[:, :, 0],
908
+ kv_cache[:, :, 1],
909
+ kv[:, :, 0],
910
+ kv[:, :, 1],
911
+ rotary_cos=rotary_cos,
912
+ rotary_sin=rotary_sin,
913
+ cache_seqlens=cache_seqlens,
914
+ softmax_scale=self.inner_cross_attn.softmax_scale,
915
+ causal=self.inner_cross_attn.causal,
916
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
917
+ alibi_slopes=alibi_slopes,
918
+ )
919
+ return context
920
+
921
+ def _update_kvcache_attention(self, q, kv, inference_params):
922
+ """Write kv to inference_params, then do attention"""
923
+ if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
924
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
925
+ kv = self._update_kv_cache(kv, inference_params)
926
+ return self.inner_cross_attn(q, kv)
927
+ else:
928
+ batch = q.shape[0]
929
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
930
+ cache_seqlens = (
931
+ inference_params.lengths_per_sample[:batch]
932
+ if inference_params.lengths_per_sample is not None
933
+ else inference_params.seqlen_offset
934
+ )
935
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
936
+ context = flash_attn_with_kvcache(
937
+ q,
938
+ kv_cache[:, :, 0],
939
+ kv_cache[:, :, 1],
940
+ kv[:, :, 0],
941
+ kv[:, :, 1],
942
+ cache_seqlens=cache_seqlens,
943
+ softmax_scale=self.inner_cross_attn.softmax_scale,
944
+ causal=self.inner_cross_attn.causal,
945
+ alibi_slopes=alibi_slopes,
946
+ )
947
+ return context
948
+
949
+ def forward(self, x, seqlen=None, inference_params=None, **kwargs):
950
+ """
951
+ Arguments:
952
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
953
+ If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
954
+ split x during sequence parallel, we split the batch * seqlen dimension
955
+ (in case batch is small).
956
+ """
957
+ qkv = self.Wqkv(x)
958
+ if seqlen is not None:
959
+ qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
960
+ seqlen_offset = (
961
+ 0
962
+ if inference_params is None
963
+ else (
964
+ inference_params.lengths_per_sample
965
+ if inference_params.lengths_per_sample is not None
966
+ else inference_params.seqlen_offset
967
+ )
968
+ )
969
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
970
+ if self.num_heads_kv == self.num_heads:
971
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
972
+ if (
973
+ inference_params is None
974
+ or inference_params.seqlen_offset == 0
975
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
976
+ or not self.use_flash_attn
977
+ ):
978
+ if self.rotary_emb_dim > 0:
979
+ qkv = self.rotary_emb(
980
+ qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
981
+ )
982
+ if inference_params is None:
983
+ if not self.checkpointing:
984
+ context = self.inner_attn(qkv, **kwargs)
985
+ else:
986
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
987
+ else:
988
+ context = self._update_kvcache_attention(
989
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
990
+ )
991
+ else:
992
+ context = self._apply_rotary_update_kvcache_attention(
993
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
994
+ )
995
+ else:
996
+ q = rearrange(
997
+ qkv[..., : self.num_heads_per_rank * self.head_dim],
998
+ "... (h d) -> ... h d",
999
+ d=self.head_dim,
1000
+ )
1001
+ kv = rearrange(
1002
+ qkv[..., self.num_heads_per_rank * self.head_dim :],
1003
+ "... (two hkv d) -> ... two hkv d",
1004
+ two=2,
1005
+ d=self.head_dim,
1006
+ )
1007
+ if (
1008
+ inference_params is None
1009
+ or inference_params.seqlen_offset == 0
1010
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
1011
+ or not self.use_flash_attn
1012
+ ):
1013
+ if self.rotary_emb_dim > 0:
1014
+ q, kv = self.rotary_emb(
1015
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
1016
+ )
1017
+ if inference_params is None:
1018
+ if not self.checkpointing:
1019
+ context = self.inner_cross_attn(q, kv, **kwargs)
1020
+ else:
1021
+ context = torch.utils.checkpoint.checkpoint(
1022
+ self.inner_cross_attn, q, kv, **kwargs
1023
+ )
1024
+ else:
1025
+ context = self._update_kvcache_attention(q, kv, inference_params)
1026
+ else:
1027
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
1028
+ context = rearrange(context, "b s h d -> b s (h d)")
1029
+ if seqlen is not None:
1030
+ context = rearrange(context, "b s d -> (b s) d")
1031
+ out = self.out_proj(context)
1032
+ return out
1033
 
1034
 
1035
  class RMSNorm(nn.Module):
 
1126
 
1127
  def get_output_embeddings(self):
1128
  return self.lm_head
1129
+
1130
+ def set_output_embeddings(self, new_embeddings):
1131
+ self.lm_head = new_embeddings
1132
+
1133
+ def prepare_inputs_for_generation(
1134
+ self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
1135
+ ):
1136
+ # Cut decoder_input_ids if past is used
1137
+ if past is not None:
1138
+ input_ids = input_ids[:, -1:]
1139
+
1140
+ return {
1141
+ "decoder_input_ids": input_ids,
1142
+ "past_key_values": past,
1143
+ "encoder_outputs": encoder_outputs,
1144
+ "attention_mask": attention_mask,
1145
+ "use_cache": use_cache,
1146
+ }
1147
 
1148
  def forward(
1149
  self,
 
1204
  shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1205
  shifted_input_ids[..., 0] = self.config.pad_token_id
1206
  return shifted_input_ids
1207
+
1208
+ def save_pretrained(self, save_directory, safe_serialization=True):
1209
+ # Save the config
1210
+ self.config.save_pretrained(save_directory)
1211
+
1212
+ # Prepare state dict
1213
+ state_dict = self.state_dict()
1214
+
1215
+ # Handle shared weights
1216
+ if self.config.tie_word_embeddings:
1217
+ state_dict["lm_head.weight"] = state_dict["shared.weight"]
1218
+
1219
+ # Convert state_dict to CPU tensors
1220
+ cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()}
1221
+
1222
+ if safe_serialization:
1223
+ # Save using safetensors
1224
+ safe_filepath = os.path.join(save_directory, "model.safetensors")
1225
+ save_file(cpu_state_dict, safe_filepath)
1226
+ else:
1227
+ # Save using PyTorch
1228
+ torch_filepath = os.path.join(save_directory, "pytorch_model.bin")
1229
+ torch.save(cpu_state_dict, torch_filepath)
1230
+
1231
+ @classmethod
1232
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1233
+ config = kwargs.pop("config", None)
1234
+ state_dict = kwargs.pop("state_dict", None)
1235
+
1236
+ if config is None:
1237
+ config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
1238
+
1239
+ model = cls(config)
1240
+
1241
+ if state_dict is None:
1242
+ # Try loading safetensors first
1243
+ safe_filepath = os.path.join(pretrained_model_name_or_path, "model.safetensors")
1244
+ if os.path.exists(safe_filepath):
1245
+ from safetensors.torch import load_file
1246
+ state_dict = load_file(safe_filepath)
1247
+ else:
1248
+ # Fall back to PyTorch format
1249
+ torch_filepath = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
1250
+ state_dict = torch.load(torch_filepath, map_location="cpu")
1251
+
1252
+ # Handle shared weights
1253
+ if config.tie_word_embeddings and "lm_head.weight" not in state_dict:
1254
+ state_dict["lm_head.weight"] = state_dict["shared.weight"]
1255
+
1256
+ model.load_state_dict(state_dict)
1257
+ return model
1258
 
1259
  class CustomEncoder(nn.Module):
1260
  def __init__(self, config):
 
1349
 
1350
  return hidden_states
1351