winglian commited on
Commit
814aee6
·
unverified ·
1 Parent(s): b715cd5

Phi2 multipack (#1173)

Browse files

* phi2 multipack

* update validation and examples for phi

* more updates to phi examples

* make sure to use the correct collator for phi multipack

* phi needs attention mask now for multipack

* if the special token already exists in the tokenizer, don't require in lora modules to save

* fix qlora yml for phi, fix phi test validation

* test qlora too

* make sure flash attention is enabled for the test

* don't use remote code for phi anymore

* reduce sequence len for sample packing phi

examples/phi/phi-ft.yml CHANGED
@@ -1,8 +1,6 @@
1
  base_model: microsoft/phi-1_5
2
- model_type: PhiForCausalLM
3
  tokenizer_type: AutoTokenizer
4
- is_llama_derived_model: false
5
- trust_remote_code: true
6
 
7
  load_in_8bit: false
8
  load_in_4bit: false
@@ -18,7 +16,7 @@ output_dir: ./phi-sft-out
18
 
19
  sequence_len: 2048
20
  sample_packing: true
21
- pad_to_sequence_len:
22
 
23
  adapter:
24
  lora_model_dir:
@@ -35,7 +33,7 @@ wandb_name:
35
  wandb_log_model:
36
 
37
  gradient_accumulation_steps: 1
38
- micro_batch_size: 1
39
  num_epochs: 4
40
  optimizer: adamw_torch
41
  adam_beta2: 0.95
@@ -45,18 +43,20 @@ lr_scheduler: cosine
45
  learning_rate: 0.000003
46
 
47
  train_on_inputs: false
48
- group_by_length: true
49
  bf16: auto
50
  fp16:
51
  tf32: true
52
 
53
- gradient_checkpointing:
 
 
54
  early_stopping_patience:
55
  resume_from_checkpoint:
56
  local_rank:
57
  logging_steps: 1
58
  xformers_attention:
59
- flash_attention:
60
 
61
  warmup_steps: 100
62
  evals_per_epoch: 4
@@ -68,7 +68,4 @@ fsdp:
68
  fsdp_config:
69
  resize_token_embeddings_to_32x: true
70
  special_tokens:
71
- bos_token: "<|endoftext|>"
72
- eos_token: "<|endoftext|>"
73
- unk_token: "<|endoftext|>"
74
  pad_token: "<|endoftext|>"
 
1
  base_model: microsoft/phi-1_5
2
+ model_type: AutoModelForCausalLM
3
  tokenizer_type: AutoTokenizer
 
 
4
 
5
  load_in_8bit: false
6
  load_in_4bit: false
 
16
 
17
  sequence_len: 2048
18
  sample_packing: true
19
+ pad_to_sequence_len: true
20
 
21
  adapter:
22
  lora_model_dir:
 
33
  wandb_log_model:
34
 
35
  gradient_accumulation_steps: 1
36
+ micro_batch_size: 2
37
  num_epochs: 4
38
  optimizer: adamw_torch
39
  adam_beta2: 0.95
 
43
  learning_rate: 0.000003
44
 
45
  train_on_inputs: false
46
+ group_by_length: false
47
  bf16: auto
48
  fp16:
49
  tf32: true
50
 
51
+ gradient_checkpointing: true
52
+ gradient_checkpointing_kwargs:
53
+ use_reentrant: True
54
  early_stopping_patience:
55
  resume_from_checkpoint:
56
  local_rank:
57
  logging_steps: 1
58
  xformers_attention:
59
+ flash_attention: true
60
 
61
  warmup_steps: 100
62
  evals_per_epoch: 4
 
68
  fsdp_config:
69
  resize_token_embeddings_to_32x: true
70
  special_tokens:
 
 
 
71
  pad_token: "<|endoftext|>"
examples/phi/phi-qlora.yml CHANGED
@@ -1,8 +1,6 @@
1
  base_model: microsoft/phi-1_5
2
  model_type: AutoModelForCausalLM
3
  tokenizer_type: AutoTokenizer
4
- is_llama_derived_model: false
5
- trust_remote_code: true
6
 
7
  load_in_8bit: false
8
  load_in_4bit: true
@@ -16,9 +14,9 @@ dataset_prepared_path:
16
  val_set_size: 0.05
17
  output_dir: ./phi-sft-out
18
 
19
- sequence_len: 1024
20
- sample_packing: false # not CURRENTLY compatible with LoRAs
21
- pad_to_sequence_len:
22
 
23
  adapter: qlora
24
  lora_model_dir:
@@ -35,7 +33,7 @@ wandb_name:
35
  wandb_log_model:
36
 
37
  gradient_accumulation_steps: 1
38
- micro_batch_size: 1
39
  num_epochs: 4
40
  optimizer: adamw_torch
41
  adam_beta2: 0.95
@@ -45,18 +43,20 @@ lr_scheduler: cosine
45
  learning_rate: 0.000003
46
 
47
  train_on_inputs: false
48
- group_by_length: true
49
  bf16: auto
50
  fp16:
51
  tf32: true
52
 
53
- gradient_checkpointing:
 
 
54
  early_stopping_patience:
55
  resume_from_checkpoint:
56
  local_rank:
57
  logging_steps: 1
58
  xformers_attention:
59
- flash_attention:
60
 
61
  warmup_steps: 100
62
  evals_per_epoch: 4
@@ -68,7 +68,4 @@ fsdp:
68
  fsdp_config:
69
  resize_token_embeddings_to_32x: true
70
  special_tokens:
71
- bos_token: "<|endoftext|>"
72
- eos_token: "<|endoftext|>"
73
- unk_token: "<|endoftext|>"
74
  pad_token: "<|endoftext|>"
 
1
  base_model: microsoft/phi-1_5
2
  model_type: AutoModelForCausalLM
3
  tokenizer_type: AutoTokenizer
 
 
4
 
5
  load_in_8bit: false
6
  load_in_4bit: true
 
14
  val_set_size: 0.05
15
  output_dir: ./phi-sft-out
16
 
17
+ sequence_len: 2048
18
+ sample_packing: true
19
+ pad_to_sequence_len: true
20
 
21
  adapter: qlora
22
  lora_model_dir:
 
33
  wandb_log_model:
34
 
35
  gradient_accumulation_steps: 1
36
+ micro_batch_size: 2
37
  num_epochs: 4
38
  optimizer: adamw_torch
39
  adam_beta2: 0.95
 
43
  learning_rate: 0.000003
44
 
45
  train_on_inputs: false
46
+ group_by_length: false
47
  bf16: auto
48
  fp16:
49
  tf32: true
50
 
51
+ gradient_checkpointing: true
52
+ gradient_checkpointing_kwargs:
53
+ use_reentrant: True
54
  early_stopping_patience:
55
  resume_from_checkpoint:
56
  local_rank:
57
  logging_steps: 1
58
  xformers_attention:
59
+ flash_attention: true
60
 
61
  warmup_steps: 100
62
  evals_per_epoch: 4
 
68
  fsdp_config:
69
  resize_token_embeddings_to_32x: true
70
  special_tokens:
 
 
 
71
  pad_token: "<|endoftext|>"
examples/phi/phi2-ft.yml CHANGED
@@ -1,8 +1,6 @@
1
  base_model: microsoft/phi-2
2
- model_revision: 834565c # pin model repo to the previous architecture
3
  model_type: AutoModelForCausalLM
4
  tokenizer_type: AutoTokenizer
5
- trust_remote_code: true
6
 
7
  load_in_8bit: false
8
  load_in_4bit: false
@@ -17,19 +15,16 @@ val_set_size: 0.05
17
  output_dir: ./phi-sft-out
18
 
19
  sequence_len: 2048
20
- sample_packing: false # currently unsupported
21
- pad_to_sequence_len:
22
 
23
  adapter:
24
  lora_model_dir:
25
- lora_r: 16
26
- lora_alpha: 32
27
- lora_dropout: 0.1
28
- lora_target_linear: true
29
  lora_fan_in_fan_out:
30
- lora_modules_to_save:
31
- - embd
32
- - lm_head
33
 
34
  wandb_project:
35
  wandb_entity:
@@ -38,14 +33,14 @@ wandb_name:
38
  wandb_log_model:
39
 
40
  gradient_accumulation_steps: 1
41
- micro_batch_size: 1
42
  num_epochs: 4
43
- optimizer: paged_adamw_8bit
44
  adam_beta2: 0.95
45
  adam_epsilon: 0.00001
46
  max_grad_norm: 1.0
47
  lr_scheduler: cosine
48
- learning_rate: 1e-5
49
 
50
  train_on_inputs: false
51
  group_by_length: false
@@ -54,6 +49,8 @@ fp16:
54
  tf32: true
55
 
56
  gradient_checkpointing: true
 
 
57
  early_stopping_patience:
58
  resume_from_checkpoint:
59
  local_rank:
 
1
  base_model: microsoft/phi-2
 
2
  model_type: AutoModelForCausalLM
3
  tokenizer_type: AutoTokenizer
 
4
 
5
  load_in_8bit: false
6
  load_in_4bit: false
 
15
  output_dir: ./phi-sft-out
16
 
17
  sequence_len: 2048
18
+ sample_packing: true
19
+ pad_to_sequence_len: true
20
 
21
  adapter:
22
  lora_model_dir:
23
+ lora_r:
24
+ lora_alpha:
25
+ lora_dropout:
26
+ lora_target_linear:
27
  lora_fan_in_fan_out:
 
 
 
28
 
29
  wandb_project:
30
  wandb_entity:
 
33
  wandb_log_model:
34
 
35
  gradient_accumulation_steps: 1
36
+ micro_batch_size: 2
37
  num_epochs: 4
38
+ optimizer: adamw_torch
39
  adam_beta2: 0.95
40
  adam_epsilon: 0.00001
41
  max_grad_norm: 1.0
42
  lr_scheduler: cosine
43
+ learning_rate: 0.000003
44
 
45
  train_on_inputs: false
46
  group_by_length: false
 
49
  tf32: true
50
 
51
  gradient_checkpointing: true
52
+ gradient_checkpointing_kwargs:
53
+ use_reentrant: True
54
  early_stopping_patience:
55
  resume_from_checkpoint:
56
  local_rank:
src/axolotl/core/trainer_builder.py CHANGED
@@ -930,7 +930,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
930
  ]
931
  ]
932
  if use_batch_sampler_collator:
933
- if self.cfg.model_config_type in ["mixtral", "qwen2"]:
934
  collator = V2BatchSamplerDataCollatorForSeq2Seq
935
  else:
936
  collator = BatchSamplerDataCollatorForSeq2Seq
 
930
  ]
931
  ]
932
  if use_batch_sampler_collator:
933
+ if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]:
934
  collator = V2BatchSamplerDataCollatorForSeq2Seq
935
  else:
936
  collator = BatchSamplerDataCollatorForSeq2Seq
src/axolotl/models/phi/__init__.py DELETED
@@ -1,8 +0,0 @@
1
- """
2
- MixFormers model architecture used for phi models
3
- """
4
-
5
- from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
6
- from .configuration_phi import PhiConfig # noqa
7
- from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
8
- from .modeling_phi import PhiForCausalLM # noqa
 
 
 
 
 
 
 
 
 
src/axolotl/models/phi/configuration_mixformer_sequential.py DELETED
@@ -1,63 +0,0 @@
1
- # pylint: skip-file
2
-
3
- # Copyright (c) Microsoft Corporation.
4
- # Licensed under the MIT license.
5
-
6
- import math
7
- from typing import Any, Dict, List, Optional, Union
8
-
9
- from transformers import PretrainedConfig
10
-
11
-
12
- class MixFormerSequentialConfig(PretrainedConfig):
13
- """MixFormer (sequential for DeepSpeed) configuration."""
14
-
15
- model_type = "mixformer-sequential"
16
-
17
- attribute_map = {
18
- "max_position_embeddings": "n_positions",
19
- "hidden_size": "n_embd",
20
- "num_attention_heads": "n_head",
21
- "num_hidden_layers": "n_layer",
22
- "input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility
23
- "blocks": "architecture", # `blocks` key is for backward compatibility
24
- }
25
-
26
- def __init__(
27
- self,
28
- vocab_size: Optional[int] = 50304,
29
- n_positions: Optional[int] = 2048,
30
- n_embd: Optional[int] = 1024,
31
- n_layer: Optional[int] = 20,
32
- n_inner: Optional[int] = None,
33
- n_head: Optional[int] = 16,
34
- rotary_dim: Optional[int] = 32,
35
- activation_function: Optional[str] = "gelu_new",
36
- embd_layer: Optional[str] = "default",
37
- architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None,
38
- embd_pdrop: Optional[float] = 0.0,
39
- resid_pdrop: Optional[float] = 0.0,
40
- layer_norm_epsilon: Optional[float] = 1e-5,
41
- initializer_range: Optional[float] = 0.02,
42
- tie_word_embeddings: Optional[bool] = False,
43
- pad_vocab_size_multiple: Optional[int] = 64,
44
- **kwargs
45
- ) -> None:
46
- self.vocab_size = int(
47
- math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
48
- )
49
- self.n_positions = n_positions
50
- self.n_embd = n_embd
51
- self.n_layer = n_layer
52
- self.n_inner = n_inner
53
- self.n_head = n_head
54
- self.rotary_dim = min(rotary_dim, n_embd // n_head)
55
- self.activation_function = activation_function
56
- self.embd_layer = embd_layer
57
- self.architecture = architecture
58
- self.embd_pdrop = embd_pdrop
59
- self.resid_pdrop = resid_pdrop
60
- self.layer_norm_epsilon = layer_norm_epsilon
61
- self.initializer_range = initializer_range
62
-
63
- super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/models/phi/configuration_phi.py DELETED
@@ -1,65 +0,0 @@
1
- # pylint: skip-file
2
- # Copyright (c) Microsoft Corporation.
3
- # Licensed under the MIT license.
4
-
5
- import math
6
- from typing import Optional
7
-
8
- from transformers import PretrainedConfig
9
-
10
-
11
- class PhiConfig(PretrainedConfig):
12
- """Phi configuration."""
13
-
14
- model_type = "phi"
15
- attribute_map = {
16
- "max_position_embeddings": "n_positions",
17
- "hidden_size": "n_embd",
18
- "num_attention_heads": "n_head",
19
- "num_hidden_layers": "n_layer",
20
- }
21
-
22
- def __init__(
23
- self,
24
- vocab_size: int = 50304,
25
- n_positions: int = 2048,
26
- n_embd: int = 1024,
27
- n_layer: int = 20,
28
- n_inner: Optional[int] = None,
29
- n_head: int = 16,
30
- n_head_kv: Optional[int] = None,
31
- rotary_dim: Optional[int] = 32,
32
- activation_function: Optional[str] = "gelu_new",
33
- flash_attn: bool = False,
34
- flash_rotary: bool = False,
35
- fused_dense: bool = False,
36
- attn_pdrop: float = 0.0,
37
- embd_pdrop: float = 0.0,
38
- resid_pdrop: float = 0.0,
39
- layer_norm_epsilon: float = 1e-5,
40
- initializer_range: float = 0.02,
41
- tie_word_embeddings: bool = False,
42
- pad_vocab_size_multiple: int = 64,
43
- **kwargs
44
- ) -> None:
45
- self.vocab_size = int(
46
- math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
47
- )
48
- self.n_positions = n_positions
49
- self.n_embd = n_embd
50
- self.n_layer = n_layer
51
- self.n_inner = n_inner
52
- self.n_head = n_head
53
- self.n_head_kv = n_head_kv
54
- self.rotary_dim = min(rotary_dim, n_embd // n_head)
55
- self.activation_function = activation_function
56
- self.flash_attn = flash_attn
57
- self.flash_rotary = flash_rotary
58
- self.fused_dense = fused_dense
59
- self.attn_pdrop = attn_pdrop
60
- self.embd_pdrop = embd_pdrop
61
- self.resid_pdrop = resid_pdrop
62
- self.layer_norm_epsilon = layer_norm_epsilon
63
- self.initializer_range = initializer_range
64
-
65
- super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/models/phi/modeling_mixformer_sequential.py DELETED
@@ -1,930 +0,0 @@
1
- # pylint: skip-file
2
-
3
- # Copyright (c) Microsoft Corporation.
4
- # Licensed under the MIT license.
5
-
6
- # BSD 3-Clause License
7
- #
8
- # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
9
- # All rights reserved.
10
- #
11
- # Redistribution and use in source and binary forms, with or without
12
- # modification, are permitted provided that the following conditions are met:
13
- #
14
- # * Redistributions of source code must retain the above copyright notice, this
15
- # list of conditions and the following disclaimer.
16
- #
17
- # * Redistributions in binary form must reproduce the above copyright notice,
18
- # this list of conditions and the following disclaimer in the documentation
19
- # and/or other materials provided with the distribution.
20
- #
21
- # * Neither the name of the copyright holder nor the names of its
22
- # contributors may be used to endorse or promote products derived from
23
- # this software without specific prior written permission.
24
- #
25
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
29
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
30
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
31
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
32
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
33
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
34
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
35
-
36
- from __future__ import annotations
37
-
38
- import copy
39
- import inspect
40
- from dataclasses import dataclass, field
41
- from typing import Any, Dict, Optional, Tuple
42
-
43
- import torch
44
- import torch.nn as nn
45
- from einops import rearrange
46
- from flash_attn.flash_attn_interface import (
47
- flash_attn_kvpacked_func,
48
- flash_attn_qkvpacked_func,
49
- flash_attn_varlen_qkvpacked_func,
50
- )
51
- from transformers import PretrainedConfig, PreTrainedModel
52
- from transformers.activations import ACT2FN
53
- from transformers.modeling_outputs import CausalLMOutputWithPast
54
-
55
- from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids
56
- from .configuration_mixformer_sequential import MixFormerSequentialConfig
57
-
58
-
59
- @dataclass
60
- class InferenceParams:
61
- """Inference parameters that are passed to the main model in order
62
- to efficienly calculate and store the context during inference.
63
- Adapted from https://github.com/Dao-AILab/flash-attention."""
64
-
65
- max_sequence_len: int
66
- max_batch_size: int
67
- sequence_len_offset: int = 0
68
- batch_size_offset: int = 0
69
- key_value_memory_dict: dict = field(default_factory=dict)
70
- fused_ft_kernel: bool = False
71
- lengths_per_sample: Optional[torch.Tensor] = None
72
-
73
-
74
- class Embedding(nn.Module):
75
- """Token embedding with dropout."""
76
-
77
- def __init__(self, config: PretrainedConfig) -> None:
78
- super().__init__()
79
-
80
- self.wte = nn.Embedding(config.vocab_size, config.n_embd)
81
- self.drop = nn.Dropout(config.embd_pdrop)
82
-
83
- def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
84
- input_shape = input_ids.size()
85
- input_ids = input_ids.view(-1, input_shape[-1])
86
-
87
- hidden_states = self.wte(input_ids)
88
- hidden_states = self.drop(hidden_states)
89
-
90
- return hidden_states
91
-
92
-
93
- class RotaryEmbedding(nn.Module):
94
- """PyTorch implementation of `flash-attn` RotaryEmbedding layer.
95
- Adapted from https://github.com/Dao-AILab/flash-attention."""
96
-
97
- def __init__(
98
- self,
99
- dim: int,
100
- base: Optional[int] = 10000,
101
- scale_base: Optional[float] = None,
102
- device: Optional[str] = None,
103
- **kwargs,
104
- ) -> None:
105
- super().__init__()
106
-
107
- if scale_base is not None:
108
- raise NotImplementedError
109
-
110
- # Generate and save the inverse frequency buffer (non-trainable)
111
- self.dim = dim
112
- self.base = base
113
- self.scale_base = scale_base
114
- self.device = device
115
-
116
- inv_freq = 1.0 / (
117
- base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
118
- )
119
- self.register_buffer("inv_freq", inv_freq)
120
-
121
- scale = (
122
- (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
123
- / (1.4 * dim)
124
- if scale_base is not None
125
- else None
126
- )
127
- self.register_buffer("scale", scale)
128
-
129
- self._seq_len_cached = 0
130
- self._cos_cached = None
131
- self._sin_cached = None
132
- self._cos_k_cached = None
133
- self._sin_k_cached = None
134
-
135
- def _update_cos_sin_cache(
136
- self, x: torch.FloatTensor, seqlen_offset: Optional[int] = 0
137
- ) -> None:
138
- # Reset the tables if the sequence length has changed,
139
- # or if we're on a new device (possibly due to tracing for instance)
140
- seqlen = x.shape[1] + seqlen_offset
141
-
142
- # Re-generate the inverse frequency buffer if it's not fp32
143
- # (for instance if model.half() was called)
144
- if self.inv_freq.dtype != "torch.float32":
145
- self.inv_freq = 1.0 / (
146
- self.base
147
- ** (
148
- torch.arange(
149
- 0, self.dim, 2, device=self.device, dtype=torch.float32
150
- )
151
- / self.dim
152
- )
153
- )
154
-
155
- if (
156
- seqlen > self._seq_len_cached
157
- or self._cos_cached.device != x.device
158
- or self._cos_cached.dtype != x.dtype
159
- ):
160
- self._seq_len_cached = seqlen
161
- t = torch.arange(seqlen, device=x.device, dtype=torch.float32)
162
-
163
- # Don't do einsum, it converts fp32 to fp16
164
- # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
165
- freqs = torch.outer(
166
- t, self.inv_freq.to(device=t.device, dtype=torch.float32)
167
- )
168
- if self.scale is None:
169
- self._cos_cached = torch.cos(freqs).to(x.dtype)
170
- self._sin_cached = torch.sin(freqs).to(x.dtype)
171
- else:
172
- power = (
173
- torch.arange(
174
- seqlen, dtype=self.scale.dtype, device=self.scale.device
175
- )
176
- - seqlen // 2
177
- ) / self.scale_base
178
- scale = self.scale.to(device=power.device) ** rearrange(
179
- power, "s -> s 1"
180
- )
181
-
182
- # We want the multiplication by scale to happen in fp32
183
- self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
184
- self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
185
- self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
186
- self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
187
-
188
- def apply_rotary_emb_qkv(
189
- self,
190
- qkv: torch.FloatTensor,
191
- sin: torch.FloatTensor,
192
- cos: torch.FloatTensor,
193
- sin_k: Optional[torch.FloatTensor] = None,
194
- cos_k: Optional[torch.FloatTensor] = None,
195
- ) -> torch.FloatTensor:
196
- _, seqlen, three, _, headdim = qkv.shape
197
- assert three == 3
198
-
199
- rotary_seqlen, rotary_dim = cos.shape
200
- rotary_dim *= 2
201
- assert rotary_dim <= headdim
202
- assert seqlen <= rotary_seqlen
203
-
204
- cos_k = cos if cos_k is None else cos_k
205
- sin_k = sin if sin_k is None else sin_k
206
- assert (
207
- sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
208
- )
209
-
210
- q_rot = qkv[:, :, 0, :, :rotary_dim]
211
- q_pass = qkv[:, :, 0, :, rotary_dim:]
212
-
213
- k_rot = qkv[:, :, 1, :, :rotary_dim]
214
- k_pass = qkv[:, :, 1, :, rotary_dim:]
215
-
216
- # Splits the queries and keys in half
217
- q1, q2 = q_rot.chunk(2, dim=-1)
218
- k1, k2 = k_rot.chunk(2, dim=-1)
219
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
220
- sin[:seqlen], "s d -> s 1 d"
221
- )
222
-
223
- # Casts to fp32 are necessary to prevent fp16 overflow issues
224
- q1, q2, k1, k2, c, s = [
225
- t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]
226
- ]
227
-
228
- # Computes the new keys and queries, recasting to original dtype
229
- q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
230
-
231
- k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
232
-
233
- return torch.cat(
234
- [
235
- torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
236
- torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
237
- qkv[:, :, 2:3, :, :],
238
- ],
239
- axis=2,
240
- )
241
-
242
- def forward(
243
- self, qkv: torch.Tensor, seqlen_offset: int = 0
244
- ) -> Tuple[torch.Tensor, torch.Tensor]:
245
- """Perform the forward pass.
246
-
247
- Args:
248
- qkv: Query, key and value tensors of shape (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim).
249
- seqlen_offset: Used in generation where the passed `qkv` is only the last token in the batch.
250
-
251
- Returns:
252
- New `qkv` and the cached sinusoids.
253
-
254
- """
255
-
256
- self._update_cos_sin_cache(qkv, seqlen_offset)
257
-
258
- return self.apply_rotary_emb_qkv(
259
- qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:]
260
- )
261
-
262
-
263
- def _update_kv_cache(kv, inference_params, layer_idx):
264
- """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
265
- Adapted from https://github.com/Dao-AILab/flash-attention."""
266
- # Pre-allocate memory for key-values for inference.
267
- num_heads, head_dim = kv.shape[-2:]
268
- if layer_idx not in inference_params.key_value_memory_dict:
269
- kv_cache = torch.empty(
270
- inference_params.max_batch_size,
271
- inference_params.max_sequence_len,
272
- 2,
273
- num_heads,
274
- head_dim,
275
- dtype=kv.dtype,
276
- device=kv.device,
277
- )
278
- inference_params.key_value_memory_dict[layer_idx] = kv_cache
279
- else:
280
- kv_cache = inference_params.key_value_memory_dict[layer_idx]
281
-
282
- # Adjust key and value for inference
283
- batch_start = inference_params.batch_size_offset
284
- batch_end = batch_start + kv.shape[0]
285
- sequence_start = inference_params.sequence_len_offset
286
- sequence_end = sequence_start + kv.shape[1]
287
- assert batch_end <= (
288
- kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0] # noqa
289
- )
290
- assert sequence_end <= (
291
- kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2] # noqa
292
- )
293
-
294
- assert kv_cache is not None
295
- kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
296
- kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
297
- return kv
298
-
299
-
300
- class MLP(nn.Module):
301
- """Multi-Layer Perceptron.
302
-
303
- Reference:
304
- Attention Is All You Need.
305
- https://arxiv.org/pdf/1706.03762.pdf.
306
-
307
- """
308
-
309
- def __init__(
310
- self,
311
- config: PretrainedConfig,
312
- n_inner: Optional[int] = None,
313
- act_fn: Optional[str] = None,
314
- ) -> None:
315
- super().__init__()
316
-
317
- act_fn = config.activation_function if act_fn is None else act_fn
318
- assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
319
-
320
- n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
321
- n_inner = n_inner if n_inner is not None else 4 * config.n_embd
322
-
323
- self.fc1 = nn.Linear(config.n_embd, n_inner)
324
- self.fc2 = nn.Linear(n_inner, config.n_embd)
325
- self.act = ACT2FN[act_fn]
326
-
327
- def _load_from_state_dict(
328
- self,
329
- state_dict,
330
- prefix,
331
- local_metadata,
332
- strict,
333
- missing_keys,
334
- unexpected_keys,
335
- error_msgs,
336
- ):
337
- old_keys = [
338
- prefix + "fc_in.weight",
339
- prefix + "fc_out.weight",
340
- prefix + "fc_in.bias",
341
- prefix + "fc_out.bias",
342
- ]
343
- new_keys = [
344
- prefix + "fc1.weight",
345
- prefix + "fc2.weight",
346
- prefix + "fc1.bias",
347
- prefix + "fc2.bias",
348
- ]
349
-
350
- if all(k in state_dict for k in old_keys) and not all(
351
- k in state_dict for k in new_keys
352
- ):
353
- # Older version of `MLP` saved with different key names.
354
- for old_key, new_key in zip(old_keys, new_keys):
355
- state_dict[new_key] = state_dict.pop(old_key)
356
-
357
- return super()._load_from_state_dict(
358
- state_dict,
359
- prefix,
360
- local_metadata,
361
- strict,
362
- missing_keys,
363
- unexpected_keys,
364
- error_msgs,
365
- )
366
-
367
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
368
- hidden_states = self.fc1(hidden_states)
369
- hidden_states = self.act(hidden_states)
370
- hidden_states = self.fc2(hidden_states)
371
-
372
- return hidden_states
373
-
374
-
375
- class FusedMLP(nn.Module):
376
- """Fused Multi-Layer Perceptron from `flash-attn`.
377
-
378
- Reference:
379
- https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/ops/fused_dense.py.
380
-
381
- """
382
-
383
- def __init__(
384
- self,
385
- config: PretrainedConfig,
386
- n_inner: Optional[int] = None,
387
- act_fn: Optional[str] = None,
388
- raise_on_missing: bool = False,
389
- ) -> None:
390
- super().__init__()
391
-
392
- act_fn = config.activation_function if act_fn is None else act_fn
393
- assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
394
-
395
- n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
396
- n_inner = n_inner if n_inner is not None else 4 * config.n_embd
397
-
398
- gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"] # noqa
399
- activation = "gelu_approx" if act_fn in gelu_activations else "relu" # noqa
400
-
401
- self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn)
402
-
403
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
404
- return self.mlp(hidden_states)
405
-
406
-
407
- class SelfAttention(nn.Module):
408
- """Implement the scaled dot product attention with softmax.
409
- Adapted from https://github.com/Dao-AILab/flash-attention.
410
- Arguments
411
- ---------
412
- softmax_scale: The temperature to use for the softmax attention.
413
- (default: 1/sqrt(d_keys) where d_keys is computed at
414
- runtime)
415
- attention_dropout: The dropout rate to apply to the attention
416
- (default: 0.0)
417
- """
418
-
419
- def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
420
- super().__init__()
421
- self.causal = causal
422
- self.softmax_scale = softmax_scale
423
- self.drop = nn.Dropout(attention_dropout)
424
-
425
- def forward(
426
- self, qkv, causal=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None
427
- ):
428
- """Implements the multihead softmax attention.
429
- Arguments
430
- ---------
431
- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
432
- causal: if passed, will override self.causal
433
- key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
434
- False means to mask out. (B, S)
435
- """
436
- causal = self.causal if causal is None else causal
437
- if cu_seqlens is not None:
438
- return flash_attn_varlen_qkvpacked_func(
439
- qkv.squeeze(0),
440
- cu_seqlens,
441
- max_seqlen,
442
- dropout_p=self.drop.p,
443
- softmax_scale=self.softmax_scale,
444
- causal=causal,
445
- )
446
- else:
447
- return flash_attn_qkvpacked_func(
448
- qkv,
449
- dropout_p=self.drop.p,
450
- softmax_scale=self.softmax_scale,
451
- causal=causal,
452
- )
453
-
454
-
455
- class CrossAttention(nn.Module):
456
- """Implement the scaled dot product attention with softmax.
457
- Adapted from https://github.com/Dao-AILab/flash-attention.
458
- Arguments
459
- ---------
460
- softmax_scale: The temperature to use for the softmax attention.
461
- (default: 1/sqrt(d_keys) where d_keys is computed at
462
- runtime)
463
- attention_dropout: The dropout rate to apply to the attention
464
- (default: 0.0)
465
- """
466
-
467
- def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
468
- super().__init__()
469
- self.causal = causal
470
- self.softmax_scale = softmax_scale
471
- self.drop = nn.Dropout(attention_dropout)
472
-
473
- def forward(self, q, kv, causal=None, key_padding_mask=None):
474
- """Implements the multihead softmax attention.
475
- Arguments
476
- ---------
477
- q: The tensor containing the query. (B, Sq, H, D)
478
- kv: The tensor containing the key and value. (B, Sk, 2, H, D)
479
- causal: if passed, will override self.causal
480
- key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
481
- False means to mask out. (B, Sk)
482
- """
483
- causal = self.causal if causal is None else causal
484
- return flash_attn_kvpacked_func(
485
- q,
486
- kv,
487
- dropout_p=self.drop.p,
488
- softmax_scale=self.softmax_scale,
489
- causal=causal,
490
- )
491
-
492
-
493
- def find_mha_dims(
494
- config: PretrainedConfig,
495
- n_head: Optional[int] = None,
496
- head_dim: Optional[int] = None,
497
- ) -> Tuple[int, int]:
498
- """Validate and return the number of heads and head dimension for multi-head attention.
499
-
500
- Args:
501
- config: Model configuration.
502
- n_head: Number of heads.
503
- head_dim: Head dimension.
504
-
505
- Returns:
506
- Number of heads and head dimension.
507
-
508
- """
509
-
510
- assert all(
511
- hasattr(config, attr) for attr in ["n_embd", "n_head"]
512
- ), "`config` must have `n_embd` and `n_head` attributes."
513
-
514
- if head_dim is None:
515
- assert (
516
- config.n_embd % config.n_head == 0
517
- ), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
518
-
519
- if n_head is None and head_dim is None:
520
- head_dim = config.n_embd // config.n_head
521
- n_head = config.n_head
522
- elif n_head is None or head_dim is None:
523
- raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
524
-
525
- return n_head, head_dim
526
-
527
-
528
- class MHA(nn.Module):
529
- """Multi-head attention layer.
530
- Adapted from https://github.com/Dao-AILab/flash-attention."""
531
-
532
- def __init__(
533
- self,
534
- config: PretrainedConfig,
535
- rotary_dim: Optional[int] = None,
536
- n_head: Optional[int] = None,
537
- head_dim: Optional[int] = None,
538
- bias: Optional[bool] = True,
539
- dropout: Optional[float] = 0.0,
540
- softmax_scale: Optional[float] = None,
541
- causal: Optional[bool] = True,
542
- layer_idx: Optional[int] = None,
543
- rotary_emb_scale_base: Optional[float] = None,
544
- return_residual: Optional[bool] = False,
545
- checkpointing: Optional[bool] = False,
546
- device: Optional[str] = None,
547
- dtype: Optional[torch.dtype] = None,
548
- fused_dense: Optional[bool] = True,
549
- flash_attn: Optional[bool] = True,
550
- cutlass_attn: Optional[bool] = False,
551
- flash_rotary: Optional[bool] = True,
552
- raise_on_missing: Optional[bool] = False,
553
- ) -> None:
554
- super().__init__()
555
-
556
- factory_kwargs = {"device": device, "dtype": dtype}
557
- n_head, head_dim = find_mha_dims(config, n_head, head_dim)
558
-
559
- self.hidden_size = config.n_embd
560
- self.n_head = n_head
561
- self.head_dim = head_dim
562
- self.op_size = n_head * head_dim
563
-
564
- self.causal = causal
565
- self.layer_idx = layer_idx
566
- self.rotary_emb_dim = (
567
- rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
568
- )
569
- self.fused_dense = fused_dense
570
- self.flash_attn = flash_attn
571
- self.cutlass_attn = cutlass_attn
572
- self.flash_rotary = flash_rotary
573
- self.return_residual = return_residual
574
- self.checkpointing = checkpointing
575
-
576
- if self.rotary_emb_dim > 0:
577
- rotary_kwargs = {"device": device}
578
- if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
579
- rotary_kwargs["scale_base"] = rotary_emb_scale_base
580
-
581
- self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
582
- else:
583
- pass
584
-
585
- self.Wqkv = nn.Linear(
586
- self.hidden_size, 3 * self.op_size, bias=bias, **factory_kwargs
587
- )
588
- self.out_proj = nn.Linear(
589
- self.op_size, self.hidden_size, bias=bias, **factory_kwargs
590
- )
591
-
592
- self.inner_attn = SelfAttention(
593
- causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
594
- )
595
- self.inner_cross_attn = CrossAttention(
596
- causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
597
- )
598
-
599
- def _update_kv_cache(
600
- self, kv: torch.FloatTensor, inference_params: InferenceParams
601
- ) -> None:
602
- """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
603
- Adapted from https://github.com/Dao-AILab/flash-attention."""
604
-
605
- assert (
606
- self.layer_idx is not None
607
- ), "Generation requires layer_idx in the constructor"
608
-
609
- return _update_kv_cache(kv, inference_params, self.layer_idx)
610
-
611
- def forward(
612
- self,
613
- x: torch.FloatTensor,
614
- x_kv: Optional[torch.FloatTensor] = None,
615
- key_padding_mask: Optional[torch.BoolTensor] = None,
616
- cu_seqlens: Optional[torch.LongTensor] = None,
617
- max_seqlen: Optional[int] = None,
618
- mixer_subset: Optional[torch.LongTensor] = None,
619
- past_cache: Optional[InferenceParams] = None,
620
- **kwargs,
621
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
622
- """Perform the forward pass.
623
-
624
- Args:
625
- x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
626
- cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
627
- is the is the sum of the sequence lengths in the batch.
628
- x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
629
- key_padding_mask: boolean mask, True means to keep, False means to mask out.
630
- (batch, seqlen). Only applicable when not using FlashAttention.
631
- cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
632
- of the sequences in the batch, used to index into x. Only applicable when using
633
- FlashAttention.
634
- max_seqlen: int. Maximum sequence length in the batch.
635
- mixer_subset: for cross-attention only. If not None, will take a subset of x
636
- before applying the query projection. Useful for e.g., ViT where we only care
637
- about the CLS token in the last layer.
638
- past_cache: For generation only.
639
-
640
- Returns:
641
- (batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None,
642
- else (total, hidden_dim) where total is the is the sum of the sequence lengths
643
- in the batch.
644
-
645
- """
646
-
647
- if cu_seqlens is not None:
648
- assert max_seqlen is not None
649
- assert key_padding_mask is None
650
- assert self.flash_attn
651
- # assert self.rotary_emb_dim == 0
652
-
653
- if key_padding_mask is not None:
654
- assert cu_seqlens is None
655
- assert max_seqlen is None
656
- assert not self.flash_attn
657
-
658
- if past_cache is not None:
659
- assert key_padding_mask is None
660
- assert cu_seqlens is None and max_seqlen is None
661
-
662
- attn_kwargs = {"key_padding_mask": key_padding_mask}
663
-
664
- assert x_kv is None and mixer_subset is None
665
-
666
- qkv = self.Wqkv(x)
667
- qkv = rearrange(
668
- qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
669
- )
670
-
671
- if past_cache is None:
672
- if self.rotary_emb_dim > 0:
673
- qkv = self.rotary_emb(qkv)
674
- context = self.inner_attn(
675
- qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, **attn_kwargs
676
- )
677
-
678
- else:
679
- if self.rotary_emb_dim > 0:
680
- qkv = self.rotary_emb(qkv, seqlen_offset=past_cache.sequence_len_offset)
681
- q = qkv[:, :, 0]
682
- kv = self._update_kv_cache(qkv[:, :, 1:], past_cache)
683
- # If we're processing the prompt, causal=None (use self.causal).
684
- # If we're decoding, then causal=False.
685
- causal = None if past_cache.sequence_len_offset == 0 else False
686
- context = self.inner_cross_attn(q, kv, causal=causal)
687
-
688
- out = rearrange(context, "... h d -> ... (h d)")
689
- out = self.out_proj(out)
690
-
691
- return out if not self.return_residual else (out, x)
692
-
693
-
694
- class ParallelBlock(nn.Module):
695
- """Parallel block.
696
-
697
- This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
698
-
699
- """
700
-
701
- def __init__(
702
- self,
703
- config: PretrainedConfig,
704
- mixer: Optional[Dict[str, Any]] = None,
705
- mlp: Optional[Dict[str, Any]] = None,
706
- block_idx: Optional[int] = None,
707
- ) -> None:
708
- super().__init__()
709
-
710
- self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
711
- self.resid_dropout = nn.Dropout(config.resid_pdrop)
712
- self.block_idx = block_idx
713
-
714
- self.mixer = MHA(config, layer_idx=block_idx)
715
- self.mlp = MLP(config)
716
-
717
- def forward(
718
- self,
719
- hidden_states: torch.FloatTensor,
720
- past_cache: Optional[torch.FloatTensor] = None,
721
- cu_seqlens: Optional[torch.LongTensor] = None,
722
- max_seqlen: Optional[int] = None,
723
- ) -> torch.FloatTensor:
724
- residual = hidden_states
725
- hidden_states = self.ln(hidden_states)
726
-
727
- attn_outputs = self.mixer(
728
- hidden_states,
729
- past_cache=past_cache,
730
- cu_seqlens=cu_seqlens,
731
- max_seqlen=max_seqlen,
732
- )
733
- if isinstance(attn_outputs, tuple):
734
- attn_outputs = attn_outputs[0]
735
-
736
- attn_outputs = self.resid_dropout(attn_outputs)
737
- feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
738
-
739
- hidden_states = attn_outputs + feed_forward_hidden_states + residual
740
-
741
- return hidden_states
742
-
743
-
744
- class CausalLMHead(nn.Module):
745
- """Causal Language Modeling head.
746
-
747
- Reference:
748
- Improving Language Understanding by Generative Pre-Training.
749
- https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
750
-
751
- """
752
-
753
- def __init__(self, config: PretrainedConfig) -> None:
754
- super().__init__()
755
-
756
- self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
757
- self.linear = nn.Linear(config.n_embd, config.vocab_size)
758
-
759
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
760
- hidden_states = self.ln(hidden_states)
761
- logits = self.linear(hidden_states).to(torch.float32)
762
-
763
- return logits
764
-
765
-
766
- class CausalLMLoss(nn.Module):
767
- """Causal Language Modeling loss.
768
-
769
- Reference:
770
- Improving Language Understanding by Generative Pre-Training.
771
- https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
772
-
773
- """
774
-
775
- def __init__(self, shift_labels: Optional[bool] = True) -> None:
776
- super().__init__()
777
-
778
- self.shift_labels = shift_labels
779
- self.loss_fct = nn.CrossEntropyLoss()
780
-
781
- def forward(
782
- self, logits: torch.FloatTensor, labels: torch.LongTensor
783
- ) -> torch.FloatTensor:
784
- if self.shift_labels:
785
- logits = logits[..., :-1, :].contiguous()
786
- labels = labels[..., 1:].contiguous()
787
-
788
- loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
789
-
790
- return loss
791
-
792
-
793
- class MixFormerSequentialPreTrainedModel(PreTrainedModel):
794
- """MixFormer (sequential for DeepSpeed) pre-trained model."""
795
-
796
- config_class = MixFormerSequentialConfig
797
- base_model_prefix = "transformer"
798
- supports_gradient_checkpointing = True
799
-
800
- def __init__(self, *inputs, **kwargs) -> None:
801
- super().__init__(*inputs, **kwargs)
802
-
803
- def prepare_inputs_for_generation(
804
- self, input_ids, past_key_values=None, **kwargs
805
- ) -> Dict[str, Any]:
806
- if "use_cache" in kwargs and not kwargs["use_cache"]:
807
- return {"input_ids": input_ids}
808
-
809
- if past_key_values is None or not (
810
- isinstance(past_key_values, InferenceParams)
811
- ):
812
- past_key_values = InferenceParams(
813
- max_batch_size=input_ids.shape[0],
814
- max_sequence_len=self.config.n_positions,
815
- sequence_len_offset=0,
816
- batch_size_offset=0,
817
- fused_ft_kernel=False,
818
- key_value_memory_dict={},
819
- )
820
- else:
821
- # assume past_key_values has cached all but last token in input_ids
822
- past_key_values.sequence_len_offset = len(input_ids[0]) - 1
823
- input_ids = input_ids[:, -1].unsqueeze(-1)
824
-
825
- return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs}
826
-
827
-
828
- class PackedSequential(nn.Sequential):
829
- def forward(
830
- self,
831
- input,
832
- cu_seqlens: Optional[torch.LongTensor] = None,
833
- max_seqlen: Optional[int] = None,
834
- ):
835
- for module in self:
836
- sig = inspect.signature(module.forward)
837
- if "cu_seqlens" in sig.parameters:
838
- input = module(input, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
839
- else:
840
- input = module(input)
841
- return input
842
-
843
-
844
- class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
845
- """MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""
846
-
847
- _keys_to_ignore_on_load_missing = [""]
848
- _keys_to_ignore_on_load_unexpected = [
849
- r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"
850
- ]
851
- _no_split_modules = ["ParallelBlock"]
852
-
853
- def __init__(self, config: MixFormerSequentialConfig) -> None:
854
- super().__init__(config)
855
-
856
- modules = [Embedding(config)]
857
- block_config = config.architecture
858
-
859
- if not isinstance(block_config, list):
860
- block_config = [block_config for _ in range(config.n_layer)]
861
-
862
- if config.n_layer != len(block_config):
863
- config.n_layer = len(block_config)
864
-
865
- for block_idx, block in enumerate(block_config):
866
- # `block_cls` with `legacy` value is for backward compatibility
867
- # `path` key is for backward compatibility
868
- block = copy.deepcopy(block) or {"block_cls": "parallel"}
869
- block.pop("path", None) or block.pop("block_cls", None)
870
-
871
- block["block_idx"] = block_idx
872
- modules.append(ParallelBlock(config, **block))
873
-
874
- modules.append(CausalLMHead(config))
875
-
876
- self.layers = PackedSequential(*modules)
877
- self.loss = CausalLMLoss()
878
-
879
- self.post_init()
880
-
881
- def get_input_embeddings(self) -> nn.Embedding:
882
- return self.layers[0].wte
883
-
884
- def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
885
- self.layers[0].wte = new_embeddings
886
-
887
- def get_output_embeddings(self) -> nn.Linear:
888
- return self.layers[-1].linear
889
-
890
- def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
891
- self.layers[-1].linear = new_embeddings
892
-
893
- def forward(
894
- self,
895
- input_ids: torch.LongTensor,
896
- labels: Optional[torch.LongTensor] = None,
897
- past_key_values: Optional[torch.FloatTensor] = None,
898
- position_ids: Optional[torch.LongTensor] = None,
899
- **kwargs,
900
- ) -> CausalLMOutputWithPast:
901
- cu_seqlens: Optional[torch.LongTensor] = None
902
- max_seqlen: Optional[int] = None
903
- if position_ids is not None:
904
- batch_size, seq_length = input_ids.shape
905
- position_ids = position_ids.view(-1, seq_length).long()
906
- cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
907
- cu_seqlens = cu_seqlens.squeeze()
908
-
909
- if not past_key_values:
910
- lm_logits = self.layers(
911
- input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
912
- )
913
- else:
914
- hidden_layer = self.layers[0](input_ids)
915
- for module in self.layers[1:-1]:
916
- hidden_layer = module(
917
- hidden_layer,
918
- past_cache=past_key_values,
919
- cu_seqlens=cu_seqlens,
920
- max_seqlen=max_seqlen,
921
- )
922
- lm_logits = self.layers[-1](hidden_layer)
923
-
924
- loss = None
925
- if labels is not None:
926
- loss = self.loss(lm_logits, labels)
927
-
928
- return CausalLMOutputWithPast(
929
- loss=loss, logits=lm_logits, past_key_values=past_key_values
930
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/models/phi/modeling_phi.py DELETED
@@ -1,1092 +0,0 @@
1
- # pylint: skip-file
2
- # Copyright (c) Microsoft Corporation.
3
- # Licensed under the MIT license.
4
- #
5
- # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
6
- # Licensed under the BSD 3-Clause License.
7
-
8
- from __future__ import annotations
9
-
10
- import math
11
- from dataclasses import dataclass, field
12
- from typing import Any, Callable, Dict, Optional, Tuple, Union
13
-
14
- import torch
15
- import torch.nn as nn
16
- from einops import rearrange, repeat
17
- from torch.utils.checkpoint import checkpoint
18
- from transformers import PretrainedConfig, PreTrainedModel
19
- from transformers.activations import ACT2FN
20
- from transformers.modeling_outputs import CausalLMOutputWithPast
21
-
22
- from .configuration_phi import PhiConfig
23
-
24
- try:
25
- from flash_attn.bert_padding import pad_input, unpad_input
26
- from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
27
- from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
28
- except ImportError:
29
- pad_input, unpad_input = None, None
30
- FlashRotaryEmbedding = None
31
- FlashSelfAttention, FlashCrossAttention = None, None
32
-
33
- # this is in a seperate try/except block since sometimes fused_dense isn't available
34
- # and it shouldn't completely disable flash attn when it isn't
35
- try:
36
- from flash_attn.ops.fused_dense import FusedDense
37
- except ImportError:
38
- FusedDense = None
39
-
40
-
41
- @dataclass
42
- class InferenceParams:
43
- """Inference parameters passed to model to efficiently calculate
44
- and store context during inference.
45
-
46
- Reference:
47
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
48
-
49
- Args:
50
- max_seqlen: Maximum sequence length.
51
- max_batch_size: Maximum batch size.
52
- seqlen_offset: Sequence length offset.
53
- batch_size_offset: Batch size offset.
54
- key_value_memory_dict: Key value memory dictionary.
55
- lengths_per_sample: Lengths per sample.
56
-
57
- """
58
-
59
- max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
60
-
61
- max_batch_size: int = field(metadata={"help": "Maximum batch size."})
62
-
63
- seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
64
-
65
- batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
66
-
67
- key_value_memory_dict: Dict[str, Any] = field(
68
- default_factory=dict, metadata={"help": "Key value memory dictionary."}
69
- )
70
-
71
- lengths_per_sample: torch.Tensor = field(
72
- default=None, metadata={"help": "Lengths per sample."}
73
- )
74
-
75
-
76
- class Embedding(nn.Module):
77
- """Token embedding with dropout."""
78
-
79
- def __init__(self, config: PretrainedConfig) -> None:
80
- super().__init__()
81
-
82
- self.wte = nn.Embedding(config.vocab_size, config.n_embd)
83
- self.drop = nn.Dropout(config.embd_pdrop)
84
-
85
- def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
86
- input_shape = input_ids.size()
87
- input_ids = input_ids.view(-1, input_shape[-1])
88
-
89
- hidden_states = self.wte(input_ids)
90
- hidden_states = self.drop(hidden_states)
91
-
92
- return hidden_states
93
-
94
-
95
- def _apply_rotary_emb(
96
- x: torch.FloatTensor,
97
- cos: torch.FloatTensor,
98
- sin: torch.FloatTensor,
99
- ) -> torch.FloatTensor:
100
- _, seqlen, _, _ = x.shape
101
- _, rotary_dim = cos.shape
102
- rotary_dim *= 2
103
-
104
- x_rot = x[:, :, :, :rotary_dim]
105
- x_pass = x[:, :, :, rotary_dim:]
106
-
107
- x1, x2 = x_rot.chunk(2, dim=-1)
108
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
109
- sin[:seqlen], "s d -> s 1 d"
110
- )
111
- x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
112
-
113
- x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
114
-
115
- return torch.cat([x_rot, x_pass], axis=-1)
116
-
117
-
118
- def _apply_rotary_emb_kv(
119
- kv: torch.FloatTensor,
120
- cos: torch.FloatTensor,
121
- sin: torch.FloatTensor,
122
- cos_k: Optional[torch.FloatTensor] = None,
123
- sin_k: Optional[torch.FloatTensor] = None,
124
- ) -> torch.FloatTensor:
125
- _, seqlen, _, _, _ = kv.shape
126
- _, rotary_dim = cos.shape
127
- rotary_dim *= 2
128
-
129
- k_rot = kv[:, :, 0, :, :rotary_dim]
130
- k_pass = kv[:, :, 0, :, rotary_dim:]
131
-
132
- k1, k2 = k_rot.chunk(2, dim=-1)
133
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
134
- sin[:seqlen], "s d -> s 1 d"
135
- )
136
- k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
137
-
138
- k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
139
-
140
- return torch.cat(
141
- [
142
- torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
143
- kv[:, :, 1:2, :, :],
144
- ],
145
- axis=2,
146
- )
147
-
148
-
149
- def _apply_rotary_emb_qkv(
150
- qkv: torch.FloatTensor,
151
- cos: torch.FloatTensor,
152
- sin: torch.FloatTensor,
153
- cos_k: Optional[torch.FloatTensor] = None,
154
- sin_k: Optional[torch.FloatTensor] = None,
155
- ) -> torch.FloatTensor:
156
- _, seqlen, _, _, _ = qkv.shape
157
- _, rotary_dim = cos.shape
158
- rotary_dim *= 2
159
-
160
- q_rot = qkv[:, :, 0, :, :rotary_dim]
161
- q_pass = qkv[:, :, 0, :, rotary_dim:]
162
-
163
- k_rot = qkv[:, :, 1, :, :rotary_dim]
164
- k_pass = qkv[:, :, 1, :, rotary_dim:]
165
-
166
- q1, q2 = q_rot.chunk(2, dim=-1)
167
- k1, k2 = k_rot.chunk(2, dim=-1)
168
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
169
- sin[:seqlen], "s d -> s 1 d"
170
- )
171
- q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
172
-
173
- q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
174
- k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
175
-
176
- return torch.cat(
177
- [
178
- torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
179
- torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
180
- qkv[:, :, 2:3, :, :],
181
- ],
182
- axis=2,
183
- )
184
-
185
-
186
- class RotaryEmbedding(nn.Module):
187
- """Rotary positional embedding (RoPE).
188
-
189
- Reference:
190
- RoFormer: Enhanced Transformer with Rotary Position Embedding.
191
- https://arxiv.org/pdf/2104.09864.pdf.
192
-
193
- """
194
-
195
- def __init__(
196
- self,
197
- dim: int,
198
- base: int = 10000,
199
- scale_base: Optional[float] = None,
200
- pos_idx_in_fp32: bool = True,
201
- max_position_embeddings: int = 2048,
202
- device: Optional[str] = None,
203
- **kwargs,
204
- ) -> None:
205
- super().__init__()
206
-
207
- if scale_base is not None:
208
- raise NotImplementedError
209
-
210
- self.dim = dim
211
- self.base = float(base)
212
- self.scale_base = scale_base
213
- self.pos_idx_in_fp32 = pos_idx_in_fp32
214
- self.max_position_embeddings = max_position_embeddings
215
- self.device = device
216
-
217
- # Generate and save the inverse frequency buffer (non-trainable)
218
- inv_freq = self._compute_inv_freq(device)
219
- self.register_buffer("inv_freq", inv_freq, persistent=False)
220
-
221
- # Generate and save the scale buffer (non-trainable)
222
- scale = (
223
- (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
224
- / (1.4 * dim)
225
- if scale_base is not None
226
- else None
227
- )
228
- self.register_buffer("scale", scale, persistent=False)
229
-
230
- # Initialize cached attributes since ONNX can't rely on dynamic initialization
231
- self._update_cos_sin_cache(
232
- max_position_embeddings,
233
- device=device,
234
- dtype=torch.float32,
235
- )
236
-
237
- def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
238
- return 1.0 / (
239
- self.base
240
- ** (
241
- torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
242
- / self.dim
243
- )
244
- )
245
-
246
- def _update_cos_sin_cache(
247
- self,
248
- seqlen: int,
249
- device: Optional[str] = None,
250
- dtype: Optional[torch.dtype] = None,
251
- ) -> None:
252
- self._seq_len_cached = seqlen
253
-
254
- # fp32 is preferred since the output of `torch.arange` can be quite large
255
- # and bf16 would lose a lot of precision
256
- if self.pos_idx_in_fp32:
257
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
258
- if self.inv_freq.dtype != torch.float32:
259
- inv_freq = self._compute_inv_freq(device=device)
260
- else:
261
- inv_freq = self.inv_freq
262
- else:
263
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
264
- inv_freq = self.inv_freq
265
-
266
- # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
267
- freqs = torch.outer(t, inv_freq)
268
- if self.scale is None:
269
- self._cos_cached = torch.cos(freqs).to(dtype)
270
- self._sin_cached = torch.sin(freqs).to(dtype)
271
- else:
272
- power = (
273
- torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
274
- - seqlen // 2
275
- ) / self.scale_base
276
- scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
277
-
278
- # Force the scale multiplication to happen in fp32
279
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
280
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
281
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
282
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
283
-
284
- def forward(
285
- self,
286
- qkv: torch.Tensor,
287
- kv: Optional[torch.Tensor] = None,
288
- seqlen_offset: int = 0,
289
- **kwargs,
290
- ) -> Tuple[torch.Tensor, torch.Tensor]:
291
- if (
292
- self._seq_len_cached < qkv.shape[1] + seqlen_offset
293
- or self._cos_cached.device != qkv.device
294
- or self._cos_cached.dtype != qkv.dtype
295
- or (self.training and self._cos_cached.is_inference())
296
- ):
297
- self._update_cos_sin_cache(
298
- qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype
299
- )
300
-
301
- if kv is None:
302
- return _apply_rotary_emb_qkv(
303
- qkv,
304
- self._cos_cached[seqlen_offset:],
305
- self._sin_cached[seqlen_offset:],
306
- )
307
- else:
308
- q = _apply_rotary_emb(
309
- qkv,
310
- self._cos_cached[seqlen_offset:],
311
- self._sin_cached[seqlen_offset:],
312
- )
313
- kv = _apply_rotary_emb_kv(
314
- kv,
315
- self._cos_cached[seqlen_offset:],
316
- self._sin_cached[seqlen_offset:],
317
- )
318
-
319
- return q, kv
320
-
321
-
322
- class MLP(nn.Module):
323
- """Multi-Layer Perceptron.
324
-
325
- Reference:
326
- Attention Is All You Need.
327
- https://arxiv.org/pdf/1706.03762.pdf.
328
-
329
- """
330
-
331
- def __init__(
332
- self,
333
- config: PretrainedConfig,
334
- n_inner: Optional[int] = None,
335
- act_fn: Optional[str] = None,
336
- ) -> None:
337
- super().__init__()
338
-
339
- act_fn = config.activation_function if act_fn is None else act_fn
340
-
341
- n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
342
- n_inner = n_inner if n_inner is not None else 4 * config.n_embd
343
-
344
- self.fc1 = nn.Linear(config.n_embd, n_inner)
345
- self.fc2 = nn.Linear(n_inner, config.n_embd)
346
- self.act = ACT2FN[act_fn]
347
-
348
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
349
- hidden_states = self.fc1(hidden_states)
350
- hidden_states = self.act(hidden_states)
351
- hidden_states = self.fc2(hidden_states)
352
-
353
- return hidden_states
354
-
355
-
356
- class SelfAttention(nn.Module):
357
- """Self-attention layer (compatible with PyTorch).
358
-
359
- Reference:
360
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
361
-
362
- """
363
-
364
- def __init__(
365
- self,
366
- causal: bool = True,
367
- softmax_scale: Optional[float] = None,
368
- attention_dropout: float = 0.0,
369
- ) -> None:
370
- super().__init__()
371
-
372
- self.causal = causal
373
- self.softmax_scale = softmax_scale
374
- self.drop = nn.Dropout(attention_dropout)
375
-
376
- @torch.autocast("cpu", enabled=False)
377
- @torch.autocast("cuda", enabled=False)
378
- def forward(
379
- self,
380
- qkv: torch.FloatTensor,
381
- causal: bool = None,
382
- key_padding_mask: Optional[torch.BoolTensor] = None,
383
- **kwargs,
384
- ) -> torch.FloatTensor:
385
- batch_size, seqlen = qkv.shape[0], qkv.shape[1]
386
- q, k, v = qkv.unbind(dim=2)
387
-
388
- q = q.to(torch.float32)
389
- k = k.to(torch.float32)
390
-
391
- causal = self.causal if causal is None else causal
392
- softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
393
-
394
- # Autocast is manually disabled to avoid `torch.einsum` performing the operation
395
- # using float16, which might lead to overflow
396
- scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
397
-
398
- if key_padding_mask is not None:
399
- padding_mask = torch.full(
400
- (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
401
- )
402
- padding_mask.masked_fill_(key_padding_mask, 0.0)
403
-
404
- scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
405
-
406
- if causal:
407
- causal_mask = torch.triu(
408
- torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
409
- )
410
- scores = scores + causal_mask.to(dtype=scores.dtype)
411
-
412
- attention = torch.softmax(scores, dim=-1).to(v.dtype)
413
- attention = self.drop(attention)
414
-
415
- output = torch.einsum("bhts,bshd->bthd", attention, v)
416
-
417
- return output
418
-
419
-
420
- class CrossAttention(nn.Module):
421
- """Cross-attention layer (compatible with PyTorch).
422
-
423
- Reference:
424
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
425
-
426
- """
427
-
428
- def __init__(
429
- self,
430
- causal: bool = True,
431
- softmax_scale: Optional[float] = None,
432
- attention_dropout: float = 0.0,
433
- ) -> None:
434
- super().__init__()
435
-
436
- self.causal = causal
437
- self.softmax_scale = softmax_scale
438
- self.drop = nn.Dropout(attention_dropout)
439
-
440
- @torch.autocast("cpu", enabled=False)
441
- @torch.autocast("cuda", enabled=False)
442
- def forward(
443
- self,
444
- q: torch.FloatTensor,
445
- kv: torch.FloatTensor,
446
- causal: bool = None,
447
- key_padding_mask: Optional[torch.BoolTensor] = None,
448
- **kwargs,
449
- ) -> torch.FloatTensor:
450
- batch_size, seqlen_q = q.shape[0], q.shape[1]
451
- seqlen_k = kv.shape[1]
452
-
453
- if kv.shape[3] != q.shape[2]:
454
- kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
455
- k, v = kv.unbind(dim=2)
456
-
457
- q = q.to(torch.float32)
458
- k = k.to(torch.float32)
459
-
460
- causal = self.causal if causal is None else causal
461
- softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
462
-
463
- # Autocast is manually disabled to avoid `torch.einsum` performing the operation
464
- # using float16, which might lead to overflow
465
- scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
466
-
467
- if key_padding_mask is not None:
468
- padding_mask = torch.full(
469
- (batch_size, seqlen_k),
470
- -10000.0,
471
- dtype=scores.dtype,
472
- device=scores.device,
473
- )
474
- padding_mask.masked_fill_(key_padding_mask, 0.0)
475
-
476
- scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
477
-
478
- if causal:
479
- rows = rearrange(
480
- torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
481
- )
482
- cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
483
- causal_mask = cols > rows + seqlen_k - seqlen_q
484
-
485
- scores = scores.masked_fill(causal_mask, -10000.0)
486
-
487
- attention = torch.softmax(scores, dim=-1).to(v.dtype)
488
- attention = self.drop(attention)
489
-
490
- output = torch.einsum("bhts,bshd->bthd", attention, v)
491
-
492
- return output
493
-
494
-
495
- def _find_mha_dims(
496
- config: PretrainedConfig,
497
- n_head: Optional[int] = None,
498
- n_head_kv: Optional[int] = None,
499
- head_dim: Optional[int] = None,
500
- ) -> Tuple[int, int]:
501
- if n_head is None and head_dim is None:
502
- head_dim = config.n_embd // config.n_head
503
- n_head = config.n_head
504
- elif n_head is None or head_dim is None:
505
- raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
506
-
507
- if n_head_kv is None:
508
- n_head_kv = getattr(config, "n_head_kv", None) or n_head
509
-
510
- return n_head, n_head_kv, head_dim
511
-
512
-
513
- def _update_kv_cache(
514
- kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int
515
- ) -> torch.FloatTensor:
516
- num_heads, head_dim = kv.shape[-2:]
517
-
518
- if layer_idx not in inference_params.key_value_memory_dict:
519
- inference_params.key_value_memory_dict[layer_idx] = torch.empty(
520
- inference_params.max_batch_size,
521
- inference_params.max_seqlen,
522
- 2,
523
- num_heads,
524
- head_dim,
525
- dtype=kv.dtype,
526
- device=kv.device,
527
- )
528
-
529
- batch_start = inference_params.batch_size_offset
530
- batch_end = batch_start + kv.shape[0]
531
-
532
- sequence_start = inference_params.seqlen_offset
533
- sequence_end = sequence_start + kv.shape[1]
534
-
535
- # When the current sequence length is equal to or larger than the maximum sequence length,
536
- # we need to concatenate the current `kv` with the cached `kv` to expand its length
537
- if sequence_end >= inference_params.max_seqlen:
538
- inference_params.key_value_memory_dict[layer_idx] = torch.concatenate(
539
- (inference_params.key_value_memory_dict[layer_idx], kv), dim=1
540
- )
541
-
542
- inference_params.key_value_memory_dict[layer_idx][
543
- batch_start:batch_end, sequence_start:sequence_end, ...
544
- ] = kv
545
- kv = inference_params.key_value_memory_dict[layer_idx][
546
- batch_start:batch_end, :sequence_end, ...
547
- ]
548
-
549
- return kv
550
-
551
-
552
- class MHA(nn.Module):
553
- """Multi-head attention layer."""
554
-
555
- def __init__(
556
- self,
557
- config: PretrainedConfig,
558
- dtype: Optional[torch.dtype] = None,
559
- device: Optional[str] = None,
560
- rotary_dim: Optional[int] = None,
561
- rotary_base: float = 10000.0,
562
- rotary_scale_base: Optional[float] = None,
563
- n_head: Optional[int] = None,
564
- n_head_kv: Optional[int] = None,
565
- head_dim: Optional[int] = None,
566
- bias: bool = True,
567
- causal: bool = True,
568
- softmax_scale: Optional[float] = None,
569
- layer_idx: Optional[int] = None,
570
- return_residual: bool = False,
571
- checkpointing: bool = False,
572
- ) -> None:
573
- super().__init__()
574
-
575
- # Rotary embedding
576
- self.rotary_dim = (
577
- rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
578
- )
579
- if self.rotary_dim > 0:
580
- rotary_cls = (
581
- FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
582
- )
583
- if rotary_cls is None:
584
- rotary_cls = RotaryEmbedding
585
-
586
- rotary_kwargs = {}
587
- if rotary_cls is RotaryEmbedding:
588
- rotary_kwargs["max_position_embeddings"] = config.n_positions
589
-
590
- self.rotary_emb = rotary_cls(
591
- self.rotary_dim,
592
- base=rotary_base,
593
- scale_base=rotary_scale_base,
594
- device=device,
595
- **rotary_kwargs,
596
- )
597
-
598
- # MLP
599
- self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
600
- config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim
601
- )
602
- op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
603
- hidden_size = config.n_embd
604
-
605
- linear_cls = FusedDense if config.fused_dense else nn.Linear
606
- if linear_cls is None:
607
- linear_cls = nn.Linear
608
-
609
- self.Wqkv = linear_cls(
610
- hidden_size, op_size, bias=bias, device=device, dtype=dtype
611
- )
612
- self.out_proj = linear_cls(
613
- hidden_size, hidden_size, bias=bias, device=device, dtype=dtype
614
- )
615
-
616
- # Attention
617
- attn_cls = FlashSelfAttention if config.flash_attn else SelfAttention
618
- if attn_cls is None:
619
- attn_cls = SelfAttention
620
-
621
- cross_attn_cls = FlashCrossAttention if config.flash_attn else CrossAttention
622
- if cross_attn_cls is None:
623
- cross_attn_cls = CrossAttention
624
-
625
- self.inner_attn = attn_cls(
626
- causal=causal,
627
- softmax_scale=softmax_scale,
628
- attention_dropout=config.attn_pdrop,
629
- )
630
- self.inner_cross_attn = cross_attn_cls(
631
- causal=causal,
632
- softmax_scale=softmax_scale,
633
- attention_dropout=config.attn_pdrop,
634
- )
635
-
636
- self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
637
- self.layer_idx = layer_idx
638
- self.return_residual = return_residual
639
- self.checkpointing = checkpointing
640
- self._gradient_checkpointing_func = None
641
-
642
- def _forward_self_attn(
643
- self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
644
- ) -> torch.FloatTensor:
645
- qkv = self.Wqkv(x)
646
- qkv = rearrange(
647
- qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
648
- )
649
-
650
- if self.rotary_dim > 0:
651
- qkv = self.rotary_emb(qkv)
652
-
653
- if self.flash_attn:
654
- batch_size, seqlen = qkv.shape[0], qkv.shape[1]
655
-
656
- cu_seqlens, max_seqlen = None, None
657
- if key_padding_mask is not None:
658
- # If `key_padding_mask` is supplied, we need to unpad the input and retrieve
659
- # the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
660
- qkv, indices, cu_seqlens, max_seqlen = unpad_input(
661
- qkv, key_padding_mask
662
- )
663
-
664
- if self.checkpointing and self.training:
665
- attn_output = self._gradient_checkpointing_func(
666
- self.inner_attn,
667
- qkv,
668
- cu_seqlens=cu_seqlens,
669
- max_seqlen=max_seqlen,
670
- use_reentrant=False,
671
- )
672
- else:
673
- attn_output = self.inner_attn(
674
- qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
675
- ).to(qkv.device)
676
-
677
- # If `key_padding_mask` is supplied, we need to pad the output back to the original shape
678
- return (
679
- pad_input(attn_output, indices, batch_size, seqlen)
680
- if key_padding_mask is not None
681
- else attn_output
682
- )
683
-
684
- if self.checkpointing and self.training:
685
- return self._gradient_checkpointing_func(
686
- self.inner_attn,
687
- qkv,
688
- key_padding_mask=key_padding_mask,
689
- use_reentrant=False,
690
- )
691
-
692
- return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
693
-
694
- def _forward_cross_attn(
695
- self,
696
- x: torch.FloatTensor,
697
- past_key_values: Optional[InferenceParams],
698
- key_padding_mask: Optional[torch.BoolTensor],
699
- ) -> torch.FloatTensor:
700
- batch_size = x.shape[0]
701
-
702
- qkv = self.Wqkv(x)
703
-
704
- q = qkv[..., : self.n_head * self.head_dim]
705
- q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
706
-
707
- kv = qkv[..., self.n_head * self.head_dim :]
708
- kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
709
-
710
- seqlen_offset = (
711
- past_key_values.seqlen_offset if past_key_values is not None else 0
712
- )
713
- causal = None if seqlen_offset == 0 else False
714
- if self.rotary_dim > 0:
715
- q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
716
-
717
- if past_key_values is not None:
718
- kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
719
-
720
- if self.flash_attn:
721
- batch_size, seqlen_q = q.shape[0], q.shape[1]
722
- seqlen_k = kv.shape[1]
723
-
724
- cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = (
725
- None,
726
- None,
727
- None,
728
- None,
729
- )
730
- if key_padding_mask is not None:
731
- kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
732
-
733
- if seqlen_q == 1:
734
- key_padding_mask = torch.ones(batch_size, 1, device=q.device)
735
- elif seqlen_q != seqlen_k:
736
- key_padding_mask = key_padding_mask[:, -seqlen_q:]
737
-
738
- q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
739
- q, key_padding_mask
740
- )
741
-
742
- if self.checkpointing and self.training:
743
- attn_output = self._gradient_checkpointing_func(
744
- self.inner_cross_attn,
745
- q,
746
- kv,
747
- causal=causal,
748
- cu_seqlens=cu_seqlens_q,
749
- max_seqlen=max_seqlen_q,
750
- cu_seqlens_k=cu_seqlens_k,
751
- max_seqlen_k=max_seqlen_k,
752
- use_reentrant=False,
753
- )
754
- else:
755
- attn_output = self.inner_cross_attn(
756
- q,
757
- kv,
758
- causal=causal,
759
- cu_seqlens=cu_seqlens_q,
760
- max_seqlen=max_seqlen_q,
761
- cu_seqlens_k=cu_seqlens_k,
762
- max_seqlen_k=max_seqlen_k,
763
- )
764
-
765
- return (
766
- pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
767
- if key_padding_mask is not None
768
- else attn_output
769
- )
770
-
771
- if self.checkpointing and self.training:
772
- return self._gradient_checkpointing_func(
773
- self.inner_cross_attn,
774
- q,
775
- kv,
776
- key_padding_mask=key_padding_mask,
777
- causal=causal,
778
- use_reentrant=False,
779
- )
780
-
781
- return self.inner_cross_attn(
782
- q, kv, key_padding_mask=key_padding_mask, causal=causal
783
- )
784
-
785
- def forward(
786
- self,
787
- x: torch.FloatTensor,
788
- past_key_values: Optional[InferenceParams] = None,
789
- attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
790
- **kwargs,
791
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
792
- if attention_mask is not None:
793
- attention_mask = attention_mask.bool()
794
- else:
795
- attention_mask = None
796
-
797
- # MHA
798
- if self.n_head == self.n_head_kv:
799
- if past_key_values is None:
800
- # If `past_key_values` are not supplied, we run self-attention
801
- attn_output = self._forward_self_attn(x, attention_mask)
802
- else:
803
- # If `past_key_values` are supplied, it means that we might have cached values and
804
- # could take advantage of cross-attention
805
- attn_output = self._forward_cross_attn(
806
- x, past_key_values, attention_mask
807
- )
808
- # MQA / GQA
809
- else:
810
- # Regardless of `past_key_values` being supplied or not, it always use cross-attention
811
- # because `q` and `kv` lengths might be different
812
- attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
813
-
814
- output = rearrange(attn_output, "... h d -> ... (h d)")
815
- output = self.out_proj(output)
816
-
817
- return output if not self.return_residual else (output, x)
818
-
819
-
820
- class ParallelBlock(nn.Module):
821
- """Parallel block.
822
-
823
- This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
824
-
825
- """
826
-
827
- def __init__(
828
- self,
829
- config: PretrainedConfig,
830
- block_idx: Optional[int] = None,
831
- ) -> None:
832
- super().__init__()
833
-
834
- self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
835
- self.resid_dropout = nn.Dropout(config.resid_pdrop)
836
- self.block_idx = block_idx
837
-
838
- self.mixer = MHA(config, layer_idx=block_idx)
839
- self.mlp = MLP(config)
840
- self.checkpointing = False
841
- self._gradient_checkpointing_func = None
842
-
843
- def forward(
844
- self,
845
- hidden_states: torch.FloatTensor,
846
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
847
- attention_mask: Optional[torch.BoolTensor] = None,
848
- **kwargs,
849
- ) -> torch.FloatTensor:
850
- def _forward(
851
- mixer,
852
- resid_dropout,
853
- mlp,
854
- ln,
855
- hidden_states,
856
- past_key_values,
857
- attention_mask,
858
- ):
859
- residual = hidden_states
860
- hidden_states = ln(hidden_states)
861
-
862
- attn_outputs = mixer(
863
- hidden_states,
864
- past_key_values=past_key_values,
865
- attention_mask=attention_mask,
866
- )
867
- if isinstance(attn_outputs, tuple):
868
- attn_outputs = attn_outputs[0]
869
-
870
- attn_outputs = resid_dropout(attn_outputs)
871
- feed_forward_hidden_states = resid_dropout(mlp(hidden_states))
872
-
873
- return attn_outputs + feed_forward_hidden_states + residual
874
-
875
- if self.training and self.checkpointing:
876
- return self._gradient_checkpointing_func(
877
- _forward,
878
- self.mixer,
879
- self.resid_dropout,
880
- self.mlp,
881
- self.ln,
882
- hidden_states,
883
- past_key_values,
884
- attention_mask,
885
- )
886
-
887
- return _forward(
888
- self.mixer,
889
- self.resid_dropout,
890
- self.mlp,
891
- self.ln,
892
- hidden_states,
893
- past_key_values,
894
- attention_mask,
895
- )
896
-
897
-
898
- class CausalLMHead(nn.Module):
899
- """Causal Language Modeling head.
900
-
901
- Reference:
902
- Improving Language Understanding by Generative Pre-Training.
903
- https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
904
-
905
- """
906
-
907
- def __init__(self, config: PretrainedConfig) -> None:
908
- super().__init__()
909
-
910
- self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
911
- self.linear = nn.Linear(config.n_embd, config.vocab_size)
912
-
913
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
914
- hidden_states = self.ln(hidden_states)
915
- logits = self.linear(hidden_states).to(torch.float32)
916
-
917
- return logits
918
-
919
-
920
- class CausalLMLoss(nn.Module):
921
- """Causal Language Modeling loss.
922
-
923
- Reference:
924
- Improving Language Understanding by Generative Pre-Training.
925
- https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
926
-
927
- """
928
-
929
- def __init__(self, shift_labels: bool = True) -> None:
930
- super().__init__()
931
-
932
- self.shift_labels = shift_labels
933
- self.loss_fct = nn.CrossEntropyLoss()
934
-
935
- def forward(
936
- self, logits: torch.FloatTensor, labels: torch.LongTensor
937
- ) -> torch.FloatTensor:
938
- if self.shift_labels:
939
- logits = logits[..., :-1, :].contiguous()
940
- labels = labels[..., 1:].contiguous()
941
-
942
- loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
943
-
944
- return loss
945
-
946
-
947
- class PhiPreTrainedModel(PreTrainedModel):
948
- """Phi pre-trained model."""
949
-
950
- config_class = PhiConfig
951
- base_model_prefix = "transformer"
952
- supports_gradient_checkpointing = True
953
- _no_split_modules = ["ParallelBlock"]
954
-
955
- def __init__(self, *inputs, **kwargs) -> None:
956
- super().__init__(*inputs, **kwargs)
957
-
958
- def _init_weights(self, module: nn.Module) -> None:
959
- if isinstance(module, (nn.Linear,)):
960
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
961
- if module.bias is not None:
962
- module.bias.data.zero_()
963
- elif isinstance(module, nn.Embedding):
964
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
965
- if module.padding_idx is not None:
966
- module.weight.data[module.padding_idx].zero_()
967
- elif isinstance(module, nn.LayerNorm):
968
- if module.bias is not None:
969
- module.bias.data.zero_()
970
- module.weight.data.fill_(1.0)
971
-
972
- def _set_gradient_checkpointing(
973
- self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint
974
- ):
975
- for module in self.modules():
976
- if hasattr(module, "checkpointing"):
977
- module._gradient_checkpointing_func = gradient_checkpointing_func
978
- module.checkpointing = enable
979
-
980
- def prepare_inputs_for_generation(
981
- self,
982
- input_ids: torch.LongTensor,
983
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
984
- attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
985
- **kwargs,
986
- ) -> Dict[str, Any]:
987
- if past_key_values is None or not (
988
- isinstance(past_key_values, InferenceParams)
989
- ):
990
- past_key_values = InferenceParams(
991
- max_seqlen=self.config.n_positions,
992
- max_batch_size=input_ids.shape[0],
993
- seqlen_offset=0,
994
- batch_size_offset=0,
995
- key_value_memory_dict={},
996
- lengths_per_sample=None,
997
- )
998
- else:
999
- # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
1000
- past_key_values.seqlen_offset = input_ids.shape[1] - 1
1001
- input_ids = input_ids[:, -1].unsqueeze(-1)
1002
-
1003
- return {
1004
- "input_ids": input_ids,
1005
- "past_key_values": past_key_values,
1006
- "attention_mask": attention_mask,
1007
- }
1008
-
1009
-
1010
- class PhiModel(PhiPreTrainedModel):
1011
- """Phi model."""
1012
-
1013
- _keys_to_ignore_on_load_missing = [""]
1014
- _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
1015
-
1016
- def __init__(self, config: PhiConfig) -> None:
1017
- super().__init__(config)
1018
-
1019
- self.embd = Embedding(config)
1020
- self.h = nn.ModuleList(
1021
- [ParallelBlock(config, block_idx=i) for i in range(config.n_layer)]
1022
- )
1023
- self.gradient_checkpointing = False
1024
- self.post_init()
1025
-
1026
- def get_input_embeddings(self) -> nn.Embedding:
1027
- return self.embd.wte
1028
-
1029
- def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
1030
- self.embd.wte = new_embeddings
1031
-
1032
- def forward(
1033
- self,
1034
- input_ids: torch.LongTensor,
1035
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
1036
- attention_mask: Optional[torch.BoolTensor] = None,
1037
- ) -> torch.FloatTensor:
1038
- hidden_states = self.embd(input_ids)
1039
-
1040
- for layer in self.h:
1041
- hidden_states = layer(
1042
- hidden_states,
1043
- past_key_values=past_key_values,
1044
- attention_mask=attention_mask,
1045
- )
1046
-
1047
- return hidden_states
1048
-
1049
-
1050
- class PhiForCausalLM(PhiPreTrainedModel):
1051
- """Phi for Causal Language Modeling."""
1052
-
1053
- _keys_to_ignore_on_load_missing = [""]
1054
- _keys_to_ignore_on_load_unexpected = [
1055
- r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"
1056
- ]
1057
-
1058
- def __init__(self, config: PhiConfig) -> None:
1059
- super().__init__(config)
1060
-
1061
- self.transformer = PhiModel(config)
1062
- self.lm_head = CausalLMHead(config)
1063
- self.loss = CausalLMLoss()
1064
-
1065
- self.post_init()
1066
-
1067
- def get_output_embeddings(self) -> nn.Linear:
1068
- return self.lm_head.linear
1069
-
1070
- def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
1071
- self.lm_head.linear = new_embeddings
1072
-
1073
- def forward(
1074
- self,
1075
- input_ids: torch.LongTensor,
1076
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
1077
- attention_mask: Optional[torch.BoolTensor] = None,
1078
- labels: Optional[torch.LongTensor] = None,
1079
- **kwargs,
1080
- ) -> CausalLMOutputWithPast:
1081
- hidden_states = self.transformer(
1082
- input_ids, past_key_values=past_key_values, attention_mask=attention_mask
1083
- )
1084
- lm_logits = self.lm_head(hidden_states)
1085
-
1086
- loss = None
1087
- if labels is not None:
1088
- loss = self.loss(lm_logits, labels)
1089
-
1090
- return CausalLMOutputWithPast(
1091
- loss=loss, logits=lm_logits, past_key_values=past_key_values
1092
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/monkeypatch/phi/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Patches to support multipack for phi2
3
+ """
4
+ import transformers
5
+
6
+ from axolotl.monkeypatch.utils import get_unpad_data
7
+
8
+
9
+ def replace_phi_attn_with_multipack_flash_attn():
10
+ transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
11
+ get_unpad_data
12
+ )
src/axolotl/utils/config.py CHANGED
@@ -364,20 +364,6 @@ def validate_config(cfg):
364
  "`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
365
  )
366
 
367
- if cfg.model_type == "MixFormerSequentialForCausalLM" and cfg.adapter is not None:
368
- LOG.warning("Use AutoModelForCausalLM for phi/MixFormer models with qLoRA")
369
-
370
- if cfg.model_config_type == "mixformer-sequential":
371
- if cfg.sample_packing:
372
- if cfg.adapter is not None:
373
- LOG.warning(
374
- "phi/MixFormer models are not currently compatible with LoRA and sample_packing"
375
- )
376
- if cfg.model_type == "AutoModelForCausalLM":
377
- raise ValueError(
378
- "`model_type: MixFormerSequentialForCausalLM` required for sample_packing"
379
- )
380
-
381
  if cfg.datasets:
382
  for idx, ds_cfg in enumerate(cfg.datasets):
383
  if not ds_cfg.type:
 
364
  "`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
365
  )
366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  if cfg.datasets:
368
  for idx, ds_cfg in enumerate(cfg.datasets):
369
  if not ds_cfg.type:
src/axolotl/utils/data.py CHANGED
@@ -397,7 +397,7 @@ def load_tokenized_prepared_datasets(
397
  LOG.info("shuffle merged datasets")
398
  dataset = dataset.shuffle(seed=seed)
399
 
400
- dataset, _ = process_datasets_for_packing(cfg, dataset, None, tokenizer)
401
 
402
  if cfg.local_rank == 0:
403
  LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
 
397
  LOG.info("shuffle merged datasets")
398
  dataset = dataset.shuffle(seed=seed)
399
 
400
+ dataset, _ = process_datasets_for_packing(cfg, dataset, None)
401
 
402
  if cfg.local_rank == 0:
403
  LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
src/axolotl/utils/lora_embeddings.py CHANGED
@@ -7,8 +7,6 @@ def get_linear_embedding_layers(model_type):
7
  """
8
  returns the linear embedding layers needed for loras, dependent on the model arch
9
  """
10
- if model_type == "phi-msft":
11
- return ["embd.wte", "lm_head.linear"]
12
  if model_type == "gpt_neox":
13
  return ["embed_in", "embed_out"]
14
  if model_type == "falcon":
 
7
  """
8
  returns the linear embedding layers needed for loras, dependent on the model arch
9
  """
 
 
10
  if model_type == "gpt_neox":
11
  return ["embed_in", "embed_out"]
12
  if model_type == "falcon":
src/axolotl/utils/models.py CHANGED
@@ -169,6 +169,7 @@ def load_tokenizer(cfg):
169
  # pylint: disable=too-many-boolean-expressions
170
  if (
171
  (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
 
172
  and cfg.adapter
173
  and (
174
  not cfg.lora_modules_to_save
@@ -342,6 +343,12 @@ def load_model(
342
  LOG.info("patching falcon with flash attention")
343
  replace_falcon_attn_with_multipack_flash_attn()
344
 
 
 
 
 
 
 
345
  if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
346
  from axolotl.monkeypatch.qwen2 import (
347
  replace_qwen2_attn_with_multipack_flash_attn,
@@ -448,7 +455,7 @@ def load_model(
448
  "flash_attention_2"
449
  )
450
  else:
451
- if model_config.model_type in ["mixtral", "qwen2", "falcon"]:
452
  model_kwargs["attn_implementation"] = "flash_attention_2"
453
  model_config._attn_implementation = ( # pylint: disable=protected-access
454
  "flash_attention_2"
@@ -458,10 +465,6 @@ def load_model(
458
  model_config._attn_implementation = ( # pylint: disable=protected-access
459
  "eager"
460
  )
461
- if model_config.model_type == "phi-msft":
462
- model_config.flash_attn = True
463
- model_config.flash_rotary = True
464
- model_config.fused_dense = True
465
 
466
  try:
467
  if (
@@ -518,16 +521,6 @@ def load_model(
518
  # device=cfg.device,
519
  # )
520
  # model.train() # sets to train instead of eval mode
521
- elif model_type == "PhiForCausalLM" or model_config.model_type == "phi-msft":
522
- from axolotl.models.phi import PhiForCausalLM
523
-
524
- model = PhiForCausalLM.from_pretrained(
525
- base_model,
526
- config=model_config,
527
- load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
528
- load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
529
- **model_kwargs,
530
- )
531
  elif model_type == "MambaLMHeadModel":
532
  # FIXME this is janky at best and hacked together to make it work
533
  MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
 
169
  # pylint: disable=too-many-boolean-expressions
170
  if (
171
  (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
172
+ and (len(tokenizer.encode(val)) > 1)
173
  and cfg.adapter
174
  and (
175
  not cfg.lora_modules_to_save
 
343
  LOG.info("patching falcon with flash attention")
344
  replace_falcon_attn_with_multipack_flash_attn()
345
 
346
+ if cfg.model_config_type == "phi" and cfg.flash_attention and cfg.sample_packing:
347
+ from axolotl.monkeypatch.phi import replace_phi_attn_with_multipack_flash_attn
348
+
349
+ LOG.info("patching phi with flash attention")
350
+ replace_phi_attn_with_multipack_flash_attn()
351
+
352
  if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
353
  from axolotl.monkeypatch.qwen2 import (
354
  replace_qwen2_attn_with_multipack_flash_attn,
 
455
  "flash_attention_2"
456
  )
457
  else:
458
+ if model_config.model_type in ["mixtral", "qwen2", "falcon", "phi"]:
459
  model_kwargs["attn_implementation"] = "flash_attention_2"
460
  model_config._attn_implementation = ( # pylint: disable=protected-access
461
  "flash_attention_2"
 
465
  model_config._attn_implementation = ( # pylint: disable=protected-access
466
  "eager"
467
  )
 
 
 
 
468
 
469
  try:
470
  if (
 
521
  # device=cfg.device,
522
  # )
523
  # model.train() # sets to train instead of eval mode
 
 
 
 
 
 
 
 
 
 
524
  elif model_type == "MambaLMHeadModel":
525
  # FIXME this is janky at best and hacked together to make it work
526
  MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
src/axolotl/utils/trainer.py CHANGED
@@ -106,19 +106,16 @@ def drop_long_seq(sample, sequence_len=2048):
106
  return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
107
 
108
 
109
- def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
110
  drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
111
  with zero_first(is_main_process()):
112
  if cfg.is_preprocess:
113
  max_input_len = np.max(get_dataset_lengths(train_dataset))
114
  LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
115
 
116
- # Phi doesn't want the attention_mask feature when training
117
  if (
118
- "CodeGenTokenizer" in tokenizer.__class__.__name__
119
- or (cfg.is_mistral_derived_model and cfg.flash_attention)
120
- or cfg.model_config_type == "mamba"
121
- ):
122
  LOG.info("dropping attention_mask column")
123
  train_dataset = train_dataset.remove_columns("attention_mask")
124
  if eval_dataset:
 
106
  return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
107
 
108
 
109
+ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
110
  drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
111
  with zero_first(is_main_process()):
112
  if cfg.is_preprocess:
113
  max_input_len = np.max(get_dataset_lengths(train_dataset))
114
  LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
115
 
 
116
  if (
117
+ cfg.is_mistral_derived_model and cfg.flash_attention
118
+ ) or cfg.model_config_type == "mamba":
 
 
119
  LOG.info("dropping attention_mask column")
120
  train_dataset = train_dataset.remove_columns("attention_mask")
121
  if eval_dataset:
tests/e2e/patched/test_phi_multipack.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ E2E tests for lora llama
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import unittest
8
+ from pathlib import Path
9
+
10
+ from axolotl.cli import load_datasets
11
+ from axolotl.common.cli import TrainerCliArgs
12
+ from axolotl.train import train
13
+ from axolotl.utils.config import normalize_config
14
+ from axolotl.utils.dict import DictDefault
15
+
16
+ from ..utils import with_temp_dir
17
+
18
+ LOG = logging.getLogger("axolotl.tests.e2e")
19
+ os.environ["WANDB_DISABLED"] = "true"
20
+
21
+
22
+ class TestPhiMultipack(unittest.TestCase):
23
+ """
24
+ Test case for Phi2 models
25
+ """
26
+
27
+ @with_temp_dir
28
+ def test_ft_packed(self, temp_dir):
29
+ # pylint: disable=duplicate-code
30
+ cfg = DictDefault(
31
+ {
32
+ "base_model": "microsoft/phi-1_5",
33
+ "model_type": "PhiForCausalLM",
34
+ "tokenizer_type": "AutoTokenizer",
35
+ "sequence_len": 1024,
36
+ "sample_packing": True,
37
+ "flash_attention": True,
38
+ "pad_to_sequence_len": True,
39
+ "load_in_8bit": False,
40
+ "adapter": None,
41
+ "val_set_size": 0.1,
42
+ "special_tokens": {
43
+ "pad_token": "<|endoftext|>",
44
+ },
45
+ "datasets": [
46
+ {
47
+ "path": "mhenrichsen/alpaca_2k_test",
48
+ "type": "alpaca",
49
+ },
50
+ ],
51
+ "dataset_shard_num": 10,
52
+ "dataset_shard_idx": 0,
53
+ "num_epochs": 1,
54
+ "micro_batch_size": 1,
55
+ "gradient_accumulation_steps": 1,
56
+ "output_dir": temp_dir,
57
+ "learning_rate": 0.00001,
58
+ "optimizer": "adamw_bnb_8bit",
59
+ "lr_scheduler": "cosine",
60
+ "max_steps": 20,
61
+ "eval_steps": 10,
62
+ "save_steps": 10,
63
+ "bf16": "auto",
64
+ }
65
+ )
66
+
67
+ normalize_config(cfg)
68
+ cli_args = TrainerCliArgs()
69
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
70
+
71
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
72
+ assert (Path(temp_dir) / "pytorch_model.bin").exists()
73
+
74
+ @with_temp_dir
75
+ def test_qlora_packed(self, temp_dir):
76
+ # pylint: disable=duplicate-code
77
+ cfg = DictDefault(
78
+ {
79
+ "base_model": "microsoft/phi-1_5",
80
+ "model_type": "PhiForCausalLM",
81
+ "tokenizer_type": "AutoTokenizer",
82
+ "sequence_len": 1024,
83
+ "sample_packing": True,
84
+ "flash_attention": True,
85
+ "pad_to_sequence_len": True,
86
+ "load_in_8bit": False,
87
+ "adapter": "qlora",
88
+ "lora_r": 64,
89
+ "lora_alpha": 32,
90
+ "lora_dropout": 0.05,
91
+ "lora_target_linear": True,
92
+ "val_set_size": 0.1,
93
+ "special_tokens": {
94
+ "pad_token": "<|endoftext|>",
95
+ },
96
+ "datasets": [
97
+ {
98
+ "path": "mhenrichsen/alpaca_2k_test",
99
+ "type": "alpaca",
100
+ },
101
+ ],
102
+ "dataset_shard_num": 10,
103
+ "dataset_shard_idx": 0,
104
+ "num_epochs": 1,
105
+ "micro_batch_size": 1,
106
+ "gradient_accumulation_steps": 1,
107
+ "output_dir": temp_dir,
108
+ "learning_rate": 0.00001,
109
+ "optimizer": "adamw_bnb_8bit",
110
+ "lr_scheduler": "cosine",
111
+ "max_steps": 20,
112
+ "eval_steps": 10,
113
+ "save_steps": 10,
114
+ "bf16": "auto",
115
+ }
116
+ )
117
+
118
+ normalize_config(cfg)
119
+ cli_args = TrainerCliArgs()
120
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
121
+
122
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
123
+ assert (Path(temp_dir) / "adapter_model.bin").exists()
tests/e2e/test_phi.py CHANGED
@@ -7,9 +7,6 @@ import os
7
  import unittest
8
  from pathlib import Path
9
 
10
- import pytest
11
- from transformers.utils import is_torch_bf16_gpu_available
12
-
13
  from axolotl.cli import load_datasets
14
  from axolotl.common.cli import TrainerCliArgs
15
  from axolotl.train import train
@@ -27,17 +24,15 @@ class TestPhi(unittest.TestCase):
27
  Test case for Phi2 models
28
  """
29
 
30
- @pytest.mark.skip(reason="fixme later")
31
  @with_temp_dir
32
- def test_phi2_ft(self, temp_dir):
33
  # pylint: disable=duplicate-code
34
  cfg = DictDefault(
35
  {
36
- "base_model": "microsoft/phi-2",
37
- "trust_remote_code": True,
38
  "model_type": "AutoModelForCausalLM",
39
  "tokenizer_type": "AutoTokenizer",
40
- "sequence_len": 512,
41
  "sample_packing": False,
42
  "load_in_8bit": False,
43
  "adapter": None,
@@ -64,13 +59,9 @@ class TestPhi(unittest.TestCase):
64
  "max_steps": 10,
65
  "save_steps": 10,
66
  "eval_steps": 10,
67
- "save_safetensors": True,
68
  }
69
  )
70
- if is_torch_bf16_gpu_available():
71
- cfg.bf16 = True
72
- else:
73
- cfg.fp16 = True
74
  normalize_config(cfg)
75
  cli_args = TrainerCliArgs()
76
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -78,25 +69,24 @@ class TestPhi(unittest.TestCase):
78
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
79
  assert (Path(temp_dir) / "pytorch_model.bin").exists()
80
 
81
- @pytest.mark.skip(reason="multipack no longer supported atm")
82
  @with_temp_dir
83
- def test_ft_packed(self, temp_dir):
84
  # pylint: disable=duplicate-code
85
  cfg = DictDefault(
86
  {
87
- "base_model": "microsoft/phi-2",
88
- "trust_remote_code": True,
89
- "model_type": "PhiForCausalLM",
90
  "tokenizer_type": "AutoTokenizer",
91
- "sequence_len": 512,
92
- "sample_packing": True,
93
  "load_in_8bit": False,
94
- "adapter": None,
 
 
 
 
95
  "val_set_size": 0.1,
96
  "special_tokens": {
97
- "unk_token": "<|endoftext|>",
98
- "bos_token": "<|endoftext|>",
99
- "eos_token": "<|endoftext|>",
100
  "pad_token": "<|endoftext|>",
101
  },
102
  "datasets": [
@@ -112,18 +102,18 @@ class TestPhi(unittest.TestCase):
112
  "gradient_accumulation_steps": 1,
113
  "output_dir": temp_dir,
114
  "learning_rate": 0.00001,
115
- "optimizer": "adamw_bnb_8bit",
116
  "lr_scheduler": "cosine",
 
 
 
 
 
117
  }
118
  )
119
- if is_torch_bf16_gpu_available():
120
- cfg.bf16 = True
121
- else:
122
- cfg.fp16 = True
123
-
124
  normalize_config(cfg)
125
  cli_args = TrainerCliArgs()
126
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
127
 
128
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
129
- assert (Path(temp_dir) / "pytorch_model.bin").exists()
 
7
  import unittest
8
  from pathlib import Path
9
 
 
 
 
10
  from axolotl.cli import load_datasets
11
  from axolotl.common.cli import TrainerCliArgs
12
  from axolotl.train import train
 
24
  Test case for Phi2 models
25
  """
26
 
 
27
  @with_temp_dir
28
+ def test_phi_ft(self, temp_dir):
29
  # pylint: disable=duplicate-code
30
  cfg = DictDefault(
31
  {
32
+ "base_model": "microsoft/phi-1_5",
 
33
  "model_type": "AutoModelForCausalLM",
34
  "tokenizer_type": "AutoTokenizer",
35
+ "sequence_len": 2048,
36
  "sample_packing": False,
37
  "load_in_8bit": False,
38
  "adapter": None,
 
59
  "max_steps": 10,
60
  "save_steps": 10,
61
  "eval_steps": 10,
62
+ "bf16": "auto",
63
  }
64
  )
 
 
 
 
65
  normalize_config(cfg)
66
  cli_args = TrainerCliArgs()
67
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
 
69
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
70
  assert (Path(temp_dir) / "pytorch_model.bin").exists()
71
 
 
72
  @with_temp_dir
73
+ def test_phi_qlora(self, temp_dir):
74
  # pylint: disable=duplicate-code
75
  cfg = DictDefault(
76
  {
77
+ "base_model": "microsoft/phi-1_5",
78
+ "model_type": "AutoModelForCausalLM",
 
79
  "tokenizer_type": "AutoTokenizer",
80
+ "sequence_len": 2048,
81
+ "sample_packing": False,
82
  "load_in_8bit": False,
83
+ "adapter": "qlora",
84
+ "lora_r": 64,
85
+ "lora_alpha": 32,
86
+ "lora_dropout": 0.05,
87
+ "lora_target_linear": True,
88
  "val_set_size": 0.1,
89
  "special_tokens": {
 
 
 
90
  "pad_token": "<|endoftext|>",
91
  },
92
  "datasets": [
 
102
  "gradient_accumulation_steps": 1,
103
  "output_dir": temp_dir,
104
  "learning_rate": 0.00001,
105
+ "optimizer": "paged_adamw_8bit",
106
  "lr_scheduler": "cosine",
107
+ "flash_attention": True,
108
+ "max_steps": 10,
109
+ "save_steps": 10,
110
+ "eval_steps": 10,
111
+ "bf16": "auto",
112
  }
113
  )
 
 
 
 
 
114
  normalize_config(cfg)
115
  cli_args = TrainerCliArgs()
116
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
117
 
118
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
119
+ assert (Path(temp_dir) / "adapter_model.bin").exists()
tests/test_validation.py CHANGED
@@ -742,11 +742,11 @@ class ValidationCheckModelConfig(BaseValidation):
742
 
743
  check_model_config(cfg, model_config)
744
 
745
- def test_phi2_add_tokens_adapter(self):
746
  cfg = DictDefault(
747
  {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
748
  )
749
- model_config = DictDefault({"model_type": "phi-msft"})
750
 
751
  with pytest.raises(
752
  ValueError,
@@ -759,7 +759,7 @@ class ValidationCheckModelConfig(BaseValidation):
759
  "adapter": "qlora",
760
  "load_in_4bit": True,
761
  "tokens": ["<|imstart|>"],
762
- "lora_modules_to_save": ["embed_tokens", "lm_head"],
763
  }
764
  )
765
 
@@ -774,7 +774,7 @@ class ValidationCheckModelConfig(BaseValidation):
774
  "adapter": "qlora",
775
  "load_in_4bit": True,
776
  "tokens": ["<|imstart|>"],
777
- "lora_modules_to_save": ["embd.wte", "lm_head.linear"],
778
  }
779
  )
780
 
 
742
 
743
  check_model_config(cfg, model_config)
744
 
745
+ def test_phi_add_tokens_adapter(self):
746
  cfg = DictDefault(
747
  {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
748
  )
749
+ model_config = DictDefault({"model_type": "phi"})
750
 
751
  with pytest.raises(
752
  ValueError,
 
759
  "adapter": "qlora",
760
  "load_in_4bit": True,
761
  "tokens": ["<|imstart|>"],
762
+ "lora_modules_to_save": ["embd.wte", "lm_head.linear"],
763
  }
764
  )
765
 
 
774
  "adapter": "qlora",
775
  "load_in_4bit": True,
776
  "tokens": ["<|imstart|>"],
777
+ "lora_modules_to_save": ["embed_tokens", "lm_head"],
778
  }
779
  )
780