winglian commited on
Commit
6910e6a
1 Parent(s): 1d70f24

Multipack simplify for Mixtral (#1142)

Browse files
src/axolotl/core/trainer_builder.py CHANGED
@@ -12,7 +12,7 @@ from abc import abstractmethod
12
  from dataclasses import dataclass, field
13
  from functools import wraps
14
  from pathlib import Path
15
- from typing import Optional
16
 
17
  import torch
18
  import transformers
@@ -37,6 +37,7 @@ from axolotl.utils.collators import (
37
  BatchSamplerDataCollatorForSeq2Seq,
38
  DataCollatorForSeq2Seq,
39
  MambaDataCollator,
 
40
  )
41
  from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
42
  from axolotl.utils.schedulers import (
@@ -896,14 +897,22 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
896
  if is_eval and training_args.eval_sample_packing:
897
  use_batch_sampler_collator = True
898
 
 
 
 
 
 
 
 
899
  if use_batch_sampler_collator:
900
- return BatchSamplerDataCollatorForSeq2Seq(
901
- self.tokenizer,
902
- return_tensors="pt",
903
- **kwargs,
904
- )
 
905
 
906
- return DataCollatorForSeq2Seq(
907
  self.tokenizer,
908
  return_tensors="pt",
909
  **kwargs,
 
12
  from dataclasses import dataclass, field
13
  from functools import wraps
14
  from pathlib import Path
15
+ from typing import Optional, Type, Union
16
 
17
  import torch
18
  import transformers
 
37
  BatchSamplerDataCollatorForSeq2Seq,
38
  DataCollatorForSeq2Seq,
39
  MambaDataCollator,
40
+ V2BatchSamplerDataCollatorForSeq2Seq,
41
  )
42
  from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
43
  from axolotl.utils.schedulers import (
 
897
  if is_eval and training_args.eval_sample_packing:
898
  use_batch_sampler_collator = True
899
 
900
+ collator: Type[
901
+ Union[
902
+ V2BatchSamplerDataCollatorForSeq2Seq,
903
+ BatchSamplerDataCollatorForSeq2Seq,
904
+ DataCollatorForSeq2Seq,
905
+ ]
906
+ ]
907
  if use_batch_sampler_collator:
908
+ if self.cfg.model_config_type == "mixtral":
909
+ collator = V2BatchSamplerDataCollatorForSeq2Seq
910
+ else:
911
+ collator = BatchSamplerDataCollatorForSeq2Seq
912
+ else:
913
+ collator = DataCollatorForSeq2Seq
914
 
915
+ return collator(
916
  self.tokenizer,
917
  return_tensors="pt",
918
  **kwargs,
src/axolotl/monkeypatch/mixtral/__init__.py CHANGED
@@ -3,20 +3,10 @@ Patches to support multipack for mixtral
3
  """
4
  import transformers
5
 
 
6
 
7
- def replace_mixtral_attn_with_multipack_flash_attn():
8
- from .modeling_mixtral import (
9
- MixtralMultipackFlashAttention2,
10
- mixtral_decoder_layer_forward,
11
- mixtral_model_forward,
12
- )
13
 
14
- transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = (
15
- mixtral_decoder_layer_forward
16
- )
17
- transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
18
- mixtral_model_forward
19
  )
20
- transformers.models.mixtral.modeling_mixtral.MIXTRAL_ATTENTION_CLASSES[
21
- "flash_attention_2"
22
- ] = MixtralMultipackFlashAttention2
 
3
  """
4
  import transformers
5
 
6
+ from axolotl.monkeypatch.utils import get_unpad_data
7
 
 
 
 
 
 
 
8
 
9
+ def replace_mixtral_attn_with_multipack_flash_attn():
10
+ transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
11
+ get_unpad_data
 
 
12
  )
 
 
 
src/axolotl/monkeypatch/mixtral/modeling_mixtral.py DELETED
@@ -1,383 +0,0 @@
1
- """
2
- Mixtral modeling for multipack
3
- """
4
- # pylint: disable=missing-module-docstring,unused-argument,protected-access,pointless-string-statement,duplicate-code
5
- import logging
6
- import warnings
7
- from typing import List, Optional, Tuple, Union
8
-
9
- import torch
10
- from einops import rearrange
11
- from flash_attn import flash_attn_varlen_qkvpacked_func
12
- from transformers import Cache, DynamicCache
13
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
14
- from transformers.modeling_outputs import MoeModelOutputWithPast
15
- from transformers.models.mixtral.modeling_mixtral import (
16
- MixtralFlashAttention2,
17
- apply_rotary_pos_emb,
18
- repeat_kv,
19
- )
20
-
21
- from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
22
-
23
- LOG = logging.getLogger("axolotl.monkeypatch.mixtral")
24
-
25
-
26
- class MixtralMultipackFlashAttention2(MixtralFlashAttention2):
27
- """
28
- Custom multipack implementation w flash attention 2
29
- """
30
-
31
- def __init__(self, *args, **kwargs):
32
- super().__init__(*args, **kwargs)
33
- self._flash_attn_uses_top_left_mask = True
34
-
35
- def forward(
36
- self,
37
- hidden_states: torch.Tensor,
38
- attention_mask: Optional[torch.Tensor] = None,
39
- position_ids: Optional[torch.LongTensor] = None,
40
- past_key_value: Optional[Cache] = None,
41
- output_attentions: bool = False,
42
- use_cache: bool = False,
43
- cu_seqlens: Optional[torch.Tensor] = None,
44
- max_seqlen: Optional[torch.Tensor] = None,
45
- **kwargs,
46
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
47
- if "padding_mask" in kwargs:
48
- warnings.warn(
49
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
50
- )
51
- bsz, q_len, _ = hidden_states.size()
52
-
53
- query_states = self.q_proj(hidden_states)
54
- key_states = self.k_proj(hidden_states)
55
- value_states = self.v_proj(hidden_states)
56
-
57
- query_states = query_states.view(
58
- bsz, q_len, self.num_heads, self.head_dim
59
- ).transpose(1, 2)
60
- key_states = key_states.view(
61
- bsz, q_len, self.num_key_value_heads, self.head_dim
62
- ).transpose(1, 2)
63
- value_states = value_states.view(
64
- bsz, q_len, self.num_key_value_heads, self.head_dim
65
- ).transpose(1, 2)
66
-
67
- kv_seq_len = key_states.shape[-2]
68
- if past_key_value is not None:
69
- if self.layer_idx is None:
70
- raise ValueError(
71
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
72
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
73
- "with a layer index."
74
- )
75
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
76
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
77
- query_states, key_states = apply_rotary_pos_emb(
78
- query_states, key_states, cos, sin, position_ids
79
- )
80
-
81
- if past_key_value is not None:
82
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
83
- key_states, value_states = past_key_value.update(
84
- key_states, value_states, self.layer_idx, cache_kwargs
85
- )
86
-
87
- # repeat k/v heads if n_kv_heads < n_heads
88
- key_states = repeat_kv(key_states, self.num_key_value_groups)
89
- value_states = repeat_kv(value_states, self.num_key_value_groups)
90
-
91
- if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
92
- # special handling using sample packing
93
- qkv = torch.stack(
94
- [query_states, key_states, value_states], dim=2
95
- ) # [bsz, nh, 3, q_len, hd]
96
- qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
97
- qkv = rearrange(qkv, "b s ... -> (b s) ...")
98
-
99
- attn_output = flash_attn_varlen_qkvpacked_func(
100
- qkv,
101
- cu_seqlens,
102
- max_seqlen,
103
- dropout_p=self.attention_dropout,
104
- softmax_scale=None,
105
- causal=True,
106
- )
107
- attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz)
108
-
109
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
110
- attn_output = self.o_proj(attn_output)
111
-
112
- if not output_attentions:
113
- attn_weights = None
114
-
115
- return attn_output, attn_weights, past_key_value
116
-
117
-
118
- def mixtral_decoder_layer_forward(
119
- self,
120
- hidden_states: torch.Tensor,
121
- attention_mask: Optional[torch.Tensor] = None,
122
- position_ids: Optional[torch.LongTensor] = None,
123
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
124
- output_attentions: Optional[bool] = False,
125
- output_router_logits: Optional[bool] = False,
126
- use_cache: Optional[bool] = False,
127
- cu_seqlens: Optional[torch.Tensor] = None,
128
- max_seqlen: Optional[torch.Tensor] = None,
129
- **kwargs,
130
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
131
- if "padding_mask" in kwargs:
132
- warnings.warn(
133
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
134
- )
135
- """
136
- Args:
137
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
138
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
139
- `(batch, sequence_length)` where padding elements are indicated by 0.
140
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
141
- output_attentions (`bool`, *optional*):
142
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
143
- returned tensors for more detail.
144
- output_router_logits (`bool`, *optional*):
145
- Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
146
- should not be returned during inference.
147
- use_cache (`bool`, *optional*):
148
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
149
- (see `past_key_values`).
150
- """
151
-
152
- residual = hidden_states
153
-
154
- hidden_states = self.input_layernorm(hidden_states)
155
-
156
- # Self Attention
157
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
158
- hidden_states=hidden_states,
159
- attention_mask=attention_mask,
160
- position_ids=position_ids,
161
- past_key_value=past_key_value,
162
- output_attentions=output_attentions,
163
- use_cache=use_cache,
164
- cu_seqlens=cu_seqlens,
165
- max_seqlen=max_seqlen,
166
- )
167
- hidden_states = residual + hidden_states
168
-
169
- # Fully Connected
170
- residual = hidden_states
171
- hidden_states = self.post_attention_layernorm(hidden_states)
172
- hidden_states, router_logits = self.block_sparse_moe(hidden_states)
173
- hidden_states = residual + hidden_states
174
-
175
- outputs = (hidden_states,)
176
-
177
- if output_attentions:
178
- outputs += (self_attn_weights,)
179
-
180
- if use_cache:
181
- outputs += (present_key_value,)
182
-
183
- if output_router_logits:
184
- outputs += (router_logits,)
185
-
186
- return outputs
187
-
188
-
189
- def mixtral_model_forward(
190
- self,
191
- input_ids: torch.LongTensor = None,
192
- attention_mask: Optional[torch.Tensor] = None,
193
- position_ids: Optional[torch.LongTensor] = None,
194
- past_key_values: Optional[List[torch.FloatTensor]] = None,
195
- inputs_embeds: Optional[torch.FloatTensor] = None,
196
- use_cache: Optional[bool] = None,
197
- output_attentions: Optional[bool] = None,
198
- output_hidden_states: Optional[bool] = None,
199
- output_router_logits: Optional[bool] = None,
200
- return_dict: Optional[bool] = None,
201
- ) -> Union[Tuple, MoeModelOutputWithPast]:
202
- output_attentions = (
203
- output_attentions
204
- if output_attentions is not None
205
- else self.config.output_attentions
206
- )
207
- output_router_logits = (
208
- output_router_logits
209
- if output_router_logits is not None
210
- else self.config.output_router_logits
211
- )
212
- output_hidden_states = (
213
- output_hidden_states
214
- if output_hidden_states is not None
215
- else self.config.output_hidden_states
216
- )
217
- use_cache = use_cache if use_cache is not None else self.config.use_cache
218
-
219
- return_dict = (
220
- return_dict if return_dict is not None else self.config.use_return_dict
221
- )
222
-
223
- # retrieve input_ids and inputs_embeds
224
- if input_ids is not None and inputs_embeds is not None:
225
- raise ValueError(
226
- "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
227
- )
228
- if input_ids is not None:
229
- batch_size, seq_length = input_ids.shape
230
- elif inputs_embeds is not None:
231
- batch_size, seq_length, _ = inputs_embeds.shape
232
- else:
233
- raise ValueError(
234
- "You have to specify either decoder_input_ids or decoder_inputs_embeds"
235
- )
236
-
237
- past_key_values_length = 0
238
-
239
- if use_cache:
240
- use_legacy_cache = not isinstance(past_key_values, Cache)
241
- if use_legacy_cache:
242
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
243
- past_key_values_length = past_key_values.get_usable_length(seq_length)
244
-
245
- cu_seqlens = None
246
- max_seqlen = None
247
- if position_ids is None:
248
- device = input_ids.device if input_ids is not None else inputs_embeds.device
249
- position_ids = torch.arange(
250
- past_key_values_length,
251
- seq_length + past_key_values_length,
252
- dtype=torch.long,
253
- device=device,
254
- )
255
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
256
- else:
257
- position_ids = position_ids.view(-1, seq_length).long()
258
- cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
259
- cu_seqlens = cu_seqlens.squeeze()
260
-
261
- if inputs_embeds is None:
262
- inputs_embeds = self.embed_tokens(input_ids)
263
-
264
- if (
265
- attention_mask is not None
266
- and self._attn_implementation == "flash_attention_2"
267
- and use_cache
268
- ):
269
- is_padding_right = attention_mask[:, -1].sum().item() != batch_size
270
- if is_padding_right:
271
- raise ValueError(
272
- "You are attempting to perform batched generation with padding_side='right'"
273
- " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
274
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
275
- )
276
-
277
- if self._attn_implementation == "flash_attention_2":
278
- # 2d mask is passed through the layers
279
- attention_mask = (
280
- attention_mask
281
- if (attention_mask is not None and 0 in attention_mask)
282
- else None
283
- )
284
- else:
285
- # 4d mask is passed through the layers
286
- attention_mask = _prepare_4d_causal_attention_mask(
287
- attention_mask,
288
- (batch_size, seq_length),
289
- inputs_embeds,
290
- past_key_values_length,
291
- sliding_window=self.config.sliding_window,
292
- )
293
-
294
- hidden_states = inputs_embeds
295
-
296
- if self.gradient_checkpointing and self.training:
297
- if use_cache:
298
- LOG.warning_once(
299
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
300
- )
301
- use_cache = False
302
-
303
- # decoder layers
304
- all_hidden_states = () if output_hidden_states else None
305
- all_self_attns = () if output_attentions else None
306
- all_router_logits = () if output_router_logits else None
307
- next_decoder_cache = None
308
-
309
- for decoder_layer in self.layers:
310
- if output_hidden_states:
311
- all_hidden_states += (hidden_states,)
312
-
313
- if self.gradient_checkpointing and self.training:
314
- layer_outputs = self._gradient_checkpointing_func(
315
- decoder_layer.__call__,
316
- hidden_states,
317
- attention_mask,
318
- position_ids,
319
- past_key_values,
320
- output_attentions,
321
- output_router_logits,
322
- use_cache,
323
- cu_seqlens,
324
- max_seqlen,
325
- )
326
- else:
327
- layer_outputs = decoder_layer(
328
- hidden_states,
329
- attention_mask=attention_mask,
330
- position_ids=position_ids,
331
- past_key_value=past_key_values,
332
- output_attentions=output_attentions,
333
- output_router_logits=output_router_logits,
334
- use_cache=use_cache,
335
- cu_seqlens=cu_seqlens,
336
- max_seqlen=max_seqlen,
337
- )
338
-
339
- hidden_states = layer_outputs[0]
340
-
341
- if use_cache:
342
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
343
-
344
- if output_attentions:
345
- all_self_attns += (layer_outputs[1],)
346
-
347
- if output_router_logits:
348
- all_router_logits += (layer_outputs[-1],)
349
-
350
- hidden_states = self.norm(hidden_states)
351
-
352
- # add hidden states from the last decoder layer
353
- if output_hidden_states:
354
- all_hidden_states += (hidden_states,)
355
-
356
- next_cache = None
357
- if use_cache:
358
- next_cache = (
359
- next_decoder_cache.to_legacy_cache()
360
- if use_legacy_cache
361
- else next_decoder_cache
362
- )
363
-
364
- if not return_dict:
365
- return tuple(
366
- v
367
- for v in [
368
- hidden_states,
369
- next_cache,
370
- all_hidden_states,
371
- all_self_attns,
372
- all_router_logits,
373
- ]
374
- if v is not None
375
- )
376
-
377
- return MoeModelOutputWithPast(
378
- last_hidden_state=hidden_states,
379
- past_key_values=next_cache,
380
- hidden_states=all_hidden_states,
381
- attentions=all_self_attns,
382
- router_logits=all_router_logits,
383
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/monkeypatch/utils.py CHANGED
@@ -2,6 +2,40 @@
2
  Shared utils for the monkeypatches
3
  """
4
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  def get_cu_seqlens(attn_mask):
 
2
  Shared utils for the monkeypatches
3
  """
4
  import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ @torch.jit.script
9
+ def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
10
+ max_num = int(torch.max(attention_mask).item())
11
+ batch_size, _ = attention_mask.shape
12
+ counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
13
+
14
+ for i in range(1, max_num + 1):
15
+ mask = attention_mask == i
16
+ counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
17
+
18
+ result = counts.flatten()
19
+ nonzero_indices = torch.nonzero(result).squeeze(-1)
20
+ return result[nonzero_indices]
21
+
22
+
23
+ @torch.jit.script
24
+ def get_unpad_data(attention_mask: torch.Tensor):
25
+ device = attention_mask.device
26
+ seqlens_in_batch = get_max_seqlen_in_batch(attention_mask)
27
+ indices = torch.nonzero(attention_mask.flatten()).flatten()
28
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
29
+ cu_seqlens = (
30
+ F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
31
+ .to(device=device)
32
+ .detach()
33
+ )
34
+ return (
35
+ indices,
36
+ cu_seqlens,
37
+ max_seqlen_in_batch,
38
+ )
39
 
40
 
41
  def get_cu_seqlens(attn_mask):
src/axolotl/utils/collators.py CHANGED
@@ -152,6 +152,33 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
152
  return super().__call__(features, return_tensors=return_tensors)
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  @dataclass
156
  class MambaDataCollator:
157
  """
 
152
  return super().__call__(features, return_tensors=return_tensors)
153
 
154
 
155
+ @dataclass
156
+ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
157
+ """
158
+ Collator for multipack specific to the using the BatchSampler
159
+ """
160
+
161
+ def __call__(self, features, return_tensors=None):
162
+ chunked_data = {}
163
+ for feature in features[0].keys():
164
+ if feature == "length":
165
+ continue
166
+ if feature == "attention_mask":
167
+ arrays = [
168
+ (i + 1) * np.array(item[feature])
169
+ for i, item in enumerate(features)
170
+ if feature in item
171
+ ]
172
+ chunked_data[feature] = np.concatenate(arrays)
173
+ else:
174
+ arrays = [
175
+ np.array(item[feature]) for item in features if feature in item
176
+ ]
177
+ chunked_data[feature] = np.concatenate(arrays)
178
+ features = [chunked_data]
179
+ return super().__call__(features, return_tensors=return_tensors)
180
+
181
+
182
  @dataclass
183
  class MambaDataCollator:
184
  """
src/axolotl/utils/config.py CHANGED
@@ -1,12 +1,14 @@
1
  """Module for working with config dicts"""
2
-
3
  import logging
4
  import os
 
5
 
6
  import torch
7
  from transformers.utils import is_torch_bf16_gpu_available
8
 
9
  from axolotl.utils.bench import log_gpu_memory_usage
 
10
  from axolotl.utils.models import load_model_config
11
 
12
  LOG = logging.getLogger("axolotl")
@@ -135,7 +137,7 @@ def normalize_config(cfg):
135
  ]
136
  )
137
  or cfg.is_mistral_derived_model
138
- or "mistral" in cfg.base_model.lower()
139
  or (cfg.model_type and "mistral" in cfg.model_type.lower())
140
  )
141
 
@@ -484,6 +486,40 @@ def validate_config(cfg):
484
  "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
485
  )
486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
  # TODO
488
  # MPT 7b
489
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
1
  """Module for working with config dicts"""
2
+ import json
3
  import logging
4
  import os
5
+ from pathlib import Path
6
 
7
  import torch
8
  from transformers.utils import is_torch_bf16_gpu_available
9
 
10
  from axolotl.utils.bench import log_gpu_memory_usage
11
+ from axolotl.utils.dict import DictDefault
12
  from axolotl.utils.models import load_model_config
13
 
14
  LOG = logging.getLogger("axolotl")
 
137
  ]
138
  )
139
  or cfg.is_mistral_derived_model
140
+ or "mistral" in cfg.base_model.lower().split("/")[-1]
141
  or (cfg.model_type and "mistral" in cfg.model_type.lower())
142
  )
143
 
 
486
  "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
487
  )
488
 
489
+ if (
490
+ cfg.unfrozen_parameters
491
+ and cfg.gradient_checkpointing_kwargs
492
+ and cfg.gradient_checkpointing_kwargs.use_reentrant is True
493
+ ):
494
+ # https://github.com/huggingface/transformers/issues/21381
495
+ raise ValueError(
496
+ "`use_reentrant` must be false when used with partially frozen model."
497
+ )
498
+
499
+ if cfg.flash_attention and cfg.deepspeed and Path(cfg.deepspeed).is_file():
500
+ with open(cfg.deepspeed, encoding="utf-8") as file:
501
+ contents = file.read()
502
+ deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
503
+ if (
504
+ deepspeed_cfg.zero_optimization
505
+ and deepspeed_cfg.zero_optimization.stage == 3
506
+ ):
507
+ if not (
508
+ (
509
+ deepspeed_cfg.bf16
510
+ and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
511
+ is True
512
+ )
513
+ or (
514
+ deepspeed_cfg.fp16
515
+ and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
516
+ is True
517
+ )
518
+ ):
519
+ raise ValueError(
520
+ "bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
521
+ )
522
+
523
  # TODO
524
  # MPT 7b
525
  # https://github.com/facebookresearch/bitsandbytes/issues/25
src/axolotl/utils/models.py CHANGED
@@ -305,12 +305,16 @@ def load_model(
305
  )
306
 
307
  # Modify mistral derived models
308
- if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
 
 
 
 
309
  from axolotl.monkeypatch.mistral_attn_hijack_flash import (
310
  replace_mistral_attn_with_flash_attn,
311
  )
312
 
313
- LOG.info("patching with flash attention")
314
  replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
315
 
316
  if (
@@ -322,7 +326,7 @@ def load_model(
322
  replace_mixtral_attn_with_multipack_flash_attn,
323
  )
324
 
325
- LOG.info("patching with flash attention")
326
  replace_mixtral_attn_with_multipack_flash_attn()
327
 
328
  if (
 
305
  )
306
 
307
  # Modify mistral derived models
308
+ if (
309
+ cfg.model_config_type == "mistral"
310
+ and cfg.flash_attention
311
+ and cfg.sample_packing
312
+ ):
313
  from axolotl.monkeypatch.mistral_attn_hijack_flash import (
314
  replace_mistral_attn_with_flash_attn,
315
  )
316
 
317
+ LOG.info("patching mistral with flash attention")
318
  replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
319
 
320
  if (
 
326
  replace_mixtral_attn_with_multipack_flash_attn,
327
  )
328
 
329
+ LOG.info("patching mixtral with flash attention")
330
  replace_mixtral_attn_with_multipack_flash_attn()
331
 
332
  if (
src/axolotl/utils/trainer.py CHANGED
@@ -152,6 +152,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
152
  or (cfg.is_mistral_derived_model and cfg.flash_attention)
153
  or cfg.model_config_type == "mamba"
154
  ):
 
155
  train_dataset = train_dataset.remove_columns("attention_mask")
156
  if eval_dataset:
157
  eval_dataset = eval_dataset.remove_columns("attention_mask")
 
152
  or (cfg.is_mistral_derived_model and cfg.flash_attention)
153
  or cfg.model_config_type == "mamba"
154
  ):
155
+ LOG.info("dropping attention_mask column")
156
  train_dataset = train_dataset.remove_columns("attention_mask")
157
  if eval_dataset:
158
  eval_dataset = eval_dataset.remove_columns("attention_mask")
tests/e2e/patched/test_mixtral_samplepack.py CHANGED
@@ -7,8 +7,6 @@ import os
7
  import unittest
8
  from pathlib import Path
9
 
10
- from transformers.utils import is_torch_bf16_gpu_available
11
-
12
  from axolotl.cli import load_datasets
13
  from axolotl.common.cli import TrainerCliArgs
14
  from axolotl.train import train
@@ -60,12 +58,9 @@ class TestMixtral(unittest.TestCase):
60
  "save_steps": 10,
61
  "eval_steps": 10,
62
  "sample_packing": True,
 
63
  }
64
  )
65
- if is_torch_bf16_gpu_available():
66
- cfg.bf16 = True
67
- else:
68
- cfg.fp16 = True
69
  normalize_config(cfg)
70
  cli_args = TrainerCliArgs()
71
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -101,23 +96,16 @@ class TestMixtral(unittest.TestCase):
101
  "save_steps": 10,
102
  "eval_steps": 10,
103
  "sample_packing": True,
 
104
  }
105
  )
106
- if is_torch_bf16_gpu_available():
107
- cfg.bf16 = True
108
- else:
109
- cfg.fp16 = True
110
  normalize_config(cfg)
111
  cli_args = TrainerCliArgs()
112
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
113
 
114
  model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
115
  assert (
116
- "axolotl.monkeypatch.mixtral.modeling_mixtral"
117
- in model.model.layers[0].self_attn.__class__.__module__
118
- )
119
- assert (
120
- "MixtralMultipackFlashAttention2"
121
  in model.model.layers[0].self_attn.__class__.__name__
122
  )
123
  assert (Path(temp_dir) / "pytorch_model.bin").exists()
 
7
  import unittest
8
  from pathlib import Path
9
 
 
 
10
  from axolotl.cli import load_datasets
11
  from axolotl.common.cli import TrainerCliArgs
12
  from axolotl.train import train
 
58
  "save_steps": 10,
59
  "eval_steps": 10,
60
  "sample_packing": True,
61
+ "bf16": "auto",
62
  }
63
  )
 
 
 
 
64
  normalize_config(cfg)
65
  cli_args = TrainerCliArgs()
66
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
 
96
  "save_steps": 10,
97
  "eval_steps": 10,
98
  "sample_packing": True,
99
+ "bf16": "auto",
100
  }
101
  )
 
 
 
 
102
  normalize_config(cfg)
103
  cli_args = TrainerCliArgs()
104
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
105
 
106
  model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
107
  assert (
108
+ "MixtralFlashAttention2"
 
 
 
 
109
  in model.model.layers[0].self_attn.__class__.__name__
110
  )
111
  assert (Path(temp_dir) / "pytorch_model.bin").exists()
tests/e2e/patched/test_model_patches.py CHANGED
@@ -52,11 +52,7 @@ class TestModelPatches(unittest.TestCase):
52
  model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
53
 
54
  assert (
55
- "axolotl.monkeypatch.mixtral.modeling_mixtral"
56
- in model.model.layers[0].self_attn.__class__.__module__
57
- )
58
- assert (
59
- "MixtralMultipackFlashAttention2"
60
  in model.model.layers[0].self_attn.__class__.__name__
61
  )
62
 
 
52
  model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
53
 
54
  assert (
55
+ "MixtralFlashAttention2"
 
 
 
 
56
  in model.model.layers[0].self_attn.__class__.__name__
57
  )
58
 
tests/monkeypatch/test_llama_attn_hijack_flash.py CHANGED
@@ -5,7 +5,12 @@ import unittest
5
 
6
  import torch
7
 
8
- from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids
 
 
 
 
 
9
 
10
 
11
  class TestMonkeyPatchUtils(unittest.TestCase):
@@ -25,6 +30,70 @@ class TestMonkeyPatchUtils(unittest.TestCase):
25
  torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
26
  )
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  if __name__ == "__main__":
30
  unittest.main()
 
5
 
6
  import torch
7
 
8
+ from axolotl.monkeypatch.utils import (
9
+ get_cu_seqlens,
10
+ get_cu_seqlens_from_pos_ids,
11
+ get_max_seqlen_in_batch,
12
+ get_unpad_data,
13
+ )
14
 
15
 
16
  class TestMonkeyPatchUtils(unittest.TestCase):
 
30
  torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
31
  )
32
 
33
+ def test_get_max_seqlen_in_batch(self):
34
+ attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
35
+ target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32)
36
+ self.assertTrue(torch.allclose(get_max_seqlen_in_batch(attn_mask), target_res))
37
+
38
+ def test_get_unpad_data(self):
39
+ attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
40
+ target_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
41
+ target_cu_seqlen = torch.tensor([0, 4, 7, 12, 14], dtype=torch.int32)
42
+ target_max_seqlen_in_batch = 5
43
+ indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask)
44
+ self.assertTrue(torch.allclose(target_indices, indices))
45
+ self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen))
46
+ self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch)
47
+
48
+ attn_mask = torch.tensor(
49
+ [
50
+ [1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0],
51
+ [1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5],
52
+ ]
53
+ )
54
+ target_indices = torch.tensor(
55
+ [
56
+ 0,
57
+ 1,
58
+ 2,
59
+ 3,
60
+ 4,
61
+ 5,
62
+ 6,
63
+ 7,
64
+ 8,
65
+ 9,
66
+ 10,
67
+ 11,
68
+ 12,
69
+ 13,
70
+ 16,
71
+ 17,
72
+ 18,
73
+ 19,
74
+ 20,
75
+ 21,
76
+ 22,
77
+ 23,
78
+ 24,
79
+ 25,
80
+ 26,
81
+ 27,
82
+ 28,
83
+ 29,
84
+ 30,
85
+ 31,
86
+ ]
87
+ )
88
+ target_cu_seqlen = torch.tensor(
89
+ [0, 4, 7, 12, 14, 17, 22, 24, 27, 30], dtype=torch.int32
90
+ )
91
+ target_max_seqlen_in_batch = 5
92
+ indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask)
93
+ self.assertTrue(torch.allclose(target_indices, indices))
94
+ self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen))
95
+ self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch)
96
+
97
 
98
  if __name__ == "__main__":
99
  unittest.main()