Mixtral multipack (#928)
Browse files* mixtral multipack
* use mixtral model
* sample yml
* calculate cu_seqlens properly
* use updated flash ettention setting
* attn var checks
* force use of flash attention 2 for packing
* lint
* disable future fix for now
* update support table
- .mypy.ini +3 -0
- README.md +15 -13
- docker/Dockerfile-runpod +1 -0
- examples/mistral/mixtral.yml +78 -0
- src/axolotl/models/mixtral/__init__.py +6 -0
- src/axolotl/models/mixtral/configuration_moe_mistral.py +154 -0
- src/axolotl/models/mixtral/modeling_moe_mistral.py +1506 -0
- src/axolotl/utils/models.py +30 -12
.mypy.ini
CHANGED
@@ -8,6 +8,9 @@ ignore_missing_imports = True
|
|
8 |
[mypy-axolotl.monkeypatch.*]
|
9 |
ignore_errors = True
|
10 |
|
|
|
|
|
|
|
11 |
[mypy-axolotl.models.phi.*]
|
12 |
ignore_errors = True
|
13 |
|
|
|
8 |
[mypy-axolotl.monkeypatch.*]
|
9 |
ignore_errors = True
|
10 |
|
11 |
+
[mypy-axolotl.models.mixtral.*]
|
12 |
+
ignore_errors = True
|
13 |
+
|
14 |
[mypy-axolotl.models.phi.*]
|
15 |
ignore_errors = True
|
16 |
|
README.md
CHANGED
@@ -65,19 +65,21 @@ Features:
|
|
65 |
|
66 |
## Axolotl supports
|
67 |
|
68 |
-
|
|
69 |
-
|
70 |
-
| llama
|
71 |
-
|
|
72 |
-
|
|
73 |
-
|
|
74 |
-
|
|
75 |
-
|
|
76 |
-
|
|
77 |
-
|
|
78 |
-
|
|
79 |
-
|
|
80 |
-
|
|
|
|
|
|
81 |
|
82 |
|
83 |
## Quickstart ⚡
|
|
|
65 |
|
66 |
## Axolotl supports
|
67 |
|
68 |
+
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
69 |
+
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
70 |
+
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
71 |
+
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
72 |
+
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
73 |
+
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
74 |
+
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
75 |
+
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
76 |
+
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
77 |
+
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
78 |
+
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
79 |
+
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
80 |
+
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
81 |
+
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
82 |
+
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
83 |
|
84 |
|
85 |
## Quickstart ⚡
|
docker/Dockerfile-runpod
CHANGED
@@ -4,6 +4,7 @@ FROM winglian/axolotl:$BASE_TAG
|
|
4 |
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
5 |
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
6 |
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
|
|
7 |
|
8 |
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
9 |
|
|
|
4 |
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
5 |
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
6 |
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
7 |
+
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
8 |
|
9 |
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
10 |
|
examples/mistral/mixtral.yml
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: DiscoResearch/mixtral-7b-8expert
|
2 |
+
model_type: MixtralForCausalLM
|
3 |
+
tokenizer_type: LlamaTokenizer
|
4 |
+
|
5 |
+
load_in_8bit: false
|
6 |
+
load_in_4bit: true
|
7 |
+
strict: false
|
8 |
+
|
9 |
+
datasets:
|
10 |
+
- path: tatsu-lab/alpaca
|
11 |
+
type: alpaca
|
12 |
+
dataset_prepared_path: last_run_prepared
|
13 |
+
val_set_size: 0.0
|
14 |
+
output_dir: ./qlora-out
|
15 |
+
|
16 |
+
adapter: qlora
|
17 |
+
lora_model_dir:
|
18 |
+
|
19 |
+
sequence_len: 4096
|
20 |
+
sample_packing: true
|
21 |
+
pad_to_sequence_len: true
|
22 |
+
|
23 |
+
lora_r: 32
|
24 |
+
lora_alpha: 16
|
25 |
+
lora_dropout: 0.05
|
26 |
+
lora_target_linear: true
|
27 |
+
lora_fan_in_fan_out:
|
28 |
+
#lora_target_modules:
|
29 |
+
# - gate
|
30 |
+
# - q_proj
|
31 |
+
# - k_proj
|
32 |
+
# - v_proj
|
33 |
+
# - o_proj
|
34 |
+
# - w1
|
35 |
+
# - w2
|
36 |
+
# - w3
|
37 |
+
|
38 |
+
wandb_project:
|
39 |
+
wandb_entity:
|
40 |
+
wandb_watch:
|
41 |
+
wandb_name:
|
42 |
+
wandb_log_model:
|
43 |
+
|
44 |
+
gradient_accumulation_steps: 2
|
45 |
+
micro_batch_size: 1
|
46 |
+
num_epochs: 1
|
47 |
+
optimizer: adamw_bnb_8bit
|
48 |
+
lr_scheduler: cosine
|
49 |
+
learning_rate: 0.0002
|
50 |
+
|
51 |
+
train_on_inputs: false
|
52 |
+
group_by_length: false
|
53 |
+
bf16: true
|
54 |
+
fp16: false
|
55 |
+
tf32: false
|
56 |
+
|
57 |
+
gradient_checkpointing: true
|
58 |
+
early_stopping_patience:
|
59 |
+
resume_from_checkpoint:
|
60 |
+
local_rank:
|
61 |
+
logging_steps: 1
|
62 |
+
xformers_attention:
|
63 |
+
flash_attention: true
|
64 |
+
|
65 |
+
loss_watchdog_threshold: 5.0
|
66 |
+
loss_watchdog_patience: 3
|
67 |
+
|
68 |
+
warmup_steps: 10
|
69 |
+
eval_steps:
|
70 |
+
eval_table_size:
|
71 |
+
eval_table_max_new_tokens: 128
|
72 |
+
save_steps:
|
73 |
+
debug:
|
74 |
+
deepspeed: deepspeed/zero2.json
|
75 |
+
weight_decay: 0.0
|
76 |
+
fsdp:
|
77 |
+
fsdp_config:
|
78 |
+
special_tokens:
|
src/axolotl/models/mixtral/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Custom modeling code for mixtral
|
3 |
+
"""
|
4 |
+
|
5 |
+
from .configuration_moe_mistral import MixtralConfig # noqa
|
6 |
+
from .modeling_moe_mistral import MixtralForCausalLM # noqa
|
src/axolotl/models/mixtral/configuration_moe_mistral.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Mistral model configuration"""
|
16 |
+
|
17 |
+
from transformers.configuration_utils import PretrainedConfig
|
18 |
+
from transformers.utils import logging
|
19 |
+
|
20 |
+
logger = logging.get_logger(__name__)
|
21 |
+
|
22 |
+
MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
23 |
+
"mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json",
|
24 |
+
"mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json",
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
class MixtralConfig(PretrainedConfig):
|
29 |
+
r"""
|
30 |
+
This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
|
31 |
+
Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
32 |
+
with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
|
33 |
+
|
34 |
+
[mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
|
35 |
+
[mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
36 |
+
|
37 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
38 |
+
documentation from [`PretrainedConfig`] for more information.
|
39 |
+
|
40 |
+
|
41 |
+
Args:
|
42 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
43 |
+
Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the
|
44 |
+
`inputs_ids` passed when calling [`MistralModel`]
|
45 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
46 |
+
Dimension of the hidden representations.
|
47 |
+
intermediate_size (`int`, *optional*, defaults to 14336):
|
48 |
+
Dimension of the MLP representations.
|
49 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
50 |
+
Number of hidden layers in the Transformer encoder.
|
51 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
52 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
53 |
+
num_key_value_heads (`int`, *optional*, defaults to 8):
|
54 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
55 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
56 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
57 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
58 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
59 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
|
60 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
61 |
+
The non-linear activation function (function or string) in the decoder.
|
62 |
+
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
|
63 |
+
The maximum sequence length that this model might ever be used with. Mistral's sliding window attention
|
64 |
+
allows sequence of up to 4096*32 tokens.
|
65 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
66 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
67 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
68 |
+
The epsilon used by the rms normalization layers.
|
69 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
70 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
71 |
+
relevant if `config.is_decoder=True`.
|
72 |
+
pad_token_id (`int`, *optional*):
|
73 |
+
The id of the padding token.
|
74 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
75 |
+
The id of the "beginning-of-sequence" token.
|
76 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
77 |
+
The id of the "end-of-sequence" token.
|
78 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
79 |
+
Whether the model's input and output word embeddings should be tied.
|
80 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
81 |
+
The base period of the RoPE embeddings.
|
82 |
+
sliding_window (`int`, *optional*, defaults to 4096):
|
83 |
+
Sliding window attention window size. If not specified, will default to `4096`.
|
84 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
85 |
+
The dropout ratio for the attention probabilities.
|
86 |
+
|
87 |
+
```python
|
88 |
+
>>> from transformers import MistralModel, MistralConfig
|
89 |
+
|
90 |
+
>>> # Initializing a Mistral 7B style configuration
|
91 |
+
>>> configuration = MixtralConfig()
|
92 |
+
|
93 |
+
>>> # Initializing a model from the Mistral 7B style configuration
|
94 |
+
>>> model = MixtralModel(configuration)
|
95 |
+
|
96 |
+
>>> # Accessing the model configuration
|
97 |
+
>>> configuration = model.config
|
98 |
+
```"""
|
99 |
+
|
100 |
+
model_type = "mistral"
|
101 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
102 |
+
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
vocab_size=32000,
|
106 |
+
hidden_size=4096,
|
107 |
+
intermediate_size=14336,
|
108 |
+
num_hidden_layers=32,
|
109 |
+
num_attention_heads=32,
|
110 |
+
num_key_value_heads=8,
|
111 |
+
hidden_act="silu",
|
112 |
+
max_position_embeddings=4096 * 32,
|
113 |
+
initializer_range=0.02,
|
114 |
+
rms_norm_eps=1e-6,
|
115 |
+
use_cache=True,
|
116 |
+
pad_token_id=None,
|
117 |
+
bos_token_id=1,
|
118 |
+
eos_token_id=2,
|
119 |
+
tie_word_embeddings=False,
|
120 |
+
rope_theta=10000.0,
|
121 |
+
attention_dropout=0.0,
|
122 |
+
num_experts_per_token=2,
|
123 |
+
num_experts=8,
|
124 |
+
**kwargs,
|
125 |
+
):
|
126 |
+
self.vocab_size = vocab_size
|
127 |
+
self.max_position_embeddings = max_position_embeddings
|
128 |
+
self.hidden_size = hidden_size
|
129 |
+
self.intermediate_size = intermediate_size
|
130 |
+
self.num_hidden_layers = num_hidden_layers
|
131 |
+
self.num_attention_heads = num_attention_heads
|
132 |
+
|
133 |
+
# for backward compatibility
|
134 |
+
if num_key_value_heads is None:
|
135 |
+
num_key_value_heads = num_attention_heads
|
136 |
+
|
137 |
+
self.num_key_value_heads = num_key_value_heads
|
138 |
+
self.hidden_act = hidden_act
|
139 |
+
self.initializer_range = initializer_range
|
140 |
+
self.rms_norm_eps = rms_norm_eps
|
141 |
+
self.use_cache = use_cache
|
142 |
+
self.rope_theta = rope_theta
|
143 |
+
self.attention_dropout = attention_dropout
|
144 |
+
self.num_experts = num_experts
|
145 |
+
self.num_experts_per_token = num_experts_per_token
|
146 |
+
|
147 |
+
# pylint: disable=duplicate-code
|
148 |
+
super().__init__(
|
149 |
+
pad_token_id=pad_token_id,
|
150 |
+
bos_token_id=bos_token_id,
|
151 |
+
eos_token_id=eos_token_id,
|
152 |
+
tie_word_embeddings=tie_word_embeddings,
|
153 |
+
**kwargs,
|
154 |
+
)
|
src/axolotl/models/mixtral/modeling_moe_mistral.py
ADDED
@@ -0,0 +1,1506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pylint: skip-file
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
6 |
+
# and OPT implementations in this library. It has been modified from its
|
7 |
+
# original forms to accommodate minor architectural differences compared
|
8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
9 |
+
#
|
10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
11 |
+
# you may not use this file except in compliance with the License.
|
12 |
+
# You may obtain a copy of the License at
|
13 |
+
#
|
14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
15 |
+
#
|
16 |
+
# Unless required by applicable law or agreed to in writing, software
|
17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
19 |
+
# See the License for the specific language governing permissions and
|
20 |
+
# limitations under the License.
|
21 |
+
""" PyTorch Mistral model."""
|
22 |
+
import inspect
|
23 |
+
import math
|
24 |
+
import warnings
|
25 |
+
from typing import List, Optional, Tuple, Union
|
26 |
+
|
27 |
+
import torch
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import torch.utils.checkpoint
|
30 |
+
from einops import rearrange
|
31 |
+
from torch import nn
|
32 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
33 |
+
from transformers.cache_utils import Cache, DynamicCache
|
34 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
35 |
+
from transformers.modeling_outputs import (
|
36 |
+
BaseModelOutputWithPast,
|
37 |
+
CausalLMOutputWithPast,
|
38 |
+
SequenceClassifierOutputWithPast,
|
39 |
+
)
|
40 |
+
from transformers.modeling_utils import PreTrainedModel
|
41 |
+
from transformers.utils import (
|
42 |
+
add_start_docstrings,
|
43 |
+
add_start_docstrings_to_model_forward,
|
44 |
+
is_flash_attn_2_available,
|
45 |
+
is_flash_attn_greater_or_equal_2_10,
|
46 |
+
logging,
|
47 |
+
replace_return_docstrings,
|
48 |
+
)
|
49 |
+
|
50 |
+
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
51 |
+
from .configuration_moe_mistral import MixtralConfig
|
52 |
+
|
53 |
+
if is_flash_attn_2_available():
|
54 |
+
from flash_attn import (
|
55 |
+
flash_attn_func,
|
56 |
+
flash_attn_varlen_func,
|
57 |
+
flash_attn_varlen_qkvpacked_func,
|
58 |
+
)
|
59 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
60 |
+
|
61 |
+
_flash_supports_window_size = "window_size" in list(
|
62 |
+
inspect.signature(flash_attn_func).parameters
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
logger = logging.get_logger(__name__)
|
67 |
+
|
68 |
+
_CONFIG_FOR_DOC = "MixtralConfig"
|
69 |
+
|
70 |
+
|
71 |
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
72 |
+
def _get_unpad_data(attention_mask):
|
73 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
74 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
75 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
76 |
+
cu_seqlens = F.pad(
|
77 |
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
78 |
+
)
|
79 |
+
return (
|
80 |
+
indices,
|
81 |
+
cu_seqlens,
|
82 |
+
max_seqlen_in_batch,
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
|
87 |
+
class MistralRMSNorm(nn.Module):
|
88 |
+
def __init__(self, hidden_size, eps=1e-6):
|
89 |
+
"""
|
90 |
+
MistralRMSNorm is equivalent to T5LayerNorm
|
91 |
+
"""
|
92 |
+
super().__init__()
|
93 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
94 |
+
self.variance_epsilon = eps
|
95 |
+
|
96 |
+
def forward(self, hidden_states):
|
97 |
+
input_dtype = hidden_states.dtype
|
98 |
+
hidden_states = hidden_states.to(torch.float32)
|
99 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
100 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
101 |
+
return self.weight * hidden_states.to(input_dtype)
|
102 |
+
|
103 |
+
|
104 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
|
105 |
+
class MistralRotaryEmbedding(nn.Module):
|
106 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
107 |
+
super().__init__()
|
108 |
+
|
109 |
+
self.dim = dim
|
110 |
+
self.max_position_embeddings = max_position_embeddings
|
111 |
+
self.base = base
|
112 |
+
inv_freq = 1.0 / (
|
113 |
+
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
114 |
+
)
|
115 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
116 |
+
|
117 |
+
# Build here to make `torch.jit.trace` work.
|
118 |
+
self._set_cos_sin_cache(
|
119 |
+
seq_len=max_position_embeddings,
|
120 |
+
device=self.inv_freq.device,
|
121 |
+
dtype=torch.get_default_dtype(),
|
122 |
+
)
|
123 |
+
|
124 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
125 |
+
self.max_seq_len_cached = seq_len
|
126 |
+
t = torch.arange(
|
127 |
+
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
128 |
+
)
|
129 |
+
|
130 |
+
freqs = torch.outer(t, self.inv_freq)
|
131 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
132 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
133 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
134 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
135 |
+
|
136 |
+
def forward(self, x, seq_len=None):
|
137 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
138 |
+
if seq_len > self.max_seq_len_cached:
|
139 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
140 |
+
|
141 |
+
return (
|
142 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
143 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
144 |
+
)
|
145 |
+
|
146 |
+
|
147 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
148 |
+
def rotate_half(x):
|
149 |
+
"""Rotates half the hidden dims of the input."""
|
150 |
+
x1 = x[..., : x.shape[-1] // 2]
|
151 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
152 |
+
return torch.cat((-x2, x1), dim=-1)
|
153 |
+
|
154 |
+
|
155 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
156 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
157 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
q (`torch.Tensor`): The query tensor.
|
161 |
+
k (`torch.Tensor`): The key tensor.
|
162 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
163 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
164 |
+
position_ids (`torch.Tensor`):
|
165 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
166 |
+
used to pass offsetted position ids when working with a KV-cache.
|
167 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
168 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
169 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
170 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
171 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
172 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
173 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
174 |
+
Returns:
|
175 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
176 |
+
"""
|
177 |
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
178 |
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
179 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
180 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
181 |
+
return q_embed, k_embed
|
182 |
+
|
183 |
+
|
184 |
+
class FeedForward(nn.Module):
|
185 |
+
def __init__(self, config):
|
186 |
+
"""
|
187 |
+
Initialize the FeedForward module.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
dim (int): Input dimension.
|
191 |
+
hidden_dim (int): Hidden dimension of the feedforward layer.
|
192 |
+
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
193 |
+
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
|
194 |
+
|
195 |
+
Attributes:
|
196 |
+
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
197 |
+
w2 (RowParallelLinear): Linear transformation for the second layer.
|
198 |
+
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
199 |
+
|
200 |
+
"""
|
201 |
+
super().__init__()
|
202 |
+
|
203 |
+
self.w1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
204 |
+
self.w2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
205 |
+
self.w3 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
206 |
+
|
207 |
+
def forward(self, x):
|
208 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
209 |
+
|
210 |
+
|
211 |
+
class MoE(nn.Module):
|
212 |
+
def __init__(
|
213 |
+
self,
|
214 |
+
config,
|
215 |
+
):
|
216 |
+
super().__init__()
|
217 |
+
self.config = config
|
218 |
+
num_experts = config.num_experts
|
219 |
+
self.experts = nn.ModuleList([FeedForward(config) for i in range(num_experts)])
|
220 |
+
self.gate = nn.Linear(config.hidden_size, num_experts, bias=False)
|
221 |
+
self.num_experts_per_token = config.num_experts_per_token
|
222 |
+
|
223 |
+
def forward(self, x):
|
224 |
+
orig_shape = x.shape
|
225 |
+
x = x.view(-1, x.shape[-1])
|
226 |
+
|
227 |
+
scores = self.gate(x)
|
228 |
+
expert_weights, expert_indices = torch.topk(
|
229 |
+
scores, self.num_experts_per_token, dim=-1
|
230 |
+
)
|
231 |
+
expert_weights = expert_weights.softmax(dim=-1)
|
232 |
+
flat_expert_indices = expert_indices.view(-1)
|
233 |
+
|
234 |
+
x = x.repeat_interleave(self.num_experts_per_token, dim=0)
|
235 |
+
y = torch.empty_like(x)
|
236 |
+
for i, expert in enumerate(self.experts):
|
237 |
+
y[flat_expert_indices == i] = expert(x[flat_expert_indices == i])
|
238 |
+
y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(
|
239 |
+
dim=1
|
240 |
+
)
|
241 |
+
return y.view(*orig_shape)
|
242 |
+
|
243 |
+
|
244 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
245 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
246 |
+
"""
|
247 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
248 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
249 |
+
"""
|
250 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
251 |
+
if n_rep == 1:
|
252 |
+
return hidden_states
|
253 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
254 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
255 |
+
)
|
256 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
257 |
+
|
258 |
+
|
259 |
+
class MistralAttention(nn.Module):
|
260 |
+
"""
|
261 |
+
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
262 |
+
and "Generating Long Sequences with Sparse Transformers".
|
263 |
+
"""
|
264 |
+
|
265 |
+
def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
|
266 |
+
super().__init__()
|
267 |
+
self.config = config
|
268 |
+
self.layer_idx = layer_idx
|
269 |
+
if layer_idx is None:
|
270 |
+
logger.warning_once(
|
271 |
+
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
272 |
+
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
273 |
+
"when creating this class."
|
274 |
+
)
|
275 |
+
|
276 |
+
self.hidden_size = config.hidden_size
|
277 |
+
self.num_heads = config.num_attention_heads
|
278 |
+
self.head_dim = self.hidden_size // self.num_heads
|
279 |
+
self.num_key_value_heads = config.num_key_value_heads
|
280 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
281 |
+
self.max_position_embeddings = config.max_position_embeddings
|
282 |
+
self.rope_theta = config.rope_theta
|
283 |
+
self.is_causal = True
|
284 |
+
self.attention_dropout = config.attention_dropout
|
285 |
+
|
286 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
287 |
+
raise ValueError(
|
288 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
289 |
+
f" and `num_heads`: {self.num_heads})."
|
290 |
+
)
|
291 |
+
self.q_proj = nn.Linear(
|
292 |
+
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
293 |
+
)
|
294 |
+
self.k_proj = nn.Linear(
|
295 |
+
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
296 |
+
)
|
297 |
+
self.v_proj = nn.Linear(
|
298 |
+
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
299 |
+
)
|
300 |
+
self.o_proj = nn.Linear(
|
301 |
+
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
302 |
+
)
|
303 |
+
|
304 |
+
self.rotary_emb = MistralRotaryEmbedding(
|
305 |
+
self.head_dim,
|
306 |
+
max_position_embeddings=self.max_position_embeddings,
|
307 |
+
base=self.rope_theta,
|
308 |
+
)
|
309 |
+
|
310 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
311 |
+
return (
|
312 |
+
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
313 |
+
.transpose(1, 2)
|
314 |
+
.contiguous()
|
315 |
+
)
|
316 |
+
|
317 |
+
def forward(
|
318 |
+
self,
|
319 |
+
hidden_states: torch.Tensor,
|
320 |
+
attention_mask: Optional[torch.Tensor] = None,
|
321 |
+
position_ids: Optional[torch.LongTensor] = None,
|
322 |
+
past_key_value: Optional[Cache] = None,
|
323 |
+
output_attentions: bool = False,
|
324 |
+
use_cache: bool = False,
|
325 |
+
**kwargs,
|
326 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
327 |
+
if "padding_mask" in kwargs:
|
328 |
+
warnings.warn(
|
329 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
330 |
+
)
|
331 |
+
bsz, q_len, _ = hidden_states.size()
|
332 |
+
|
333 |
+
query_states = self.q_proj(hidden_states)
|
334 |
+
key_states = self.k_proj(hidden_states)
|
335 |
+
value_states = self.v_proj(hidden_states)
|
336 |
+
|
337 |
+
query_states = query_states.view(
|
338 |
+
bsz, q_len, self.num_heads, self.head_dim
|
339 |
+
).transpose(1, 2)
|
340 |
+
key_states = key_states.view(
|
341 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
342 |
+
).transpose(1, 2)
|
343 |
+
value_states = value_states.view(
|
344 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
345 |
+
).transpose(1, 2)
|
346 |
+
|
347 |
+
kv_seq_len = key_states.shape[-2]
|
348 |
+
if past_key_value is not None:
|
349 |
+
if self.layer_idx is None:
|
350 |
+
raise ValueError(
|
351 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
352 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
353 |
+
"with a layer index."
|
354 |
+
)
|
355 |
+
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
356 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
357 |
+
query_states, key_states = apply_rotary_pos_emb(
|
358 |
+
query_states, key_states, cos, sin, position_ids
|
359 |
+
)
|
360 |
+
|
361 |
+
if past_key_value is not None:
|
362 |
+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
363 |
+
key_states, value_states = past_key_value.update(
|
364 |
+
key_states, value_states, self.layer_idx, cache_kwargs
|
365 |
+
)
|
366 |
+
|
367 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
368 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
369 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
370 |
+
|
371 |
+
attn_weights = torch.matmul(
|
372 |
+
query_states, key_states.transpose(2, 3)
|
373 |
+
) / math.sqrt(self.head_dim)
|
374 |
+
|
375 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
376 |
+
raise ValueError(
|
377 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
378 |
+
f" {attn_weights.size()}"
|
379 |
+
)
|
380 |
+
|
381 |
+
if attention_mask is not None:
|
382 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
383 |
+
raise ValueError(
|
384 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
385 |
+
)
|
386 |
+
|
387 |
+
attn_weights = attn_weights + attention_mask
|
388 |
+
|
389 |
+
# upcast attention to fp32
|
390 |
+
attn_weights = nn.functional.softmax(
|
391 |
+
attn_weights, dim=-1, dtype=torch.float32
|
392 |
+
).to(query_states.dtype)
|
393 |
+
attn_weights = nn.functional.dropout(
|
394 |
+
attn_weights, p=self.attention_dropout, training=self.training
|
395 |
+
)
|
396 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
397 |
+
|
398 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
399 |
+
raise ValueError(
|
400 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
401 |
+
f" {attn_output.size()}"
|
402 |
+
)
|
403 |
+
|
404 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
405 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
406 |
+
|
407 |
+
attn_output = self.o_proj(attn_output)
|
408 |
+
|
409 |
+
if not output_attentions:
|
410 |
+
attn_weights = None
|
411 |
+
|
412 |
+
return attn_output, attn_weights, past_key_value
|
413 |
+
|
414 |
+
|
415 |
+
class MistralFlashAttention2(MistralAttention):
|
416 |
+
"""
|
417 |
+
Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
|
418 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
419 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
420 |
+
"""
|
421 |
+
|
422 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
423 |
+
def __init__(self, *args, **kwargs):
|
424 |
+
super().__init__(*args, **kwargs)
|
425 |
+
|
426 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
427 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
428 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
429 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
430 |
+
|
431 |
+
def forward(
|
432 |
+
self,
|
433 |
+
hidden_states: torch.Tensor,
|
434 |
+
attention_mask: Optional[torch.Tensor] = None,
|
435 |
+
position_ids: Optional[torch.LongTensor] = None,
|
436 |
+
past_key_value: Optional[Cache] = None,
|
437 |
+
output_attentions: bool = False,
|
438 |
+
use_cache: bool = False,
|
439 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
440 |
+
max_seqlen: Optional[torch.Tensor] = None,
|
441 |
+
**kwargs,
|
442 |
+
):
|
443 |
+
if "padding_mask" in kwargs:
|
444 |
+
warnings.warn(
|
445 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
446 |
+
)
|
447 |
+
|
448 |
+
# overwrite attention_mask with padding_mask
|
449 |
+
attention_mask = kwargs.pop("padding_mask")
|
450 |
+
bsz, q_len, _ = hidden_states.size()
|
451 |
+
|
452 |
+
query_states = self.q_proj(hidden_states)
|
453 |
+
key_states = self.k_proj(hidden_states)
|
454 |
+
value_states = self.v_proj(hidden_states)
|
455 |
+
|
456 |
+
query_states = query_states.view(
|
457 |
+
bsz, q_len, self.num_heads, self.head_dim
|
458 |
+
).transpose(1, 2)
|
459 |
+
key_states = key_states.view(
|
460 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
461 |
+
).transpose(1, 2)
|
462 |
+
value_states = value_states.view(
|
463 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
464 |
+
).transpose(1, 2)
|
465 |
+
|
466 |
+
kv_seq_len = key_states.shape[-2]
|
467 |
+
if past_key_value is not None:
|
468 |
+
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
469 |
+
|
470 |
+
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
471 |
+
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
472 |
+
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
473 |
+
|
474 |
+
query_states, key_states = apply_rotary_pos_emb(
|
475 |
+
query_states, key_states, cos, sin, position_ids
|
476 |
+
)
|
477 |
+
|
478 |
+
use_sliding_windows = (
|
479 |
+
_flash_supports_window_size
|
480 |
+
and getattr(self.config, "sliding_window", None) is not None
|
481 |
+
and kv_seq_len > self.config.sliding_window
|
482 |
+
)
|
483 |
+
|
484 |
+
if not _flash_supports_window_size:
|
485 |
+
logger.warning_once(
|
486 |
+
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
|
487 |
+
" make sure to upgrade flash-attn library."
|
488 |
+
)
|
489 |
+
|
490 |
+
if past_key_value is not None:
|
491 |
+
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
492 |
+
if (
|
493 |
+
getattr(self.config, "sliding_window", None) is not None
|
494 |
+
and kv_seq_len > self.config.sliding_window
|
495 |
+
):
|
496 |
+
slicing_tokens = 1 - self.config.sliding_window
|
497 |
+
|
498 |
+
past_key = past_key_value[0]
|
499 |
+
past_value = past_key_value[1]
|
500 |
+
|
501 |
+
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
502 |
+
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
503 |
+
|
504 |
+
if past_key.shape[-2] != self.config.sliding_window - 1:
|
505 |
+
raise ValueError(
|
506 |
+
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
507 |
+
f" {past_key.shape}"
|
508 |
+
)
|
509 |
+
|
510 |
+
past_key_value = (past_key, past_value)
|
511 |
+
|
512 |
+
if attention_mask is not None:
|
513 |
+
attention_mask = attention_mask[:, slicing_tokens:]
|
514 |
+
attention_mask = torch.cat(
|
515 |
+
[attention_mask, torch.ones_like(attention_mask[:, -1:])],
|
516 |
+
dim=-1,
|
517 |
+
)
|
518 |
+
|
519 |
+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
520 |
+
key_states, value_states = past_key_value.update(
|
521 |
+
key_states, value_states, self.layer_idx, cache_kwargs
|
522 |
+
)
|
523 |
+
|
524 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
525 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
526 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
527 |
+
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
528 |
+
|
529 |
+
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
530 |
+
# special handling using sample packing
|
531 |
+
qkv = torch.stack(
|
532 |
+
[query_states, key_states, value_states], dim=2
|
533 |
+
) # [bsz, nh, 3, q_len, hd]
|
534 |
+
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
535 |
+
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
536 |
+
|
537 |
+
attn_output = flash_attn_varlen_qkvpacked_func(
|
538 |
+
qkv,
|
539 |
+
cu_seqlens,
|
540 |
+
max_seqlen,
|
541 |
+
dropout_p=dropout_rate,
|
542 |
+
softmax_scale=None,
|
543 |
+
causal=True,
|
544 |
+
)
|
545 |
+
attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz)
|
546 |
+
else:
|
547 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
548 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
549 |
+
# cast them back in float16 just to be sure everything works as expected.
|
550 |
+
input_dtype = query_states.dtype
|
551 |
+
if input_dtype == torch.float32:
|
552 |
+
# Handle the case where the model is quantized
|
553 |
+
if hasattr(self.config, "_pre_quantization_dtype"):
|
554 |
+
target_dtype = self.config._pre_quantization_dtype
|
555 |
+
else:
|
556 |
+
target_dtype = self.q_proj.weight.dtype
|
557 |
+
|
558 |
+
logger.warning_once(
|
559 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
560 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
561 |
+
f" {target_dtype}."
|
562 |
+
)
|
563 |
+
|
564 |
+
query_states = query_states.to(target_dtype)
|
565 |
+
key_states = key_states.to(target_dtype)
|
566 |
+
value_states = value_states.to(target_dtype)
|
567 |
+
|
568 |
+
# Reashape to the expected shape for Flash Attention
|
569 |
+
query_states = query_states.transpose(1, 2)
|
570 |
+
key_states = key_states.transpose(1, 2)
|
571 |
+
value_states = value_states.transpose(1, 2)
|
572 |
+
|
573 |
+
attn_output = self._flash_attention_forward(
|
574 |
+
query_states,
|
575 |
+
key_states,
|
576 |
+
value_states,
|
577 |
+
attention_mask,
|
578 |
+
q_len,
|
579 |
+
dropout=dropout_rate,
|
580 |
+
use_sliding_windows=use_sliding_windows,
|
581 |
+
)
|
582 |
+
|
583 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
584 |
+
attn_output = self.o_proj(attn_output)
|
585 |
+
|
586 |
+
if not output_attentions:
|
587 |
+
attn_weights = None
|
588 |
+
|
589 |
+
return attn_output, attn_weights, past_key_value
|
590 |
+
|
591 |
+
def _flash_attention_forward(
|
592 |
+
self,
|
593 |
+
query_states,
|
594 |
+
key_states,
|
595 |
+
value_states,
|
596 |
+
attention_mask,
|
597 |
+
query_length,
|
598 |
+
dropout=0.0,
|
599 |
+
softmax_scale=None,
|
600 |
+
use_sliding_windows=False,
|
601 |
+
):
|
602 |
+
"""
|
603 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
604 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
605 |
+
|
606 |
+
Args:
|
607 |
+
query_states (`torch.Tensor`):
|
608 |
+
Input query states to be passed to Flash Attention API
|
609 |
+
key_states (`torch.Tensor`):
|
610 |
+
Input key states to be passed to Flash Attention API
|
611 |
+
value_states (`torch.Tensor`):
|
612 |
+
Input value states to be passed to Flash Attention API
|
613 |
+
attention_mask (`torch.Tensor`):
|
614 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
615 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
616 |
+
dropout (`int`, *optional*):
|
617 |
+
Attention dropout
|
618 |
+
softmax_scale (`float`, *optional*):
|
619 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
620 |
+
use_sliding_windows (`bool`, *optional*):
|
621 |
+
Whether to activate sliding window attention.
|
622 |
+
"""
|
623 |
+
if not self._flash_attn_uses_top_left_mask:
|
624 |
+
causal = self.is_causal
|
625 |
+
else:
|
626 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
627 |
+
causal = self.is_causal and query_length != 1
|
628 |
+
|
629 |
+
# Contains at least one padding token in the sequence
|
630 |
+
if attention_mask is not None:
|
631 |
+
batch_size = query_states.shape[0]
|
632 |
+
(
|
633 |
+
query_states,
|
634 |
+
key_states,
|
635 |
+
value_states,
|
636 |
+
indices_q,
|
637 |
+
cu_seq_lens,
|
638 |
+
max_seq_lens,
|
639 |
+
) = self._upad_input(
|
640 |
+
query_states, key_states, value_states, attention_mask, query_length
|
641 |
+
)
|
642 |
+
|
643 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
644 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
645 |
+
|
646 |
+
if not use_sliding_windows:
|
647 |
+
attn_output_unpad = flash_attn_varlen_func(
|
648 |
+
query_states,
|
649 |
+
key_states,
|
650 |
+
value_states,
|
651 |
+
cu_seqlens_q=cu_seqlens_q,
|
652 |
+
cu_seqlens_k=cu_seqlens_k,
|
653 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
654 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
655 |
+
dropout_p=dropout,
|
656 |
+
softmax_scale=softmax_scale,
|
657 |
+
causal=causal,
|
658 |
+
)
|
659 |
+
else:
|
660 |
+
attn_output_unpad = flash_attn_varlen_func(
|
661 |
+
query_states,
|
662 |
+
key_states,
|
663 |
+
value_states,
|
664 |
+
cu_seqlens_q=cu_seqlens_q,
|
665 |
+
cu_seqlens_k=cu_seqlens_k,
|
666 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
667 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
668 |
+
dropout_p=dropout,
|
669 |
+
softmax_scale=softmax_scale,
|
670 |
+
causal=causal,
|
671 |
+
window_size=(
|
672 |
+
self.config.sliding_window,
|
673 |
+
self.config.sliding_window,
|
674 |
+
),
|
675 |
+
)
|
676 |
+
|
677 |
+
attn_output = pad_input(
|
678 |
+
attn_output_unpad, indices_q, batch_size, query_length
|
679 |
+
)
|
680 |
+
else:
|
681 |
+
if not use_sliding_windows:
|
682 |
+
attn_output = flash_attn_func(
|
683 |
+
query_states,
|
684 |
+
key_states,
|
685 |
+
value_states,
|
686 |
+
dropout,
|
687 |
+
softmax_scale=softmax_scale,
|
688 |
+
causal=causal,
|
689 |
+
)
|
690 |
+
else:
|
691 |
+
attn_output = flash_attn_func(
|
692 |
+
query_states,
|
693 |
+
key_states,
|
694 |
+
value_states,
|
695 |
+
dropout,
|
696 |
+
softmax_scale=softmax_scale,
|
697 |
+
causal=causal,
|
698 |
+
window_size=(
|
699 |
+
self.config.sliding_window,
|
700 |
+
self.config.sliding_window,
|
701 |
+
),
|
702 |
+
)
|
703 |
+
|
704 |
+
return attn_output
|
705 |
+
|
706 |
+
def _upad_input(
|
707 |
+
self, query_layer, key_layer, value_layer, attention_mask, query_length
|
708 |
+
):
|
709 |
+
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
710 |
+
|
711 |
+
# On the first iteration we need to properly re-create the padding mask
|
712 |
+
# by slicing it on the proper place
|
713 |
+
if kv_seq_len != attention_mask.shape[-1]:
|
714 |
+
attention_mask_num_tokens = attention_mask.shape[-1]
|
715 |
+
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
|
716 |
+
|
717 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
718 |
+
|
719 |
+
key_layer = index_first_axis(
|
720 |
+
key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
721 |
+
)
|
722 |
+
value_layer = index_first_axis(
|
723 |
+
value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
724 |
+
)
|
725 |
+
|
726 |
+
if query_length == kv_seq_len:
|
727 |
+
query_layer = index_first_axis(
|
728 |
+
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim),
|
729 |
+
indices_k,
|
730 |
+
)
|
731 |
+
cu_seqlens_q = cu_seqlens_k
|
732 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
733 |
+
indices_q = indices_k
|
734 |
+
elif query_length == 1:
|
735 |
+
max_seqlen_in_batch_q = 1
|
736 |
+
cu_seqlens_q = torch.arange(
|
737 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
738 |
+
) # There is a memcpy here, that is very bad.
|
739 |
+
indices_q = cu_seqlens_q[:-1]
|
740 |
+
query_layer = query_layer.squeeze(1)
|
741 |
+
else:
|
742 |
+
# The -q_len: slice assumes left padding.
|
743 |
+
attention_mask = attention_mask[:, -query_length:]
|
744 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
745 |
+
query_layer, attention_mask
|
746 |
+
)
|
747 |
+
|
748 |
+
return (
|
749 |
+
query_layer,
|
750 |
+
key_layer,
|
751 |
+
value_layer,
|
752 |
+
indices_q,
|
753 |
+
(cu_seqlens_q, cu_seqlens_k),
|
754 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
755 |
+
)
|
756 |
+
|
757 |
+
|
758 |
+
class MixtralDecoderLayer(nn.Module):
|
759 |
+
def __init__(self, config: MixtralConfig, layer_idx: int):
|
760 |
+
super().__init__()
|
761 |
+
self.hidden_size = config.hidden_size
|
762 |
+
self.self_attn = MistralFlashAttention2(config, layer_idx=layer_idx)
|
763 |
+
self.mlp = MoE(config)
|
764 |
+
self.input_layernorm = MistralRMSNorm(
|
765 |
+
config.hidden_size, eps=config.rms_norm_eps
|
766 |
+
)
|
767 |
+
self.post_attention_layernorm = MistralRMSNorm(
|
768 |
+
config.hidden_size, eps=config.rms_norm_eps
|
769 |
+
)
|
770 |
+
|
771 |
+
def forward(
|
772 |
+
self,
|
773 |
+
hidden_states: torch.Tensor,
|
774 |
+
attention_mask: Optional[torch.Tensor] = None,
|
775 |
+
position_ids: Optional[torch.LongTensor] = None,
|
776 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
777 |
+
output_attentions: Optional[bool] = False,
|
778 |
+
use_cache: Optional[bool] = False,
|
779 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
780 |
+
max_seqlen: Optional[torch.Tensor] = None,
|
781 |
+
**kwargs,
|
782 |
+
) -> Tuple[
|
783 |
+
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
784 |
+
]:
|
785 |
+
if "padding_mask" in kwargs:
|
786 |
+
warnings.warn(
|
787 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
788 |
+
)
|
789 |
+
"""
|
790 |
+
Args:
|
791 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
792 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
793 |
+
`(batch, sequence_length)` where padding elements are indicated by 0.
|
794 |
+
output_attentions (`bool`, *optional*):
|
795 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
796 |
+
returned tensors for more detail.
|
797 |
+
use_cache (`bool`, *optional*):
|
798 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
799 |
+
(see `past_key_values`).
|
800 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
801 |
+
"""
|
802 |
+
|
803 |
+
residual = hidden_states
|
804 |
+
|
805 |
+
hidden_states = self.input_layernorm(hidden_states)
|
806 |
+
|
807 |
+
# Self Attention
|
808 |
+
# pylint: disable=duplicate-code
|
809 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
810 |
+
hidden_states=hidden_states,
|
811 |
+
attention_mask=attention_mask,
|
812 |
+
position_ids=position_ids,
|
813 |
+
past_key_value=past_key_value,
|
814 |
+
output_attentions=output_attentions,
|
815 |
+
use_cache=use_cache,
|
816 |
+
cu_seqlens=cu_seqlens,
|
817 |
+
max_seqlen=max_seqlen,
|
818 |
+
)
|
819 |
+
hidden_states = residual + hidden_states
|
820 |
+
|
821 |
+
# Fully Connected
|
822 |
+
residual = hidden_states
|
823 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
824 |
+
hidden_states = self.mlp(hidden_states)
|
825 |
+
hidden_states = residual + hidden_states
|
826 |
+
|
827 |
+
outputs = (hidden_states,)
|
828 |
+
|
829 |
+
if output_attentions:
|
830 |
+
outputs += (self_attn_weights,)
|
831 |
+
|
832 |
+
if use_cache:
|
833 |
+
outputs += (present_key_value,)
|
834 |
+
|
835 |
+
return outputs
|
836 |
+
|
837 |
+
|
838 |
+
MISTRAL_START_DOCSTRING = r"""
|
839 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
840 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
841 |
+
etc.)
|
842 |
+
|
843 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
844 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
845 |
+
and behavior.
|
846 |
+
|
847 |
+
Parameters:
|
848 |
+
config ([`MixtralConfig`]):
|
849 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
850 |
+
load the weights associated with the model, only the configuration. Check out the
|
851 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
852 |
+
"""
|
853 |
+
|
854 |
+
|
855 |
+
@add_start_docstrings(
|
856 |
+
"The bare Mistral Model outputting raw hidden-states without any specific head on top.",
|
857 |
+
MISTRAL_START_DOCSTRING,
|
858 |
+
)
|
859 |
+
class MixtralPreTrainedModel(PreTrainedModel):
|
860 |
+
config_class = MixtralConfig
|
861 |
+
base_model_prefix = "model"
|
862 |
+
supports_gradient_checkpointing = True
|
863 |
+
_no_split_modules = ["MixtralDecoderLayer"]
|
864 |
+
_skip_keys_device_placement = "past_key_values"
|
865 |
+
_supports_flash_attn_2 = True
|
866 |
+
_supports_cache_class = True
|
867 |
+
|
868 |
+
def _init_weights(self, module):
|
869 |
+
std = self.config.initializer_range
|
870 |
+
if isinstance(module, nn.Linear):
|
871 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
872 |
+
if module.bias is not None:
|
873 |
+
module.bias.data.zero_()
|
874 |
+
elif isinstance(module, nn.Embedding):
|
875 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
876 |
+
if module.padding_idx is not None:
|
877 |
+
module.weight.data[module.padding_idx].zero_()
|
878 |
+
|
879 |
+
|
880 |
+
MISTRAL_INPUTS_DOCSTRING = r"""
|
881 |
+
Args:
|
882 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
883 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
884 |
+
it.
|
885 |
+
|
886 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
887 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
888 |
+
|
889 |
+
[What are input IDs?](../glossary#input-ids)
|
890 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
891 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
892 |
+
|
893 |
+
- 1 for tokens that are **not masked**,
|
894 |
+
- 0 for tokens that are **masked**.
|
895 |
+
|
896 |
+
[What are attention masks?](../glossary#attention-mask)
|
897 |
+
|
898 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
899 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
900 |
+
|
901 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
902 |
+
`past_key_values`).
|
903 |
+
|
904 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
905 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
906 |
+
information on the default strategy.
|
907 |
+
|
908 |
+
- 1 indicates the head is **not masked**,
|
909 |
+
- 0 indicates the head is **masked**.
|
910 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
911 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
912 |
+
config.n_positions - 1]`.
|
913 |
+
|
914 |
+
[What are position IDs?](../glossary#position-ids)
|
915 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
916 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
917 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
918 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
919 |
+
|
920 |
+
Two formats are allowed:
|
921 |
+
- a [`~cache_utils.Cache`] instance;
|
922 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
923 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
924 |
+
cache format.
|
925 |
+
|
926 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
927 |
+
legacy cache format will be returned.
|
928 |
+
|
929 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
930 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
931 |
+
of shape `(batch_size, sequence_length)`.
|
932 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
933 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
934 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
935 |
+
model's internal embedding lookup matrix.
|
936 |
+
use_cache (`bool`, *optional*):
|
937 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
938 |
+
`past_key_values`).
|
939 |
+
output_attentions (`bool`, *optional*):
|
940 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
941 |
+
tensors for more detail.
|
942 |
+
output_hidden_states (`bool`, *optional*):
|
943 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
944 |
+
more detail.
|
945 |
+
return_dict (`bool`, *optional*):
|
946 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
947 |
+
"""
|
948 |
+
|
949 |
+
|
950 |
+
@add_start_docstrings(
|
951 |
+
"The bare Mistral Model outputting raw hidden-states without any specific head on top.",
|
952 |
+
MISTRAL_START_DOCSTRING,
|
953 |
+
)
|
954 |
+
class MistralModel(MixtralPreTrainedModel):
|
955 |
+
"""
|
956 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
|
957 |
+
|
958 |
+
Args:
|
959 |
+
config: MixtralConfig
|
960 |
+
"""
|
961 |
+
|
962 |
+
def __init__(self, config: MixtralConfig):
|
963 |
+
super().__init__(config)
|
964 |
+
self.padding_idx = config.pad_token_id
|
965 |
+
self.vocab_size = config.vocab_size
|
966 |
+
|
967 |
+
self.embed_tokens = nn.Embedding(
|
968 |
+
config.vocab_size, config.hidden_size, self.padding_idx
|
969 |
+
)
|
970 |
+
self.layers = nn.ModuleList(
|
971 |
+
[
|
972 |
+
MixtralDecoderLayer(config, layer_idx)
|
973 |
+
for layer_idx in range(config.num_hidden_layers)
|
974 |
+
]
|
975 |
+
)
|
976 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
977 |
+
self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
978 |
+
|
979 |
+
self.gradient_checkpointing = False
|
980 |
+
# Initialize weights and apply final processing
|
981 |
+
self.post_init()
|
982 |
+
|
983 |
+
def get_input_embeddings(self):
|
984 |
+
return self.embed_tokens
|
985 |
+
|
986 |
+
def set_input_embeddings(self, value):
|
987 |
+
self.embed_tokens = value
|
988 |
+
|
989 |
+
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
990 |
+
def forward(
|
991 |
+
self,
|
992 |
+
input_ids: torch.LongTensor = None,
|
993 |
+
attention_mask: Optional[torch.Tensor] = None,
|
994 |
+
position_ids: Optional[torch.LongTensor] = None,
|
995 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
996 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
997 |
+
use_cache: Optional[bool] = None,
|
998 |
+
output_attentions: Optional[bool] = None,
|
999 |
+
output_hidden_states: Optional[bool] = None,
|
1000 |
+
return_dict: Optional[bool] = None,
|
1001 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
1002 |
+
output_attentions = (
|
1003 |
+
output_attentions
|
1004 |
+
if output_attentions is not None
|
1005 |
+
else self.config.output_attentions
|
1006 |
+
)
|
1007 |
+
output_hidden_states = (
|
1008 |
+
output_hidden_states
|
1009 |
+
if output_hidden_states is not None
|
1010 |
+
else self.config.output_hidden_states
|
1011 |
+
)
|
1012 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1013 |
+
|
1014 |
+
return_dict = (
|
1015 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1016 |
+
)
|
1017 |
+
|
1018 |
+
# retrieve input_ids and inputs_embeds
|
1019 |
+
if input_ids is not None and inputs_embeds is not None:
|
1020 |
+
raise ValueError(
|
1021 |
+
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
1022 |
+
)
|
1023 |
+
elif input_ids is not None:
|
1024 |
+
batch_size, seq_length = input_ids.shape
|
1025 |
+
elif inputs_embeds is not None:
|
1026 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
1027 |
+
else:
|
1028 |
+
raise ValueError(
|
1029 |
+
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
1030 |
+
)
|
1031 |
+
|
1032 |
+
seq_length_with_past = seq_length
|
1033 |
+
past_key_values_length = 0
|
1034 |
+
|
1035 |
+
if use_cache:
|
1036 |
+
use_legacy_cache = not isinstance(past_key_values, Cache)
|
1037 |
+
if use_legacy_cache:
|
1038 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
1039 |
+
past_key_values_length = past_key_values.get_seq_length()
|
1040 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
1041 |
+
|
1042 |
+
cu_seqlens = None
|
1043 |
+
max_seqlen = None
|
1044 |
+
if position_ids is None:
|
1045 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
1046 |
+
position_ids = torch.arange(
|
1047 |
+
past_key_values_length,
|
1048 |
+
seq_length + past_key_values_length,
|
1049 |
+
dtype=torch.long,
|
1050 |
+
device=device,
|
1051 |
+
)
|
1052 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
1053 |
+
else:
|
1054 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
1055 |
+
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
1056 |
+
cu_seqlens = cu_seqlens.squeeze()
|
1057 |
+
|
1058 |
+
if inputs_embeds is None:
|
1059 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
1060 |
+
|
1061 |
+
if (
|
1062 |
+
attention_mask is not None
|
1063 |
+
and hasattr(self.config, "_flash_attn_2_enabled")
|
1064 |
+
and self.config._flash_attn_2_enabled
|
1065 |
+
and use_cache
|
1066 |
+
):
|
1067 |
+
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
1068 |
+
if is_padding_right:
|
1069 |
+
raise ValueError(
|
1070 |
+
"You are attempting to perform batched generation with padding_side='right'"
|
1071 |
+
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
|
1072 |
+
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
1073 |
+
)
|
1074 |
+
|
1075 |
+
if getattr(self.config, "_flash_attn_2_enabled", False):
|
1076 |
+
# 2d mask is passed through the layers
|
1077 |
+
attention_mask = (
|
1078 |
+
attention_mask
|
1079 |
+
if (attention_mask is not None and 0 in attention_mask)
|
1080 |
+
else None
|
1081 |
+
)
|
1082 |
+
else:
|
1083 |
+
# 4d mask is passed through the layers
|
1084 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
1085 |
+
attention_mask,
|
1086 |
+
(batch_size, seq_length),
|
1087 |
+
inputs_embeds,
|
1088 |
+
past_key_values_length,
|
1089 |
+
)
|
1090 |
+
|
1091 |
+
hidden_states = inputs_embeds
|
1092 |
+
|
1093 |
+
if self.gradient_checkpointing and self.training:
|
1094 |
+
if use_cache:
|
1095 |
+
logger.warning_once(
|
1096 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
1097 |
+
)
|
1098 |
+
use_cache = False
|
1099 |
+
|
1100 |
+
# decoder layers
|
1101 |
+
all_hidden_states = () if output_hidden_states else None
|
1102 |
+
all_self_attns = () if output_attentions else None
|
1103 |
+
next_decoder_cache = None
|
1104 |
+
|
1105 |
+
for decoder_layer in self.layers:
|
1106 |
+
if output_hidden_states:
|
1107 |
+
all_hidden_states += (hidden_states,)
|
1108 |
+
|
1109 |
+
if self.gradient_checkpointing and self.training:
|
1110 |
+
layer_outputs = self._gradient_checkpointing_func(
|
1111 |
+
decoder_layer.__call__,
|
1112 |
+
hidden_states,
|
1113 |
+
attention_mask,
|
1114 |
+
position_ids,
|
1115 |
+
past_key_values,
|
1116 |
+
output_attentions,
|
1117 |
+
use_cache,
|
1118 |
+
cu_seqlens,
|
1119 |
+
max_seqlen,
|
1120 |
+
)
|
1121 |
+
else:
|
1122 |
+
layer_outputs = decoder_layer(
|
1123 |
+
hidden_states,
|
1124 |
+
attention_mask=attention_mask,
|
1125 |
+
position_ids=position_ids,
|
1126 |
+
past_key_value=past_key_values,
|
1127 |
+
output_attentions=output_attentions,
|
1128 |
+
use_cache=use_cache,
|
1129 |
+
cu_seqlens=cu_seqlens,
|
1130 |
+
max_seqlen=max_seqlen,
|
1131 |
+
)
|
1132 |
+
|
1133 |
+
hidden_states = layer_outputs[0]
|
1134 |
+
|
1135 |
+
if use_cache:
|
1136 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
1137 |
+
|
1138 |
+
if output_attentions:
|
1139 |
+
all_self_attns += (layer_outputs[1],)
|
1140 |
+
|
1141 |
+
hidden_states = self.norm(hidden_states)
|
1142 |
+
|
1143 |
+
# add hidden states from the last decoder layer
|
1144 |
+
if output_hidden_states:
|
1145 |
+
all_hidden_states += (hidden_states,)
|
1146 |
+
|
1147 |
+
next_cache = None
|
1148 |
+
if use_cache:
|
1149 |
+
next_cache = (
|
1150 |
+
next_decoder_cache.to_legacy_cache()
|
1151 |
+
if use_legacy_cache
|
1152 |
+
else next_decoder_cache
|
1153 |
+
)
|
1154 |
+
|
1155 |
+
if not return_dict:
|
1156 |
+
return tuple(
|
1157 |
+
v
|
1158 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
1159 |
+
if v is not None
|
1160 |
+
)
|
1161 |
+
return BaseModelOutputWithPast(
|
1162 |
+
last_hidden_state=hidden_states,
|
1163 |
+
past_key_values=next_cache,
|
1164 |
+
hidden_states=all_hidden_states,
|
1165 |
+
attentions=all_self_attns,
|
1166 |
+
)
|
1167 |
+
|
1168 |
+
|
1169 |
+
class MixtralForCausalLM(MixtralPreTrainedModel):
|
1170 |
+
_tied_weights_keys = ["lm_head.weight"]
|
1171 |
+
|
1172 |
+
def __init__(self, config):
|
1173 |
+
super().__init__(config)
|
1174 |
+
self.model = MistralModel(config)
|
1175 |
+
self.vocab_size = config.vocab_size
|
1176 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1177 |
+
|
1178 |
+
# Initialize weights and apply final processing
|
1179 |
+
self.post_init()
|
1180 |
+
|
1181 |
+
def get_input_embeddings(self):
|
1182 |
+
return self.model.embed_tokens
|
1183 |
+
|
1184 |
+
def set_input_embeddings(self, value):
|
1185 |
+
self.model.embed_tokens = value
|
1186 |
+
|
1187 |
+
def get_output_embeddings(self):
|
1188 |
+
return self.lm_head
|
1189 |
+
|
1190 |
+
def set_output_embeddings(self, new_embeddings):
|
1191 |
+
self.lm_head = new_embeddings
|
1192 |
+
|
1193 |
+
def set_decoder(self, decoder):
|
1194 |
+
self.model = decoder
|
1195 |
+
|
1196 |
+
def get_decoder(self):
|
1197 |
+
return self.model
|
1198 |
+
|
1199 |
+
def _init_weights(self, module):
|
1200 |
+
return
|
1201 |
+
|
1202 |
+
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
1203 |
+
@replace_return_docstrings(
|
1204 |
+
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
1205 |
+
)
|
1206 |
+
def forward(
|
1207 |
+
self,
|
1208 |
+
input_ids: torch.LongTensor = None,
|
1209 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1210 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1211 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1212 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1213 |
+
labels: Optional[torch.LongTensor] = None,
|
1214 |
+
use_cache: Optional[bool] = None,
|
1215 |
+
output_attentions: Optional[bool] = None,
|
1216 |
+
output_hidden_states: Optional[bool] = None,
|
1217 |
+
return_dict: Optional[bool] = None,
|
1218 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
1219 |
+
r"""
|
1220 |
+
Args:
|
1221 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1222 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1223 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1224 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1225 |
+
|
1226 |
+
Returns:
|
1227 |
+
|
1228 |
+
Example:
|
1229 |
+
|
1230 |
+
```python
|
1231 |
+
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
1232 |
+
|
1233 |
+
>>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
1234 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
1235 |
+
|
1236 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
1237 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1238 |
+
|
1239 |
+
>>> # Generate
|
1240 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1241 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1242 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
1243 |
+
```"""
|
1244 |
+
|
1245 |
+
output_attentions = (
|
1246 |
+
output_attentions
|
1247 |
+
if output_attentions is not None
|
1248 |
+
else self.config.output_attentions
|
1249 |
+
)
|
1250 |
+
output_hidden_states = (
|
1251 |
+
output_hidden_states
|
1252 |
+
if output_hidden_states is not None
|
1253 |
+
else self.config.output_hidden_states
|
1254 |
+
)
|
1255 |
+
return_dict = (
|
1256 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1257 |
+
)
|
1258 |
+
|
1259 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1260 |
+
outputs = self.model(
|
1261 |
+
input_ids=input_ids,
|
1262 |
+
attention_mask=attention_mask,
|
1263 |
+
position_ids=position_ids,
|
1264 |
+
past_key_values=past_key_values,
|
1265 |
+
inputs_embeds=inputs_embeds,
|
1266 |
+
use_cache=use_cache,
|
1267 |
+
output_attentions=output_attentions,
|
1268 |
+
output_hidden_states=output_hidden_states,
|
1269 |
+
return_dict=return_dict,
|
1270 |
+
)
|
1271 |
+
|
1272 |
+
hidden_states = outputs[0]
|
1273 |
+
logits = self.lm_head(hidden_states)
|
1274 |
+
logits = logits.float()
|
1275 |
+
|
1276 |
+
loss = None
|
1277 |
+
if labels is not None:
|
1278 |
+
# Shift so that tokens < n predict n
|
1279 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
1280 |
+
shift_labels = labels[..., 1:].contiguous()
|
1281 |
+
# Flatten the tokens
|
1282 |
+
loss_fct = CrossEntropyLoss()
|
1283 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1284 |
+
shift_labels = shift_labels.view(-1)
|
1285 |
+
# Enable model parallelism
|
1286 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
1287 |
+
loss = loss_fct(shift_logits, shift_labels)
|
1288 |
+
|
1289 |
+
if not return_dict:
|
1290 |
+
output = (logits,) + outputs[1:]
|
1291 |
+
return (loss,) + output if loss is not None else output
|
1292 |
+
|
1293 |
+
return CausalLMOutputWithPast(
|
1294 |
+
loss=loss,
|
1295 |
+
logits=logits,
|
1296 |
+
past_key_values=outputs.past_key_values,
|
1297 |
+
hidden_states=outputs.hidden_states,
|
1298 |
+
attentions=outputs.attentions,
|
1299 |
+
)
|
1300 |
+
|
1301 |
+
def prepare_inputs_for_generation(
|
1302 |
+
self,
|
1303 |
+
input_ids,
|
1304 |
+
past_key_values=None,
|
1305 |
+
attention_mask=None,
|
1306 |
+
inputs_embeds=None,
|
1307 |
+
**kwargs,
|
1308 |
+
):
|
1309 |
+
# Omit tokens covered by past_key_values
|
1310 |
+
if past_key_values is not None:
|
1311 |
+
if isinstance(past_key_values, Cache):
|
1312 |
+
cache_length = past_key_values.get_seq_length()
|
1313 |
+
past_length = past_key_values.seen_tokens
|
1314 |
+
else:
|
1315 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
1316 |
+
|
1317 |
+
# Keep only the unprocessed tokens:
|
1318 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
1319 |
+
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
|
1320 |
+
# input)
|
1321 |
+
if (
|
1322 |
+
attention_mask is not None
|
1323 |
+
and attention_mask.shape[1] > input_ids.shape[1]
|
1324 |
+
):
|
1325 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
1326 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
1327 |
+
# input_ids based on the past_length.
|
1328 |
+
elif past_length < input_ids.shape[1]:
|
1329 |
+
input_ids = input_ids[:, past_length:]
|
1330 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
1331 |
+
|
1332 |
+
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
1333 |
+
# older attention values, as their corresponding values are not part of the input.
|
1334 |
+
if cache_length < past_length and attention_mask is not None:
|
1335 |
+
attention_mask = attention_mask[
|
1336 |
+
:, -(cache_length + input_ids.shape[1]) :
|
1337 |
+
]
|
1338 |
+
|
1339 |
+
position_ids = kwargs.get("position_ids", None)
|
1340 |
+
if attention_mask is not None and position_ids is None:
|
1341 |
+
# create position_ids on the fly for batch generation
|
1342 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1343 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1344 |
+
if past_key_values:
|
1345 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1346 |
+
|
1347 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1348 |
+
if inputs_embeds is not None and past_key_values is None:
|
1349 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
1350 |
+
else:
|
1351 |
+
model_inputs = {"input_ids": input_ids}
|
1352 |
+
|
1353 |
+
model_inputs.update(
|
1354 |
+
{
|
1355 |
+
"position_ids": position_ids,
|
1356 |
+
"past_key_values": past_key_values,
|
1357 |
+
"use_cache": kwargs.get("use_cache"),
|
1358 |
+
"attention_mask": attention_mask,
|
1359 |
+
}
|
1360 |
+
)
|
1361 |
+
return model_inputs
|
1362 |
+
|
1363 |
+
@staticmethod
|
1364 |
+
def _reorder_cache(past_key_values, beam_idx):
|
1365 |
+
reordered_past = ()
|
1366 |
+
for layer_past in past_key_values:
|
1367 |
+
reordered_past += (
|
1368 |
+
tuple(
|
1369 |
+
past_state.index_select(0, beam_idx.to(past_state.device))
|
1370 |
+
for past_state in layer_past
|
1371 |
+
),
|
1372 |
+
)
|
1373 |
+
return reordered_past
|
1374 |
+
|
1375 |
+
|
1376 |
+
@add_start_docstrings(
|
1377 |
+
"""
|
1378 |
+
The Mistral Model transformer with a sequence classification head on top (linear layer).
|
1379 |
+
|
1380 |
+
[`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1381 |
+
(e.g. GPT-2) do.
|
1382 |
+
|
1383 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1384 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1385 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
1386 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1387 |
+
each row of the batch).
|
1388 |
+
""",
|
1389 |
+
MISTRAL_START_DOCSTRING,
|
1390 |
+
)
|
1391 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL
|
1392 |
+
class MistralForSequenceClassification(MixtralPreTrainedModel):
|
1393 |
+
def __init__(self, config):
|
1394 |
+
super().__init__(config)
|
1395 |
+
self.num_labels = config.num_labels
|
1396 |
+
self.model = MistralModel(config)
|
1397 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1398 |
+
|
1399 |
+
# Initialize weights and apply final processing
|
1400 |
+
self.post_init()
|
1401 |
+
|
1402 |
+
def get_input_embeddings(self):
|
1403 |
+
return self.model.embed_tokens
|
1404 |
+
|
1405 |
+
def set_input_embeddings(self, value):
|
1406 |
+
self.model.embed_tokens = value
|
1407 |
+
|
1408 |
+
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
1409 |
+
def forward(
|
1410 |
+
self,
|
1411 |
+
input_ids: torch.LongTensor = None,
|
1412 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1413 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1414 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1415 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1416 |
+
labels: Optional[torch.LongTensor] = None,
|
1417 |
+
use_cache: Optional[bool] = None,
|
1418 |
+
output_attentions: Optional[bool] = None,
|
1419 |
+
output_hidden_states: Optional[bool] = None,
|
1420 |
+
return_dict: Optional[bool] = None,
|
1421 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1422 |
+
r"""
|
1423 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1424 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1425 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1426 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1427 |
+
"""
|
1428 |
+
return_dict = (
|
1429 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1430 |
+
)
|
1431 |
+
|
1432 |
+
transformer_outputs = self.model(
|
1433 |
+
input_ids,
|
1434 |
+
attention_mask=attention_mask,
|
1435 |
+
position_ids=position_ids,
|
1436 |
+
past_key_values=past_key_values,
|
1437 |
+
inputs_embeds=inputs_embeds,
|
1438 |
+
use_cache=use_cache,
|
1439 |
+
output_attentions=output_attentions,
|
1440 |
+
output_hidden_states=output_hidden_states,
|
1441 |
+
return_dict=return_dict,
|
1442 |
+
)
|
1443 |
+
hidden_states = transformer_outputs[0]
|
1444 |
+
logits = self.score(hidden_states)
|
1445 |
+
|
1446 |
+
if input_ids is not None:
|
1447 |
+
batch_size = input_ids.shape[0]
|
1448 |
+
else:
|
1449 |
+
batch_size = inputs_embeds.shape[0]
|
1450 |
+
|
1451 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
1452 |
+
raise ValueError(
|
1453 |
+
"Cannot handle batch sizes > 1 if no padding token is defined."
|
1454 |
+
)
|
1455 |
+
if self.config.pad_token_id is None:
|
1456 |
+
sequence_lengths = -1
|
1457 |
+
else:
|
1458 |
+
if input_ids is not None:
|
1459 |
+
sequence_lengths = (
|
1460 |
+
torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
1461 |
+
).to(logits.device)
|
1462 |
+
else:
|
1463 |
+
sequence_lengths = -1
|
1464 |
+
|
1465 |
+
pooled_logits = logits[
|
1466 |
+
torch.arange(batch_size, device=logits.device), sequence_lengths
|
1467 |
+
]
|
1468 |
+
|
1469 |
+
loss = None
|
1470 |
+
if labels is not None:
|
1471 |
+
labels = labels.to(logits.device)
|
1472 |
+
if self.config.problem_type is None:
|
1473 |
+
if self.num_labels == 1:
|
1474 |
+
self.config.problem_type = "regression"
|
1475 |
+
elif self.num_labels > 1 and (
|
1476 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
1477 |
+
):
|
1478 |
+
self.config.problem_type = "single_label_classification"
|
1479 |
+
else:
|
1480 |
+
self.config.problem_type = "multi_label_classification"
|
1481 |
+
|
1482 |
+
if self.config.problem_type == "regression":
|
1483 |
+
loss_fct = MSELoss()
|
1484 |
+
if self.num_labels == 1:
|
1485 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
1486 |
+
else:
|
1487 |
+
loss = loss_fct(pooled_logits, labels)
|
1488 |
+
elif self.config.problem_type == "single_label_classification":
|
1489 |
+
loss_fct = CrossEntropyLoss()
|
1490 |
+
loss = loss_fct(
|
1491 |
+
pooled_logits.view(-1, self.num_labels), labels.view(-1)
|
1492 |
+
)
|
1493 |
+
elif self.config.problem_type == "multi_label_classification":
|
1494 |
+
loss_fct = BCEWithLogitsLoss()
|
1495 |
+
loss = loss_fct(pooled_logits, labels)
|
1496 |
+
if not return_dict:
|
1497 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
1498 |
+
return ((loss,) + output) if loss is not None else output
|
1499 |
+
|
1500 |
+
return SequenceClassifierOutputWithPast(
|
1501 |
+
loss=loss,
|
1502 |
+
logits=pooled_logits,
|
1503 |
+
past_key_values=transformer_outputs.past_key_values,
|
1504 |
+
hidden_states=transformer_outputs.hidden_states,
|
1505 |
+
attentions=transformer_outputs.attentions,
|
1506 |
+
)
|
src/axolotl/utils/models.py
CHANGED
@@ -54,18 +54,25 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
|
54 |
def load_model_config(cfg):
|
55 |
model_config_name = cfg.base_model_config or cfg.base_model
|
56 |
trust_remote_code = cfg.trust_remote_code is True
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
)
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
if cfg.model_config:
|
71 |
for key, val in cfg.model_config.items():
|
@@ -301,7 +308,9 @@ def load_model(
|
|
301 |
or cfg.is_falcon_derived_model
|
302 |
or cfg.is_mistral_derived_model
|
303 |
):
|
304 |
-
|
|
|
|
|
305 |
|
306 |
try:
|
307 |
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
@@ -363,6 +372,15 @@ def load_model(
|
|
363 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
364 |
**model_kwargs,
|
365 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
elif model_type == "MambaLMHeadModel":
|
367 |
# FIXME this is janky at best and hacked together to make it work
|
368 |
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
|
|
54 |
def load_model_config(cfg):
|
55 |
model_config_name = cfg.base_model_config or cfg.base_model
|
56 |
trust_remote_code = cfg.trust_remote_code is True
|
57 |
+
model_type = cfg.model_type
|
58 |
+
|
59 |
+
if model_type == "MixtralForCausalLM":
|
60 |
+
from axolotl.models.mixtral.configuration_moe_mistral import MixtralConfig
|
61 |
+
|
62 |
+
model_config = MixtralConfig.from_pretrained(model_config_name)
|
63 |
+
else:
|
64 |
+
try:
|
65 |
+
model_config = AutoConfig.from_pretrained(
|
66 |
+
model_config_name, trust_remote_code=trust_remote_code
|
67 |
)
|
68 |
+
except ValueError as err:
|
69 |
+
if "mamba" in model_config_name:
|
70 |
+
return addict.Dict(
|
71 |
+
{
|
72 |
+
"model_type": "mamba",
|
73 |
+
}
|
74 |
+
)
|
75 |
+
raise err
|
76 |
|
77 |
if cfg.model_config:
|
78 |
for key, val in cfg.model_config.items():
|
|
|
308 |
or cfg.is_falcon_derived_model
|
309 |
or cfg.is_mistral_derived_model
|
310 |
):
|
311 |
+
# TODO enable once properly supported in transformers
|
312 |
+
# model_kwargs["attn_implementation"] = "flash_attention_2"
|
313 |
+
model_kwargs["use_flash_attention_2"] = True # legacy, to be deprecated
|
314 |
|
315 |
try:
|
316 |
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
|
|
372 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
373 |
**model_kwargs,
|
374 |
)
|
375 |
+
elif model_type == "MixtralForCausalLM":
|
376 |
+
from axolotl.models.mixtral import MixtralForCausalLM
|
377 |
+
|
378 |
+
model = MixtralForCausalLM.from_pretrained(
|
379 |
+
base_model,
|
380 |
+
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
381 |
+
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
382 |
+
**model_kwargs,
|
383 |
+
)
|
384 |
elif model_type == "MambaLMHeadModel":
|
385 |
# FIXME this is janky at best and hacked together to make it work
|
386 |
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|