Charlie commited on
Commit
fbed214
1 Parent(s): a226da4

First commit from MLP Lab

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ pytorch_model-00001-of-00004.bin filter=lfs diff=lfs merge=lfs -text
37
+ pytorch_model-00002-of-00004.bin filter=lfs diff=lfs merge=lfs -text
38
+ pytorch_model-00003-of-00004.bin filter=lfs diff=lfs merge=lfs -text
39
+ pytorch_model-00004-of-00004.bin filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - zh
4
+ - en
5
+ tags:
6
+ - BatGPT
7
+ - MLP
8
+ pipeline_tag: text-generation
9
+ inference: false
10
+ ---
11
+ # BatGPT-15B-sirius
12
+
13
+ <!-- Provide a quick summary of what the model is/does. -->
14
+
15
+ ## 介绍 (Introduction)
16
+
17
+ BatGPT-15B-sirius 是上海交通大学与武汉大学<font size=1>(或武汉大学与上海交通大学,排名不分先后)</font>联合自然语言处理团队设计、预训练、对齐的系列大型语言模型 [BatGPT](https://github.com/zcli-charlie/BatGPT) 中的一个开源可商用版本。
18
+ BatGPT系列模型中还包括BatGPT-30B-orion,BatGPT-70B-alhena,以及BatGPT-140B-menkalinan。
19
+
20
+ BatGPT-15B-sirius 包含 150 亿参数,在中英文 1T 语料上进行了预训练,在权威的中文和英文 benchmark 上均取得同不错的效果。BatGPT-15B-sirius 有如下几个特点:
21
+
22
+ 1. **支持长达32K的上下文**:BatGPT-15B-sirius 采用旋转位置编码RoPE,在预训练阶段采用 2048 序列长度,并且在指令微调阶段逐步扩展到了 32K 上下文。
23
+ 2. **高效的预训练目标与模型架构**:BatGPT-15B-sirius 采用双向自回归预训练目标,以提高对于训练数据的运用程度,并且基于 [Multi-Query Attention](http://arxiv.org/abs/1911.02150) 技术,在保证参数规模的前提下尽可能的减少推理显存的占用,提高推理速度。
24
+ 3. **商业友好的开放协议**:BatGPT-15B-sirius 的源码以及权重不仅支持自由的学术研究使用,也允许免费开源商用,助推大模型进一步帮助人类的日常生活。
25
+
26
+ BatGPT-15B-sirius is an open-source commercially available version of the series of large-scale language models [BatGPT](https://github.com/zcli-charlie/BatGPT), designed, pretrained, and aligned by the joint natural language processing teams of Shanghai Jiao Tong University and Wuhan University <font size=1>(or Wuhan University and Shanghai Jiao Tong University, in no particular order)</font>.
27
+
28
+ The BatGPT series of models also include BatGPT-30B-orion, BatGPT-70B-alhena, and BatGPT-140B-menkalinan.
29
+
30
+ BatGPT-15B-sirius contains 15 billion parameters and has been pretrained on 1T Chinese and English corpora. It achieves excellent performance on authoritative Chinese and English benchmarks. BatGPT-15B-sirius has the following characteristics:
31
+
32
+ 1. **Supports Contexts Up to 32K Tokens**: BatGPT-15B-sirius uses rotated positional encoding (RoPE) and is pretrained with a sequence length of 2048 tokens. During fine-tuning, it gradually expands to support contexts up to 32K tokens.
33
+ 2. **Efficient Pre-training Objectives and Model Architecture**: BatGPT-15B-sirius employs a bidirectional autoregressive pretraining objective to better utilize the training data. It also utilizes the [Multi-Query Attention](http://arxiv.org/abs/1911.02150) technique to reduce inference memory consumption and improve inference speed while maintaining model size.
34
+ 3. **Business-friendly Open License**: The source code and weights of BatGPT-15B-sirius are not only available for academic research but also allow free and open-source commercial use, further facilitating the integration of large language models into human daily life.
35
+
36
+
37
+ ## 软件依赖
38
+
39
+ ```shell
40
+ pip install protobuf transformers cpm_kernels torch>=2.0 streamlit sentencepiece accelerate deepspeed
41
+ ```
42
+
43
+ ## 简易使用
44
+
45
+ 如下是一个使用 BatGPT-15B-sirius 进行对话的示例:
46
+
47
+ ```python
48
+ import torch
49
+ from transformers import AutoModelForCausalLM, AutoTokenizer
50
+ tokenizer = AutoTokenizer.from_pretrained("MLP-lab/BatGPT-15B-sirius", trust_remote_code=True)
51
+ model = AutoModelForCausalLM.from_pretrained("MLP-lab/BatGPT-15B-sirius", torch_dtype=torch.float16, trust_remote_code=True).cuda()
52
+ model = model.eval()
53
+ history = []
54
+ system_prompt = None # 你也可以指定系统提示
55
+ response, history = model.chat(tokenizer, "你好", history=history, system_prompt=system_prompt)
56
+ print(response)
57
+ response, history = model.chat(tokenizer, "介绍一下你自己", history=history, system_prompt=system_prompt)
58
+ print(response)
59
+ ```
60
+
61
+ Here is an example of a conversation using BatGPT-15B-sirius:
62
+
63
+ ```python
64
+ import torch
65
+ from transformers import AutoModelForCausalLM, AutoTokenizer
66
+ tokenizer = AutoTokenizer.from_pretrained("MLP-lab/BatGPT-15B-sirius", trust_remote_code=True)
67
+ model = AutoModelForCausalLM.from_pretrained("MLP-lab/BatGPT-15B-sirius", torch_dtype=torch.float16, trust_remote_code=True).cuda()
68
+ model = model.eval()
69
+ history = []
70
+ system_prompt = None # You can give a system prompt here.
71
+ response, history = model.chat(tokenizer, "Hello", history=history, system_prompt=system_prompt)
72
+ print(response)
73
+ response, history = model.chat(tokenizer, "Please introduce yourself", history=history, system_prompt=system_prompt)
74
+ print(response)
75
+ ```
76
+
77
+
78
+ ## 模型详情 (Model Details)
79
+
80
+
81
+ BatGPT-15B-sirius 具体参数和见下表:
82
+
83
+ | 模型名称 | 隐含层维度 | 层数 | Query头数 | Key/Value头数 |词表大小 | 总参数量 | 训练数据(tokens) | 位置编码 | 最大长度 |
84
+ |-------------------------|-------|------------|------------|------------|-----------------|--------|--------|----------------|---------|
85
+ | BatGPT-15B-sirius | 5,632 | 48 | 44 | 2 | 65,536 | 15,030,081,024 | 1 万亿 | [RoPE](https://arxiv.org/abs/2104.09864) | 32K |
86
+
87
+
88
+
89
+ The specific parameters of BatGPT-15B-sirius are as follows:
90
+ | Model Name | Hidden Size | Num Layers | Query Heads | Key/Value Heads |Vocab Size | Total Params | Training Dats(tokens) | Position Embedding | Max Length |
91
+ |-------------------------|-------|------------|------------|------------|-----------------|--------|--------|----------------|---------|
92
+ | BatGPT-15B-sirius | 5,632 | 48 | 44 | 2 | 65,536 | 15,030,081,024 | 1 万亿 | [RoPE](https://arxiv.org/abs/2104.09864) | 32K |
93
+
94
+
95
+
96
+ - **Developed by:** MLP Lab of Wuhan University, Shanghai Jiao Tong University
97
+ - **Email**: zcli-charlie@whu.edu.cn, zhaohai@cs.sjtu.edu.cn
98
+ - **Language(s) (NLP):** Chinese/English
99
+ - **License:** The code in this project is licensed under the Apache 2.0 license, the model weights are licensed under the GNU AGPL 3.0 license. If you intend to use the models included in this project for commercial purposes or public deployment, please email to us to obtain authorization. Commercial usage information will be used for record purposes only and no fees will be charged.
100
+
101
+
102
+ ## 免责声明 (Disclaimers)
103
+
104
+ BatGPT-15B-sirius 模型的使用应当遵循社会的公序良俗,不能被用于任何危害国家社会安全或违法的活动。另外,我们也要求使用者不要将 BatGPT-15B-sirius 模型用于未经适当安全审查和备案的互联网服务。我们希望所有的使用者都能遵守这个原则,确保科技的发展能在规范和合法的环境下进行。
105
+
106
+ 我们已经尽我们所能,来确保模型训练过程中使用的数据的合规性。然而,尽管我们已经做出了巨大的努力,但由于模型和数据的复杂性,仍有可能存在一些无法预见的问题。如使用本项目所含模型及其修改版本提供服务产生误导性或有害性言论,造成不良影响,由服务提供方负责,与本项目无关。
107
+
108
+ The use of the BatGPT-15B-sirius model should adhere to societal norms and not be used for any activities that jeopardize national or social security or violate the law. Additionally, we also request users not to use the BatGPT-15B-sirius model for internet services that have not undergone appropriate security review and documentation. We hope that all users will abide by this principle to ensure that technological development occurs in a regulated and legal environment.
109
+
110
+ We have done our best to ensure the compliance of the data used during the model training process. However, despite our significant efforts, unforeseen issues may still arise due to the complexity of the model and data. If misleading or harmful statements are generated through the use of the models included in this project or their modified versions while providing services, the responsibility lies with the service provider and is not associated with this project.
111
+
112
+ ## 引用
113
+
114
+ 如果你觉得我们的工作有帮助的话,请考虑引用我们的BatGPT论文:
115
+
116
+ If you find our work helpful, please consider citing our BatGPT paper:
117
+
118
+ ```
119
+ @article{li2023batgpt,
120
+ title={BatGPT: A Bidirectional Autoregessive Talker from Generative Pre-trained Transformer},
121
+ author={Li, Zuchao and Zhang, Shitou and Zhao, Hai and Yang, Yifei and Yang, Dongjie},
122
+ journal={arXiv preprint arXiv:2307.00360},
123
+ year={2023}
124
+ }
125
+ ```
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/data/workspace/bat-gpt/checkpoints/batgpt-15b-sirius",
3
+ "alibi": false,
4
+ "architectures": [
5
+ "BatGPTForCausalLM"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_batgpt.BatGPTConfig",
9
+ "AutoModel": "modeling_batgpt.BatGPTForCausalLM",
10
+ "AutoModelForCausalLM": "modeling_batgpt.BatGPTForCausalLM",
11
+ "AutoModelForSeq2SeqLM": "modeling_batgpt.BatGPTForCausalLM"
12
+ },
13
+ "emb_dim": 5632,
14
+ "empty_init": false,
15
+ "eos_token_id": 2,
16
+ "ffn_hidden_size": 13696,
17
+ "hidden_dropout": 0.0,
18
+ "hidden_size": 5632,
19
+ "layer_norm_epsilon": 1e-05,
20
+ "max_seq_len": 32768,
21
+ "mlp_activation": "swiglu",
22
+ "model_type": "batgpt",
23
+ "n_head": 44,
24
+ "n_layer": 48,
25
+ "num_heads_per_kv": 2,
26
+ "pad_token_id": 0,
27
+ "pos_emb_impl": "rope",
28
+ "prefix_proj": false,
29
+ "prefix_size": null,
30
+ "qkv_bias": true,
31
+ "tie_word_embeddings": false,
32
+ "torch_dtype": "float16",
33
+ "transformers_version": "4.30.2",
34
+ "use_cache": true,
35
+ "use_emb_factorization": false,
36
+ "use_multi_query_attn": true,
37
+ "use_native_attn_impl": true,
38
+ "vocab_size": 65536
39
+ }
configuration_batgpt.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class BatGPTConfig(PretrainedConfig):
5
+
6
+ model_type = "batgpt"
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=65024,
11
+ emb_dim=5632,
12
+ hidden_size=5632,
13
+ n_layer=48,
14
+ n_head=44,
15
+ layer_norm_epsilon=1e-5,
16
+ use_multi_query_attn=True,
17
+ num_heads_per_kv=2,
18
+ qkv_bias=True,
19
+ use_native_attn_impl=True,
20
+ mlp_activation="swiglu",
21
+ hidden_dropout=0.0,
22
+ ffn_hidden_size=13696,
23
+ prefix_size=None,
24
+ prefix_proj=False,
25
+ max_seq_len=32768,
26
+ pos_emb_impl="rope",
27
+ use_emb_factorization=False,
28
+ empty_init=True,
29
+ **kwargs
30
+ ):
31
+ self.vocab_size = vocab_size
32
+ self.emb_dim = emb_dim
33
+ self.hidden_size = hidden_size
34
+ self.n_layer = n_layer
35
+ self.n_head = n_head
36
+ self.layer_norm_epsilon = layer_norm_epsilon
37
+ self.use_multi_query_attn = use_multi_query_attn
38
+ self.num_heads_per_kv = num_heads_per_kv
39
+ self.qkv_bias = qkv_bias
40
+ self.use_native_attn_impl = use_native_attn_impl
41
+ self.mlp_activation = mlp_activation
42
+ self.hidden_dropout = hidden_dropout
43
+ self.ffn_hidden_size = ffn_hidden_size
44
+ self.prefix_size = prefix_size
45
+ self.prefix_proj = prefix_proj
46
+ self.max_seq_len = max_seq_len
47
+ self.pos_emb_impl = pos_emb_impl
48
+ self.use_emb_factorization = use_emb_factorization
49
+ self.empty_init = empty_init
50
+ super().__init__(**kwargs)
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": 2,
4
+ "pad_token_id": 0,
5
+ "transformers_version": "4.30.2"
6
+ }
modeling_batgpt.py ADDED
@@ -0,0 +1,1166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code serves as a port of the models described in BatGPT.
2
+ # It is based on the bloom codebase, which provides the initial framework for our model implementation.
3
+ # To understand how to use these models, please refer to the documentation and usage instructions provided in the bloom models repository.
4
+ # Additionally, we draw inspiration from the ChatGLM and Baichuan codebase, which includes implementations for prefix encoder, chat, and stream_chat functionalities. These components are utilized in our ported models.
5
+ # Feel free to explore the ChatGLM and Baichuan codebase for further insights on how these components can be utilized effectively.
6
+
7
+ import math
8
+ import warnings
9
+ from typing import Optional, Tuple, Union, List, Callable, Dict, Any
10
+
11
+ import torch
12
+ import torch.utils.checkpoint
13
+ from torch import nn
14
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
15
+ from torch.nn import functional as F
16
+ from torch.nn.utils import skip_init
17
+
18
+ import copy
19
+ import re
20
+ import sys
21
+
22
+ from transformers.modeling_outputs import (
23
+ BaseModelOutputWithPast,
24
+ CausalLMOutputWithPast,
25
+ )
26
+ from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.utils import logging
28
+ from transformers.generation.logits_process import LogitsProcessor
29
+ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
30
+
31
+ from .configuration_batgpt import BatGPTConfig
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ # flags required to enable jit fusion kernels
37
+
38
+ if sys.platform != 'darwin':
39
+ torch._C._jit_set_profiling_mode(False)
40
+ torch._C._jit_set_profiling_executor(False)
41
+ torch._C._jit_override_can_fuse_on_cpu(True)
42
+ torch._C._jit_override_can_fuse_on_gpu(True)
43
+
44
+
45
+ # For faster llm model initilization
46
+ def module_init(cls, empty_init, *args, **kwargs):
47
+ if empty_init:
48
+ return skip_init(cls, *args, **kwargs)
49
+ else:
50
+ return cls(*args, **kwargs)
51
+
52
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
53
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
54
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
55
+ scores.zero_()
56
+ scores[..., 5] = 5e4
57
+ return scores
58
+
59
+
60
+ class PrefixEncoder(torch.nn.Module):
61
+ """
62
+ The torch.nn model to encode the prefix
63
+ Input shape: (batch-size, prefix-length)
64
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
65
+ """
66
+
67
+ def __init__(self, config: BatGPTConfig):
68
+ super().__init__()
69
+ self.prefix_proj = config.prefix_proj
70
+ self.head_dim = config.hidden_size // config.n_head
71
+ if self.prefix_proj:
72
+ # Use a two-layer MLP to encode the prefix
73
+ kv_size = config.n_layer * self.head_dim * config.num_heads_per_kv * 2
74
+ self.embedding = torch.nn.Embedding(config.prefix_size, kv_size)
75
+ self.trans = torch.nn.Sequential(
76
+ torch.nn.Linear(kv_size, config.hidden_size),
77
+ torch.nn.Tanh(),
78
+ torch.nn.Linear(config.hidden_size, kv_size)
79
+ )
80
+ else:
81
+ self.embedding = torch.nn.Embedding(config.prefix_size,
82
+ config.n_layer * self.head_dim * config.num_heads_per_kv * 2)
83
+
84
+ def forward(self, prefix: torch.Tensor):
85
+ if self.prefix_proj:
86
+ prefix_tokens = self.embedding(prefix)
87
+ past_key_values = self.trans(prefix_tokens)
88
+ else:
89
+ past_key_values = self.embedding(prefix)
90
+ return past_key_values
91
+
92
+
93
+ def _get_interleave(n):
94
+ def _get_interleave_power_of_2(n):
95
+ start = (2 ** (-2 ** -(math.log2(n) - 3)))
96
+ ratio = start
97
+ return [start * ratio ** i for i in range(n)]
98
+
99
+ if math.log2(n).is_integer():
100
+ return _get_interleave_power_of_2(n)
101
+ else:
102
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
103
+ return _get_interleave_power_of_2(closest_power_of_2) + \
104
+ _get_interleave(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
105
+
106
+ def _fill_with_neg_inf(t):
107
+ """FP16-compatible function that fills a tensor with -inf."""
108
+ return t.float().fill_(float("-inf")).type_as(t)
109
+
110
+ def _gen_alibi_mask(n_head, max_pos):
111
+ """used in inference only"""
112
+ slopes = torch.Tensor(_get_interleave(n_head))
113
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0).expand(
114
+ n_head, -1, -1)
115
+ alibi = alibi.view(n_head, 1, max_pos)
116
+ alibi_mask = torch.triu(
117
+ _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1
118
+ )
119
+ alibi_mask = alibi_mask.unsqueeze(0) + alibi
120
+ return alibi_mask
121
+
122
+ def _build_position_ids(input_ids, device):
123
+ batch_size, seq_length = input_ids.shape
124
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
125
+ return position_ids
126
+
127
+ def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
128
+ """used in training only"""
129
+ dim = tensor.size(0)
130
+ _future_mask = torch.triu(
131
+ _fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1
132
+ )
133
+ _future_mask = _future_mask.unsqueeze(0) + alibi
134
+ _future_mask = _future_mask.to(tensor)
135
+ return _future_mask[:tensor.shape[1] * attn_heads, :maxpos, :maxpos]
136
+
137
+ @torch.jit.script
138
+ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
139
+ # x: [sq, b, np, hn]
140
+ sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
141
+ rot_dim = rope_cache.shape[-2] * 2
142
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
143
+ # truncate to support variable sizes
144
+ rope_cache = rope_cache[:sq]
145
+ xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
146
+ rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
147
+ x_out2 = torch.stack(
148
+ [
149
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
150
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
151
+ ],
152
+ -1,
153
+ )
154
+ x_out2 = x_out2.flatten(3)
155
+ return torch.cat((x_out2, x_pass), dim=-1)
156
+
157
+
158
+
159
+
160
+
161
+ class RMSNorm(torch.nn.Module):
162
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
163
+ super().__init__()
164
+ self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
165
+ self.eps = eps
166
+
167
+ def forward(self, hidden_states: torch.Tensor):
168
+ input_dtype = hidden_states.dtype
169
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
170
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
171
+
172
+ return (self.weight * hidden_states).to(input_dtype)
173
+
174
+
175
+ class SelfAttention(torch.nn.Module):
176
+ def __init__(self, config: BatGPTConfig, device=None):
177
+ super(SelfAttention, self).__init__()
178
+
179
+ self.num_heads = config.n_head
180
+ self.use_multi_query_attn = config.use_multi_query_attn
181
+ self.num_heads_per_kv = config.num_heads_per_kv
182
+ self.qkv_bias = config.qkv_bias
183
+ self.use_native_attn_impl = config.use_native_attn_impl
184
+ if not self.use_multi_query_attn:
185
+ assert self.num_heads_per_kv == self.num_heads, "num_heads_per_kv must equal to num_heads when not use_multi_query_attn"
186
+
187
+ self.head_dim = config.hidden_size // config.n_head
188
+
189
+ self.query_proj = nn.Linear(
190
+ config.hidden_size, config.hidden_size, bias=self.qkv_bias,
191
+ device=device, **_config_to_kwargs(config)
192
+ )
193
+
194
+ self.key_proj = nn.Linear(
195
+ config.hidden_size, self.head_dim * self.num_heads_per_kv, bias=self.qkv_bias,
196
+ device=device, **_config_to_kwargs(config)
197
+ )
198
+ self.value_proj = nn.Linear(
199
+ config.hidden_size, self.head_dim * self.num_heads_per_kv, bias=self.qkv_bias,
200
+ device=device, **_config_to_kwargs(config)
201
+ )
202
+
203
+ # Output.
204
+ self.dense = nn.Linear(
205
+ config.hidden_size, config.hidden_size, bias=False,
206
+ device=device, **_config_to_kwargs(config)
207
+ )
208
+
209
+ def forward(
210
+ self,
211
+ hidden_states,
212
+ attention_mask,
213
+ rotary_pos_emb,
214
+ kv_cache=None,
215
+ use_cache=True
216
+ ):
217
+ # 1. query/key/value mapping
218
+ # hidden_states: [seq_len, batch_size, hidden_size]
219
+ seq_len, batch_size, hidden_size = hidden_states.shape
220
+ query_layer = self.query_proj(hidden_states)
221
+ key_layer = self.key_proj(hidden_states)
222
+ value_layer = self.value_proj(hidden_states)
223
+
224
+ query_layer = query_layer.view(seq_len, batch_size, self.num_heads, self.head_dim)
225
+
226
+ key_layer = key_layer.view(seq_len, batch_size, self.num_heads_per_kv, self.head_dim)
227
+
228
+ value_layer = value_layer.view(seq_len, batch_size, self.num_heads_per_kv, self.head_dim)
229
+
230
+ # 2. apply the rotary position embedding
231
+ if rotary_pos_emb is not None:
232
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
233
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
234
+
235
+ # 3. adjust key and value for inference
236
+ if kv_cache is not None:
237
+ cache_k, cache_v = kv_cache
238
+ key_layer = torch.cat((cache_k, key_layer), dim=0)
239
+ value_layer = torch.cat((cache_v, value_layer), dim=0)
240
+ if use_cache:
241
+ kv_cache = (key_layer, value_layer)
242
+ else:
243
+ kv_cache = None
244
+
245
+ # 4. repeat the key and value for attention
246
+ if self.num_heads_per_kv != self.num_heads:
247
+ key_layer = key_layer.unsqueeze(-2)
248
+ key_layer = key_layer.expand(
249
+ -1, -1, -1, self.num_heads // self.num_heads_per_kv, -1
250
+ )
251
+ key_layer = key_layer.contiguous().view(
252
+ key_layer.size()[:2] + (self.num_heads, self.head_dim)
253
+ )
254
+ value_layer = value_layer.unsqueeze(-2)
255
+ value_layer = value_layer.expand(
256
+ -1, -1, -1, self.num_heads // self.num_heads_per_kv, -1
257
+ )
258
+ value_layer = value_layer.contiguous().view(
259
+ value_layer.size()[:2] + (self.num_heads, self.head_dim)
260
+ )
261
+
262
+ # 5. attention [seq_len, batch_size, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim]
263
+ query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
264
+
265
+ pytorch_version = int(torch.__version__.split('.')[0])
266
+ if self.use_native_attn_impl and pytorch_version >= 2:
267
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
268
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
269
+ is_causal=True)
270
+ else:
271
+ if attention_mask is not None:
272
+ attention_mask = ~attention_mask
273
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
274
+ attention_mask)
275
+ else:
276
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(2, 3)) / math.sqrt(self.head_dim)
277
+
278
+ if attention_mask is not None:
279
+ if seq_len == 1: # inference with cache
280
+ if len(attention_mask.size()) == 4:
281
+ attention_mask = attention_mask[:, :, -1:, :]
282
+ else:
283
+ attention_mask = attention_mask[:, -1:, :]
284
+ attention_scores = attention_scores + attention_mask
285
+ attention_scores = torch.max(attention_scores, torch.tensor(torch.finfo(attention_scores.dtype).min))
286
+
287
+ attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
288
+
289
+ context_layer = torch.matmul(attention_probs, value_layer)
290
+
291
+ # [batch_size, num_heads, seq_len, head_dim] -> [seq_len, batch_size, num_heads, head_dim]
292
+ context_layer = context_layer.permute(2, 0, 1, 3)
293
+
294
+ # [seq_len, batch_size, hidden_size]
295
+ context_layer = context_layer.reshape(seq_len, batch_size, hidden_size)
296
+
297
+ #
298
+ output = self.dense(context_layer)
299
+
300
+ return output, kv_cache
301
+
302
+
303
+ def _config_to_kwargs(args):
304
+ common_kwargs = {
305
+ "dtype": args.torch_dtype,
306
+ }
307
+ return common_kwargs
308
+
309
+
310
+ class MLP(torch.nn.Module):
311
+ def __init__(self, config: BatGPTConfig, device=None):
312
+ super(MLP, self).__init__()
313
+ self.mlp_activation = config.mlp_activation
314
+
315
+ def swiglu(x):
316
+ x = torch.chunk(x, 2, dim=-1)
317
+ return F.silu(x[0]) * x[1]
318
+
319
+ def silu(x):
320
+ return F.silu(x)
321
+
322
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
323
+ if self.mlp_activation == "swiglu":
324
+ self.activation_func = swiglu
325
+
326
+ self.gate_proj = None
327
+
328
+ self.dense_h_to_4h = nn.Linear(
329
+ config.hidden_size,
330
+ config.ffn_hidden_size * 2,
331
+ bias=False,
332
+ device=device,
333
+ **_config_to_kwargs(config)
334
+ )
335
+ elif self.mlp_activation == "silu":
336
+ self.activation_func = silu
337
+
338
+ self.gate_proj = nn.Linear(
339
+ config.hidden_size,
340
+ config.ffn_hidden_size,
341
+ bias=False,
342
+ device=device,
343
+ **_config_to_kwargs(config)
344
+ )
345
+
346
+ self.dense_h_to_4h = nn.Linear(
347
+ config.hidden_size,
348
+ config.ffn_hidden_size,
349
+ bias=False,
350
+ device=device,
351
+ **_config_to_kwargs(config)
352
+ )
353
+ else:
354
+ raise NotImplementedError("mlp_activation {} not supported".format(self.mlp_activation))
355
+
356
+ # Project back to h.
357
+ self.dense_4h_to_h = nn.Linear(
358
+ config.ffn_hidden_size,
359
+ config.hidden_size,
360
+ bias=False,
361
+ device=device,
362
+ **_config_to_kwargs(config)
363
+ )
364
+
365
+ def forward(self, hidden_states):
366
+
367
+ # [s, b, 4hp]
368
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
369
+
370
+ if self.mlp_activation == "swiglu":
371
+ intermediate_parallel = self.activation_func(intermediate_parallel)
372
+ elif self.mlp_activation == "silu":
373
+ gated_weight = self.activation_func(self.gate_proj(hidden_states))
374
+ intermediate_parallel = gated_weight * intermediate_parallel
375
+ else:
376
+ raise NotImplementedError("mlp_activation {} not supported".format(self.mlp_activation))
377
+
378
+ # [s, b, h]
379
+ output = self.dense_4h_to_h(intermediate_parallel)
380
+
381
+ return output
382
+
383
+
384
+ class BatGPTLayer(torch.nn.Module):
385
+ """A single transformer layer.
386
+
387
+ Transformer layer takes input with size [s, b, h] and returns an
388
+ output of the same size.
389
+ """
390
+
391
+ def __init__(self, config: BatGPTConfig, device=None):
392
+ super(BatGPTLayer, self).__init__()
393
+
394
+ # Layernorm on the input data.
395
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon, device=device,
396
+ dtype=config.torch_dtype)
397
+
398
+ # Self attention.
399
+ self.self_attention = SelfAttention(config, device=device)
400
+
401
+ self.hidden_dropout = config.hidden_dropout
402
+
403
+ # Layernorm on the attention output
404
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon, device=device,
405
+ dtype=config.torch_dtype)
406
+
407
+ # MLP
408
+ self.mlp = MLP(config, device=device)
409
+
410
+ def forward(
411
+ self,
412
+ hidden_states,
413
+ attention_mask,
414
+ rotary_pos_emb,
415
+ kv_cache=None,
416
+ use_cache=True,
417
+ ):
418
+ # hidden_states: [s, b, h]
419
+ residual = hidden_states
420
+
421
+ # Layer norm at the beginning of the transformer layer.
422
+ layernorm_output = self.input_layernorm(hidden_states)
423
+
424
+ # Self attention.
425
+ attention_output, kv_cache = self.self_attention(
426
+ layernorm_output,
427
+ attention_mask,
428
+ rotary_pos_emb,
429
+ kv_cache=kv_cache,
430
+ use_cache=use_cache
431
+ )
432
+
433
+ # Residual connection.
434
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
435
+
436
+ layernorm_input = residual + layernorm_input
437
+
438
+ # Layer norm post the self attention.
439
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
440
+
441
+ # MLP.
442
+ mlp_output = self.mlp(layernorm_output)
443
+
444
+ # Second residual connection.
445
+ residual = layernorm_input
446
+
447
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
448
+
449
+ output = residual + output
450
+
451
+ return output, kv_cache
452
+
453
+
454
+ class BatGPTTransformer(torch.nn.Module):
455
+ """Transformer class."""
456
+
457
+ def __init__(self, config: BatGPTConfig, device=None):
458
+ super(BatGPTTransformer, self).__init__()
459
+
460
+ # Number of layers.
461
+ self.num_layers = config.n_layer
462
+
463
+ # Transformer layers.
464
+ def build_layer():
465
+ return BatGPTLayer(config, device=device)
466
+
467
+ self.layers = torch.nn.ModuleList([build_layer() for i in range(self.num_layers)])
468
+
469
+ # final layer norm before output.
470
+ self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon, device=device,
471
+ dtype=config.torch_dtype)
472
+
473
+ self.gradient_checkpointing = False
474
+
475
+ def _get_layer(self, layer_number):
476
+ return self.layers[layer_number]
477
+
478
+ def forward(
479
+ self,
480
+ hidden_states,
481
+ attention_mask,
482
+ rotary_pos_emb,
483
+ kv_caches=None,
484
+ use_cache: Optional[bool] = True,
485
+ output_hidden_states: Optional[bool] = False,
486
+ ):
487
+ if not kv_caches:
488
+ kv_caches = [None for _ in range(self.num_layers)]
489
+ presents = () if use_cache else None
490
+ if self.gradient_checkpointing and self.training:
491
+ if use_cache:
492
+ logger.warning_once(
493
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
494
+ )
495
+ use_cache = False
496
+
497
+ all_self_attentions = None
498
+ all_hidden_states = () if output_hidden_states else None
499
+ for index in range(self.num_layers):
500
+ if output_hidden_states:
501
+ all_hidden_states = all_hidden_states + (hidden_states,)
502
+
503
+ layer = self._get_layer(index)
504
+ if self.gradient_checkpointing and self.training:
505
+ layer_ret = torch.utils.checkpoint.checkpoint(
506
+ layer,
507
+ hidden_states,
508
+ attention_mask,
509
+ rotary_pos_emb,
510
+ kv_caches[index],
511
+ use_cache
512
+ )
513
+ else:
514
+ layer_ret = layer(
515
+ hidden_states,
516
+ attention_mask,
517
+ rotary_pos_emb,
518
+ kv_cache=kv_caches[index],
519
+ use_cache=use_cache
520
+ )
521
+ hidden_states, kv_cache = layer_ret
522
+ if use_cache:
523
+ presents = presents + (kv_cache,)
524
+
525
+ if output_hidden_states:
526
+ all_hidden_states = all_hidden_states + (hidden_states,)
527
+
528
+ hidden_states = self.ln_f(hidden_states)
529
+
530
+ return hidden_states, presents, all_hidden_states, all_self_attentions
531
+
532
+
533
+ class BatGPTPreTrainedModel(PreTrainedModel):
534
+ """
535
+ An abstract class to handle weights initialization and
536
+ a simple interface for downloading and loading pretrained models.
537
+ """
538
+
539
+ is_parallelizable = False
540
+ supports_gradient_checkpointing = True
541
+ config_class = BatGPTConfig
542
+ base_model_prefix = "transformer"
543
+ _no_split_modules = ["BatGPTLayer"]
544
+
545
+ def _init_weights(self, module: nn.Module):
546
+ """Initialize the weights."""
547
+ return
548
+
549
+
550
+
551
+ def _set_gradient_checkpointing(self, module, value=False):
552
+ if isinstance(module, BatGPTTransformer):
553
+ module.gradient_checkpointing = value
554
+
555
+
556
+
557
+ class BatGPTModel(BatGPTPreTrainedModel):
558
+ def __init__(self, config: BatGPTConfig, device=None):
559
+ super().__init__(config)
560
+
561
+ self.num_layers = config.n_layer
562
+ self.num_heads = config.n_head
563
+ self.head_dim = config.hidden_size // config.n_head
564
+ self.max_seq_len = config.max_seq_len
565
+ self.pos_emb_impl = config.pos_emb_impl
566
+ self.model_cache_seq_len = 1024
567
+
568
+ # word embedding
569
+ self.word_embeddings = module_init(nn.Embedding,
570
+ config.empty_init,
571
+ config.vocab_size,
572
+ config.emb_dim,
573
+ dtype=config.torch_dtype,
574
+ device=device
575
+ )
576
+
577
+ self.emb_fact = None
578
+ if config.use_emb_factorization or config.emb_dim != config.hidden_size:
579
+ self.emb_fact = nn.Linear(config.emb_dim, config.hidden_size, bias=False,
580
+ dtype=config.torch_dtype, device=device)
581
+
582
+ init_kwargs = {}
583
+ if device is not None:
584
+ init_kwargs["device"] = device
585
+
586
+ self.encoder = module_init(BatGPTTransformer, config.empty_init, config, **init_kwargs)
587
+
588
+ self.first_run = True
589
+ self.alibi_mask = None
590
+
591
+ self.prefix_size = config.prefix_size
592
+ self.prefix_proj = config.prefix_proj
593
+ if self.prefix_size is not None:
594
+ for param in self.parameters():
595
+ param.requires_grad = False
596
+ self.prefix_tokens = torch.arange(self.prefix_size).long()
597
+ self.prefix_encoder = PrefixEncoder(config)
598
+ self.dropout = torch.nn.Dropout(0.1)
599
+
600
+ def get_input_embeddings(self):
601
+ return self.word_embeddings
602
+
603
+ def get_prompt(self, batch_size, device, dtype=torch.half):
604
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
605
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
606
+ past_key_values = past_key_values.view(
607
+ batch_size,
608
+ self.prefix_size,
609
+ self.num_layers * 2,
610
+ self.multi_query_group_num,
611
+ self.kv_channels
612
+ )
613
+ # seq_len, b, nh, hidden_size
614
+ past_key_values = self.dropout(past_key_values)
615
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
616
+ return past_key_values
617
+
618
+ def get_rotary_tensor(self, seq_len: int, head_dim: int, dtype: torch.dtype, device: torch.device, base: int = 10000):
619
+
620
+ n_elem = head_dim // 2
621
+
622
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
623
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
624
+
625
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
626
+ seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
627
+
628
+ # Calculate the product of position index and $\theta_i$
629
+ idx_theta = torch.outer(seq_idx, theta).float()
630
+
631
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
632
+
633
+ # this is to mimic the behaviour of complex32, else we will get different results
634
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
635
+ cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
636
+
637
+ return cache
638
+
639
+ def get_causal_mask(self, input_ids, past_key_values, attention_mask=None) -> torch.BoolTensor:
640
+
641
+ batch_size, seq_length = input_ids.shape
642
+
643
+ # B x L x L
644
+ causal_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
645
+ causal_mask.tril_()
646
+
647
+ past_length = 0
648
+ if past_key_values:
649
+ past_length = past_key_values[0][0].shape[0]
650
+
651
+ if past_length:
652
+ causal_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
653
+ device=input_ids.device), causal_mask), dim=-1)
654
+
655
+ if attention_mask is not None:
656
+ causal_mask = causal_mask * attention_mask.unsqueeze(1)
657
+
658
+ if not past_length and attention_mask is not None:
659
+ causal_mask -= attention_mask.unsqueeze(-1) - 1
660
+
661
+ causal_mask = (causal_mask < 0.5).bool()
662
+ causal_mask.unsqueeze_(1)
663
+
664
+ return causal_mask
665
+
666
+ def get_alibi_mask(self, tensor, seq_length_with_past):
667
+ if self.training:
668
+ slopes = torch.Tensor(_get_interleave(self.num_heads))
669
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(seq_length_with_past).unsqueeze(0).unsqueeze(0).expand(
670
+ self.num_heads,
671
+ -1, -1)
672
+ alibi = alibi.view(self.num_heads, 1, seq_length_with_past)
673
+ mask = _buffered_future_mask(tensor, seq_length_with_past, alibi, self.num_heads)
674
+ else:
675
+ if self.first_run:
676
+ self.first_run = False
677
+ self.register_buffer("future_mask", _gen_alibi_mask(self.num_heads, self.model_cache_seq_len).to(tensor), persistent=False)
678
+ if seq_length_with_past > self.model_cache_seq_len:
679
+ self.model_cache_seq_len = seq_length_with_past
680
+ self.register_buffer("future_mask", _gen_alibi_mask(self.num_heads, self.model_cache_seq_len).to(tensor), persistent=False)
681
+ mask = self.future_mask[:self.num_heads, :seq_length_with_past, :seq_length_with_past]
682
+ return mask
683
+
684
+
685
+ def forward(
686
+ self,
687
+ input_ids,
688
+ position_ids: Optional[torch.Tensor] = None,
689
+ attention_mask: Optional[torch.BoolTensor] = None,
690
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
691
+ inputs_embeds: Optional[torch.Tensor] = None,
692
+ use_cache: Optional[bool] = None,
693
+ output_hidden_states: Optional[bool] = None,
694
+ return_dict: Optional[bool] = None,
695
+ ):
696
+ output_hidden_states = (
697
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
698
+ )
699
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
700
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
701
+
702
+ batch_size, seq_length = input_ids.shape
703
+
704
+ seq_length_with_past = seq_length
705
+
706
+ # -> word embedding
707
+ if inputs_embeds is None:
708
+ inputs_embeds = self.word_embeddings(input_ids)
709
+ # [b s h] --> [s b h].
710
+ inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
711
+
712
+ if self.prefix_size is not None:
713
+ if past_key_values is None:
714
+ past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
715
+ dtype=inputs_embeds.dtype)
716
+ if attention_mask is not None:
717
+ attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.prefix_size)),
718
+ attention_mask], dim=-1)
719
+
720
+ if past_key_values is not None:
721
+ past_key_values_length = past_key_values[0][0].shape[0]
722
+ seq_length_with_past = seq_length_with_past + past_key_values_length
723
+
724
+
725
+ full_attention_mask = None
726
+ rotary_pos_emb=None
727
+ if self.pos_emb_impl == "alibi":
728
+ if self.training:
729
+ if self.alibi_mask is None or self.alibi_mask.shape[-1] != seq_length_with_past:
730
+ self.alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
731
+ alibi_mask = self.alibi_mask
732
+ else:
733
+ alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
734
+
735
+
736
+ if attention_mask is not None:
737
+
738
+ if len(attention_mask.shape) == 2:
739
+ expanded_mask = attention_mask.to(alibi_mask.dtype)
740
+ expanded_mask = torch.tril(torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0)
741
+ ) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0)
742
+ else:
743
+ expanded_mask = attention_mask
744
+ src_len, tgt_len = alibi_mask.size()[-2:]
745
+ expanded_mask = expanded_mask.unsqueeze(1).expand(batch_size, 1, src_len, tgt_len).to(alibi_mask.dtype)
746
+ # Target sizes: [1, 1, 41, 41]. Tensor sizes: [1, 1, 8, 8]
747
+ inverted_mask = 1.0 - expanded_mask
748
+ inverted_mask = inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(alibi_mask.dtype).min)
749
+ full_attention_mask = inverted_mask + alibi_mask.unsqueeze(0)
750
+ else:
751
+ full_attention_mask = alibi_mask
752
+ elif self.pos_emb_impl == "rope":
753
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
754
+ # B x 1 x L x L
755
+ full_attention_mask = self.get_causal_mask(input_ids, past_key_values, attention_mask)
756
+
757
+ # Rotary positional embeddings
758
+ rotary_pos_emb = self.get_rotary_tensor(self.max_seq_len, self.head_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device)
759
+ if position_ids is not None:
760
+ rotary_pos_emb = rotary_pos_emb[position_ids]
761
+ else:
762
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
763
+ rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
764
+ else:
765
+ raise NotImplementedError("position embedding type: {} not supported!".format(self.pos_emb_impl))
766
+
767
+
768
+ # Run encoder.
769
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
770
+ inputs_embeds,
771
+ full_attention_mask,
772
+ rotary_pos_emb=rotary_pos_emb,
773
+ kv_caches=past_key_values,
774
+ use_cache=use_cache,
775
+ output_hidden_states=output_hidden_states
776
+ )
777
+
778
+ if not return_dict:
779
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
780
+
781
+ return BaseModelOutputWithPast(
782
+ last_hidden_state=hidden_states,
783
+ past_key_values=presents,
784
+ hidden_states=all_hidden_states,
785
+ attentions=all_self_attentions,
786
+ )
787
+
788
+
789
+ class BatGPTForCausalLM(BatGPTPreTrainedModel):
790
+ def __init__(self, config: BatGPTConfig, device=None):
791
+ super().__init__(config)
792
+
793
+ self.max_sequence_length = config.max_length
794
+
795
+ self.model = BatGPTModel(config, device=device)
796
+
797
+ self.lm_head = module_init(nn.Linear, config.empty_init, config.hidden_size, config.vocab_size, bias=False,
798
+ dtype=config.torch_dtype, device=device)
799
+
800
+ self.config = config
801
+
802
+ def get_input_embeddings(self):
803
+ return self.model.get_input_embeddings()
804
+
805
+ def _update_model_kwargs_for_generation(
806
+ self,
807
+ outputs: ModelOutput,
808
+ model_kwargs: Dict[str, Any],
809
+ is_encoder_decoder: bool = False,
810
+ standardize_cache_format: bool = False,
811
+ ) -> Dict[str, Any]:
812
+ # update past_key_values
813
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
814
+ outputs, standardize_cache_format=standardize_cache_format
815
+ )
816
+
817
+ # update attention mask
818
+ if "attention_mask" in model_kwargs:
819
+ attention_mask = model_kwargs["attention_mask"]
820
+ model_kwargs["attention_mask"] = torch.cat(
821
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
822
+ )
823
+
824
+ # update position ids
825
+ if "position_ids" in model_kwargs:
826
+ position_ids = model_kwargs["position_ids"]
827
+ new_position_id = position_ids[..., -1:].clone()
828
+ new_position_id += 1
829
+ model_kwargs["position_ids"] = torch.cat(
830
+ [position_ids, new_position_id], dim=-1
831
+ )
832
+
833
+ model_kwargs["is_first_forward"] = False
834
+ return model_kwargs
835
+
836
+ def prepare_inputs_for_generation(
837
+ self,
838
+ input_ids: torch.LongTensor,
839
+ past_key_values: Optional[torch.Tensor] = None,
840
+ attention_mask: Optional[torch.Tensor] = None,
841
+ position_ids: Optional[torch.Tensor] = None,
842
+ is_first_forward: bool = True,
843
+ **kwargs
844
+ ) -> dict:
845
+
846
+ # only last token for input_ids if past is not None
847
+ if position_ids is None:
848
+ position_ids = _build_position_ids(input_ids, device=input_ids.device)
849
+
850
+ if not is_first_forward:
851
+ position_ids = position_ids[..., -1:]
852
+ input_ids = input_ids[:, -1:]
853
+
854
+ return {
855
+ "input_ids": input_ids,
856
+ "past_key_values": past_key_values,
857
+ "position_ids": position_ids,
858
+ "attention_mask": attention_mask,
859
+ "return_last_logit": True
860
+ }
861
+
862
+ def forward(
863
+ self,
864
+ input_ids: Optional[torch.Tensor] = None,
865
+ position_ids: Optional[torch.Tensor] = None,
866
+ attention_mask: Optional[torch.Tensor] = None,
867
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
868
+ inputs_embeds: Optional[torch.Tensor] = None,
869
+ labels: Optional[torch.Tensor] = None,
870
+ use_cache: Optional[bool] = None,
871
+ output_attentions: Optional[bool] = None,
872
+ output_hidden_states: Optional[bool] = None,
873
+ return_dict: Optional[bool] = None,
874
+ return_last_logit: Optional[bool] = False,
875
+ ):
876
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
877
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
878
+
879
+ encodings = self.model(
880
+ input_ids=input_ids,
881
+ position_ids=position_ids,
882
+ attention_mask=attention_mask,
883
+ past_key_values=past_key_values,
884
+ inputs_embeds=inputs_embeds,
885
+ use_cache=use_cache,
886
+ output_hidden_states=output_hidden_states,
887
+ return_dict=return_dict,
888
+ )
889
+
890
+ hidden_states = encodings[0]
891
+ if return_last_logit:
892
+ hidden_states = hidden_states[-1:]
893
+
894
+ lm_logits = self.lm_head(hidden_states)
895
+
896
+ lm_logits = lm_logits.transpose(0, 1).contiguous()
897
+
898
+ loss = None
899
+ if labels is not None:
900
+ lm_logits = lm_logits.to(torch.float32)
901
+
902
+ # Shift so that tokens < n predict n
903
+ shift_logits = lm_logits[..., :-1, :].contiguous()
904
+ shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
905
+ # Flatten the tokens
906
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
907
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
908
+
909
+ lm_logits = lm_logits.to(hidden_states.dtype)
910
+ loss = loss.to(hidden_states.dtype)
911
+
912
+ if not return_dict:
913
+ output = (lm_logits,) + encodings[1:]
914
+ return ((loss,) + output) if loss is not None else output
915
+
916
+ return CausalLMOutputWithPast(
917
+ loss=loss,
918
+ logits=lm_logits,
919
+ past_key_values=encodings.past_key_values,
920
+ hidden_states=encodings.hidden_states,
921
+ attentions=encodings.attentions,
922
+ )
923
+
924
+ @staticmethod
925
+ def _reorder_cache(
926
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
927
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
928
+ """
929
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
930
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
931
+ beam_idx at every generation step.
932
+
933
+ Output shares the same memory storage as `past`.
934
+ """
935
+ return tuple(
936
+ (
937
+ layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
938
+ layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
939
+ )
940
+ for layer_past in past
941
+ )
942
+
943
+
944
+ def quantize(self, bits: int):
945
+ try:
946
+ # from .quantizer import QLinear
947
+ from quantizer import QLinear
948
+ except ImportError:
949
+ raise ImportError(
950
+ f"Needs QLinear to run quantize."
951
+ )
952
+
953
+ for layer in self.model.encoder.layers:
954
+ layer.self_attention.query_proj = QLinear(
955
+ bits=bits,
956
+ weight=layer.self_attention.query_proj.weight,
957
+ bias = layer.self_attention.query_proj.bias if self.config.qkv_bias else None,
958
+ )
959
+ layer.self_attention.key_proj = QLinear(
960
+ bits=bits,
961
+ weight=layer.self_attention.key_proj.weight,
962
+ bias = layer.self_attention.key_proj.bias if self.config.qkv_bias else None,
963
+ )
964
+ layer.self_attention.value_proj = QLinear(
965
+ bits=bits,
966
+ weight=layer.self_attention.value_proj.weight,
967
+ bias = layer.self_attention.value_proj.bias if self.config.qkv_bias else None,
968
+ )
969
+ layer.self_attention.dense = QLinear(
970
+ bits=bits,
971
+ weight=layer.self_attention.dense.weight,
972
+ bias = None,
973
+ )
974
+ layer.mlp.dense_h_to_4h = QLinear(
975
+ bits=bits,
976
+ weight=layer.mlp.dense_h_to_4h.weight,
977
+ bias = None,
978
+ )
979
+ layer.mlp.dense_4h_to_h = QLinear(
980
+ bits=bits,
981
+ weight=layer.mlp.dense_4h_to_h.weight,
982
+ bias = None,
983
+ )
984
+ if self.config.mlp_activation == "silu":
985
+ layer.mlp.gate_proj = QLinear(
986
+ bits=bits,
987
+ weight=layer.mlp.gate_proj.weight,
988
+ bias = None,
989
+ )
990
+ return self
991
+
992
+
993
+ def process_response(self, response):
994
+ response = response.strip()
995
+ return response
996
+
997
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system_prompt = None):
998
+ inputs = tokenizer.build_inputs(query, history=history, system_prompt=system_prompt)
999
+ inputs = inputs.to(self.device)
1000
+ return inputs
1001
+
1002
+ def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system_prompt = None):
1003
+ inputs = tokenizer.build_stream_inputs(query, history=history, system_prompt=system_prompt)
1004
+ inputs = inputs.to(self.device)
1005
+ return inputs
1006
+
1007
+ @torch.no_grad()
1008
+ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system_prompt=None, max_length: int = 8192, num_beams=1,
1009
+ do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
1010
+ if history is None:
1011
+ history = []
1012
+ if logits_processor is None:
1013
+ logits_processor = LogitsProcessorList()
1014
+ logits_processor.append(InvalidScoreLogitsProcessor())
1015
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1016
+ "temperature": temperature, **kwargs} #, "logits_processor": logits_processor
1017
+ inputs = self.build_inputs(tokenizer, query, history=history, system_prompt=system_prompt)
1018
+ outputs = self.generate(**inputs, **gen_kwargs)
1019
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
1020
+ response = tokenizer.decode(outputs, skip_special_tokens=True) #
1021
+ response = self.process_response(response)
1022
+ history = history + [(query, response)]
1023
+ return response, history
1024
+
1025
+ @torch.no_grad()
1026
+ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, system_prompt=None, past_key_values=None,
1027
+ max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1028
+ return_past_key_values=False, **kwargs):
1029
+ if history is None:
1030
+ history = []
1031
+ if logits_processor is None:
1032
+ logits_processor = LogitsProcessorList()
1033
+ logits_processor.append(InvalidScoreLogitsProcessor())
1034
+ gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1035
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1036
+ if past_key_values is None and not return_past_key_values:
1037
+ inputs = self.build_inputs(tokenizer, query, history=history, system_prompt=system_prompt)
1038
+ else:
1039
+ inputs = self.build_stream_inputs(tokenizer, query, history=history, system_prompt=system_prompt)
1040
+ if past_key_values is not None:
1041
+ past_length = past_key_values[0][0].shape[0]
1042
+ if self.model.prefix_size is not None:
1043
+ past_length -= self.transformer.prefix_size
1044
+ inputs.position_ids += past_length
1045
+ attention_mask = inputs.attention_mask
1046
+ attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
1047
+ inputs['attention_mask'] = attention_mask
1048
+ for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
1049
+ return_past_key_values=return_past_key_values, **gen_kwargs):
1050
+ if return_past_key_values:
1051
+ outputs, past_key_values = outputs
1052
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
1053
+ response = tokenizer.decode(outputs)
1054
+ if response and response[-1] != "�":
1055
+ response = self.process_response(response)
1056
+ new_history = history + [(query, response)]
1057
+ if return_past_key_values:
1058
+ yield response, new_history, past_key_values
1059
+ else:
1060
+ yield response, new_history
1061
+
1062
+ @torch.no_grad()
1063
+ def stream_generate(
1064
+ self,
1065
+ input_ids,
1066
+ generation_config: Optional[GenerationConfig] = None,
1067
+ logits_processor: Optional[LogitsProcessorList] = None,
1068
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1069
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1070
+ return_past_key_values=False,
1071
+ **kwargs,
1072
+ ):
1073
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1074
+
1075
+ if generation_config is None:
1076
+ generation_config = self.generation_config
1077
+ generation_config = copy.deepcopy(generation_config)
1078
+ model_kwargs = generation_config.update(**kwargs)
1079
+ bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1080
+
1081
+ if isinstance(eos_token_id, int):
1082
+ eos_token_id = [eos_token_id]
1083
+
1084
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1085
+ if has_default_max_length and generation_config.max_new_tokens is None:
1086
+ warnings.warn(
1087
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1088
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1089
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
1090
+ UserWarning,
1091
+ )
1092
+ elif generation_config.max_new_tokens is not None:
1093
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1094
+ if not has_default_max_length:
1095
+ logger.warn(
1096
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1097
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1098
+ "Please refer to the documentation for more information. "
1099
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1100
+ UserWarning,
1101
+ )
1102
+
1103
+ if input_ids_seq_length >= generation_config.max_length:
1104
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1105
+ logger.warning(
1106
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1107
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1108
+ " increasing `max_new_tokens`."
1109
+ )
1110
+
1111
+ # 2. Set generation parameters if not already defined
1112
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1113
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1114
+
1115
+ logits_processor = self._get_logits_processor(
1116
+ generation_config=generation_config,
1117
+ input_ids_seq_length=input_ids_seq_length,
1118
+ encoder_input_ids=input_ids,
1119
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1120
+ logits_processor=logits_processor,
1121
+ )
1122
+
1123
+ stopping_criteria = self._get_stopping_criteria(
1124
+ generation_config=generation_config, stopping_criteria=stopping_criteria
1125
+ )
1126
+ logits_warper = self._get_logits_warper(generation_config)
1127
+
1128
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1129
+ scores = None
1130
+ while True:
1131
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1132
+ # forward pass to get next token
1133
+ outputs = self(
1134
+ **model_inputs,
1135
+ return_dict=True,
1136
+ output_attentions=False,
1137
+ output_hidden_states=False,
1138
+ )
1139
+
1140
+ next_token_logits = outputs.logits[:, -1, :]
1141
+
1142
+ # pre-process distribution
1143
+ next_token_scores = logits_processor(input_ids, next_token_logits)
1144
+ next_token_scores = logits_warper(input_ids, next_token_scores)
1145
+
1146
+ # sample
1147
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
1148
+ if generation_config.do_sample:
1149
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1150
+ else:
1151
+ next_tokens = torch.argmax(probs, dim=-1)
1152
+
1153
+ # update generated ids, model inputs, and length for next step
1154
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1155
+ model_kwargs = self._update_model_kwargs_for_generation(
1156
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1157
+ )
1158
+ unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
1159
+ if return_past_key_values:
1160
+ yield input_ids, outputs.past_key_values
1161
+ else:
1162
+ yield input_ids
1163
+ # stop when each sentence is finished, or if we exceed the maximum length
1164
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1165
+ break
1166
+
pytorch_model-00001-of-00004.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f0c35d490a8038ff1798cba9649679f94892be8e49b444bcc3d737aa3fc0741
3
+ size 9803365668
pytorch_model-00002-of-00004.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:462f69b6b67955a7427fc1d5c50ec9f637e2ab41bece82b52d14d744137ba9db
3
+ size 9990811432
pytorch_model-00003-of-00004.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12eef94f37b4431662a71f0b680c9c767ee347c98324354c273281e1e02f54e4
3
+ size 9527983699
pytorch_model-00004-of-00004.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db34e02ed0155656c0b998fb6d4e56914a1bcec17b0ccc580931b4019cc65b9c
3
+ size 738198442
pytorch_model.bin.index.json ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 30060162048
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "pytorch_model-00004-of-00004.bin",
7
+ "model.encoder.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
8
+ "model.encoder.layers.0.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00004.bin",
9
+ "model.encoder.layers.0.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00004.bin",
10
+ "model.encoder.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
11
+ "model.encoder.layers.0.self_attention.dense.weight": "pytorch_model-00001-of-00004.bin",
12
+ "model.encoder.layers.0.self_attention.key_proj.bias": "pytorch_model-00001-of-00004.bin",
13
+ "model.encoder.layers.0.self_attention.key_proj.weight": "pytorch_model-00001-of-00004.bin",
14
+ "model.encoder.layers.0.self_attention.query_proj.bias": "pytorch_model-00001-of-00004.bin",
15
+ "model.encoder.layers.0.self_attention.query_proj.weight": "pytorch_model-00001-of-00004.bin",
16
+ "model.encoder.layers.0.self_attention.value_proj.bias": "pytorch_model-00001-of-00004.bin",
17
+ "model.encoder.layers.0.self_attention.value_proj.weight": "pytorch_model-00001-of-00004.bin",
18
+ "model.encoder.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
19
+ "model.encoder.layers.1.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00004.bin",
20
+ "model.encoder.layers.1.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00004.bin",
21
+ "model.encoder.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
22
+ "model.encoder.layers.1.self_attention.dense.weight": "pytorch_model-00001-of-00004.bin",
23
+ "model.encoder.layers.1.self_attention.key_proj.bias": "pytorch_model-00001-of-00004.bin",
24
+ "model.encoder.layers.1.self_attention.key_proj.weight": "pytorch_model-00001-of-00004.bin",
25
+ "model.encoder.layers.1.self_attention.query_proj.bias": "pytorch_model-00001-of-00004.bin",
26
+ "model.encoder.layers.1.self_attention.query_proj.weight": "pytorch_model-00001-of-00004.bin",
27
+ "model.encoder.layers.1.self_attention.value_proj.bias": "pytorch_model-00001-of-00004.bin",
28
+ "model.encoder.layers.1.self_attention.value_proj.weight": "pytorch_model-00001-of-00004.bin",
29
+ "model.encoder.layers.10.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
30
+ "model.encoder.layers.10.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00004.bin",
31
+ "model.encoder.layers.10.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00004.bin",
32
+ "model.encoder.layers.10.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
33
+ "model.encoder.layers.10.self_attention.dense.weight": "pytorch_model-00001-of-00004.bin",
34
+ "model.encoder.layers.10.self_attention.key_proj.bias": "pytorch_model-00001-of-00004.bin",
35
+ "model.encoder.layers.10.self_attention.key_proj.weight": "pytorch_model-00001-of-00004.bin",
36
+ "model.encoder.layers.10.self_attention.query_proj.bias": "pytorch_model-00001-of-00004.bin",
37
+ "model.encoder.layers.10.self_attention.query_proj.weight": "pytorch_model-00001-of-00004.bin",
38
+ "model.encoder.layers.10.self_attention.value_proj.bias": "pytorch_model-00001-of-00004.bin",
39
+ "model.encoder.layers.10.self_attention.value_proj.weight": "pytorch_model-00001-of-00004.bin",
40
+ "model.encoder.layers.11.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
41
+ "model.encoder.layers.11.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00004.bin",
42
+ "model.encoder.layers.11.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00004.bin",
43
+ "model.encoder.layers.11.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
44
+ "model.encoder.layers.11.self_attention.dense.weight": "pytorch_model-00001-of-00004.bin",
45
+ "model.encoder.layers.11.self_attention.key_proj.bias": "pytorch_model-00001-of-00004.bin",
46
+ "model.encoder.layers.11.self_attention.key_proj.weight": "pytorch_model-00001-of-00004.bin",
47
+ "model.encoder.layers.11.self_attention.query_proj.bias": "pytorch_model-00001-of-00004.bin",
48
+ "model.encoder.layers.11.self_attention.query_proj.weight": "pytorch_model-00001-of-00004.bin",
49
+ "model.encoder.layers.11.self_attention.value_proj.bias": "pytorch_model-00001-of-00004.bin",
50
+ "model.encoder.layers.11.self_attention.value_proj.weight": "pytorch_model-00001-of-00004.bin",
51
+ "model.encoder.layers.12.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
52
+ "model.encoder.layers.12.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00004.bin",
53
+ "model.encoder.layers.12.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00004.bin",
54
+ "model.encoder.layers.12.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
55
+ "model.encoder.layers.12.self_attention.dense.weight": "pytorch_model-00001-of-00004.bin",
56
+ "model.encoder.layers.12.self_attention.key_proj.bias": "pytorch_model-00001-of-00004.bin",
57
+ "model.encoder.layers.12.self_attention.key_proj.weight": "pytorch_model-00001-of-00004.bin",
58
+ "model.encoder.layers.12.self_attention.query_proj.bias": "pytorch_model-00001-of-00004.bin",
59
+ "model.encoder.layers.12.self_attention.query_proj.weight": "pytorch_model-00001-of-00004.bin",
60
+ "model.encoder.layers.12.self_attention.value_proj.bias": "pytorch_model-00001-of-00004.bin",
61
+ "model.encoder.layers.12.self_attention.value_proj.weight": "pytorch_model-00001-of-00004.bin",
62
+ "model.encoder.layers.13.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
63
+ "model.encoder.layers.13.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00004.bin",
64
+ "model.encoder.layers.13.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00004.bin",
65
+ "model.encoder.layers.13.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
66
+ "model.encoder.layers.13.self_attention.dense.weight": "pytorch_model-00001-of-00004.bin",
67
+ "model.encoder.layers.13.self_attention.key_proj.bias": "pytorch_model-00001-of-00004.bin",
68
+ "model.encoder.layers.13.self_attention.key_proj.weight": "pytorch_model-00001-of-00004.bin",
69
+ "model.encoder.layers.13.self_attention.query_proj.bias": "pytorch_model-00001-of-00004.bin",
70
+ "model.encoder.layers.13.self_attention.query_proj.weight": "pytorch_model-00001-of-00004.bin",
71
+ "model.encoder.layers.13.self_attention.value_proj.bias": "pytorch_model-00001-of-00004.bin",
72
+ "model.encoder.layers.13.self_attention.value_proj.weight": "pytorch_model-00001-of-00004.bin",
73
+ "model.encoder.layers.14.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
74
+ "model.encoder.layers.14.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00004.bin",
75
+ "model.encoder.layers.14.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00004.bin",
76
+ "model.encoder.layers.14.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
77
+ "model.encoder.layers.14.self_attention.dense.weight": "pytorch_model-00001-of-00004.bin",
78
+ "model.encoder.layers.14.self_attention.key_proj.bias": "pytorch_model-00001-of-00004.bin",
79
+ "model.encoder.layers.14.self_attention.key_proj.weight": "pytorch_model-00001-of-00004.bin",
80
+ "model.encoder.layers.14.self_attention.query_proj.bias": "pytorch_model-00001-of-00004.bin",
81
+ "model.encoder.layers.14.self_attention.query_proj.weight": "pytorch_model-00001-of-00004.bin",
82
+ "model.encoder.layers.14.self_attention.value_proj.bias": "pytorch_model-00001-of-00004.bin",
83
+ "model.encoder.layers.14.self_attention.value_proj.weight": "pytorch_model-00001-of-00004.bin",
84
+ "model.encoder.layers.15.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
85
+ "model.encoder.layers.15.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00004.bin",
86
+ "model.encoder.layers.15.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00004.bin",
87
+ "model.encoder.layers.15.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
88
+ "model.encoder.layers.15.self_attention.dense.weight": "pytorch_model-00001-of-00004.bin",
89
+ "model.encoder.layers.15.self_attention.key_proj.bias": "pytorch_model-00001-of-00004.bin",
90
+ "model.encoder.layers.15.self_attention.key_proj.weight": "pytorch_model-00001-of-00004.bin",
91
+ "model.encoder.layers.15.self_attention.query_proj.bias": "pytorch_model-00001-of-00004.bin",
92
+ "model.encoder.layers.15.self_attention.query_proj.weight": "pytorch_model-00001-of-00004.bin",
93
+ "model.encoder.layers.15.self_attention.value_proj.bias": "pytorch_model-00001-of-00004.bin",
94
+ "model.encoder.layers.15.self_attention.value_proj.weight": "pytorch_model-00001-of-00004.bin",
95
+ "model.encoder.layers.16.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
96
+ "model.encoder.layers.16.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00004.bin",
97
+ "model.encoder.layers.16.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00004.bin",
98
+ "model.encoder.layers.16.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
99
+ "model.encoder.layers.16.self_attention.dense.weight": "pytorch_model-00002-of-00004.bin",
100
+ "model.encoder.layers.16.self_attention.key_proj.bias": "pytorch_model-00002-of-00004.bin",
101
+ "model.encoder.layers.16.self_attention.key_proj.weight": "pytorch_model-00002-of-00004.bin",
102
+ "model.encoder.layers.16.self_attention.query_proj.bias": "pytorch_model-00002-of-00004.bin",
103
+ "model.encoder.layers.16.self_attention.query_proj.weight": "pytorch_model-00002-of-00004.bin",
104
+ "model.encoder.layers.16.self_attention.value_proj.bias": "pytorch_model-00002-of-00004.bin",
105
+ "model.encoder.layers.16.self_attention.value_proj.weight": "pytorch_model-00002-of-00004.bin",
106
+ "model.encoder.layers.17.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
107
+ "model.encoder.layers.17.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00004.bin",
108
+ "model.encoder.layers.17.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00004.bin",
109
+ "model.encoder.layers.17.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
110
+ "model.encoder.layers.17.self_attention.dense.weight": "pytorch_model-00002-of-00004.bin",
111
+ "model.encoder.layers.17.self_attention.key_proj.bias": "pytorch_model-00002-of-00004.bin",
112
+ "model.encoder.layers.17.self_attention.key_proj.weight": "pytorch_model-00002-of-00004.bin",
113
+ "model.encoder.layers.17.self_attention.query_proj.bias": "pytorch_model-00002-of-00004.bin",
114
+ "model.encoder.layers.17.self_attention.query_proj.weight": "pytorch_model-00002-of-00004.bin",
115
+ "model.encoder.layers.17.self_attention.value_proj.bias": "pytorch_model-00002-of-00004.bin",
116
+ "model.encoder.layers.17.self_attention.value_proj.weight": "pytorch_model-00002-of-00004.bin",
117
+ "model.encoder.layers.18.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
118
+ "model.encoder.layers.18.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00004.bin",
119
+ "model.encoder.layers.18.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00004.bin",
120
+ "model.encoder.layers.18.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
121
+ "model.encoder.layers.18.self_attention.dense.weight": "pytorch_model-00002-of-00004.bin",
122
+ "model.encoder.layers.18.self_attention.key_proj.bias": "pytorch_model-00002-of-00004.bin",
123
+ "model.encoder.layers.18.self_attention.key_proj.weight": "pytorch_model-00002-of-00004.bin",
124
+ "model.encoder.layers.18.self_attention.query_proj.bias": "pytorch_model-00002-of-00004.bin",
125
+ "model.encoder.layers.18.self_attention.query_proj.weight": "pytorch_model-00002-of-00004.bin",
126
+ "model.encoder.layers.18.self_attention.value_proj.bias": "pytorch_model-00002-of-00004.bin",
127
+ "model.encoder.layers.18.self_attention.value_proj.weight": "pytorch_model-00002-of-00004.bin",
128
+ "model.encoder.layers.19.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
129
+ "model.encoder.layers.19.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00004.bin",
130
+ "model.encoder.layers.19.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00004.bin",
131
+ "model.encoder.layers.19.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
132
+ "model.encoder.layers.19.self_attention.dense.weight": "pytorch_model-00002-of-00004.bin",
133
+ "model.encoder.layers.19.self_attention.key_proj.bias": "pytorch_model-00002-of-00004.bin",
134
+ "model.encoder.layers.19.self_attention.key_proj.weight": "pytorch_model-00002-of-00004.bin",
135
+ "model.encoder.layers.19.self_attention.query_proj.bias": "pytorch_model-00002-of-00004.bin",
136
+ "model.encoder.layers.19.self_attention.query_proj.weight": "pytorch_model-00002-of-00004.bin",
137
+ "model.encoder.layers.19.self_attention.value_proj.bias": "pytorch_model-00002-of-00004.bin",
138
+ "model.encoder.layers.19.self_attention.value_proj.weight": "pytorch_model-00002-of-00004.bin",
139
+ "model.encoder.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
140
+ "model.encoder.layers.2.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00004.bin",
141
+ "model.encoder.layers.2.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00004.bin",
142
+ "model.encoder.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
143
+ "model.encoder.layers.2.self_attention.dense.weight": "pytorch_model-00001-of-00004.bin",
144
+ "model.encoder.layers.2.self_attention.key_proj.bias": "pytorch_model-00001-of-00004.bin",
145
+ "model.encoder.layers.2.self_attention.key_proj.weight": "pytorch_model-00001-of-00004.bin",
146
+ "model.encoder.layers.2.self_attention.query_proj.bias": "pytorch_model-00001-of-00004.bin",
147
+ "model.encoder.layers.2.self_attention.query_proj.weight": "pytorch_model-00001-of-00004.bin",
148
+ "model.encoder.layers.2.self_attention.value_proj.bias": "pytorch_model-00001-of-00004.bin",
149
+ "model.encoder.layers.2.self_attention.value_proj.weight": "pytorch_model-00001-of-00004.bin",
150
+ "model.encoder.layers.20.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
151
+ "model.encoder.layers.20.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00004.bin",
152
+ "model.encoder.layers.20.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00004.bin",
153
+ "model.encoder.layers.20.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
154
+ "model.encoder.layers.20.self_attention.dense.weight": "pytorch_model-00002-of-00004.bin",
155
+ "model.encoder.layers.20.self_attention.key_proj.bias": "pytorch_model-00002-of-00004.bin",
156
+ "model.encoder.layers.20.self_attention.key_proj.weight": "pytorch_model-00002-of-00004.bin",
157
+ "model.encoder.layers.20.self_attention.query_proj.bias": "pytorch_model-00002-of-00004.bin",
158
+ "model.encoder.layers.20.self_attention.query_proj.weight": "pytorch_model-00002-of-00004.bin",
159
+ "model.encoder.layers.20.self_attention.value_proj.bias": "pytorch_model-00002-of-00004.bin",
160
+ "model.encoder.layers.20.self_attention.value_proj.weight": "pytorch_model-00002-of-00004.bin",
161
+ "model.encoder.layers.21.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
162
+ "model.encoder.layers.21.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00004.bin",
163
+ "model.encoder.layers.21.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00004.bin",
164
+ "model.encoder.layers.21.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
165
+ "model.encoder.layers.21.self_attention.dense.weight": "pytorch_model-00002-of-00004.bin",
166
+ "model.encoder.layers.21.self_attention.key_proj.bias": "pytorch_model-00002-of-00004.bin",
167
+ "model.encoder.layers.21.self_attention.key_proj.weight": "pytorch_model-00002-of-00004.bin",
168
+ "model.encoder.layers.21.self_attention.query_proj.bias": "pytorch_model-00002-of-00004.bin",
169
+ "model.encoder.layers.21.self_attention.query_proj.weight": "pytorch_model-00002-of-00004.bin",
170
+ "model.encoder.layers.21.self_attention.value_proj.bias": "pytorch_model-00002-of-00004.bin",
171
+ "model.encoder.layers.21.self_attention.value_proj.weight": "pytorch_model-00002-of-00004.bin",
172
+ "model.encoder.layers.22.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
173
+ "model.encoder.layers.22.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00004.bin",
174
+ "model.encoder.layers.22.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00004.bin",
175
+ "model.encoder.layers.22.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
176
+ "model.encoder.layers.22.self_attention.dense.weight": "pytorch_model-00002-of-00004.bin",
177
+ "model.encoder.layers.22.self_attention.key_proj.bias": "pytorch_model-00002-of-00004.bin",
178
+ "model.encoder.layers.22.self_attention.key_proj.weight": "pytorch_model-00002-of-00004.bin",
179
+ "model.encoder.layers.22.self_attention.query_proj.bias": "pytorch_model-00002-of-00004.bin",
180
+ "model.encoder.layers.22.self_attention.query_proj.weight": "pytorch_model-00002-of-00004.bin",
181
+ "model.encoder.layers.22.self_attention.value_proj.bias": "pytorch_model-00002-of-00004.bin",
182
+ "model.encoder.layers.22.self_attention.value_proj.weight": "pytorch_model-00002-of-00004.bin",
183
+ "model.encoder.layers.23.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
184
+ "model.encoder.layers.23.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00004.bin",
185
+ "model.encoder.layers.23.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00004.bin",
186
+ "model.encoder.layers.23.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
187
+ "model.encoder.layers.23.self_attention.dense.weight": "pytorch_model-00002-of-00004.bin",
188
+ "model.encoder.layers.23.self_attention.key_proj.bias": "pytorch_model-00002-of-00004.bin",
189
+ "model.encoder.layers.23.self_attention.key_proj.weight": "pytorch_model-00002-of-00004.bin",
190
+ "model.encoder.layers.23.self_attention.query_proj.bias": "pytorch_model-00002-of-00004.bin",
191
+ "model.encoder.layers.23.self_attention.query_proj.weight": "pytorch_model-00002-of-00004.bin",
192
+ "model.encoder.layers.23.self_attention.value_proj.bias": "pytorch_model-00002-of-00004.bin",
193
+ "model.encoder.layers.23.self_attention.value_proj.weight": "pytorch_model-00002-of-00004.bin",
194
+ "model.encoder.layers.24.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
195
+ "model.encoder.layers.24.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00004.bin",
196
+ "model.encoder.layers.24.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00004.bin",
197
+ "model.encoder.layers.24.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
198
+ "model.encoder.layers.24.self_attention.dense.weight": "pytorch_model-00002-of-00004.bin",
199
+ "model.encoder.layers.24.self_attention.key_proj.bias": "pytorch_model-00002-of-00004.bin",
200
+ "model.encoder.layers.24.self_attention.key_proj.weight": "pytorch_model-00002-of-00004.bin",
201
+ "model.encoder.layers.24.self_attention.query_proj.bias": "pytorch_model-00002-of-00004.bin",
202
+ "model.encoder.layers.24.self_attention.query_proj.weight": "pytorch_model-00002-of-00004.bin",
203
+ "model.encoder.layers.24.self_attention.value_proj.bias": "pytorch_model-00002-of-00004.bin",
204
+ "model.encoder.layers.24.self_attention.value_proj.weight": "pytorch_model-00002-of-00004.bin",
205
+ "model.encoder.layers.25.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
206
+ "model.encoder.layers.25.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00004.bin",
207
+ "model.encoder.layers.25.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00004.bin",
208
+ "model.encoder.layers.25.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
209
+ "model.encoder.layers.25.self_attention.dense.weight": "pytorch_model-00002-of-00004.bin",
210
+ "model.encoder.layers.25.self_attention.key_proj.bias": "pytorch_model-00002-of-00004.bin",
211
+ "model.encoder.layers.25.self_attention.key_proj.weight": "pytorch_model-00002-of-00004.bin",
212
+ "model.encoder.layers.25.self_attention.query_proj.bias": "pytorch_model-00002-of-00004.bin",
213
+ "model.encoder.layers.25.self_attention.query_proj.weight": "pytorch_model-00002-of-00004.bin",
214
+ "model.encoder.layers.25.self_attention.value_proj.bias": "pytorch_model-00002-of-00004.bin",
215
+ "model.encoder.layers.25.self_attention.value_proj.weight": "pytorch_model-00002-of-00004.bin",
216
+ "model.encoder.layers.26.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
217
+ "model.encoder.layers.26.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00004.bin",
218
+ "model.encoder.layers.26.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00004.bin",
219
+ "model.encoder.layers.26.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
220
+ "model.encoder.layers.26.self_attention.dense.weight": "pytorch_model-00002-of-00004.bin",
221
+ "model.encoder.layers.26.self_attention.key_proj.bias": "pytorch_model-00002-of-00004.bin",
222
+ "model.encoder.layers.26.self_attention.key_proj.weight": "pytorch_model-00002-of-00004.bin",
223
+ "model.encoder.layers.26.self_attention.query_proj.bias": "pytorch_model-00002-of-00004.bin",
224
+ "model.encoder.layers.26.self_attention.query_proj.weight": "pytorch_model-00002-of-00004.bin",
225
+ "model.encoder.layers.26.self_attention.value_proj.bias": "pytorch_model-00002-of-00004.bin",
226
+ "model.encoder.layers.26.self_attention.value_proj.weight": "pytorch_model-00002-of-00004.bin",
227
+ "model.encoder.layers.27.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
228
+ "model.encoder.layers.27.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00004.bin",
229
+ "model.encoder.layers.27.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00004.bin",
230
+ "model.encoder.layers.27.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
231
+ "model.encoder.layers.27.self_attention.dense.weight": "pytorch_model-00002-of-00004.bin",
232
+ "model.encoder.layers.27.self_attention.key_proj.bias": "pytorch_model-00002-of-00004.bin",
233
+ "model.encoder.layers.27.self_attention.key_proj.weight": "pytorch_model-00002-of-00004.bin",
234
+ "model.encoder.layers.27.self_attention.query_proj.bias": "pytorch_model-00002-of-00004.bin",
235
+ "model.encoder.layers.27.self_attention.query_proj.weight": "pytorch_model-00002-of-00004.bin",
236
+ "model.encoder.layers.27.self_attention.value_proj.bias": "pytorch_model-00002-of-00004.bin",
237
+ "model.encoder.layers.27.self_attention.value_proj.weight": "pytorch_model-00002-of-00004.bin",
238
+ "model.encoder.layers.28.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
239
+ "model.encoder.layers.28.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00004.bin",
240
+ "model.encoder.layers.28.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00004.bin",
241
+ "model.encoder.layers.28.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
242
+ "model.encoder.layers.28.self_attention.dense.weight": "pytorch_model-00002-of-00004.bin",
243
+ "model.encoder.layers.28.self_attention.key_proj.bias": "pytorch_model-00002-of-00004.bin",
244
+ "model.encoder.layers.28.self_attention.key_proj.weight": "pytorch_model-00002-of-00004.bin",
245
+ "model.encoder.layers.28.self_attention.query_proj.bias": "pytorch_model-00002-of-00004.bin",
246
+ "model.encoder.layers.28.self_attention.query_proj.weight": "pytorch_model-00002-of-00004.bin",
247
+ "model.encoder.layers.28.self_attention.value_proj.bias": "pytorch_model-00002-of-00004.bin",
248
+ "model.encoder.layers.28.self_attention.value_proj.weight": "pytorch_model-00002-of-00004.bin",
249
+ "model.encoder.layers.29.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
250
+ "model.encoder.layers.29.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00004.bin",
251
+ "model.encoder.layers.29.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00004.bin",
252
+ "model.encoder.layers.29.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
253
+ "model.encoder.layers.29.self_attention.dense.weight": "pytorch_model-00002-of-00004.bin",
254
+ "model.encoder.layers.29.self_attention.key_proj.bias": "pytorch_model-00002-of-00004.bin",
255
+ "model.encoder.layers.29.self_attention.key_proj.weight": "pytorch_model-00002-of-00004.bin",
256
+ "model.encoder.layers.29.self_attention.query_proj.bias": "pytorch_model-00002-of-00004.bin",
257
+ "model.encoder.layers.29.self_attention.query_proj.weight": "pytorch_model-00002-of-00004.bin",
258
+ "model.encoder.layers.29.self_attention.value_proj.bias": "pytorch_model-00002-of-00004.bin",
259
+ "model.encoder.layers.29.self_attention.value_proj.weight": "pytorch_model-00002-of-00004.bin",
260
+ "model.encoder.layers.3.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
261
+ "model.encoder.layers.3.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00004.bin",
262
+ "model.encoder.layers.3.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00004.bin",
263
+ "model.encoder.layers.3.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
264
+ "model.encoder.layers.3.self_attention.dense.weight": "pytorch_model-00001-of-00004.bin",
265
+ "model.encoder.layers.3.self_attention.key_proj.bias": "pytorch_model-00001-of-00004.bin",
266
+ "model.encoder.layers.3.self_attention.key_proj.weight": "pytorch_model-00001-of-00004.bin",
267
+ "model.encoder.layers.3.self_attention.query_proj.bias": "pytorch_model-00001-of-00004.bin",
268
+ "model.encoder.layers.3.self_attention.query_proj.weight": "pytorch_model-00001-of-00004.bin",
269
+ "model.encoder.layers.3.self_attention.value_proj.bias": "pytorch_model-00001-of-00004.bin",
270
+ "model.encoder.layers.3.self_attention.value_proj.weight": "pytorch_model-00001-of-00004.bin",
271
+ "model.encoder.layers.30.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
272
+ "model.encoder.layers.30.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00004.bin",
273
+ "model.encoder.layers.30.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00004.bin",
274
+ "model.encoder.layers.30.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
275
+ "model.encoder.layers.30.self_attention.dense.weight": "pytorch_model-00002-of-00004.bin",
276
+ "model.encoder.layers.30.self_attention.key_proj.bias": "pytorch_model-00002-of-00004.bin",
277
+ "model.encoder.layers.30.self_attention.key_proj.weight": "pytorch_model-00002-of-00004.bin",
278
+ "model.encoder.layers.30.self_attention.query_proj.bias": "pytorch_model-00002-of-00004.bin",
279
+ "model.encoder.layers.30.self_attention.query_proj.weight": "pytorch_model-00002-of-00004.bin",
280
+ "model.encoder.layers.30.self_attention.value_proj.bias": "pytorch_model-00002-of-00004.bin",
281
+ "model.encoder.layers.30.self_attention.value_proj.weight": "pytorch_model-00002-of-00004.bin",
282
+ "model.encoder.layers.31.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
283
+ "model.encoder.layers.31.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00004.bin",
284
+ "model.encoder.layers.31.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00004.bin",
285
+ "model.encoder.layers.31.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
286
+ "model.encoder.layers.31.self_attention.dense.weight": "pytorch_model-00002-of-00004.bin",
287
+ "model.encoder.layers.31.self_attention.key_proj.bias": "pytorch_model-00002-of-00004.bin",
288
+ "model.encoder.layers.31.self_attention.key_proj.weight": "pytorch_model-00002-of-00004.bin",
289
+ "model.encoder.layers.31.self_attention.query_proj.bias": "pytorch_model-00002-of-00004.bin",
290
+ "model.encoder.layers.31.self_attention.query_proj.weight": "pytorch_model-00002-of-00004.bin",
291
+ "model.encoder.layers.31.self_attention.value_proj.bias": "pytorch_model-00002-of-00004.bin",
292
+ "model.encoder.layers.31.self_attention.value_proj.weight": "pytorch_model-00002-of-00004.bin",
293
+ "model.encoder.layers.32.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
294
+ "model.encoder.layers.32.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00004.bin",
295
+ "model.encoder.layers.32.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00004.bin",
296
+ "model.encoder.layers.32.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
297
+ "model.encoder.layers.32.self_attention.dense.weight": "pytorch_model-00003-of-00004.bin",
298
+ "model.encoder.layers.32.self_attention.key_proj.bias": "pytorch_model-00003-of-00004.bin",
299
+ "model.encoder.layers.32.self_attention.key_proj.weight": "pytorch_model-00003-of-00004.bin",
300
+ "model.encoder.layers.32.self_attention.query_proj.bias": "pytorch_model-00003-of-00004.bin",
301
+ "model.encoder.layers.32.self_attention.query_proj.weight": "pytorch_model-00003-of-00004.bin",
302
+ "model.encoder.layers.32.self_attention.value_proj.bias": "pytorch_model-00003-of-00004.bin",
303
+ "model.encoder.layers.32.self_attention.value_proj.weight": "pytorch_model-00003-of-00004.bin",
304
+ "model.encoder.layers.33.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
305
+ "model.encoder.layers.33.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00004.bin",
306
+ "model.encoder.layers.33.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00004.bin",
307
+ "model.encoder.layers.33.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
308
+ "model.encoder.layers.33.self_attention.dense.weight": "pytorch_model-00003-of-00004.bin",
309
+ "model.encoder.layers.33.self_attention.key_proj.bias": "pytorch_model-00003-of-00004.bin",
310
+ "model.encoder.layers.33.self_attention.key_proj.weight": "pytorch_model-00003-of-00004.bin",
311
+ "model.encoder.layers.33.self_attention.query_proj.bias": "pytorch_model-00003-of-00004.bin",
312
+ "model.encoder.layers.33.self_attention.query_proj.weight": "pytorch_model-00003-of-00004.bin",
313
+ "model.encoder.layers.33.self_attention.value_proj.bias": "pytorch_model-00003-of-00004.bin",
314
+ "model.encoder.layers.33.self_attention.value_proj.weight": "pytorch_model-00003-of-00004.bin",
315
+ "model.encoder.layers.34.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
316
+ "model.encoder.layers.34.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00004.bin",
317
+ "model.encoder.layers.34.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00004.bin",
318
+ "model.encoder.layers.34.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
319
+ "model.encoder.layers.34.self_attention.dense.weight": "pytorch_model-00003-of-00004.bin",
320
+ "model.encoder.layers.34.self_attention.key_proj.bias": "pytorch_model-00003-of-00004.bin",
321
+ "model.encoder.layers.34.self_attention.key_proj.weight": "pytorch_model-00003-of-00004.bin",
322
+ "model.encoder.layers.34.self_attention.query_proj.bias": "pytorch_model-00003-of-00004.bin",
323
+ "model.encoder.layers.34.self_attention.query_proj.weight": "pytorch_model-00003-of-00004.bin",
324
+ "model.encoder.layers.34.self_attention.value_proj.bias": "pytorch_model-00003-of-00004.bin",
325
+ "model.encoder.layers.34.self_attention.value_proj.weight": "pytorch_model-00003-of-00004.bin",
326
+ "model.encoder.layers.35.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
327
+ "model.encoder.layers.35.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00004.bin",
328
+ "model.encoder.layers.35.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00004.bin",
329
+ "model.encoder.layers.35.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
330
+ "model.encoder.layers.35.self_attention.dense.weight": "pytorch_model-00003-of-00004.bin",
331
+ "model.encoder.layers.35.self_attention.key_proj.bias": "pytorch_model-00003-of-00004.bin",
332
+ "model.encoder.layers.35.self_attention.key_proj.weight": "pytorch_model-00003-of-00004.bin",
333
+ "model.encoder.layers.35.self_attention.query_proj.bias": "pytorch_model-00003-of-00004.bin",
334
+ "model.encoder.layers.35.self_attention.query_proj.weight": "pytorch_model-00003-of-00004.bin",
335
+ "model.encoder.layers.35.self_attention.value_proj.bias": "pytorch_model-00003-of-00004.bin",
336
+ "model.encoder.layers.35.self_attention.value_proj.weight": "pytorch_model-00003-of-00004.bin",
337
+ "model.encoder.layers.36.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
338
+ "model.encoder.layers.36.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00004.bin",
339
+ "model.encoder.layers.36.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00004.bin",
340
+ "model.encoder.layers.36.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
341
+ "model.encoder.layers.36.self_attention.dense.weight": "pytorch_model-00003-of-00004.bin",
342
+ "model.encoder.layers.36.self_attention.key_proj.bias": "pytorch_model-00003-of-00004.bin",
343
+ "model.encoder.layers.36.self_attention.key_proj.weight": "pytorch_model-00003-of-00004.bin",
344
+ "model.encoder.layers.36.self_attention.query_proj.bias": "pytorch_model-00003-of-00004.bin",
345
+ "model.encoder.layers.36.self_attention.query_proj.weight": "pytorch_model-00003-of-00004.bin",
346
+ "model.encoder.layers.36.self_attention.value_proj.bias": "pytorch_model-00003-of-00004.bin",
347
+ "model.encoder.layers.36.self_attention.value_proj.weight": "pytorch_model-00003-of-00004.bin",
348
+ "model.encoder.layers.37.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
349
+ "model.encoder.layers.37.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00004.bin",
350
+ "model.encoder.layers.37.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00004.bin",
351
+ "model.encoder.layers.37.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
352
+ "model.encoder.layers.37.self_attention.dense.weight": "pytorch_model-00003-of-00004.bin",
353
+ "model.encoder.layers.37.self_attention.key_proj.bias": "pytorch_model-00003-of-00004.bin",
354
+ "model.encoder.layers.37.self_attention.key_proj.weight": "pytorch_model-00003-of-00004.bin",
355
+ "model.encoder.layers.37.self_attention.query_proj.bias": "pytorch_model-00003-of-00004.bin",
356
+ "model.encoder.layers.37.self_attention.query_proj.weight": "pytorch_model-00003-of-00004.bin",
357
+ "model.encoder.layers.37.self_attention.value_proj.bias": "pytorch_model-00003-of-00004.bin",
358
+ "model.encoder.layers.37.self_attention.value_proj.weight": "pytorch_model-00003-of-00004.bin",
359
+ "model.encoder.layers.38.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
360
+ "model.encoder.layers.38.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00004.bin",
361
+ "model.encoder.layers.38.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00004.bin",
362
+ "model.encoder.layers.38.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
363
+ "model.encoder.layers.38.self_attention.dense.weight": "pytorch_model-00003-of-00004.bin",
364
+ "model.encoder.layers.38.self_attention.key_proj.bias": "pytorch_model-00003-of-00004.bin",
365
+ "model.encoder.layers.38.self_attention.key_proj.weight": "pytorch_model-00003-of-00004.bin",
366
+ "model.encoder.layers.38.self_attention.query_proj.bias": "pytorch_model-00003-of-00004.bin",
367
+ "model.encoder.layers.38.self_attention.query_proj.weight": "pytorch_model-00003-of-00004.bin",
368
+ "model.encoder.layers.38.self_attention.value_proj.bias": "pytorch_model-00003-of-00004.bin",
369
+ "model.encoder.layers.38.self_attention.value_proj.weight": "pytorch_model-00003-of-00004.bin",
370
+ "model.encoder.layers.39.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
371
+ "model.encoder.layers.39.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00004.bin",
372
+ "model.encoder.layers.39.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00004.bin",
373
+ "model.encoder.layers.39.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
374
+ "model.encoder.layers.39.self_attention.dense.weight": "pytorch_model-00003-of-00004.bin",
375
+ "model.encoder.layers.39.self_attention.key_proj.bias": "pytorch_model-00003-of-00004.bin",
376
+ "model.encoder.layers.39.self_attention.key_proj.weight": "pytorch_model-00003-of-00004.bin",
377
+ "model.encoder.layers.39.self_attention.query_proj.bias": "pytorch_model-00003-of-00004.bin",
378
+ "model.encoder.layers.39.self_attention.query_proj.weight": "pytorch_model-00003-of-00004.bin",
379
+ "model.encoder.layers.39.self_attention.value_proj.bias": "pytorch_model-00003-of-00004.bin",
380
+ "model.encoder.layers.39.self_attention.value_proj.weight": "pytorch_model-00003-of-00004.bin",
381
+ "model.encoder.layers.4.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
382
+ "model.encoder.layers.4.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00004.bin",
383
+ "model.encoder.layers.4.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00004.bin",
384
+ "model.encoder.layers.4.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
385
+ "model.encoder.layers.4.self_attention.dense.weight": "pytorch_model-00001-of-00004.bin",
386
+ "model.encoder.layers.4.self_attention.key_proj.bias": "pytorch_model-00001-of-00004.bin",
387
+ "model.encoder.layers.4.self_attention.key_proj.weight": "pytorch_model-00001-of-00004.bin",
388
+ "model.encoder.layers.4.self_attention.query_proj.bias": "pytorch_model-00001-of-00004.bin",
389
+ "model.encoder.layers.4.self_attention.query_proj.weight": "pytorch_model-00001-of-00004.bin",
390
+ "model.encoder.layers.4.self_attention.value_proj.bias": "pytorch_model-00001-of-00004.bin",
391
+ "model.encoder.layers.4.self_attention.value_proj.weight": "pytorch_model-00001-of-00004.bin",
392
+ "model.encoder.layers.40.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
393
+ "model.encoder.layers.40.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00004.bin",
394
+ "model.encoder.layers.40.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00004.bin",
395
+ "model.encoder.layers.40.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
396
+ "model.encoder.layers.40.self_attention.dense.weight": "pytorch_model-00003-of-00004.bin",
397
+ "model.encoder.layers.40.self_attention.key_proj.bias": "pytorch_model-00003-of-00004.bin",
398
+ "model.encoder.layers.40.self_attention.key_proj.weight": "pytorch_model-00003-of-00004.bin",
399
+ "model.encoder.layers.40.self_attention.query_proj.bias": "pytorch_model-00003-of-00004.bin",
400
+ "model.encoder.layers.40.self_attention.query_proj.weight": "pytorch_model-00003-of-00004.bin",
401
+ "model.encoder.layers.40.self_attention.value_proj.bias": "pytorch_model-00003-of-00004.bin",
402
+ "model.encoder.layers.40.self_attention.value_proj.weight": "pytorch_model-00003-of-00004.bin",
403
+ "model.encoder.layers.41.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
404
+ "model.encoder.layers.41.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00004.bin",
405
+ "model.encoder.layers.41.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00004.bin",
406
+ "model.encoder.layers.41.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
407
+ "model.encoder.layers.41.self_attention.dense.weight": "pytorch_model-00003-of-00004.bin",
408
+ "model.encoder.layers.41.self_attention.key_proj.bias": "pytorch_model-00003-of-00004.bin",
409
+ "model.encoder.layers.41.self_attention.key_proj.weight": "pytorch_model-00003-of-00004.bin",
410
+ "model.encoder.layers.41.self_attention.query_proj.bias": "pytorch_model-00003-of-00004.bin",
411
+ "model.encoder.layers.41.self_attention.query_proj.weight": "pytorch_model-00003-of-00004.bin",
412
+ "model.encoder.layers.41.self_attention.value_proj.bias": "pytorch_model-00003-of-00004.bin",
413
+ "model.encoder.layers.41.self_attention.value_proj.weight": "pytorch_model-00003-of-00004.bin",
414
+ "model.encoder.layers.42.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
415
+ "model.encoder.layers.42.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00004.bin",
416
+ "model.encoder.layers.42.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00004.bin",
417
+ "model.encoder.layers.42.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
418
+ "model.encoder.layers.42.self_attention.dense.weight": "pytorch_model-00003-of-00004.bin",
419
+ "model.encoder.layers.42.self_attention.key_proj.bias": "pytorch_model-00003-of-00004.bin",
420
+ "model.encoder.layers.42.self_attention.key_proj.weight": "pytorch_model-00003-of-00004.bin",
421
+ "model.encoder.layers.42.self_attention.query_proj.bias": "pytorch_model-00003-of-00004.bin",
422
+ "model.encoder.layers.42.self_attention.query_proj.weight": "pytorch_model-00003-of-00004.bin",
423
+ "model.encoder.layers.42.self_attention.value_proj.bias": "pytorch_model-00003-of-00004.bin",
424
+ "model.encoder.layers.42.self_attention.value_proj.weight": "pytorch_model-00003-of-00004.bin",
425
+ "model.encoder.layers.43.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
426
+ "model.encoder.layers.43.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00004.bin",
427
+ "model.encoder.layers.43.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00004.bin",
428
+ "model.encoder.layers.43.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
429
+ "model.encoder.layers.43.self_attention.dense.weight": "pytorch_model-00003-of-00004.bin",
430
+ "model.encoder.layers.43.self_attention.key_proj.bias": "pytorch_model-00003-of-00004.bin",
431
+ "model.encoder.layers.43.self_attention.key_proj.weight": "pytorch_model-00003-of-00004.bin",
432
+ "model.encoder.layers.43.self_attention.query_proj.bias": "pytorch_model-00003-of-00004.bin",
433
+ "model.encoder.layers.43.self_attention.query_proj.weight": "pytorch_model-00003-of-00004.bin",
434
+ "model.encoder.layers.43.self_attention.value_proj.bias": "pytorch_model-00003-of-00004.bin",
435
+ "model.encoder.layers.43.self_attention.value_proj.weight": "pytorch_model-00003-of-00004.bin",
436
+ "model.encoder.layers.44.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
437
+ "model.encoder.layers.44.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00004.bin",
438
+ "model.encoder.layers.44.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00004.bin",
439
+ "model.encoder.layers.44.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
440
+ "model.encoder.layers.44.self_attention.dense.weight": "pytorch_model-00003-of-00004.bin",
441
+ "model.encoder.layers.44.self_attention.key_proj.bias": "pytorch_model-00003-of-00004.bin",
442
+ "model.encoder.layers.44.self_attention.key_proj.weight": "pytorch_model-00003-of-00004.bin",
443
+ "model.encoder.layers.44.self_attention.query_proj.bias": "pytorch_model-00003-of-00004.bin",
444
+ "model.encoder.layers.44.self_attention.query_proj.weight": "pytorch_model-00003-of-00004.bin",
445
+ "model.encoder.layers.44.self_attention.value_proj.bias": "pytorch_model-00003-of-00004.bin",
446
+ "model.encoder.layers.44.self_attention.value_proj.weight": "pytorch_model-00003-of-00004.bin",
447
+ "model.encoder.layers.45.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
448
+ "model.encoder.layers.45.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00004.bin",
449
+ "model.encoder.layers.45.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00004.bin",
450
+ "model.encoder.layers.45.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
451
+ "model.encoder.layers.45.self_attention.dense.weight": "pytorch_model-00003-of-00004.bin",
452
+ "model.encoder.layers.45.self_attention.key_proj.bias": "pytorch_model-00003-of-00004.bin",
453
+ "model.encoder.layers.45.self_attention.key_proj.weight": "pytorch_model-00003-of-00004.bin",
454
+ "model.encoder.layers.45.self_attention.query_proj.bias": "pytorch_model-00003-of-00004.bin",
455
+ "model.encoder.layers.45.self_attention.query_proj.weight": "pytorch_model-00003-of-00004.bin",
456
+ "model.encoder.layers.45.self_attention.value_proj.bias": "pytorch_model-00003-of-00004.bin",
457
+ "model.encoder.layers.45.self_attention.value_proj.weight": "pytorch_model-00003-of-00004.bin",
458
+ "model.encoder.layers.46.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
459
+ "model.encoder.layers.46.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00004.bin",
460
+ "model.encoder.layers.46.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00004.bin",
461
+ "model.encoder.layers.46.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
462
+ "model.encoder.layers.46.self_attention.dense.weight": "pytorch_model-00003-of-00004.bin",
463
+ "model.encoder.layers.46.self_attention.key_proj.bias": "pytorch_model-00003-of-00004.bin",
464
+ "model.encoder.layers.46.self_attention.key_proj.weight": "pytorch_model-00003-of-00004.bin",
465
+ "model.encoder.layers.46.self_attention.query_proj.bias": "pytorch_model-00003-of-00004.bin",
466
+ "model.encoder.layers.46.self_attention.query_proj.weight": "pytorch_model-00003-of-00004.bin",
467
+ "model.encoder.layers.46.self_attention.value_proj.bias": "pytorch_model-00003-of-00004.bin",
468
+ "model.encoder.layers.46.self_attention.value_proj.weight": "pytorch_model-00003-of-00004.bin",
469
+ "model.encoder.layers.47.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
470
+ "model.encoder.layers.47.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00004.bin",
471
+ "model.encoder.layers.47.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00004.bin",
472
+ "model.encoder.layers.47.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
473
+ "model.encoder.layers.47.self_attention.dense.weight": "pytorch_model-00003-of-00004.bin",
474
+ "model.encoder.layers.47.self_attention.key_proj.bias": "pytorch_model-00003-of-00004.bin",
475
+ "model.encoder.layers.47.self_attention.key_proj.weight": "pytorch_model-00003-of-00004.bin",
476
+ "model.encoder.layers.47.self_attention.query_proj.bias": "pytorch_model-00003-of-00004.bin",
477
+ "model.encoder.layers.47.self_attention.query_proj.weight": "pytorch_model-00003-of-00004.bin",
478
+ "model.encoder.layers.47.self_attention.value_proj.bias": "pytorch_model-00003-of-00004.bin",
479
+ "model.encoder.layers.47.self_attention.value_proj.weight": "pytorch_model-00003-of-00004.bin",
480
+ "model.encoder.layers.5.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
481
+ "model.encoder.layers.5.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00004.bin",
482
+ "model.encoder.layers.5.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00004.bin",
483
+ "model.encoder.layers.5.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
484
+ "model.encoder.layers.5.self_attention.dense.weight": "pytorch_model-00001-of-00004.bin",
485
+ "model.encoder.layers.5.self_attention.key_proj.bias": "pytorch_model-00001-of-00004.bin",
486
+ "model.encoder.layers.5.self_attention.key_proj.weight": "pytorch_model-00001-of-00004.bin",
487
+ "model.encoder.layers.5.self_attention.query_proj.bias": "pytorch_model-00001-of-00004.bin",
488
+ "model.encoder.layers.5.self_attention.query_proj.weight": "pytorch_model-00001-of-00004.bin",
489
+ "model.encoder.layers.5.self_attention.value_proj.bias": "pytorch_model-00001-of-00004.bin",
490
+ "model.encoder.layers.5.self_attention.value_proj.weight": "pytorch_model-00001-of-00004.bin",
491
+ "model.encoder.layers.6.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
492
+ "model.encoder.layers.6.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00004.bin",
493
+ "model.encoder.layers.6.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00004.bin",
494
+ "model.encoder.layers.6.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
495
+ "model.encoder.layers.6.self_attention.dense.weight": "pytorch_model-00001-of-00004.bin",
496
+ "model.encoder.layers.6.self_attention.key_proj.bias": "pytorch_model-00001-of-00004.bin",
497
+ "model.encoder.layers.6.self_attention.key_proj.weight": "pytorch_model-00001-of-00004.bin",
498
+ "model.encoder.layers.6.self_attention.query_proj.bias": "pytorch_model-00001-of-00004.bin",
499
+ "model.encoder.layers.6.self_attention.query_proj.weight": "pytorch_model-00001-of-00004.bin",
500
+ "model.encoder.layers.6.self_attention.value_proj.bias": "pytorch_model-00001-of-00004.bin",
501
+ "model.encoder.layers.6.self_attention.value_proj.weight": "pytorch_model-00001-of-00004.bin",
502
+ "model.encoder.layers.7.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
503
+ "model.encoder.layers.7.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00004.bin",
504
+ "model.encoder.layers.7.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00004.bin",
505
+ "model.encoder.layers.7.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
506
+ "model.encoder.layers.7.self_attention.dense.weight": "pytorch_model-00001-of-00004.bin",
507
+ "model.encoder.layers.7.self_attention.key_proj.bias": "pytorch_model-00001-of-00004.bin",
508
+ "model.encoder.layers.7.self_attention.key_proj.weight": "pytorch_model-00001-of-00004.bin",
509
+ "model.encoder.layers.7.self_attention.query_proj.bias": "pytorch_model-00001-of-00004.bin",
510
+ "model.encoder.layers.7.self_attention.query_proj.weight": "pytorch_model-00001-of-00004.bin",
511
+ "model.encoder.layers.7.self_attention.value_proj.bias": "pytorch_model-00001-of-00004.bin",
512
+ "model.encoder.layers.7.self_attention.value_proj.weight": "pytorch_model-00001-of-00004.bin",
513
+ "model.encoder.layers.8.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
514
+ "model.encoder.layers.8.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00004.bin",
515
+ "model.encoder.layers.8.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00004.bin",
516
+ "model.encoder.layers.8.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
517
+ "model.encoder.layers.8.self_attention.dense.weight": "pytorch_model-00001-of-00004.bin",
518
+ "model.encoder.layers.8.self_attention.key_proj.bias": "pytorch_model-00001-of-00004.bin",
519
+ "model.encoder.layers.8.self_attention.key_proj.weight": "pytorch_model-00001-of-00004.bin",
520
+ "model.encoder.layers.8.self_attention.query_proj.bias": "pytorch_model-00001-of-00004.bin",
521
+ "model.encoder.layers.8.self_attention.query_proj.weight": "pytorch_model-00001-of-00004.bin",
522
+ "model.encoder.layers.8.self_attention.value_proj.bias": "pytorch_model-00001-of-00004.bin",
523
+ "model.encoder.layers.8.self_attention.value_proj.weight": "pytorch_model-00001-of-00004.bin",
524
+ "model.encoder.layers.9.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
525
+ "model.encoder.layers.9.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00004.bin",
526
+ "model.encoder.layers.9.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00004.bin",
527
+ "model.encoder.layers.9.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
528
+ "model.encoder.layers.9.self_attention.dense.weight": "pytorch_model-00001-of-00004.bin",
529
+ "model.encoder.layers.9.self_attention.key_proj.bias": "pytorch_model-00001-of-00004.bin",
530
+ "model.encoder.layers.9.self_attention.key_proj.weight": "pytorch_model-00001-of-00004.bin",
531
+ "model.encoder.layers.9.self_attention.query_proj.bias": "pytorch_model-00001-of-00004.bin",
532
+ "model.encoder.layers.9.self_attention.query_proj.weight": "pytorch_model-00001-of-00004.bin",
533
+ "model.encoder.layers.9.self_attention.value_proj.bias": "pytorch_model-00001-of-00004.bin",
534
+ "model.encoder.layers.9.self_attention.value_proj.weight": "pytorch_model-00001-of-00004.bin",
535
+ "model.encoder.ln_f.weight": "pytorch_model-00003-of-00004.bin",
536
+ "model.word_embeddings.weight": "pytorch_model-00001-of-00004.bin"
537
+ }
538
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<mask>",
4
+ "<doc>",
5
+ "<title>",
6
+ "<para>",
7
+ "<eop>",
8
+ "<eot>",
9
+ "<eod>",
10
+ "[User]",
11
+ "[Assistant]",
12
+ "[System]",
13
+ "[Turn 1]",
14
+ "[Turn 2]",
15
+ "[Turn 3]",
16
+ "[Turn 4]",
17
+ "[Turn 5]",
18
+ "[Turn 6]",
19
+ "[Turn 7]",
20
+ "[Turn 8]",
21
+ "[Turn 9]",
22
+ "[Turn 10]",
23
+ "[Turn 11]",
24
+ "[Turn 12]",
25
+ "[Turn 13]",
26
+ "[Turn 14]",
27
+ "[Turn 15]",
28
+ "[Turn 16]",
29
+ "[Turn 17]",
30
+ "[Turn 18]",
31
+ "[Turn 19]",
32
+ "[Turn 20]",
33
+ "[Turn 21]",
34
+ "[Turn 22]",
35
+ "[Turn 23]",
36
+ "[Turn 24]",
37
+ "[Turn 25]",
38
+ "[Turn 26]",
39
+ "[Turn 27]",
40
+ "[Turn 28]",
41
+ "[Turn 29]",
42
+ "[Turn 30]",
43
+ "[Turn 31]",
44
+ "[Turn 32]",
45
+ "[Turn 33]",
46
+ "[Turn 34]",
47
+ "[Turn 35]",
48
+ "[Turn 36]",
49
+ "[Turn 37]",
50
+ "[Turn 38]",
51
+ "[Turn 39]",
52
+ "[Turn 40]",
53
+ "[Turn 41]",
54
+ "[Turn 42]",
55
+ "[Turn 43]",
56
+ "[Turn 44]",
57
+ "[Turn 45]",
58
+ "[Turn 46]",
59
+ "[Turn 47]",
60
+ "[Turn 48]",
61
+ "[Turn 49]",
62
+ "[Turn 50]",
63
+ "[Turn 51]",
64
+ "[Turn 52]",
65
+ "[Turn 53]",
66
+ "[Turn 54]",
67
+ "[Turn 55]",
68
+ "[Turn 56]",
69
+ "[Turn 57]",
70
+ "[Turn 58]",
71
+ "[Turn 59]",
72
+ "[Turn 60]",
73
+ "[Turn 61]",
74
+ "[Turn 62]",
75
+ "[Turn 63]",
76
+ "[Turn 64]",
77
+ "[Turn 65]",
78
+ "[Turn 66]",
79
+ "[Turn 67]",
80
+ "[Turn 68]",
81
+ "[Turn 69]",
82
+ "[Turn 70]",
83
+ "[Turn 71]",
84
+ "[Turn 72]",
85
+ "[Turn 73]",
86
+ "[Turn 74]",
87
+ "[Turn 75]",
88
+ "[Turn 76]",
89
+ "[Turn 77]",
90
+ "[Turn 78]",
91
+ "[Turn 79]",
92
+ "[Turn 80]",
93
+ "[Turn 81]",
94
+ "[Turn 82]",
95
+ "[Turn 83]",
96
+ "[Turn 84]",
97
+ "[Turn 85]",
98
+ "[Turn 86]",
99
+ "[Turn 87]",
100
+ "[Turn 88]",
101
+ "[Turn 89]",
102
+ "[Turn 90]",
103
+ "[Turn 91]",
104
+ "[Turn 92]",
105
+ "[Turn 93]",
106
+ "[Turn 94]",
107
+ "[Turn 95]",
108
+ "[Turn 96]",
109
+ "[Turn 97]",
110
+ "[Turn 98]",
111
+ "[Turn 99]",
112
+ "[Turn 100]"
113
+ ],
114
+ "unk_token": "<unk>"
115
+ }
tokenization_batgpt.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from typing import List, Optional, Union, Dict, Tuple
4
+ from sentencepiece import SentencePieceProcessor
5
+ from transformers import PreTrainedTokenizer
6
+ from transformers.utils import logging, PaddingStrategy
7
+ from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
8
+
9
+ SPECIAL_TOKENS = ["<mask>", "<doc>", "<title>", "<para>", "<eop>", "<eot>", "<eod>"] + ["[User]", "[Assistant]", "[System]"] + ["[Turn {}]".format(i+1) for i in range(100)]
10
+
11
+ class SPTokenizer:
12
+ def __init__(self, model_path: str):
13
+ # reload tokenizer
14
+ assert os.path.isfile(model_path), model_path
15
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
16
+
17
+ # BOS / EOS token IDs
18
+ self.n_words: int = self.sp_model.vocab_size()
19
+ self.bos_id: int = self.sp_model.bos_id()
20
+ self.eos_id: int = self.sp_model.eos_id()
21
+ self.pad_id: int = self.sp_model.unk_id()
22
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
23
+
24
+ self.special_tokens = {}
25
+ self.index_special_tokens = {}
26
+ for token in SPECIAL_TOKENS:
27
+ self.special_tokens[token] = self.n_words
28
+ self.index_special_tokens[self.n_words] = token
29
+ self.n_words += 1
30
+
31
+ def tokenize(self, s: str):
32
+ return self.sp_model.EncodeAsPieces(s)
33
+
34
+ def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
35
+ assert type(s) is str
36
+ t = self.sp_model.encode(s)
37
+ if bos:
38
+ t = [self.bos_id] + t
39
+ if eos:
40
+ t = t + [self.eos_id]
41
+ return t
42
+
43
+ def decode(self, t: List[int]) -> str:
44
+ return self.sp_model.decode(t)
45
+
46
+ def decode_tokens(self, tokens: List[str]) -> str:
47
+ text = self.sp_model.DecodePieces(tokens)
48
+ return text
49
+
50
+ def convert_token_to_id(self, token):
51
+ """ Converts a token (str) in an id using the vocab. """
52
+ if token in self.special_tokens:
53
+ return self.special_tokens[token]
54
+ return self.sp_model.PieceToId(token)
55
+
56
+ def convert_id_to_token(self, index):
57
+ """Converts an index (integer) in a token (str) using the vocab."""
58
+ if index in self.index_special_tokens or index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
59
+ return ""
60
+ return self.sp_model.IdToPiece(index)
61
+
62
+
63
+ class BatGPTTokenizer(PreTrainedTokenizer):
64
+ vocab_files_names = {"vocab_file": "tokenizer.model"}
65
+
66
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
67
+
68
+ def __init__(self, vocab_file, padding_side="left", **kwargs):
69
+ super().__init__(padding_side=padding_side, **kwargs)
70
+ self.name = "BatGPTTokenizer"
71
+
72
+ self.vocab_file = vocab_file
73
+ self.tokenizer = SPTokenizer(vocab_file)
74
+ self.special_tokens = {
75
+ "<bos>": self.tokenizer.bos_id,
76
+ "<eos>": self.tokenizer.eos_id,
77
+ "<pad>": self.tokenizer.pad_id
78
+ }
79
+
80
+ #
81
+ self.unk_token = "<unk>"
82
+ self.add_special_tokens({'additional_special_tokens': SPECIAL_TOKENS})
83
+
84
+ def get_command(self, token):
85
+ if token in self.special_tokens:
86
+ return self.special_tokens[token]
87
+ assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
88
+ return self.tokenizer.special_tokens[token]
89
+
90
+ @property
91
+ def pad_token(self) -> str:
92
+ return "<unk>"
93
+
94
+ @property
95
+ def pad_token_id(self):
96
+ return self.get_command("<pad>")
97
+
98
+ @property
99
+ def eos_token(self) -> str:
100
+ return "</s>"
101
+
102
+ @property
103
+ def eos_token_id(self):
104
+ return self.get_command("<eos>")
105
+
106
+ @property
107
+ def vocab_size(self):
108
+ return self.tokenizer.n_words
109
+
110
+ def get_vocab(self):
111
+ """ Returns vocab as a dict """
112
+ vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
113
+ vocab.update(self.added_tokens_encoder)
114
+ return vocab
115
+
116
+ def _tokenize(self, text, **kwargs):
117
+ return self.tokenizer.tokenize(text)
118
+
119
+ def _convert_token_to_id(self, token):
120
+ """ Converts a token (str) in an id using the vocab. """
121
+ return self.tokenizer.convert_token_to_id(token)
122
+
123
+ def _convert_id_to_token(self, index):
124
+ """Converts an index (integer) in a token (str) using the vocab."""
125
+ return self.tokenizer.convert_id_to_token(index)
126
+
127
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
128
+ return self.tokenizer.decode_tokens(tokens)
129
+
130
+ def save_vocabulary(self, save_directory, filename_prefix=None):
131
+ if os.path.isdir(save_directory):
132
+ vocab_file = os.path.join(
133
+ save_directory, self.vocab_files_names["vocab_file"]
134
+ )
135
+ else:
136
+ vocab_file = save_directory
137
+
138
+ with open(self.vocab_file, 'rb') as fin:
139
+ proto_str = fin.read()
140
+
141
+ with open(vocab_file, "wb") as writer:
142
+ writer.write(proto_str)
143
+
144
+ return (vocab_file,)
145
+
146
+ def get_prefix_tokens(self):
147
+ prefix_tokens = [self.get_command("<doc>"), self.get_command("<para>")]
148
+ return prefix_tokens
149
+
150
+ def build_inputs(self, query, history=None, system_prompt=None):
151
+ if history is None:
152
+ history = []
153
+ role_user = "[User]"
154
+ role_assistant = "[Assistant]"
155
+ if system_prompt:
156
+ prompt = "[System]\n\n {}\n\n<eot>".format(system_prompt)
157
+ else:
158
+ prompt = ""
159
+ for i, (old_query, response) in enumerate(history):
160
+ prompt += "[Turn {}]\n\n{} {}\n\n{} {}\n\n<eop>".format(i + 1, role_user, old_query, role_assistant, response)
161
+ prompt += "[Turn {}]\n\n{} {}\n\n{}".format(len(history) + 1, role_user, query, role_assistant)
162
+ inputs = self([prompt], return_tensors="pt")
163
+ return inputs
164
+
165
+ def build_stream_inputs(self, query: str, history: List[Tuple[str, str]] = None, system_prompt = None):
166
+ role_user = "[User]"
167
+ role_assistant = "[Assistant]"
168
+ if history:
169
+ prompt = "\n\n[Turn {}]\n\n{} {}\n\n{}".format(len(history) + 1, role_user, query, role_assistant)
170
+ input_ids = self.encode(prompt, add_special_tokens=False)
171
+ input_ids = input_ids[1:]
172
+ inputs = self.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False)
173
+ else:
174
+ if system_prompt:
175
+ prompt = "[System]\n\n {}\n\n[Turn {}]\n\n{} {}\n\n{} ".format(system_prompt, len(history) + 1, role_user, query, role_assistant)
176
+ else:
177
+ prompt = "[Turn {}]\n\n{} {}\n\n{} ".format(len(history) + 1, role_user, query, role_assistant)
178
+ inputs = self([prompt], return_tensors="pt")
179
+ return inputs
180
+
181
+ def build_inputs_with_special_tokens(
182
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
183
+ ) -> List[int]:
184
+ prefix_tokens = self.get_prefix_tokens()
185
+ token_ids_0 = prefix_tokens + token_ids_0
186
+ if token_ids_1 is not None:
187
+ token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
188
+ return token_ids_0
189
+
190
+ def _pad(
191
+ self,
192
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
193
+ max_length: Optional[int] = None,
194
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
195
+ pad_to_multiple_of: Optional[int] = None,
196
+ return_attention_mask: Optional[bool] = None,
197
+ ) -> dict:
198
+ # Load from model defaults
199
+ assert self.padding_side == "left"
200
+
201
+ required_input = encoded_inputs[self.model_input_names[0]]
202
+ seq_length = len(required_input)
203
+
204
+ if padding_strategy == PaddingStrategy.LONGEST:
205
+ max_length = len(required_input)
206
+
207
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
208
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
209
+
210
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
211
+
212
+ # Initialize attention mask if not present.
213
+ if "attention_mask" not in encoded_inputs:
214
+ encoded_inputs["attention_mask"] = [1] * seq_length
215
+
216
+ if "position_ids" not in encoded_inputs:
217
+ encoded_inputs["position_ids"] = list(range(seq_length))
218
+
219
+ if needs_to_be_padded:
220
+ difference = max_length - len(required_input)
221
+
222
+ if "attention_mask" in encoded_inputs:
223
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
224
+ if "position_ids" in encoded_inputs:
225
+ encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
226
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
227
+
228
+ return encoded_inputs
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7dc4c393423b76e4373e5157ddc34803a0189ba96b21ddbb40269d31468a6f2
3
+ size 1018370
tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": [
4
+ "tokenization_batgpt.BatGPTTokenizer",
5
+ null
6
+ ]
7
+ },
8
+ "clean_up_tokenization_spaces": true,
9
+ "do_lower_case": false,
10
+ "model_max_length": 1000000000000000019884624838656,
11
+ "padding_side": "left",
12
+ "remove_space": false,
13
+ "tokenizer_class": "BatGPTTokenizer"
14
+ }