RaymondAISG
commited on
Upload sea-lion-7b-gptq
Browse files- adapt_tokenizer.py +43 -0
- attention.py +735 -0
- blocks.py +147 -0
- config.json +74 -0
- configuration_mpt.py +322 -0
- custom_embedding.py +11 -0
- fc.py +9 -0
- ffn.py +173 -0
- flash_attn_triton.py +1085 -0
- gptq_model-4bit-128g.safetensors +3 -0
- hf_prefixlm_converter.py +257 -0
- meta_init_context.py +121 -0
- modeling_mpt.py +907 -0
- norm.py +122 -0
- param_init_fns.py +380 -0
- quantize_config.json +11 -0
- tokenization_SEA_BPE.py +197 -0
- tokenizer.model +3 -0
- tokenizer_config.json +53 -0
- warnings.py +20 -0
adapt_tokenizer.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
3 |
+
|
4 |
+
NUM_SENTINEL_TOKENS: int = 100
|
5 |
+
|
6 |
+
|
7 |
+
def adapt_tokenizer_for_denoising(tokenizer: PreTrainedTokenizerBase) -> None:
|
8 |
+
"""Adds sentinel tokens and padding token (if missing).
|
9 |
+
|
10 |
+
Expands the tokenizer vocabulary to include sentinel tokens
|
11 |
+
used in mixture-of-denoiser tasks as well as a padding token.
|
12 |
+
|
13 |
+
All added tokens are added as special tokens. No tokens are
|
14 |
+
added if sentinel tokens and padding token already exist.
|
15 |
+
"""
|
16 |
+
sentinels_to_add = [f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)]
|
17 |
+
tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
|
18 |
+
if tokenizer.pad_token is None:
|
19 |
+
tokenizer.add_tokens("<pad>", special_tokens=True)
|
20 |
+
tokenizer.pad_token = "<pad>"
|
21 |
+
assert tokenizer.pad_token_id is not None
|
22 |
+
sentinels = "".join([f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)])
|
23 |
+
_sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
|
24 |
+
tokenizer.sentinel_token_ids = _sentinel_token_ids
|
25 |
+
|
26 |
+
|
27 |
+
class AutoTokenizerForMOD(AutoTokenizer):
|
28 |
+
"""AutoTokenizer + Adaptation for MOD.
|
29 |
+
|
30 |
+
A simple wrapper around AutoTokenizer to make instantiating
|
31 |
+
an MOD-adapted tokenizer a bit easier.
|
32 |
+
|
33 |
+
MOD-adapted tokenizers have sentinel tokens (e.g., <extra_id_0>),
|
34 |
+
a padding token, and a property to get the token ids of the
|
35 |
+
sentinel tokens.
|
36 |
+
"""
|
37 |
+
|
38 |
+
@classmethod
|
39 |
+
def from_pretrained(cls, *args: Any, **kwargs: Any) -> PreTrainedTokenizerBase:
|
40 |
+
"""See `AutoTokenizer.from_pretrained` docstring."""
|
41 |
+
tokenizer = super().from_pretrained(*args, **kwargs)
|
42 |
+
adapt_tokenizer_for_denoising(tokenizer)
|
43 |
+
return tokenizer
|
attention.py
ADDED
@@ -0,0 +1,735 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Attention layers."""
|
2 |
+
|
3 |
+
import math
|
4 |
+
import warnings
|
5 |
+
from typing import Any, Optional
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import transformers
|
9 |
+
from einops import rearrange
|
10 |
+
from packaging import version
|
11 |
+
from torch import nn
|
12 |
+
from .fc import FC_CLASS_REGISTRY
|
13 |
+
from .norm import NORM_CLASS_REGISTRY
|
14 |
+
|
15 |
+
|
16 |
+
def is_flash_v2_installed(v2_version: str = "2.0.0"):
|
17 |
+
assert version.parse(v2_version) >= version.parse("2.0.0")
|
18 |
+
try:
|
19 |
+
import flash_attn as flash_attn
|
20 |
+
except:
|
21 |
+
return False
|
22 |
+
return version.parse(flash_attn.__version__) >= version.parse(v2_version)
|
23 |
+
|
24 |
+
|
25 |
+
def is_flash_v1_installed():
|
26 |
+
try:
|
27 |
+
import flash_attn as flash_attn
|
28 |
+
except:
|
29 |
+
return False
|
30 |
+
return version.parse(flash_attn.__version__) < version.parse("2.0.0")
|
31 |
+
|
32 |
+
|
33 |
+
def is_transformers_version_gte(hf_version: str) -> bool:
|
34 |
+
return version.parse(transformers.__version__) >= version.parse(hf_version)
|
35 |
+
|
36 |
+
|
37 |
+
def check_alibi_support(attention_impl: str) -> bool:
|
38 |
+
return attention_impl != "flash" or is_flash_v2_installed(v2_version="v2.4.2")
|
39 |
+
|
40 |
+
|
41 |
+
if is_flash_v1_installed():
|
42 |
+
import transformers
|
43 |
+
|
44 |
+
transformers.utils.is_flash_attn_available = lambda: False
|
45 |
+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
46 |
+
|
47 |
+
|
48 |
+
def _reset_is_causal(
|
49 |
+
num_query_tokens: int, num_key_tokens: int, original_is_causal: bool
|
50 |
+
) -> bool:
|
51 |
+
if original_is_causal and num_query_tokens != num_key_tokens:
|
52 |
+
if num_query_tokens != 1:
|
53 |
+
raise NotImplementedError(
|
54 |
+
"MPT does not support query and key with different number of tokens, unless number of query tokens is 1."
|
55 |
+
)
|
56 |
+
else:
|
57 |
+
return False
|
58 |
+
return original_is_causal
|
59 |
+
|
60 |
+
|
61 |
+
def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
|
62 |
+
"""Perform repeat of kv heads along a particular dimension.
|
63 |
+
|
64 |
+
hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim)
|
65 |
+
n_rep: amount of repetitions of kv_n_heads
|
66 |
+
Unlike torch.repeat_interleave, this function avoids allocating new memory.
|
67 |
+
"""
|
68 |
+
if n_rep == 1:
|
69 |
+
return hidden
|
70 |
+
(b, s, kv_n_heads, d) = hidden.shape
|
71 |
+
hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)
|
72 |
+
return hidden.reshape(b, s, kv_n_heads * n_rep, d)
|
73 |
+
|
74 |
+
|
75 |
+
def scaled_multihead_dot_product_attention(
|
76 |
+
query: torch.Tensor,
|
77 |
+
key: torch.Tensor,
|
78 |
+
value: torch.Tensor,
|
79 |
+
n_heads: int,
|
80 |
+
kv_n_heads: int,
|
81 |
+
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
82 |
+
softmax_scale: Optional[float] = None,
|
83 |
+
attn_bias: Optional[torch.Tensor] = None,
|
84 |
+
key_padding_mask: Optional[torch.Tensor] = None,
|
85 |
+
is_causal: bool = False,
|
86 |
+
dropout_p: float = 0.0,
|
87 |
+
training: bool = False,
|
88 |
+
needs_weights: bool = False,
|
89 |
+
) -> tuple[
|
90 |
+
torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]
|
91 |
+
]:
|
92 |
+
q = rearrange(query, "b s (h d) -> b h s d", h=n_heads)
|
93 |
+
k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads)
|
94 |
+
v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads)
|
95 |
+
if past_key_value is not None:
|
96 |
+
if len(past_key_value) != 0:
|
97 |
+
k = torch.cat([past_key_value[0], k], dim=3)
|
98 |
+
v = torch.cat([past_key_value[1], v], dim=2)
|
99 |
+
past_key_value = (k, v)
|
100 |
+
(b, _, s_q, d) = q.shape
|
101 |
+
s_k = k.size(-1)
|
102 |
+
if kv_n_heads > 1 and kv_n_heads < n_heads:
|
103 |
+
k = repeat_kv_for_gqa(k.transpose(1, 2), n_heads // kv_n_heads).transpose(1, 2)
|
104 |
+
v = repeat_kv_for_gqa(v.transpose(1, 2), n_heads // kv_n_heads).transpose(1, 2)
|
105 |
+
if softmax_scale is None:
|
106 |
+
softmax_scale = 1 / math.sqrt(d)
|
107 |
+
attn_weight = q.matmul(k) * softmax_scale
|
108 |
+
if attn_bias is not None:
|
109 |
+
_s_q = max(0, attn_bias.size(2) - s_q)
|
110 |
+
_s_k = max(0, attn_bias.size(3) - s_k)
|
111 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
112 |
+
if (
|
113 |
+
attn_bias.size(-1) != 1
|
114 |
+
and attn_bias.size(-1) != s_k
|
115 |
+
or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q)
|
116 |
+
):
|
117 |
+
raise RuntimeError(
|
118 |
+
f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}."
|
119 |
+
)
|
120 |
+
attn_weight = attn_weight + attn_bias
|
121 |
+
min_val = torch.finfo(q.dtype).min
|
122 |
+
if key_padding_mask is not None:
|
123 |
+
if attn_bias is not None:
|
124 |
+
warnings.warn(
|
125 |
+
"Propagating key_padding_mask to the attention module "
|
126 |
+
+ "and applying it within the attention module can cause "
|
127 |
+
+ "unnecessary computation/memory usage. Consider integrating "
|
128 |
+
+ "into attn_bias once and passing that to each attention "
|
129 |
+
+ "module instead."
|
130 |
+
)
|
131 |
+
attn_weight = attn_weight.masked_fill(
|
132 |
+
~key_padding_mask.view((b, 1, 1, s_k)), min_val
|
133 |
+
)
|
134 |
+
if is_causal and (not q.size(2) == 1):
|
135 |
+
s = max(s_q, s_k)
|
136 |
+
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float32)
|
137 |
+
causal_mask = causal_mask.tril()
|
138 |
+
causal_mask = causal_mask.to(torch.bool)
|
139 |
+
causal_mask = ~causal_mask
|
140 |
+
causal_mask = causal_mask[-s_q:, -s_k:]
|
141 |
+
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
|
142 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
143 |
+
if dropout_p:
|
144 |
+
attn_weight = torch.nn.functional.dropout(
|
145 |
+
attn_weight, p=dropout_p, training=training, inplace=True
|
146 |
+
)
|
147 |
+
out = attn_weight.to(v.dtype).matmul(v)
|
148 |
+
out = rearrange(out, "b h s d -> b s (h d)")
|
149 |
+
if needs_weights:
|
150 |
+
return (out, attn_weight, past_key_value)
|
151 |
+
return (out, None, past_key_value)
|
152 |
+
|
153 |
+
|
154 |
+
def check_valid_inputs(
|
155 |
+
*tensors: torch.Tensor, valid_dtypes: Optional[list[torch.dtype]] = None
|
156 |
+
):
|
157 |
+
if valid_dtypes is None:
|
158 |
+
valid_dtypes = [torch.float16, torch.bfloat16]
|
159 |
+
for tensor in tensors:
|
160 |
+
if tensor.dtype not in valid_dtypes:
|
161 |
+
raise TypeError(
|
162 |
+
f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}."
|
163 |
+
)
|
164 |
+
if not tensor.is_cuda:
|
165 |
+
raise TypeError(
|
166 |
+
f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})."
|
167 |
+
)
|
168 |
+
|
169 |
+
|
170 |
+
def flash_attn_fn(
|
171 |
+
query: torch.Tensor,
|
172 |
+
key: torch.Tensor,
|
173 |
+
value: torch.Tensor,
|
174 |
+
n_heads: int,
|
175 |
+
kv_n_heads: int,
|
176 |
+
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
177 |
+
softmax_scale: Optional[float] = None,
|
178 |
+
attn_bias: Optional[torch.Tensor] = None,
|
179 |
+
key_padding_mask: Optional[torch.Tensor] = None,
|
180 |
+
is_causal: bool = False,
|
181 |
+
dropout_p: float = 0.0,
|
182 |
+
training: bool = False,
|
183 |
+
needs_weights: bool = False,
|
184 |
+
multiquery: bool = False,
|
185 |
+
should_repeat_kv_for_gqa: Optional[bool] = True,
|
186 |
+
sliding_window_size: int = -1,
|
187 |
+
alibi_slopes: Optional[torch.Tensor] = None,
|
188 |
+
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
|
189 |
+
) -> tuple[
|
190 |
+
torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]
|
191 |
+
]:
|
192 |
+
if key_padding_mask is not None:
|
193 |
+
raise ValueError("key_padding_mask should be None for flash attn.")
|
194 |
+
del key_padding_mask
|
195 |
+
if flash_attn_padding_info is None:
|
196 |
+
raise ValueError("flash_attn_padding_info is required for flash attn.")
|
197 |
+
try:
|
198 |
+
from flash_attn import bert_padding, flash_attn_interface
|
199 |
+
except:
|
200 |
+
raise RuntimeError("Please install flash-attn==1.0.9 or flash-attn==2.3.6")
|
201 |
+
check_valid_inputs(query, key, value)
|
202 |
+
if past_key_value is not None:
|
203 |
+
if len(past_key_value) != 0:
|
204 |
+
key = torch.cat([past_key_value[0], key], dim=1)
|
205 |
+
value = torch.cat([past_key_value[1], value], dim=1)
|
206 |
+
past_key_value = (key, value)
|
207 |
+
if attn_bias is not None:
|
208 |
+
raise NotImplementedError(f"attn_bias not implemented for flash attn.")
|
209 |
+
(batch_size, seqlen) = query.shape[:2]
|
210 |
+
indices_q = flash_attn_padding_info["indices_q"]
|
211 |
+
indices_k = flash_attn_padding_info["indices_k"]
|
212 |
+
indices_v = flash_attn_padding_info["indices_v"]
|
213 |
+
cu_seqlens_q = flash_attn_padding_info["cu_seqlens_q"]
|
214 |
+
cu_seqlens_k = flash_attn_padding_info["cu_seqlens_k"]
|
215 |
+
max_seqlen_q = flash_attn_padding_info["max_seqlen_q"]
|
216 |
+
max_seqlen_k = flash_attn_padding_info["max_seqlen_k"]
|
217 |
+
query_unpad = bert_padding.index_first_axis(
|
218 |
+
rearrange(query, "b s ... -> (b s) ..."), indices_q
|
219 |
+
)
|
220 |
+
query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads)
|
221 |
+
key_unpad = bert_padding.index_first_axis(
|
222 |
+
rearrange(key, "b s ... -> (b s) ..."), indices_k
|
223 |
+
)
|
224 |
+
key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=kv_n_heads)
|
225 |
+
value_unpad = bert_padding.index_first_axis(
|
226 |
+
rearrange(value, "b s ... -> (b s) ..."), indices_v
|
227 |
+
)
|
228 |
+
value_unpad = rearrange(value_unpad, "nnz (h d) -> nnz h d", h=kv_n_heads)
|
229 |
+
if (
|
230 |
+
kv_n_heads < n_heads
|
231 |
+
and (not is_flash_v2_installed())
|
232 |
+
and (not should_repeat_kv_for_gqa)
|
233 |
+
):
|
234 |
+
raise ValueError(
|
235 |
+
"For Grouped Query Attention or Multi Query Attention, should_repeat_kv_for_gqa should be set to True if not using Flash Attention v2."
|
236 |
+
)
|
237 |
+
if should_repeat_kv_for_gqa:
|
238 |
+
if kv_n_heads == 1:
|
239 |
+
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
|
240 |
+
value_unpad = value_unpad.expand(
|
241 |
+
value_unpad.size(0), n_heads, value_unpad.size(-1)
|
242 |
+
)
|
243 |
+
elif kv_n_heads < n_heads:
|
244 |
+
key_unpad = repeat_kv_for_gqa(
|
245 |
+
key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1),
|
246 |
+
n_heads // kv_n_heads,
|
247 |
+
).view(key_unpad.size(0), n_heads, -1)
|
248 |
+
value_unpad = repeat_kv_for_gqa(
|
249 |
+
value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1),
|
250 |
+
n_heads // kv_n_heads,
|
251 |
+
).view(value_unpad.size(0), n_heads, -1)
|
252 |
+
dropout_p = dropout_p if training else 0.0
|
253 |
+
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
254 |
+
if is_flash_v1_installed():
|
255 |
+
output_unpad = flash_attn_interface.flash_attn_unpadded_func(
|
256 |
+
q=query_unpad,
|
257 |
+
k=key_unpad,
|
258 |
+
v=value_unpad,
|
259 |
+
cu_seqlens_q=cu_seqlens_q,
|
260 |
+
cu_seqlens_k=cu_seqlens_k,
|
261 |
+
max_seqlen_q=max_seqlen_q,
|
262 |
+
max_seqlen_k=max_seqlen_k,
|
263 |
+
dropout_p=dropout_p,
|
264 |
+
softmax_scale=softmax_scale,
|
265 |
+
causal=reset_is_causal,
|
266 |
+
return_attn_probs=needs_weights,
|
267 |
+
)
|
268 |
+
elif is_flash_v2_installed():
|
269 |
+
alibi_kwargs = {}
|
270 |
+
if check_alibi_support("flash"):
|
271 |
+
alibi_kwargs = {"alibi_slopes": alibi_slopes}
|
272 |
+
elif alibi_slopes is not None:
|
273 |
+
raise ValueError("alibi_slopes is only supported for flash-attn>=2.4.2")
|
274 |
+
output_unpad = flash_attn_interface.flash_attn_varlen_func(
|
275 |
+
q=query_unpad,
|
276 |
+
k=key_unpad,
|
277 |
+
v=value_unpad,
|
278 |
+
cu_seqlens_q=cu_seqlens_q,
|
279 |
+
cu_seqlens_k=cu_seqlens_k,
|
280 |
+
max_seqlen_q=max_seqlen_q,
|
281 |
+
max_seqlen_k=max_seqlen_k,
|
282 |
+
dropout_p=dropout_p,
|
283 |
+
softmax_scale=softmax_scale,
|
284 |
+
causal=reset_is_causal,
|
285 |
+
return_attn_probs=needs_weights,
|
286 |
+
window_size=(sliding_window_size, sliding_window_size),
|
287 |
+
**alibi_kwargs,
|
288 |
+
)
|
289 |
+
else:
|
290 |
+
raise RuntimeError("flash-attn==1.0.9 or flash-attn==2.4.2 is required.")
|
291 |
+
output = bert_padding.pad_input(
|
292 |
+
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen
|
293 |
+
)
|
294 |
+
return (output, None, past_key_value)
|
295 |
+
|
296 |
+
|
297 |
+
def triton_flash_attn_fn(
|
298 |
+
query: torch.Tensor,
|
299 |
+
key: torch.Tensor,
|
300 |
+
value: torch.Tensor,
|
301 |
+
n_heads: int,
|
302 |
+
kv_n_heads: int,
|
303 |
+
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
304 |
+
softmax_scale: Optional[float] = None,
|
305 |
+
attn_bias: Optional[torch.Tensor] = None,
|
306 |
+
key_padding_mask: Optional[torch.Tensor] = None,
|
307 |
+
is_causal: bool = False,
|
308 |
+
dropout_p: float = 0.0,
|
309 |
+
training: bool = False,
|
310 |
+
needs_weights: bool = False,
|
311 |
+
) -> tuple[
|
312 |
+
torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]
|
313 |
+
]:
|
314 |
+
try:
|
315 |
+
from .flash_attn_triton import flash_attn_func
|
316 |
+
except:
|
317 |
+
_installed = False
|
318 |
+
if version.parse(torch.__version__) < version.parse("2.0.0"):
|
319 |
+
_installed = True
|
320 |
+
try:
|
321 |
+
from flash_attn.flash_attn_triton import flash_attn_func
|
322 |
+
except:
|
323 |
+
_installed = False
|
324 |
+
if not _installed:
|
325 |
+
raise RuntimeError(
|
326 |
+
"Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU "
|
327 |
+
+ "and `pip install .[gpu]` if installing from llm-foundry source or "
|
328 |
+
+ "`pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` "
|
329 |
+
+ "if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). "
|
330 |
+
+ "Note: (1) requires you have CMake and PyTorch already installed."
|
331 |
+
)
|
332 |
+
check_valid_inputs(query, key, value)
|
333 |
+
if past_key_value is not None:
|
334 |
+
if len(past_key_value) != 0:
|
335 |
+
key = torch.cat([past_key_value[0], key], dim=1)
|
336 |
+
value = torch.cat([past_key_value[1], value], dim=1)
|
337 |
+
past_key_value = (key, value)
|
338 |
+
if attn_bias is not None:
|
339 |
+
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
340 |
+
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
341 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
342 |
+
if dropout_p:
|
343 |
+
raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
|
344 |
+
dropout_p = dropout_p if training else 0.0
|
345 |
+
if needs_weights:
|
346 |
+
raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
|
347 |
+
if key_padding_mask is not None:
|
348 |
+
warnings.warn(
|
349 |
+
"Propagating key_padding_mask to the attention module "
|
350 |
+
+ "and applying it within the attention module can cause "
|
351 |
+
+ "unnecessary computation/memory usage. Consider integrating "
|
352 |
+
+ "into attn_bias once and passing that to each attention "
|
353 |
+
+ "module instead."
|
354 |
+
)
|
355 |
+
(b_size, s_k) = key_padding_mask.shape[:2]
|
356 |
+
if attn_bias is None:
|
357 |
+
attn_bias = query.new_zeros(b_size, 1, 1, s_k)
|
358 |
+
attn_bias = attn_bias.masked_fill(
|
359 |
+
~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min
|
360 |
+
)
|
361 |
+
query = rearrange(query, "b s (h d) -> b s h d", h=n_heads)
|
362 |
+
key = rearrange(key, "b s (h d) -> b s h d", h=kv_n_heads)
|
363 |
+
value = rearrange(value, "b s (h d) -> b s h d", h=kv_n_heads)
|
364 |
+
if kv_n_heads == 1:
|
365 |
+
key = key.repeat(1, 1, n_heads, 1)
|
366 |
+
value = value.repeat(1, 1, n_heads, 1)
|
367 |
+
elif kv_n_heads < n_heads:
|
368 |
+
key = repeat_kv_for_gqa(key, n_heads // kv_n_heads)
|
369 |
+
value = repeat_kv_for_gqa(value, n_heads // kv_n_heads)
|
370 |
+
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
371 |
+
attn_output = flash_attn_func(
|
372 |
+
query, key, value, attn_bias, reset_is_causal, softmax_scale
|
373 |
+
)
|
374 |
+
output = attn_output.view(*attn_output.shape[:2], -1)
|
375 |
+
return (output, None, past_key_value)
|
376 |
+
|
377 |
+
|
378 |
+
class GroupedQueryAttention(nn.Module):
|
379 |
+
"""Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
|
380 |
+
|
381 |
+
and Multi-query attention (MQA).
|
382 |
+
|
383 |
+
This allows the user to set a variable of number of kv_n_heads, rather than
|
384 |
+
just n_heads or 1, as in MHA and MQA. Using torch or triton attention
|
385 |
+
implementation enables user to also use additive bias.
|
386 |
+
"""
|
387 |
+
|
388 |
+
def __init__(
|
389 |
+
self,
|
390 |
+
d_model: int,
|
391 |
+
n_heads: int,
|
392 |
+
kv_n_heads: int,
|
393 |
+
attn_impl: str = "triton",
|
394 |
+
clip_qkv: Optional[float] = None,
|
395 |
+
qk_ln: bool = False,
|
396 |
+
qk_gn: bool = False,
|
397 |
+
softmax_scale: Optional[float] = None,
|
398 |
+
attn_pdrop: float = 0.0,
|
399 |
+
norm_type: str = "low_precision_layernorm",
|
400 |
+
fc_type: str = "torch",
|
401 |
+
device: Optional[str] = None,
|
402 |
+
bias: bool = True,
|
403 |
+
sliding_window_size: int = -1,
|
404 |
+
):
|
405 |
+
super().__init__()
|
406 |
+
self.attn_impl = attn_impl
|
407 |
+
self.clip_qkv = clip_qkv
|
408 |
+
self.qk_ln = qk_ln
|
409 |
+
self.qk_gn = qk_gn
|
410 |
+
self.d_model = d_model
|
411 |
+
self.n_heads = n_heads
|
412 |
+
self.kv_n_heads = kv_n_heads
|
413 |
+
self.sliding_window_size = sliding_window_size
|
414 |
+
self.head_dim = d_model // n_heads
|
415 |
+
if self.kv_n_heads <= 0:
|
416 |
+
raise ValueError("kv_n_heads should be greater than zero.")
|
417 |
+
if self.kv_n_heads > self.n_heads:
|
418 |
+
raise ValueError(
|
419 |
+
"The number of KV heads should be less than or equal to Q heads."
|
420 |
+
)
|
421 |
+
if self.n_heads % self.kv_n_heads != 0:
|
422 |
+
raise ValueError(
|
423 |
+
"Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads."
|
424 |
+
)
|
425 |
+
if qk_ln and qk_gn:
|
426 |
+
raise ValueError("Only one of qk_ln and qk_gn can be set to True.")
|
427 |
+
self.softmax_scale = softmax_scale
|
428 |
+
if self.softmax_scale is None:
|
429 |
+
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
430 |
+
self.attn_dropout_p = attn_pdrop
|
431 |
+
fc_kwargs: dict[str, Any] = {"bias": bias}
|
432 |
+
if fc_type != "te":
|
433 |
+
fc_kwargs["device"] = device
|
434 |
+
self.Wqkv = FC_CLASS_REGISTRY[fc_type](
|
435 |
+
self.d_model,
|
436 |
+
self.d_model + 2 * self.kv_n_heads * self.head_dim,
|
437 |
+
**fc_kwargs,
|
438 |
+
)
|
439 |
+
fuse_splits = [
|
440 |
+
i * self.head_dim for i in range(1, self.n_heads + 2 * self.kv_n_heads)
|
441 |
+
]
|
442 |
+
self.Wqkv._fused = (0, fuse_splits)
|
443 |
+
if self.qk_ln or self.qk_gn:
|
444 |
+
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
445 |
+
norm_size = self.head_dim if qk_gn else d_model
|
446 |
+
self.q_ln = norm_class(norm_size, device=device)
|
447 |
+
if qk_ln:
|
448 |
+
norm_size = self.head_dim * kv_n_heads
|
449 |
+
self.k_ln = norm_class(norm_size, device=device)
|
450 |
+
if self.attn_impl == "flash":
|
451 |
+
self.attn_fn = flash_attn_fn
|
452 |
+
elif self.attn_impl == "triton":
|
453 |
+
self.attn_fn = triton_flash_attn_fn
|
454 |
+
elif self.attn_impl == "torch":
|
455 |
+
self.attn_fn = scaled_multihead_dot_product_attention
|
456 |
+
else:
|
457 |
+
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
458 |
+
self.out_proj = FC_CLASS_REGISTRY[fc_type](
|
459 |
+
self.d_model, self.d_model, **fc_kwargs
|
460 |
+
)
|
461 |
+
self.out_proj._is_residual = True
|
462 |
+
|
463 |
+
def forward(
|
464 |
+
self,
|
465 |
+
x: torch.Tensor,
|
466 |
+
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
467 |
+
attn_bias: Optional[torch.Tensor] = None,
|
468 |
+
attention_mask: Optional[torch.Tensor] = None,
|
469 |
+
rotary_emb_w_meta_info: Optional[dict] = None,
|
470 |
+
is_causal: bool = True,
|
471 |
+
needs_weights: bool = False,
|
472 |
+
alibi_slopes: Optional[torch.Tensor] = None,
|
473 |
+
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
|
474 |
+
) -> tuple[
|
475 |
+
torch.Tensor,
|
476 |
+
Optional[torch.Tensor],
|
477 |
+
Optional[tuple[torch.Tensor, torch.Tensor]],
|
478 |
+
]:
|
479 |
+
qkv = self.Wqkv(x)
|
480 |
+
if self.clip_qkv:
|
481 |
+
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|
482 |
+
(query, key, value) = qkv.split(
|
483 |
+
[
|
484 |
+
self.d_model,
|
485 |
+
self.kv_n_heads * self.head_dim,
|
486 |
+
self.kv_n_heads * self.head_dim,
|
487 |
+
],
|
488 |
+
dim=2,
|
489 |
+
)
|
490 |
+
key_padding_mask = attention_mask
|
491 |
+
if self.qk_ln or self.qk_gn:
|
492 |
+
(q_shape, k_shape) = (query.shape, key.shape)
|
493 |
+
if self.qk_gn:
|
494 |
+
(b, s) = query.shape[:2]
|
495 |
+
query = query.view(b, s, self.n_heads, -1)
|
496 |
+
key = key.view(b, s, self.kv_n_heads, -1)
|
497 |
+
dtype = query.dtype
|
498 |
+
query = self.q_ln(query).to(dtype).view(q_shape)
|
499 |
+
key = self.k_ln(key).to(dtype).view(k_shape)
|
500 |
+
if rotary_emb_w_meta_info is not None:
|
501 |
+
rotary_emb = rotary_emb_w_meta_info["rotary_emb"]
|
502 |
+
seq_len = rotary_emb_w_meta_info["seq_len"]
|
503 |
+
offset_info = rotary_emb_w_meta_info["offset_info"]
|
504 |
+
(bsz, seqlen) = query.shape[:2]
|
505 |
+
query = query.view(bsz, seqlen, -1, self.head_dim)
|
506 |
+
key = key.view(bsz, seqlen, -1, self.head_dim)
|
507 |
+
if rotary_emb_w_meta_info["impl"] == "dail":
|
508 |
+
value = value.view(bsz, seqlen, -1, self.head_dim)
|
509 |
+
kv = torch.stack([key, value], dim=2)
|
510 |
+
(query, kv) = rotary_emb(
|
511 |
+
query, kv, seqlen_offset=offset_info, max_seqlen=seq_len
|
512 |
+
)
|
513 |
+
[key, value] = torch.unbind(kv, dim=2)
|
514 |
+
value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
|
515 |
+
elif rotary_emb_w_meta_info["impl"] == "hf":
|
516 |
+
(cos, sin) = rotary_emb(value, seq_len)
|
517 |
+
if is_transformers_version_gte("4.36"):
|
518 |
+
(query, key) = apply_rotary_pos_emb(
|
519 |
+
query, key, cos, sin, offset_info, unsqueeze_dim=2
|
520 |
+
)
|
521 |
+
else:
|
522 |
+
query = query.transpose(1, 2)
|
523 |
+
key = key.transpose(1, 2)
|
524 |
+
(query, key) = apply_rotary_pos_emb(
|
525 |
+
query, key, cos, sin, offset_info
|
526 |
+
)
|
527 |
+
query = query.transpose(1, 2)
|
528 |
+
key = key.transpose(1, 2)
|
529 |
+
query = query.view(bsz, seqlen, self.d_model)
|
530 |
+
key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
|
531 |
+
extra_attn_kwargs = {}
|
532 |
+
if self.attn_impl == "flash":
|
533 |
+
key_padding_mask = None
|
534 |
+
extra_attn_kwargs = {
|
535 |
+
"should_repeat_kv_for_gqa": not is_flash_v2_installed(),
|
536 |
+
"sliding_window_size": self.sliding_window_size,
|
537 |
+
"alibi_slopes": alibi_slopes,
|
538 |
+
"flash_attn_padding_info": flash_attn_padding_info,
|
539 |
+
}
|
540 |
+
(context, attn_weights, past_key_value) = self.attn_fn(
|
541 |
+
query,
|
542 |
+
key,
|
543 |
+
value,
|
544 |
+
self.n_heads,
|
545 |
+
self.kv_n_heads,
|
546 |
+
past_key_value=past_key_value,
|
547 |
+
softmax_scale=self.softmax_scale,
|
548 |
+
attn_bias=attn_bias,
|
549 |
+
key_padding_mask=key_padding_mask,
|
550 |
+
is_causal=is_causal,
|
551 |
+
dropout_p=self.attn_dropout_p,
|
552 |
+
training=self.training,
|
553 |
+
needs_weights=needs_weights,
|
554 |
+
**extra_attn_kwargs,
|
555 |
+
)
|
556 |
+
return (self.out_proj(context), attn_weights, past_key_value)
|
557 |
+
|
558 |
+
|
559 |
+
class MultiheadAttention(GroupedQueryAttention):
|
560 |
+
"""Multi-head self attention.
|
561 |
+
|
562 |
+
Using torch or triton attention implementation enables user to also use
|
563 |
+
additive bias.
|
564 |
+
"""
|
565 |
+
|
566 |
+
def __init__(
|
567 |
+
self,
|
568 |
+
d_model: int,
|
569 |
+
n_heads: int,
|
570 |
+
attn_impl: str = "triton",
|
571 |
+
clip_qkv: Optional[float] = None,
|
572 |
+
qk_ln: bool = False,
|
573 |
+
qk_gn: bool = False,
|
574 |
+
softmax_scale: Optional[float] = None,
|
575 |
+
attn_pdrop: float = 0.0,
|
576 |
+
norm_type: str = "low_precision_layernorm",
|
577 |
+
fc_type: str = "torch",
|
578 |
+
device: Optional[str] = None,
|
579 |
+
bias: bool = True,
|
580 |
+
sliding_window_size: int = -1,
|
581 |
+
):
|
582 |
+
super().__init__(
|
583 |
+
d_model=d_model,
|
584 |
+
n_heads=n_heads,
|
585 |
+
kv_n_heads=n_heads,
|
586 |
+
attn_impl=attn_impl,
|
587 |
+
clip_qkv=clip_qkv,
|
588 |
+
qk_ln=qk_ln,
|
589 |
+
qk_gn=qk_gn,
|
590 |
+
softmax_scale=softmax_scale,
|
591 |
+
attn_pdrop=attn_pdrop,
|
592 |
+
norm_type=norm_type,
|
593 |
+
fc_type=fc_type,
|
594 |
+
device=device,
|
595 |
+
bias=bias,
|
596 |
+
sliding_window_size=sliding_window_size,
|
597 |
+
)
|
598 |
+
|
599 |
+
|
600 |
+
class MultiQueryAttention(GroupedQueryAttention):
|
601 |
+
"""Multi-Query self attention.
|
602 |
+
|
603 |
+
Using torch or triton attention implementation enables user to also use
|
604 |
+
additive bias.
|
605 |
+
"""
|
606 |
+
|
607 |
+
def __init__(
|
608 |
+
self,
|
609 |
+
d_model: int,
|
610 |
+
n_heads: int,
|
611 |
+
attn_impl: str = "triton",
|
612 |
+
clip_qkv: Optional[float] = None,
|
613 |
+
qk_ln: bool = False,
|
614 |
+
qk_gn: bool = False,
|
615 |
+
softmax_scale: Optional[float] = None,
|
616 |
+
attn_pdrop: float = 0.0,
|
617 |
+
norm_type: str = "low_precision_layernorm",
|
618 |
+
fc_type: str = "torch",
|
619 |
+
device: Optional[str] = None,
|
620 |
+
bias: bool = True,
|
621 |
+
sliding_window_size: int = -1,
|
622 |
+
):
|
623 |
+
super().__init__(
|
624 |
+
d_model=d_model,
|
625 |
+
n_heads=n_heads,
|
626 |
+
kv_n_heads=1,
|
627 |
+
attn_impl=attn_impl,
|
628 |
+
clip_qkv=clip_qkv,
|
629 |
+
qk_ln=qk_ln,
|
630 |
+
qk_gn=qk_gn,
|
631 |
+
softmax_scale=softmax_scale,
|
632 |
+
attn_pdrop=attn_pdrop,
|
633 |
+
norm_type=norm_type,
|
634 |
+
fc_type=fc_type,
|
635 |
+
device=device,
|
636 |
+
bias=bias,
|
637 |
+
sliding_window_size=sliding_window_size,
|
638 |
+
)
|
639 |
+
|
640 |
+
|
641 |
+
def attn_bias_shape(
|
642 |
+
attn_impl: str,
|
643 |
+
n_heads: int,
|
644 |
+
seq_len: int,
|
645 |
+
alibi: bool,
|
646 |
+
prefix_lm: bool,
|
647 |
+
causal: bool,
|
648 |
+
use_sequence_id: bool,
|
649 |
+
) -> Optional[tuple[int, int, int, int]]:
|
650 |
+
if attn_impl == "flash":
|
651 |
+
return None
|
652 |
+
elif attn_impl in ["torch", "triton"]:
|
653 |
+
if alibi:
|
654 |
+
if (prefix_lm or not causal) or use_sequence_id:
|
655 |
+
return (1, n_heads, seq_len, seq_len)
|
656 |
+
return (1, n_heads, 1, seq_len)
|
657 |
+
elif prefix_lm or use_sequence_id:
|
658 |
+
return (1, 1, seq_len, seq_len)
|
659 |
+
return None
|
660 |
+
else:
|
661 |
+
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
662 |
+
|
663 |
+
|
664 |
+
def build_attn_bias(
|
665 |
+
attn_impl: str,
|
666 |
+
attn_bias: torch.Tensor,
|
667 |
+
n_heads: int,
|
668 |
+
seq_len: int,
|
669 |
+
causal: bool = False,
|
670 |
+
alibi: bool = False,
|
671 |
+
alibi_bias_max: int = 8,
|
672 |
+
) -> Optional[torch.Tensor]:
|
673 |
+
if attn_impl == "flash":
|
674 |
+
return None
|
675 |
+
elif attn_impl in ["torch", "triton"]:
|
676 |
+
if alibi:
|
677 |
+
(device, dtype) = (attn_bias.device, attn_bias.dtype)
|
678 |
+
attn_bias = attn_bias.add(
|
679 |
+
build_alibi_bias(
|
680 |
+
n_heads,
|
681 |
+
seq_len,
|
682 |
+
full=not causal,
|
683 |
+
alibi_bias_max=alibi_bias_max,
|
684 |
+
device=device,
|
685 |
+
dtype=dtype,
|
686 |
+
)
|
687 |
+
)
|
688 |
+
return attn_bias
|
689 |
+
else:
|
690 |
+
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
691 |
+
|
692 |
+
|
693 |
+
def gen_slopes(
|
694 |
+
n_heads: int,
|
695 |
+
alibi_bias_max: int = 8,
|
696 |
+
device: Optional[torch.device] = None,
|
697 |
+
return_1d: bool = False,
|
698 |
+
) -> torch.Tensor:
|
699 |
+
_n_heads = 2 ** math.ceil(math.log2(n_heads))
|
700 |
+
m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
|
701 |
+
m = m.mul(alibi_bias_max / _n_heads)
|
702 |
+
slopes = 1.0 / torch.pow(2, m)
|
703 |
+
if _n_heads != n_heads:
|
704 |
+
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
|
705 |
+
if return_1d:
|
706 |
+
return slopes
|
707 |
+
return slopes.view(1, n_heads, 1, 1)
|
708 |
+
|
709 |
+
|
710 |
+
def build_alibi_bias(
|
711 |
+
n_heads: int,
|
712 |
+
seq_len: int,
|
713 |
+
full: bool = False,
|
714 |
+
alibi_bias_max: int = 8,
|
715 |
+
device: Optional[torch.device] = None,
|
716 |
+
dtype: Optional[torch.dtype] = None,
|
717 |
+
) -> torch.Tensor:
|
718 |
+
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(
|
719 |
+
1, 1, 1, seq_len
|
720 |
+
)
|
721 |
+
if full:
|
722 |
+
alibi_bias = alibi_bias - torch.arange(
|
723 |
+
1 - seq_len, 1, dtype=torch.int32, device=device
|
724 |
+
).view(1, 1, seq_len, 1)
|
725 |
+
alibi_bias = alibi_bias.abs().mul(-1)
|
726 |
+
slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
|
727 |
+
alibi_bias = alibi_bias * slopes
|
728 |
+
return alibi_bias.to(dtype=dtype)
|
729 |
+
|
730 |
+
|
731 |
+
ATTN_CLASS_REGISTRY = {
|
732 |
+
"multihead_attention": MultiheadAttention,
|
733 |
+
"multiquery_attention": MultiQueryAttention,
|
734 |
+
"grouped_query_attention": GroupedQueryAttention,
|
735 |
+
}
|
blocks.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""GPT Blocks used for the GPT Model."""
|
2 |
+
|
3 |
+
from typing import Any, Dict, Optional, Tuple
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from .attention import ATTN_CLASS_REGISTRY
|
7 |
+
from .ffn import FFN_CLASS_REGISTRY, build_ffn
|
8 |
+
from .norm import NORM_CLASS_REGISTRY
|
9 |
+
|
10 |
+
try:
|
11 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
12 |
+
except:
|
13 |
+
(unpad_input, pad_input) = (None, None)
|
14 |
+
attn_config_defaults: Dict = {
|
15 |
+
"attn_type": "multihead_attention",
|
16 |
+
"attn_pdrop": 0.0,
|
17 |
+
"attn_impl": "flash",
|
18 |
+
"qk_ln": True,
|
19 |
+
"qk_gn": False,
|
20 |
+
"clip_qkv": None,
|
21 |
+
"softmax_scale": None,
|
22 |
+
"prefix_lm": False,
|
23 |
+
"attn_uses_sequence_id": False,
|
24 |
+
"sliding_window_size": -1,
|
25 |
+
"alibi": False,
|
26 |
+
"alibi_bias_max": 8,
|
27 |
+
"rope": False,
|
28 |
+
"rope_theta": 10000,
|
29 |
+
"rope_impl": "dail",
|
30 |
+
"rope_dail_config": {
|
31 |
+
"type": "original",
|
32 |
+
"pos_idx_in_fp32": True,
|
33 |
+
"xpos_scale_base": 512,
|
34 |
+
},
|
35 |
+
"rope_hf_config": {"type": "no_scaling", "factor": 1.0},
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
class MPTBlock(nn.Module):
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
d_model: int,
|
44 |
+
n_heads: int,
|
45 |
+
expansion_ratio: int,
|
46 |
+
attn_config: Optional[Dict] = None,
|
47 |
+
ffn_config: Optional[Dict] = None,
|
48 |
+
resid_pdrop: float = 0.0,
|
49 |
+
norm_type: str = "low_precision_layernorm",
|
50 |
+
fc_type: str = "torch",
|
51 |
+
device: Optional[str] = None,
|
52 |
+
no_bias: bool = False,
|
53 |
+
use_pad_tok_in_ffn: bool = True,
|
54 |
+
**kwargs: Any
|
55 |
+
):
|
56 |
+
if attn_config is None:
|
57 |
+
attn_config = attn_config_defaults
|
58 |
+
if ffn_config is None:
|
59 |
+
ffn_config = {"ffn_type": "mptmlp"}
|
60 |
+
del kwargs
|
61 |
+
super().__init__()
|
62 |
+
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
63 |
+
assert isinstance(attn_config["attn_type"], str)
|
64 |
+
attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
|
65 |
+
args_to_exclude_in_attn_class = {
|
66 |
+
"attn_type",
|
67 |
+
"prefix_lm",
|
68 |
+
"alibi",
|
69 |
+
"attn_uses_sequence_id",
|
70 |
+
"alibi_bias_max",
|
71 |
+
"rope",
|
72 |
+
"rope_theta",
|
73 |
+
"rope_impl",
|
74 |
+
"rope_dail_config",
|
75 |
+
"rope_hf_config",
|
76 |
+
}
|
77 |
+
attn_config_subset_for_attn_class = {
|
78 |
+
k: v
|
79 |
+
for (k, v) in attn_config.items()
|
80 |
+
if k not in args_to_exclude_in_attn_class
|
81 |
+
}
|
82 |
+
self.norm_1 = norm_class(d_model, device=device)
|
83 |
+
self.attn = attn_class(
|
84 |
+
d_model=d_model,
|
85 |
+
n_heads=n_heads,
|
86 |
+
fc_type=fc_type,
|
87 |
+
device=device,
|
88 |
+
**attn_config_subset_for_attn_class,
|
89 |
+
bias=not no_bias
|
90 |
+
)
|
91 |
+
self.norm_2 = None
|
92 |
+
if not getattr(FFN_CLASS_REGISTRY[ffn_config["ffn_type"]], "_has_norm", False):
|
93 |
+
self.norm_2 = norm_class(d_model, device=device)
|
94 |
+
self.ffn = build_ffn(
|
95 |
+
d_model=d_model,
|
96 |
+
expansion_ratio=expansion_ratio,
|
97 |
+
device=device,
|
98 |
+
bias=not no_bias,
|
99 |
+
**ffn_config
|
100 |
+
)
|
101 |
+
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
102 |
+
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
|
103 |
+
self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
|
104 |
+
|
105 |
+
def forward(
|
106 |
+
self,
|
107 |
+
x: torch.Tensor,
|
108 |
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
109 |
+
attn_bias: Optional[torch.Tensor] = None,
|
110 |
+
rotary_emb_w_meta_info: Optional[Dict] = None,
|
111 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
112 |
+
is_causal: bool = True,
|
113 |
+
output_attentions: bool = False,
|
114 |
+
alibi_slopes: Optional[torch.Tensor] = None,
|
115 |
+
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
|
116 |
+
) -> Tuple[
|
117 |
+
torch.Tensor,
|
118 |
+
Optional[torch.Tensor],
|
119 |
+
Optional[Tuple[torch.Tensor, torch.Tensor]],
|
120 |
+
]:
|
121 |
+
a = self.norm_1(x)
|
122 |
+
(b, attn_weights, past_key_value) = self.attn(
|
123 |
+
a,
|
124 |
+
past_key_value=past_key_value,
|
125 |
+
attn_bias=attn_bias,
|
126 |
+
rotary_emb_w_meta_info=rotary_emb_w_meta_info,
|
127 |
+
attention_mask=attention_mask,
|
128 |
+
is_causal=is_causal,
|
129 |
+
needs_weights=output_attentions,
|
130 |
+
alibi_slopes=alibi_slopes,
|
131 |
+
flash_attn_padding_info=flash_attn_padding_info,
|
132 |
+
)
|
133 |
+
x = x + self.resid_attn_dropout(b)
|
134 |
+
m = x
|
135 |
+
if self.norm_2 is not None:
|
136 |
+
m = self.norm_2(x)
|
137 |
+
(batch_size, seq_len) = m.size()[:2]
|
138 |
+
indices = None
|
139 |
+
if not self.use_pad_tok_in_ffn:
|
140 |
+
assert unpad_input is not None
|
141 |
+
(m, indices, _, _) = unpad_input(m, attention_mask)
|
142 |
+
n = self.ffn(m)
|
143 |
+
if not self.use_pad_tok_in_ffn:
|
144 |
+
assert pad_input is not None
|
145 |
+
n = pad_input(n, indices, batch_size, seq_len)
|
146 |
+
x = x + self.resid_ffn_dropout(n)
|
147 |
+
return (x, attn_weights, past_key_value)
|
config.json
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "/home/users/nus/e0538503/scratch/models/sea-lion-7b-instruct",
|
3 |
+
"architectures": [
|
4 |
+
"MPTForCausalLM"
|
5 |
+
],
|
6 |
+
"attn_config": {
|
7 |
+
"alibi": false,
|
8 |
+
"alibi_bias_max": 8,
|
9 |
+
"attn_impl": "torch",
|
10 |
+
"attn_pdrop": 0.0,
|
11 |
+
"attn_type": "multihead_attention",
|
12 |
+
"attn_uses_sequence_id": false,
|
13 |
+
"clip_qkv": null,
|
14 |
+
"prefix_lm": false,
|
15 |
+
"qk_gn": false,
|
16 |
+
"qk_ln": true,
|
17 |
+
"rope": false,
|
18 |
+
"rope_dail_config": {
|
19 |
+
"pos_idx_in_fp32": true,
|
20 |
+
"type": "original",
|
21 |
+
"xpos_scale_base": 512
|
22 |
+
},
|
23 |
+
"rope_hf_config": {
|
24 |
+
"factor": 1.0,
|
25 |
+
"type": "no_scaling"
|
26 |
+
},
|
27 |
+
"rope_impl": "dail",
|
28 |
+
"rope_theta": 10000,
|
29 |
+
"sliding_window_size": -1,
|
30 |
+
"softmax_scale": null
|
31 |
+
},
|
32 |
+
"auto_map": {
|
33 |
+
"AutoConfig": "configuration_mpt.MPTConfig",
|
34 |
+
"AutoModelForCausalLM": "modeling_mpt.MPTForCausalLM"
|
35 |
+
},
|
36 |
+
"d_model": 4096,
|
37 |
+
"emb_pdrop": 0.0,
|
38 |
+
"embedding_fraction": 0.1,
|
39 |
+
"expansion_ratio": 4,
|
40 |
+
"fc_type": "torch",
|
41 |
+
"ffn_config": {
|
42 |
+
"fc_type": "torch",
|
43 |
+
"ffn_type": "mptmlp"
|
44 |
+
},
|
45 |
+
"init_config": {
|
46 |
+
"emb_init_std": null,
|
47 |
+
"emb_init_uniform_lim": null,
|
48 |
+
"fan_mode": "fan_in",
|
49 |
+
"init_div_is_residual": true,
|
50 |
+
"init_gain": 0.0,
|
51 |
+
"init_nonlinearity": "relu",
|
52 |
+
"init_std": null,
|
53 |
+
"name": "kaiming_normal_",
|
54 |
+
"verbose": 0
|
55 |
+
},
|
56 |
+
"init_config_defaults": {
|
57 |
+
"init_std": 0.02
|
58 |
+
},
|
59 |
+
"init_device": "cpu",
|
60 |
+
"learned_pos_emb": true,
|
61 |
+
"logit_scale": "inv_sqrt_d_model",
|
62 |
+
"max_seq_len": 2048,
|
63 |
+
"model_type": "mpt",
|
64 |
+
"n_heads": 32,
|
65 |
+
"n_layers": 32,
|
66 |
+
"no_bias": false,
|
67 |
+
"norm_type": "low_precision_layernorm",
|
68 |
+
"resid_pdrop": 0.0,
|
69 |
+
"torch_dtype": "float16",
|
70 |
+
"transformers_version": "4.38.2",
|
71 |
+
"use_cache": false,
|
72 |
+
"use_pad_tok_in_ffn": true,
|
73 |
+
"vocab_size": 256000
|
74 |
+
}
|
configuration_mpt.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A HuggingFace-style model configuration."""
|
2 |
+
|
3 |
+
import warnings
|
4 |
+
from typing import Any, Dict, Optional, Union
|
5 |
+
from transformers import PretrainedConfig
|
6 |
+
from .attention import check_alibi_support, is_flash_v1_installed, is_flash_v2_installed
|
7 |
+
from .blocks import attn_config_defaults
|
8 |
+
from .fc import FC_CLASS_REGISTRY
|
9 |
+
from .norm import LPLayerNorm
|
10 |
+
from .ffn import FFN_CLASS_REGISTRY
|
11 |
+
from .warnings import VersionedDeprecationWarning
|
12 |
+
|
13 |
+
ffn_config_defaults: Dict = {"ffn_type": "mptmlp"}
|
14 |
+
init_config_defaults: Dict = {
|
15 |
+
"name": "kaiming_normal_",
|
16 |
+
"fan_mode": "fan_in",
|
17 |
+
"init_nonlinearity": "relu",
|
18 |
+
"init_div_is_residual": True,
|
19 |
+
"emb_init_std": None,
|
20 |
+
"emb_init_uniform_lim": None,
|
21 |
+
"init_std": None,
|
22 |
+
"init_gain": 0.0,
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
class MPTConfig(PretrainedConfig):
|
27 |
+
model_type = "mpt"
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
d_model: int = 2048,
|
32 |
+
n_heads: int = 16,
|
33 |
+
n_layers: int = 24,
|
34 |
+
expansion_ratio: Union[int, float] = 4,
|
35 |
+
max_seq_len: int = 2048,
|
36 |
+
vocab_size: int = 50368,
|
37 |
+
resid_pdrop: float = 0.0,
|
38 |
+
emb_pdrop: float = 0.0,
|
39 |
+
learned_pos_emb: bool = True,
|
40 |
+
attn_config: Dict = attn_config_defaults,
|
41 |
+
ffn_config: Dict = ffn_config_defaults,
|
42 |
+
init_device: str = "cpu",
|
43 |
+
logit_scale: Optional[Union[float, str]] = None,
|
44 |
+
no_bias: bool = False,
|
45 |
+
embedding_fraction: float = 1.0,
|
46 |
+
norm_type: str = "low_precision_layernorm",
|
47 |
+
use_cache: bool = False,
|
48 |
+
init_config: Dict = init_config_defaults,
|
49 |
+
fc_type: str = "torch",
|
50 |
+
tie_word_embeddings: bool = True,
|
51 |
+
use_pad_tok_in_ffn: bool = True,
|
52 |
+
**kwargs: Any,
|
53 |
+
):
|
54 |
+
"""The MPT configuration class.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
d_model (int): The size of the embedding dimension of the model.
|
58 |
+
n_heads (int): The number of attention heads.
|
59 |
+
n_layers (int): The number of layers in the model.
|
60 |
+
expansion_ratio (Union[int, float]): The ratio of the up/down scale in the ffn.
|
61 |
+
max_seq_len (int): The maximum sequence length of the model.
|
62 |
+
vocab_size (int): The size of the vocabulary.
|
63 |
+
resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
|
64 |
+
emb_pdrop (float): The dropout probability for the embedding layer.
|
65 |
+
learned_pos_emb (bool): Whether to use learned positional embeddings
|
66 |
+
attn_config (Dict): A dictionary used to configure the model's attention module:
|
67 |
+
attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, grouped_query_attention
|
68 |
+
attn_pdrop (float): The dropout probability for the attention layers.
|
69 |
+
attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
|
70 |
+
qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
|
71 |
+
qk_gn (bool): Whether to apply group normalization to the queries and keys in the attention layer.
|
72 |
+
clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
|
73 |
+
this value.
|
74 |
+
softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
|
75 |
+
use the default scale of ``1/sqrt(d_keys)``.
|
76 |
+
prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
|
77 |
+
extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
|
78 |
+
can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
|
79 |
+
attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
|
80 |
+
When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
|
81 |
+
which sub-sequence each token belongs to.
|
82 |
+
Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
|
83 |
+
sliding_window_size (int): Window size for sliding window local attention. Defaults to -1, which means no sliding window. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size, i + seqlen_k - seqlen_q + window_size] inclusive. Only works for flash attention v2.3.0 or higher.
|
84 |
+
alibi (bool): Whether to use the alibi bias instead of position embeddings.
|
85 |
+
alibi_bias_max (int): The maximum value of the alibi bias.
|
86 |
+
rope (bool): Whether to use rotary positional embeddings.
|
87 |
+
rope_theta (int): The base frequency for rope.
|
88 |
+
rope_impl (str): The implementation of rope to use. One of 'hf' (to use the implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) or 'dail' (to use the implementation from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py).
|
89 |
+
rope_dail_config (Dict): The configuration for the dail implementation of rope.
|
90 |
+
type (str): The type of rotary position embedding to use. Options: 'original' (for https://arxiv.org/pdf/2104.09864.pdf), 'xpos' (for https://arxiv.org/pdf/2212.10554.pdf).
|
91 |
+
pos_idx_in_fp32 (bool): If True, the position indices [0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. A consequence could be, for example, that bf16 rounds position 1995 to 2000, which leads to them having the same positional embedding.
|
92 |
+
xpos_scale_base (float): The scale base for XPos (if using XPos).
|
93 |
+
rope_hf_config (Dict): A dictionary used to configure rope's scaling behavior (when scaling beyond the training length).
|
94 |
+
type (str): Can be one of 'no_scaling', 'linear', or 'dynamic'. 'no_scaling' uses the default implementation for rotary embeddings, 'linear' uses linear scaling as proposed by the Reddit user /u/kaiokendev, and 'dynamic' uses Dynamic NTK scaling as proposed by the Reddit users /u/bloc97 and /u/emozilla.
|
95 |
+
factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type.
|
96 |
+
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
|
97 |
+
ffn_config (Dict): A dictionary used to configure the model's ffn module:
|
98 |
+
ffn_type (str): type of ffn to use. Options: mptmlp, mptglu, te_ln_mlp
|
99 |
+
init_device (str): The device to use for parameter initialization.
|
100 |
+
logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
|
101 |
+
no_bias (bool): Whether to use bias in all layers.
|
102 |
+
embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
|
103 |
+
norm_type (str): choose type of norm to use
|
104 |
+
use_cache (bool): Whether or not the model should return the last key/values attentions
|
105 |
+
init_config (Dict): A dictionary used to configure the model initialization:
|
106 |
+
init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
|
107 |
+
'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
|
108 |
+
'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
|
109 |
+
init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
|
110 |
+
emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
|
111 |
+
emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
|
112 |
+
used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
|
113 |
+
init_std (float): The standard deviation of the normal distribution used to initialize the model,
|
114 |
+
if using the baseline_ parameter initialization scheme.
|
115 |
+
init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
|
116 |
+
fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
|
117 |
+
init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
|
118 |
+
---
|
119 |
+
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
|
120 |
+
fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
|
121 |
+
tie_word_embeddings (bool): Whether to tie the input embedding and output layers.
|
122 |
+
use_pad_tok_in_ffn (bool): Whether to forward the pad token in the feedforward networks.
|
123 |
+
"""
|
124 |
+
self.d_model = d_model
|
125 |
+
self.n_heads = n_heads
|
126 |
+
self.n_layers = n_layers
|
127 |
+
self.expansion_ratio = expansion_ratio
|
128 |
+
self.max_seq_len = max_seq_len
|
129 |
+
self.vocab_size = vocab_size
|
130 |
+
self.resid_pdrop = resid_pdrop
|
131 |
+
self.emb_pdrop = emb_pdrop
|
132 |
+
self.learned_pos_emb = learned_pos_emb
|
133 |
+
self.attn_config = attn_config
|
134 |
+
self.ffn_config = ffn_config
|
135 |
+
self.init_device = init_device
|
136 |
+
self.logit_scale = logit_scale
|
137 |
+
self.no_bias = no_bias
|
138 |
+
self.embedding_fraction = embedding_fraction
|
139 |
+
self.norm_type = norm_type
|
140 |
+
self.use_cache = use_cache
|
141 |
+
self.init_config = init_config
|
142 |
+
self.fc_type = fc_type
|
143 |
+
self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
|
144 |
+
if "name" in kwargs:
|
145 |
+
del kwargs["name"]
|
146 |
+
if "loss_fn" in kwargs:
|
147 |
+
del kwargs["loss_fn"]
|
148 |
+
if self.attn_config.get("alibi", False) or self.attn_config.get("rope", False):
|
149 |
+
self.learned_pos_emb = False
|
150 |
+
warnings.warn(
|
151 |
+
f"alibi or rope is turned on, setting `learned_pos_emb` to `False.`"
|
152 |
+
)
|
153 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
154 |
+
self._validate_config()
|
155 |
+
|
156 |
+
def _set_config_defaults(
|
157 |
+
self, config: Dict[str, Any], config_defaults: Dict[str, Any]
|
158 |
+
) -> Dict[str, Any]:
|
159 |
+
for k, v in config_defaults.items():
|
160 |
+
if k not in config:
|
161 |
+
config[k] = v
|
162 |
+
elif isinstance(v, dict):
|
163 |
+
config[k] = self._set_config_defaults(
|
164 |
+
config[k] if config[k] is not None else {}, v
|
165 |
+
)
|
166 |
+
return config
|
167 |
+
|
168 |
+
def _validate_config(self) -> None:
|
169 |
+
self.attn_config = self._set_config_defaults(
|
170 |
+
self.attn_config, attn_config_defaults
|
171 |
+
)
|
172 |
+
self.ffn_config = self._set_config_defaults(
|
173 |
+
self.ffn_config, ffn_config_defaults
|
174 |
+
)
|
175 |
+
self.init_config = self._set_config_defaults(
|
176 |
+
self.init_config, init_config_defaults
|
177 |
+
)
|
178 |
+
if self.d_model % self.n_heads != 0:
|
179 |
+
raise ValueError("d_model must be divisible by n_heads")
|
180 |
+
if any(
|
181 |
+
(
|
182 |
+
prob < 0 or prob > 1
|
183 |
+
for prob in [
|
184 |
+
self.attn_config["attn_pdrop"],
|
185 |
+
self.resid_pdrop,
|
186 |
+
self.emb_pdrop,
|
187 |
+
]
|
188 |
+
)
|
189 |
+
):
|
190 |
+
raise ValueError(
|
191 |
+
"self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1"
|
192 |
+
)
|
193 |
+
if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]:
|
194 |
+
raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
|
195 |
+
if self.attn_config["prefix_lm"] and self.attn_config["attn_impl"] not in [
|
196 |
+
"torch",
|
197 |
+
"triton",
|
198 |
+
]:
|
199 |
+
raise NotImplementedError(
|
200 |
+
"prefix_lm only implemented with torch and triton attention."
|
201 |
+
)
|
202 |
+
if self.attn_config["attn_impl"] == "flash" and is_flash_v1_installed():
|
203 |
+
warnings.warn(
|
204 |
+
VersionedDeprecationWarning(
|
205 |
+
'Support for Flash Attention v1 is deprecated. Please upgrade to Flash Attention v2.4.2. To install Flash Attention v2.4.2, please run `pip install -e ".[gpu-flash2]"` from the root directory of the llm-foundry repository.',
|
206 |
+
remove_version="0.6.0",
|
207 |
+
)
|
208 |
+
)
|
209 |
+
if self.attn_config["attn_impl"] == "triton" and (
|
210 |
+
not self.attn_config["prefix_lm"]
|
211 |
+
):
|
212 |
+
warnings.warn(
|
213 |
+
UserWarning(
|
214 |
+
'If not using a Prefix Language Model, we recommend setting "attn_impl" to "flash" instead of "triton".'
|
215 |
+
)
|
216 |
+
)
|
217 |
+
if self.attn_config["alibi"] and (
|
218 |
+
not check_alibi_support(self.attn_config["attn_impl"])
|
219 |
+
):
|
220 |
+
raise NotImplementedError(
|
221 |
+
"alibi only implemented with torch, triton, and flash (v2.4.2 or higher) attention."
|
222 |
+
)
|
223 |
+
if self.attn_config["attn_uses_sequence_id"] and (
|
224 |
+
not (
|
225 |
+
self.attn_config["attn_impl"] in ["torch", "triton"]
|
226 |
+
or (
|
227 |
+
self.attn_config["attn_impl"] == "flash"
|
228 |
+
and is_flash_v2_installed(v2_version="v2.1.2")
|
229 |
+
)
|
230 |
+
)
|
231 |
+
):
|
232 |
+
raise NotImplementedError(
|
233 |
+
"attn_uses_sequence_id only implemented with torch, triton, and flash (v2.1.2 or higher) attention."
|
234 |
+
)
|
235 |
+
if self.attn_config["rope"] and self.attn_config["rope_impl"] not in [
|
236 |
+
"dail",
|
237 |
+
"hf",
|
238 |
+
]:
|
239 |
+
raise ValueError(
|
240 |
+
'If rope is being used then rope_impl should be either "dail", or "hf".'
|
241 |
+
)
|
242 |
+
if (
|
243 |
+
self.attn_config["rope"]
|
244 |
+
and self.attn_config["rope_impl"] == "hf"
|
245 |
+
and (
|
246 |
+
self.attn_config["rope_hf_config"]["type"]
|
247 |
+
not in ["no_scaling", "linear", "dynamic"]
|
248 |
+
)
|
249 |
+
):
|
250 |
+
raise ValueError(
|
251 |
+
'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".'
|
252 |
+
)
|
253 |
+
if self.attn_config["rope"] and self.attn_config["rope_impl"] == "dail":
|
254 |
+
if self.attn_config["rope_dail_config"]["type"] not in ["original", "xpos"]:
|
255 |
+
raise ValueError(
|
256 |
+
'If using the dail implementation of rope, the type should be one of "original" or "xpos".'
|
257 |
+
)
|
258 |
+
if not is_flash_v2_installed(v2_version="2.0.1"):
|
259 |
+
raise ImportError(
|
260 |
+
"If using the dail implementation of rope, the flash_attn library v2.0.1 or higher must be installed. Please check the instructions at https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support"
|
261 |
+
)
|
262 |
+
if self.attn_config["sliding_window_size"] != -1 and (
|
263 |
+
not (
|
264 |
+
self.attn_config["attn_impl"] == "flash"
|
265 |
+
and is_flash_v2_installed(v2_version="v2.3.0")
|
266 |
+
)
|
267 |
+
):
|
268 |
+
raise NotImplementedError(
|
269 |
+
"sliding window only implemented with flash attention v2.3.0 or higher."
|
270 |
+
)
|
271 |
+
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
|
272 |
+
raise ValueError(
|
273 |
+
"model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!"
|
274 |
+
)
|
275 |
+
if isinstance(self.logit_scale, str) and self.logit_scale != "inv_sqrt_d_model":
|
276 |
+
raise ValueError(
|
277 |
+
f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
|
278 |
+
)
|
279 |
+
if self.init_config.get("name", None) is None:
|
280 |
+
raise ValueError(
|
281 |
+
f"self.init_config={self.init_config!r} 'name' needs to be set."
|
282 |
+
)
|
283 |
+
if not (
|
284 |
+
self.learned_pos_emb
|
285 |
+
or self.attn_config["alibi"]
|
286 |
+
or self.attn_config["rope"]
|
287 |
+
):
|
288 |
+
warnings.warn(
|
289 |
+
f"Positional information not being provided to the model using either learned_pos_emb or alibi or rope."
|
290 |
+
)
|
291 |
+
if self.fc_type == "te" or self.ffn_config["ffn_type"] == "te_ln_mlp":
|
292 |
+
try:
|
293 |
+
import transformer_engine.pytorch as te
|
294 |
+
|
295 |
+
del te
|
296 |
+
except:
|
297 |
+
raise ImportError(
|
298 |
+
"TransformerEngine import fail. `fc_type: te` requires TransformerEngine be installed. "
|
299 |
+
+ "The required version of transformer_engine also requires FlashAttention v1.0.6 is installed:\n"
|
300 |
+
+ "pip install flash-attn==1.0.6 --no-build-isolation \n"
|
301 |
+
+ "pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156"
|
302 |
+
)
|
303 |
+
if self.ffn_config["ffn_type"] == "mptgeglu":
|
304 |
+
raise ValueError(
|
305 |
+
'API CHANGE: `ffn_type=="mptgeglu"` changed to `ffn_type=="mptglu"`. '
|
306 |
+
+ "See [#829](https://github.com/mosaicml/llm-foundry/pull/829) for details."
|
307 |
+
)
|
308 |
+
elif self.ffn_config["ffn_type"] in ["mptmlp", "mptglu"]:
|
309 |
+
self.ffn_config["fc_type"] = self.fc_type
|
310 |
+
elif self.ffn_config["ffn_type"] == "te_ln_mlp":
|
311 |
+
self.ffn_config["bias"] = not self.no_bias
|
312 |
+
if "ffn_act_fn" in self.ffn_config.keys():
|
313 |
+
raise ValueError(
|
314 |
+
f"Transformer Engine block does not support custom activation functions."
|
315 |
+
)
|
316 |
+
if not self.use_pad_tok_in_ffn:
|
317 |
+
try:
|
318 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
319 |
+
except:
|
320 |
+
raise ImportError(
|
321 |
+
"In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.6"
|
322 |
+
)
|
custom_embedding.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import Tensor
|
4 |
+
|
5 |
+
|
6 |
+
class SharedEmbedding(nn.Embedding):
|
7 |
+
|
8 |
+
def forward(self, input: Tensor, unembed: bool = False) -> Tensor:
|
9 |
+
if unembed:
|
10 |
+
return F.linear(input, self.weight)
|
11 |
+
return super().forward(input)
|
fc.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
FC_CLASS_REGISTRY = {"torch": nn.Linear}
|
4 |
+
try:
|
5 |
+
import transformer_engine.pytorch as te
|
6 |
+
|
7 |
+
FC_CLASS_REGISTRY["te"] = te.Linear
|
8 |
+
except:
|
9 |
+
pass
|
ffn.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MPT Blocks used for the MPT Model."""
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from copy import deepcopy
|
5 |
+
from functools import partial
|
6 |
+
from typing import Any, Callable, Optional, Union
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from .fc import FC_CLASS_REGISTRY
|
10 |
+
|
11 |
+
try:
|
12 |
+
import transformer_engine.pytorch as te
|
13 |
+
except:
|
14 |
+
te = None
|
15 |
+
log = logging.getLogger(__name__)
|
16 |
+
_FFN_ACT_FN_DEFAULT = {"name": "gelu", "approximate": "none"}
|
17 |
+
|
18 |
+
|
19 |
+
def resolve_ffn_act_fn(
|
20 |
+
config: Optional[dict] = None,
|
21 |
+
) -> Callable[[torch.Tensor], torch.Tensor]:
|
22 |
+
"""Resolve the activation function for the feed-forward network.
|
23 |
+
Args:
|
24 |
+
config (Optional[dict]): The configuration dictionary for the activation function.
|
25 |
+
The dict config must specify the 'name' of a torch.nn.functional activation
|
26 |
+
function. All of other key values pairs are bound to the function as a partial.
|
27 |
+
Returns:
|
28 |
+
Callable[[torch.Tensor], torch.Tensor]: The activation function.
|
29 |
+
"""
|
30 |
+
if config is None:
|
31 |
+
config = _FFN_ACT_FN_DEFAULT
|
32 |
+
config = deepcopy(config)
|
33 |
+
name = config.pop("name")
|
34 |
+
if not hasattr(torch.nn.functional, name):
|
35 |
+
raise ValueError(f"Unrecognised activation function name ({name}).")
|
36 |
+
act = getattr(torch.nn.functional, name)
|
37 |
+
return partial(act, **config)
|
38 |
+
|
39 |
+
|
40 |
+
_DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT)
|
41 |
+
|
42 |
+
|
43 |
+
def resolve_ffn_hidden_size(
|
44 |
+
d_model: int,
|
45 |
+
expansion_ratio: Union[int, float],
|
46 |
+
ffn_hidden_size: Optional[int] = None,
|
47 |
+
) -> int:
|
48 |
+
"""Resolve the hidden size of the feed-forward network.
|
49 |
+
Args:
|
50 |
+
d_model (int): The dimension of the input and output of the feed-forward network.
|
51 |
+
expansion_ratio (Union[int, float]): The expansion ratio of the feed-forward network.
|
52 |
+
ffn_hidden_size (Optional[int]): The hidden size of the feed-forward network.
|
53 |
+
Returns:
|
54 |
+
int: The hidden size of the feed-forward network.
|
55 |
+
"""
|
56 |
+
if ffn_hidden_size is not None:
|
57 |
+
log.info(
|
58 |
+
f"`expansion_ratio` (={expansion_ratio}) ignored when `ffn_hidden_size` (={ffn_hidden_size}) is specified."
|
59 |
+
)
|
60 |
+
else:
|
61 |
+
ffn_hidden_size = int(d_model * expansion_ratio)
|
62 |
+
if ffn_hidden_size != d_model * expansion_ratio:
|
63 |
+
raise ValueError(
|
64 |
+
f"`d_model * expansion_ratio` must be an integer (d_model={d_model!r}; expansion_ratio={expansion_ratio!r}; d_model * expansion_ratio={d_model * expansion_ratio!r})."
|
65 |
+
)
|
66 |
+
return ffn_hidden_size
|
67 |
+
|
68 |
+
|
69 |
+
class MPTMLP(nn.Module):
|
70 |
+
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
d_model: int,
|
74 |
+
expansion_ratio: Union[int, float],
|
75 |
+
fc_type: str = "torch",
|
76 |
+
ffn_hidden_size: Optional[int] = None,
|
77 |
+
act_fn: Callable[[torch.Tensor], torch.Tensor] = _DEFAULT_ACT_FN,
|
78 |
+
device: Optional[str] = None,
|
79 |
+
bias: bool = True,
|
80 |
+
):
|
81 |
+
super().__init__()
|
82 |
+
ffn_hidden_size = resolve_ffn_hidden_size(
|
83 |
+
d_model, expansion_ratio, ffn_hidden_size
|
84 |
+
)
|
85 |
+
self.fc_kwargs: dict[str, Any] = {"bias": bias}
|
86 |
+
if fc_type != "te":
|
87 |
+
self.fc_kwargs["device"] = device
|
88 |
+
self.up_proj = FC_CLASS_REGISTRY[fc_type](
|
89 |
+
d_model, ffn_hidden_size, **self.fc_kwargs
|
90 |
+
)
|
91 |
+
self.act = act_fn
|
92 |
+
self.down_proj = FC_CLASS_REGISTRY[fc_type](
|
93 |
+
ffn_hidden_size, d_model, **self.fc_kwargs
|
94 |
+
)
|
95 |
+
self.down_proj._is_residual = True
|
96 |
+
|
97 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
98 |
+
return self.down_proj(self.act(self.up_proj(x)))
|
99 |
+
|
100 |
+
|
101 |
+
class MPTGLU(MPTMLP):
|
102 |
+
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
d_model: int,
|
106 |
+
expansion_ratio: Union[int, float],
|
107 |
+
fc_type: str = "torch",
|
108 |
+
ffn_hidden_size: Optional[int] = None,
|
109 |
+
act_fn: Callable[[torch.Tensor], torch.Tensor] = _DEFAULT_ACT_FN,
|
110 |
+
device: Optional[str] = None,
|
111 |
+
bias: bool = True,
|
112 |
+
):
|
113 |
+
super().__init__(
|
114 |
+
d_model=d_model,
|
115 |
+
expansion_ratio=expansion_ratio,
|
116 |
+
fc_type=fc_type,
|
117 |
+
ffn_hidden_size=ffn_hidden_size,
|
118 |
+
act_fn=act_fn,
|
119 |
+
device=device,
|
120 |
+
bias=bias,
|
121 |
+
)
|
122 |
+
self.gate_proj = FC_CLASS_REGISTRY[fc_type](
|
123 |
+
d_model, self.up_proj.out_features, **self.fc_kwargs
|
124 |
+
)
|
125 |
+
|
126 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
127 |
+
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
|
128 |
+
|
129 |
+
|
130 |
+
FFN_CLASS_REGISTRY = {"mptmlp": MPTMLP, "mptglu": MPTGLU}
|
131 |
+
if te is not None:
|
132 |
+
te.LayerNormMLP._has_norm = True
|
133 |
+
FFN_CLASS_REGISTRY["te_ln_mlp"] = te.LayerNormMLP
|
134 |
+
|
135 |
+
|
136 |
+
def build_ffn(
|
137 |
+
d_model: int,
|
138 |
+
expansion_ratio: Union[int, float],
|
139 |
+
fc_type: str = "torch",
|
140 |
+
ffn_hidden_size: Optional[int] = None,
|
141 |
+
ffn_act_fn: Optional[dict] = None,
|
142 |
+
device: Optional[str] = None,
|
143 |
+
bias: bool = True,
|
144 |
+
**kwargs: Any,
|
145 |
+
) -> nn.Module:
|
146 |
+
ffn_type = kwargs.pop("ffn_type")
|
147 |
+
if ffn_type in ["mptmlp", "mptglu"]:
|
148 |
+
if len(kwargs) > 0:
|
149 |
+
raise ValueError(
|
150 |
+
f"MPTMLP (or MPTGLU) got an unexpected keyword argument: {kwargs}"
|
151 |
+
)
|
152 |
+
return FFN_CLASS_REGISTRY[ffn_type](
|
153 |
+
d_model=d_model,
|
154 |
+
expansion_ratio=expansion_ratio,
|
155 |
+
fc_type=fc_type,
|
156 |
+
act_fn=resolve_ffn_act_fn(ffn_act_fn),
|
157 |
+
ffn_hidden_size=ffn_hidden_size,
|
158 |
+
device=device,
|
159 |
+
bias=bias,
|
160 |
+
)
|
161 |
+
elif ffn_type == "te_ln_mlp":
|
162 |
+
assert te is not None
|
163 |
+
ffn_hidden_size = resolve_ffn_hidden_size(
|
164 |
+
d_model, expansion_ratio, ffn_hidden_size
|
165 |
+
)
|
166 |
+
if ffn_act_fn is not None:
|
167 |
+
raise ValueError(
|
168 |
+
f"Transformer Engine block does not support custom activation functions."
|
169 |
+
)
|
170 |
+
return te.LayerNormMLP(
|
171 |
+
hidden_size=d_model, ffn_hidden_size=ffn_hidden_size, bias=bias, **kwargs
|
172 |
+
)
|
173 |
+
raise ValueError(f"ffn_type={ffn_type!r} not recognized.")
|
flash_attn_triton.py
ADDED
@@ -0,0 +1,1085 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py
|
3 |
+
update imports to use 'triton_pre_mlir'
|
4 |
+
*Experimental* implementation of FlashAttention in Triton.
|
5 |
+
Tested with triton==2.0.0.dev20221202.
|
6 |
+
Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
|
7 |
+
other than 64:
|
8 |
+
https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
|
9 |
+
We'll update this implementation with the new Triton backend once this is fixed.
|
10 |
+
We use the FlashAttention implementation from Phil Tillet a starting point.
|
11 |
+
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
|
12 |
+
Changes:
|
13 |
+
- Implement both causal and non-causal attention.
|
14 |
+
- Implement both self-attention and cross-attention.
|
15 |
+
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
|
16 |
+
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
|
17 |
+
- Support attention bias.
|
18 |
+
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
|
19 |
+
- Make the backward for d=128 much faster by reducing register spilling.
|
20 |
+
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
|
21 |
+
small batch size * nheads.
|
22 |
+
Caution:
|
23 |
+
- This is an *experimental* implementation. The forward pass should be quite robust but
|
24 |
+
I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
|
25 |
+
- This implementation has only been tested on A100.
|
26 |
+
- If you plan to use headdim other than 64 and 128, you should test for race conditions
|
27 |
+
(due to the Triton compiler), as done in tests/test_flash_attn.py
|
28 |
+
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
|
29 |
+
for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
|
30 |
+
that there are none left for other head dimensions.
|
31 |
+
Differences between this Triton version and the CUDA version:
|
32 |
+
- Triton version doesn't support dropout.
|
33 |
+
- Triton forward is generally faster than CUDA forward, while Triton backward is
|
34 |
+
generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
|
35 |
+
than CUDA forward + backward.
|
36 |
+
- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
|
37 |
+
- Triton version supports attention bias, while CUDA version doesn't.
|
38 |
+
"""
|
39 |
+
|
40 |
+
import math
|
41 |
+
import torch
|
42 |
+
import triton_pre_mlir as triton
|
43 |
+
import triton_pre_mlir.language as tl
|
44 |
+
|
45 |
+
|
46 |
+
@triton.heuristics(
|
47 |
+
{
|
48 |
+
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
49 |
+
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
50 |
+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
51 |
+
}
|
52 |
+
)
|
53 |
+
@triton.jit
|
54 |
+
def _fwd_kernel(
|
55 |
+
Q,
|
56 |
+
K,
|
57 |
+
V,
|
58 |
+
Bias,
|
59 |
+
Out,
|
60 |
+
Lse,
|
61 |
+
TMP,
|
62 |
+
softmax_scale,
|
63 |
+
stride_qb,
|
64 |
+
stride_qh,
|
65 |
+
stride_qm,
|
66 |
+
stride_kb,
|
67 |
+
stride_kh,
|
68 |
+
stride_kn,
|
69 |
+
stride_vb,
|
70 |
+
stride_vh,
|
71 |
+
stride_vn,
|
72 |
+
stride_bb,
|
73 |
+
stride_bh,
|
74 |
+
stride_bm,
|
75 |
+
stride_ob,
|
76 |
+
stride_oh,
|
77 |
+
stride_om,
|
78 |
+
nheads,
|
79 |
+
seqlen_q,
|
80 |
+
seqlen_k,
|
81 |
+
seqlen_q_rounded,
|
82 |
+
headdim,
|
83 |
+
CACHE_KEY_SEQLEN_Q,
|
84 |
+
CACHE_KEY_SEQLEN_K,
|
85 |
+
BIAS_TYPE: tl.constexpr,
|
86 |
+
IS_CAUSAL: tl.constexpr,
|
87 |
+
BLOCK_HEADDIM: tl.constexpr,
|
88 |
+
EVEN_M: tl.constexpr,
|
89 |
+
EVEN_N: tl.constexpr,
|
90 |
+
EVEN_HEADDIM: tl.constexpr,
|
91 |
+
BLOCK_M: tl.constexpr,
|
92 |
+
BLOCK_N: tl.constexpr,
|
93 |
+
):
|
94 |
+
start_m = tl.program_id(0)
|
95 |
+
off_hb = tl.program_id(1)
|
96 |
+
off_b = off_hb // nheads
|
97 |
+
off_h = off_hb % nheads
|
98 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
99 |
+
offs_n = tl.arange(0, BLOCK_N)
|
100 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
101 |
+
q_ptrs = (
|
102 |
+
Q
|
103 |
+
+ off_b * stride_qb
|
104 |
+
+ off_h * stride_qh
|
105 |
+
+ (offs_m[:, None] * stride_qm + offs_d[None, :])
|
106 |
+
)
|
107 |
+
k_ptrs = (
|
108 |
+
K
|
109 |
+
+ off_b * stride_kb
|
110 |
+
+ off_h * stride_kh
|
111 |
+
+ (offs_n[:, None] * stride_kn + offs_d[None, :])
|
112 |
+
)
|
113 |
+
v_ptrs = (
|
114 |
+
V
|
115 |
+
+ off_b * stride_vb
|
116 |
+
+ off_h * stride_vh
|
117 |
+
+ (offs_n[:, None] * stride_vn + offs_d[None, :])
|
118 |
+
)
|
119 |
+
if BIAS_TYPE == "vector":
|
120 |
+
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
|
121 |
+
elif BIAS_TYPE == "matrix":
|
122 |
+
b_ptrs = (
|
123 |
+
Bias
|
124 |
+
+ off_b * stride_bb
|
125 |
+
+ off_h * stride_bh
|
126 |
+
+ (offs_m[:, None] * stride_bm + offs_n[None, :])
|
127 |
+
)
|
128 |
+
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
|
129 |
+
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
130 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
131 |
+
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
|
132 |
+
if EVEN_M & EVEN_N:
|
133 |
+
if EVEN_HEADDIM:
|
134 |
+
q = tl.load(q_ptrs)
|
135 |
+
else:
|
136 |
+
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
137 |
+
elif EVEN_HEADDIM:
|
138 |
+
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
|
139 |
+
else:
|
140 |
+
q = tl.load(
|
141 |
+
q_ptrs,
|
142 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
143 |
+
other=0.0,
|
144 |
+
)
|
145 |
+
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
|
146 |
+
for start_n in range(0, end_n, BLOCK_N):
|
147 |
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
148 |
+
if EVEN_N & EVEN_M:
|
149 |
+
if EVEN_HEADDIM:
|
150 |
+
k = tl.load(k_ptrs + start_n * stride_kn)
|
151 |
+
else:
|
152 |
+
k = tl.load(
|
153 |
+
k_ptrs + start_n * stride_kn,
|
154 |
+
mask=offs_d[None, :] < headdim,
|
155 |
+
other=0.0,
|
156 |
+
)
|
157 |
+
elif EVEN_HEADDIM:
|
158 |
+
k = tl.load(
|
159 |
+
k_ptrs + start_n * stride_kn,
|
160 |
+
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
161 |
+
other=0.0,
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
k = tl.load(
|
165 |
+
k_ptrs + start_n * stride_kn,
|
166 |
+
mask=((start_n + offs_n)[:, None] < seqlen_k)
|
167 |
+
& (offs_d[None, :] < headdim),
|
168 |
+
other=0.0,
|
169 |
+
)
|
170 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
171 |
+
qk += tl.dot(q, k, trans_b=True)
|
172 |
+
if not EVEN_N:
|
173 |
+
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
|
174 |
+
if IS_CAUSAL:
|
175 |
+
qk += tl.where(
|
176 |
+
offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")
|
177 |
+
)
|
178 |
+
if BIAS_TYPE != "none":
|
179 |
+
if BIAS_TYPE == "vector":
|
180 |
+
if EVEN_N:
|
181 |
+
bias = tl.load(b_ptrs + start_n).to(tl.float32)
|
182 |
+
else:
|
183 |
+
bias = tl.load(
|
184 |
+
b_ptrs + start_n, mask=start_n + offs_n < seqlen_k, other=0.0
|
185 |
+
).to(tl.float32)
|
186 |
+
bias = bias[None, :]
|
187 |
+
elif BIAS_TYPE == "matrix":
|
188 |
+
if EVEN_M & EVEN_N:
|
189 |
+
bias = tl.load(b_ptrs + start_n).to(tl.float32)
|
190 |
+
else:
|
191 |
+
bias = tl.load(
|
192 |
+
b_ptrs + start_n,
|
193 |
+
mask=(offs_m[:, None] < seqlen_q)
|
194 |
+
& ((start_n + offs_n)[None, :] < seqlen_k),
|
195 |
+
other=0.0,
|
196 |
+
).to(tl.float32)
|
197 |
+
qk = qk * softmax_scale + bias
|
198 |
+
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
|
199 |
+
p = tl.exp(qk - m_ij[:, None])
|
200 |
+
else:
|
201 |
+
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
|
202 |
+
p = tl.exp(qk * softmax_scale - m_ij[:, None])
|
203 |
+
l_ij = tl.sum(p, 1)
|
204 |
+
acc_o_scale = tl.exp(m_i - m_ij)
|
205 |
+
tl.store(t_ptrs, acc_o_scale)
|
206 |
+
acc_o_scale = tl.load(t_ptrs)
|
207 |
+
acc_o = acc_o * acc_o_scale[:, None]
|
208 |
+
if EVEN_N & EVEN_M:
|
209 |
+
if EVEN_HEADDIM:
|
210 |
+
v = tl.load(v_ptrs + start_n * stride_vn)
|
211 |
+
else:
|
212 |
+
v = tl.load(
|
213 |
+
v_ptrs + start_n * stride_vn,
|
214 |
+
mask=offs_d[None, :] < headdim,
|
215 |
+
other=0.0,
|
216 |
+
)
|
217 |
+
elif EVEN_HEADDIM:
|
218 |
+
v = tl.load(
|
219 |
+
v_ptrs + start_n * stride_vn,
|
220 |
+
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
221 |
+
other=0.0,
|
222 |
+
)
|
223 |
+
else:
|
224 |
+
v = tl.load(
|
225 |
+
v_ptrs + start_n * stride_vn,
|
226 |
+
mask=((start_n + offs_n)[:, None] < seqlen_k)
|
227 |
+
& (offs_d[None, :] < headdim),
|
228 |
+
other=0.0,
|
229 |
+
)
|
230 |
+
p = p.to(v.dtype)
|
231 |
+
acc_o += tl.dot(p, v)
|
232 |
+
m_i = m_ij
|
233 |
+
l_i_new = tl.exp(lse_i - m_ij) + l_ij
|
234 |
+
lse_i = m_ij + tl.log(l_i_new)
|
235 |
+
o_scale = tl.exp(m_i - lse_i)
|
236 |
+
tl.store(t_ptrs, o_scale)
|
237 |
+
o_scale = tl.load(t_ptrs)
|
238 |
+
acc_o = acc_o * o_scale[:, None]
|
239 |
+
start_m = tl.program_id(0)
|
240 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
241 |
+
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
|
242 |
+
tl.store(lse_ptrs, lse_i)
|
243 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
244 |
+
out_ptrs = (
|
245 |
+
Out
|
246 |
+
+ off_b * stride_ob
|
247 |
+
+ off_h * stride_oh
|
248 |
+
+ (offs_m[:, None] * stride_om + offs_d[None, :])
|
249 |
+
)
|
250 |
+
if EVEN_M:
|
251 |
+
if EVEN_HEADDIM:
|
252 |
+
tl.store(out_ptrs, acc_o)
|
253 |
+
else:
|
254 |
+
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
|
255 |
+
elif EVEN_HEADDIM:
|
256 |
+
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
|
257 |
+
else:
|
258 |
+
tl.store(
|
259 |
+
out_ptrs,
|
260 |
+
acc_o,
|
261 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
262 |
+
)
|
263 |
+
|
264 |
+
|
265 |
+
@triton.jit
|
266 |
+
def _bwd_preprocess_do_o_dot(
|
267 |
+
Out,
|
268 |
+
DO,
|
269 |
+
Delta,
|
270 |
+
stride_ob,
|
271 |
+
stride_oh,
|
272 |
+
stride_om,
|
273 |
+
stride_dob,
|
274 |
+
stride_doh,
|
275 |
+
stride_dom,
|
276 |
+
nheads,
|
277 |
+
seqlen_q,
|
278 |
+
seqlen_q_rounded,
|
279 |
+
headdim,
|
280 |
+
BLOCK_M: tl.constexpr,
|
281 |
+
BLOCK_HEADDIM: tl.constexpr,
|
282 |
+
):
|
283 |
+
start_m = tl.program_id(0)
|
284 |
+
off_hb = tl.program_id(1)
|
285 |
+
off_b = off_hb // nheads
|
286 |
+
off_h = off_hb % nheads
|
287 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
288 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
289 |
+
o = tl.load(
|
290 |
+
Out
|
291 |
+
+ off_b * stride_ob
|
292 |
+
+ off_h * stride_oh
|
293 |
+
+ offs_m[:, None] * stride_om
|
294 |
+
+ offs_d[None, :],
|
295 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
296 |
+
other=0.0,
|
297 |
+
).to(tl.float32)
|
298 |
+
do = tl.load(
|
299 |
+
DO
|
300 |
+
+ off_b * stride_dob
|
301 |
+
+ off_h * stride_doh
|
302 |
+
+ offs_m[:, None] * stride_dom
|
303 |
+
+ offs_d[None, :],
|
304 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
305 |
+
other=0.0,
|
306 |
+
).to(tl.float32)
|
307 |
+
delta = tl.sum(o * do, axis=1)
|
308 |
+
tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
|
309 |
+
|
310 |
+
|
311 |
+
@triton.jit
|
312 |
+
def _bwd_store_dk_dv(
|
313 |
+
dk_ptrs,
|
314 |
+
dv_ptrs,
|
315 |
+
dk,
|
316 |
+
dv,
|
317 |
+
offs_n,
|
318 |
+
offs_d,
|
319 |
+
seqlen_k,
|
320 |
+
headdim,
|
321 |
+
EVEN_M: tl.constexpr,
|
322 |
+
EVEN_N: tl.constexpr,
|
323 |
+
EVEN_HEADDIM: tl.constexpr,
|
324 |
+
):
|
325 |
+
if EVEN_N & EVEN_M:
|
326 |
+
if EVEN_HEADDIM:
|
327 |
+
tl.store(dv_ptrs, dv)
|
328 |
+
tl.store(dk_ptrs, dk)
|
329 |
+
else:
|
330 |
+
tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
|
331 |
+
tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
|
332 |
+
elif EVEN_HEADDIM:
|
333 |
+
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
|
334 |
+
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
|
335 |
+
else:
|
336 |
+
tl.store(
|
337 |
+
dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)
|
338 |
+
)
|
339 |
+
tl.store(
|
340 |
+
dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)
|
341 |
+
)
|
342 |
+
|
343 |
+
|
344 |
+
@triton.jit
|
345 |
+
def _bwd_kernel_one_col_block(
|
346 |
+
start_n,
|
347 |
+
Q,
|
348 |
+
K,
|
349 |
+
V,
|
350 |
+
Bias,
|
351 |
+
DO,
|
352 |
+
DQ,
|
353 |
+
DK,
|
354 |
+
DV,
|
355 |
+
LSE,
|
356 |
+
D,
|
357 |
+
softmax_scale,
|
358 |
+
stride_qm,
|
359 |
+
stride_kn,
|
360 |
+
stride_vn,
|
361 |
+
stride_bm,
|
362 |
+
stride_dom,
|
363 |
+
stride_dqm,
|
364 |
+
stride_dkn,
|
365 |
+
stride_dvn,
|
366 |
+
seqlen_q,
|
367 |
+
seqlen_k,
|
368 |
+
headdim,
|
369 |
+
ATOMIC_ADD: tl.constexpr,
|
370 |
+
BIAS_TYPE: tl.constexpr,
|
371 |
+
IS_CAUSAL: tl.constexpr,
|
372 |
+
BLOCK_HEADDIM: tl.constexpr,
|
373 |
+
EVEN_M: tl.constexpr,
|
374 |
+
EVEN_N: tl.constexpr,
|
375 |
+
EVEN_HEADDIM: tl.constexpr,
|
376 |
+
BLOCK_M: tl.constexpr,
|
377 |
+
BLOCK_N: tl.constexpr,
|
378 |
+
):
|
379 |
+
begin_m = 0 if not IS_CAUSAL else start_n * BLOCK_N // BLOCK_M * BLOCK_M
|
380 |
+
offs_qm = begin_m + tl.arange(0, BLOCK_M)
|
381 |
+
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
382 |
+
offs_m = tl.arange(0, BLOCK_M)
|
383 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
384 |
+
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
|
385 |
+
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
|
386 |
+
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
387 |
+
do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
|
388 |
+
dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
|
389 |
+
if BIAS_TYPE == "vector":
|
390 |
+
b_ptrs = Bias + offs_n
|
391 |
+
elif BIAS_TYPE == "matrix":
|
392 |
+
b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
|
393 |
+
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
394 |
+
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
395 |
+
if begin_m >= seqlen_q:
|
396 |
+
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
|
397 |
+
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
|
398 |
+
_bwd_store_dk_dv(
|
399 |
+
dk_ptrs,
|
400 |
+
dv_ptrs,
|
401 |
+
dk,
|
402 |
+
dv,
|
403 |
+
offs_n,
|
404 |
+
offs_d,
|
405 |
+
seqlen_k,
|
406 |
+
headdim,
|
407 |
+
EVEN_M=EVEN_M,
|
408 |
+
EVEN_N=EVEN_N,
|
409 |
+
EVEN_HEADDIM=EVEN_HEADDIM,
|
410 |
+
)
|
411 |
+
return
|
412 |
+
if EVEN_N & EVEN_M:
|
413 |
+
if EVEN_HEADDIM:
|
414 |
+
k = tl.load(k_ptrs)
|
415 |
+
v = tl.load(v_ptrs)
|
416 |
+
else:
|
417 |
+
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
418 |
+
v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
419 |
+
elif EVEN_HEADDIM:
|
420 |
+
k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
|
421 |
+
v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
|
422 |
+
else:
|
423 |
+
k = tl.load(
|
424 |
+
k_ptrs,
|
425 |
+
mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
426 |
+
other=0.0,
|
427 |
+
)
|
428 |
+
v = tl.load(
|
429 |
+
v_ptrs,
|
430 |
+
mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
431 |
+
other=0.0,
|
432 |
+
)
|
433 |
+
num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
|
434 |
+
for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
|
435 |
+
start_m = tl.multiple_of(start_m, BLOCK_M)
|
436 |
+
offs_m_curr = start_m + offs_m
|
437 |
+
if EVEN_M & EVEN_HEADDIM:
|
438 |
+
q = tl.load(q_ptrs)
|
439 |
+
elif EVEN_HEADDIM:
|
440 |
+
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
|
441 |
+
else:
|
442 |
+
q = tl.load(
|
443 |
+
q_ptrs,
|
444 |
+
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
445 |
+
other=0.0,
|
446 |
+
)
|
447 |
+
qk = tl.dot(q, k, trans_b=True)
|
448 |
+
if not EVEN_N:
|
449 |
+
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
|
450 |
+
if IS_CAUSAL:
|
451 |
+
qk = tl.where(offs_m_curr[:, None] >= offs_n[None, :], qk, float("-inf"))
|
452 |
+
if BIAS_TYPE != "none":
|
453 |
+
tl.debug_barrier()
|
454 |
+
if BIAS_TYPE == "vector":
|
455 |
+
if EVEN_N:
|
456 |
+
bias = tl.load(b_ptrs).to(tl.float32)
|
457 |
+
else:
|
458 |
+
bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(
|
459 |
+
tl.float32
|
460 |
+
)
|
461 |
+
bias = bias[None, :]
|
462 |
+
elif BIAS_TYPE == "matrix":
|
463 |
+
if EVEN_M & EVEN_N:
|
464 |
+
bias = tl.load(b_ptrs).to(tl.float32)
|
465 |
+
else:
|
466 |
+
bias = tl.load(
|
467 |
+
b_ptrs,
|
468 |
+
mask=(offs_m_curr[:, None] < seqlen_q)
|
469 |
+
& (offs_n[None, :] < seqlen_k),
|
470 |
+
other=0.0,
|
471 |
+
).to(tl.float32)
|
472 |
+
qk = qk * softmax_scale + bias
|
473 |
+
if not EVEN_M & EVEN_HEADDIM:
|
474 |
+
tl.debug_barrier()
|
475 |
+
lse_i = tl.load(LSE + offs_m_curr)
|
476 |
+
if BIAS_TYPE == "none":
|
477 |
+
p = tl.exp(qk * softmax_scale - lse_i[:, None])
|
478 |
+
else:
|
479 |
+
p = tl.exp(qk - lse_i[:, None])
|
480 |
+
if EVEN_M & EVEN_HEADDIM:
|
481 |
+
do = tl.load(do_ptrs)
|
482 |
+
else:
|
483 |
+
do = tl.load(
|
484 |
+
do_ptrs,
|
485 |
+
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
486 |
+
other=0.0,
|
487 |
+
)
|
488 |
+
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
|
489 |
+
if not EVEN_M & EVEN_HEADDIM:
|
490 |
+
tl.debug_barrier()
|
491 |
+
dp = tl.dot(do, v, trans_b=True)
|
492 |
+
if not EVEN_HEADDIM:
|
493 |
+
tl.debug_barrier()
|
494 |
+
Di = tl.load(D + offs_m_curr)
|
495 |
+
ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
|
496 |
+
dk += tl.dot(ds, q, trans_a=True)
|
497 |
+
if not EVEN_M & EVEN_HEADDIM:
|
498 |
+
tl.debug_barrier()
|
499 |
+
if not ATOMIC_ADD:
|
500 |
+
if EVEN_M & EVEN_HEADDIM:
|
501 |
+
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
|
502 |
+
dq += tl.dot(ds, k)
|
503 |
+
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
|
504 |
+
elif EVEN_HEADDIM:
|
505 |
+
dq = tl.load(
|
506 |
+
dq_ptrs,
|
507 |
+
mask=offs_m_curr[:, None] < seqlen_q,
|
508 |
+
other=0.0,
|
509 |
+
eviction_policy="evict_last",
|
510 |
+
)
|
511 |
+
dq += tl.dot(ds, k)
|
512 |
+
tl.store(
|
513 |
+
dq_ptrs,
|
514 |
+
dq,
|
515 |
+
mask=offs_m_curr[:, None] < seqlen_q,
|
516 |
+
eviction_policy="evict_last",
|
517 |
+
)
|
518 |
+
else:
|
519 |
+
dq = tl.load(
|
520 |
+
dq_ptrs,
|
521 |
+
mask=(offs_m_curr[:, None] < seqlen_q)
|
522 |
+
& (offs_d[None, :] < headdim),
|
523 |
+
other=0.0,
|
524 |
+
eviction_policy="evict_last",
|
525 |
+
)
|
526 |
+
dq += tl.dot(ds, k)
|
527 |
+
tl.store(
|
528 |
+
dq_ptrs,
|
529 |
+
dq,
|
530 |
+
mask=(offs_m_curr[:, None] < seqlen_q)
|
531 |
+
& (offs_d[None, :] < headdim),
|
532 |
+
eviction_policy="evict_last",
|
533 |
+
)
|
534 |
+
else:
|
535 |
+
dq = tl.dot(ds, k)
|
536 |
+
if EVEN_M & EVEN_HEADDIM:
|
537 |
+
tl.atomic_add(dq_ptrs, dq)
|
538 |
+
elif EVEN_HEADDIM:
|
539 |
+
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
|
540 |
+
else:
|
541 |
+
tl.atomic_add(
|
542 |
+
dq_ptrs,
|
543 |
+
dq,
|
544 |
+
mask=(offs_m_curr[:, None] < seqlen_q)
|
545 |
+
& (offs_d[None, :] < headdim),
|
546 |
+
)
|
547 |
+
dq_ptrs += BLOCK_M * stride_dqm
|
548 |
+
q_ptrs += BLOCK_M * stride_qm
|
549 |
+
do_ptrs += BLOCK_M * stride_dom
|
550 |
+
if BIAS_TYPE == "matrix":
|
551 |
+
b_ptrs += BLOCK_M * stride_bm
|
552 |
+
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
|
553 |
+
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
|
554 |
+
_bwd_store_dk_dv(
|
555 |
+
dk_ptrs,
|
556 |
+
dv_ptrs,
|
557 |
+
dk,
|
558 |
+
dv,
|
559 |
+
offs_n,
|
560 |
+
offs_d,
|
561 |
+
seqlen_k,
|
562 |
+
headdim,
|
563 |
+
EVEN_M=EVEN_M,
|
564 |
+
EVEN_N=EVEN_N,
|
565 |
+
EVEN_HEADDIM=EVEN_HEADDIM,
|
566 |
+
)
|
567 |
+
|
568 |
+
|
569 |
+
def init_to_zero(name):
|
570 |
+
return lambda nargs: nargs[name].zero_()
|
571 |
+
|
572 |
+
|
573 |
+
@triton.autotune(
|
574 |
+
configs=[
|
575 |
+
triton.Config(
|
576 |
+
{"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False},
|
577 |
+
num_warps=8,
|
578 |
+
num_stages=1,
|
579 |
+
pre_hook=init_to_zero("DQ"),
|
580 |
+
),
|
581 |
+
triton.Config(
|
582 |
+
{"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True},
|
583 |
+
num_warps=8,
|
584 |
+
num_stages=1,
|
585 |
+
pre_hook=init_to_zero("DQ"),
|
586 |
+
),
|
587 |
+
],
|
588 |
+
key=[
|
589 |
+
"CACHE_KEY_SEQLEN_Q",
|
590 |
+
"CACHE_KEY_SEQLEN_K",
|
591 |
+
"BIAS_TYPE",
|
592 |
+
"IS_CAUSAL",
|
593 |
+
"BLOCK_HEADDIM",
|
594 |
+
],
|
595 |
+
)
|
596 |
+
@triton.heuristics(
|
597 |
+
{
|
598 |
+
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
599 |
+
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
600 |
+
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
601 |
+
}
|
602 |
+
)
|
603 |
+
@triton.jit
|
604 |
+
def _bwd_kernel(
|
605 |
+
Q,
|
606 |
+
K,
|
607 |
+
V,
|
608 |
+
Bias,
|
609 |
+
DO,
|
610 |
+
DQ,
|
611 |
+
DK,
|
612 |
+
DV,
|
613 |
+
LSE,
|
614 |
+
D,
|
615 |
+
softmax_scale,
|
616 |
+
stride_qb,
|
617 |
+
stride_qh,
|
618 |
+
stride_qm,
|
619 |
+
stride_kb,
|
620 |
+
stride_kh,
|
621 |
+
stride_kn,
|
622 |
+
stride_vb,
|
623 |
+
stride_vh,
|
624 |
+
stride_vn,
|
625 |
+
stride_bb,
|
626 |
+
stride_bh,
|
627 |
+
stride_bm,
|
628 |
+
stride_dob,
|
629 |
+
stride_doh,
|
630 |
+
stride_dom,
|
631 |
+
stride_dqb,
|
632 |
+
stride_dqh,
|
633 |
+
stride_dqm,
|
634 |
+
stride_dkb,
|
635 |
+
stride_dkh,
|
636 |
+
stride_dkn,
|
637 |
+
stride_dvb,
|
638 |
+
stride_dvh,
|
639 |
+
stride_dvn,
|
640 |
+
nheads,
|
641 |
+
seqlen_q,
|
642 |
+
seqlen_k,
|
643 |
+
seqlen_q_rounded,
|
644 |
+
headdim,
|
645 |
+
CACHE_KEY_SEQLEN_Q,
|
646 |
+
CACHE_KEY_SEQLEN_K,
|
647 |
+
BIAS_TYPE: tl.constexpr,
|
648 |
+
IS_CAUSAL: tl.constexpr,
|
649 |
+
BLOCK_HEADDIM: tl.constexpr,
|
650 |
+
SEQUENCE_PARALLEL: tl.constexpr,
|
651 |
+
EVEN_M: tl.constexpr,
|
652 |
+
EVEN_N: tl.constexpr,
|
653 |
+
EVEN_HEADDIM: tl.constexpr,
|
654 |
+
BLOCK_M: tl.constexpr,
|
655 |
+
BLOCK_N: tl.constexpr,
|
656 |
+
):
|
657 |
+
off_hb = tl.program_id(1)
|
658 |
+
off_b = off_hb // nheads
|
659 |
+
off_h = off_hb % nheads
|
660 |
+
Q += off_b * stride_qb + off_h * stride_qh
|
661 |
+
K += off_b * stride_kb + off_h * stride_kh
|
662 |
+
V += off_b * stride_vb + off_h * stride_vh
|
663 |
+
DO += off_b * stride_dob + off_h * stride_doh
|
664 |
+
DQ += off_b * stride_dqb + off_h * stride_dqh
|
665 |
+
DK += off_b * stride_dkb + off_h * stride_dkh
|
666 |
+
DV += off_b * stride_dvb + off_h * stride_dvh
|
667 |
+
if BIAS_TYPE != "none":
|
668 |
+
Bias += off_b * stride_bb + off_h * stride_bh
|
669 |
+
D += off_hb * seqlen_q_rounded
|
670 |
+
LSE += off_hb * seqlen_q_rounded
|
671 |
+
if not SEQUENCE_PARALLEL:
|
672 |
+
num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
|
673 |
+
for start_n in range(0, num_block_n):
|
674 |
+
_bwd_kernel_one_col_block(
|
675 |
+
start_n,
|
676 |
+
Q,
|
677 |
+
K,
|
678 |
+
V,
|
679 |
+
Bias,
|
680 |
+
DO,
|
681 |
+
DQ,
|
682 |
+
DK,
|
683 |
+
DV,
|
684 |
+
LSE,
|
685 |
+
D,
|
686 |
+
softmax_scale,
|
687 |
+
stride_qm,
|
688 |
+
stride_kn,
|
689 |
+
stride_vn,
|
690 |
+
stride_bm,
|
691 |
+
stride_dom,
|
692 |
+
stride_dqm,
|
693 |
+
stride_dkn,
|
694 |
+
stride_dvn,
|
695 |
+
seqlen_q,
|
696 |
+
seqlen_k,
|
697 |
+
headdim,
|
698 |
+
ATOMIC_ADD=False,
|
699 |
+
BIAS_TYPE=BIAS_TYPE,
|
700 |
+
IS_CAUSAL=IS_CAUSAL,
|
701 |
+
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
702 |
+
EVEN_M=EVEN_M,
|
703 |
+
EVEN_N=EVEN_N,
|
704 |
+
EVEN_HEADDIM=EVEN_HEADDIM,
|
705 |
+
BLOCK_M=BLOCK_M,
|
706 |
+
BLOCK_N=BLOCK_N,
|
707 |
+
)
|
708 |
+
else:
|
709 |
+
start_n = tl.program_id(0)
|
710 |
+
_bwd_kernel_one_col_block(
|
711 |
+
start_n,
|
712 |
+
Q,
|
713 |
+
K,
|
714 |
+
V,
|
715 |
+
Bias,
|
716 |
+
DO,
|
717 |
+
DQ,
|
718 |
+
DK,
|
719 |
+
DV,
|
720 |
+
LSE,
|
721 |
+
D,
|
722 |
+
softmax_scale,
|
723 |
+
stride_qm,
|
724 |
+
stride_kn,
|
725 |
+
stride_vn,
|
726 |
+
stride_bm,
|
727 |
+
stride_dom,
|
728 |
+
stride_dqm,
|
729 |
+
stride_dkn,
|
730 |
+
stride_dvn,
|
731 |
+
seqlen_q,
|
732 |
+
seqlen_k,
|
733 |
+
headdim,
|
734 |
+
ATOMIC_ADD=True,
|
735 |
+
BIAS_TYPE=BIAS_TYPE,
|
736 |
+
IS_CAUSAL=IS_CAUSAL,
|
737 |
+
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
738 |
+
EVEN_M=EVEN_M,
|
739 |
+
EVEN_N=EVEN_N,
|
740 |
+
EVEN_HEADDIM=EVEN_HEADDIM,
|
741 |
+
BLOCK_M=BLOCK_M,
|
742 |
+
BLOCK_N=BLOCK_N,
|
743 |
+
)
|
744 |
+
|
745 |
+
|
746 |
+
def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
|
747 |
+
(batch, seqlen_q, nheads, d) = q.shape
|
748 |
+
(_, seqlen_k, _, _) = k.shape
|
749 |
+
assert k.shape == (batch, seqlen_k, nheads, d)
|
750 |
+
assert v.shape == (batch, seqlen_k, nheads, d)
|
751 |
+
assert d <= 128, "FlashAttention only support head dimensions up to 128"
|
752 |
+
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
|
753 |
+
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
|
754 |
+
assert q.is_cuda and k.is_cuda and v.is_cuda
|
755 |
+
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
|
756 |
+
has_bias = bias is not None
|
757 |
+
bias_type = "none"
|
758 |
+
if has_bias:
|
759 |
+
assert bias.dtype in [q.dtype, torch.float]
|
760 |
+
assert bias.is_cuda
|
761 |
+
assert bias.dim() == 4
|
762 |
+
if bias.stride(-1) != 1:
|
763 |
+
bias = bias.contiguous()
|
764 |
+
if bias.shape[2:] == (1, seqlen_k):
|
765 |
+
bias_type = "vector"
|
766 |
+
elif bias.shape[2:] == (seqlen_q, seqlen_k):
|
767 |
+
bias_type = "matrix"
|
768 |
+
else:
|
769 |
+
raise RuntimeError(
|
770 |
+
"Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)"
|
771 |
+
)
|
772 |
+
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
|
773 |
+
bias_strides = (
|
774 |
+
(bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
|
775 |
+
)
|
776 |
+
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
777 |
+
lse = torch.empty(
|
778 |
+
(batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32
|
779 |
+
)
|
780 |
+
tmp = torch.empty(
|
781 |
+
(batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32
|
782 |
+
)
|
783 |
+
o = torch.empty_like(q)
|
784 |
+
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
|
785 |
+
BLOCK = 128
|
786 |
+
num_warps = 4 if d <= 64 else 8
|
787 |
+
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
|
788 |
+
_fwd_kernel[grid](
|
789 |
+
q,
|
790 |
+
k,
|
791 |
+
v,
|
792 |
+
bias,
|
793 |
+
o,
|
794 |
+
lse,
|
795 |
+
tmp,
|
796 |
+
softmax_scale,
|
797 |
+
q.stride(0),
|
798 |
+
q.stride(2),
|
799 |
+
q.stride(1),
|
800 |
+
k.stride(0),
|
801 |
+
k.stride(2),
|
802 |
+
k.stride(1),
|
803 |
+
v.stride(0),
|
804 |
+
v.stride(2),
|
805 |
+
v.stride(1),
|
806 |
+
*bias_strides,
|
807 |
+
o.stride(0),
|
808 |
+
o.stride(2),
|
809 |
+
o.stride(1),
|
810 |
+
nheads,
|
811 |
+
seqlen_q,
|
812 |
+
seqlen_k,
|
813 |
+
seqlen_q_rounded,
|
814 |
+
d,
|
815 |
+
seqlen_q // 32,
|
816 |
+
seqlen_k // 32,
|
817 |
+
bias_type,
|
818 |
+
causal,
|
819 |
+
BLOCK_HEADDIM,
|
820 |
+
BLOCK_M=BLOCK,
|
821 |
+
BLOCK_N=BLOCK,
|
822 |
+
num_warps=num_warps,
|
823 |
+
num_stages=1
|
824 |
+
)
|
825 |
+
return (o, lse, softmax_scale)
|
826 |
+
|
827 |
+
|
828 |
+
def _flash_attn_backward(
|
829 |
+
do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None
|
830 |
+
):
|
831 |
+
if do.stride(-1) != 1:
|
832 |
+
do = do.contiguous()
|
833 |
+
(batch, seqlen_q, nheads, d) = q.shape
|
834 |
+
(_, seqlen_k, _, _) = k.shape
|
835 |
+
assert d <= 128
|
836 |
+
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
837 |
+
assert lse.shape == (batch, nheads, seqlen_q_rounded)
|
838 |
+
assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
|
839 |
+
assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
|
840 |
+
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
|
841 |
+
dq_accum = torch.empty_like(q, dtype=torch.float32)
|
842 |
+
delta = torch.empty_like(lse)
|
843 |
+
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
|
844 |
+
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
|
845 |
+
_bwd_preprocess_do_o_dot[grid](
|
846 |
+
o,
|
847 |
+
do,
|
848 |
+
delta,
|
849 |
+
o.stride(0),
|
850 |
+
o.stride(2),
|
851 |
+
o.stride(1),
|
852 |
+
do.stride(0),
|
853 |
+
do.stride(2),
|
854 |
+
do.stride(1),
|
855 |
+
nheads,
|
856 |
+
seqlen_q,
|
857 |
+
seqlen_q_rounded,
|
858 |
+
d,
|
859 |
+
BLOCK_M=128,
|
860 |
+
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
861 |
+
)
|
862 |
+
has_bias = bias is not None
|
863 |
+
bias_type = "none"
|
864 |
+
if has_bias:
|
865 |
+
assert bias.dtype in [q.dtype, torch.float]
|
866 |
+
assert bias.is_cuda
|
867 |
+
assert bias.dim() == 4
|
868 |
+
assert bias.stride(-1) == 1
|
869 |
+
if bias.shape[2:] == (1, seqlen_k):
|
870 |
+
bias_type = "vector"
|
871 |
+
elif bias.shape[2:] == (seqlen_q, seqlen_k):
|
872 |
+
bias_type = "matrix"
|
873 |
+
else:
|
874 |
+
raise RuntimeError(
|
875 |
+
"Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)"
|
876 |
+
)
|
877 |
+
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
|
878 |
+
bias_strides = (
|
879 |
+
(bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
|
880 |
+
)
|
881 |
+
grid = lambda META: (
|
882 |
+
triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
|
883 |
+
batch * nheads,
|
884 |
+
)
|
885 |
+
_bwd_kernel[grid](
|
886 |
+
q,
|
887 |
+
k,
|
888 |
+
v,
|
889 |
+
bias,
|
890 |
+
do,
|
891 |
+
dq_accum,
|
892 |
+
dk,
|
893 |
+
dv,
|
894 |
+
lse,
|
895 |
+
delta,
|
896 |
+
softmax_scale,
|
897 |
+
q.stride(0),
|
898 |
+
q.stride(2),
|
899 |
+
q.stride(1),
|
900 |
+
k.stride(0),
|
901 |
+
k.stride(2),
|
902 |
+
k.stride(1),
|
903 |
+
v.stride(0),
|
904 |
+
v.stride(2),
|
905 |
+
v.stride(1),
|
906 |
+
*bias_strides,
|
907 |
+
do.stride(0),
|
908 |
+
do.stride(2),
|
909 |
+
do.stride(1),
|
910 |
+
dq_accum.stride(0),
|
911 |
+
dq_accum.stride(2),
|
912 |
+
dq_accum.stride(1),
|
913 |
+
dk.stride(0),
|
914 |
+
dk.stride(2),
|
915 |
+
dk.stride(1),
|
916 |
+
dv.stride(0),
|
917 |
+
dv.stride(2),
|
918 |
+
dv.stride(1),
|
919 |
+
nheads,
|
920 |
+
seqlen_q,
|
921 |
+
seqlen_k,
|
922 |
+
seqlen_q_rounded,
|
923 |
+
d,
|
924 |
+
seqlen_q // 32,
|
925 |
+
seqlen_k // 32,
|
926 |
+
bias_type,
|
927 |
+
causal,
|
928 |
+
BLOCK_HEADDIM
|
929 |
+
)
|
930 |
+
dq.copy_(dq_accum)
|
931 |
+
|
932 |
+
|
933 |
+
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
934 |
+
|
935 |
+
@staticmethod
|
936 |
+
def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
|
937 |
+
"""
|
938 |
+
qkv: (batch, seqlen, 3, nheads, headdim)
|
939 |
+
bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
|
940 |
+
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
|
941 |
+
ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
|
942 |
+
"""
|
943 |
+
if qkv.stride(-1) != 1:
|
944 |
+
qkv = qkv.contiguous()
|
945 |
+
(o, lse, ctx.softmax_scale) = _flash_attn_forward(
|
946 |
+
qkv[:, :, 0],
|
947 |
+
qkv[:, :, 1],
|
948 |
+
qkv[:, :, 2],
|
949 |
+
bias=bias,
|
950 |
+
causal=causal,
|
951 |
+
softmax_scale=softmax_scale,
|
952 |
+
)
|
953 |
+
ctx.save_for_backward(qkv, o, lse, bias)
|
954 |
+
ctx.causal = causal
|
955 |
+
return o
|
956 |
+
|
957 |
+
@staticmethod
|
958 |
+
def backward(ctx, do):
|
959 |
+
(qkv, o, lse, bias) = ctx.saved_tensors
|
960 |
+
assert not ctx.needs_input_grad[
|
961 |
+
1
|
962 |
+
], "FlashAttention does not support bias gradient yet"
|
963 |
+
with torch.inference_mode():
|
964 |
+
dqkv = torch.empty_like(qkv)
|
965 |
+
_flash_attn_backward(
|
966 |
+
do,
|
967 |
+
qkv[:, :, 0],
|
968 |
+
qkv[:, :, 1],
|
969 |
+
qkv[:, :, 2],
|
970 |
+
o,
|
971 |
+
lse,
|
972 |
+
dqkv[:, :, 0],
|
973 |
+
dqkv[:, :, 1],
|
974 |
+
dqkv[:, :, 2],
|
975 |
+
bias=bias,
|
976 |
+
causal=ctx.causal,
|
977 |
+
softmax_scale=ctx.softmax_scale,
|
978 |
+
)
|
979 |
+
return (dqkv, None, None, None)
|
980 |
+
|
981 |
+
|
982 |
+
flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
|
983 |
+
|
984 |
+
|
985 |
+
class FlashAttnKVPackedFunc(torch.autograd.Function):
|
986 |
+
|
987 |
+
@staticmethod
|
988 |
+
def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
|
989 |
+
"""
|
990 |
+
q: (batch, seqlen_q, nheads, headdim)
|
991 |
+
kv: (batch, seqlen_k, 2, nheads, headdim)
|
992 |
+
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
|
993 |
+
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
|
994 |
+
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
|
995 |
+
"""
|
996 |
+
(q, kv) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
|
997 |
+
(o, lse, ctx.softmax_scale) = _flash_attn_forward(
|
998 |
+
q,
|
999 |
+
kv[:, :, 0],
|
1000 |
+
kv[:, :, 1],
|
1001 |
+
bias=bias,
|
1002 |
+
causal=causal,
|
1003 |
+
softmax_scale=softmax_scale,
|
1004 |
+
)
|
1005 |
+
ctx.save_for_backward(q, kv, o, lse, bias)
|
1006 |
+
ctx.causal = causal
|
1007 |
+
return o
|
1008 |
+
|
1009 |
+
@staticmethod
|
1010 |
+
def backward(ctx, do):
|
1011 |
+
(q, kv, o, lse, bias) = ctx.saved_tensors
|
1012 |
+
if len(ctx.needs_input_grad) >= 3:
|
1013 |
+
assert not ctx.needs_input_grad[
|
1014 |
+
2
|
1015 |
+
], "FlashAttention does not support bias gradient yet"
|
1016 |
+
with torch.inference_mode():
|
1017 |
+
dq = torch.empty_like(q)
|
1018 |
+
dkv = torch.empty_like(kv)
|
1019 |
+
_flash_attn_backward(
|
1020 |
+
do,
|
1021 |
+
q,
|
1022 |
+
kv[:, :, 0],
|
1023 |
+
kv[:, :, 1],
|
1024 |
+
o,
|
1025 |
+
lse,
|
1026 |
+
dq,
|
1027 |
+
dkv[:, :, 0],
|
1028 |
+
dkv[:, :, 1],
|
1029 |
+
bias=bias,
|
1030 |
+
causal=ctx.causal,
|
1031 |
+
softmax_scale=ctx.softmax_scale,
|
1032 |
+
)
|
1033 |
+
return (dq, dkv, None, None, None)
|
1034 |
+
|
1035 |
+
|
1036 |
+
flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
|
1037 |
+
|
1038 |
+
|
1039 |
+
class FlashAttnFunc(torch.autograd.Function):
|
1040 |
+
|
1041 |
+
@staticmethod
|
1042 |
+
def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
|
1043 |
+
"""
|
1044 |
+
q: (batch_size, seqlen_q, nheads, headdim)
|
1045 |
+
k, v: (batch_size, seqlen_k, nheads, headdim)
|
1046 |
+
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
|
1047 |
+
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
|
1048 |
+
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
|
1049 |
+
"""
|
1050 |
+
(q, k, v) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
|
1051 |
+
(o, lse, ctx.softmax_scale) = _flash_attn_forward(
|
1052 |
+
q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
|
1053 |
+
)
|
1054 |
+
ctx.save_for_backward(q, k, v, o, lse, bias)
|
1055 |
+
ctx.causal = causal
|
1056 |
+
return o
|
1057 |
+
|
1058 |
+
@staticmethod
|
1059 |
+
def backward(ctx, do):
|
1060 |
+
(q, k, v, o, lse, bias) = ctx.saved_tensors
|
1061 |
+
assert not ctx.needs_input_grad[
|
1062 |
+
3
|
1063 |
+
], "FlashAttention does not support bias gradient yet"
|
1064 |
+
with torch.inference_mode():
|
1065 |
+
dq = torch.empty_like(q)
|
1066 |
+
dk = torch.empty_like(k)
|
1067 |
+
dv = torch.empty_like(v)
|
1068 |
+
_flash_attn_backward(
|
1069 |
+
do,
|
1070 |
+
q,
|
1071 |
+
k,
|
1072 |
+
v,
|
1073 |
+
o,
|
1074 |
+
lse,
|
1075 |
+
dq,
|
1076 |
+
dk,
|
1077 |
+
dv,
|
1078 |
+
bias=bias,
|
1079 |
+
causal=ctx.causal,
|
1080 |
+
softmax_scale=ctx.softmax_scale,
|
1081 |
+
)
|
1082 |
+
return (dq, dk, dv, None, None, None)
|
1083 |
+
|
1084 |
+
|
1085 |
+
flash_attn_func = FlashAttnFunc.apply
|
gptq_model-4bit-128g.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:681d688b0189880dda3e905b139279d64c1f65326ee92fd5a73596bfbf47f589
|
3 |
+
size 5469228368
|
hf_prefixlm_converter.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Converts Huggingface Causal LM to Prefix LM.
|
2 |
+
|
3 |
+
Conversion does lightweight surgery on a HuggingFace
|
4 |
+
Causal LM to convert it to a Prefix LM.
|
5 |
+
|
6 |
+
Prefix LMs accepts a `bidirectional_mask` input in `forward`
|
7 |
+
and treat the input prompt as the prefix in `generate`.
|
8 |
+
"""
|
9 |
+
|
10 |
+
from types import MethodType
|
11 |
+
from typing import Any, List, MutableMapping, Optional, Tuple, Union
|
12 |
+
import torch
|
13 |
+
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
14 |
+
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
|
15 |
+
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
|
16 |
+
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
|
17 |
+
|
18 |
+
_SUPPORTED_GPT_MODELS = (
|
19 |
+
GPT2LMHeadModel,
|
20 |
+
GPTJForCausalLM,
|
21 |
+
GPTNeoForCausalLM,
|
22 |
+
GPTNeoXForCausalLM,
|
23 |
+
)
|
24 |
+
CAUSAL_GPT_TYPES = Union[
|
25 |
+
GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM
|
26 |
+
]
|
27 |
+
|
28 |
+
|
29 |
+
def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
|
30 |
+
"""Converts a GPT-style Causal LM to a Prefix LM.
|
31 |
+
|
32 |
+
Supported HuggingFace model classes:
|
33 |
+
- `GPT2LMHeadModel`
|
34 |
+
- `GPTNeoForCausalLM`
|
35 |
+
- `GPTNeoXForCausalLM`
|
36 |
+
- `GPTJForCausalLM`
|
37 |
+
|
38 |
+
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
39 |
+
"""
|
40 |
+
if hasattr(model, "_prefix_lm_converted"):
|
41 |
+
return model
|
42 |
+
assert isinstance(model, _SUPPORTED_GPT_MODELS)
|
43 |
+
assert (
|
44 |
+
model.config.add_cross_attention == False
|
45 |
+
), "Only supports GPT-style decoder-only models"
|
46 |
+
|
47 |
+
def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
|
48 |
+
"""Helper that gets a list of the model's attention modules.
|
49 |
+
|
50 |
+
Each module has a `bias` buffer used for causal masking. The Prefix LM
|
51 |
+
conversion adds logic to dynamically manipulate these biases to support
|
52 |
+
Prefix LM attention masking.
|
53 |
+
"""
|
54 |
+
attn_modules = []
|
55 |
+
if isinstance(model, GPTNeoXForCausalLM):
|
56 |
+
blocks = model.gpt_neox.layers
|
57 |
+
else:
|
58 |
+
blocks = model.transformer.h
|
59 |
+
for block in blocks:
|
60 |
+
if isinstance(model, GPTNeoForCausalLM):
|
61 |
+
if block.attn.attention_type != "global":
|
62 |
+
continue
|
63 |
+
attn_module = block.attn.attention
|
64 |
+
elif isinstance(model, GPTNeoXForCausalLM):
|
65 |
+
attn_module = block.attention
|
66 |
+
else:
|
67 |
+
attn_module = block.attn
|
68 |
+
attn_modules.append(attn_module)
|
69 |
+
return attn_modules
|
70 |
+
|
71 |
+
setattr(model, "_original_forward", getattr(model, "forward"))
|
72 |
+
setattr(model, "_original_generate", getattr(model, "generate"))
|
73 |
+
|
74 |
+
def forward(
|
75 |
+
self: CAUSAL_GPT_TYPES,
|
76 |
+
input_ids: Optional[torch.LongTensor] = None,
|
77 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
78 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
79 |
+
bidirectional_mask: Optional[torch.Tensor] = None,
|
80 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
81 |
+
position_ids: Optional[torch.LongTensor] = None,
|
82 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
83 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
84 |
+
labels: Optional[torch.LongTensor] = None,
|
85 |
+
use_cache: Optional[bool] = None,
|
86 |
+
output_attentions: Optional[bool] = None,
|
87 |
+
output_hidden_states: Optional[bool] = None,
|
88 |
+
return_dict: Optional[bool] = None,
|
89 |
+
):
|
90 |
+
"""Wraps original forward to enable PrefixLM attention."""
|
91 |
+
|
92 |
+
def call_og_forward():
|
93 |
+
if isinstance(self, GPTNeoXForCausalLM):
|
94 |
+
return self._original_forward(
|
95 |
+
input_ids=input_ids,
|
96 |
+
past_key_values=past_key_values,
|
97 |
+
attention_mask=attention_mask,
|
98 |
+
head_mask=head_mask,
|
99 |
+
inputs_embeds=inputs_embeds,
|
100 |
+
labels=labels,
|
101 |
+
use_cache=use_cache,
|
102 |
+
output_attentions=output_attentions,
|
103 |
+
output_hidden_states=output_hidden_states,
|
104 |
+
return_dict=return_dict,
|
105 |
+
)
|
106 |
+
else:
|
107 |
+
return self._original_forward(
|
108 |
+
input_ids=input_ids,
|
109 |
+
past_key_values=past_key_values,
|
110 |
+
attention_mask=attention_mask,
|
111 |
+
token_type_ids=token_type_ids,
|
112 |
+
position_ids=position_ids,
|
113 |
+
head_mask=head_mask,
|
114 |
+
inputs_embeds=inputs_embeds,
|
115 |
+
labels=labels,
|
116 |
+
use_cache=use_cache,
|
117 |
+
output_attentions=output_attentions,
|
118 |
+
output_hidden_states=output_hidden_states,
|
119 |
+
return_dict=return_dict,
|
120 |
+
)
|
121 |
+
|
122 |
+
if bidirectional_mask is None:
|
123 |
+
return call_og_forward()
|
124 |
+
assert isinstance(bidirectional_mask, torch.Tensor)
|
125 |
+
attn_modules = _get_attn_modules(model)
|
126 |
+
(b, s) = bidirectional_mask.shape
|
127 |
+
max_length = attn_modules[0].bias.shape[-1]
|
128 |
+
if s > max_length:
|
129 |
+
raise ValueError(
|
130 |
+
f"bidirectional_mask sequence length (={s}) exceeds the "
|
131 |
+
+ f"max length allowed by the model ({max_length})."
|
132 |
+
)
|
133 |
+
assert s <= max_length
|
134 |
+
if s < max_length:
|
135 |
+
pad = torch.zeros(
|
136 |
+
(int(b), int(max_length - s)),
|
137 |
+
dtype=bidirectional_mask.dtype,
|
138 |
+
device=bidirectional_mask.device,
|
139 |
+
)
|
140 |
+
bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
|
141 |
+
bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
|
142 |
+
for attn_module in attn_modules:
|
143 |
+
assert isinstance(attn_module.bias, torch.Tensor)
|
144 |
+
attn_module.bias.data = torch.logical_or(
|
145 |
+
attn_module.bias.data, bidirectional
|
146 |
+
)
|
147 |
+
output = call_og_forward()
|
148 |
+
for attn_module in attn_modules:
|
149 |
+
attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
|
150 |
+
return output
|
151 |
+
|
152 |
+
def generate(self: CAUSAL_GPT_TYPES, *args: Any, **kwargs: Any):
|
153 |
+
"""Wraps original generate to enable PrefixLM attention."""
|
154 |
+
attn_modules = _get_attn_modules(model)
|
155 |
+
for attn_module in attn_modules:
|
156 |
+
attn_module.bias.data[:] = 1
|
157 |
+
output = self._original_generate(*args, **kwargs)
|
158 |
+
for attn_module in attn_modules:
|
159 |
+
attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
|
160 |
+
return output
|
161 |
+
|
162 |
+
setattr(model, "forward", MethodType(forward, model))
|
163 |
+
setattr(model, "generate", MethodType(generate, model))
|
164 |
+
setattr(model, "_prefix_lm_converted", True)
|
165 |
+
return model
|
166 |
+
|
167 |
+
|
168 |
+
_SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS
|
169 |
+
CAUSAL_LM_TYPES = Union[
|
170 |
+
GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM
|
171 |
+
]
|
172 |
+
|
173 |
+
|
174 |
+
def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
|
175 |
+
"""Converts a HuggingFace Causal LM to a Prefix LM.
|
176 |
+
|
177 |
+
Supported HuggingFace model classes:
|
178 |
+
- `GPT2LMHeadModel`
|
179 |
+
- `GPTNeoForCausalLM`
|
180 |
+
- `GPTNeoXForCausalLM`
|
181 |
+
- `GPTJForCausalLM`
|
182 |
+
|
183 |
+
Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
|
184 |
+
`generate` method and/or select underlying methods depending on the model class.
|
185 |
+
|
186 |
+
These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask".
|
187 |
+
|
188 |
+
Notes on training:
|
189 |
+
To actually train the converted model as a Prefix LM, training batches will need to indicate
|
190 |
+
the prefix/target structure by including `bidirectional_mask` as part of the batch inputs.
|
191 |
+
|
192 |
+
**This is not a standard input and requires custom layers either within or after your dataloader.**
|
193 |
+
|
194 |
+
In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels`
|
195 |
+
such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`.
|
196 |
+
That is, the prefix portion of the sequence should not generate any loss. Loss should only be
|
197 |
+
generated by the target portion of the sequence.
|
198 |
+
|
199 |
+
Notes on `GPTNeoForCausalLM`:
|
200 |
+
To simplify the implementation, "global" and "local" attention layers are handled differently.
|
201 |
+
For "global" layers, we handle conversion as described above. For "local" layers, which use a
|
202 |
+
causal attention mask within a restricted local window, we do not alter the masking.
|
203 |
+
|
204 |
+
Notes on `forward` method conversion:
|
205 |
+
After conversion, the `forward` method will handle a new input, `bidirectional_mask`,
|
206 |
+
which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions
|
207 |
+
belonging to the prefix (prefix tokens can attend to one another bidirectionally), and
|
208 |
+
0 indicates token positions belonging to the target.
|
209 |
+
|
210 |
+
The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing
|
211 |
+
causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset
|
212 |
+
the causal masks before returning the result.
|
213 |
+
|
214 |
+
Notes on `generate` method conversion:
|
215 |
+
After conversion, the `generate` method will have the same signature but will internally
|
216 |
+
convert all causal masks to be purely bidirectional, call the original `generate` method, and
|
217 |
+
(where appropriate) reset the causal masks before returning the result.
|
218 |
+
|
219 |
+
This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token
|
220 |
+
"prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates
|
221 |
+
each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one
|
222 |
+
another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and
|
223 |
+
previously-generated tokens (also as expected in a Prefix LM).
|
224 |
+
|
225 |
+
To preserve the API, the original methods are renamed to `_original_forward` and
|
226 |
+
`_original_generate`, and replaced with new `forward` and `generate` methods that wrap
|
227 |
+
them, respectively. Although implementation details vary by model class.
|
228 |
+
"""
|
229 |
+
if isinstance(model, _SUPPORTED_GPT_MODELS):
|
230 |
+
return _convert_gpt_causal_lm_to_prefix_lm(model)
|
231 |
+
else:
|
232 |
+
raise TypeError(
|
233 |
+
f"Cannot convert model to Prefix LM. "
|
234 |
+
+ f"Model does not belong to set of supported HF models:"
|
235 |
+
+ f"\n{_SUPPORTED_HF_MODELS}"
|
236 |
+
)
|
237 |
+
|
238 |
+
|
239 |
+
def add_bidirectional_mask_if_missing(batch: MutableMapping):
|
240 |
+
"""Attempts to add bidirectional_mask to batch if missing.
|
241 |
+
|
242 |
+
Raises:
|
243 |
+
KeyError if bidirectional_mask is missing and can't be inferred
|
244 |
+
"""
|
245 |
+
if "bidirectional_mask" not in batch:
|
246 |
+
if batch.get("mode", None) == "icl_task":
|
247 |
+
batch["bidirectional_mask"] = batch["attention_mask"].clone()
|
248 |
+
for i, continuation_indices in enumerate(batch["continuation_indices"]):
|
249 |
+
batch["bidirectional_mask"][i, continuation_indices] = 0
|
250 |
+
elif "labels" in batch and "attention_mask" in batch:
|
251 |
+
batch["bidirectional_mask"] = torch.logical_and(
|
252 |
+
torch.eq(batch["attention_mask"], 1), torch.eq(batch["labels"], -100)
|
253 |
+
).type_as(batch["attention_mask"])
|
254 |
+
else:
|
255 |
+
raise KeyError(
|
256 |
+
"No bidirectional_mask in batch and not sure how to construct one."
|
257 |
+
)
|
meta_init_context.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import contextmanager
|
2 |
+
from typing import Any, Callable, Optional
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
@contextmanager
|
8 |
+
def init_empty_weights(include_buffers: bool = False):
|
9 |
+
"""Meta initialization context manager.
|
10 |
+
|
11 |
+
A context manager under which models are initialized with all parameters
|
12 |
+
on the meta device, therefore creating an empty model. Useful when just
|
13 |
+
initializing the model would blow the available RAM.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
include_buffers (`bool`, *optional*, defaults to `False`): Whether or
|
17 |
+
not to also put all buffers on the meta device while initializing.
|
18 |
+
|
19 |
+
Example:
|
20 |
+
```python
|
21 |
+
import torch.nn as nn
|
22 |
+
|
23 |
+
# Initialize a model with 100 billions parameters in no time and without using any RAM.
|
24 |
+
with init_empty_weights():
|
25 |
+
tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
|
26 |
+
```
|
27 |
+
|
28 |
+
<Tip warning={true}>
|
29 |
+
|
30 |
+
Any model created under this context manager has no weights. As such you can't do something like
|
31 |
+
`model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
|
32 |
+
|
33 |
+
</Tip>
|
34 |
+
"""
|
35 |
+
with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
|
36 |
+
yield f
|
37 |
+
|
38 |
+
|
39 |
+
@contextmanager
|
40 |
+
def init_on_device(device: torch.device, include_buffers: bool = False):
|
41 |
+
"""Device initialization context manager.
|
42 |
+
|
43 |
+
A context manager under which models are initialized with all parameters
|
44 |
+
on the specified device.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
device (`torch.device`): Device to initialize all parameters on.
|
48 |
+
include_buffers (`bool`, *optional*, defaults to `False`): Whether or
|
49 |
+
not to also put all buffers on the meta device while initializing.
|
50 |
+
|
51 |
+
Example:
|
52 |
+
```python
|
53 |
+
import torch.nn as nn
|
54 |
+
|
55 |
+
with init_on_device(device=torch.device("cuda")):
|
56 |
+
tst = nn.Liner(100, 100) # on `cuda` device
|
57 |
+
```
|
58 |
+
"""
|
59 |
+
old_register_parameter = nn.Module.register_parameter
|
60 |
+
if include_buffers:
|
61 |
+
old_register_buffer = nn.Module.register_buffer
|
62 |
+
|
63 |
+
def register_empty_parameter(
|
64 |
+
self: torch.nn.Module, name: str, param: Optional[torch.nn.Parameter]
|
65 |
+
):
|
66 |
+
old_register_parameter(self, name, param)
|
67 |
+
if param is not None:
|
68 |
+
parameter = self._parameters[name]
|
69 |
+
assert parameter is not None
|
70 |
+
param_cls = type(parameter)
|
71 |
+
kwargs = parameter.__dict__
|
72 |
+
self._parameters[name] = param_cls(parameter.to(device), **kwargs)
|
73 |
+
|
74 |
+
def register_empty_buffer(
|
75 |
+
self: torch.nn.Module,
|
76 |
+
name: str,
|
77 |
+
tensor: Optional[torch.Tensor],
|
78 |
+
persistent: bool = True,
|
79 |
+
):
|
80 |
+
old_register_buffer(self, name, tensor, persistent=persistent)
|
81 |
+
if tensor is not None:
|
82 |
+
named_buffer = self._buffers[name]
|
83 |
+
assert named_buffer is not None
|
84 |
+
self._buffers[name] = named_buffer.to(device)
|
85 |
+
|
86 |
+
if include_buffers:
|
87 |
+
tensor_constructors_to_patch = {
|
88 |
+
torch_function_name: getattr(torch, torch_function_name)
|
89 |
+
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
90 |
+
}
|
91 |
+
else:
|
92 |
+
tensor_constructors_to_patch = {}
|
93 |
+
|
94 |
+
def patch_tensor_constructor(fn: Callable):
|
95 |
+
|
96 |
+
def wrapper(*args: Any, **kwargs: Any):
|
97 |
+
kwargs["device"] = device
|
98 |
+
return fn(*args, **kwargs)
|
99 |
+
|
100 |
+
return wrapper
|
101 |
+
|
102 |
+
try:
|
103 |
+
nn.Module.register_parameter = register_empty_parameter
|
104 |
+
if include_buffers:
|
105 |
+
nn.Module.register_buffer = register_empty_buffer
|
106 |
+
for torch_function_name in tensor_constructors_to_patch.keys():
|
107 |
+
setattr(
|
108 |
+
torch,
|
109 |
+
torch_function_name,
|
110 |
+
patch_tensor_constructor(getattr(torch, torch_function_name)),
|
111 |
+
)
|
112 |
+
yield
|
113 |
+
finally:
|
114 |
+
nn.Module.register_parameter = old_register_parameter
|
115 |
+
if include_buffers:
|
116 |
+
nn.Module.register_buffer = old_register_buffer
|
117 |
+
for (
|
118 |
+
torch_function_name,
|
119 |
+
old_torch_function,
|
120 |
+
) in tensor_constructors_to_patch.items():
|
121 |
+
setattr(torch, torch_function_name, old_torch_function)
|
modeling_mpt.py
ADDED
@@ -0,0 +1,907 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A simple, flexible implementation of a GPT model.
|
2 |
+
|
3 |
+
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
import math
|
8 |
+
import warnings
|
9 |
+
from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, Union
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from .attention import is_flash_v1_installed, is_flash_v2_installed
|
14 |
+
|
15 |
+
if is_flash_v2_installed():
|
16 |
+
try:
|
17 |
+
from flash_attn import bert_padding
|
18 |
+
from flash_attn.layers.rotary import RotaryEmbedding as DAILRotaryEmbedding
|
19 |
+
except Exception as e:
|
20 |
+
raise e
|
21 |
+
if is_flash_v1_installed():
|
22 |
+
try:
|
23 |
+
from flash_attn import bert_padding
|
24 |
+
except Exception as e:
|
25 |
+
raise e
|
26 |
+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
27 |
+
from transformers.modeling_outputs import (
|
28 |
+
BaseModelOutputWithPast,
|
29 |
+
CausalLMOutputWithPast,
|
30 |
+
)
|
31 |
+
from transformers.models.llama.modeling_llama import (
|
32 |
+
LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding,
|
33 |
+
)
|
34 |
+
from transformers.models.llama.modeling_llama import (
|
35 |
+
LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding,
|
36 |
+
)
|
37 |
+
from transformers.models.llama.modeling_llama import (
|
38 |
+
LlamaRotaryEmbedding as HFRotaryEmbedding,
|
39 |
+
)
|
40 |
+
from .attention import ATTN_CLASS_REGISTRY, attn_bias_shape, build_attn_bias, gen_slopes
|
41 |
+
from .blocks import MPTBlock
|
42 |
+
from .custom_embedding import SharedEmbedding
|
43 |
+
from .fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY
|
44 |
+
from .ffn import FFN_CLASS_REGISTRY as FFN_CLASS_REGISTRY
|
45 |
+
from .ffn import MPTMLP as MPTMLP
|
46 |
+
from .ffn import build_ffn as build_ffn
|
47 |
+
from .norm import NORM_CLASS_REGISTRY
|
48 |
+
from .configuration_mpt import MPTConfig
|
49 |
+
from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
|
50 |
+
from .hf_prefixlm_converter import (
|
51 |
+
add_bidirectional_mask_if_missing,
|
52 |
+
convert_hf_causal_lm_to_prefix_lm,
|
53 |
+
)
|
54 |
+
from .meta_init_context import init_empty_weights
|
55 |
+
from .param_init_fns import generic_param_init_fn_, MODEL_INIT_REGISTRY
|
56 |
+
|
57 |
+
try:
|
58 |
+
from .flash_attn_triton import flash_attn_func as flash_attn_func
|
59 |
+
except:
|
60 |
+
pass
|
61 |
+
import logging
|
62 |
+
|
63 |
+
log = logging.getLogger(__name__)
|
64 |
+
|
65 |
+
|
66 |
+
def gen_rotary_embedding(
|
67 |
+
rope_head_dim: int,
|
68 |
+
rope_impl: str,
|
69 |
+
rope_theta: int,
|
70 |
+
rope_dail_config: dict,
|
71 |
+
rope_hf_config: dict,
|
72 |
+
max_seq_len: int,
|
73 |
+
):
|
74 |
+
if rope_impl == "dail":
|
75 |
+
return DAILRotaryEmbedding(
|
76 |
+
dim=rope_head_dim,
|
77 |
+
base=rope_theta,
|
78 |
+
interleaved=False,
|
79 |
+
scale_base=(
|
80 |
+
rope_dail_config["xpos_scale_base"]
|
81 |
+
if rope_dail_config["type"] == "xpos"
|
82 |
+
else None
|
83 |
+
),
|
84 |
+
pos_idx_in_fp32=rope_dail_config["pos_idx_in_fp32"],
|
85 |
+
device="cpu",
|
86 |
+
)
|
87 |
+
elif rope_impl == "hf":
|
88 |
+
if rope_hf_config["type"] == "no_scaling":
|
89 |
+
return HFRotaryEmbedding(
|
90 |
+
rope_head_dim,
|
91 |
+
max_position_embeddings=max_seq_len,
|
92 |
+
base=rope_theta,
|
93 |
+
device="cpu",
|
94 |
+
)
|
95 |
+
elif rope_hf_config["type"] == "linear":
|
96 |
+
return HFLinearScalingRotaryEmbedding(
|
97 |
+
rope_head_dim,
|
98 |
+
max_position_embeddings=max_seq_len,
|
99 |
+
base=rope_theta,
|
100 |
+
scaling_factor=rope_hf_config["factor"],
|
101 |
+
device="cpu",
|
102 |
+
)
|
103 |
+
elif rope_hf_config["type"] == "dynamic":
|
104 |
+
return HFDynamicNTKScalingRotaryEmbedding(
|
105 |
+
rope_head_dim,
|
106 |
+
max_position_embeddings=max_seq_len,
|
107 |
+
base=rope_theta,
|
108 |
+
scaling_factor=rope_hf_config["factor"],
|
109 |
+
device="cpu",
|
110 |
+
)
|
111 |
+
raise ValueError("rope_impl needs to be either dail or hf")
|
112 |
+
|
113 |
+
|
114 |
+
def gen_attention_mask_in_length(
|
115 |
+
sequence_id: Union[None, torch.Tensor],
|
116 |
+
S: int,
|
117 |
+
attn_uses_sequence_id: bool,
|
118 |
+
attn_impl: str,
|
119 |
+
attention_mask: Union[torch.Tensor, None],
|
120 |
+
):
|
121 |
+
"""Generates the attention mask used for sequence masking in FA v2.
|
122 |
+
|
123 |
+
Only supports sequence id based sparse attention for no attention masking or attention masking with right padding.
|
124 |
+
In case of left padding:
|
125 |
+
1. Training with left padding is not supported in MPT (see https://github.com/mosaicml/llm-foundry/blob/1eecd4cb8e734499f77f6a35f657b8b20c0adfcb/llmfoundry/models/mpt/modeling_mpt.py#L407).
|
126 |
+
2. For generation with left padding, we only have a single sequence id per sample, so we don't need sequence id based sparse attention.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
sequence_id (Union[None, torch.Tensor]): Tensor containing the sequence id for each token. Shape (batch_size, seq_len).
|
130 |
+
S (int): Sequence length
|
131 |
+
attn_uses_sequence_id (bool): Whether the attention uses sequence id based masking.
|
132 |
+
attn_impl (str): Attention implementation. This function is only creates attention_mask_in_length for flash attention.
|
133 |
+
attention_mask (Union[torch.Tensor, None]): Attention mask tensor of shape (batch_size, seq_len)
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
|
137 |
+
```
|
138 |
+
[
|
139 |
+
[2, 3, 0, 0, 0, 0],
|
140 |
+
[3, 2, 0, 0, 0, 0],
|
141 |
+
[6, 0, 0, 0, 0, 0]
|
142 |
+
]
|
143 |
+
```
|
144 |
+
, which refers to the 3D-attention mask:
|
145 |
+
```
|
146 |
+
[
|
147 |
+
[
|
148 |
+
[1, 0, 0, 0, 0, 0],
|
149 |
+
[1, 1, 0, 0, 0, 0],
|
150 |
+
[0, 0, 1, 0, 0, 0],
|
151 |
+
[0, 0, 1, 1, 0, 0],
|
152 |
+
[0, 0, 1, 1, 1, 0],
|
153 |
+
[0, 0, 0, 0, 0, 1]
|
154 |
+
],
|
155 |
+
[
|
156 |
+
[1, 0, 0, 0, 0, 0],
|
157 |
+
[1, 1, 0, 0, 0, 0],
|
158 |
+
[1, 1, 1, 0, 0, 0],
|
159 |
+
[0, 0, 0, 1, 0, 0],
|
160 |
+
[0, 0, 0, 1, 1, 0],
|
161 |
+
[0, 0, 0, 0, 0, 1]
|
162 |
+
],
|
163 |
+
[
|
164 |
+
[1, 0, 0, 0, 0, 0],
|
165 |
+
[1, 1, 0, 0, 0, 0],
|
166 |
+
[1, 1, 1, 0, 0, 0],
|
167 |
+
[1, 1, 1, 1, 0, 0],
|
168 |
+
[1, 1, 1, 1, 1, 0],
|
169 |
+
[1, 1, 1, 1, 1, 1]
|
170 |
+
]
|
171 |
+
]
|
172 |
+
```.
|
173 |
+
(The description above is taken verbatim from https://github.com/Dao-AILab/flash-attention/blob/9356a1c0389660d7e231ff3163c1ac17d9e3824a/flash_attn/bert_padding.py#L125 .)
|
174 |
+
"""
|
175 |
+
attention_mask_in_length = None
|
176 |
+
if sequence_id is not None and attn_uses_sequence_id and (attn_impl == "flash"):
|
177 |
+
if (
|
178 |
+
attention_mask is not None
|
179 |
+
and attention_mask[:, 0].sum() != attention_mask.shape[0]
|
180 |
+
):
|
181 |
+
raise NotImplementedError(
|
182 |
+
"Left padding is not supported with flash attention when attn_uses_sequence_id is set to True."
|
183 |
+
)
|
184 |
+
if S != sequence_id.shape[-1]:
|
185 |
+
raise ValueError(
|
186 |
+
f"Sequence length ({S}) does not match length of sequences in sequence_id ({sequence_id.shape[-1]})."
|
187 |
+
)
|
188 |
+
if attention_mask is not None:
|
189 |
+
sequence_id = sequence_id.masked_fill(~attention_mask, 0)
|
190 |
+
attention_mask_in_length = torch.nn.functional.one_hot(sequence_id)
|
191 |
+
if attention_mask is not None:
|
192 |
+
attention_mask_in_length = attention_mask_in_length.masked_fill(
|
193 |
+
~attention_mask.unsqueeze(-1), 0
|
194 |
+
)
|
195 |
+
attention_mask_in_length = attention_mask_in_length.sum(dim=1)
|
196 |
+
attention_mask_in_length = torch.nn.functional.pad(
|
197 |
+
attention_mask_in_length,
|
198 |
+
(0, S - attention_mask_in_length.shape[-1]),
|
199 |
+
mode="constant",
|
200 |
+
value=0,
|
201 |
+
)
|
202 |
+
return attention_mask_in_length
|
203 |
+
|
204 |
+
|
205 |
+
def gen_flash_attn_padding_info(
|
206 |
+
bsz: int,
|
207 |
+
S: int,
|
208 |
+
past_key_len: int,
|
209 |
+
device: torch.device,
|
210 |
+
attention_mask_in_length: Optional[torch.Tensor] = None,
|
211 |
+
attention_mask: Optional[torch.Tensor] = None,
|
212 |
+
):
|
213 |
+
flash_attn_padding_info = {}
|
214 |
+
if attention_mask_in_length is None:
|
215 |
+
key_padding_mask = attention_mask
|
216 |
+
if key_padding_mask is None:
|
217 |
+
key_padding_mask = torch.ones(
|
218 |
+
(bsz, past_key_len + S), dtype=torch.bool, device=device
|
219 |
+
)
|
220 |
+
query_padding_mask = key_padding_mask[:, -S:]
|
221 |
+
unpadding_function = bert_padding.unpad_input
|
222 |
+
else:
|
223 |
+
key_padding_mask = attention_mask_in_length
|
224 |
+
query_padding_mask = attention_mask_in_length
|
225 |
+
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
|
226 |
+
(_, indices_q, cu_seqlens_q, max_seqlen_q) = unpadding_function(
|
227 |
+
torch.empty(bsz, S, 1, device=device), query_padding_mask
|
228 |
+
)
|
229 |
+
(_, indices_k, cu_seqlens_k, max_seqlen_k) = unpadding_function(
|
230 |
+
torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask
|
231 |
+
)
|
232 |
+
(_, indices_v, _, _) = unpadding_function(
|
233 |
+
torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask
|
234 |
+
)
|
235 |
+
flash_attn_padding_info["indices_q"] = indices_q
|
236 |
+
flash_attn_padding_info["indices_k"] = indices_k
|
237 |
+
flash_attn_padding_info["indices_v"] = indices_v
|
238 |
+
flash_attn_padding_info["cu_seqlens_q"] = cu_seqlens_q
|
239 |
+
flash_attn_padding_info["cu_seqlens_k"] = cu_seqlens_k
|
240 |
+
flash_attn_padding_info["max_seqlen_q"] = max_seqlen_q
|
241 |
+
flash_attn_padding_info["max_seqlen_k"] = max_seqlen_k
|
242 |
+
return flash_attn_padding_info
|
243 |
+
|
244 |
+
|
245 |
+
def apply_sequence_id(
|
246 |
+
attn_bias: torch.Tensor, sequence_id: torch.LongTensor, max_seq_len: int
|
247 |
+
) -> torch.Tensor:
|
248 |
+
seq_len = sequence_id.shape[-1]
|
249 |
+
if seq_len > max_seq_len:
|
250 |
+
raise ValueError(
|
251 |
+
f"sequence_id sequence length cannot exceed max_seq_len={max_seq_len}"
|
252 |
+
)
|
253 |
+
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
254 |
+
cannot_attend = torch.logical_not(
|
255 |
+
torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))
|
256 |
+
).unsqueeze(1)
|
257 |
+
min_val = torch.finfo(attn_bias.dtype).min
|
258 |
+
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
259 |
+
return attn_bias
|
260 |
+
|
261 |
+
|
262 |
+
class MPTPreTrainedModel(PreTrainedModel):
|
263 |
+
config_class = MPTConfig
|
264 |
+
base_model_prefix = "model"
|
265 |
+
_no_split_modules = ["MPTBlock"]
|
266 |
+
_supports_flash_attn_2 = True
|
267 |
+
supports_gradient_checkpointing = True
|
268 |
+
|
269 |
+
|
270 |
+
def _fsdp_wrap_fn(self: Union[MPTModel, MPTForCausalLM], module: nn.Module) -> bool:
|
271 |
+
return isinstance(module, MPTBlock)
|
272 |
+
|
273 |
+
|
274 |
+
class MPTModel(MPTPreTrainedModel):
|
275 |
+
|
276 |
+
def __init__(self, config: MPTConfig):
|
277 |
+
config._validate_config()
|
278 |
+
super().__init__(config)
|
279 |
+
self.gradient_checkpointing = False
|
280 |
+
self.attn_impl = config.attn_config["attn_impl"]
|
281 |
+
self.prefix_lm = config.attn_config["prefix_lm"]
|
282 |
+
self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
|
283 |
+
self.alibi = config.attn_config["alibi"]
|
284 |
+
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
285 |
+
self.learned_pos_emb = config.learned_pos_emb
|
286 |
+
if config.init_device == "mixed":
|
287 |
+
if dist.get_local_rank() == 0:
|
288 |
+
config.init_device = "cpu"
|
289 |
+
else:
|
290 |
+
config.init_device = "meta"
|
291 |
+
if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
|
292 |
+
norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
|
293 |
+
raise NotImplementedError(
|
294 |
+
f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})."
|
295 |
+
)
|
296 |
+
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
|
297 |
+
self.embedding_fraction = config.embedding_fraction
|
298 |
+
self.wte = SharedEmbedding(
|
299 |
+
config.vocab_size, config.d_model, device=config.init_device
|
300 |
+
)
|
301 |
+
if self.learned_pos_emb:
|
302 |
+
self.wpe = torch.nn.Embedding(
|
303 |
+
config.max_seq_len, config.d_model, device=config.init_device
|
304 |
+
)
|
305 |
+
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
306 |
+
self.blocks = nn.ModuleList(
|
307 |
+
[
|
308 |
+
MPTBlock(device=config.init_device, **config.to_dict())
|
309 |
+
for _ in range(config.n_layers)
|
310 |
+
]
|
311 |
+
)
|
312 |
+
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
313 |
+
self.rope = config.attn_config["rope"]
|
314 |
+
self.rope_impl = None
|
315 |
+
if self.rope:
|
316 |
+
self.rope_impl = config.attn_config["rope_impl"]
|
317 |
+
self.rotary_embedding = gen_rotary_embedding(
|
318 |
+
rope_head_dim=config.d_model // config.n_heads,
|
319 |
+
rope_impl=self.rope_impl,
|
320 |
+
rope_theta=config.attn_config["rope_theta"],
|
321 |
+
rope_dail_config=config.attn_config["rope_dail_config"],
|
322 |
+
rope_hf_config=config.attn_config["rope_hf_config"],
|
323 |
+
max_seq_len=self.config.max_seq_len,
|
324 |
+
)
|
325 |
+
if config.init_device != "meta":
|
326 |
+
log.info(
|
327 |
+
f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.'
|
328 |
+
)
|
329 |
+
self.apply(self.param_init_fn)
|
330 |
+
self.is_causal = not self.prefix_lm
|
331 |
+
self._attn_bias_initialized = False
|
332 |
+
self.attn_bias = None
|
333 |
+
self.attn_bias_shape = attn_bias_shape(
|
334 |
+
self.attn_impl,
|
335 |
+
config.n_heads,
|
336 |
+
config.max_seq_len,
|
337 |
+
self.alibi,
|
338 |
+
prefix_lm=self.prefix_lm,
|
339 |
+
causal=self.is_causal,
|
340 |
+
use_sequence_id=self.attn_uses_sequence_id,
|
341 |
+
)
|
342 |
+
if config.no_bias:
|
343 |
+
for module in self.modules():
|
344 |
+
if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
|
345 |
+
log.info(f"Removing bias from module={module!r}.")
|
346 |
+
module.register_parameter("bias", None)
|
347 |
+
if hasattr(module, "use_bias"):
|
348 |
+
log.info(f"Setting use_bias=False for module={module!r}.")
|
349 |
+
module.use_bias = False
|
350 |
+
log.debug(self)
|
351 |
+
log.debug(f"Using {self.config.init_config['name']} initialization.")
|
352 |
+
|
353 |
+
def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]:
|
354 |
+
return self.wte
|
355 |
+
|
356 |
+
def set_input_embeddings(self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
|
357 |
+
self.wte = value
|
358 |
+
|
359 |
+
@torch.no_grad()
|
360 |
+
def _attn_bias(
|
361 |
+
self,
|
362 |
+
device: torch.device,
|
363 |
+
dtype: torch.dtype,
|
364 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
365 |
+
prefix_mask: Optional[torch.ByteTensor] = None,
|
366 |
+
sequence_id: Optional[torch.LongTensor] = None,
|
367 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.ByteTensor]]:
|
368 |
+
if not self._attn_bias_initialized:
|
369 |
+
if self.attn_bias_shape:
|
370 |
+
self.attn_bias = torch.zeros(
|
371 |
+
self.attn_bias_shape, device=device, dtype=dtype
|
372 |
+
)
|
373 |
+
self.attn_bias = build_attn_bias(
|
374 |
+
self.attn_impl,
|
375 |
+
self.attn_bias,
|
376 |
+
self.config.n_heads,
|
377 |
+
self.config.max_seq_len,
|
378 |
+
causal=self.is_causal,
|
379 |
+
alibi=self.alibi,
|
380 |
+
alibi_bias_max=self.alibi_bias_max,
|
381 |
+
)
|
382 |
+
self._attn_bias_initialized = True
|
383 |
+
if self.attn_impl == "flash":
|
384 |
+
return (self.attn_bias, attention_mask)
|
385 |
+
if self.attn_bias is not None:
|
386 |
+
self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
|
387 |
+
attn_bias = self.attn_bias
|
388 |
+
if self.prefix_lm:
|
389 |
+
assert isinstance(attn_bias, torch.Tensor)
|
390 |
+
assert isinstance(prefix_mask, torch.Tensor)
|
391 |
+
attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
|
392 |
+
if self.attn_uses_sequence_id and sequence_id is not None:
|
393 |
+
assert isinstance(attn_bias, torch.Tensor)
|
394 |
+
attn_bias = apply_sequence_id(
|
395 |
+
attn_bias, sequence_id, self.config.max_seq_len
|
396 |
+
)
|
397 |
+
if attention_mask is not None:
|
398 |
+
s_k = attention_mask.shape[-1]
|
399 |
+
if attn_bias is None:
|
400 |
+
attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
|
401 |
+
else:
|
402 |
+
_s_k = max(0, attn_bias.size(-1) - s_k)
|
403 |
+
attn_bias = attn_bias[:, :, :, _s_k:]
|
404 |
+
if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
|
405 |
+
raise ValueError(
|
406 |
+
f"attention_mask shape={attention_mask.shape} "
|
407 |
+
+ f"and prefix_mask shape={prefix_mask.shape} are not equal."
|
408 |
+
)
|
409 |
+
min_val = torch.finfo(attn_bias.dtype).min
|
410 |
+
attn_bias = attn_bias.masked_fill(
|
411 |
+
~attention_mask.view(-1, 1, 1, s_k), min_val
|
412 |
+
)
|
413 |
+
return (attn_bias, attention_mask)
|
414 |
+
|
415 |
+
def _apply_prefix_mask(
|
416 |
+
self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor
|
417 |
+
) -> torch.Tensor:
|
418 |
+
(s_k, s_q) = attn_bias.shape[-2:]
|
419 |
+
if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
|
420 |
+
raise ValueError(
|
421 |
+
"attn_bias does not match the expected shape. "
|
422 |
+
+ f"The last two dimensions should both be {self.config.max_length} "
|
423 |
+
+ f"but are {s_k} and {s_q}."
|
424 |
+
)
|
425 |
+
seq_len = prefix_mask.shape[-1]
|
426 |
+
if seq_len > self.config.max_seq_len:
|
427 |
+
raise ValueError(
|
428 |
+
f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
|
429 |
+
)
|
430 |
+
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
431 |
+
causal = torch.tril(
|
432 |
+
torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)
|
433 |
+
).view(1, 1, seq_len, seq_len)
|
434 |
+
prefix = prefix_mask.view(-1, 1, 1, seq_len)
|
435 |
+
cannot_attend = ~torch.logical_or(causal, prefix.bool())
|
436 |
+
min_val = torch.finfo(attn_bias.dtype).min
|
437 |
+
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
438 |
+
return attn_bias
|
439 |
+
|
440 |
+
def forward(
|
441 |
+
self,
|
442 |
+
input_ids: Optional[torch.LongTensor] = None,
|
443 |
+
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
444 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
445 |
+
prefix_mask: Optional[torch.ByteTensor] = None,
|
446 |
+
sequence_id: Optional[torch.LongTensor] = None,
|
447 |
+
return_dict: Optional[bool] = None,
|
448 |
+
output_attentions: Optional[bool] = None,
|
449 |
+
output_hidden_states: Optional[bool] = None,
|
450 |
+
use_cache: Optional[bool] = None,
|
451 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
452 |
+
) -> BaseModelOutputWithPast:
|
453 |
+
return_dict = (
|
454 |
+
return_dict if return_dict is not None else self.config.return_dict
|
455 |
+
)
|
456 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
457 |
+
if attention_mask is not None:
|
458 |
+
attention_mask = attention_mask.bool()
|
459 |
+
if prefix_mask is not None:
|
460 |
+
prefix_mask = prefix_mask.bool()
|
461 |
+
if not return_dict:
|
462 |
+
raise NotImplementedError(
|
463 |
+
"return_dict False is not implemented yet for MPT"
|
464 |
+
)
|
465 |
+
if output_attentions:
|
466 |
+
if self.attn_impl != "torch":
|
467 |
+
raise NotImplementedError(
|
468 |
+
"output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`."
|
469 |
+
)
|
470 |
+
if (
|
471 |
+
self.training
|
472 |
+
and attention_mask is not None
|
473 |
+
and (attention_mask[:, 0].sum() != attention_mask.shape[0])
|
474 |
+
):
|
475 |
+
raise NotImplementedError(
|
476 |
+
"MPT does not support training with left padding."
|
477 |
+
)
|
478 |
+
if self.prefix_lm and prefix_mask is None:
|
479 |
+
raise ValueError(
|
480 |
+
"prefix_mask is a required argument when MPT is configured with prefix_lm=True."
|
481 |
+
)
|
482 |
+
if self.training:
|
483 |
+
if self.attn_uses_sequence_id and sequence_id is None:
|
484 |
+
raise ValueError(
|
485 |
+
"sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True "
|
486 |
+
+ "and the model is in train mode."
|
487 |
+
)
|
488 |
+
elif self.attn_uses_sequence_id is False and sequence_id is not None:
|
489 |
+
warnings.warn(
|
490 |
+
"MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
|
491 |
+
+ "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
|
492 |
+
)
|
493 |
+
|
494 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
495 |
+
warnings.warn(
|
496 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
497 |
+
)
|
498 |
+
use_cache = False
|
499 |
+
|
500 |
+
if input_ids is not None and inputs_embeds is not None:
|
501 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds.")
|
502 |
+
elif input_ids is not None:
|
503 |
+
bsz = input_ids.size(0)
|
504 |
+
S = input_ids.size(1)
|
505 |
+
x = self.wte(input_ids)
|
506 |
+
input_device = input_ids.device
|
507 |
+
elif inputs_embeds is not None:
|
508 |
+
bsz = inputs_embeds.size(0)
|
509 |
+
S = inputs_embeds.size(1)
|
510 |
+
x = inputs_embeds
|
511 |
+
input_device = inputs_embeds.device
|
512 |
+
else:
|
513 |
+
raise ValueError("You must specify input_ids or inputs_embeds")
|
514 |
+
assert (
|
515 |
+
S <= self.config.max_seq_len
|
516 |
+
), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
|
517 |
+
rotary_emb_w_meta_info = None
|
518 |
+
past_position = 0
|
519 |
+
if past_key_values is not None:
|
520 |
+
if len(past_key_values) != self.config.n_layers:
|
521 |
+
raise ValueError(
|
522 |
+
f"past_key_values must provide a past_key_value for each attention "
|
523 |
+
+ f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
|
524 |
+
)
|
525 |
+
past_position = past_key_values[0][0].size(1)
|
526 |
+
if self.attn_impl == "torch":
|
527 |
+
past_position = past_key_values[0][0].size(3)
|
528 |
+
if self.learned_pos_emb or self.rope:
|
529 |
+
if self.learned_pos_emb and S + past_position > self.config.max_seq_len:
|
530 |
+
raise ValueError(
|
531 |
+
f"Cannot forward input with past sequence length {past_position} and current sequence length "
|
532 |
+
+ f"{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
|
533 |
+
)
|
534 |
+
if self.learned_pos_emb or (self.rope and self.rope_impl == "hf"):
|
535 |
+
pos = torch.arange(
|
536 |
+
past_position,
|
537 |
+
S + past_position,
|
538 |
+
dtype=torch.long,
|
539 |
+
device=input_device,
|
540 |
+
).unsqueeze(0)
|
541 |
+
if attention_mask is not None:
|
542 |
+
pos = torch.clamp(
|
543 |
+
pos
|
544 |
+
- torch.cumsum((~attention_mask).to(torch.int32), dim=1)[
|
545 |
+
:, past_position:
|
546 |
+
],
|
547 |
+
min=0,
|
548 |
+
)
|
549 |
+
if self.learned_pos_emb:
|
550 |
+
x = x + self.wpe(pos)
|
551 |
+
elif self.rope and self.rope_impl == "hf":
|
552 |
+
rotary_emb_w_meta_info = {
|
553 |
+
"impl": self.rope_impl,
|
554 |
+
"rotary_emb": self.rotary_embedding,
|
555 |
+
"offset_info": pos,
|
556 |
+
"seq_len": S + past_position,
|
557 |
+
}
|
558 |
+
elif self.rope and self.rope_impl == "dail":
|
559 |
+
rotary_emb_w_meta_info = {
|
560 |
+
"impl": self.rope_impl,
|
561 |
+
"rotary_emb": self.rotary_embedding,
|
562 |
+
"offset_info": past_position,
|
563 |
+
"seq_len": S + past_position,
|
564 |
+
}
|
565 |
+
if self.embedding_fraction == 1:
|
566 |
+
x = self.emb_drop(x)
|
567 |
+
else:
|
568 |
+
x_shrunk = x * self.embedding_fraction + x.detach() * (
|
569 |
+
1 - self.embedding_fraction
|
570 |
+
)
|
571 |
+
assert isinstance(self.emb_drop, nn.Module)
|
572 |
+
x = self.emb_drop(x_shrunk)
|
573 |
+
(attn_bias, attention_mask) = self._attn_bias(
|
574 |
+
device=x.device,
|
575 |
+
dtype=torch.float32,
|
576 |
+
attention_mask=attention_mask,
|
577 |
+
prefix_mask=prefix_mask,
|
578 |
+
sequence_id=sequence_id,
|
579 |
+
)
|
580 |
+
attention_mask_in_length = gen_attention_mask_in_length(
|
581 |
+
sequence_id=sequence_id,
|
582 |
+
S=S,
|
583 |
+
attn_uses_sequence_id=self.attn_uses_sequence_id,
|
584 |
+
attn_impl=self.attn_impl,
|
585 |
+
attention_mask=attention_mask,
|
586 |
+
)
|
587 |
+
alibi_slopes = None
|
588 |
+
if self.alibi and self.attn_impl == "flash":
|
589 |
+
alibi_slopes = gen_slopes(
|
590 |
+
n_heads=self.config.n_heads,
|
591 |
+
alibi_bias_max=self.alibi_bias_max,
|
592 |
+
device=x.device,
|
593 |
+
return_1d=True,
|
594 |
+
)
|
595 |
+
presents = () if use_cache else None
|
596 |
+
if use_cache and past_key_values is None:
|
597 |
+
past_key_values = [() for _ in range(self.config.n_layers)]
|
598 |
+
all_hidden_states = () if output_hidden_states else None
|
599 |
+
all_self_attns = () if output_attentions else None
|
600 |
+
flash_attn_padding_info = {}
|
601 |
+
if self.attn_impl == "flash":
|
602 |
+
flash_attn_padding_info = gen_flash_attn_padding_info(
|
603 |
+
bsz,
|
604 |
+
S,
|
605 |
+
past_position,
|
606 |
+
x.device,
|
607 |
+
attention_mask_in_length,
|
608 |
+
attention_mask,
|
609 |
+
)
|
610 |
+
for b_idx, block in enumerate(self.blocks):
|
611 |
+
if output_hidden_states:
|
612 |
+
assert all_hidden_states is not None
|
613 |
+
all_hidden_states = all_hidden_states + (x,)
|
614 |
+
past_key_value = (
|
615 |
+
past_key_values[b_idx] if past_key_values is not None else None
|
616 |
+
)
|
617 |
+
if self.gradient_checkpointing and self.training:
|
618 |
+
(x, attn_weights, present) = self._gradient_checkpointing_func(
|
619 |
+
block.__call__,
|
620 |
+
x,
|
621 |
+
past_key_value,
|
622 |
+
attn_bias,
|
623 |
+
rotary_emb_w_meta_info,
|
624 |
+
attention_mask,
|
625 |
+
self.is_causal,
|
626 |
+
bool(output_attentions),
|
627 |
+
alibi_slopes,
|
628 |
+
flash_attn_padding_info,
|
629 |
+
)
|
630 |
+
else:
|
631 |
+
(x, attn_weights, present) = block(
|
632 |
+
x,
|
633 |
+
past_key_value=past_key_value,
|
634 |
+
attn_bias=attn_bias,
|
635 |
+
rotary_emb_w_meta_info=rotary_emb_w_meta_info,
|
636 |
+
attention_mask=attention_mask,
|
637 |
+
is_causal=self.is_causal,
|
638 |
+
output_attentions=bool(output_attentions),
|
639 |
+
alibi_slopes=alibi_slopes,
|
640 |
+
flash_attn_padding_info=flash_attn_padding_info,
|
641 |
+
)
|
642 |
+
if presents is not None:
|
643 |
+
presents += (present,)
|
644 |
+
if output_attentions:
|
645 |
+
assert all_self_attns is not None
|
646 |
+
all_self_attns = all_self_attns + (attn_weights,)
|
647 |
+
x = self.norm_f(x)
|
648 |
+
if output_hidden_states:
|
649 |
+
assert all_hidden_states is not None
|
650 |
+
all_hidden_states = all_hidden_states + (x,)
|
651 |
+
return BaseModelOutputWithPast(
|
652 |
+
last_hidden_state=x,
|
653 |
+
past_key_values=presents,
|
654 |
+
hidden_states=all_hidden_states,
|
655 |
+
attentions=all_self_attns,
|
656 |
+
)
|
657 |
+
|
658 |
+
def param_init_fn(self, module: nn.Module) -> None:
|
659 |
+
init_fn_name = self.config.init_config["name"]
|
660 |
+
MODEL_INIT_REGISTRY[init_fn_name](
|
661 |
+
module=module,
|
662 |
+
n_layers=self.config.n_layers,
|
663 |
+
d_model=self.config.d_model,
|
664 |
+
**self.config.init_config,
|
665 |
+
)
|
666 |
+
|
667 |
+
def fsdp_wrap_fn(self, module: nn.Module) -> bool:
|
668 |
+
return _fsdp_wrap_fn(self, module)
|
669 |
+
|
670 |
+
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
|
671 |
+
return isinstance(module, MPTBlock)
|
672 |
+
|
673 |
+
|
674 |
+
class MPTForCausalLM(MPTPreTrainedModel):
|
675 |
+
|
676 |
+
def __init__(self, config: MPTConfig):
|
677 |
+
super().__init__(config)
|
678 |
+
log.info(f"Instantiating an MPTForCausalLM model from {__file__}")
|
679 |
+
self.transformer: MPTModel = MPTModel(config)
|
680 |
+
self.lm_head = None
|
681 |
+
if not config.tie_word_embeddings:
|
682 |
+
self.lm_head = nn.Linear(
|
683 |
+
config.d_model, config.vocab_size, bias=False, device=config.init_device
|
684 |
+
)
|
685 |
+
self.lm_head._fsdp_wrap = True
|
686 |
+
for child in self.transformer.children():
|
687 |
+
if isinstance(child, torch.nn.ModuleList):
|
688 |
+
continue
|
689 |
+
if isinstance(child, torch.nn.Module):
|
690 |
+
child._fsdp_wrap = True
|
691 |
+
self.logit_scale = None
|
692 |
+
if config.logit_scale is not None:
|
693 |
+
logit_scale = config.logit_scale
|
694 |
+
if isinstance(logit_scale, str):
|
695 |
+
if logit_scale == "inv_sqrt_d_model":
|
696 |
+
logit_scale = 1 / math.sqrt(config.d_model)
|
697 |
+
else:
|
698 |
+
raise ValueError(
|
699 |
+
f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
|
700 |
+
)
|
701 |
+
self.logit_scale = logit_scale
|
702 |
+
|
703 |
+
def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]:
|
704 |
+
return self.transformer.get_input_embeddings()
|
705 |
+
|
706 |
+
def set_input_embeddings(self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
|
707 |
+
self.transformer.set_input_embeddings(value)
|
708 |
+
|
709 |
+
def get_output_embeddings(self) -> Union[SharedEmbedding, nn.Embedding, nn.Linear]:
|
710 |
+
if self.lm_head is not None:
|
711 |
+
return self.lm_head
|
712 |
+
return self.transformer.get_input_embeddings()
|
713 |
+
|
714 |
+
def set_output_embeddings(
|
715 |
+
self, new_embeddings: Union[SharedEmbedding, nn.Embedding, nn.Linear]
|
716 |
+
) -> None:
|
717 |
+
if self.lm_head is not None:
|
718 |
+
self.lm_head = new_embeddings
|
719 |
+
else:
|
720 |
+
if not isinstance(new_embeddings, (SharedEmbedding, nn.Embedding)):
|
721 |
+
raise ValueError(
|
722 |
+
"new_embeddings must be an instance of SharedEmbedding "
|
723 |
+
+ f"or nn.Embedding, but got {type(new_embeddings)}."
|
724 |
+
)
|
725 |
+
warnings.warn(
|
726 |
+
"Using `set_output_embeddings` to set the embedding layer of "
|
727 |
+
+ "MPTForCausalLM with tied weights. Given weights are tied, "
|
728 |
+
+ "using `set_input_embeddings` is recommended over using "
|
729 |
+
+ "`set_output_embeddings`."
|
730 |
+
)
|
731 |
+
self.transformer.set_input_embeddings(new_embeddings)
|
732 |
+
|
733 |
+
def tie_weights(self) -> None:
|
734 |
+
self.lm_head = None
|
735 |
+
|
736 |
+
def set_decoder(self, decoder: MPTModel) -> None:
|
737 |
+
self.transformer = decoder
|
738 |
+
|
739 |
+
def get_decoder(self) -> MPTModel:
|
740 |
+
return self.transformer
|
741 |
+
|
742 |
+
def forward(
|
743 |
+
self,
|
744 |
+
input_ids: Optional[torch.LongTensor] = None,
|
745 |
+
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
746 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
747 |
+
prefix_mask: Optional[torch.ByteTensor] = None,
|
748 |
+
sequence_id: Optional[torch.LongTensor] = None,
|
749 |
+
labels: Optional[torch.LongTensor] = None,
|
750 |
+
return_dict: Optional[bool] = None,
|
751 |
+
output_attentions: Optional[bool] = None,
|
752 |
+
output_hidden_states: Optional[bool] = None,
|
753 |
+
use_cache: Optional[bool] = None,
|
754 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
755 |
+
) -> CausalLMOutputWithPast:
|
756 |
+
return_dict = (
|
757 |
+
return_dict if return_dict is not None else self.config.return_dict
|
758 |
+
)
|
759 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
760 |
+
outputs = self.transformer(
|
761 |
+
input_ids=input_ids,
|
762 |
+
past_key_values=past_key_values,
|
763 |
+
attention_mask=attention_mask,
|
764 |
+
prefix_mask=prefix_mask,
|
765 |
+
sequence_id=sequence_id,
|
766 |
+
return_dict=return_dict,
|
767 |
+
output_attentions=output_attentions,
|
768 |
+
output_hidden_states=output_hidden_states,
|
769 |
+
use_cache=use_cache,
|
770 |
+
inputs_embeds=inputs_embeds,
|
771 |
+
)
|
772 |
+
if self.lm_head is not None:
|
773 |
+
logits = self.lm_head(outputs.last_hidden_state)
|
774 |
+
else:
|
775 |
+
out = outputs.last_hidden_state
|
776 |
+
out = out.to(self.transformer.wte.weight.device)
|
777 |
+
logits = self.transformer.wte(out, True)
|
778 |
+
if self.logit_scale is not None:
|
779 |
+
if self.logit_scale == 0:
|
780 |
+
warnings.warn(
|
781 |
+
f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
|
782 |
+
)
|
783 |
+
logits *= self.logit_scale
|
784 |
+
loss = None
|
785 |
+
if labels is not None:
|
786 |
+
_labels = torch.roll(labels, shifts=-1)
|
787 |
+
_labels[:, -1] = -100
|
788 |
+
loss = F.cross_entropy(
|
789 |
+
logits.view(-1, logits.size(-1)), _labels.to(logits.device).view(-1)
|
790 |
+
)
|
791 |
+
return CausalLMOutputWithPast(
|
792 |
+
loss=loss,
|
793 |
+
logits=logits,
|
794 |
+
past_key_values=outputs.past_key_values,
|
795 |
+
hidden_states=outputs.hidden_states,
|
796 |
+
attentions=outputs.attentions,
|
797 |
+
)
|
798 |
+
|
799 |
+
def param_init_fn(self, module: nn.Module) -> None:
|
800 |
+
init_fn_name = self.config.init_config["name"]
|
801 |
+
MODEL_INIT_REGISTRY[init_fn_name](
|
802 |
+
module=module,
|
803 |
+
n_layers=self.config.n_layers,
|
804 |
+
d_model=self.config.d_model,
|
805 |
+
**self.config.init_config,
|
806 |
+
)
|
807 |
+
|
808 |
+
def fsdp_wrap_fn(self, module: nn.Module) -> bool:
|
809 |
+
return _fsdp_wrap_fn(self, module)
|
810 |
+
|
811 |
+
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
|
812 |
+
act_ckpt_list = getattr(
|
813 |
+
self.config, "activation_checkpointing_target", None
|
814 |
+
) or ["MPTBlock"]
|
815 |
+
if isinstance(act_ckpt_list, str):
|
816 |
+
act_ckpt_list = [act_ckpt_list]
|
817 |
+
elif not isinstance(act_ckpt_list, list):
|
818 |
+
raise ValueError(
|
819 |
+
f"activation_checkpointing_target must be either a single string or a list, but got {type(act_ckpt_list)}"
|
820 |
+
)
|
821 |
+
if "MPTBlock" in act_ckpt_list or "mptblock" in act_ckpt_list:
|
822 |
+
if len(act_ckpt_list) > 1:
|
823 |
+
log.info(
|
824 |
+
"Activation checkpointing MPTBlock only (ignoring other sub-block modules specified in activation_checkpointing_target)."
|
825 |
+
)
|
826 |
+
return isinstance(module, MPTBlock)
|
827 |
+
mod_types = ()
|
828 |
+
for mod_name in act_ckpt_list:
|
829 |
+
if mod_name.lower() == "mptblock":
|
830 |
+
mod_types += (MPTBlock,)
|
831 |
+
elif mod_name in ATTN_CLASS_REGISTRY:
|
832 |
+
mod_types += (ATTN_CLASS_REGISTRY[mod_name],)
|
833 |
+
elif mod_name in FFN_CLASS_REGISTRY:
|
834 |
+
mod_types += (FFN_CLASS_REGISTRY[mod_name],)
|
835 |
+
elif mod_name in NORM_CLASS_REGISTRY:
|
836 |
+
mod_types += (NORM_CLASS_REGISTRY[mod_name],)
|
837 |
+
else:
|
838 |
+
msg = ", ".join(
|
839 |
+
list(ATTN_CLASS_REGISTRY.keys())
|
840 |
+
+ list(FFN_CLASS_REGISTRY.keys())
|
841 |
+
+ list(NORM_CLASS_REGISTRY.keys())
|
842 |
+
+ ["MPTBlock"]
|
843 |
+
)
|
844 |
+
raise ValueError(
|
845 |
+
f"{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}."
|
846 |
+
)
|
847 |
+
return isinstance(module, mod_types)
|
848 |
+
|
849 |
+
def prepare_inputs_for_generation(
|
850 |
+
self,
|
851 |
+
input_ids: torch.Tensor,
|
852 |
+
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
853 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
854 |
+
**kwargs: Any,
|
855 |
+
) -> Dict[str, Any]:
|
856 |
+
attention_mask = kwargs["attention_mask"].bool()
|
857 |
+
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
|
858 |
+
raise NotImplementedError(
|
859 |
+
"MPT does not support generation with right padding."
|
860 |
+
)
|
861 |
+
if self.transformer.attn_uses_sequence_id and self.training:
|
862 |
+
sequence_id = torch.zeros_like(input_ids[:1])
|
863 |
+
else:
|
864 |
+
sequence_id = None
|
865 |
+
if past_key_values is not None:
|
866 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
867 |
+
if self.transformer.prefix_lm:
|
868 |
+
prefix_mask = torch.ones_like(attention_mask)
|
869 |
+
if kwargs.get("use_cache") == False:
|
870 |
+
raise NotImplementedError(
|
871 |
+
"MPT with prefix_lm=True does not support use_cache=False."
|
872 |
+
)
|
873 |
+
else:
|
874 |
+
prefix_mask = None
|
875 |
+
if inputs_embeds is not None and past_key_values is None:
|
876 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
877 |
+
else:
|
878 |
+
model_inputs = {"input_ids": input_ids}
|
879 |
+
model_inputs.update(
|
880 |
+
{
|
881 |
+
"attention_mask": attention_mask,
|
882 |
+
"prefix_mask": prefix_mask,
|
883 |
+
"sequence_id": sequence_id,
|
884 |
+
"past_key_values": past_key_values,
|
885 |
+
"use_cache": kwargs.get("use_cache", True),
|
886 |
+
}
|
887 |
+
)
|
888 |
+
return model_inputs
|
889 |
+
|
890 |
+
@staticmethod
|
891 |
+
def _reorder_cache(
|
892 |
+
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
|
893 |
+
beam_idx: torch.LongTensor,
|
894 |
+
) -> List[Tuple[torch.Tensor, ...]]:
|
895 |
+
"""Used by HuggingFace generate when using beam search with kv-caching.
|
896 |
+
|
897 |
+
See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
|
898 |
+
for an example in transformers.
|
899 |
+
"""
|
900 |
+
reordered_past = []
|
901 |
+
for layer_past in past_key_values:
|
902 |
+
reordered_past += [
|
903 |
+
tuple(
|
904 |
+
(past_state.index_select(0, beam_idx) for past_state in layer_past)
|
905 |
+
)
|
906 |
+
]
|
907 |
+
return reordered_past
|
norm.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Type, Union
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor:
|
6 |
+
if torch.is_autocast_enabled():
|
7 |
+
if tensor.device.type == "cuda":
|
8 |
+
dtype = torch.get_autocast_gpu_dtype()
|
9 |
+
elif tensor.device.type == "cpu":
|
10 |
+
dtype = torch.get_autocast_cpu_dtype()
|
11 |
+
else:
|
12 |
+
raise NotImplementedError()
|
13 |
+
return tensor.to(dtype=dtype)
|
14 |
+
return tensor
|
15 |
+
|
16 |
+
|
17 |
+
class LPLayerNorm(torch.nn.LayerNorm):
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
normalized_shape: Union[int, List[int], torch.Size],
|
22 |
+
eps: float = 1e-05,
|
23 |
+
elementwise_affine: bool = True,
|
24 |
+
device: Optional[torch.device] = None,
|
25 |
+
dtype: Optional[torch.dtype] = None,
|
26 |
+
):
|
27 |
+
super().__init__(
|
28 |
+
normalized_shape=normalized_shape,
|
29 |
+
eps=eps,
|
30 |
+
elementwise_affine=elementwise_affine,
|
31 |
+
device=device,
|
32 |
+
dtype=dtype,
|
33 |
+
)
|
34 |
+
|
35 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
36 |
+
module_device = x.device
|
37 |
+
downcast_x = _cast_if_autocast_enabled(x)
|
38 |
+
downcast_weight = (
|
39 |
+
_cast_if_autocast_enabled(self.weight)
|
40 |
+
if self.weight is not None
|
41 |
+
else self.weight
|
42 |
+
)
|
43 |
+
downcast_bias = (
|
44 |
+
_cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
|
45 |
+
)
|
46 |
+
with torch.autocast(enabled=False, device_type=module_device.type):
|
47 |
+
return torch.nn.functional.layer_norm(
|
48 |
+
downcast_x,
|
49 |
+
self.normalized_shape,
|
50 |
+
downcast_weight,
|
51 |
+
downcast_bias,
|
52 |
+
self.eps,
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
def rms_norm(
|
57 |
+
x: torch.Tensor, weight: Optional[torch.Tensor] = None, eps: float = 1e-05
|
58 |
+
) -> torch.Tensor:
|
59 |
+
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
60 |
+
if weight is not None:
|
61 |
+
return output * weight
|
62 |
+
return output
|
63 |
+
|
64 |
+
|
65 |
+
class RMSNorm(torch.nn.Module):
|
66 |
+
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
normalized_shape: Union[int, List[int], torch.Size],
|
70 |
+
eps: float = 1e-05,
|
71 |
+
weight: bool = True,
|
72 |
+
dtype: Optional[torch.dtype] = None,
|
73 |
+
device: Optional[torch.device] = None,
|
74 |
+
):
|
75 |
+
super().__init__()
|
76 |
+
self.eps = eps
|
77 |
+
if weight:
|
78 |
+
self.weight = torch.nn.Parameter(
|
79 |
+
torch.ones(normalized_shape, dtype=dtype, device=device)
|
80 |
+
)
|
81 |
+
else:
|
82 |
+
self.register_parameter("weight", None)
|
83 |
+
|
84 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
85 |
+
return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
|
86 |
+
|
87 |
+
|
88 |
+
class LPRMSNorm(RMSNorm):
|
89 |
+
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
normalized_shape: Union[int, List[int], torch.Size],
|
93 |
+
eps: float = 1e-05,
|
94 |
+
weight: bool = True,
|
95 |
+
dtype: Optional[torch.dtype] = None,
|
96 |
+
device: Optional[torch.device] = None,
|
97 |
+
):
|
98 |
+
super().__init__(
|
99 |
+
normalized_shape=normalized_shape,
|
100 |
+
eps=eps,
|
101 |
+
weight=weight,
|
102 |
+
dtype=dtype,
|
103 |
+
device=device,
|
104 |
+
)
|
105 |
+
|
106 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
107 |
+
downcast_x = _cast_if_autocast_enabled(x)
|
108 |
+
downcast_weight = (
|
109 |
+
_cast_if_autocast_enabled(self.weight)
|
110 |
+
if self.weight is not None
|
111 |
+
else self.weight
|
112 |
+
)
|
113 |
+
with torch.autocast(enabled=False, device_type=x.device.type):
|
114 |
+
return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
|
115 |
+
|
116 |
+
|
117 |
+
NORM_CLASS_REGISTRY: Dict[str, Type[torch.nn.Module]] = {
|
118 |
+
"layernorm": torch.nn.LayerNorm,
|
119 |
+
"low_precision_layernorm": LPLayerNorm,
|
120 |
+
"rmsnorm": RMSNorm,
|
121 |
+
"low_precision_rmsnorm": LPRMSNorm,
|
122 |
+
}
|
param_init_fns.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from collections.abc import Sequence
|
4 |
+
from functools import partial
|
5 |
+
from typing import Any, Callable, Optional, Tuple, Union
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from .fc import FC_CLASS_REGISTRY
|
9 |
+
from .norm import NORM_CLASS_REGISTRY
|
10 |
+
|
11 |
+
try:
|
12 |
+
import transformer_engine.pytorch as te
|
13 |
+
except:
|
14 |
+
te = None
|
15 |
+
|
16 |
+
|
17 |
+
def torch_default_param_init_fn_(module: nn.Module, **kwargs: Any) -> None:
|
18 |
+
del kwargs
|
19 |
+
if hasattr(module, "reset_parameters") and isinstance(
|
20 |
+
module.reset_parameters, Callable
|
21 |
+
):
|
22 |
+
module.reset_parameters()
|
23 |
+
|
24 |
+
|
25 |
+
def fused_init_helper_(module: nn.Module, init_fn_: Callable) -> None:
|
26 |
+
_fused = getattr(module, "_fused", None)
|
27 |
+
if _fused is None:
|
28 |
+
raise RuntimeError(f"Internal logic error")
|
29 |
+
assert isinstance(module.weight, torch.Tensor)
|
30 |
+
(dim, splits) = _fused
|
31 |
+
splits = (0, *splits, module.weight.size(dim))
|
32 |
+
for s, e in zip(splits[:-1], splits[1:]):
|
33 |
+
slice_indices = [slice(None)] * module.weight.ndim
|
34 |
+
slice_indices[dim] = slice(s, e)
|
35 |
+
init_fn_(module.weight[slice_indices])
|
36 |
+
|
37 |
+
|
38 |
+
def generic_param_init_fn_(
|
39 |
+
module: nn.Module,
|
40 |
+
init_fn_: Callable,
|
41 |
+
n_layers: int,
|
42 |
+
d_model: Optional[int] = None,
|
43 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
44 |
+
emb_init_std: Optional[float] = None,
|
45 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
46 |
+
**kwargs: Any,
|
47 |
+
) -> None:
|
48 |
+
del kwargs
|
49 |
+
init_div_is_residual = init_div_is_residual
|
50 |
+
if init_div_is_residual is False:
|
51 |
+
div_is_residual = 1.0
|
52 |
+
elif init_div_is_residual is True:
|
53 |
+
div_is_residual = math.sqrt(2 * n_layers)
|
54 |
+
elif isinstance(init_div_is_residual, float) or isinstance(
|
55 |
+
init_div_is_residual, int
|
56 |
+
):
|
57 |
+
div_is_residual = init_div_is_residual
|
58 |
+
elif init_div_is_residual.isnumeric():
|
59 |
+
div_is_residual = float(init_div_is_residual)
|
60 |
+
else:
|
61 |
+
div_is_residual = 1.0
|
62 |
+
raise ValueError(
|
63 |
+
f"Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}"
|
64 |
+
)
|
65 |
+
if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))):
|
66 |
+
if hasattr(module, "_fused"):
|
67 |
+
fused_init_helper_(module, init_fn_)
|
68 |
+
else:
|
69 |
+
init_fn_(module.weight)
|
70 |
+
if module.bias is not None:
|
71 |
+
assert isinstance(module.bias, torch.Tensor)
|
72 |
+
torch.nn.init.zeros_(module.bias)
|
73 |
+
if init_div_is_residual is not False and getattr(module, "_is_residual", False):
|
74 |
+
with torch.no_grad():
|
75 |
+
module.weight.div_(div_is_residual)
|
76 |
+
elif isinstance(module, nn.Embedding):
|
77 |
+
if emb_init_std is not None:
|
78 |
+
std = emb_init_std
|
79 |
+
if std == 0:
|
80 |
+
warnings.warn(f"Embedding layer initialized to 0.")
|
81 |
+
emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
|
82 |
+
elif emb_init_uniform_lim is not None:
|
83 |
+
lim = emb_init_uniform_lim
|
84 |
+
if isinstance(lim, Sequence):
|
85 |
+
if len(lim) > 2:
|
86 |
+
raise ValueError(
|
87 |
+
f"Uniform init requires a min and a max limit. User input: {lim}."
|
88 |
+
)
|
89 |
+
if lim[0] == lim[1]:
|
90 |
+
warnings.warn(f"Embedding layer initialized to {lim[0]}.")
|
91 |
+
else:
|
92 |
+
if lim == 0:
|
93 |
+
warnings.warn(f"Embedding layer initialized to 0.")
|
94 |
+
lim = [-lim, lim]
|
95 |
+
(a, b) = lim
|
96 |
+
emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
|
97 |
+
else:
|
98 |
+
emb_init_fn_ = init_fn_
|
99 |
+
emb_init_fn_(module.weight)
|
100 |
+
elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
|
101 |
+
if hasattr(module, "weight") and isinstance(module.weight, torch.Tensor):
|
102 |
+
torch.nn.init.ones_(module.weight)
|
103 |
+
if hasattr(module, "bias") and isinstance(module.bias, torch.Tensor):
|
104 |
+
torch.nn.init.zeros_(module.bias)
|
105 |
+
elif isinstance(module, nn.MultiheadAttention):
|
106 |
+
if module._qkv_same_embed_dim:
|
107 |
+
assert module.in_proj_weight is not None
|
108 |
+
assert (
|
109 |
+
module.q_proj_weight is None
|
110 |
+
and module.k_proj_weight is None
|
111 |
+
and (module.v_proj_weight is None)
|
112 |
+
)
|
113 |
+
assert d_model is not None
|
114 |
+
_d = d_model
|
115 |
+
splits = (0, _d, 2 * _d, 3 * _d)
|
116 |
+
for s, e in zip(splits[:-1], splits[1:]):
|
117 |
+
init_fn_(module.in_proj_weight[s:e])
|
118 |
+
else:
|
119 |
+
assert (
|
120 |
+
module.q_proj_weight is not None
|
121 |
+
and module.k_proj_weight is not None
|
122 |
+
and (module.v_proj_weight is not None)
|
123 |
+
)
|
124 |
+
assert module.in_proj_weight is None
|
125 |
+
init_fn_(module.q_proj_weight)
|
126 |
+
init_fn_(module.k_proj_weight)
|
127 |
+
init_fn_(module.v_proj_weight)
|
128 |
+
if module.in_proj_bias is not None:
|
129 |
+
torch.nn.init.zeros_(module.in_proj_bias)
|
130 |
+
if module.bias_k is not None:
|
131 |
+
torch.nn.init.zeros_(module.bias_k)
|
132 |
+
if module.bias_v is not None:
|
133 |
+
torch.nn.init.zeros_(module.bias_v)
|
134 |
+
init_fn_(module.out_proj.weight)
|
135 |
+
if init_div_is_residual is not False and getattr(
|
136 |
+
module.out_proj, "_is_residual", False
|
137 |
+
):
|
138 |
+
with torch.no_grad():
|
139 |
+
module.out_proj.weight.div_(div_is_residual)
|
140 |
+
if module.out_proj.bias is not None:
|
141 |
+
torch.nn.init.zeros_(module.out_proj.bias)
|
142 |
+
elif te is not None and isinstance(module, te.LayerNormMLP):
|
143 |
+
if isinstance(module.layer_norm_weight, torch.Tensor):
|
144 |
+
torch.nn.init.ones_(module.layer_norm_weight)
|
145 |
+
if isinstance(module.layer_norm_bias, torch.Tensor):
|
146 |
+
torch.nn.init.zeros_(module.layer_norm_bias)
|
147 |
+
init_fn_(module.fc1_weight)
|
148 |
+
if module.fc1_bias is not None:
|
149 |
+
assert isinstance(module.fc1_bias, torch.Tensor)
|
150 |
+
torch.nn.init.zeros_(module.fc1_bias)
|
151 |
+
init_fn_(module.fc2_weight)
|
152 |
+
if module.fc2_bias is not None:
|
153 |
+
assert isinstance(module.fc2_bias, torch.Tensor)
|
154 |
+
torch.nn.init.zeros_(module.fc2_bias)
|
155 |
+
with torch.no_grad():
|
156 |
+
module.fc2_weight.div_(div_is_residual)
|
157 |
+
else:
|
158 |
+
for _ in module.parameters(recurse=False):
|
159 |
+
raise NotImplementedError(
|
160 |
+
f"{module.__class__.__name__} parameters are not initialized by param_init_fn."
|
161 |
+
)
|
162 |
+
|
163 |
+
|
164 |
+
def _normal_init_(std: float, mean: float = 0.0) -> Callable:
|
165 |
+
return partial(torch.nn.init.normal_, mean=mean, std=std)
|
166 |
+
|
167 |
+
|
168 |
+
def _normal_param_init_fn_(
|
169 |
+
module: nn.Module,
|
170 |
+
std: float,
|
171 |
+
n_layers: int,
|
172 |
+
d_model: Optional[int] = None,
|
173 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
174 |
+
emb_init_std: Optional[float] = None,
|
175 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
176 |
+
**kwargs: Any,
|
177 |
+
) -> None:
|
178 |
+
del kwargs
|
179 |
+
init_fn_ = _normal_init_(std=std)
|
180 |
+
generic_param_init_fn_(
|
181 |
+
module=module,
|
182 |
+
init_fn_=init_fn_,
|
183 |
+
d_model=d_model,
|
184 |
+
n_layers=n_layers,
|
185 |
+
init_div_is_residual=init_div_is_residual,
|
186 |
+
emb_init_std=emb_init_std,
|
187 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
188 |
+
)
|
189 |
+
|
190 |
+
|
191 |
+
def baseline_param_init_fn_(
|
192 |
+
module: nn.Module,
|
193 |
+
init_std: Optional[float],
|
194 |
+
n_layers: int,
|
195 |
+
d_model: Optional[int] = None,
|
196 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
197 |
+
emb_init_std: Optional[float] = None,
|
198 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
199 |
+
**kwargs: Any,
|
200 |
+
) -> None:
|
201 |
+
del kwargs
|
202 |
+
if init_std is None:
|
203 |
+
raise ValueError(
|
204 |
+
"You must set model.init_config['init_std'] to a float value to use the default initialization scheme."
|
205 |
+
)
|
206 |
+
_normal_param_init_fn_(
|
207 |
+
module=module,
|
208 |
+
std=init_std,
|
209 |
+
d_model=d_model,
|
210 |
+
n_layers=n_layers,
|
211 |
+
init_div_is_residual=init_div_is_residual,
|
212 |
+
emb_init_std=emb_init_std,
|
213 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
214 |
+
)
|
215 |
+
|
216 |
+
|
217 |
+
def small_param_init_fn_(
|
218 |
+
module: nn.Module,
|
219 |
+
n_layers: int,
|
220 |
+
d_model: int,
|
221 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
222 |
+
emb_init_std: Optional[float] = None,
|
223 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
224 |
+
**kwargs: Any,
|
225 |
+
) -> None:
|
226 |
+
del kwargs
|
227 |
+
std = math.sqrt(2 / (5 * d_model))
|
228 |
+
_normal_param_init_fn_(
|
229 |
+
module=module,
|
230 |
+
std=std,
|
231 |
+
d_model=d_model,
|
232 |
+
n_layers=n_layers,
|
233 |
+
init_div_is_residual=init_div_is_residual,
|
234 |
+
emb_init_std=emb_init_std,
|
235 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
236 |
+
)
|
237 |
+
|
238 |
+
|
239 |
+
def neox_param_init_fn_(
|
240 |
+
module: nn.Module,
|
241 |
+
n_layers: int,
|
242 |
+
d_model: int,
|
243 |
+
emb_init_std: Optional[float] = None,
|
244 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
245 |
+
**kwargs: Any,
|
246 |
+
) -> None:
|
247 |
+
"""From section 2.3.1 of GPT-NeoX-20B:
|
248 |
+
|
249 |
+
An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
|
250 |
+
see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
|
251 |
+
and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
|
252 |
+
"""
|
253 |
+
del kwargs
|
254 |
+
residual_div = n_layers / math.sqrt(10)
|
255 |
+
small_param_init_fn_(
|
256 |
+
module=module,
|
257 |
+
d_model=d_model,
|
258 |
+
n_layers=n_layers,
|
259 |
+
init_div_is_residual=residual_div,
|
260 |
+
emb_init_std=emb_init_std,
|
261 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
262 |
+
)
|
263 |
+
|
264 |
+
|
265 |
+
def kaiming_uniform_param_init_fn_(
|
266 |
+
module: nn.Module,
|
267 |
+
n_layers: int,
|
268 |
+
d_model: Optional[int] = None,
|
269 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
270 |
+
emb_init_std: Optional[float] = None,
|
271 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
272 |
+
init_gain: float = 0,
|
273 |
+
fan_mode: str = "fan_in",
|
274 |
+
init_nonlinearity: str = "leaky_relu",
|
275 |
+
**kwargs: Any,
|
276 |
+
) -> None:
|
277 |
+
del kwargs
|
278 |
+
kaiming_uniform_ = partial(
|
279 |
+
nn.init.kaiming_uniform_,
|
280 |
+
a=init_gain,
|
281 |
+
mode=fan_mode,
|
282 |
+
nonlinearity=init_nonlinearity,
|
283 |
+
)
|
284 |
+
generic_param_init_fn_(
|
285 |
+
module=module,
|
286 |
+
init_fn_=kaiming_uniform_,
|
287 |
+
d_model=d_model,
|
288 |
+
n_layers=n_layers,
|
289 |
+
init_div_is_residual=init_div_is_residual,
|
290 |
+
emb_init_std=emb_init_std,
|
291 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
292 |
+
)
|
293 |
+
|
294 |
+
|
295 |
+
def kaiming_normal_param_init_fn_(
|
296 |
+
module: nn.Module,
|
297 |
+
n_layers: int,
|
298 |
+
d_model: Optional[int] = None,
|
299 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
300 |
+
emb_init_std: Optional[float] = None,
|
301 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
302 |
+
init_gain: float = 0,
|
303 |
+
fan_mode: str = "fan_in",
|
304 |
+
init_nonlinearity: str = "leaky_relu",
|
305 |
+
**kwargs: Any,
|
306 |
+
) -> None:
|
307 |
+
del kwargs
|
308 |
+
kaiming_normal_ = partial(
|
309 |
+
torch.nn.init.kaiming_normal_,
|
310 |
+
a=init_gain,
|
311 |
+
mode=fan_mode,
|
312 |
+
nonlinearity=init_nonlinearity,
|
313 |
+
)
|
314 |
+
generic_param_init_fn_(
|
315 |
+
module=module,
|
316 |
+
init_fn_=kaiming_normal_,
|
317 |
+
d_model=d_model,
|
318 |
+
n_layers=n_layers,
|
319 |
+
init_div_is_residual=init_div_is_residual,
|
320 |
+
emb_init_std=emb_init_std,
|
321 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
322 |
+
)
|
323 |
+
|
324 |
+
|
325 |
+
def xavier_uniform_param_init_fn_(
|
326 |
+
module: nn.Module,
|
327 |
+
n_layers: int,
|
328 |
+
d_model: Optional[int] = None,
|
329 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
330 |
+
emb_init_std: Optional[float] = None,
|
331 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
332 |
+
init_gain: float = 0,
|
333 |
+
**kwargs: Any,
|
334 |
+
) -> None:
|
335 |
+
del kwargs
|
336 |
+
xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
|
337 |
+
generic_param_init_fn_(
|
338 |
+
module=module,
|
339 |
+
init_fn_=xavier_uniform_,
|
340 |
+
d_model=d_model,
|
341 |
+
n_layers=n_layers,
|
342 |
+
init_div_is_residual=init_div_is_residual,
|
343 |
+
emb_init_std=emb_init_std,
|
344 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
345 |
+
)
|
346 |
+
|
347 |
+
|
348 |
+
def xavier_normal_param_init_fn_(
|
349 |
+
module: nn.Module,
|
350 |
+
n_layers: int,
|
351 |
+
d_model: Optional[int] = None,
|
352 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
353 |
+
emb_init_std: Optional[float] = None,
|
354 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
355 |
+
init_gain: float = 0,
|
356 |
+
**kwargs: Any,
|
357 |
+
) -> None:
|
358 |
+
del kwargs
|
359 |
+
xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
|
360 |
+
generic_param_init_fn_(
|
361 |
+
module=module,
|
362 |
+
init_fn_=xavier_normal_,
|
363 |
+
d_model=d_model,
|
364 |
+
n_layers=n_layers,
|
365 |
+
init_div_is_residual=init_div_is_residual,
|
366 |
+
emb_init_std=emb_init_std,
|
367 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
368 |
+
)
|
369 |
+
|
370 |
+
|
371 |
+
MODEL_INIT_REGISTRY = {
|
372 |
+
"default_": torch_default_param_init_fn_,
|
373 |
+
"baseline_": baseline_param_init_fn_,
|
374 |
+
"kaiming_uniform_": kaiming_uniform_param_init_fn_,
|
375 |
+
"kaiming_normal_": kaiming_normal_param_init_fn_,
|
376 |
+
"neox_init_": neox_param_init_fn_,
|
377 |
+
"small_init_": small_param_init_fn_,
|
378 |
+
"xavier_uniform_": xavier_uniform_param_init_fn_,
|
379 |
+
"xavier_normal_": xavier_normal_param_init_fn_,
|
380 |
+
}
|
quantize_config.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bits": 4,
|
3 |
+
"group_size": 128,
|
4 |
+
"damp_percent": 0.01,
|
5 |
+
"desc_act": false,
|
6 |
+
"static_groups": false,
|
7 |
+
"sym": true,
|
8 |
+
"true_sequential": true,
|
9 |
+
"model_name_or_path": null,
|
10 |
+
"model_file_base_name": null
|
11 |
+
}
|
tokenization_SEA_BPE.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from shutil import copyfile
|
3 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
4 |
+
import sentencepiece as spm
|
5 |
+
from tokenizers import processors
|
6 |
+
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
7 |
+
from transformers.utils import logging
|
8 |
+
|
9 |
+
logger = logging.get_logger(__name__)
|
10 |
+
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
|
11 |
+
SPIECE_UNDERLINE = "▁"
|
12 |
+
|
13 |
+
|
14 |
+
class SEABPETokenizer(PreTrainedTokenizer):
|
15 |
+
"""
|
16 |
+
Construct the SEA BPE Tokenizer tailored for SEA languages. Based on the Byte-Pair-Encoding with an expanded voculabulary size
|
17 |
+
|
18 |
+
Args:
|
19 |
+
vocab_file (`str`):
|
20 |
+
Path to the vocabulary file.
|
21 |
+
legacy (`bool`, *optional*, defaults to `True`):
|
22 |
+
Whether or not the `legacy` behaviour of the tokenizer should be used. Legacy is before the merge of #24622
|
23 |
+
which includes fixes to properly handle tokens that appear after special tokens.
|
24 |
+
legacy means we are not modifying existing tokenizers without knowing. (And we need to manually update those core tokenizers)
|
25 |
+
|
26 |
+
A simple example:
|
27 |
+
|
28 |
+
- `legacy=True`:
|
29 |
+
```python
|
30 |
+
>>> from transformers import T5Tokenizer
|
31 |
+
|
32 |
+
>>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=True)
|
33 |
+
>>> tokenizer.encode("Hello <extra_id_0>.")
|
34 |
+
[8774, 32099, 3, 5, 1]
|
35 |
+
```
|
36 |
+
- `legacy=False`:
|
37 |
+
```python
|
38 |
+
>>> from transformers import T5Tokenizer
|
39 |
+
|
40 |
+
>>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=False)
|
41 |
+
>>> tokenizer.encode("Hello <extra_id_0>.") # the extra space `[3]` is no longer here
|
42 |
+
[8774, 32099, 5, 1]
|
43 |
+
```
|
44 |
+
Checkout the pull request and the issue [here](https://github.com/huggingface/transformers/pull/24565) for
|
45 |
+
more details.
|
46 |
+
|
47 |
+
"""
|
48 |
+
|
49 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
vocab_file,
|
54 |
+
unk_token="<unk>",
|
55 |
+
bos_token=None,
|
56 |
+
eos_token="<|endoftext|>",
|
57 |
+
pad_token=None,
|
58 |
+
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
59 |
+
add_bos_token=False,
|
60 |
+
add_eos_token=False,
|
61 |
+
clean_up_tokenization_spaces=False,
|
62 |
+
legacy=None,
|
63 |
+
**kwargs,
|
64 |
+
):
|
65 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
66 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
67 |
+
self.sp_model.Load(vocab_file)
|
68 |
+
super().__init__(
|
69 |
+
bos_token=bos_token,
|
70 |
+
eos_token=eos_token,
|
71 |
+
unk_token=unk_token,
|
72 |
+
pad_token=pad_token,
|
73 |
+
add_bos_token=add_bos_token,
|
74 |
+
add_eos_token=add_eos_token,
|
75 |
+
sp_model_kwargs=self.sp_model_kwargs,
|
76 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
77 |
+
legacy=legacy,
|
78 |
+
**kwargs,
|
79 |
+
)
|
80 |
+
if legacy is None:
|
81 |
+
logger.warning_once(
|
82 |
+
f"You are using the default legacy behaviour of the {self.__class__}. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565, and set the legacy attribute accordingly."
|
83 |
+
)
|
84 |
+
legacy = True
|
85 |
+
self.legacy = legacy
|
86 |
+
self.vocab_file = vocab_file
|
87 |
+
self.add_bos_token = add_bos_token
|
88 |
+
self.add_eos_token = add_eos_token
|
89 |
+
|
90 |
+
def __getstate__(self):
|
91 |
+
state = self.__dict__.copy()
|
92 |
+
state["sp_model"] = None
|
93 |
+
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
|
94 |
+
return state
|
95 |
+
|
96 |
+
def __setstate__(self, d):
|
97 |
+
self.__dict__ = d
|
98 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
99 |
+
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
100 |
+
|
101 |
+
@property
|
102 |
+
def vocab_size(self):
|
103 |
+
"""Returns vocab size"""
|
104 |
+
return self.sp_model.get_piece_size()
|
105 |
+
|
106 |
+
def get_vocab(self):
|
107 |
+
"""Returns vocab as a dict"""
|
108 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
109 |
+
vocab.update(self.added_tokens_encoder)
|
110 |
+
return vocab
|
111 |
+
|
112 |
+
def tokenize(self, text, **kwargs) -> List[str]:
|
113 |
+
if not self.legacy:
|
114 |
+
text = SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " ")
|
115 |
+
return super().tokenize(text, **kwargs)
|
116 |
+
|
117 |
+
def _tokenize(self, text):
|
118 |
+
"""
|
119 |
+
Returns a tokenized string.
|
120 |
+
|
121 |
+
Since the sentencepiece internal model always adds a SPIECE_UNDERLINE, at the beginning of the provided text,
|
122 |
+
we need to remove it by hand when the current text is a subsequence. This happens whenever the `self.tokenize`
|
123 |
+
function is called with specials tokens: the input is split on the special tokens, and each subsequence is
|
124 |
+
passed to `_tokenize`. Thus if a subsequence did not start with a `" "` or SPIECE_UNDERLINE, we have to remove
|
125 |
+
the extra `SPIECE_UNDERLINE` prepended.
|
126 |
+
"""
|
127 |
+
if not self.legacy:
|
128 |
+
is_first = text.startswith(SPIECE_UNDERLINE)
|
129 |
+
if is_first:
|
130 |
+
text = text[1:]
|
131 |
+
tokens = self.sp_model.encode(text, out_type=str)
|
132 |
+
if (
|
133 |
+
not self.legacy
|
134 |
+
and (not is_first)
|
135 |
+
and (not text.startswith(" "))
|
136 |
+
and tokens[0].startswith(SPIECE_UNDERLINE)
|
137 |
+
):
|
138 |
+
tokens = ([tokens[0][1:]] if len(tokens[0]) > 1 else []) + tokens[1:]
|
139 |
+
return tokens
|
140 |
+
|
141 |
+
def _convert_token_to_id(self, token):
|
142 |
+
"""Converts a token (str) in an id using the vocab."""
|
143 |
+
return self.sp_model.piece_to_id(token)
|
144 |
+
|
145 |
+
def _convert_id_to_token(self, index):
|
146 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
147 |
+
token = self.sp_model.IdToPiece(index)
|
148 |
+
return token
|
149 |
+
|
150 |
+
def convert_tokens_to_string(self, tokens):
|
151 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
152 |
+
current_sub_tokens = []
|
153 |
+
out_string = ""
|
154 |
+
prev_is_special = False
|
155 |
+
for i, token in enumerate(tokens):
|
156 |
+
if token in self.all_special_tokens:
|
157 |
+
if not prev_is_special and i != 0:
|
158 |
+
out_string += " "
|
159 |
+
out_string += self.sp_model.decode(current_sub_tokens) + token
|
160 |
+
prev_is_special = True
|
161 |
+
current_sub_tokens = []
|
162 |
+
else:
|
163 |
+
current_sub_tokens.append(token)
|
164 |
+
prev_is_special = False
|
165 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
166 |
+
return out_string
|
167 |
+
|
168 |
+
def save_vocabulary(
|
169 |
+
self, save_directory, filename_prefix: Optional[str] = None
|
170 |
+
) -> Tuple[str]:
|
171 |
+
"""
|
172 |
+
Save the vocabulary and special tokens file to a directory.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
save_directory (`str`):
|
176 |
+
The directory in which to save the vocabulary.
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
`Tuple(str)`: Paths to the files saved.
|
180 |
+
"""
|
181 |
+
if not os.path.isdir(save_directory):
|
182 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
183 |
+
return
|
184 |
+
out_vocab_file = os.path.join(
|
185 |
+
save_directory,
|
186 |
+
(filename_prefix + "-" if filename_prefix else "")
|
187 |
+
+ VOCAB_FILES_NAMES["vocab_file"],
|
188 |
+
)
|
189 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(
|
190 |
+
out_vocab_file
|
191 |
+
) and os.path.isfile(self.vocab_file):
|
192 |
+
copyfile(self.vocab_file, out_vocab_file)
|
193 |
+
elif not os.path.isfile(self.vocab_file):
|
194 |
+
with open(out_vocab_file, "wb") as fi:
|
195 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
196 |
+
fi.write(content_spiece_model)
|
197 |
+
return (out_vocab_file,)
|
tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c0c576972c98fa150efff77f61a30b46afbc1247ff4697f39e51e90d0a8b2190
|
3 |
+
size 4569957
|
tokenizer_config.json
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": false,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"0": {
|
6 |
+
"content": "<unk>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"1": {
|
14 |
+
"content": "<|endoftext|>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"2": {
|
22 |
+
"content": "<|endofline|>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": false,
|
25 |
+
"rstrip": false,
|
26 |
+
"single_word": false,
|
27 |
+
"special": true
|
28 |
+
},
|
29 |
+
"3": {
|
30 |
+
"content": "<|padding|>",
|
31 |
+
"lstrip": false,
|
32 |
+
"normalized": false,
|
33 |
+
"rstrip": false,
|
34 |
+
"single_word": false,
|
35 |
+
"special": true
|
36 |
+
}
|
37 |
+
},
|
38 |
+
"auto_map": {
|
39 |
+
"AutoTokenizer": [
|
40 |
+
"tokenization_SEA_BPE.SEABPETokenizer",
|
41 |
+
null
|
42 |
+
]
|
43 |
+
},
|
44 |
+
"bos_token": null,
|
45 |
+
"clean_up_tokenization_spaces": false,
|
46 |
+
"eos_token": "<|endoftext|>",
|
47 |
+
"legacy": true,
|
48 |
+
"model_max_length": 1000000000000000019884624838656,
|
49 |
+
"pad_token": "<|padding|>",
|
50 |
+
"sp_model_kwargs": {},
|
51 |
+
"tokenizer_class": "SEABPETokenizer",
|
52 |
+
"unk_token": "<unk>"
|
53 |
+
}
|
warnings.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class VersionedDeprecationWarning(DeprecationWarning):
|
2 |
+
"""A custom deprecation warning class that includes version information.
|
3 |
+
Attributes:
|
4 |
+
message (str): The deprecation message describing why the feature is deprecated.
|
5 |
+
remove_version (str): The version in which the feature will be removed.
|
6 |
+
Example:
|
7 |
+
>>> def deprecated_function():
|
8 |
+
... warnings.warn(
|
9 |
+
... VersionedDeprecationWarning(
|
10 |
+
... "Function XYZ is deprecated.",
|
11 |
+
... after_version="2.0.0"
|
12 |
+
... )
|
13 |
+
... )
|
14 |
+
...
|
15 |
+
>>> deprecated_function()
|
16 |
+
DeprecationWarning: Function XYZ is deprecated. It will be removed in version 2.0.0.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, message: str, remove_version: str) -> None:
|
20 |
+
super().__init__(message + f" It will be removed in version {remove_version}.")
|