flash_attention + sample packing for stablelm 3b (#671)
Browse files* stablelm epoch fa patch
* is causal for fa
* working stablelm fa w packing
* chore: pre-commit linting
src/axolotl/monkeypatch/btlm_attn_hijack_flash.py
CHANGED
@@ -7,6 +7,7 @@ import logging
|
|
7 |
from typing import Optional, Tuple
|
8 |
|
9 |
import torch
|
|
|
10 |
from flash_attn.flash_attn_interface import flash_attn_func
|
11 |
from transformers import AutoConfig, AutoModelForCausalLM
|
12 |
|
@@ -17,7 +18,8 @@ def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
|
|
17 |
# this is a wonky hack to get the remotely loaded module
|
18 |
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
19 |
# we need to load the model here in order for modeling_btlm to be available
|
20 |
-
|
|
|
21 |
module_name = model_config.__class__.__module__.replace(
|
22 |
".configuration_btlm", ".modeling_btlm"
|
23 |
)
|
|
|
7 |
from typing import Optional, Tuple
|
8 |
|
9 |
import torch
|
10 |
+
from accelerate import init_empty_weights
|
11 |
from flash_attn.flash_attn_interface import flash_attn_func
|
12 |
from transformers import AutoConfig, AutoModelForCausalLM
|
13 |
|
|
|
18 |
# this is a wonky hack to get the remotely loaded module
|
19 |
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
20 |
# we need to load the model here in order for modeling_btlm to be available
|
21 |
+
with init_empty_weights():
|
22 |
+
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
23 |
module_name = model_config.__class__.__module__.replace(
|
24 |
".configuration_btlm", ".modeling_btlm"
|
25 |
)
|
src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
# This code is based off the following work:
|
17 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
18 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
19 |
+
""" PyTorch StableLM Epoch model. """
|
20 |
+
import importlib
|
21 |
+
import math
|
22 |
+
from typing import Optional, Tuple, Union
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import torch.utils.checkpoint
|
26 |
+
from accelerate import init_empty_weights
|
27 |
+
from einops import rearrange
|
28 |
+
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
29 |
+
flash_attn_varlen_qkvpacked_func,
|
30 |
+
)
|
31 |
+
from torch import nn
|
32 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
33 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
34 |
+
from transformers.utils import logging
|
35 |
+
|
36 |
+
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__)
|
39 |
+
|
40 |
+
|
41 |
+
def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"):
|
42 |
+
# this is a wonky hack to get the remotely loaded module
|
43 |
+
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
44 |
+
# we need to load the model here in order for modeling_stablelm_epoch to be available
|
45 |
+
with init_empty_weights():
|
46 |
+
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
47 |
+
module_name = model_config.__class__.__module__.replace(
|
48 |
+
".configuration_stablelm_epoch", ".modeling_stablelm_epoch"
|
49 |
+
)
|
50 |
+
modeling_stablelm = importlib.import_module(module_name)
|
51 |
+
modeling_stablelm.Attention.forward = ( # pylint: disable=protected-access
|
52 |
+
flashattn_attn
|
53 |
+
)
|
54 |
+
modeling_stablelm.StableLMEpochModel.forward = ( # pylint: disable=protected-access
|
55 |
+
stablelm_model_forward
|
56 |
+
)
|
57 |
+
modeling_stablelm.DecoderLayer.forward = ( # pylint: disable=protected-access
|
58 |
+
decoder_layer_forward
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def rotate_half(x: torch.Tensor):
|
63 |
+
"""Rotates half the hidden dims of the input."""
|
64 |
+
# pylint: disable=invalid-name
|
65 |
+
x1, x2 = torch.chunk(x, 2, dim=-1)
|
66 |
+
return torch.cat((-x2, x1), dim=-1)
|
67 |
+
|
68 |
+
|
69 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
70 |
+
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
71 |
+
# pylint: disable=invalid-name
|
72 |
+
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
73 |
+
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
74 |
+
cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
|
75 |
+
sin = sin[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
|
76 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
77 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
78 |
+
return q_embed, k_embed
|
79 |
+
|
80 |
+
|
81 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
82 |
+
"""
|
83 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
84 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
85 |
+
"""
|
86 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
87 |
+
if n_rep == 1:
|
88 |
+
return hidden_states
|
89 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
90 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
91 |
+
)
|
92 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
93 |
+
|
94 |
+
|
95 |
+
def flashattn_attn(
|
96 |
+
self,
|
97 |
+
hidden_states: torch.FloatTensor,
|
98 |
+
attention_mask: torch.FloatTensor,
|
99 |
+
position_ids: torch.LongTensor,
|
100 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
101 |
+
output_attentions: Optional[bool] = False, # pylint: disable=unused-argument
|
102 |
+
use_cache: Optional[bool] = False,
|
103 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
104 |
+
max_seqlen: Optional[torch.Tensor] = None,
|
105 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
106 |
+
bsz, q_len, _ = hidden_states.size()
|
107 |
+
|
108 |
+
query_states = self.q_proj(hidden_states)
|
109 |
+
key_states = self.k_proj(hidden_states)
|
110 |
+
value_states = self.v_proj(hidden_states)
|
111 |
+
|
112 |
+
query_states = query_states.view(
|
113 |
+
bsz, q_len, self.num_heads, self.head_dim
|
114 |
+
).transpose(1, 2)
|
115 |
+
key_states = key_states.view(
|
116 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
117 |
+
).transpose(1, 2)
|
118 |
+
value_states = value_states.view(
|
119 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
120 |
+
).transpose(1, 2)
|
121 |
+
|
122 |
+
query_rot = query_states[..., : self.rotary_ndims]
|
123 |
+
query_pass = query_states[..., self.rotary_ndims :]
|
124 |
+
key_rot = key_states[..., : self.rotary_ndims]
|
125 |
+
key_pass = key_states[..., self.rotary_ndims :]
|
126 |
+
|
127 |
+
kv_seq_len = key_states.shape[-2]
|
128 |
+
if past_key_value is not None:
|
129 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
130 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
131 |
+
query_states, key_states = apply_rotary_pos_emb(
|
132 |
+
query_rot, key_rot, cos, sin, position_ids
|
133 |
+
)
|
134 |
+
|
135 |
+
# [batch_size, num_heads, seq_len, head_dim]
|
136 |
+
query_states = torch.cat((query_states, query_pass), dim=-1)
|
137 |
+
key_states = torch.cat((key_states, key_pass), dim=-1)
|
138 |
+
|
139 |
+
if past_key_value is not None:
|
140 |
+
# Reuse k, v, self_attention
|
141 |
+
key_states = torch.cat((past_key_value[0], key_states), dim=2)
|
142 |
+
value_states = torch.cat((past_key_value[1], value_states), dim=2)
|
143 |
+
|
144 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
145 |
+
|
146 |
+
# Repeat k/v heads if n_kv_heads < n_heads
|
147 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
148 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
149 |
+
|
150 |
+
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
151 |
+
# special handling using sample packing
|
152 |
+
qkv = torch.stack(
|
153 |
+
[query_states, key_states, value_states], dim=2
|
154 |
+
) # [bsz, nh, 3, q_len, hd]
|
155 |
+
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
156 |
+
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
157 |
+
softmax_scale = None
|
158 |
+
|
159 |
+
output = flash_attn_varlen_qkvpacked_func(
|
160 |
+
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=softmax_scale, causal=True
|
161 |
+
)
|
162 |
+
|
163 |
+
attn_output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
164 |
+
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
165 |
+
else:
|
166 |
+
attn_weights = torch.matmul(
|
167 |
+
query_states, key_states.transpose(2, 3)
|
168 |
+
) / math.sqrt(self.head_dim)
|
169 |
+
|
170 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
171 |
+
raise ValueError(
|
172 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
173 |
+
f" {attn_weights.size()}"
|
174 |
+
)
|
175 |
+
|
176 |
+
if attention_mask is not None:
|
177 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
178 |
+
raise ValueError(
|
179 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
180 |
+
)
|
181 |
+
attn_weights = attn_weights + attention_mask
|
182 |
+
|
183 |
+
# Upcast attention to fp32
|
184 |
+
attn_weights = nn.functional.softmax(
|
185 |
+
attn_weights, dim=-1, dtype=torch.float32
|
186 |
+
).to(query_states.dtype)
|
187 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
188 |
+
|
189 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
190 |
+
raise ValueError(
|
191 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
192 |
+
f" {attn_output.size()}"
|
193 |
+
)
|
194 |
+
|
195 |
+
# Merge heads
|
196 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
197 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
198 |
+
|
199 |
+
# Final linear projection
|
200 |
+
attn_output = self.o_proj(attn_output)
|
201 |
+
|
202 |
+
return attn_output, None, past_key_value
|
203 |
+
|
204 |
+
|
205 |
+
def decoder_layer_forward(
|
206 |
+
self,
|
207 |
+
hidden_states: Optional[torch.FloatTensor],
|
208 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
209 |
+
position_ids: Optional[torch.LongTensor] = None,
|
210 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
211 |
+
output_attentions: Optional[bool] = False,
|
212 |
+
use_cache: Optional[bool] = False,
|
213 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
214 |
+
max_seqlen: Optional[torch.Tensor] = None,
|
215 |
+
) -> Union[
|
216 |
+
Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]
|
217 |
+
]:
|
218 |
+
# pylint: disable=duplicate-code
|
219 |
+
residual = hidden_states
|
220 |
+
|
221 |
+
hidden_states = self.input_layernorm(hidden_states)
|
222 |
+
|
223 |
+
# Self Attention
|
224 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
225 |
+
hidden_states=hidden_states,
|
226 |
+
attention_mask=attention_mask,
|
227 |
+
position_ids=position_ids,
|
228 |
+
past_key_value=past_key_value,
|
229 |
+
output_attentions=output_attentions,
|
230 |
+
use_cache=use_cache,
|
231 |
+
cu_seqlens=cu_seqlens,
|
232 |
+
max_seqlen=max_seqlen,
|
233 |
+
)
|
234 |
+
hidden_states = residual + hidden_states
|
235 |
+
|
236 |
+
# Fully Connected
|
237 |
+
residual = hidden_states
|
238 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
239 |
+
hidden_states = self.mlp(hidden_states)
|
240 |
+
hidden_states = residual + hidden_states
|
241 |
+
|
242 |
+
outputs = (hidden_states,)
|
243 |
+
|
244 |
+
if output_attentions:
|
245 |
+
outputs += (self_attn_weights,)
|
246 |
+
|
247 |
+
if use_cache:
|
248 |
+
outputs += (present_key_value,)
|
249 |
+
|
250 |
+
return outputs
|
251 |
+
|
252 |
+
|
253 |
+
def stablelm_model_forward(
|
254 |
+
self,
|
255 |
+
input_ids: Optional[torch.LongTensor] = None,
|
256 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
257 |
+
position_ids: Optional[torch.LongTensor] = None,
|
258 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
259 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
260 |
+
use_cache: Optional[bool] = None,
|
261 |
+
output_attentions: Optional[bool] = None,
|
262 |
+
output_hidden_states: Optional[bool] = None,
|
263 |
+
return_dict: Optional[bool] = None,
|
264 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
265 |
+
# pylint: disable=duplicate-code
|
266 |
+
output_attentions = (
|
267 |
+
output_attentions
|
268 |
+
if output_attentions is not None
|
269 |
+
else self.config.output_attentions
|
270 |
+
)
|
271 |
+
output_hidden_states = (
|
272 |
+
output_hidden_states
|
273 |
+
if output_hidden_states is not None
|
274 |
+
else self.config.output_hidden_states
|
275 |
+
)
|
276 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
277 |
+
|
278 |
+
return_dict = (
|
279 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
280 |
+
)
|
281 |
+
|
282 |
+
# Retrieve input_ids and inputs_embeds
|
283 |
+
if input_ids is not None and inputs_embeds is not None:
|
284 |
+
raise ValueError(
|
285 |
+
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
286 |
+
)
|
287 |
+
if input_ids is not None:
|
288 |
+
batch_size, seq_length = input_ids.shape
|
289 |
+
elif inputs_embeds is not None:
|
290 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
291 |
+
else:
|
292 |
+
raise ValueError(
|
293 |
+
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
294 |
+
)
|
295 |
+
|
296 |
+
seq_length_with_past = seq_length
|
297 |
+
past_key_values_length = 0
|
298 |
+
|
299 |
+
if past_key_values is not None:
|
300 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
301 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
302 |
+
|
303 |
+
cu_seqlens = None
|
304 |
+
max_seqlen = None
|
305 |
+
if position_ids is None:
|
306 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
307 |
+
position_ids = torch.arange(
|
308 |
+
past_key_values_length,
|
309 |
+
seq_length + past_key_values_length,
|
310 |
+
dtype=torch.long,
|
311 |
+
device=device,
|
312 |
+
)
|
313 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
314 |
+
else:
|
315 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
316 |
+
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
317 |
+
cu_seqlens = cu_seqlens.squeeze()
|
318 |
+
|
319 |
+
if inputs_embeds is None:
|
320 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
321 |
+
# Embed positions
|
322 |
+
if attention_mask is None:
|
323 |
+
attention_mask = torch.ones(
|
324 |
+
(batch_size, seq_length_with_past),
|
325 |
+
dtype=torch.bool,
|
326 |
+
device=inputs_embeds.device,
|
327 |
+
)
|
328 |
+
attention_mask = (
|
329 |
+
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
330 |
+
attention_mask,
|
331 |
+
(batch_size, seq_length),
|
332 |
+
inputs_embeds,
|
333 |
+
past_key_values_length,
|
334 |
+
)
|
335 |
+
)
|
336 |
+
|
337 |
+
hidden_states = inputs_embeds
|
338 |
+
|
339 |
+
if self.gradient_checkpointing and self.training:
|
340 |
+
if use_cache:
|
341 |
+
logger.warning(
|
342 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
343 |
+
)
|
344 |
+
use_cache = False
|
345 |
+
|
346 |
+
# Decoder layers
|
347 |
+
all_hidden_states = () if output_hidden_states else None
|
348 |
+
all_self_attns = () if output_attentions else None
|
349 |
+
next_decoder_cache = () if use_cache else None
|
350 |
+
|
351 |
+
for idx, decoder_layer in enumerate(self.layers):
|
352 |
+
if output_hidden_states:
|
353 |
+
all_hidden_states += (hidden_states,)
|
354 |
+
|
355 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
356 |
+
|
357 |
+
if self.gradient_checkpointing and self.training:
|
358 |
+
|
359 |
+
def create_custom_forward(module):
|
360 |
+
def custom_forward(*inputs):
|
361 |
+
# None for past_key_value
|
362 |
+
return module(*inputs)
|
363 |
+
|
364 |
+
return custom_forward
|
365 |
+
|
366 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
367 |
+
create_custom_forward(decoder_layer),
|
368 |
+
hidden_states,
|
369 |
+
attention_mask,
|
370 |
+
position_ids,
|
371 |
+
past_key_value,
|
372 |
+
output_attentions,
|
373 |
+
None,
|
374 |
+
cu_seqlens,
|
375 |
+
max_seqlen,
|
376 |
+
)
|
377 |
+
else:
|
378 |
+
layer_outputs = decoder_layer(
|
379 |
+
hidden_states,
|
380 |
+
attention_mask=attention_mask,
|
381 |
+
position_ids=position_ids,
|
382 |
+
past_key_value=past_key_value,
|
383 |
+
output_attentions=output_attentions,
|
384 |
+
use_cache=use_cache,
|
385 |
+
cu_seqlens=cu_seqlens,
|
386 |
+
max_seqlen=max_seqlen,
|
387 |
+
)
|
388 |
+
|
389 |
+
hidden_states = layer_outputs[0]
|
390 |
+
|
391 |
+
if use_cache:
|
392 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
393 |
+
|
394 |
+
if output_attentions:
|
395 |
+
all_self_attns += (layer_outputs[1],)
|
396 |
+
|
397 |
+
hidden_states = self.norm(hidden_states)
|
398 |
+
|
399 |
+
# Add hidden states from the last decoder layer
|
400 |
+
if output_hidden_states:
|
401 |
+
all_hidden_states += (hidden_states,)
|
402 |
+
|
403 |
+
next_cache = next_decoder_cache if use_cache else None
|
404 |
+
if not return_dict:
|
405 |
+
return tuple(
|
406 |
+
v
|
407 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
408 |
+
if v is not None
|
409 |
+
)
|
410 |
+
return BaseModelOutputWithPast(
|
411 |
+
last_hidden_state=hidden_states,
|
412 |
+
past_key_values=next_cache,
|
413 |
+
hidden_states=all_hidden_states,
|
414 |
+
attentions=all_self_attns,
|
415 |
+
)
|
src/axolotl/utils/models.py
CHANGED
@@ -124,6 +124,17 @@ def load_model(
|
|
124 |
|
125 |
replace_btlm_attn_with_flash_attn(cfg.base_model)
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
|
128 |
if cfg.device not in ["mps", "cpu"] and not inference:
|
129 |
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
|
|
124 |
|
125 |
replace_btlm_attn_with_flash_attn(cfg.base_model)
|
126 |
|
127 |
+
if (
|
128 |
+
hasattr(model_config, "model_type")
|
129 |
+
and model_config.model_type == "stablelm_epoch"
|
130 |
+
):
|
131 |
+
if cfg.flash_attention and cfg.sample_packing:
|
132 |
+
from axolotl.monkeypatch.stablelm_attn_hijack_flash import (
|
133 |
+
replace_stablelm_attn_with_flash_attn,
|
134 |
+
)
|
135 |
+
|
136 |
+
replace_stablelm_attn_with_flash_attn(cfg.base_model)
|
137 |
+
|
138 |
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
|
139 |
if cfg.device not in ["mps", "cpu"] and not inference:
|
140 |
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|