Multipack simplify for Mixtral (#1142)
Browse files- src/axolotl/core/trainer_builder.py +16 -7
- src/axolotl/monkeypatch/mixtral/__init__.py +4 -14
- src/axolotl/monkeypatch/mixtral/modeling_mixtral.py +0 -383
- src/axolotl/monkeypatch/utils.py +34 -0
- src/axolotl/utils/collators.py +27 -0
- src/axolotl/utils/config.py +38 -2
- src/axolotl/utils/models.py +7 -3
- src/axolotl/utils/trainer.py +1 -0
- tests/e2e/patched/test_mixtral_samplepack.py +3 -15
- tests/e2e/patched/test_model_patches.py +1 -5
- tests/monkeypatch/test_llama_attn_hijack_flash.py +70 -1
src/axolotl/core/trainer_builder.py
CHANGED
@@ -12,7 +12,7 @@ from abc import abstractmethod
|
|
12 |
from dataclasses import dataclass, field
|
13 |
from functools import wraps
|
14 |
from pathlib import Path
|
15 |
-
from typing import Optional
|
16 |
|
17 |
import torch
|
18 |
import transformers
|
@@ -37,6 +37,7 @@ from axolotl.utils.collators import (
|
|
37 |
BatchSamplerDataCollatorForSeq2Seq,
|
38 |
DataCollatorForSeq2Seq,
|
39 |
MambaDataCollator,
|
|
|
40 |
)
|
41 |
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
42 |
from axolotl.utils.schedulers import (
|
@@ -896,14 +897,22 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
896 |
if is_eval and training_args.eval_sample_packing:
|
897 |
use_batch_sampler_collator = True
|
898 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
899 |
if use_batch_sampler_collator:
|
900 |
-
|
901 |
-
|
902 |
-
|
903 |
-
|
904 |
-
|
|
|
905 |
|
906 |
-
return
|
907 |
self.tokenizer,
|
908 |
return_tensors="pt",
|
909 |
**kwargs,
|
|
|
12 |
from dataclasses import dataclass, field
|
13 |
from functools import wraps
|
14 |
from pathlib import Path
|
15 |
+
from typing import Optional, Type, Union
|
16 |
|
17 |
import torch
|
18 |
import transformers
|
|
|
37 |
BatchSamplerDataCollatorForSeq2Seq,
|
38 |
DataCollatorForSeq2Seq,
|
39 |
MambaDataCollator,
|
40 |
+
V2BatchSamplerDataCollatorForSeq2Seq,
|
41 |
)
|
42 |
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
43 |
from axolotl.utils.schedulers import (
|
|
|
897 |
if is_eval and training_args.eval_sample_packing:
|
898 |
use_batch_sampler_collator = True
|
899 |
|
900 |
+
collator: Type[
|
901 |
+
Union[
|
902 |
+
V2BatchSamplerDataCollatorForSeq2Seq,
|
903 |
+
BatchSamplerDataCollatorForSeq2Seq,
|
904 |
+
DataCollatorForSeq2Seq,
|
905 |
+
]
|
906 |
+
]
|
907 |
if use_batch_sampler_collator:
|
908 |
+
if self.cfg.model_config_type == "mixtral":
|
909 |
+
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
910 |
+
else:
|
911 |
+
collator = BatchSamplerDataCollatorForSeq2Seq
|
912 |
+
else:
|
913 |
+
collator = DataCollatorForSeq2Seq
|
914 |
|
915 |
+
return collator(
|
916 |
self.tokenizer,
|
917 |
return_tensors="pt",
|
918 |
**kwargs,
|
src/axolotl/monkeypatch/mixtral/__init__.py
CHANGED
@@ -3,20 +3,10 @@ Patches to support multipack for mixtral
|
|
3 |
"""
|
4 |
import transformers
|
5 |
|
|
|
6 |
|
7 |
-
def replace_mixtral_attn_with_multipack_flash_attn():
|
8 |
-
from .modeling_mixtral import (
|
9 |
-
MixtralMultipackFlashAttention2,
|
10 |
-
mixtral_decoder_layer_forward,
|
11 |
-
mixtral_model_forward,
|
12 |
-
)
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
|
18 |
-
mixtral_model_forward
|
19 |
)
|
20 |
-
transformers.models.mixtral.modeling_mixtral.MIXTRAL_ATTENTION_CLASSES[
|
21 |
-
"flash_attention_2"
|
22 |
-
] = MixtralMultipackFlashAttention2
|
|
|
3 |
"""
|
4 |
import transformers
|
5 |
|
6 |
+
from axolotl.monkeypatch.utils import get_unpad_data
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
def replace_mixtral_attn_with_multipack_flash_attn():
|
10 |
+
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
11 |
+
get_unpad_data
|
|
|
|
|
12 |
)
|
|
|
|
|
|
src/axolotl/monkeypatch/mixtral/modeling_mixtral.py
DELETED
@@ -1,383 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Mixtral modeling for multipack
|
3 |
-
"""
|
4 |
-
# pylint: disable=missing-module-docstring,unused-argument,protected-access,pointless-string-statement,duplicate-code
|
5 |
-
import logging
|
6 |
-
import warnings
|
7 |
-
from typing import List, Optional, Tuple, Union
|
8 |
-
|
9 |
-
import torch
|
10 |
-
from einops import rearrange
|
11 |
-
from flash_attn import flash_attn_varlen_qkvpacked_func
|
12 |
-
from transformers import Cache, DynamicCache
|
13 |
-
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
14 |
-
from transformers.modeling_outputs import MoeModelOutputWithPast
|
15 |
-
from transformers.models.mixtral.modeling_mixtral import (
|
16 |
-
MixtralFlashAttention2,
|
17 |
-
apply_rotary_pos_emb,
|
18 |
-
repeat_kv,
|
19 |
-
)
|
20 |
-
|
21 |
-
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
22 |
-
|
23 |
-
LOG = logging.getLogger("axolotl.monkeypatch.mixtral")
|
24 |
-
|
25 |
-
|
26 |
-
class MixtralMultipackFlashAttention2(MixtralFlashAttention2):
|
27 |
-
"""
|
28 |
-
Custom multipack implementation w flash attention 2
|
29 |
-
"""
|
30 |
-
|
31 |
-
def __init__(self, *args, **kwargs):
|
32 |
-
super().__init__(*args, **kwargs)
|
33 |
-
self._flash_attn_uses_top_left_mask = True
|
34 |
-
|
35 |
-
def forward(
|
36 |
-
self,
|
37 |
-
hidden_states: torch.Tensor,
|
38 |
-
attention_mask: Optional[torch.Tensor] = None,
|
39 |
-
position_ids: Optional[torch.LongTensor] = None,
|
40 |
-
past_key_value: Optional[Cache] = None,
|
41 |
-
output_attentions: bool = False,
|
42 |
-
use_cache: bool = False,
|
43 |
-
cu_seqlens: Optional[torch.Tensor] = None,
|
44 |
-
max_seqlen: Optional[torch.Tensor] = None,
|
45 |
-
**kwargs,
|
46 |
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
47 |
-
if "padding_mask" in kwargs:
|
48 |
-
warnings.warn(
|
49 |
-
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
50 |
-
)
|
51 |
-
bsz, q_len, _ = hidden_states.size()
|
52 |
-
|
53 |
-
query_states = self.q_proj(hidden_states)
|
54 |
-
key_states = self.k_proj(hidden_states)
|
55 |
-
value_states = self.v_proj(hidden_states)
|
56 |
-
|
57 |
-
query_states = query_states.view(
|
58 |
-
bsz, q_len, self.num_heads, self.head_dim
|
59 |
-
).transpose(1, 2)
|
60 |
-
key_states = key_states.view(
|
61 |
-
bsz, q_len, self.num_key_value_heads, self.head_dim
|
62 |
-
).transpose(1, 2)
|
63 |
-
value_states = value_states.view(
|
64 |
-
bsz, q_len, self.num_key_value_heads, self.head_dim
|
65 |
-
).transpose(1, 2)
|
66 |
-
|
67 |
-
kv_seq_len = key_states.shape[-2]
|
68 |
-
if past_key_value is not None:
|
69 |
-
if self.layer_idx is None:
|
70 |
-
raise ValueError(
|
71 |
-
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
72 |
-
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
73 |
-
"with a layer index."
|
74 |
-
)
|
75 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
76 |
-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
77 |
-
query_states, key_states = apply_rotary_pos_emb(
|
78 |
-
query_states, key_states, cos, sin, position_ids
|
79 |
-
)
|
80 |
-
|
81 |
-
if past_key_value is not None:
|
82 |
-
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
83 |
-
key_states, value_states = past_key_value.update(
|
84 |
-
key_states, value_states, self.layer_idx, cache_kwargs
|
85 |
-
)
|
86 |
-
|
87 |
-
# repeat k/v heads if n_kv_heads < n_heads
|
88 |
-
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
89 |
-
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
90 |
-
|
91 |
-
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
92 |
-
# special handling using sample packing
|
93 |
-
qkv = torch.stack(
|
94 |
-
[query_states, key_states, value_states], dim=2
|
95 |
-
) # [bsz, nh, 3, q_len, hd]
|
96 |
-
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
97 |
-
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
98 |
-
|
99 |
-
attn_output = flash_attn_varlen_qkvpacked_func(
|
100 |
-
qkv,
|
101 |
-
cu_seqlens,
|
102 |
-
max_seqlen,
|
103 |
-
dropout_p=self.attention_dropout,
|
104 |
-
softmax_scale=None,
|
105 |
-
causal=True,
|
106 |
-
)
|
107 |
-
attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz)
|
108 |
-
|
109 |
-
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
110 |
-
attn_output = self.o_proj(attn_output)
|
111 |
-
|
112 |
-
if not output_attentions:
|
113 |
-
attn_weights = None
|
114 |
-
|
115 |
-
return attn_output, attn_weights, past_key_value
|
116 |
-
|
117 |
-
|
118 |
-
def mixtral_decoder_layer_forward(
|
119 |
-
self,
|
120 |
-
hidden_states: torch.Tensor,
|
121 |
-
attention_mask: Optional[torch.Tensor] = None,
|
122 |
-
position_ids: Optional[torch.LongTensor] = None,
|
123 |
-
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
124 |
-
output_attentions: Optional[bool] = False,
|
125 |
-
output_router_logits: Optional[bool] = False,
|
126 |
-
use_cache: Optional[bool] = False,
|
127 |
-
cu_seqlens: Optional[torch.Tensor] = None,
|
128 |
-
max_seqlen: Optional[torch.Tensor] = None,
|
129 |
-
**kwargs,
|
130 |
-
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
131 |
-
if "padding_mask" in kwargs:
|
132 |
-
warnings.warn(
|
133 |
-
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
134 |
-
)
|
135 |
-
"""
|
136 |
-
Args:
|
137 |
-
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
138 |
-
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
139 |
-
`(batch, sequence_length)` where padding elements are indicated by 0.
|
140 |
-
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
141 |
-
output_attentions (`bool`, *optional*):
|
142 |
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
143 |
-
returned tensors for more detail.
|
144 |
-
output_router_logits (`bool`, *optional*):
|
145 |
-
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
146 |
-
should not be returned during inference.
|
147 |
-
use_cache (`bool`, *optional*):
|
148 |
-
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
149 |
-
(see `past_key_values`).
|
150 |
-
"""
|
151 |
-
|
152 |
-
residual = hidden_states
|
153 |
-
|
154 |
-
hidden_states = self.input_layernorm(hidden_states)
|
155 |
-
|
156 |
-
# Self Attention
|
157 |
-
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
158 |
-
hidden_states=hidden_states,
|
159 |
-
attention_mask=attention_mask,
|
160 |
-
position_ids=position_ids,
|
161 |
-
past_key_value=past_key_value,
|
162 |
-
output_attentions=output_attentions,
|
163 |
-
use_cache=use_cache,
|
164 |
-
cu_seqlens=cu_seqlens,
|
165 |
-
max_seqlen=max_seqlen,
|
166 |
-
)
|
167 |
-
hidden_states = residual + hidden_states
|
168 |
-
|
169 |
-
# Fully Connected
|
170 |
-
residual = hidden_states
|
171 |
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
172 |
-
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
173 |
-
hidden_states = residual + hidden_states
|
174 |
-
|
175 |
-
outputs = (hidden_states,)
|
176 |
-
|
177 |
-
if output_attentions:
|
178 |
-
outputs += (self_attn_weights,)
|
179 |
-
|
180 |
-
if use_cache:
|
181 |
-
outputs += (present_key_value,)
|
182 |
-
|
183 |
-
if output_router_logits:
|
184 |
-
outputs += (router_logits,)
|
185 |
-
|
186 |
-
return outputs
|
187 |
-
|
188 |
-
|
189 |
-
def mixtral_model_forward(
|
190 |
-
self,
|
191 |
-
input_ids: torch.LongTensor = None,
|
192 |
-
attention_mask: Optional[torch.Tensor] = None,
|
193 |
-
position_ids: Optional[torch.LongTensor] = None,
|
194 |
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
195 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
196 |
-
use_cache: Optional[bool] = None,
|
197 |
-
output_attentions: Optional[bool] = None,
|
198 |
-
output_hidden_states: Optional[bool] = None,
|
199 |
-
output_router_logits: Optional[bool] = None,
|
200 |
-
return_dict: Optional[bool] = None,
|
201 |
-
) -> Union[Tuple, MoeModelOutputWithPast]:
|
202 |
-
output_attentions = (
|
203 |
-
output_attentions
|
204 |
-
if output_attentions is not None
|
205 |
-
else self.config.output_attentions
|
206 |
-
)
|
207 |
-
output_router_logits = (
|
208 |
-
output_router_logits
|
209 |
-
if output_router_logits is not None
|
210 |
-
else self.config.output_router_logits
|
211 |
-
)
|
212 |
-
output_hidden_states = (
|
213 |
-
output_hidden_states
|
214 |
-
if output_hidden_states is not None
|
215 |
-
else self.config.output_hidden_states
|
216 |
-
)
|
217 |
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
218 |
-
|
219 |
-
return_dict = (
|
220 |
-
return_dict if return_dict is not None else self.config.use_return_dict
|
221 |
-
)
|
222 |
-
|
223 |
-
# retrieve input_ids and inputs_embeds
|
224 |
-
if input_ids is not None and inputs_embeds is not None:
|
225 |
-
raise ValueError(
|
226 |
-
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
227 |
-
)
|
228 |
-
if input_ids is not None:
|
229 |
-
batch_size, seq_length = input_ids.shape
|
230 |
-
elif inputs_embeds is not None:
|
231 |
-
batch_size, seq_length, _ = inputs_embeds.shape
|
232 |
-
else:
|
233 |
-
raise ValueError(
|
234 |
-
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
235 |
-
)
|
236 |
-
|
237 |
-
past_key_values_length = 0
|
238 |
-
|
239 |
-
if use_cache:
|
240 |
-
use_legacy_cache = not isinstance(past_key_values, Cache)
|
241 |
-
if use_legacy_cache:
|
242 |
-
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
243 |
-
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
244 |
-
|
245 |
-
cu_seqlens = None
|
246 |
-
max_seqlen = None
|
247 |
-
if position_ids is None:
|
248 |
-
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
249 |
-
position_ids = torch.arange(
|
250 |
-
past_key_values_length,
|
251 |
-
seq_length + past_key_values_length,
|
252 |
-
dtype=torch.long,
|
253 |
-
device=device,
|
254 |
-
)
|
255 |
-
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
256 |
-
else:
|
257 |
-
position_ids = position_ids.view(-1, seq_length).long()
|
258 |
-
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
259 |
-
cu_seqlens = cu_seqlens.squeeze()
|
260 |
-
|
261 |
-
if inputs_embeds is None:
|
262 |
-
inputs_embeds = self.embed_tokens(input_ids)
|
263 |
-
|
264 |
-
if (
|
265 |
-
attention_mask is not None
|
266 |
-
and self._attn_implementation == "flash_attention_2"
|
267 |
-
and use_cache
|
268 |
-
):
|
269 |
-
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
270 |
-
if is_padding_right:
|
271 |
-
raise ValueError(
|
272 |
-
"You are attempting to perform batched generation with padding_side='right'"
|
273 |
-
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
|
274 |
-
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
275 |
-
)
|
276 |
-
|
277 |
-
if self._attn_implementation == "flash_attention_2":
|
278 |
-
# 2d mask is passed through the layers
|
279 |
-
attention_mask = (
|
280 |
-
attention_mask
|
281 |
-
if (attention_mask is not None and 0 in attention_mask)
|
282 |
-
else None
|
283 |
-
)
|
284 |
-
else:
|
285 |
-
# 4d mask is passed through the layers
|
286 |
-
attention_mask = _prepare_4d_causal_attention_mask(
|
287 |
-
attention_mask,
|
288 |
-
(batch_size, seq_length),
|
289 |
-
inputs_embeds,
|
290 |
-
past_key_values_length,
|
291 |
-
sliding_window=self.config.sliding_window,
|
292 |
-
)
|
293 |
-
|
294 |
-
hidden_states = inputs_embeds
|
295 |
-
|
296 |
-
if self.gradient_checkpointing and self.training:
|
297 |
-
if use_cache:
|
298 |
-
LOG.warning_once(
|
299 |
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
300 |
-
)
|
301 |
-
use_cache = False
|
302 |
-
|
303 |
-
# decoder layers
|
304 |
-
all_hidden_states = () if output_hidden_states else None
|
305 |
-
all_self_attns = () if output_attentions else None
|
306 |
-
all_router_logits = () if output_router_logits else None
|
307 |
-
next_decoder_cache = None
|
308 |
-
|
309 |
-
for decoder_layer in self.layers:
|
310 |
-
if output_hidden_states:
|
311 |
-
all_hidden_states += (hidden_states,)
|
312 |
-
|
313 |
-
if self.gradient_checkpointing and self.training:
|
314 |
-
layer_outputs = self._gradient_checkpointing_func(
|
315 |
-
decoder_layer.__call__,
|
316 |
-
hidden_states,
|
317 |
-
attention_mask,
|
318 |
-
position_ids,
|
319 |
-
past_key_values,
|
320 |
-
output_attentions,
|
321 |
-
output_router_logits,
|
322 |
-
use_cache,
|
323 |
-
cu_seqlens,
|
324 |
-
max_seqlen,
|
325 |
-
)
|
326 |
-
else:
|
327 |
-
layer_outputs = decoder_layer(
|
328 |
-
hidden_states,
|
329 |
-
attention_mask=attention_mask,
|
330 |
-
position_ids=position_ids,
|
331 |
-
past_key_value=past_key_values,
|
332 |
-
output_attentions=output_attentions,
|
333 |
-
output_router_logits=output_router_logits,
|
334 |
-
use_cache=use_cache,
|
335 |
-
cu_seqlens=cu_seqlens,
|
336 |
-
max_seqlen=max_seqlen,
|
337 |
-
)
|
338 |
-
|
339 |
-
hidden_states = layer_outputs[0]
|
340 |
-
|
341 |
-
if use_cache:
|
342 |
-
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
343 |
-
|
344 |
-
if output_attentions:
|
345 |
-
all_self_attns += (layer_outputs[1],)
|
346 |
-
|
347 |
-
if output_router_logits:
|
348 |
-
all_router_logits += (layer_outputs[-1],)
|
349 |
-
|
350 |
-
hidden_states = self.norm(hidden_states)
|
351 |
-
|
352 |
-
# add hidden states from the last decoder layer
|
353 |
-
if output_hidden_states:
|
354 |
-
all_hidden_states += (hidden_states,)
|
355 |
-
|
356 |
-
next_cache = None
|
357 |
-
if use_cache:
|
358 |
-
next_cache = (
|
359 |
-
next_decoder_cache.to_legacy_cache()
|
360 |
-
if use_legacy_cache
|
361 |
-
else next_decoder_cache
|
362 |
-
)
|
363 |
-
|
364 |
-
if not return_dict:
|
365 |
-
return tuple(
|
366 |
-
v
|
367 |
-
for v in [
|
368 |
-
hidden_states,
|
369 |
-
next_cache,
|
370 |
-
all_hidden_states,
|
371 |
-
all_self_attns,
|
372 |
-
all_router_logits,
|
373 |
-
]
|
374 |
-
if v is not None
|
375 |
-
)
|
376 |
-
|
377 |
-
return MoeModelOutputWithPast(
|
378 |
-
last_hidden_state=hidden_states,
|
379 |
-
past_key_values=next_cache,
|
380 |
-
hidden_states=all_hidden_states,
|
381 |
-
attentions=all_self_attns,
|
382 |
-
router_logits=all_router_logits,
|
383 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/axolotl/monkeypatch/utils.py
CHANGED
@@ -2,6 +2,40 @@
|
|
2 |
Shared utils for the monkeypatches
|
3 |
"""
|
4 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
def get_cu_seqlens(attn_mask):
|
|
|
2 |
Shared utils for the monkeypatches
|
3 |
"""
|
4 |
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
@torch.jit.script
|
9 |
+
def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
|
10 |
+
max_num = int(torch.max(attention_mask).item())
|
11 |
+
batch_size, _ = attention_mask.shape
|
12 |
+
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
|
13 |
+
|
14 |
+
for i in range(1, max_num + 1):
|
15 |
+
mask = attention_mask == i
|
16 |
+
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
|
17 |
+
|
18 |
+
result = counts.flatten()
|
19 |
+
nonzero_indices = torch.nonzero(result).squeeze(-1)
|
20 |
+
return result[nonzero_indices]
|
21 |
+
|
22 |
+
|
23 |
+
@torch.jit.script
|
24 |
+
def get_unpad_data(attention_mask: torch.Tensor):
|
25 |
+
device = attention_mask.device
|
26 |
+
seqlens_in_batch = get_max_seqlen_in_batch(attention_mask)
|
27 |
+
indices = torch.nonzero(attention_mask.flatten()).flatten()
|
28 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
29 |
+
cu_seqlens = (
|
30 |
+
F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
31 |
+
.to(device=device)
|
32 |
+
.detach()
|
33 |
+
)
|
34 |
+
return (
|
35 |
+
indices,
|
36 |
+
cu_seqlens,
|
37 |
+
max_seqlen_in_batch,
|
38 |
+
)
|
39 |
|
40 |
|
41 |
def get_cu_seqlens(attn_mask):
|
src/axolotl/utils/collators.py
CHANGED
@@ -152,6 +152,33 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|
152 |
return super().__call__(features, return_tensors=return_tensors)
|
153 |
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
@dataclass
|
156 |
class MambaDataCollator:
|
157 |
"""
|
|
|
152 |
return super().__call__(features, return_tensors=return_tensors)
|
153 |
|
154 |
|
155 |
+
@dataclass
|
156 |
+
class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
157 |
+
"""
|
158 |
+
Collator for multipack specific to the using the BatchSampler
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __call__(self, features, return_tensors=None):
|
162 |
+
chunked_data = {}
|
163 |
+
for feature in features[0].keys():
|
164 |
+
if feature == "length":
|
165 |
+
continue
|
166 |
+
if feature == "attention_mask":
|
167 |
+
arrays = [
|
168 |
+
(i + 1) * np.array(item[feature])
|
169 |
+
for i, item in enumerate(features)
|
170 |
+
if feature in item
|
171 |
+
]
|
172 |
+
chunked_data[feature] = np.concatenate(arrays)
|
173 |
+
else:
|
174 |
+
arrays = [
|
175 |
+
np.array(item[feature]) for item in features if feature in item
|
176 |
+
]
|
177 |
+
chunked_data[feature] = np.concatenate(arrays)
|
178 |
+
features = [chunked_data]
|
179 |
+
return super().__call__(features, return_tensors=return_tensors)
|
180 |
+
|
181 |
+
|
182 |
@dataclass
|
183 |
class MambaDataCollator:
|
184 |
"""
|
src/axolotl/utils/config.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
"""Module for working with config dicts"""
|
2 |
-
|
3 |
import logging
|
4 |
import os
|
|
|
5 |
|
6 |
import torch
|
7 |
from transformers.utils import is_torch_bf16_gpu_available
|
8 |
|
9 |
from axolotl.utils.bench import log_gpu_memory_usage
|
|
|
10 |
from axolotl.utils.models import load_model_config
|
11 |
|
12 |
LOG = logging.getLogger("axolotl")
|
@@ -135,7 +137,7 @@ def normalize_config(cfg):
|
|
135 |
]
|
136 |
)
|
137 |
or cfg.is_mistral_derived_model
|
138 |
-
or "mistral" in cfg.base_model.lower()
|
139 |
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
140 |
)
|
141 |
|
@@ -484,6 +486,40 @@ def validate_config(cfg):
|
|
484 |
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
485 |
)
|
486 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
487 |
# TODO
|
488 |
# MPT 7b
|
489 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
1 |
"""Module for working with config dicts"""
|
2 |
+
import json
|
3 |
import logging
|
4 |
import os
|
5 |
+
from pathlib import Path
|
6 |
|
7 |
import torch
|
8 |
from transformers.utils import is_torch_bf16_gpu_available
|
9 |
|
10 |
from axolotl.utils.bench import log_gpu_memory_usage
|
11 |
+
from axolotl.utils.dict import DictDefault
|
12 |
from axolotl.utils.models import load_model_config
|
13 |
|
14 |
LOG = logging.getLogger("axolotl")
|
|
|
137 |
]
|
138 |
)
|
139 |
or cfg.is_mistral_derived_model
|
140 |
+
or "mistral" in cfg.base_model.lower().split("/")[-1]
|
141 |
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
142 |
)
|
143 |
|
|
|
486 |
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
487 |
)
|
488 |
|
489 |
+
if (
|
490 |
+
cfg.unfrozen_parameters
|
491 |
+
and cfg.gradient_checkpointing_kwargs
|
492 |
+
and cfg.gradient_checkpointing_kwargs.use_reentrant is True
|
493 |
+
):
|
494 |
+
# https://github.com/huggingface/transformers/issues/21381
|
495 |
+
raise ValueError(
|
496 |
+
"`use_reentrant` must be false when used with partially frozen model."
|
497 |
+
)
|
498 |
+
|
499 |
+
if cfg.flash_attention and cfg.deepspeed and Path(cfg.deepspeed).is_file():
|
500 |
+
with open(cfg.deepspeed, encoding="utf-8") as file:
|
501 |
+
contents = file.read()
|
502 |
+
deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
|
503 |
+
if (
|
504 |
+
deepspeed_cfg.zero_optimization
|
505 |
+
and deepspeed_cfg.zero_optimization.stage == 3
|
506 |
+
):
|
507 |
+
if not (
|
508 |
+
(
|
509 |
+
deepspeed_cfg.bf16
|
510 |
+
and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
|
511 |
+
is True
|
512 |
+
)
|
513 |
+
or (
|
514 |
+
deepspeed_cfg.fp16
|
515 |
+
and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
|
516 |
+
is True
|
517 |
+
)
|
518 |
+
):
|
519 |
+
raise ValueError(
|
520 |
+
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
|
521 |
+
)
|
522 |
+
|
523 |
# TODO
|
524 |
# MPT 7b
|
525 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
src/axolotl/utils/models.py
CHANGED
@@ -305,12 +305,16 @@ def load_model(
|
|
305 |
)
|
306 |
|
307 |
# Modify mistral derived models
|
308 |
-
if
|
|
|
|
|
|
|
|
|
309 |
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
310 |
replace_mistral_attn_with_flash_attn,
|
311 |
)
|
312 |
|
313 |
-
LOG.info("patching with flash attention")
|
314 |
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
315 |
|
316 |
if (
|
@@ -322,7 +326,7 @@ def load_model(
|
|
322 |
replace_mixtral_attn_with_multipack_flash_attn,
|
323 |
)
|
324 |
|
325 |
-
LOG.info("patching with flash attention")
|
326 |
replace_mixtral_attn_with_multipack_flash_attn()
|
327 |
|
328 |
if (
|
|
|
305 |
)
|
306 |
|
307 |
# Modify mistral derived models
|
308 |
+
if (
|
309 |
+
cfg.model_config_type == "mistral"
|
310 |
+
and cfg.flash_attention
|
311 |
+
and cfg.sample_packing
|
312 |
+
):
|
313 |
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
314 |
replace_mistral_attn_with_flash_attn,
|
315 |
)
|
316 |
|
317 |
+
LOG.info("patching mistral with flash attention")
|
318 |
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
319 |
|
320 |
if (
|
|
|
326 |
replace_mixtral_attn_with_multipack_flash_attn,
|
327 |
)
|
328 |
|
329 |
+
LOG.info("patching mixtral with flash attention")
|
330 |
replace_mixtral_attn_with_multipack_flash_attn()
|
331 |
|
332 |
if (
|
src/axolotl/utils/trainer.py
CHANGED
@@ -152,6 +152,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|
152 |
or (cfg.is_mistral_derived_model and cfg.flash_attention)
|
153 |
or cfg.model_config_type == "mamba"
|
154 |
):
|
|
|
155 |
train_dataset = train_dataset.remove_columns("attention_mask")
|
156 |
if eval_dataset:
|
157 |
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
|
|
152 |
or (cfg.is_mistral_derived_model and cfg.flash_attention)
|
153 |
or cfg.model_config_type == "mamba"
|
154 |
):
|
155 |
+
LOG.info("dropping attention_mask column")
|
156 |
train_dataset = train_dataset.remove_columns("attention_mask")
|
157 |
if eval_dataset:
|
158 |
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
tests/e2e/patched/test_mixtral_samplepack.py
CHANGED
@@ -7,8 +7,6 @@ import os
|
|
7 |
import unittest
|
8 |
from pathlib import Path
|
9 |
|
10 |
-
from transformers.utils import is_torch_bf16_gpu_available
|
11 |
-
|
12 |
from axolotl.cli import load_datasets
|
13 |
from axolotl.common.cli import TrainerCliArgs
|
14 |
from axolotl.train import train
|
@@ -60,12 +58,9 @@ class TestMixtral(unittest.TestCase):
|
|
60 |
"save_steps": 10,
|
61 |
"eval_steps": 10,
|
62 |
"sample_packing": True,
|
|
|
63 |
}
|
64 |
)
|
65 |
-
if is_torch_bf16_gpu_available():
|
66 |
-
cfg.bf16 = True
|
67 |
-
else:
|
68 |
-
cfg.fp16 = True
|
69 |
normalize_config(cfg)
|
70 |
cli_args = TrainerCliArgs()
|
71 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
@@ -101,23 +96,16 @@ class TestMixtral(unittest.TestCase):
|
|
101 |
"save_steps": 10,
|
102 |
"eval_steps": 10,
|
103 |
"sample_packing": True,
|
|
|
104 |
}
|
105 |
)
|
106 |
-
if is_torch_bf16_gpu_available():
|
107 |
-
cfg.bf16 = True
|
108 |
-
else:
|
109 |
-
cfg.fp16 = True
|
110 |
normalize_config(cfg)
|
111 |
cli_args = TrainerCliArgs()
|
112 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
113 |
|
114 |
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
115 |
assert (
|
116 |
-
"
|
117 |
-
in model.model.layers[0].self_attn.__class__.__module__
|
118 |
-
)
|
119 |
-
assert (
|
120 |
-
"MixtralMultipackFlashAttention2"
|
121 |
in model.model.layers[0].self_attn.__class__.__name__
|
122 |
)
|
123 |
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
|
|
7 |
import unittest
|
8 |
from pathlib import Path
|
9 |
|
|
|
|
|
10 |
from axolotl.cli import load_datasets
|
11 |
from axolotl.common.cli import TrainerCliArgs
|
12 |
from axolotl.train import train
|
|
|
58 |
"save_steps": 10,
|
59 |
"eval_steps": 10,
|
60 |
"sample_packing": True,
|
61 |
+
"bf16": "auto",
|
62 |
}
|
63 |
)
|
|
|
|
|
|
|
|
|
64 |
normalize_config(cfg)
|
65 |
cli_args = TrainerCliArgs()
|
66 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
|
96 |
"save_steps": 10,
|
97 |
"eval_steps": 10,
|
98 |
"sample_packing": True,
|
99 |
+
"bf16": "auto",
|
100 |
}
|
101 |
)
|
|
|
|
|
|
|
|
|
102 |
normalize_config(cfg)
|
103 |
cli_args = TrainerCliArgs()
|
104 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
105 |
|
106 |
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
107 |
assert (
|
108 |
+
"MixtralFlashAttention2"
|
|
|
|
|
|
|
|
|
109 |
in model.model.layers[0].self_attn.__class__.__name__
|
110 |
)
|
111 |
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
tests/e2e/patched/test_model_patches.py
CHANGED
@@ -52,11 +52,7 @@ class TestModelPatches(unittest.TestCase):
|
|
52 |
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
53 |
|
54 |
assert (
|
55 |
-
"
|
56 |
-
in model.model.layers[0].self_attn.__class__.__module__
|
57 |
-
)
|
58 |
-
assert (
|
59 |
-
"MixtralMultipackFlashAttention2"
|
60 |
in model.model.layers[0].self_attn.__class__.__name__
|
61 |
)
|
62 |
|
|
|
52 |
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
53 |
|
54 |
assert (
|
55 |
+
"MixtralFlashAttention2"
|
|
|
|
|
|
|
|
|
56 |
in model.model.layers[0].self_attn.__class__.__name__
|
57 |
)
|
58 |
|
tests/monkeypatch/test_llama_attn_hijack_flash.py
CHANGED
@@ -5,7 +5,12 @@ import unittest
|
|
5 |
|
6 |
import torch
|
7 |
|
8 |
-
from axolotl.monkeypatch.utils import
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
class TestMonkeyPatchUtils(unittest.TestCase):
|
@@ -25,6 +30,70 @@ class TestMonkeyPatchUtils(unittest.TestCase):
|
|
25 |
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
26 |
)
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
if __name__ == "__main__":
|
30 |
unittest.main()
|
|
|
5 |
|
6 |
import torch
|
7 |
|
8 |
+
from axolotl.monkeypatch.utils import (
|
9 |
+
get_cu_seqlens,
|
10 |
+
get_cu_seqlens_from_pos_ids,
|
11 |
+
get_max_seqlen_in_batch,
|
12 |
+
get_unpad_data,
|
13 |
+
)
|
14 |
|
15 |
|
16 |
class TestMonkeyPatchUtils(unittest.TestCase):
|
|
|
30 |
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
31 |
)
|
32 |
|
33 |
+
def test_get_max_seqlen_in_batch(self):
|
34 |
+
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
35 |
+
target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32)
|
36 |
+
self.assertTrue(torch.allclose(get_max_seqlen_in_batch(attn_mask), target_res))
|
37 |
+
|
38 |
+
def test_get_unpad_data(self):
|
39 |
+
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
40 |
+
target_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
|
41 |
+
target_cu_seqlen = torch.tensor([0, 4, 7, 12, 14], dtype=torch.int32)
|
42 |
+
target_max_seqlen_in_batch = 5
|
43 |
+
indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask)
|
44 |
+
self.assertTrue(torch.allclose(target_indices, indices))
|
45 |
+
self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen))
|
46 |
+
self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch)
|
47 |
+
|
48 |
+
attn_mask = torch.tensor(
|
49 |
+
[
|
50 |
+
[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0],
|
51 |
+
[1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5],
|
52 |
+
]
|
53 |
+
)
|
54 |
+
target_indices = torch.tensor(
|
55 |
+
[
|
56 |
+
0,
|
57 |
+
1,
|
58 |
+
2,
|
59 |
+
3,
|
60 |
+
4,
|
61 |
+
5,
|
62 |
+
6,
|
63 |
+
7,
|
64 |
+
8,
|
65 |
+
9,
|
66 |
+
10,
|
67 |
+
11,
|
68 |
+
12,
|
69 |
+
13,
|
70 |
+
16,
|
71 |
+
17,
|
72 |
+
18,
|
73 |
+
19,
|
74 |
+
20,
|
75 |
+
21,
|
76 |
+
22,
|
77 |
+
23,
|
78 |
+
24,
|
79 |
+
25,
|
80 |
+
26,
|
81 |
+
27,
|
82 |
+
28,
|
83 |
+
29,
|
84 |
+
30,
|
85 |
+
31,
|
86 |
+
]
|
87 |
+
)
|
88 |
+
target_cu_seqlen = torch.tensor(
|
89 |
+
[0, 4, 7, 12, 14, 17, 22, 24, 27, 30], dtype=torch.int32
|
90 |
+
)
|
91 |
+
target_max_seqlen_in_batch = 5
|
92 |
+
indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask)
|
93 |
+
self.assertTrue(torch.allclose(target_indices, indices))
|
94 |
+
self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen))
|
95 |
+
self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch)
|
96 |
+
|
97 |
|
98 |
if __name__ == "__main__":
|
99 |
unittest.main()
|