typeof commited on
Commit
c86e829
1 Parent(s): 32207bc

Upload folder using huggingface_hub (#1)

Browse files

- Upload folder using huggingface_hub (2152b0d52733c0928d65baed86c92ec19f43793d)

README.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ language:
4
+ - en
5
+ pipeline_tag: text-generation
6
+ ---
7
+ ## Model Summary
8
+
9
+ The language model phi-1.5 is a Transformer with 1.3 billion parameters. It was trained using the same data sources as [phi-1](https://huggingface.co/microsoft/phi-1), augmented with a new data source that consists of various NLP synthetic texts. When assessed against benchmarks testing common sense, language understanding, and logical reasoning, phi-1.5 demonstrates a nearly state-of-the-art performance among models with less than 10 billion parameters.
10
+
11
+ We did not fine-tune phi-1.5 either for instruction following or through reinforcement learning from human feedback. The intention behind crafting this open-source model is to provide the research community with a non-restricted small model to explore vital safety challenges, such as reducing toxicity, understanding societal biases, enhancing controllability, and more.
12
+
13
+ For a safer model release, we exclude generic web-crawl data sources such as common-crawl from the training. This strategy prevents direct exposure to potentially harmful online content, enhancing the model's safety without RLHF. However, the model is still vulnerable to generating harmful content. We hope the model can help the research community to further study the safety of language models.
14
+
15
+ ## Intended Uses
16
+ Given the nature of the training data, phi-1.5 is best suited for prompts using the QA format, the chat format, and the code format. Note that phi-1.5, being a base model, often produces irrelevant text following the main answer. In the following example, we've truncated the answer for illustrative purposes only.
17
+
18
+ #### QA format:
19
+
20
+ ```markdown
21
+ Write a detailed analogy between mathematics and a lighthouse.
22
+
23
+ Answer: Mathematics is like a lighthouse, guiding us through the vast ocean of numbers and calculations. Just as a lighthouse illuminates the darkness, mathematics provides us with a clear path to navigate through complex problems. It helps us make sense of the world around us, just like a lighthouse helps ships find their way home.
24
+ ```
25
+ where the model generates the text after "Answer:".
26
+
27
+ #### Chat format:
28
+
29
+ ```markdown
30
+ Alice: Alice: I don't know why, I'm struggling to maintain focus while studying. Any suggestions?
31
+
32
+ Bob: Have you tried using a timer? It can help you stay on track and avoid distractions.
33
+
34
+ Alice: That's a good idea. I'll give it a try.
35
+
36
+ Charlie: Another thing that can help is to break up your study sessions into smaller chunks. It's easier to concentrate on one thing at a time.
37
+
38
+ Alice: That makes sense. I'll try that too.
39
+
40
+ Bob: And don't forget to take breaks! It's important to give your brain a rest so you can come back to your studies with a fresh perspective.
41
+
42
+ Alice: Thanks for the advice, guys. I feel more motivated now.
43
+
44
+ Charlie: No problem, Alice. We're all in this together.
45
+
46
+ Bob: Yeah, and remember that it's okay to ask for help if you need it. We're here to support each other.
47
+ ```
48
+ where the model generates the text after the first "Bob:".
49
+
50
+ #### Code format:
51
+ ```python
52
+ def print_prime(n):
53
+ """
54
+ Print all primes between 1 and n
55
+ """
56
+ primes = []
57
+ for num in range(2, n+1):
58
+ is_prime = True
59
+ for i in range(2, int(math.sqrt(num))+1):
60
+ if num % i == 0:
61
+ is_prime = False
62
+ break
63
+ if is_prime:
64
+ primes.append(num)
65
+ print(primes)
66
+ ```
67
+ where the model generates the text after the comments.
68
+
69
+ **Notes**
70
+ * phi-1.5 is intended for research purposes. The model-generated text/code should be treated as a starting point rather than a definitive solution for potential use cases. Users should be cautious when employing these models in their applications.
71
+ * Direct adoption for production tasks is out of the scope of this research project. As a result, phi-1.5 has not been tested to ensure that it performs adequately for any production-level application. Please refer to the limitation sections of this document for more details.
72
+
73
+ ## Limitations of phi-1.5
74
+
75
+ * Generate Inaccurate Code and Facts: The model often produces incorrect code snippets and statements. Users should treat these outputs as suggestions or starting points, not as definitive or accurate solutions.
76
+ * Limited Scope for code: If the model generates Python scripts that utilize uncommon packages or scripts in other languages, we strongly recommend users manually verify all API uses.
77
+ * Unreliable Responses to Instruction: The model has not undergone instruction fine-tuning. As a result, it may struggle or fail to adhere to intricate or nuanced instructions provided by users.
78
+ * Language Limitations: The model is primarily designed to understand standard English. Informal English, slang, or any other language outside of English might pose challenges to its comprehension, leading to potential misinterpretations or errors in response.
79
+ * Potential Societal Biases: Regardless of the safe data used for its training, the model is not entirely free from societal biases. There's a possibility it may generate content that mirrors these societal biases, particularly if prompted or instructed to do so. We urge users to be aware of this and to exercise caution and critical thinking when interpreting model outputs.
80
+ * Toxicity: Despite that the model is trained with carefully selected data, the model can still produce harmful content if explicitly prompted or instructed to do so. We chose to release the model for research purposes only -- We hope to help the open-source community develop the most effective ways to reduce the toxicity of a model directly after pretraining.
81
+
82
+ ## Training
83
+
84
+ ### Model
85
+ * Architecture: a Transformer-based model with next-word prediction objective
86
+ * Dataset size: 30B tokens
87
+ * Training tokens: 150B tokens
88
+ * Precision: fp16
89
+ * GPUs: 32xA100-40G
90
+ * Training time: 8 days
91
+
92
+ ### Software
93
+ * [PyTorch](https://github.com/pytorch/pytorch)
94
+ * [DeepSpeed](https://github.com/microsoft/DeepSpeed)
95
+ * [flash-attention](https://github.com/HazyResearch/flash-attention)
96
+
97
+ ### License
98
+ The model is licensed under the [Research License](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx).
99
+
100
+ ### Sample Code
101
+ ```python
102
+ import torch
103
+ from transformers import AutoModelForCausalLM, AutoTokenizer
104
+
105
+ model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5", trust_remote_code=True, torch_dtype="auto")
106
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5", trust_remote_code=True, torch_dtype="auto")
107
+ inputs = tokenizer('''```python
108
+ def print_prime(n):
109
+ """
110
+ Print all primes between 1 and n
111
+ """''', return_tensors="pt", return_attention_mask=False)
112
+
113
+ outputs = model.generate(**inputs, max_length=200)
114
+ text = tokenizer.batch_decode(outputs)[0]
115
+ print(text)
116
+ ```
117
+
118
+ **Remark.** In the generation function, our model currently does not support beam search (`num_beams` >1) and `attention_mask' parameters.
119
+ Furthermore, in the forward pass of the model, we currently do not support outputing hidden states or attention values, or using custom input embeddings (instead of the model's).
120
+
121
+ ### Citation
122
+ ```bib
123
+ @article{textbooks2,
124
+ title={Textbooks Are All You Need II: \textbf{phi-1.5} technical report},
125
+ author={Li, Yuanzhi and Bubeck, S{\'e}bastien and Eldan, Ronen and Del Giorno, Allie and Gunasekar, Suriya and Lee, Yin Tat},
126
+ journal={arXiv preprint arXiv:2309.05463},
127
+ year={2023}
128
+ }
129
+ ```
Research License.docx ADDED
Binary file (38.9 kB). View file
 
added_tokens.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "\t\t": 50294,
3
+ "\t\t\t": 50293,
4
+ "\t\t\t\t": 50292,
5
+ "\t\t\t\t\t": 50291,
6
+ "\t\t\t\t\t\t": 50290,
7
+ "\t\t\t\t\t\t\t": 50289,
8
+ "\t\t\t\t\t\t\t\t": 50288,
9
+ "\t\t\t\t\t\t\t\t\t": 50287,
10
+ " ": 50286,
11
+ " ": 50285,
12
+ " ": 50284,
13
+ " ": 50283,
14
+ " ": 50282,
15
+ " ": 50281,
16
+ " ": 50280,
17
+ " ": 50279,
18
+ " ": 50278,
19
+ " ": 50277,
20
+ " ": 50276,
21
+ " ": 50275,
22
+ " ": 50274,
23
+ " ": 50273,
24
+ " ": 50272,
25
+ " ": 50271,
26
+ " ": 50270,
27
+ " ": 50269,
28
+ " ": 50268,
29
+ " ": 50267,
30
+ " ": 50266,
31
+ " ": 50265,
32
+ " ": 50264,
33
+ " ": 50263,
34
+ " ": 50262,
35
+ " ": 50261,
36
+ " ": 50260,
37
+ " ": 50259,
38
+ " ": 50258,
39
+ " ": 50257
40
+ }
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "phi-1.5-half",
3
+ "activation_function": "gelu_new",
4
+ "architecture": {
5
+ "block_cls": "parallel",
6
+ "mixer": {},
7
+ "mlp": {
8
+ "mlp_cls": "mlp"
9
+ }
10
+ },
11
+ "architectures": [
12
+ "MixFormerSequentialForCausalLM"
13
+ ],
14
+ "auto_map": {
15
+ "AutoConfig": "configuration_mixformer_sequential.MixFormerSequentialConfig",
16
+ "AutoModelForCausalLM": "modeling_mixformer_sequential.MixFormerSequentialForCausalLM"
17
+ },
18
+ "embd_layer": "default",
19
+ "embd_pdrop": 0.0,
20
+ "initializer_range": 0.02,
21
+ "layer_norm_epsilon": 1e-05,
22
+ "model_type": "mixformer-sequential",
23
+ "n_embd": 2048,
24
+ "n_head": 32,
25
+ "n_inner": null,
26
+ "n_layer": 24,
27
+ "n_positions": 2048,
28
+ "phyagi_version": "0.0.4.dev",
29
+ "resid_pdrop": 0.0,
30
+ "rotary_dim": 32,
31
+ "tie_word_embeddings": false,
32
+ "torch_dtype": "float16",
33
+ "transformers_version": "4.32.1",
34
+ "vocab_size": 51200
35
+ }
configuration_mixformer_sequential.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+
4
+ import math
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ from transformers import PretrainedConfig
8
+
9
+
10
+ class MixFormerSequentialConfig(PretrainedConfig):
11
+ """MixFormer (sequential for DeepSpeed) configuration."""
12
+
13
+ model_type = "mixformer-sequential"
14
+
15
+ attribute_map = {
16
+ "max_position_embeddings": "n_positions",
17
+ "hidden_size": "n_embd",
18
+ "num_attention_heads": "n_head",
19
+ "num_hidden_layers": "n_layer",
20
+ "input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility
21
+ "blocks": "architecture", # `blocks` key is for backward compatibility
22
+ }
23
+
24
+ def __init__(
25
+ self,
26
+ vocab_size: Optional[int] = 50304,
27
+ n_positions: Optional[int] = 2048,
28
+ n_embd: Optional[int] = 1024,
29
+ n_layer: Optional[int] = 20,
30
+ n_inner: Optional[int] = None,
31
+ n_head: Optional[int] = 16,
32
+ rotary_dim: Optional[int] = 32,
33
+ activation_function: Optional[str] = "gelu_new",
34
+ embd_layer: Optional[str] = "default",
35
+ architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None,
36
+ embd_pdrop: Optional[float] = 0.0,
37
+ resid_pdrop: Optional[float] = 0.0,
38
+ layer_norm_epsilon: Optional[float] = 1e-5,
39
+ initializer_range: Optional[float] = 0.02,
40
+ tie_word_embeddings: Optional[bool] = False,
41
+ pad_vocab_size_multiple: Optional[int] = 64,
42
+ **kwargs
43
+ ) -> None:
44
+ self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
45
+ self.n_positions = n_positions
46
+ self.n_embd = n_embd
47
+ self.n_layer = n_layer
48
+ self.n_inner = n_inner
49
+ self.n_head = n_head
50
+ self.rotary_dim = min(rotary_dim, n_embd // n_head)
51
+ self.activation_function = activation_function
52
+ self.embd_layer = embd_layer
53
+ self.architecture = architecture
54
+ self.embd_pdrop = embd_pdrop
55
+ self.resid_pdrop = resid_pdrop
56
+ self.layer_norm_epsilon = layer_norm_epsilon
57
+ self.initializer_range = initializer_range
58
+
59
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.32.1"
4
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
modeling_mixformer_sequential.py ADDED
@@ -0,0 +1,784 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+
4
+ # BSD 3-Clause License
5
+ #
6
+ # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
7
+ # All rights reserved.
8
+ #
9
+ # Redistribution and use in source and binary forms, with or without
10
+ # modification, are permitted provided that the following conditions are met:
11
+ #
12
+ # * Redistributions of source code must retain the above copyright notice, this
13
+ # list of conditions and the following disclaimer.
14
+ #
15
+ # * Redistributions in binary form must reproduce the above copyright notice,
16
+ # this list of conditions and the following disclaimer in the documentation
17
+ # and/or other materials provided with the distribution.
18
+ #
19
+ # * Neither the name of the copyright holder nor the names of its
20
+ # contributors may be used to endorse or promote products derived from
21
+ # this software without specific prior written permission.
22
+ #
23
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
24
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
25
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
26
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
27
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
28
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
29
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
31
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
+
34
+ from __future__ import annotations
35
+
36
+ import math
37
+ import copy
38
+ from typing import Any, Dict, Optional, Tuple
39
+ from dataclasses import dataclass, field
40
+
41
+ import torch
42
+ import torch.nn as nn
43
+
44
+ from einops import rearrange
45
+ from transformers.activations import ACT2FN
46
+ from transformers import PretrainedConfig, PreTrainedModel
47
+ from transformers.modeling_outputs import CausalLMOutputWithPast
48
+
49
+ from .configuration_mixformer_sequential import MixFormerSequentialConfig
50
+
51
+ class LayerNorm(nn.LayerNorm):
52
+ def forward(self, x):
53
+ t = x.dtype
54
+ x = super().forward(x.type(torch.float32))
55
+ return x.type(t)
56
+
57
+ @dataclass
58
+ class InferenceParams:
59
+ """Inference parameters that are passed to the main model in order
60
+ to efficienly calculate and store the context during inference.
61
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
62
+ max_sequence_len: int
63
+ max_batch_size: int
64
+ sequence_len_offset: int = 0
65
+ batch_size_offset: int = 0
66
+ key_value_memory_dict: dict = field(default_factory=dict)
67
+ fused_ft_kernel: bool = False
68
+ lengths_per_sample: Optional[torch.Tensor] = None
69
+
70
+
71
+ class Embedding(nn.Module):
72
+ """Token embedding with dropout."""
73
+
74
+ def __init__(self, config: PretrainedConfig) -> None:
75
+ super().__init__()
76
+
77
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
78
+ self.drop = nn.Dropout(config.embd_pdrop)
79
+
80
+ def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
81
+ input_shape = input_ids.size()
82
+ input_ids = input_ids.view(-1, input_shape[-1])
83
+
84
+ hidden_states = self.wte(input_ids)
85
+ hidden_states = self.drop(hidden_states)
86
+
87
+ return hidden_states
88
+
89
+ class RotaryEmbedding(nn.Module):
90
+ """PyTorch implementation of `flash-attn` RotaryEmbedding layer.
91
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
92
+
93
+ def __init__(
94
+ self,
95
+ dim: int,
96
+ base: Optional[int] = 10000,
97
+ scale_base: Optional[float] = None,
98
+ device: Optional[str] = None,
99
+ **kwargs,
100
+ ) -> None:
101
+ super().__init__()
102
+
103
+ if scale_base is not None:
104
+ raise NotImplementedError
105
+
106
+ # Generate and save the inverse frequency buffer (non-trainable)
107
+ self.dim = dim
108
+ self.base = base
109
+ self.scale_base = scale_base
110
+ self.device = device
111
+
112
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
113
+ self.register_buffer("inv_freq", inv_freq)
114
+
115
+ scale = (
116
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
117
+ if scale_base is not None
118
+ else None
119
+ )
120
+ self.register_buffer("scale", scale)
121
+
122
+ self._seq_len_cached = 0
123
+ self._cos_cached = None
124
+ self._sin_cached = None
125
+ self._cos_k_cached = None
126
+ self._sin_k_cached = None
127
+
128
+ def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: Optional[int] = 0) -> None:
129
+ # Reset the tables if the sequence length has changed,
130
+ # or if we're on a new device (possibly due to tracing for instance)
131
+ seqlen = x.shape[1] + seqlen_offset
132
+
133
+ # Re-generate the inverse frequency buffer if it's not fp32
134
+ # (for instance if model.half() was called)
135
+ if self.inv_freq.dtype != "torch.float32":
136
+ self.inv_freq = 1.0 / (
137
+ self.base ** (torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) / self.dim)
138
+ )
139
+
140
+ if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
141
+ self._seq_len_cached = seqlen
142
+ t = torch.arange(seqlen, device=x.device, dtype=torch.float32)
143
+
144
+ # Don't do einsum, it converts fp32 to fp16
145
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
146
+ freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32))
147
+ if self.scale is None:
148
+ self._cos_cached = torch.cos(freqs).to(x.dtype)
149
+ self._sin_cached = torch.sin(freqs).to(x.dtype)
150
+ else:
151
+ power = (
152
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
153
+ ) / self.scale_base
154
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
155
+
156
+ # We want the multiplication by scale to happen in fp32
157
+ self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
158
+ self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
159
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
160
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
161
+
162
+ def apply_rotary_emb_qkv(
163
+ self,
164
+ qkv: torch.FloatTensor,
165
+ sin: torch.FloatTensor,
166
+ cos: torch.FloatTensor,
167
+ sin_k: Optional[torch.FloatTensor] = None,
168
+ cos_k: Optional[torch.FloatTensor] = None,
169
+ ) -> torch.FloatTensor:
170
+ _, seqlen, three, _, headdim = qkv.shape
171
+ assert three == 3
172
+
173
+ rotary_seqlen, rotary_dim = cos.shape
174
+ rotary_dim *= 2
175
+ assert rotary_dim <= headdim
176
+ assert seqlen <= rotary_seqlen
177
+
178
+ cos_k = cos if cos_k is None else cos_k
179
+ sin_k = sin if sin_k is None else sin_k
180
+ assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
181
+
182
+ q_rot = qkv[:, :, 0, :, :rotary_dim]
183
+ q_pass = qkv[:, :, 0, :, rotary_dim:]
184
+
185
+ k_rot = qkv[:, :, 1, :, :rotary_dim]
186
+ k_pass = qkv[:, :, 1, :, rotary_dim:]
187
+
188
+ # Splits the queries and keys in half
189
+ q1, q2 = q_rot.chunk(2, dim=-1)
190
+ k1, k2 = k_rot.chunk(2, dim=-1)
191
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
192
+
193
+ # Casts to fp32 are necessary to prevent fp16 overflow issues
194
+ q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
195
+
196
+ # Computes the new keys and queries, recasting to original dtype
197
+ q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
198
+
199
+ k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
200
+
201
+ return torch.cat(
202
+ [
203
+ torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
204
+ torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
205
+ qkv[:, :, 2:3, :, :],
206
+ ],
207
+ axis=2,
208
+ )
209
+
210
+ def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
211
+ """Perform the forward pass.
212
+
213
+ Args:
214
+ qkv: Query, key and value tensors of shape (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim).
215
+ seqlen_offset: Used in generation where the passed `qkv` is only the last token in the batch.
216
+
217
+ Returns:
218
+ New `qkv` and the cached sinusoids.
219
+
220
+ """
221
+
222
+ self._update_cos_sin_cache(qkv, seqlen_offset)
223
+
224
+ return self.apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:])
225
+
226
+ def _update_kv_cache(kv, inference_params, layer_idx):
227
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
228
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
229
+ # Pre-allocate memory for key-values for inference.
230
+ num_heads, head_dim = kv.shape[-2:]
231
+ if layer_idx not in inference_params.key_value_memory_dict:
232
+ kv_cache = torch.empty(
233
+ inference_params.max_batch_size, inference_params.max_sequence_len, 2,
234
+ num_heads, head_dim, dtype=kv.dtype, device=kv.device
235
+ )
236
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
237
+ else:
238
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
239
+
240
+ # Adjust key and value for inference
241
+ batch_start = inference_params.batch_size_offset
242
+ batch_end = batch_start + kv.shape[0]
243
+ sequence_start = inference_params.sequence_len_offset
244
+ sequence_end = sequence_start + kv.shape[1]
245
+ assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
246
+ assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
247
+
248
+ assert kv_cache is not None
249
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
250
+ kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
251
+ return kv
252
+
253
+
254
+ class MLP(nn.Module):
255
+ """Multi-Layer Perceptron.
256
+
257
+ Reference:
258
+ Attention Is All You Need.
259
+ https://arxiv.org/pdf/1706.03762.pdf.
260
+
261
+ """
262
+
263
+ def __init__(self, config: PretrainedConfig, n_inner: Optional[int] = None, act_fn: Optional[str] = None) -> None:
264
+ super().__init__()
265
+
266
+ act_fn = config.activation_function if act_fn is None else act_fn
267
+ assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
268
+
269
+ n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
270
+ n_inner = n_inner if n_inner is not None else 4 * config.n_embd
271
+
272
+ self.fc1 = nn.Linear(config.n_embd, n_inner)
273
+ self.fc2 = nn.Linear(n_inner, config.n_embd)
274
+ self.act = ACT2FN[act_fn]
275
+
276
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
277
+ old_keys = [prefix + "fc_in.weight", prefix + "fc_out.weight", prefix + "fc_in.bias", prefix + "fc_out.bias"]
278
+ new_keys = [prefix + "fc1.weight", prefix + "fc2.weight", prefix + "fc1.bias", prefix + "fc2.bias"]
279
+
280
+ if all(k in state_dict for k in old_keys) and not all(k in state_dict for k in new_keys):
281
+ # Older version of `MLP` saved with different key names.
282
+ for old_key, new_key in zip(old_keys, new_keys):
283
+ state_dict[new_key] = state_dict.pop(old_key)
284
+
285
+ return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
286
+
287
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
288
+ hidden_states = self.fc1(hidden_states)
289
+ hidden_states = self.act(hidden_states)
290
+ hidden_states = self.fc2(hidden_states)
291
+
292
+ return hidden_states
293
+
294
+
295
+ class FusedMLP(nn.Module):
296
+ """Fused Multi-Layer Perceptron from `flash-attn`.
297
+
298
+ Reference:
299
+ https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/ops/fused_dense.py.
300
+
301
+ """
302
+ def __init__(self, config: PretrainedConfig, n_inner: Optional[int] = None, act_fn: Optional[str] = None,
303
+ raise_on_missing: bool = False) -> None:
304
+ super().__init__()
305
+
306
+ act_fn = config.activation_function if act_fn is None else act_fn
307
+ assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
308
+
309
+ n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
310
+ n_inner = n_inner if n_inner is not None else 4 * config.n_embd
311
+
312
+ gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"]
313
+ activation = "gelu_approx" if act_fn in gelu_activations else "relu"
314
+
315
+ self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn)
316
+
317
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
318
+ return self.mlp(hidden_states)
319
+
320
+ class SelfAttention(nn.Module):
321
+ """Implement the scaled dot product attention with softmax.
322
+ Adapted from https://github.com/Dao-AILab/flash-attention.
323
+ Arguments
324
+ ---------
325
+ softmax_scale: The temperature to use for the softmax attention.
326
+ (default: 1/sqrt(d_keys) where d_keys is computed at
327
+ runtime)
328
+ attention_dropout: The dropout rate to apply to the attention
329
+ (default: 0.0)
330
+ """
331
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
332
+ super().__init__()
333
+ self.causal = causal
334
+ self.softmax_scale = softmax_scale
335
+ self.drop = nn.Dropout(attention_dropout)
336
+
337
+ def forward(self, qkv, causal=None, key_padding_mask=None):
338
+ """Implements the multihead softmax attention.
339
+ Arguments
340
+ ---------
341
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
342
+ causal: if passed, will override self.causal
343
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
344
+ False means to mask out. (B, S)
345
+ """
346
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
347
+ causal = self.causal if causal is None else causal
348
+ q, k, v = qkv.unbind(dim=2)
349
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
350
+ scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
351
+ if key_padding_mask is not None:
352
+ padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype,
353
+ device=scores.device)
354
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
355
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
356
+ scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
357
+ if causal:
358
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
359
+ # So we have to construct the mask in float
360
+ causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
361
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
362
+ scores = scores + causal_mask.to(dtype=scores.dtype)
363
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
364
+ attention_drop = self.drop(attention)
365
+ output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
366
+ return output
367
+
368
+
369
+ class CrossAttention(nn.Module):
370
+ """Implement the scaled dot product attention with softmax.
371
+ Adapted from https://github.com/Dao-AILab/flash-attention.
372
+ Arguments
373
+ ---------
374
+ softmax_scale: The temperature to use for the softmax attention.
375
+ (default: 1/sqrt(d_keys) where d_keys is computed at
376
+ runtime)
377
+ attention_dropout: The dropout rate to apply to the attention
378
+ (default: 0.0)
379
+ """
380
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
381
+ super().__init__()
382
+ self.causal = causal
383
+ self.softmax_scale = softmax_scale
384
+ self.drop = nn.Dropout(attention_dropout)
385
+
386
+ def forward(self, q, kv, causal=None, key_padding_mask=None):
387
+ """Implements the multihead softmax attention.
388
+ Arguments
389
+ ---------
390
+ q: The tensor containing the query. (B, Sq, H, D)
391
+ kv: The tensor containing the key and value. (B, Sk, 2, H, D)
392
+ causal: if passed, will override self.causal
393
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
394
+ False means to mask out. (B, Sk)
395
+ """
396
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
397
+ causal = self.causal if causal is None else causal
398
+ seqlen_k = kv.shape[1]
399
+ assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
400
+ k, v = kv.unbind(dim=2)
401
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
402
+ scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
403
+ if key_padding_mask is not None:
404
+ padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype,
405
+ device=scores.device)
406
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
407
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
408
+ scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
409
+ if causal:
410
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
411
+ # So we have to construct the mask in float
412
+ causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0,
413
+ device=scores.device), 1)
414
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
415
+ scores = scores + causal_mask.to(dtype=scores.dtype)
416
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
417
+ attention_drop = self.drop(attention)
418
+ output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
419
+ return output
420
+
421
+ def find_mha_dims(
422
+ config: PretrainedConfig, n_head: Optional[int] = None, head_dim: Optional[int] = None
423
+ ) -> Tuple[int, int]:
424
+ """Validate and return the number of heads and head dimension for multi-head attention.
425
+
426
+ Args:
427
+ config: Model configuration.
428
+ n_head: Number of heads.
429
+ head_dim: Head dimension.
430
+
431
+ Returns:
432
+ Number of heads and head dimension.
433
+
434
+ """
435
+
436
+ assert all(
437
+ hasattr(config, attr) for attr in ["n_embd", "n_head"]
438
+ ), "`config` must have `n_embd` and `n_head` attributes."
439
+
440
+ if head_dim is None:
441
+ assert (
442
+ config.n_embd % config.n_head == 0
443
+ ), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
444
+
445
+ if n_head is None and head_dim is None:
446
+ head_dim = config.n_embd // config.n_head
447
+ n_head = config.n_head
448
+ elif n_head is None or head_dim is None:
449
+ raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
450
+
451
+ return n_head, head_dim
452
+
453
+
454
+ class MHA(nn.Module):
455
+ """Multi-head attention layer.
456
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
457
+
458
+ def __init__(
459
+ self,
460
+ config: PretrainedConfig,
461
+ rotary_dim: Optional[int] = None,
462
+ n_head: Optional[int] = None,
463
+ head_dim: Optional[int] = None,
464
+ bias: Optional[bool] = True,
465
+ dropout: Optional[float] = 0.0,
466
+ softmax_scale: Optional[float] = None,
467
+ causal: Optional[bool] = True,
468
+ layer_idx: Optional[int] = None,
469
+ rotary_emb_scale_base: Optional[float] = None,
470
+ return_residual: Optional[bool] = False,
471
+ checkpointing: Optional[bool] = False,
472
+ device: Optional[str] = None,
473
+ dtype: Optional[torch.dtype] = None,
474
+ fused_dense: Optional[bool] = True,
475
+ flash_attn: Optional[bool] = True,
476
+ cutlass_attn: Optional[bool] = False,
477
+ flash_rotary: Optional[bool] = True,
478
+ raise_on_missing: Optional[bool] = False
479
+ ) -> None:
480
+ super().__init__()
481
+
482
+ factory_kwargs = {"device": device, "dtype": dtype}
483
+ n_head, head_dim = find_mha_dims(config, n_head, head_dim)
484
+
485
+ self.hidden_size = config.n_embd
486
+ self.n_head = n_head
487
+ self.head_dim = head_dim
488
+ self.op_size = n_head * head_dim
489
+
490
+ self.causal = causal
491
+ self.layer_idx = layer_idx
492
+ self.rotary_emb_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
493
+ self.fused_dense = fused_dense
494
+ self.flash_attn = flash_attn
495
+ self.cutlass_attn = cutlass_attn
496
+ self.flash_rotary = flash_rotary
497
+ self.return_residual = return_residual
498
+ self.checkpointing = checkpointing
499
+
500
+ if self.rotary_emb_dim > 0:
501
+ rotary_kwargs = {"device": device}
502
+ if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
503
+ rotary_kwargs["scale_base"] = rotary_emb_scale_base
504
+
505
+ self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
506
+ else:
507
+ pass
508
+
509
+ self.Wqkv = nn.Linear(self.hidden_size, 3 * self.op_size, bias=bias, **factory_kwargs)
510
+ self.out_proj = nn.Linear(self.op_size, self.hidden_size, bias=bias, **factory_kwargs)
511
+
512
+ self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
513
+ self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
514
+
515
+ def _update_kv_cache(self, kv: torch.FloatTensor, inference_params: InferenceParams) -> None:
516
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
517
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
518
+
519
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
520
+
521
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
522
+
523
+ def forward(
524
+ self,
525
+ x: torch.FloatTensor,
526
+ x_kv: Optional[torch.FloatTensor] = None,
527
+ key_padding_mask: Optional[torch.BoolTensor] = None,
528
+ cu_seqlens: Optional[torch.LongTensor] = None,
529
+ max_seqlen: Optional[int] = None,
530
+ mixer_subset: Optional[torch.LongTensor] = None,
531
+ past_cache: Optional[InferenceParams] = None,
532
+ **kwargs
533
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
534
+ """Perform the forward pass.
535
+
536
+ Args:
537
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
538
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
539
+ is the is the sum of the sequence lengths in the batch.
540
+ x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
541
+ key_padding_mask: boolean mask, True means to keep, False means to mask out.
542
+ (batch, seqlen). Only applicable when not using FlashAttention.
543
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
544
+ of the sequences in the batch, used to index into x. Only applicable when using
545
+ FlashAttention.
546
+ max_seqlen: int. Maximum sequence length in the batch.
547
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
548
+ before applying the query projection. Useful for e.g., ViT where we only care
549
+ about the CLS token in the last layer.
550
+ past_cache: For generation only.
551
+
552
+ Returns:
553
+ (batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None,
554
+ else (total, hidden_dim) where total is the is the sum of the sequence lengths
555
+ in the batch.
556
+
557
+ """
558
+
559
+ if cu_seqlens is not None:
560
+ assert max_seqlen is not None
561
+ assert key_padding_mask is None
562
+ assert self.flash_attn
563
+ assert self.rotary_emb_dim == 0
564
+
565
+ if key_padding_mask is not None:
566
+ assert cu_seqlens is None
567
+ assert max_seqlen is None
568
+ assert not self.flash_attn
569
+
570
+ if past_cache is not None:
571
+ assert key_padding_mask is None
572
+ assert cu_seqlens is None and max_seqlen is None
573
+
574
+ attn_kwargs = {"key_padding_mask": key_padding_mask}
575
+
576
+ assert x_kv is None and mixer_subset is None
577
+
578
+ qkv = self.Wqkv(x)
579
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
580
+
581
+ if past_cache is None:
582
+ if self.rotary_emb_dim > 0:
583
+ qkv = self.rotary_emb(qkv)
584
+ context = self.inner_attn(qkv, **attn_kwargs)
585
+
586
+ else:
587
+ if self.rotary_emb_dim > 0:
588
+ qkv = self.rotary_emb(qkv, seqlen_offset=past_cache.sequence_len_offset)
589
+ q = qkv[:, :, 0]
590
+ kv = self._update_kv_cache(qkv[:, :, 1:], past_cache)
591
+ # If we're processing the prompt, causal=None (use self.causal).
592
+ # If we're decoding, then causal=False.
593
+ causal = None if past_cache.sequence_len_offset == 0 else False
594
+ context = self.inner_cross_attn(q, kv, causal=causal)
595
+
596
+ out = rearrange(context, "... h d -> ... (h d)")
597
+ out = self.out_proj(out)
598
+
599
+ return out if not self.return_residual else (out, x)
600
+
601
+ class ParallelBlock(nn.Module):
602
+ """Parallel block.
603
+
604
+ This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
605
+
606
+ """
607
+
608
+ def __init__(
609
+ self,
610
+ config: PretrainedConfig,
611
+ mixer: Optional[Dict[str, Any]] = None,
612
+ mlp: Optional[Dict[str, Any]] = None,
613
+ block_idx: Optional[int] = None,
614
+ ) -> None:
615
+ super().__init__()
616
+
617
+ self.ln = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
618
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
619
+ self.block_idx = block_idx
620
+
621
+ self.mixer = MHA(config=config, **mixer, layer_idx=block_idx)
622
+ mlp_cls = mlp.pop('mlp_cls')
623
+ if mlp_cls == 'fused_mlp':
624
+ self.mlp = FusedMLP(config=config, **mlp)
625
+ else:
626
+ self.mlp = MLP(config=config, **mlp)
627
+
628
+ def forward(self, hidden_states: torch.FloatTensor,
629
+ past_cache: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
630
+ residual = hidden_states
631
+ hidden_states = self.ln(hidden_states)
632
+
633
+ attn_outputs = self.mixer(hidden_states, past_cache=past_cache)
634
+ if isinstance(attn_outputs, tuple):
635
+ attn_outputs = attn_outputs[0]
636
+
637
+ attn_outputs = self.resid_dropout(attn_outputs)
638
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
639
+
640
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
641
+
642
+ return hidden_states
643
+
644
+ class CausalLMHead(nn.Module):
645
+ """Causal Language Modeling head.
646
+
647
+ Reference:
648
+ Improving Language Understanding by Generative Pre-Training.
649
+ https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
650
+
651
+ """
652
+
653
+ def __init__(self, config: PretrainedConfig) -> None:
654
+ super().__init__()
655
+
656
+ self.ln = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
657
+ self.linear = nn.Linear(config.n_embd, config.vocab_size)
658
+
659
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
660
+ hidden_states = self.ln(hidden_states)
661
+ logits = self.linear(hidden_states).to(torch.float32)
662
+
663
+ return logits
664
+
665
+
666
+ class CausalLMLoss(nn.Module):
667
+ """Causal Language Modeling loss.
668
+
669
+ Reference:
670
+ Improving Language Understanding by Generative Pre-Training.
671
+ https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
672
+
673
+ """
674
+
675
+ def __init__(self, shift_labels: Optional[bool] = True) -> None:
676
+ super().__init__()
677
+
678
+ self.shift_labels = shift_labels
679
+ self.loss_fct = nn.CrossEntropyLoss()
680
+
681
+ def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor:
682
+ if self.shift_labels:
683
+ logits = logits[..., :-1, :].contiguous()
684
+ labels = labels[..., 1:].contiguous()
685
+
686
+ loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
687
+
688
+ return loss
689
+
690
+ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
691
+ """MixFormer (sequential for DeepSpeed) pre-trained model."""
692
+
693
+ config_class = MixFormerSequentialConfig
694
+ base_model_prefix = "transformer"
695
+ supports_gradient_checkpointing = True
696
+
697
+ def __init__(self, *inputs, **kwargs) -> None:
698
+ super().__init__(*inputs, **kwargs)
699
+
700
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs) -> Dict[str, Any]:
701
+ if "use_cache" in kwargs and not kwargs["use_cache"]:
702
+ return {"input_ids": input_ids}
703
+
704
+ if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
705
+ past_key_values = InferenceParams(
706
+ max_batch_size=input_ids.shape[0],
707
+ max_sequence_len=self.config.n_positions,
708
+ sequence_len_offset=0,
709
+ batch_size_offset=0,
710
+ fused_ft_kernel=False,
711
+ key_value_memory_dict={},
712
+ )
713
+ else:
714
+ # assume past_key_values has cached all but last token in input_ids
715
+ past_key_values.sequence_len_offset = len(input_ids[0]) - 1
716
+ input_ids = input_ids[:, -1].unsqueeze(-1)
717
+
718
+ return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs}
719
+
720
+
721
+ class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
722
+ """MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""
723
+
724
+ _keys_to_ignore_on_load_missing = [""]
725
+ _keys_to_ignore_on_load_unexpected = [r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
726
+
727
+ def __init__(self, config: MixFormerSequentialConfig) -> None:
728
+ super().__init__(config)
729
+
730
+ modules = [Embedding(config)]
731
+ block_config = config.architecture
732
+
733
+ if not isinstance(block_config, list):
734
+ block_config = [block_config for _ in range(config.n_layer)]
735
+
736
+ if config.n_layer != len(block_config):
737
+ config.n_layer = len(block_config)
738
+
739
+ for block_idx, block in enumerate(block_config):
740
+ # `block_cls` with `legacy` value is for backward compatibility
741
+ # `path` key is for backward compatibility
742
+ block = copy.deepcopy(block) or {"block_cls": "parallel"}
743
+ block_cls = block.pop("path", None) or block.pop("block_cls", None)
744
+
745
+ block["block_idx"] = block_idx
746
+ modules.append(ParallelBlock(config, **block))
747
+
748
+ modules.append(CausalLMHead(config))
749
+
750
+ self.layers = nn.Sequential(*modules)
751
+ self.loss = CausalLMLoss()
752
+
753
+ self.post_init()
754
+
755
+ def get_input_embeddings(self) -> nn.Embedding:
756
+ return self.layers[0].wte
757
+
758
+ def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
759
+ self.layers[0].wte = new_embeddings
760
+
761
+ def get_output_embeddings(self) -> nn.Linear:
762
+ return self.layers[-1].linear
763
+
764
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
765
+ self.layers[-1].linear = new_embeddings
766
+
767
+ def forward(
768
+ self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None,
769
+ past_key_values: Optional[torch.FloatTensor] = None, **kwargs
770
+ ) -> CausalLMOutputWithPast:
771
+
772
+ if not past_key_values:
773
+ lm_logits = self.layers(input_ids)
774
+ else:
775
+ hidden_layer = self.layers[0](input_ids)
776
+ for module in self.layers[1:-1]:
777
+ hidden_layer = module(hidden_layer, past_cache=past_key_values)
778
+ lm_logits = self.layers[-1](hidden_layer)
779
+
780
+ loss = None
781
+ if labels is not None:
782
+ loss = self.loss(lm_logits, labels)
783
+
784
+ return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eab6a12a9a2b78cac8f8975aea9f3a5e89ddadcb9e0dad27e40965e57e235a4a
3
+ size 2836623617
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "unk_token": "<|endoftext|>"
5
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<|endoftext|>",
4
+ "clean_up_tokenization_spaces": true,
5
+ "eos_token": "<|endoftext|>",
6
+ "model_max_length": 2048,
7
+ "tokenizer_class": "CodeGenTokenizer",
8
+ "unk_token": "<|endoftext|>"
9
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff