Add shifted sparse attention (#973) [skip-ci]
Browse files* Add s2_attn to hijack flash code
* Refactor code to account for s2_attn
* Add test for models utils
* Add ``s2_attention`` option to llama configs
* Add ``s2_attention`` option to README config
* Format code to appease linter
* chore: lint
* Remove xpos and llama-landmark [bad merge]
* add e2e smoke tests for shifted sparse attention
* remove stray patch from merge
* update yml with link to paper for s2_attention/longlora
* fix assertion check for full fine tune
* increase sequence len for tests and PR feedback updates
* reduce context len to 16k for tests
* reduce context len to 16k for tests
* reduce batch size for larger context len and udpate test to check message
* fix test for message
---------
Co-authored-by: joecummings <jrcummings@devvm050.nha0.facebook.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
- README.md +2 -1
- examples/code-llama/13b/lora.yml +1 -0
- examples/code-llama/34b/lora.yml +1 -0
- examples/code-llama/7b/lora.yml +1 -0
- examples/llama-2/lora.yml +1 -0
- examples/openllama-3b/lora.yml +1 -0
- src/axolotl/monkeypatch/llama_attn_hijack_flash.py +140 -1
- src/axolotl/utils/models.py +44 -17
- tests/e2e/patched/test_llama_s2_attention.py +111 -0
- tests/utils/test_models.py +37 -0
@@ -834,7 +834,8 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
|
834 |
# Whether to use scaled-dot-product attention
|
835 |
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
836 |
sdp_attention:
|
837 |
-
|
|
|
838 |
# Resume from a specific checkpoint dir
|
839 |
resume_from_checkpoint:
|
840 |
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
|
|
834 |
# Whether to use scaled-dot-product attention
|
835 |
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
836 |
sdp_attention:
|
837 |
+
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
|
838 |
+
s2_attention:
|
839 |
# Resume from a specific checkpoint dir
|
840 |
resume_from_checkpoint:
|
841 |
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
@@ -52,6 +52,7 @@ local_rank:
|
|
52 |
logging_steps: 1
|
53 |
xformers_attention:
|
54 |
flash_attention: true
|
|
|
55 |
|
56 |
warmup_steps: 10
|
57 |
evals_per_epoch: 4
|
|
|
52 |
logging_steps: 1
|
53 |
xformers_attention:
|
54 |
flash_attention: true
|
55 |
+
s2_attention:
|
56 |
|
57 |
warmup_steps: 10
|
58 |
evals_per_epoch: 4
|
@@ -52,6 +52,7 @@ local_rank:
|
|
52 |
logging_steps: 1
|
53 |
xformers_attention:
|
54 |
flash_attention: true
|
|
|
55 |
|
56 |
warmup_steps: 10
|
57 |
evals_per_epoch: 4
|
|
|
52 |
logging_steps: 1
|
53 |
xformers_attention:
|
54 |
flash_attention: true
|
55 |
+
s2_attention:
|
56 |
|
57 |
warmup_steps: 10
|
58 |
evals_per_epoch: 4
|
@@ -52,6 +52,7 @@ local_rank:
|
|
52 |
logging_steps: 1
|
53 |
xformers_attention:
|
54 |
flash_attention: true
|
|
|
55 |
|
56 |
warmup_steps: 10
|
57 |
evals_per_epoch: 4
|
|
|
52 |
logging_steps: 1
|
53 |
xformers_attention:
|
54 |
flash_attention: true
|
55 |
+
s2_attention:
|
56 |
|
57 |
warmup_steps: 10
|
58 |
evals_per_epoch: 4
|
@@ -52,6 +52,7 @@ local_rank:
|
|
52 |
logging_steps: 1
|
53 |
xformers_attention:
|
54 |
flash_attention: true
|
|
|
55 |
|
56 |
warmup_steps: 10
|
57 |
evals_per_epoch: 4
|
|
|
52 |
logging_steps: 1
|
53 |
xformers_attention:
|
54 |
flash_attention: true
|
55 |
+
s2_attention:
|
56 |
|
57 |
warmup_steps: 10
|
58 |
evals_per_epoch: 4
|
@@ -52,6 +52,7 @@ logging_steps: 1
|
|
52 |
xformers_attention:
|
53 |
flash_attention: true
|
54 |
gptq_groupsize:
|
|
|
55 |
gptq_model_v1:
|
56 |
warmup_steps: 20
|
57 |
evals_per_epoch: 4
|
|
|
52 |
xformers_attention:
|
53 |
flash_attention: true
|
54 |
gptq_groupsize:
|
55 |
+
s2_attention:
|
56 |
gptq_model_v1:
|
57 |
warmup_steps: 20
|
58 |
evals_per_epoch: 4
|
@@ -70,11 +70,20 @@ def replace_llama_attn_with_flash_attn(
|
|
70 |
packed: Optional[bool] = False,
|
71 |
cross_entropy: Optional[bool] = False,
|
72 |
rms_norm: Optional[bool] = False,
|
|
|
73 |
):
|
74 |
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
75 |
_prepare_decoder_attention_mask
|
76 |
)
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
if packed:
|
79 |
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
80 |
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
@@ -213,6 +222,136 @@ def _prepare_decoder_attention_mask(
|
|
213 |
return attention_mask
|
214 |
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
def flashattn_forward(
|
217 |
self,
|
218 |
hidden_states: torch.Tensor,
|
|
|
70 |
packed: Optional[bool] = False,
|
71 |
cross_entropy: Optional[bool] = False,
|
72 |
rms_norm: Optional[bool] = False,
|
73 |
+
use_shifted_sparse_attn: Optional[bool] = False,
|
74 |
):
|
75 |
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
76 |
_prepare_decoder_attention_mask
|
77 |
)
|
78 |
+
if use_shifted_sparse_attn:
|
79 |
+
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
80 |
+
flashattn_forward_with_s2attn
|
81 |
+
)
|
82 |
+
else:
|
83 |
+
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
84 |
+
flashattn_forward
|
85 |
+
)
|
86 |
+
|
87 |
if packed:
|
88 |
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
89 |
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
|
|
222 |
return attention_mask
|
223 |
|
224 |
|
225 |
+
GROUP_SIZE_RATIO = 1 / 4
|
226 |
+
|
227 |
+
|
228 |
+
def flashattn_forward_with_s2attn(
|
229 |
+
self,
|
230 |
+
hidden_states: torch.Tensor,
|
231 |
+
attention_mask: Optional[torch.Tensor] = None,
|
232 |
+
position_ids: Optional[torch.Tensor] = None,
|
233 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
234 |
+
output_attentions: bool = False,
|
235 |
+
use_cache: bool = False,
|
236 |
+
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
237 |
+
cu_seqlens: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
238 |
+
max_seqlen: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
239 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
240 |
+
"""Input shape: Batch x Time x Channel
|
241 |
+
|
242 |
+
From: https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py
|
243 |
+
|
244 |
+
attention_mask: [bsz, q_len]
|
245 |
+
|
246 |
+
`cu_seqlens` will be ignored if provided
|
247 |
+
`max_seqlen` will be ignored if provided
|
248 |
+
"""
|
249 |
+
if output_attentions:
|
250 |
+
warnings.warn(
|
251 |
+
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
252 |
+
)
|
253 |
+
|
254 |
+
bsz, q_len, _ = hidden_states.size()
|
255 |
+
|
256 |
+
query_states = (
|
257 |
+
self.q_proj(hidden_states)
|
258 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
259 |
+
.transpose(1, 2)
|
260 |
+
)
|
261 |
+
key_states = (
|
262 |
+
self.k_proj(hidden_states)
|
263 |
+
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
264 |
+
.transpose(1, 2)
|
265 |
+
)
|
266 |
+
value_states = (
|
267 |
+
self.v_proj(hidden_states)
|
268 |
+
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
269 |
+
.transpose(1, 2)
|
270 |
+
)
|
271 |
+
# [bsz, q_len, nh, hd]
|
272 |
+
# [bsz, nh, q_len, hd]
|
273 |
+
# pylint: disable=duplicate-code
|
274 |
+
|
275 |
+
kv_seq_len = key_states.shape[-2]
|
276 |
+
if past_key_value is not None:
|
277 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
278 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
279 |
+
query_states, key_states = apply_rotary_pos_emb(
|
280 |
+
query_states, key_states, cos, sin, position_ids
|
281 |
+
)
|
282 |
+
|
283 |
+
# Past Key value support
|
284 |
+
if past_key_value is not None:
|
285 |
+
# reuse k, v, self_attention
|
286 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
287 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
288 |
+
|
289 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
290 |
+
|
291 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
292 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
293 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
294 |
+
|
295 |
+
# Flash attention codes from
|
296 |
+
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
|
297 |
+
|
298 |
+
# transform the data into the format required by flash attention
|
299 |
+
qkv = torch.stack(
|
300 |
+
[query_states, key_states, value_states], dim=2
|
301 |
+
) # [bsz, nh, 3, q_len, hd]
|
302 |
+
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
303 |
+
|
304 |
+
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
305 |
+
# the attention_mask should be the same as the key_padding_mask
|
306 |
+
|
307 |
+
key_padding_mask = attention_mask.repeat(2, 1)
|
308 |
+
nheads = qkv.shape[-2]
|
309 |
+
# shift
|
310 |
+
|
311 |
+
group_size = int(q_len * GROUP_SIZE_RATIO)
|
312 |
+
if q_len % group_size > 0:
|
313 |
+
raise ValueError(
|
314 |
+
f"q_len {q_len} should be divisible by group size {group_size}."
|
315 |
+
)
|
316 |
+
|
317 |
+
qkv = (
|
318 |
+
qkv.reshape(bsz, q_len, 3, 2, self.num_heads // 2, self.head_dim)
|
319 |
+
.permute(0, 3, 1, 2, 4, 5)
|
320 |
+
.reshape(bsz * 2, q_len, 3, self.num_heads // 2, self.head_dim)
|
321 |
+
)
|
322 |
+
x = rearrange( # pylint: disable=invalid-name
|
323 |
+
qkv, "b s three h d -> b s (three h d)"
|
324 |
+
)
|
325 |
+
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
326 |
+
cu_q_len_tmp = torch.arange(
|
327 |
+
0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype
|
328 |
+
)
|
329 |
+
cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp + group_size // 2]).repeat(
|
330 |
+
bsz, 1
|
331 |
+
) + cu_q_lens[:-1].unsqueeze(-1)
|
332 |
+
cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1)
|
333 |
+
|
334 |
+
x_unpad = rearrange(
|
335 |
+
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads // 2
|
336 |
+
)
|
337 |
+
output_unpad = flash_attn_varlen_qkvpacked_func(
|
338 |
+
x_unpad, cu_q_lens, group_size, 0.0, softmax_scale=None, causal=True
|
339 |
+
)
|
340 |
+
output = rearrange(
|
341 |
+
pad_input(
|
342 |
+
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz * 2, q_len
|
343 |
+
),
|
344 |
+
"b s (h d) -> b s h d",
|
345 |
+
h=nheads // 2,
|
346 |
+
)
|
347 |
+
output = (
|
348 |
+
output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim)
|
349 |
+
.transpose(1, 2)
|
350 |
+
.reshape(bsz, q_len, nheads, self.head_dim)
|
351 |
+
)
|
352 |
+
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
|
353 |
+
|
354 |
+
|
355 |
def flashattn_forward(
|
356 |
self,
|
357 |
hidden_states: torch.Tensor,
|
@@ -256,31 +256,55 @@ def load_model(
|
|
256 |
|
257 |
replace_stablelm_attn_with_flash_attn(cfg.base_model)
|
258 |
|
259 |
-
if cfg.
|
260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
262 |
replace_llama_attn_with_flash_attn,
|
263 |
)
|
264 |
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
)
|
271 |
-
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
272 |
-
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
273 |
-
hijack_llama_attention,
|
274 |
-
)
|
275 |
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
|
|
|
|
280 |
|
281 |
-
|
282 |
-
|
|
|
|
|
|
|
|
|
283 |
|
|
|
284 |
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
|
285 |
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
286 |
replace_mistral_attn_with_flash_attn,
|
@@ -387,9 +411,12 @@ def load_model(
|
|
387 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
388 |
**bnb_config,
|
389 |
)
|
|
|
390 |
# sample packing uses custom FA2 patch
|
391 |
if cfg.flash_attention:
|
392 |
if not cfg.sample_packing:
|
|
|
|
|
393 |
if (
|
394 |
cfg.is_llama_derived_model
|
395 |
or cfg.is_falcon_derived_model
|
|
|
256 |
|
257 |
replace_stablelm_attn_with_flash_attn(cfg.base_model)
|
258 |
|
259 |
+
if cfg.sample_packing and cfg.s2_attention:
|
260 |
+
raise ValueError(
|
261 |
+
"Received `sample_packing=true` and `s2_attention=true`; however, \
|
262 |
+
shifted-sparse attention does not currently support sample packing."
|
263 |
+
)
|
264 |
+
|
265 |
+
# Modify all llama derived models in one block
|
266 |
+
if cfg.is_llama_derived_model:
|
267 |
+
if cfg.flash_attention:
|
268 |
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
269 |
replace_llama_attn_with_flash_attn,
|
270 |
)
|
271 |
|
272 |
+
if cfg.sample_packing:
|
273 |
+
if cfg.device not in ["mps", "cpu"] and not inference:
|
274 |
+
LOG.info("patching with flash attention for sample packing")
|
275 |
+
replace_llama_attn_with_flash_attn(
|
276 |
+
packed=True,
|
277 |
+
cross_entropy=cfg.flash_attn_cross_entropy,
|
278 |
+
rms_norm=cfg.flash_attn_rms_norm,
|
279 |
+
)
|
280 |
+
elif cfg.s2_attention:
|
281 |
+
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
|
282 |
+
replace_llama_attn_with_flash_attn(
|
283 |
+
packed=False,
|
284 |
+
cross_entropy=cfg.flash_attn_cross_entropy,
|
285 |
+
rms_norm=cfg.flash_attn_rms_norm,
|
286 |
+
use_shifted_sparse_attn=True,
|
287 |
+
)
|
288 |
+
elif cfg.xformers_attention:
|
289 |
+
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
290 |
+
hijack_llama_attention,
|
291 |
)
|
|
|
|
|
|
|
|
|
292 |
|
293 |
+
LOG.info("patching with xformers attention")
|
294 |
+
hijack_llama_attention()
|
295 |
+
elif cfg.sdp_attention:
|
296 |
+
from axolotl.monkeypatch.llama_attn_hijack_sdp import (
|
297 |
+
hijack_llama_sdp_attention,
|
298 |
+
)
|
299 |
|
300 |
+
LOG.info("patching with sdp attention")
|
301 |
+
hijack_llama_sdp_attention()
|
302 |
+
elif cfg.s2_attention:
|
303 |
+
raise NotImplementedError(
|
304 |
+
"Shifted-sparse attention not currently implemented without flash attention."
|
305 |
+
)
|
306 |
|
307 |
+
# Modify mistral derived models
|
308 |
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
|
309 |
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
310 |
replace_mistral_attn_with_flash_attn,
|
|
|
411 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
412 |
**bnb_config,
|
413 |
)
|
414 |
+
|
415 |
# sample packing uses custom FA2 patch
|
416 |
if cfg.flash_attention:
|
417 |
if not cfg.sample_packing:
|
418 |
+
if cfg.s2_attention:
|
419 |
+
pass
|
420 |
if (
|
421 |
cfg.is_llama_derived_model
|
422 |
or cfg.is_falcon_derived_model
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
E2E tests for llama w/ S2 attn
|
3 |
+
"""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import unittest
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
from axolotl.cli import load_datasets
|
11 |
+
from axolotl.common.cli import TrainerCliArgs
|
12 |
+
from axolotl.train import train
|
13 |
+
from axolotl.utils.config import normalize_config
|
14 |
+
from axolotl.utils.dict import DictDefault
|
15 |
+
|
16 |
+
from ..utils import with_temp_dir
|
17 |
+
|
18 |
+
LOG = logging.getLogger("axolotl.tests.e2e")
|
19 |
+
os.environ["WANDB_DISABLED"] = "true"
|
20 |
+
|
21 |
+
|
22 |
+
class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
23 |
+
"""
|
24 |
+
Test case for Llama models using S2 Attn
|
25 |
+
"""
|
26 |
+
|
27 |
+
@with_temp_dir
|
28 |
+
def test_lora_s2_attn(self, temp_dir):
|
29 |
+
# pylint: disable=duplicate-code
|
30 |
+
cfg = DictDefault(
|
31 |
+
{
|
32 |
+
"base_model": "JackFram/llama-68m",
|
33 |
+
"tokenizer_type": "LlamaTokenizer",
|
34 |
+
"sequence_len": 16384,
|
35 |
+
"sample_packing": False,
|
36 |
+
"flash_attention": True,
|
37 |
+
"s2_attention": True,
|
38 |
+
"load_in_8bit": True,
|
39 |
+
"adapter": "lora",
|
40 |
+
"lora_r": 32,
|
41 |
+
"lora_alpha": 16,
|
42 |
+
"lora_dropout": 0.05,
|
43 |
+
"lora_target_linear": True,
|
44 |
+
"val_set_size": 0.1,
|
45 |
+
"special_tokens": {},
|
46 |
+
"datasets": [
|
47 |
+
{
|
48 |
+
"path": "Yukang/LongAlpaca-12k",
|
49 |
+
"type": "alpaca",
|
50 |
+
},
|
51 |
+
],
|
52 |
+
"num_epochs": 2,
|
53 |
+
"micro_batch_size": 1,
|
54 |
+
"gradient_accumulation_steps": 1,
|
55 |
+
"output_dir": temp_dir,
|
56 |
+
"learning_rate": 0.00001,
|
57 |
+
"optimizer": "adamw_torch",
|
58 |
+
"lr_scheduler": "cosine",
|
59 |
+
"max_steps": 10,
|
60 |
+
"save_steps": 5,
|
61 |
+
"eval_steps": 5,
|
62 |
+
"bf16": "auto",
|
63 |
+
}
|
64 |
+
)
|
65 |
+
|
66 |
+
normalize_config(cfg)
|
67 |
+
cli_args = TrainerCliArgs()
|
68 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
69 |
+
|
70 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
71 |
+
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
72 |
+
|
73 |
+
@with_temp_dir
|
74 |
+
def test_fft_s2_attn(self, temp_dir):
|
75 |
+
# pylint: disable=duplicate-code
|
76 |
+
cfg = DictDefault(
|
77 |
+
{
|
78 |
+
"base_model": "JackFram/llama-68m",
|
79 |
+
"tokenizer_type": "LlamaTokenizer",
|
80 |
+
"sequence_len": 16384,
|
81 |
+
"sample_packing": False,
|
82 |
+
"flash_attention": True,
|
83 |
+
"s2_attention": True,
|
84 |
+
"val_set_size": 0.1,
|
85 |
+
"special_tokens": {},
|
86 |
+
"datasets": [
|
87 |
+
{
|
88 |
+
"path": "Yukang/LongAlpaca-12k",
|
89 |
+
"type": "alpaca",
|
90 |
+
},
|
91 |
+
],
|
92 |
+
"num_epochs": 2,
|
93 |
+
"micro_batch_size": 1,
|
94 |
+
"gradient_accumulation_steps": 1,
|
95 |
+
"output_dir": temp_dir,
|
96 |
+
"learning_rate": 0.00001,
|
97 |
+
"optimizer": "adamw_torch",
|
98 |
+
"lr_scheduler": "cosine",
|
99 |
+
"max_steps": 10,
|
100 |
+
"save_steps": 5,
|
101 |
+
"eval_steps": 5,
|
102 |
+
"bf16": "auto",
|
103 |
+
}
|
104 |
+
)
|
105 |
+
|
106 |
+
normalize_config(cfg)
|
107 |
+
cli_args = TrainerCliArgs()
|
108 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
109 |
+
|
110 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
111 |
+
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Module for testing models utils file."""
|
2 |
+
|
3 |
+
|
4 |
+
import unittest
|
5 |
+
from unittest.mock import patch
|
6 |
+
|
7 |
+
import pytest
|
8 |
+
|
9 |
+
from axolotl.utils.dict import DictDefault
|
10 |
+
from axolotl.utils.models import load_model
|
11 |
+
|
12 |
+
|
13 |
+
class ModelsUtilsTest(unittest.TestCase):
|
14 |
+
"""Testing module for models utils."""
|
15 |
+
|
16 |
+
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
|
17 |
+
cfg = DictDefault(
|
18 |
+
{
|
19 |
+
"s2_attention": True,
|
20 |
+
"sample_packing": True,
|
21 |
+
"base_model": "",
|
22 |
+
"model_type": "LlamaForCausalLM",
|
23 |
+
}
|
24 |
+
)
|
25 |
+
|
26 |
+
# Mock out call to HF hub
|
27 |
+
with patch(
|
28 |
+
"axolotl.utils.models.load_model_config"
|
29 |
+
) as mocked_load_model_config:
|
30 |
+
mocked_load_model_config.return_value = {}
|
31 |
+
with pytest.raises(ValueError) as exc:
|
32 |
+
# Should error before hitting tokenizer, so we pass in an empty str
|
33 |
+
load_model(cfg, tokenizer="")
|
34 |
+
assert (
|
35 |
+
"shifted-sparse attention does not currently support sample packing"
|
36 |
+
in str(exc.value)
|
37 |
+
)
|