Upload folder using huggingface_hub
Browse files- .gitattributes +3 -0
- MODEL_LICENSE +93 -0
- README.md +90 -3
- imgs/head_final3.png +3 -0
- model_index.json +25 -0
- scheduler/scheduler_config.json +22 -0
- text_encoder/__pycache__/configuration_chatglm.cpython-311.pyc +0 -0
- text_encoder/__pycache__/configuration_chatglm.cpython-37.pyc +0 -0
- text_encoder/__pycache__/configuration_chatglm.cpython-38.pyc +0 -0
- text_encoder/__pycache__/configuration_chatglm.cpython-39.pyc +0 -0
- text_encoder/__pycache__/modeling_chatglm.cpython-38.pyc +0 -0
- text_encoder/__pycache__/modeling_chatglm.cpython-39.pyc +0 -0
- text_encoder/__pycache__/tokenization_chatglm.cpython-38.pyc +0 -0
- text_encoder/__pycache__/tokenization_chatglm.cpython-39.pyc +0 -0
- text_encoder/config.json +42 -0
- text_encoder/configuration_chatglm.py +61 -0
- text_encoder/modeling_chatglm.py +1298 -0
- text_encoder/pytorch_model-00001-of-00007.bin +3 -0
- text_encoder/pytorch_model-00002-of-00007.bin +3 -0
- text_encoder/pytorch_model-00003-of-00007.bin +3 -0
- text_encoder/pytorch_model-00004-of-00007.bin +3 -0
- text_encoder/pytorch_model-00005-of-00007.bin +3 -0
- text_encoder/pytorch_model-00006-of-00007.bin +3 -0
- text_encoder/pytorch_model-00007-of-00007.bin +3 -0
- text_encoder/pytorch_model.bin.index.json +207 -0
- text_encoder/quantization.py +188 -0
- text_encoder/tokenization_chatglm.py +300 -0
- text_encoder/tokenizer.model +3 -0
- text_encoder/tokenizer_config.json +12 -0
- text_encoder/vocab.txt +3 -0
- tokenizer/tokenization_chatglm.py +300 -0
- tokenizer/tokenizer.model +3 -0
- tokenizer/tokenizer_config.json +12 -0
- tokenizer/vocab.txt +3 -0
- unet/config.json +73 -0
- unet/diffusion_pytorch_model.safetensors +3 -0
- vae/config.json +31 -0
- vae/diffusion_pytorch_model.bin +3 -0
- vae/diffusion_pytorch_model.fp16.bin +3 -0
- vae/diffusion_pytorch_model.fp16.safetensors +3 -0
- vae/diffusion_pytorch_model.safetensors +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
text_encoder/vocab.txt filter=lfs diff=lfs merge=lfs -text
|
37 |
+
imgs/head_final3.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text
|
MODEL_LICENSE
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
模型许可协议
|
2 |
+
模型发布日期:2024/7/6
|
3 |
+
|
4 |
+
通过点击同意或使用、复制、修改、分发、表演或展示模型作品的任何部分或元素,您将被视为已承认并接受本协议的内容,本协议立即生效。
|
5 |
+
|
6 |
+
1.定义。
|
7 |
+
a. “协议”指本协议中所规定的使用、复制、分发、修改、表演和展示模型作品或其任何部分或元素的条款和条件。
|
8 |
+
b. “材料”是指根据本协议提供的专有的模型和文档(及其任何部分)的统称。
|
9 |
+
c. “模型”指大型语言模型、图像/视频/音频/3D 生成模型、多模态大型语言模型及其软件和算法,包括训练后的模型权重、参数(包括优化器状态)、机器学习模型代码、推理支持代码、训练支持代码、微调支持代码以及我们公开提供的前述其他元素。
|
10 |
+
d. “输出”是指通过操作或以其他方式使用模型或模型衍生品而产生的模型或模型衍生品的信息和/或内容输出。
|
11 |
+
e. “模型衍生品”包括:(i)对模型或任何模型衍生物的修改;(ii)基于模型的任何模型衍生物的作品;或(iii)通过将模型或模型的任何模型衍生物的权重、参数、操作或输出的模式转移到该模型而创建的任何其他机器学习模型,以使该模型的性能类似于模型或模型衍生物。为清楚起见,输出本身不被视为模型衍生物。
|
12 |
+
f. “模型作品”包括:(i)材料;(ii)模型衍生品;及(iii)其所有衍生作品。
|
13 |
+
g. “许可人”或“我们”指作品所有者或作品所有者授权的授予许可的实体,包括可能对模型和/或分发模型拥有权利的个人或实体。
|
14 |
+
h.“被许可人”、“您”或“您的”是指行使本协议授予的权利和/或为任何目的和在任何使用领域使用模型作品的自然人或法人实体。
|
15 |
+
i.“第三方”是指不受我们或您共同控制的个人或法人实体。
|
16 |
+
|
17 |
+
2. 许可内容。
|
18 |
+
a.我们授予您非排他性的、全球性的、不可转让的、免版税的许可(在我们的知识产权或我们拥有的体现在材料中或利用材料的其他权利的范围内),允许您仅根据本协议的条款使用、复制、分发、创作衍生作品(包括模型衍生品)和对材料进行修改,并且您不得违反(或鼓励、或允许任何其他人违反)本协议的任何条款。
|
19 |
+
b.在遵守本协议的前提下,您可以分发或向第三方提供模型作品,您须满足以下条件:
|
20 |
+
(i)您必须向所有该模型作品或使用该作品的产品或服务的任何第三方接收者提供模型作品的来源和本协议的副本;
|
21 |
+
(ii)您必须在任何修改过的文档上附加明显的声明,说明您更改了这些文档;
|
22 |
+
(iii)您可以在您的修改中添加您自己的版权声明,并且,在您对该作品的使用、复制、修改、分发、表演和展示符合本协议的条款和条件的前提下,您可以为您的修改或任何此类模型衍生品的使用、复制或分发提供额外或不同的许可条款和条件。
|
23 |
+
c. 附加商业条款: 若您希望将模型及模型衍生品用作商业用途,则您必须向许可人申请许可,许可人可自行决定向您授予许可。除非许可人另行明确授予您该等权利,否则您无权行使本协议项下的任何权利。
|
24 |
+
|
25 |
+
3.使用限制。
|
26 |
+
a. 您对本模型作品的使用必须遵守适用法律法规(包括贸易合规法律法规),并遵守《服务协议》(https://kolors.kuaishou.com/agreement)。您必须将本第 3(a) 和 3(b) 条中提及的使用限制作为可执行条款纳入任何规范本模型作品使用和/或分发的协议(例如许可协议、使用条款等),并且您必须向您分发的后续用户发出通知,告知其本模型作品受本第 3(a) 和 3(b) 条中的使用限制约束。
|
27 |
+
b. 您不得使用本模型作品或本模型作品的任何输出或成果来改进任何其他模型(本模型或其模型衍生品除外)。
|
28 |
+
|
29 |
+
4.知识产权。
|
30 |
+
a. 我们保留模型的所有权及其相关知识产权。在遵守本协议条款和条件的前提下,对于您制作的材料的任何衍生作品和修改,您是且将是此类衍生作品和修改的所有者。
|
31 |
+
b. 本协议不授予任何商标、商号、服务标记或产品名称的标识许可,除非出于描述和分发本模型作品的合理和惯常用途。
|
32 |
+
c. 如果您对我们或任何个人或实体提起诉讼或其他程序(包括诉讼中的交叉索赔或反索赔),声称材料或任何输出或任何上述内容的任何部分侵犯您拥有或可许可的任何知识产权或其他权利,则根据本协议授予您的所有许可应于提起此类诉讼或其他程序之日起终止。
|
33 |
+
|
34 |
+
5. 免责声明和责任限制。
|
35 |
+
a. 本模型作品及其任何输出和结果按“原样”提供,不作任何明示或暗示的保证,包括适销性、非侵权性或适用于特定用途的保证。我们不对材料及其任何输出的安全性或稳定性作任��保证,也不承担任何责任。
|
36 |
+
b. 在任何情况下,我们均不对您承担任何损害赔偿责任,包括但不限于因您使用或无法使用材料或其任何输出而造成的任何直接、间接、特殊或后果性损害赔偿责任,无论该损害赔偿责任是如何造成的。
|
37 |
+
|
38 |
+
6. 存续和终止。
|
39 |
+
a. 本协议期限自您接受本协议或访问材料之日起开始,并将持续完全有效,直至根据本协议条款和条件终止。
|
40 |
+
b. 如果您违反本协议的任何条款或条件,我们可终止本协议。本协议终止后,您必须立即删除并停止使用本模型作品。第 4(a)、4(c)、5和 7 条在本协议终止后仍然有效。
|
41 |
+
|
42 |
+
7. 适用法律和管辖权。
|
43 |
+
a. 本协议及由本协议引起的或与本协议有关的任何争议均受中华人民共和国大陆地区(仅为本协议目的,不包括香港、澳门和台湾)法律管辖,并排除冲突法的适用,且《联合国国际货物销售合同公约》不适用于本协议。
|
44 |
+
b. 因本协议引起或与本协议有关的任何争议,由许可人住所地人民法院管辖。
|
45 |
+
|
46 |
+
请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 kwai-kolors@kuaishou.com 与我们联系。
|
47 |
+
|
48 |
+
|
49 |
+
英文版
|
50 |
+
|
51 |
+
MODEL LICENSE AGREEMENT
|
52 |
+
Release Date: 2024/7/6
|
53 |
+
By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Model Works, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
54 |
+
1. DEFINITIONS.
|
55 |
+
a. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of the Model Works or any portion or element thereof set forth herein.
|
56 |
+
b. “Materials” shall mean, collectively, Us proprietary the Model and Documentation (and any portion thereof) as made available by Us under this Agreement.
|
57 |
+
c. “Model” shall mean the large language models, image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us .
|
58 |
+
d. “Output” shall mean the information and/or content output of Model or a Model Derivative that results from operating or otherwise using Model or a Model Derivative.
|
59 |
+
e. “Model Derivatives” shall mean all: (i) modifications to the Model or any Model Derivative; (ii) works based on the Model or any Model Derivative; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of the Model or any Model Derivative, to that model in order to cause that model to perform similarly to the Model or a Model Derivative, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs or a Model Derivative for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
|
60 |
+
f. “Model Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
|
61 |
+
g. “Licensor” , “We” or “Us” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
|
62 |
+
h. “Licensee”, “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Model Works for any purpose and in any field of use.
|
63 |
+
i. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
|
64 |
+
|
65 |
+
2. LICENSE CONTENT.
|
66 |
+
a. We grant You a non-exclusive, worldwide, non-transferable and royalty-free limited license under the intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
|
67 |
+
b. You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Model Works, provided that You meet all of the following conditions:
|
68 |
+
(i) You must provide all such Third Party recipients of the Model Works or products or services using them the source of the Model and a copy of this Agreement;
|
69 |
+
(ii) You must cause any modified documents to carry prominent notices stating that You changed the documents;
|
70 |
+
(iii) You may add Your own copyright statement to Your modifications and, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement.
|
71 |
+
c. additional commercial terms: If, on the Model version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, or, Licensee is a cloud computing platform vendor, You must request a license from licensor, which the licensor may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until We otherwise expressly grants You such rights.
|
72 |
+
|
73 |
+
|
74 |
+
3. LICENSE RESTRICITIONS.
|
75 |
+
a. Your use of the Model Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Service Agreement. You must include the use restrictions referenced in these Sections 3(a) and 3(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Model Works and You must provide notice to subsequent users to whom You distribute that Model Works are subject to the use restrictions in these Sections 3(a) and 3(b).
|
76 |
+
b. You must not use the Model Works or any Output or results of the Model Works to improve any other large model (other than Model or Model Derivatives thereof).
|
77 |
+
4. INTELLECTUAL PROPERTY.
|
78 |
+
a. We retain ownership of all intellectual property rights in and to the Model and derivatives. Conditioned upon compliance with the terms and conditions of this Agreement, with respect to any derivative works and modifications of the Materials that are made by you, you are and will be the owner of such derivative works and modifications.
|
79 |
+
b. No trademark license is granted to use the trade names, trademarks, service marks, or product names of Us, except as required to fulfill notice requirements under this Agreement or as required for reasonable and customary use in describing and redistributing the Materials.
|
80 |
+
c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed.
|
81 |
+
5. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
|
82 |
+
a. THE MODEL WORKS AND ANY OUTPUT AND RESULTS THERE FROM ARE PROVIDED "AS IS" WITHOUT ANY EXPRESS OR IMPLIED WARRANTY OF ANY KIND INCLUDING WARRANTIES OF MERCHANTABILITY, NONINFRINGEMENT, OR FITNESS FOR A PARTICULAR PURPOSE. WE MAKE NO WARRANTY AND ASSUME NO RESPONSIBILITY FOR THE SAFETY OR STABILITY OF THE MATERIALS AND ANY OUTPUT THEREFROM.
|
83 |
+
b. IN NO EVENT SHALL WE BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE MATERIALS OR ANY OUTPUT OF IT, NO MATTER HOW IT’S CAUSED.
|
84 |
+
c. You will defend, indemnify and hold harmless Us from and against any claim by any third party arising out of or related to your use or distribution of the Materials.
|
85 |
+
|
86 |
+
6. SURVIVAL AND TERMINATION.
|
87 |
+
a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
88 |
+
b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Model Works. Sections 4(a), 4(c), 5 and 7 shall survive the termination of this Agreement.
|
89 |
+
7. GOVERNING LAW AND JURISDICTION.
|
90 |
+
a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of China (for the purpose of this agreement only, excluding Hong Kong, Macau, and Taiwan), without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
91 |
+
b. Any disputes arising from or related to this Agreement shall be under the jurisdiction of the People's Court where the Licensor is located.
|
92 |
+
|
93 |
+
Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at kwai-kolors@kuaishou.com.
|
README.md
CHANGED
@@ -1,3 +1,90 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- zh
|
5 |
+
- en
|
6 |
+
tags:
|
7 |
+
- text-to-image
|
8 |
+
- stable-diffusion
|
9 |
+
- kolors
|
10 |
+
---
|
11 |
+
# Kolors: Effective Training of Diffusion Model for Photorealistic Text-to-Image Synthesis
|
12 |
+
<div align="center" style="display: flex; justify-content: center; flex-wrap: wrap;">
|
13 |
+
<a href="https://github.com/Kwai-Kolors/Kolors"><img src="https://img.shields.io/static/v1?label=Kolors Code&message=Github&color=blue&logo=github-pages"></a>  
|
14 |
+
<a href="https://kwai-kolors.github.io/"><img src="https://img.shields.io/static/v1?label=Team%20Page&message=Page&color=green"></a>  
|
15 |
+
<a href="https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/Kolors_paper.pdf"><img src="https://img.shields.io/static/v1?label=Tech Report&message=Arxiv:Kolors&color=red&logo=arxiv"></a>  
|
16 |
+
<a href="https://kolors.kuaishou.com/"><img src="https://img.shields.io/static/v1?label=Official Website&message=Page&color=green"></a>
|
17 |
+
</div>
|
18 |
+
<figure>
|
19 |
+
<img src="imgs/head_final3.png">
|
20 |
+
</figure>
|
21 |
+
<br>
|
22 |
+
|
23 |
+
## 📖 Introduction
|
24 |
+
Kolors is a large-scale text-to-image generation model based on latent diffusion, developed by the Kuaishou Kolors team. Trained on billions of text-image pairs, Kolors exhibits significant advantages over both open-source and proprietary models in visual quality, complex semantic accuracy, and text rendering for both Chinese and English characters. Furthermore, Kolors supports both Chinese and English inputs, demonstrating strong performance in understanding and generating Chinese-specific content. For more details, please refer to this <a href="https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/Kolors_paper.pdf">technical report</a></b>.
|
25 |
+
|
26 |
+
|
27 |
+
## 🚀 Quick Start
|
28 |
+
### Requirements
|
29 |
+
|
30 |
+
* Python 3.8 or later
|
31 |
+
* PyTorch 1.13.1 or later
|
32 |
+
* Transformers 4.26.1 or later
|
33 |
+
* Recommended: CUDA 11.7 or later
|
34 |
+
<br>
|
35 |
+
|
36 |
+
1. Repository cloning and dependency installation
|
37 |
+
|
38 |
+
```bash
|
39 |
+
apt-get install git-lfs
|
40 |
+
git clone https://github.com/Kwai-Kolors/Kolors
|
41 |
+
cd Kolors
|
42 |
+
conda create --name kolors python=3.8
|
43 |
+
conda activate kolors
|
44 |
+
pip install -r requirements.txt
|
45 |
+
python3 setup.py install
|
46 |
+
```
|
47 |
+
2. Weights download([link](https://huggingface.co/Kwai-Kolors/Kolors)):
|
48 |
+
```bash
|
49 |
+
huggingface-cli download --resume-download Kwai-Kolors/Kolors --local-dir weights/Kolors
|
50 |
+
```
|
51 |
+
or
|
52 |
+
```bash
|
53 |
+
git lfs clone https://huggingface.co/Kwai-Kolors/Kolors weights/Kolors
|
54 |
+
```
|
55 |
+
3. Inference:
|
56 |
+
```bash
|
57 |
+
python3 scripts/sample.py "一张瓢虫的照片,微距,变焦,高质量,电影,拿着一个牌子,写着“可图”"
|
58 |
+
# The image will be saved to "scripts/outputs/sample_test.jpg"
|
59 |
+
```
|
60 |
+
|
61 |
+
### Using with Diffusers
|
62 |
+
Please refer to https://huggingface.co/Kwai-Kolors/Kolors-diffusers.
|
63 |
+
|
64 |
+
## 📜 License&Citation
|
65 |
+
### License
|
66 |
+
Kolors are fully open-sourced for academic research. For commercial use, please fill out this [questionnaire](https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/可图KOLORS模型商业授权申请书.docx) and sent it to kwai-kolors@kuaishou.com for registration.
|
67 |
+
|
68 |
+
We open-source Kolors to promote the development of large text-to-image models in collaboration with the open-source community. The code of this project is open-sourced under the Apache-2.0 license. We sincerely urge all developers and users to strictly adhere to the [open-source license](MODEL_LICENSE), avoiding the use of the open-source model, code, and its derivatives for any purposes that may harm the country and society or for any services not evaluated and registered for safety. Note that despite our best efforts to ensure the compliance, accuracy, and safety of the data during training, due to the diversity and combinability of generated content and the probabilistic randomness affecting the model, we cannot guarantee the accuracy and safety of the output content, and the model is susceptible to misleading. This project does not assume any legal responsibility for any data security issues, public opinion risks, or risks and liabilities arising from the model being misled, abused, misused, or improperly utilized due to the use of the open-source model and code.
|
69 |
+
|
70 |
+
|
71 |
+
### Citation
|
72 |
+
If you find our work helpful, please cite it!
|
73 |
+
|
74 |
+
```
|
75 |
+
@article{kolors,
|
76 |
+
title={Kolors: Effective Training of Diffusion Model for Photorealistic Text-to-Image Synthesis},
|
77 |
+
author={Kolors Team},
|
78 |
+
journal={arXiv preprint},
|
79 |
+
year={2024}
|
80 |
+
}
|
81 |
+
```
|
82 |
+
|
83 |
+
### Acknowledgments
|
84 |
+
- Thanks to [Diffusers](https://github.com/huggingface/diffusers) for providing the codebase.
|
85 |
+
- Thanks to [ChatGLM3](https://github.com/THUDM/ChatGLM3) for providing the powerful Chinese language model.
|
86 |
+
<br>
|
87 |
+
|
88 |
+
### Contact Us
|
89 |
+
|
90 |
+
If you want to leave a message for our R&D team and product team, feel free to join our [WeChat group](https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/wechat.png). You can also contact us via email (kwai-kolors@kuaishou.com).
|
imgs/head_final3.png
ADDED
Git LFS Details
|
model_index.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "StableDiffusionXLPipeline",
|
3 |
+
"_diffusers_version": "0.18.0.dev0",
|
4 |
+
"force_zeros_for_empty_prompt": true,
|
5 |
+
"scheduler": [
|
6 |
+
"diffusers",
|
7 |
+
"EulerDiscreteScheduler"
|
8 |
+
],
|
9 |
+
"text_encoder": [
|
10 |
+
"kolors",
|
11 |
+
"ChatGLMModel"
|
12 |
+
],
|
13 |
+
"tokenizer": [
|
14 |
+
"kolors",
|
15 |
+
"ChatGLMTokenizer"
|
16 |
+
],
|
17 |
+
"unet": [
|
18 |
+
"diffusers",
|
19 |
+
"UNet2DConditionModel"
|
20 |
+
],
|
21 |
+
"vae": [
|
22 |
+
"diffusers",
|
23 |
+
"AutoencoderKL"
|
24 |
+
]
|
25 |
+
}
|
scheduler/scheduler_config.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "EulerDiscreteScheduler",
|
3 |
+
"_diffusers_version": "0.18.0.dev0",
|
4 |
+
"beta_schedule": "scaled_linear",
|
5 |
+
"beta_start": 0.00085,
|
6 |
+
"beta_end": 0.014,
|
7 |
+
"clip_sample": false,
|
8 |
+
"clip_sample_range": 1.0,
|
9 |
+
"dynamic_thresholding_ratio": 0.995,
|
10 |
+
"interpolation_type": "linear",
|
11 |
+
"num_train_timesteps": 1100,
|
12 |
+
"prediction_type": "epsilon",
|
13 |
+
"rescale_betas_zero_snr": false,
|
14 |
+
"sample_max_value": 1.0,
|
15 |
+
"set_alpha_to_one": false,
|
16 |
+
"skip_prk_steps": true,
|
17 |
+
"steps_offset": 1,
|
18 |
+
"thresholding": false,
|
19 |
+
"timestep_spacing": "leading",
|
20 |
+
"trained_betas": null,
|
21 |
+
"use_karras_sigmas": false
|
22 |
+
}
|
text_encoder/__pycache__/configuration_chatglm.cpython-311.pyc
ADDED
Binary file (2.38 kB). View file
|
|
text_encoder/__pycache__/configuration_chatglm.cpython-37.pyc
ADDED
Binary file (1.67 kB). View file
|
|
text_encoder/__pycache__/configuration_chatglm.cpython-38.pyc
ADDED
Binary file (1.68 kB). View file
|
|
text_encoder/__pycache__/configuration_chatglm.cpython-39.pyc
ADDED
Binary file (1.68 kB). View file
|
|
text_encoder/__pycache__/modeling_chatglm.cpython-38.pyc
ADDED
Binary file (33.6 kB). View file
|
|
text_encoder/__pycache__/modeling_chatglm.cpython-39.pyc
ADDED
Binary file (33.6 kB). View file
|
|
text_encoder/__pycache__/tokenization_chatglm.cpython-38.pyc
ADDED
Binary file (11.6 kB). View file
|
|
text_encoder/__pycache__/tokenization_chatglm.cpython-39.pyc
ADDED
Binary file (11.6 kB). View file
|
|
text_encoder/config.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "THUDM/chatglm3-6b-base",
|
3 |
+
"model_type": "chatglm",
|
4 |
+
"architectures": [
|
5 |
+
"ChatGLMModel"
|
6 |
+
],
|
7 |
+
"auto_map": {
|
8 |
+
"AutoConfig": "configuration_chatglm.ChatGLMConfig",
|
9 |
+
"AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
10 |
+
"AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
11 |
+
"AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
12 |
+
"AutoModelForSequenceClassification": "modeling_chatglm.ChatGLMForSequenceClassification"
|
13 |
+
},
|
14 |
+
"add_bias_linear": false,
|
15 |
+
"add_qkv_bias": true,
|
16 |
+
"apply_query_key_layer_scaling": true,
|
17 |
+
"apply_residual_connection_post_layernorm": false,
|
18 |
+
"attention_dropout": 0.0,
|
19 |
+
"attention_softmax_in_fp32": true,
|
20 |
+
"bias_dropout_fusion": true,
|
21 |
+
"ffn_hidden_size": 13696,
|
22 |
+
"fp32_residual_connection": false,
|
23 |
+
"hidden_dropout": 0.0,
|
24 |
+
"hidden_size": 4096,
|
25 |
+
"kv_channels": 128,
|
26 |
+
"layernorm_epsilon": 1e-05,
|
27 |
+
"multi_query_attention": true,
|
28 |
+
"multi_query_group_num": 2,
|
29 |
+
"num_attention_heads": 32,
|
30 |
+
"num_layers": 28,
|
31 |
+
"original_rope": true,
|
32 |
+
"padded_vocab_size": 65024,
|
33 |
+
"post_layer_norm": true,
|
34 |
+
"rmsnorm": true,
|
35 |
+
"seq_length": 32768,
|
36 |
+
"use_cache": true,
|
37 |
+
"torch_dtype": "float16",
|
38 |
+
"transformers_version": "4.30.2",
|
39 |
+
"tie_word_embeddings": false,
|
40 |
+
"eos_token_id": 2,
|
41 |
+
"pad_token_id": 0
|
42 |
+
}
|
text_encoder/configuration_chatglm.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
|
4 |
+
class ChatGLMConfig(PretrainedConfig):
|
5 |
+
model_type = "chatglm"
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
num_layers=28,
|
9 |
+
padded_vocab_size=65024,
|
10 |
+
hidden_size=4096,
|
11 |
+
ffn_hidden_size=13696,
|
12 |
+
kv_channels=128,
|
13 |
+
num_attention_heads=32,
|
14 |
+
seq_length=2048,
|
15 |
+
hidden_dropout=0.0,
|
16 |
+
classifier_dropout=None,
|
17 |
+
attention_dropout=0.0,
|
18 |
+
layernorm_epsilon=1e-5,
|
19 |
+
rmsnorm=True,
|
20 |
+
apply_residual_connection_post_layernorm=False,
|
21 |
+
post_layer_norm=True,
|
22 |
+
add_bias_linear=False,
|
23 |
+
add_qkv_bias=False,
|
24 |
+
bias_dropout_fusion=True,
|
25 |
+
multi_query_attention=False,
|
26 |
+
multi_query_group_num=1,
|
27 |
+
apply_query_key_layer_scaling=True,
|
28 |
+
attention_softmax_in_fp32=True,
|
29 |
+
fp32_residual_connection=False,
|
30 |
+
quantization_bit=0,
|
31 |
+
pre_seq_len=None,
|
32 |
+
prefix_projection=False,
|
33 |
+
**kwargs
|
34 |
+
):
|
35 |
+
self.num_layers = num_layers
|
36 |
+
self.vocab_size = padded_vocab_size
|
37 |
+
self.padded_vocab_size = padded_vocab_size
|
38 |
+
self.hidden_size = hidden_size
|
39 |
+
self.ffn_hidden_size = ffn_hidden_size
|
40 |
+
self.kv_channels = kv_channels
|
41 |
+
self.num_attention_heads = num_attention_heads
|
42 |
+
self.seq_length = seq_length
|
43 |
+
self.hidden_dropout = hidden_dropout
|
44 |
+
self.classifier_dropout = classifier_dropout
|
45 |
+
self.attention_dropout = attention_dropout
|
46 |
+
self.layernorm_epsilon = layernorm_epsilon
|
47 |
+
self.rmsnorm = rmsnorm
|
48 |
+
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
49 |
+
self.post_layer_norm = post_layer_norm
|
50 |
+
self.add_bias_linear = add_bias_linear
|
51 |
+
self.add_qkv_bias = add_qkv_bias
|
52 |
+
self.bias_dropout_fusion = bias_dropout_fusion
|
53 |
+
self.multi_query_attention = multi_query_attention
|
54 |
+
self.multi_query_group_num = multi_query_group_num
|
55 |
+
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
56 |
+
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
57 |
+
self.fp32_residual_connection = fp32_residual_connection
|
58 |
+
self.quantization_bit = quantization_bit
|
59 |
+
self.pre_seq_len = pre_seq_len
|
60 |
+
self.prefix_projection = prefix_projection
|
61 |
+
super().__init__(**kwargs)
|
text_encoder/modeling_chatglm.py
ADDED
@@ -0,0 +1,1298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" PyTorch ChatGLM model. """
|
2 |
+
|
3 |
+
import math
|
4 |
+
import copy
|
5 |
+
import warnings
|
6 |
+
import re
|
7 |
+
import sys
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.utils.checkpoint
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import CrossEntropyLoss, LayerNorm
|
14 |
+
from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
|
15 |
+
from torch.nn.utils import skip_init
|
16 |
+
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
17 |
+
from copy import deepcopy
|
18 |
+
|
19 |
+
from transformers.modeling_outputs import (
|
20 |
+
BaseModelOutputWithPast,
|
21 |
+
CausalLMOutputWithPast,
|
22 |
+
SequenceClassifierOutputWithPast,
|
23 |
+
)
|
24 |
+
from transformers.modeling_utils import PreTrainedModel
|
25 |
+
from transformers.utils import logging
|
26 |
+
from transformers.generation.logits_process import LogitsProcessor
|
27 |
+
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
|
28 |
+
|
29 |
+
try:
|
30 |
+
from .configuration_chatglm import ChatGLMConfig
|
31 |
+
except:
|
32 |
+
from configuration_chatglm import ChatGLMConfig
|
33 |
+
|
34 |
+
|
35 |
+
# flags required to enable jit fusion kernels
|
36 |
+
|
37 |
+
if sys.platform != 'darwin':
|
38 |
+
torch._C._jit_set_profiling_mode(False)
|
39 |
+
torch._C._jit_set_profiling_executor(False)
|
40 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
41 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
42 |
+
|
43 |
+
logger = logging.get_logger(__name__)
|
44 |
+
|
45 |
+
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
|
46 |
+
_CONFIG_FOR_DOC = "ChatGLM6BConfig"
|
47 |
+
|
48 |
+
CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
49 |
+
"THUDM/chatglm3-6b-base",
|
50 |
+
# See all ChatGLM models at https://huggingface.co/models?filter=chatglm
|
51 |
+
]
|
52 |
+
|
53 |
+
|
54 |
+
def default_init(cls, *args, **kwargs):
|
55 |
+
return cls(*args, **kwargs)
|
56 |
+
|
57 |
+
|
58 |
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
59 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
60 |
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
61 |
+
scores.zero_()
|
62 |
+
scores[..., 5] = 5e4
|
63 |
+
return scores
|
64 |
+
|
65 |
+
|
66 |
+
class PrefixEncoder(torch.nn.Module):
|
67 |
+
"""
|
68 |
+
The torch.nn model to encode the prefix
|
69 |
+
Input shape: (batch-size, prefix-length)
|
70 |
+
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(self, config: ChatGLMConfig):
|
74 |
+
super().__init__()
|
75 |
+
self.prefix_projection = config.prefix_projection
|
76 |
+
if self.prefix_projection:
|
77 |
+
# Use a two-layer MLP to encode the prefix
|
78 |
+
kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
|
79 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
|
80 |
+
self.trans = torch.nn.Sequential(
|
81 |
+
torch.nn.Linear(kv_size, config.hidden_size),
|
82 |
+
torch.nn.Tanh(),
|
83 |
+
torch.nn.Linear(config.hidden_size, kv_size)
|
84 |
+
)
|
85 |
+
else:
|
86 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len,
|
87 |
+
config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
|
88 |
+
|
89 |
+
def forward(self, prefix: torch.Tensor):
|
90 |
+
if self.prefix_projection:
|
91 |
+
prefix_tokens = self.embedding(prefix)
|
92 |
+
past_key_values = self.trans(prefix_tokens)
|
93 |
+
else:
|
94 |
+
past_key_values = self.embedding(prefix)
|
95 |
+
return past_key_values
|
96 |
+
|
97 |
+
|
98 |
+
def split_tensor_along_last_dim(
|
99 |
+
tensor: torch.Tensor,
|
100 |
+
num_partitions: int,
|
101 |
+
contiguous_split_chunks: bool = False,
|
102 |
+
) -> List[torch.Tensor]:
|
103 |
+
"""Split a tensor along its last dimension.
|
104 |
+
|
105 |
+
Arguments:
|
106 |
+
tensor: input tensor.
|
107 |
+
num_partitions: number of partitions to split the tensor
|
108 |
+
contiguous_split_chunks: If True, make each chunk contiguous
|
109 |
+
in memory.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
A list of Tensors
|
113 |
+
"""
|
114 |
+
# Get the size and dimension.
|
115 |
+
last_dim = tensor.dim() - 1
|
116 |
+
last_dim_size = tensor.size()[last_dim] // num_partitions
|
117 |
+
# Split.
|
118 |
+
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
119 |
+
# Note: torch.split does not create contiguous tensors by default.
|
120 |
+
if contiguous_split_chunks:
|
121 |
+
return tuple(chunk.contiguous() for chunk in tensor_list)
|
122 |
+
|
123 |
+
return tensor_list
|
124 |
+
|
125 |
+
|
126 |
+
class RotaryEmbedding(nn.Module):
|
127 |
+
def __init__(self, dim, original_impl=False, device=None, dtype=None):
|
128 |
+
super().__init__()
|
129 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
|
130 |
+
self.register_buffer("inv_freq", inv_freq)
|
131 |
+
self.dim = dim
|
132 |
+
self.original_impl = original_impl
|
133 |
+
|
134 |
+
def forward_impl(
|
135 |
+
self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
|
136 |
+
):
|
137 |
+
"""Enhanced Transformer with Rotary Position Embedding.
|
138 |
+
|
139 |
+
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
|
140 |
+
transformers/rope/__init__.py. MIT License:
|
141 |
+
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
|
142 |
+
"""
|
143 |
+
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
144 |
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
|
145 |
+
|
146 |
+
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
147 |
+
seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
|
148 |
+
|
149 |
+
# Calculate the product of position index and $\theta_i$
|
150 |
+
idx_theta = torch.outer(seq_idx, theta).float()
|
151 |
+
|
152 |
+
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
|
153 |
+
|
154 |
+
# this is to mimic the behaviour of complex32, else we will get different results
|
155 |
+
if dtype in (torch.float16, torch.bfloat16, torch.int8):
|
156 |
+
cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
|
157 |
+
return cache
|
158 |
+
|
159 |
+
def forward(self, max_seq_len, offset=0):
|
160 |
+
return self.forward_impl(
|
161 |
+
max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
|
162 |
+
)
|
163 |
+
|
164 |
+
|
165 |
+
@torch.jit.script
|
166 |
+
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
167 |
+
# x: [sq, b, np, hn]
|
168 |
+
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
169 |
+
rot_dim = rope_cache.shape[-2] * 2
|
170 |
+
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
171 |
+
# truncate to support variable sizes
|
172 |
+
rope_cache = rope_cache[:sq]
|
173 |
+
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
|
174 |
+
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
|
175 |
+
x_out2 = torch.stack(
|
176 |
+
[
|
177 |
+
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
|
178 |
+
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
|
179 |
+
],
|
180 |
+
-1,
|
181 |
+
)
|
182 |
+
x_out2 = x_out2.flatten(3)
|
183 |
+
return torch.cat((x_out2, x_pass), dim=-1)
|
184 |
+
|
185 |
+
|
186 |
+
class RMSNorm(torch.nn.Module):
|
187 |
+
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
188 |
+
super().__init__()
|
189 |
+
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
190 |
+
self.eps = eps
|
191 |
+
|
192 |
+
def forward(self, hidden_states: torch.Tensor):
|
193 |
+
input_dtype = hidden_states.dtype
|
194 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
195 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
196 |
+
|
197 |
+
return (self.weight * hidden_states).to(input_dtype)
|
198 |
+
|
199 |
+
|
200 |
+
class CoreAttention(torch.nn.Module):
|
201 |
+
def __init__(self, config: ChatGLMConfig, layer_number):
|
202 |
+
super(CoreAttention, self).__init__()
|
203 |
+
|
204 |
+
self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
|
205 |
+
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
|
206 |
+
if self.apply_query_key_layer_scaling:
|
207 |
+
self.attention_softmax_in_fp32 = True
|
208 |
+
self.layer_number = max(1, layer_number)
|
209 |
+
|
210 |
+
projection_size = config.kv_channels * config.num_attention_heads
|
211 |
+
|
212 |
+
# Per attention head and per partition values.
|
213 |
+
self.hidden_size_per_partition = projection_size
|
214 |
+
self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
|
215 |
+
self.num_attention_heads_per_partition = config.num_attention_heads
|
216 |
+
|
217 |
+
coeff = None
|
218 |
+
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
219 |
+
if self.apply_query_key_layer_scaling:
|
220 |
+
coeff = self.layer_number
|
221 |
+
self.norm_factor *= coeff
|
222 |
+
self.coeff = coeff
|
223 |
+
|
224 |
+
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
|
225 |
+
|
226 |
+
def forward(self, query_layer, key_layer, value_layer, attention_mask):
|
227 |
+
pytorch_major_version = int(torch.__version__.split('.')[0])
|
228 |
+
if pytorch_major_version >= 2:
|
229 |
+
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
230 |
+
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
231 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
232 |
+
is_causal=True)
|
233 |
+
else:
|
234 |
+
if attention_mask is not None:
|
235 |
+
attention_mask = ~attention_mask
|
236 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
237 |
+
attention_mask)
|
238 |
+
context_layer = context_layer.permute(2, 0, 1, 3)
|
239 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
240 |
+
context_layer = context_layer.reshape(*new_context_layer_shape)
|
241 |
+
else:
|
242 |
+
# Raw attention scores
|
243 |
+
|
244 |
+
# [b, np, sq, sk]
|
245 |
+
output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
|
246 |
+
|
247 |
+
# [sq, b, np, hn] -> [sq, b * np, hn]
|
248 |
+
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
249 |
+
# [sk, b, np, hn] -> [sk, b * np, hn]
|
250 |
+
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
251 |
+
|
252 |
+
# preallocting input tensor: [b * np, sq, sk]
|
253 |
+
matmul_input_buffer = torch.empty(
|
254 |
+
output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
|
255 |
+
device=query_layer.device
|
256 |
+
)
|
257 |
+
|
258 |
+
# Raw attention scores. [b * np, sq, sk]
|
259 |
+
matmul_result = torch.baddbmm(
|
260 |
+
matmul_input_buffer,
|
261 |
+
query_layer.transpose(0, 1), # [b * np, sq, hn]
|
262 |
+
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
|
263 |
+
beta=0.0,
|
264 |
+
alpha=(1.0 / self.norm_factor),
|
265 |
+
)
|
266 |
+
|
267 |
+
# change view to [b, np, sq, sk]
|
268 |
+
attention_scores = matmul_result.view(*output_size)
|
269 |
+
|
270 |
+
# ===========================
|
271 |
+
# Attention probs and dropout
|
272 |
+
# ===========================
|
273 |
+
|
274 |
+
# attention scores and attention mask [b, np, sq, sk]
|
275 |
+
if self.attention_softmax_in_fp32:
|
276 |
+
attention_scores = attention_scores.float()
|
277 |
+
if self.coeff is not None:
|
278 |
+
attention_scores = attention_scores * self.coeff
|
279 |
+
if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
|
280 |
+
attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
|
281 |
+
device=attention_scores.device, dtype=torch.bool)
|
282 |
+
attention_mask.tril_()
|
283 |
+
attention_mask = ~attention_mask
|
284 |
+
if attention_mask is not None:
|
285 |
+
attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
|
286 |
+
attention_probs = F.softmax(attention_scores, dim=-1)
|
287 |
+
attention_probs = attention_probs.type_as(value_layer)
|
288 |
+
|
289 |
+
# This is actually dropping out entire tokens to attend to, which might
|
290 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
291 |
+
attention_probs = self.attention_dropout(attention_probs)
|
292 |
+
# =========================
|
293 |
+
# Context layer. [sq, b, hp]
|
294 |
+
# =========================
|
295 |
+
|
296 |
+
# value_layer -> context layer.
|
297 |
+
# [sk, b, np, hn] --> [b, np, sq, hn]
|
298 |
+
|
299 |
+
# context layer shape: [b, np, sq, hn]
|
300 |
+
output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
|
301 |
+
# change view [sk, b * np, hn]
|
302 |
+
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
|
303 |
+
# change view [b * np, sq, sk]
|
304 |
+
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
|
305 |
+
# matmul: [b * np, sq, hn]
|
306 |
+
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
|
307 |
+
# change view [b, np, sq, hn]
|
308 |
+
context_layer = context_layer.view(*output_size)
|
309 |
+
# [b, np, sq, hn] --> [sq, b, np, hn]
|
310 |
+
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
311 |
+
# [sq, b, np, hn] --> [sq, b, hp]
|
312 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
313 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
314 |
+
|
315 |
+
return context_layer
|
316 |
+
|
317 |
+
|
318 |
+
class SelfAttention(torch.nn.Module):
|
319 |
+
"""Parallel self-attention layer abstract class.
|
320 |
+
|
321 |
+
Self-attention layer takes input with size [s, b, h]
|
322 |
+
and returns output of the same size.
|
323 |
+
"""
|
324 |
+
|
325 |
+
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
326 |
+
super(SelfAttention, self).__init__()
|
327 |
+
self.layer_number = max(1, layer_number)
|
328 |
+
|
329 |
+
self.projection_size = config.kv_channels * config.num_attention_heads
|
330 |
+
|
331 |
+
# Per attention head and per partition values.
|
332 |
+
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
|
333 |
+
self.num_attention_heads_per_partition = config.num_attention_heads
|
334 |
+
|
335 |
+
self.multi_query_attention = config.multi_query_attention
|
336 |
+
self.qkv_hidden_size = 3 * self.projection_size
|
337 |
+
if self.multi_query_attention:
|
338 |
+
self.num_multi_query_groups_per_partition = config.multi_query_group_num
|
339 |
+
self.qkv_hidden_size = (
|
340 |
+
self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
|
341 |
+
)
|
342 |
+
self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
|
343 |
+
bias=config.add_bias_linear or config.add_qkv_bias,
|
344 |
+
device=device, **_config_to_kwargs(config)
|
345 |
+
)
|
346 |
+
|
347 |
+
self.core_attention = CoreAttention(config, self.layer_number)
|
348 |
+
|
349 |
+
# Output.
|
350 |
+
self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
|
351 |
+
device=device, **_config_to_kwargs(config)
|
352 |
+
)
|
353 |
+
|
354 |
+
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
|
355 |
+
if self.multi_query_attention:
|
356 |
+
num_attention_heads = self.num_multi_query_groups_per_partition
|
357 |
+
else:
|
358 |
+
num_attention_heads = self.num_attention_heads_per_partition
|
359 |
+
return torch.empty(
|
360 |
+
inference_max_sequence_len,
|
361 |
+
batch_size,
|
362 |
+
num_attention_heads,
|
363 |
+
self.hidden_size_per_attention_head,
|
364 |
+
dtype=dtype,
|
365 |
+
device=device,
|
366 |
+
)
|
367 |
+
|
368 |
+
def forward(
|
369 |
+
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
370 |
+
):
|
371 |
+
# hidden_states: [sq, b, h]
|
372 |
+
|
373 |
+
# =================================================
|
374 |
+
# Pre-allocate memory for key-values for inference.
|
375 |
+
# =================================================
|
376 |
+
# =====================
|
377 |
+
# Query, Key, and Value
|
378 |
+
# =====================
|
379 |
+
|
380 |
+
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
381 |
+
mixed_x_layer = self.query_key_value(hidden_states)
|
382 |
+
|
383 |
+
if self.multi_query_attention:
|
384 |
+
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
385 |
+
[
|
386 |
+
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
|
387 |
+
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
388 |
+
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
389 |
+
],
|
390 |
+
dim=-1,
|
391 |
+
)
|
392 |
+
query_layer = query_layer.view(
|
393 |
+
query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
394 |
+
)
|
395 |
+
key_layer = key_layer.view(
|
396 |
+
key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
397 |
+
)
|
398 |
+
value_layer = value_layer.view(
|
399 |
+
value_layer.size()[:-1]
|
400 |
+
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
401 |
+
)
|
402 |
+
else:
|
403 |
+
new_tensor_shape = mixed_x_layer.size()[:-1] + \
|
404 |
+
(self.num_attention_heads_per_partition,
|
405 |
+
3 * self.hidden_size_per_attention_head)
|
406 |
+
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
407 |
+
|
408 |
+
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
409 |
+
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
410 |
+
|
411 |
+
# apply relative positional encoding (rotary embedding)
|
412 |
+
if rotary_pos_emb is not None:
|
413 |
+
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
414 |
+
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
415 |
+
|
416 |
+
# adjust key and value for inference
|
417 |
+
if kv_cache is not None:
|
418 |
+
cache_k, cache_v = kv_cache
|
419 |
+
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
420 |
+
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
421 |
+
if use_cache:
|
422 |
+
kv_cache = (key_layer, value_layer)
|
423 |
+
else:
|
424 |
+
kv_cache = None
|
425 |
+
|
426 |
+
if self.multi_query_attention:
|
427 |
+
key_layer = key_layer.unsqueeze(-2)
|
428 |
+
key_layer = key_layer.expand(
|
429 |
+
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
|
430 |
+
)
|
431 |
+
key_layer = key_layer.contiguous().view(
|
432 |
+
key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
433 |
+
)
|
434 |
+
value_layer = value_layer.unsqueeze(-2)
|
435 |
+
value_layer = value_layer.expand(
|
436 |
+
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
|
437 |
+
)
|
438 |
+
value_layer = value_layer.contiguous().view(
|
439 |
+
value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
440 |
+
)
|
441 |
+
|
442 |
+
# ==================================
|
443 |
+
# core attention computation
|
444 |
+
# ==================================
|
445 |
+
|
446 |
+
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
447 |
+
|
448 |
+
# =================
|
449 |
+
# Output. [sq, b, h]
|
450 |
+
# =================
|
451 |
+
|
452 |
+
output = self.dense(context_layer)
|
453 |
+
|
454 |
+
return output, kv_cache
|
455 |
+
|
456 |
+
|
457 |
+
def _config_to_kwargs(args):
|
458 |
+
common_kwargs = {
|
459 |
+
"dtype": args.torch_dtype,
|
460 |
+
}
|
461 |
+
return common_kwargs
|
462 |
+
|
463 |
+
|
464 |
+
class MLP(torch.nn.Module):
|
465 |
+
"""MLP.
|
466 |
+
|
467 |
+
MLP will take the input with h hidden state, project it to 4*h
|
468 |
+
hidden dimension, perform nonlinear transformation, and project the
|
469 |
+
state back into h hidden dimension.
|
470 |
+
"""
|
471 |
+
|
472 |
+
def __init__(self, config: ChatGLMConfig, device=None):
|
473 |
+
super(MLP, self).__init__()
|
474 |
+
|
475 |
+
self.add_bias = config.add_bias_linear
|
476 |
+
|
477 |
+
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
478 |
+
self.dense_h_to_4h = nn.Linear(
|
479 |
+
config.hidden_size,
|
480 |
+
config.ffn_hidden_size * 2,
|
481 |
+
bias=self.add_bias,
|
482 |
+
device=device,
|
483 |
+
**_config_to_kwargs(config)
|
484 |
+
)
|
485 |
+
|
486 |
+
def swiglu(x):
|
487 |
+
x = torch.chunk(x, 2, dim=-1)
|
488 |
+
return F.silu(x[0]) * x[1]
|
489 |
+
|
490 |
+
self.activation_func = swiglu
|
491 |
+
|
492 |
+
# Project back to h.
|
493 |
+
self.dense_4h_to_h = nn.Linear(
|
494 |
+
config.ffn_hidden_size,
|
495 |
+
config.hidden_size,
|
496 |
+
bias=self.add_bias,
|
497 |
+
device=device,
|
498 |
+
**_config_to_kwargs(config)
|
499 |
+
)
|
500 |
+
|
501 |
+
def forward(self, hidden_states):
|
502 |
+
# [s, b, 4hp]
|
503 |
+
intermediate_parallel = self.dense_h_to_4h(hidden_states)
|
504 |
+
intermediate_parallel = self.activation_func(intermediate_parallel)
|
505 |
+
# [s, b, h]
|
506 |
+
output = self.dense_4h_to_h(intermediate_parallel)
|
507 |
+
return output
|
508 |
+
|
509 |
+
|
510 |
+
class GLMBlock(torch.nn.Module):
|
511 |
+
"""A single transformer layer.
|
512 |
+
|
513 |
+
Transformer layer takes input with size [s, b, h] and returns an
|
514 |
+
output of the same size.
|
515 |
+
"""
|
516 |
+
|
517 |
+
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
518 |
+
super(GLMBlock, self).__init__()
|
519 |
+
self.layer_number = layer_number
|
520 |
+
|
521 |
+
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
522 |
+
|
523 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
524 |
+
|
525 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
526 |
+
# Layernorm on the input data.
|
527 |
+
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
528 |
+
dtype=config.torch_dtype)
|
529 |
+
|
530 |
+
# Self attention.
|
531 |
+
self.self_attention = SelfAttention(config, layer_number, device=device)
|
532 |
+
self.hidden_dropout = config.hidden_dropout
|
533 |
+
|
534 |
+
# Layernorm on the attention output
|
535 |
+
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
536 |
+
dtype=config.torch_dtype)
|
537 |
+
|
538 |
+
# MLP
|
539 |
+
self.mlp = MLP(config, device=device)
|
540 |
+
|
541 |
+
def forward(
|
542 |
+
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
|
543 |
+
):
|
544 |
+
# hidden_states: [s, b, h]
|
545 |
+
|
546 |
+
# Layer norm at the beginning of the transformer layer.
|
547 |
+
layernorm_output = self.input_layernorm(hidden_states)
|
548 |
+
# Self attention.
|
549 |
+
attention_output, kv_cache = self.self_attention(
|
550 |
+
layernorm_output,
|
551 |
+
attention_mask,
|
552 |
+
rotary_pos_emb,
|
553 |
+
kv_cache=kv_cache,
|
554 |
+
use_cache=use_cache
|
555 |
+
)
|
556 |
+
|
557 |
+
# Residual connection.
|
558 |
+
if self.apply_residual_connection_post_layernorm:
|
559 |
+
residual = layernorm_output
|
560 |
+
else:
|
561 |
+
residual = hidden_states
|
562 |
+
|
563 |
+
layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
|
564 |
+
layernorm_input = residual + layernorm_input
|
565 |
+
|
566 |
+
# Layer norm post the self attention.
|
567 |
+
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
568 |
+
|
569 |
+
# MLP.
|
570 |
+
mlp_output = self.mlp(layernorm_output)
|
571 |
+
|
572 |
+
# Second residual connection.
|
573 |
+
if self.apply_residual_connection_post_layernorm:
|
574 |
+
residual = layernorm_output
|
575 |
+
else:
|
576 |
+
residual = layernorm_input
|
577 |
+
|
578 |
+
output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
|
579 |
+
output = residual + output
|
580 |
+
|
581 |
+
return output, kv_cache
|
582 |
+
|
583 |
+
|
584 |
+
class GLMTransformer(torch.nn.Module):
|
585 |
+
"""Transformer class."""
|
586 |
+
|
587 |
+
def __init__(self, config: ChatGLMConfig, device=None):
|
588 |
+
super(GLMTransformer, self).__init__()
|
589 |
+
|
590 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
591 |
+
self.post_layer_norm = config.post_layer_norm
|
592 |
+
|
593 |
+
# Number of layers.
|
594 |
+
self.num_layers = config.num_layers
|
595 |
+
|
596 |
+
# Transformer layers.
|
597 |
+
def build_layer(layer_number):
|
598 |
+
return GLMBlock(config, layer_number, device=device)
|
599 |
+
|
600 |
+
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
601 |
+
|
602 |
+
if self.post_layer_norm:
|
603 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
604 |
+
# Final layer norm before output.
|
605 |
+
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
606 |
+
dtype=config.torch_dtype)
|
607 |
+
|
608 |
+
self.gradient_checkpointing = False
|
609 |
+
|
610 |
+
def _get_layer(self, layer_number):
|
611 |
+
return self.layers[layer_number]
|
612 |
+
|
613 |
+
def forward(
|
614 |
+
self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
|
615 |
+
use_cache: Optional[bool] = True,
|
616 |
+
output_hidden_states: Optional[bool] = False,
|
617 |
+
):
|
618 |
+
if not kv_caches:
|
619 |
+
kv_caches = [None for _ in range(self.num_layers)]
|
620 |
+
presents = () if use_cache else None
|
621 |
+
if self.gradient_checkpointing and self.training:
|
622 |
+
if use_cache:
|
623 |
+
logger.warning_once(
|
624 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
625 |
+
)
|
626 |
+
use_cache = False
|
627 |
+
|
628 |
+
all_self_attentions = None
|
629 |
+
all_hidden_states = () if output_hidden_states else None
|
630 |
+
for index in range(self.num_layers):
|
631 |
+
if output_hidden_states:
|
632 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
633 |
+
|
634 |
+
layer = self._get_layer(index)
|
635 |
+
if self.gradient_checkpointing and self.training:
|
636 |
+
layer_ret = torch.utils.checkpoint.checkpoint(
|
637 |
+
layer,
|
638 |
+
hidden_states,
|
639 |
+
attention_mask,
|
640 |
+
rotary_pos_emb,
|
641 |
+
kv_caches[index],
|
642 |
+
use_cache
|
643 |
+
)
|
644 |
+
else:
|
645 |
+
layer_ret = layer(
|
646 |
+
hidden_states,
|
647 |
+
attention_mask,
|
648 |
+
rotary_pos_emb,
|
649 |
+
kv_cache=kv_caches[index],
|
650 |
+
use_cache=use_cache
|
651 |
+
)
|
652 |
+
hidden_states, kv_cache = layer_ret
|
653 |
+
if use_cache:
|
654 |
+
presents = presents + (kv_cache,)
|
655 |
+
|
656 |
+
if output_hidden_states:
|
657 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
658 |
+
|
659 |
+
# Final layer norm.
|
660 |
+
if self.post_layer_norm:
|
661 |
+
hidden_states = self.final_layernorm(hidden_states)
|
662 |
+
|
663 |
+
return hidden_states, presents, all_hidden_states, all_self_attentions
|
664 |
+
|
665 |
+
|
666 |
+
class ChatGLMPreTrainedModel(PreTrainedModel):
|
667 |
+
"""
|
668 |
+
An abstract class to handle weights initialization and
|
669 |
+
a simple interface for downloading and loading pretrained models.
|
670 |
+
"""
|
671 |
+
|
672 |
+
is_parallelizable = False
|
673 |
+
supports_gradient_checkpointing = True
|
674 |
+
config_class = ChatGLMConfig
|
675 |
+
base_model_prefix = "transformer"
|
676 |
+
_no_split_modules = ["GLMBlock"]
|
677 |
+
|
678 |
+
def _init_weights(self, module: nn.Module):
|
679 |
+
"""Initialize the weights."""
|
680 |
+
return
|
681 |
+
|
682 |
+
def get_masks(self, input_ids, past_key_values, padding_mask=None):
|
683 |
+
batch_size, seq_length = input_ids.shape
|
684 |
+
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
|
685 |
+
full_attention_mask.tril_()
|
686 |
+
past_length = 0
|
687 |
+
if past_key_values:
|
688 |
+
past_length = past_key_values[0][0].shape[0]
|
689 |
+
if past_length:
|
690 |
+
full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
|
691 |
+
device=input_ids.device), full_attention_mask), dim=-1)
|
692 |
+
if padding_mask is not None:
|
693 |
+
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
694 |
+
if not past_length and padding_mask is not None:
|
695 |
+
full_attention_mask -= padding_mask.unsqueeze(-1) - 1
|
696 |
+
full_attention_mask = (full_attention_mask < 0.5).bool()
|
697 |
+
full_attention_mask.unsqueeze_(1)
|
698 |
+
return full_attention_mask
|
699 |
+
|
700 |
+
def get_position_ids(self, input_ids, device):
|
701 |
+
batch_size, seq_length = input_ids.shape
|
702 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
703 |
+
return position_ids
|
704 |
+
|
705 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
706 |
+
if isinstance(module, GLMTransformer):
|
707 |
+
module.gradient_checkpointing = value
|
708 |
+
|
709 |
+
|
710 |
+
class Embedding(torch.nn.Module):
|
711 |
+
"""Language model embeddings."""
|
712 |
+
|
713 |
+
def __init__(self, config: ChatGLMConfig, device=None):
|
714 |
+
super(Embedding, self).__init__()
|
715 |
+
|
716 |
+
self.hidden_size = config.hidden_size
|
717 |
+
# Word embeddings (parallel).
|
718 |
+
self.word_embeddings = nn.Embedding(
|
719 |
+
config.padded_vocab_size,
|
720 |
+
self.hidden_size,
|
721 |
+
dtype=config.torch_dtype,
|
722 |
+
device=device
|
723 |
+
)
|
724 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
725 |
+
|
726 |
+
def forward(self, input_ids):
|
727 |
+
# Embeddings.
|
728 |
+
words_embeddings = self.word_embeddings(input_ids)
|
729 |
+
embeddings = words_embeddings
|
730 |
+
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
|
731 |
+
embeddings = embeddings.transpose(0, 1).contiguous()
|
732 |
+
# If the input flag for fp32 residual connection is set, convert for float.
|
733 |
+
if self.fp32_residual_connection:
|
734 |
+
embeddings = embeddings.float()
|
735 |
+
return embeddings
|
736 |
+
|
737 |
+
|
738 |
+
class ChatGLMModel(ChatGLMPreTrainedModel):
|
739 |
+
def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
|
740 |
+
super().__init__(config)
|
741 |
+
if empty_init:
|
742 |
+
init_method = skip_init
|
743 |
+
else:
|
744 |
+
init_method = default_init
|
745 |
+
init_kwargs = {}
|
746 |
+
if device is not None:
|
747 |
+
init_kwargs["device"] = device
|
748 |
+
self.embedding = init_method(Embedding, config, **init_kwargs)
|
749 |
+
self.num_layers = config.num_layers
|
750 |
+
self.multi_query_group_num = config.multi_query_group_num
|
751 |
+
self.kv_channels = config.kv_channels
|
752 |
+
|
753 |
+
# Rotary positional embeddings
|
754 |
+
self.seq_length = config.seq_length
|
755 |
+
rotary_dim = (
|
756 |
+
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
757 |
+
)
|
758 |
+
|
759 |
+
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
760 |
+
dtype=config.torch_dtype)
|
761 |
+
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
762 |
+
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
763 |
+
dtype=config.torch_dtype, **init_kwargs)
|
764 |
+
self.pre_seq_len = config.pre_seq_len
|
765 |
+
self.prefix_projection = config.prefix_projection
|
766 |
+
if self.pre_seq_len is not None:
|
767 |
+
for param in self.parameters():
|
768 |
+
param.requires_grad = False
|
769 |
+
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
770 |
+
self.prefix_encoder = PrefixEncoder(config)
|
771 |
+
self.dropout = torch.nn.Dropout(0.1)
|
772 |
+
|
773 |
+
def get_input_embeddings(self):
|
774 |
+
return self.embedding.word_embeddings
|
775 |
+
|
776 |
+
def get_prompt(self, batch_size, device, dtype=torch.half):
|
777 |
+
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
778 |
+
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|
779 |
+
past_key_values = past_key_values.view(
|
780 |
+
batch_size,
|
781 |
+
self.pre_seq_len,
|
782 |
+
self.num_layers * 2,
|
783 |
+
self.multi_query_group_num,
|
784 |
+
self.kv_channels
|
785 |
+
)
|
786 |
+
# seq_len, b, nh, hidden_size
|
787 |
+
past_key_values = self.dropout(past_key_values)
|
788 |
+
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
|
789 |
+
return past_key_values
|
790 |
+
|
791 |
+
def forward(
|
792 |
+
self,
|
793 |
+
input_ids,
|
794 |
+
position_ids: Optional[torch.Tensor] = None,
|
795 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
796 |
+
full_attention_mask: Optional[torch.BoolTensor] = None,
|
797 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
798 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
799 |
+
use_cache: Optional[bool] = None,
|
800 |
+
output_hidden_states: Optional[bool] = None,
|
801 |
+
return_dict: Optional[bool] = None,
|
802 |
+
):
|
803 |
+
output_hidden_states = (
|
804 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
805 |
+
)
|
806 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
807 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
808 |
+
|
809 |
+
batch_size, seq_length = input_ids.shape
|
810 |
+
|
811 |
+
if inputs_embeds is None:
|
812 |
+
inputs_embeds = self.embedding(input_ids)
|
813 |
+
|
814 |
+
if self.pre_seq_len is not None:
|
815 |
+
if past_key_values is None:
|
816 |
+
past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
|
817 |
+
dtype=inputs_embeds.dtype)
|
818 |
+
if attention_mask is not None:
|
819 |
+
attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
|
820 |
+
attention_mask], dim=-1)
|
821 |
+
|
822 |
+
if full_attention_mask is None:
|
823 |
+
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
824 |
+
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
825 |
+
|
826 |
+
# Rotary positional embeddings
|
827 |
+
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
828 |
+
if position_ids is not None:
|
829 |
+
rotary_pos_emb = rotary_pos_emb[position_ids]
|
830 |
+
else:
|
831 |
+
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
832 |
+
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
833 |
+
|
834 |
+
# Run encoder.
|
835 |
+
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
836 |
+
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
837 |
+
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
838 |
+
)
|
839 |
+
|
840 |
+
if not return_dict:
|
841 |
+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
842 |
+
|
843 |
+
return BaseModelOutputWithPast(
|
844 |
+
last_hidden_state=hidden_states,
|
845 |
+
past_key_values=presents,
|
846 |
+
hidden_states=all_hidden_states,
|
847 |
+
attentions=all_self_attentions,
|
848 |
+
)
|
849 |
+
|
850 |
+
def quantize(self, weight_bit_width: int):
|
851 |
+
from .quantization import quantize
|
852 |
+
quantize(self.encoder, weight_bit_width)
|
853 |
+
return self
|
854 |
+
|
855 |
+
|
856 |
+
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
857 |
+
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
858 |
+
super().__init__(config)
|
859 |
+
|
860 |
+
self.max_sequence_length = config.max_length
|
861 |
+
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
862 |
+
self.config = config
|
863 |
+
self.quantized = False
|
864 |
+
|
865 |
+
if self.config.quantization_bit:
|
866 |
+
self.quantize(self.config.quantization_bit, empty_init=True)
|
867 |
+
|
868 |
+
def _update_model_kwargs_for_generation(
|
869 |
+
self,
|
870 |
+
outputs: ModelOutput,
|
871 |
+
model_kwargs: Dict[str, Any],
|
872 |
+
is_encoder_decoder: bool = False,
|
873 |
+
standardize_cache_format: bool = False,
|
874 |
+
) -> Dict[str, Any]:
|
875 |
+
# update past_key_values
|
876 |
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
877 |
+
outputs, standardize_cache_format=standardize_cache_format
|
878 |
+
)
|
879 |
+
|
880 |
+
# update attention mask
|
881 |
+
if "attention_mask" in model_kwargs:
|
882 |
+
attention_mask = model_kwargs["attention_mask"]
|
883 |
+
model_kwargs["attention_mask"] = torch.cat(
|
884 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
885 |
+
)
|
886 |
+
|
887 |
+
# update position ids
|
888 |
+
if "position_ids" in model_kwargs:
|
889 |
+
position_ids = model_kwargs["position_ids"]
|
890 |
+
new_position_id = position_ids[..., -1:].clone()
|
891 |
+
new_position_id += 1
|
892 |
+
model_kwargs["position_ids"] = torch.cat(
|
893 |
+
[position_ids, new_position_id], dim=-1
|
894 |
+
)
|
895 |
+
|
896 |
+
model_kwargs["is_first_forward"] = False
|
897 |
+
return model_kwargs
|
898 |
+
|
899 |
+
def prepare_inputs_for_generation(
|
900 |
+
self,
|
901 |
+
input_ids: torch.LongTensor,
|
902 |
+
past_key_values: Optional[torch.Tensor] = None,
|
903 |
+
attention_mask: Optional[torch.Tensor] = None,
|
904 |
+
position_ids: Optional[torch.Tensor] = None,
|
905 |
+
use_cache: Optional[bool] = None,
|
906 |
+
is_first_forward: bool = True,
|
907 |
+
**kwargs
|
908 |
+
) -> dict:
|
909 |
+
# only last token for input_ids if past is not None
|
910 |
+
if position_ids is None:
|
911 |
+
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
|
912 |
+
if not is_first_forward:
|
913 |
+
if past_key_values is not None:
|
914 |
+
position_ids = position_ids[..., -1:]
|
915 |
+
input_ids = input_ids[:, -1:]
|
916 |
+
return {
|
917 |
+
"input_ids": input_ids,
|
918 |
+
"past_key_values": past_key_values,
|
919 |
+
"position_ids": position_ids,
|
920 |
+
"attention_mask": attention_mask,
|
921 |
+
"return_last_logit": True,
|
922 |
+
"use_cache": use_cache
|
923 |
+
}
|
924 |
+
|
925 |
+
def forward(
|
926 |
+
self,
|
927 |
+
input_ids: Optional[torch.Tensor] = None,
|
928 |
+
position_ids: Optional[torch.Tensor] = None,
|
929 |
+
attention_mask: Optional[torch.Tensor] = None,
|
930 |
+
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
931 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
932 |
+
labels: Optional[torch.Tensor] = None,
|
933 |
+
use_cache: Optional[bool] = None,
|
934 |
+
output_attentions: Optional[bool] = None,
|
935 |
+
output_hidden_states: Optional[bool] = None,
|
936 |
+
return_dict: Optional[bool] = None,
|
937 |
+
return_last_logit: Optional[bool] = False,
|
938 |
+
):
|
939 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
940 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
941 |
+
|
942 |
+
transformer_outputs = self.transformer(
|
943 |
+
input_ids=input_ids,
|
944 |
+
position_ids=position_ids,
|
945 |
+
attention_mask=attention_mask,
|
946 |
+
past_key_values=past_key_values,
|
947 |
+
inputs_embeds=inputs_embeds,
|
948 |
+
use_cache=use_cache,
|
949 |
+
output_hidden_states=output_hidden_states,
|
950 |
+
return_dict=return_dict,
|
951 |
+
)
|
952 |
+
|
953 |
+
hidden_states = transformer_outputs[0]
|
954 |
+
if return_last_logit:
|
955 |
+
hidden_states = hidden_states[-1:]
|
956 |
+
lm_logits = self.transformer.output_layer(hidden_states)
|
957 |
+
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
958 |
+
|
959 |
+
loss = None
|
960 |
+
if labels is not None:
|
961 |
+
lm_logits = lm_logits.to(torch.float32)
|
962 |
+
|
963 |
+
# Shift so that tokens < n predict n
|
964 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
965 |
+
shift_labels = labels[..., 1:].contiguous()
|
966 |
+
# Flatten the tokens
|
967 |
+
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
968 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
969 |
+
|
970 |
+
lm_logits = lm_logits.to(hidden_states.dtype)
|
971 |
+
loss = loss.to(hidden_states.dtype)
|
972 |
+
|
973 |
+
if not return_dict:
|
974 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
975 |
+
return ((loss,) + output) if loss is not None else output
|
976 |
+
|
977 |
+
return CausalLMOutputWithPast(
|
978 |
+
loss=loss,
|
979 |
+
logits=lm_logits,
|
980 |
+
past_key_values=transformer_outputs.past_key_values,
|
981 |
+
hidden_states=transformer_outputs.hidden_states,
|
982 |
+
attentions=transformer_outputs.attentions,
|
983 |
+
)
|
984 |
+
|
985 |
+
@staticmethod
|
986 |
+
def _reorder_cache(
|
987 |
+
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
988 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
989 |
+
"""
|
990 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
991 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
992 |
+
beam_idx at every generation step.
|
993 |
+
|
994 |
+
Output shares the same memory storage as `past`.
|
995 |
+
"""
|
996 |
+
return tuple(
|
997 |
+
(
|
998 |
+
layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
|
999 |
+
layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
|
1000 |
+
)
|
1001 |
+
for layer_past in past
|
1002 |
+
)
|
1003 |
+
|
1004 |
+
def process_response(self, output, history):
|
1005 |
+
content = ""
|
1006 |
+
history = deepcopy(history)
|
1007 |
+
for response in output.split("<|assistant|>"):
|
1008 |
+
metadata, content = response.split("\n", maxsplit=1)
|
1009 |
+
if not metadata.strip():
|
1010 |
+
content = content.strip()
|
1011 |
+
history.append({"role": "assistant", "metadata": metadata, "content": content})
|
1012 |
+
content = content.replace("[[训练时间]]", "2023年")
|
1013 |
+
else:
|
1014 |
+
history.append({"role": "assistant", "metadata": metadata, "content": content})
|
1015 |
+
if history[0]["role"] == "system" and "tools" in history[0]:
|
1016 |
+
content = "\n".join(content.split("\n")[1:-1])
|
1017 |
+
def tool_call(**kwargs):
|
1018 |
+
return kwargs
|
1019 |
+
parameters = eval(content)
|
1020 |
+
content = {"name": metadata.strip(), "parameters": parameters}
|
1021 |
+
else:
|
1022 |
+
content = {"name": metadata.strip(), "content": content}
|
1023 |
+
return content, history
|
1024 |
+
|
1025 |
+
@torch.inference_mode()
|
1026 |
+
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
|
1027 |
+
max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
1028 |
+
**kwargs):
|
1029 |
+
if history is None:
|
1030 |
+
history = []
|
1031 |
+
if logits_processor is None:
|
1032 |
+
logits_processor = LogitsProcessorList()
|
1033 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
1034 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
1035 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1036 |
+
inputs = tokenizer.build_chat_input(query, history=history, role=role)
|
1037 |
+
inputs = inputs.to(self.device)
|
1038 |
+
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
1039 |
+
tokenizer.get_command("<|observation|>")]
|
1040 |
+
outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
|
1041 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
1042 |
+
response = tokenizer.decode(outputs)
|
1043 |
+
history.append({"role": role, "content": query})
|
1044 |
+
response, history = self.process_response(response, history)
|
1045 |
+
return response, history
|
1046 |
+
|
1047 |
+
@torch.inference_mode()
|
1048 |
+
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
|
1049 |
+
past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
|
1050 |
+
logits_processor=None, return_past_key_values=False, **kwargs):
|
1051 |
+
if history is None:
|
1052 |
+
history = []
|
1053 |
+
if logits_processor is None:
|
1054 |
+
logits_processor = LogitsProcessorList()
|
1055 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
1056 |
+
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
1057 |
+
tokenizer.get_command("<|observation|>")]
|
1058 |
+
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
1059 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1060 |
+
if past_key_values is None:
|
1061 |
+
inputs = tokenizer.build_chat_input(query, history=history, role=role)
|
1062 |
+
else:
|
1063 |
+
inputs = tokenizer.build_chat_input(query, role=role)
|
1064 |
+
inputs = inputs.to(self.device)
|
1065 |
+
if past_key_values is not None:
|
1066 |
+
past_length = past_key_values[0][0].shape[0]
|
1067 |
+
if self.transformer.pre_seq_len is not None:
|
1068 |
+
past_length -= self.transformer.pre_seq_len
|
1069 |
+
inputs.position_ids += past_length
|
1070 |
+
attention_mask = inputs.attention_mask
|
1071 |
+
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
1072 |
+
inputs['attention_mask'] = attention_mask
|
1073 |
+
history.append({"role": role, "content": query})
|
1074 |
+
for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
|
1075 |
+
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
|
1076 |
+
**gen_kwargs):
|
1077 |
+
if return_past_key_values:
|
1078 |
+
outputs, past_key_values = outputs
|
1079 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
1080 |
+
response = tokenizer.decode(outputs)
|
1081 |
+
if response and response[-1] != "�":
|
1082 |
+
response, new_history = self.process_response(response, history)
|
1083 |
+
if return_past_key_values:
|
1084 |
+
yield response, new_history, past_key_values
|
1085 |
+
else:
|
1086 |
+
yield response, new_history
|
1087 |
+
|
1088 |
+
@torch.inference_mode()
|
1089 |
+
def stream_generate(
|
1090 |
+
self,
|
1091 |
+
input_ids,
|
1092 |
+
generation_config: Optional[GenerationConfig] = None,
|
1093 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
1094 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
1095 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
1096 |
+
return_past_key_values=False,
|
1097 |
+
**kwargs,
|
1098 |
+
):
|
1099 |
+
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
1100 |
+
|
1101 |
+
if generation_config is None:
|
1102 |
+
generation_config = self.generation_config
|
1103 |
+
generation_config = copy.deepcopy(generation_config)
|
1104 |
+
model_kwargs = generation_config.update(**kwargs)
|
1105 |
+
model_kwargs["use_cache"] = generation_config.use_cache
|
1106 |
+
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
|
1107 |
+
|
1108 |
+
if isinstance(eos_token_id, int):
|
1109 |
+
eos_token_id = [eos_token_id]
|
1110 |
+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
1111 |
+
|
1112 |
+
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
1113 |
+
if has_default_max_length and generation_config.max_new_tokens is None:
|
1114 |
+
warnings.warn(
|
1115 |
+
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
1116 |
+
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
|
1117 |
+
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
1118 |
+
UserWarning,
|
1119 |
+
)
|
1120 |
+
elif generation_config.max_new_tokens is not None:
|
1121 |
+
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
1122 |
+
if not has_default_max_length:
|
1123 |
+
logger.warn(
|
1124 |
+
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
1125 |
+
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
1126 |
+
"Please refer to the documentation for more information. "
|
1127 |
+
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
|
1128 |
+
UserWarning,
|
1129 |
+
)
|
1130 |
+
|
1131 |
+
if input_ids_seq_length >= generation_config.max_length:
|
1132 |
+
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
1133 |
+
logger.warning(
|
1134 |
+
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
1135 |
+
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
1136 |
+
" increasing `max_new_tokens`."
|
1137 |
+
)
|
1138 |
+
|
1139 |
+
# 2. Set generation parameters if not already defined
|
1140 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
1141 |
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
1142 |
+
|
1143 |
+
logits_processor = self._get_logits_processor(
|
1144 |
+
generation_config=generation_config,
|
1145 |
+
input_ids_seq_length=input_ids_seq_length,
|
1146 |
+
encoder_input_ids=input_ids,
|
1147 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
1148 |
+
logits_processor=logits_processor,
|
1149 |
+
)
|
1150 |
+
|
1151 |
+
stopping_criteria = self._get_stopping_criteria(
|
1152 |
+
generation_config=generation_config, stopping_criteria=stopping_criteria
|
1153 |
+
)
|
1154 |
+
logits_warper = self._get_logits_warper(generation_config)
|
1155 |
+
|
1156 |
+
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
1157 |
+
scores = None
|
1158 |
+
while True:
|
1159 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
1160 |
+
# forward pass to get next token
|
1161 |
+
outputs = self(
|
1162 |
+
**model_inputs,
|
1163 |
+
return_dict=True,
|
1164 |
+
output_attentions=False,
|
1165 |
+
output_hidden_states=False,
|
1166 |
+
)
|
1167 |
+
|
1168 |
+
next_token_logits = outputs.logits[:, -1, :]
|
1169 |
+
|
1170 |
+
# pre-process distribution
|
1171 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
1172 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
1173 |
+
|
1174 |
+
# sample
|
1175 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
1176 |
+
if generation_config.do_sample:
|
1177 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
1178 |
+
else:
|
1179 |
+
next_tokens = torch.argmax(probs, dim=-1)
|
1180 |
+
# update generated ids, model inputs, and length for next step
|
1181 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
1182 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
1183 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
1184 |
+
)
|
1185 |
+
unfinished_sequences = unfinished_sequences.mul(
|
1186 |
+
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
1187 |
+
)
|
1188 |
+
if return_past_key_values:
|
1189 |
+
yield input_ids, outputs.past_key_values
|
1190 |
+
else:
|
1191 |
+
yield input_ids
|
1192 |
+
# stop when each sentence is finished, or if we exceed the maximum length
|
1193 |
+
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
1194 |
+
break
|
1195 |
+
|
1196 |
+
def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
|
1197 |
+
if bits == 0:
|
1198 |
+
return
|
1199 |
+
|
1200 |
+
from .quantization import quantize
|
1201 |
+
|
1202 |
+
if self.quantized:
|
1203 |
+
logger.info("Already quantized.")
|
1204 |
+
return self
|
1205 |
+
|
1206 |
+
self.quantized = True
|
1207 |
+
|
1208 |
+
self.config.quantization_bit = bits
|
1209 |
+
|
1210 |
+
self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
|
1211 |
+
**kwargs)
|
1212 |
+
return self
|
1213 |
+
|
1214 |
+
|
1215 |
+
class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
1216 |
+
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
1217 |
+
super().__init__(config)
|
1218 |
+
|
1219 |
+
self.num_labels = config.num_labels
|
1220 |
+
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
1221 |
+
|
1222 |
+
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
|
1223 |
+
if config.classifier_dropout is not None:
|
1224 |
+
self.dropout = nn.Dropout(config.classifier_dropout)
|
1225 |
+
else:
|
1226 |
+
self.dropout = None
|
1227 |
+
self.config = config
|
1228 |
+
|
1229 |
+
if self.config.quantization_bit:
|
1230 |
+
self.quantize(self.config.quantization_bit, empty_init=True)
|
1231 |
+
|
1232 |
+
def forward(
|
1233 |
+
self,
|
1234 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1235 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1236 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1237 |
+
full_attention_mask: Optional[torch.Tensor] = None,
|
1238 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
1239 |
+
inputs_embeds: Optional[torch.LongTensor] = None,
|
1240 |
+
labels: Optional[torch.LongTensor] = None,
|
1241 |
+
use_cache: Optional[bool] = None,
|
1242 |
+
output_hidden_states: Optional[bool] = None,
|
1243 |
+
return_dict: Optional[bool] = None,
|
1244 |
+
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
|
1245 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1246 |
+
|
1247 |
+
transformer_outputs = self.transformer(
|
1248 |
+
input_ids=input_ids,
|
1249 |
+
position_ids=position_ids,
|
1250 |
+
attention_mask=attention_mask,
|
1251 |
+
full_attention_mask=full_attention_mask,
|
1252 |
+
past_key_values=past_key_values,
|
1253 |
+
inputs_embeds=inputs_embeds,
|
1254 |
+
use_cache=use_cache,
|
1255 |
+
output_hidden_states=output_hidden_states,
|
1256 |
+
return_dict=return_dict,
|
1257 |
+
)
|
1258 |
+
|
1259 |
+
hidden_states = transformer_outputs[0]
|
1260 |
+
pooled_hidden_states = hidden_states[-1]
|
1261 |
+
if self.dropout is not None:
|
1262 |
+
pooled_hidden_states = self.dropout(pooled_hidden_states)
|
1263 |
+
logits = self.classifier_head(pooled_hidden_states)
|
1264 |
+
|
1265 |
+
loss = None
|
1266 |
+
if labels is not None:
|
1267 |
+
if self.config.problem_type is None:
|
1268 |
+
if self.num_labels == 1:
|
1269 |
+
self.config.problem_type = "regression"
|
1270 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1271 |
+
self.config.problem_type = "single_label_classification"
|
1272 |
+
else:
|
1273 |
+
self.config.problem_type = "multi_label_classification"
|
1274 |
+
|
1275 |
+
if self.config.problem_type == "regression":
|
1276 |
+
loss_fct = MSELoss()
|
1277 |
+
if self.num_labels == 1:
|
1278 |
+
loss = loss_fct(logits.squeeze().float(), labels.squeeze())
|
1279 |
+
else:
|
1280 |
+
loss = loss_fct(logits.float(), labels)
|
1281 |
+
elif self.config.problem_type == "single_label_classification":
|
1282 |
+
loss_fct = CrossEntropyLoss()
|
1283 |
+
loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
|
1284 |
+
elif self.config.problem_type == "multi_label_classification":
|
1285 |
+
loss_fct = BCEWithLogitsLoss()
|
1286 |
+
loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
|
1287 |
+
|
1288 |
+
if not return_dict:
|
1289 |
+
output = (logits,) + transformer_outputs[1:]
|
1290 |
+
return ((loss,) + output) if loss is not None else output
|
1291 |
+
|
1292 |
+
return SequenceClassifierOutputWithPast(
|
1293 |
+
loss=loss,
|
1294 |
+
logits=logits,
|
1295 |
+
past_key_values=transformer_outputs.past_key_values,
|
1296 |
+
hidden_states=transformer_outputs.hidden_states,
|
1297 |
+
attentions=transformer_outputs.attentions,
|
1298 |
+
)
|
text_encoder/pytorch_model-00001-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b6a6388dae55b598efe76c704e7f017bd84e6f6213466b7686a8f8326f78ab05
|
3 |
+
size 1827781090
|
text_encoder/pytorch_model-00002-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2f96bef324acb5c3fe06b7a80f84272fe064d0327cbf14eddfae7af0d665a6ac
|
3 |
+
size 1968299480
|
text_encoder/pytorch_model-00003-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2400101255213250d9df716f778b7d2325f2fa4a8acaedee788338fceee5b27e
|
3 |
+
size 1927415036
|
text_encoder/pytorch_model-00004-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:472567c1b0e448a19171fbb5b3dab5670426d0a5dfdfd2c3a87a60bb1f96037d
|
3 |
+
size 1815225998
|
text_encoder/pytorch_model-00005-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef2aea78fa386168958e5ba42ecf09cbb567ed3e77ce2be990d556b84081e2b9
|
3 |
+
size 1968299544
|
text_encoder/pytorch_model-00006-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:35191adf21a1ab632c2b175fcbb6c27601150026cb1ed5d602938d825954526f
|
3 |
+
size 1927415036
|
text_encoder/pytorch_model-00007-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b7cdaa9b8ed183284905c49d19bf42360037fdf2f95acb3093039d3c3a459261
|
3 |
+
size 1052808542
|
text_encoder/pytorch_model.bin.index.json
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 12487168064
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"transformer.embedding.word_embeddings.weight": "pytorch_model-00001-of-00007.bin",
|
7 |
+
"transformer.encoder.final_layernorm.weight": "pytorch_model-00007-of-00007.bin",
|
8 |
+
"transformer.encoder.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00007.bin",
|
9 |
+
"transformer.encoder.layers.0.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00007.bin",
|
10 |
+
"transformer.encoder.layers.0.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00007.bin",
|
11 |
+
"transformer.encoder.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00007.bin",
|
12 |
+
"transformer.encoder.layers.0.self_attention.dense.weight": "pytorch_model-00001-of-00007.bin",
|
13 |
+
"transformer.encoder.layers.0.self_attention.query_key_value.bias": "pytorch_model-00001-of-00007.bin",
|
14 |
+
"transformer.encoder.layers.0.self_attention.query_key_value.weight": "pytorch_model-00001-of-00007.bin",
|
15 |
+
"transformer.encoder.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00007.bin",
|
16 |
+
"transformer.encoder.layers.1.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00007.bin",
|
17 |
+
"transformer.encoder.layers.1.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00007.bin",
|
18 |
+
"transformer.encoder.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00007.bin",
|
19 |
+
"transformer.encoder.layers.1.self_attention.dense.weight": "pytorch_model-00001-of-00007.bin",
|
20 |
+
"transformer.encoder.layers.1.self_attention.query_key_value.bias": "pytorch_model-00001-of-00007.bin",
|
21 |
+
"transformer.encoder.layers.1.self_attention.query_key_value.weight": "pytorch_model-00001-of-00007.bin",
|
22 |
+
"transformer.encoder.layers.10.input_layernorm.weight": "pytorch_model-00003-of-00007.bin",
|
23 |
+
"transformer.encoder.layers.10.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00007.bin",
|
24 |
+
"transformer.encoder.layers.10.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00007.bin",
|
25 |
+
"transformer.encoder.layers.10.post_attention_layernorm.weight": "pytorch_model-00003-of-00007.bin",
|
26 |
+
"transformer.encoder.layers.10.self_attention.dense.weight": "pytorch_model-00003-of-00007.bin",
|
27 |
+
"transformer.encoder.layers.10.self_attention.query_key_value.bias": "pytorch_model-00003-of-00007.bin",
|
28 |
+
"transformer.encoder.layers.10.self_attention.query_key_value.weight": "pytorch_model-00003-of-00007.bin",
|
29 |
+
"transformer.encoder.layers.11.input_layernorm.weight": "pytorch_model-00003-of-00007.bin",
|
30 |
+
"transformer.encoder.layers.11.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00007.bin",
|
31 |
+
"transformer.encoder.layers.11.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00007.bin",
|
32 |
+
"transformer.encoder.layers.11.post_attention_layernorm.weight": "pytorch_model-00003-of-00007.bin",
|
33 |
+
"transformer.encoder.layers.11.self_attention.dense.weight": "pytorch_model-00003-of-00007.bin",
|
34 |
+
"transformer.encoder.layers.11.self_attention.query_key_value.bias": "pytorch_model-00003-of-00007.bin",
|
35 |
+
"transformer.encoder.layers.11.self_attention.query_key_value.weight": "pytorch_model-00003-of-00007.bin",
|
36 |
+
"transformer.encoder.layers.12.input_layernorm.weight": "pytorch_model-00003-of-00007.bin",
|
37 |
+
"transformer.encoder.layers.12.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00007.bin",
|
38 |
+
"transformer.encoder.layers.12.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00007.bin",
|
39 |
+
"transformer.encoder.layers.12.post_attention_layernorm.weight": "pytorch_model-00003-of-00007.bin",
|
40 |
+
"transformer.encoder.layers.12.self_attention.dense.weight": "pytorch_model-00003-of-00007.bin",
|
41 |
+
"transformer.encoder.layers.12.self_attention.query_key_value.bias": "pytorch_model-00003-of-00007.bin",
|
42 |
+
"transformer.encoder.layers.12.self_attention.query_key_value.weight": "pytorch_model-00003-of-00007.bin",
|
43 |
+
"transformer.encoder.layers.13.input_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
44 |
+
"transformer.encoder.layers.13.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00007.bin",
|
45 |
+
"transformer.encoder.layers.13.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00007.bin",
|
46 |
+
"transformer.encoder.layers.13.post_attention_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
47 |
+
"transformer.encoder.layers.13.self_attention.dense.weight": "pytorch_model-00004-of-00007.bin",
|
48 |
+
"transformer.encoder.layers.13.self_attention.query_key_value.bias": "pytorch_model-00004-of-00007.bin",
|
49 |
+
"transformer.encoder.layers.13.self_attention.query_key_value.weight": "pytorch_model-00004-of-00007.bin",
|
50 |
+
"transformer.encoder.layers.14.input_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
51 |
+
"transformer.encoder.layers.14.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00007.bin",
|
52 |
+
"transformer.encoder.layers.14.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00007.bin",
|
53 |
+
"transformer.encoder.layers.14.post_attention_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
54 |
+
"transformer.encoder.layers.14.self_attention.dense.weight": "pytorch_model-00004-of-00007.bin",
|
55 |
+
"transformer.encoder.layers.14.self_attention.query_key_value.bias": "pytorch_model-00004-of-00007.bin",
|
56 |
+
"transformer.encoder.layers.14.self_attention.query_key_value.weight": "pytorch_model-00004-of-00007.bin",
|
57 |
+
"transformer.encoder.layers.15.input_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
58 |
+
"transformer.encoder.layers.15.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00007.bin",
|
59 |
+
"transformer.encoder.layers.15.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00007.bin",
|
60 |
+
"transformer.encoder.layers.15.post_attention_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
61 |
+
"transformer.encoder.layers.15.self_attention.dense.weight": "pytorch_model-00004-of-00007.bin",
|
62 |
+
"transformer.encoder.layers.15.self_attention.query_key_value.bias": "pytorch_model-00004-of-00007.bin",
|
63 |
+
"transformer.encoder.layers.15.self_attention.query_key_value.weight": "pytorch_model-00004-of-00007.bin",
|
64 |
+
"transformer.encoder.layers.16.input_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
65 |
+
"transformer.encoder.layers.16.mlp.dense_4h_to_h.weight": "pytorch_model-00004-of-00007.bin",
|
66 |
+
"transformer.encoder.layers.16.mlp.dense_h_to_4h.weight": "pytorch_model-00004-of-00007.bin",
|
67 |
+
"transformer.encoder.layers.16.post_attention_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
68 |
+
"transformer.encoder.layers.16.self_attention.dense.weight": "pytorch_model-00004-of-00007.bin",
|
69 |
+
"transformer.encoder.layers.16.self_attention.query_key_value.bias": "pytorch_model-00004-of-00007.bin",
|
70 |
+
"transformer.encoder.layers.16.self_attention.query_key_value.weight": "pytorch_model-00004-of-00007.bin",
|
71 |
+
"transformer.encoder.layers.17.input_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
72 |
+
"transformer.encoder.layers.17.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00007.bin",
|
73 |
+
"transformer.encoder.layers.17.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00007.bin",
|
74 |
+
"transformer.encoder.layers.17.post_attention_layernorm.weight": "pytorch_model-00004-of-00007.bin",
|
75 |
+
"transformer.encoder.layers.17.self_attention.dense.weight": "pytorch_model-00004-of-00007.bin",
|
76 |
+
"transformer.encoder.layers.17.self_attention.query_key_value.bias": "pytorch_model-00004-of-00007.bin",
|
77 |
+
"transformer.encoder.layers.17.self_attention.query_key_value.weight": "pytorch_model-00004-of-00007.bin",
|
78 |
+
"transformer.encoder.layers.18.input_layernorm.weight": "pytorch_model-00005-of-00007.bin",
|
79 |
+
"transformer.encoder.layers.18.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00007.bin",
|
80 |
+
"transformer.encoder.layers.18.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00007.bin",
|
81 |
+
"transformer.encoder.layers.18.post_attention_layernorm.weight": "pytorch_model-00005-of-00007.bin",
|
82 |
+
"transformer.encoder.layers.18.self_attention.dense.weight": "pytorch_model-00005-of-00007.bin",
|
83 |
+
"transformer.encoder.layers.18.self_attention.query_key_value.bias": "pytorch_model-00005-of-00007.bin",
|
84 |
+
"transformer.encoder.layers.18.self_attention.query_key_value.weight": "pytorch_model-00005-of-00007.bin",
|
85 |
+
"transformer.encoder.layers.19.input_layernorm.weight": "pytorch_model-00005-of-00007.bin",
|
86 |
+
"transformer.encoder.layers.19.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00007.bin",
|
87 |
+
"transformer.encoder.layers.19.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00007.bin",
|
88 |
+
"transformer.encoder.layers.19.post_attention_layernorm.weight": "pytorch_model-00005-of-00007.bin",
|
89 |
+
"transformer.encoder.layers.19.self_attention.dense.weight": "pytorch_model-00005-of-00007.bin",
|
90 |
+
"transformer.encoder.layers.19.self_attention.query_key_value.bias": "pytorch_model-00005-of-00007.bin",
|
91 |
+
"transformer.encoder.layers.19.self_attention.query_key_value.weight": "pytorch_model-00005-of-00007.bin",
|
92 |
+
"transformer.encoder.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00007.bin",
|
93 |
+
"transformer.encoder.layers.2.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00007.bin",
|
94 |
+
"transformer.encoder.layers.2.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00007.bin",
|
95 |
+
"transformer.encoder.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00007.bin",
|
96 |
+
"transformer.encoder.layers.2.self_attention.dense.weight": "pytorch_model-00001-of-00007.bin",
|
97 |
+
"transformer.encoder.layers.2.self_attention.query_key_value.bias": "pytorch_model-00001-of-00007.bin",
|
98 |
+
"transformer.encoder.layers.2.self_attention.query_key_value.weight": "pytorch_model-00001-of-00007.bin",
|
99 |
+
"transformer.encoder.layers.20.input_layernorm.weight": "pytorch_model-00005-of-00007.bin",
|
100 |
+
"transformer.encoder.layers.20.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00007.bin",
|
101 |
+
"transformer.encoder.layers.20.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00007.bin",
|
102 |
+
"transformer.encoder.layers.20.post_attention_layernorm.weight": "pytorch_model-00005-of-00007.bin",
|
103 |
+
"transformer.encoder.layers.20.self_attention.dense.weight": "pytorch_model-00005-of-00007.bin",
|
104 |
+
"transformer.encoder.layers.20.self_attention.query_key_value.bias": "pytorch_model-00005-of-00007.bin",
|
105 |
+
"transformer.encoder.layers.20.self_attention.query_key_value.weight": "pytorch_model-00005-of-00007.bin",
|
106 |
+
"transformer.encoder.layers.21.input_layernorm.weight": "pytorch_model-00005-of-00007.bin",
|
107 |
+
"transformer.encoder.layers.21.mlp.dense_4h_to_h.weight": "pytorch_model-00005-of-00007.bin",
|
108 |
+
"transformer.encoder.layers.21.mlp.dense_h_to_4h.weight": "pytorch_model-00005-of-00007.bin",
|
109 |
+
"transformer.encoder.layers.21.post_attention_layernorm.weight": "pytorch_model-00005-of-00007.bin",
|
110 |
+
"transformer.encoder.layers.21.self_attention.dense.weight": "pytorch_model-00005-of-00007.bin",
|
111 |
+
"transformer.encoder.layers.21.self_attention.query_key_value.bias": "pytorch_model-00005-of-00007.bin",
|
112 |
+
"transformer.encoder.layers.21.self_attention.query_key_value.weight": "pytorch_model-00005-of-00007.bin",
|
113 |
+
"transformer.encoder.layers.22.input_layernorm.weight": "pytorch_model-00005-of-00007.bin",
|
114 |
+
"transformer.encoder.layers.22.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00007.bin",
|
115 |
+
"transformer.encoder.layers.22.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00007.bin",
|
116 |
+
"transformer.encoder.layers.22.post_attention_layernorm.weight": "pytorch_model-00006-of-00007.bin",
|
117 |
+
"transformer.encoder.layers.22.self_attention.dense.weight": "pytorch_model-00006-of-00007.bin",
|
118 |
+
"transformer.encoder.layers.22.self_attention.query_key_value.bias": "pytorch_model-00006-of-00007.bin",
|
119 |
+
"transformer.encoder.layers.22.self_attention.query_key_value.weight": "pytorch_model-00006-of-00007.bin",
|
120 |
+
"transformer.encoder.layers.23.input_layernorm.weight": "pytorch_model-00006-of-00007.bin",
|
121 |
+
"transformer.encoder.layers.23.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00007.bin",
|
122 |
+
"transformer.encoder.layers.23.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00007.bin",
|
123 |
+
"transformer.encoder.layers.23.post_attention_layernorm.weight": "pytorch_model-00006-of-00007.bin",
|
124 |
+
"transformer.encoder.layers.23.self_attention.dense.weight": "pytorch_model-00006-of-00007.bin",
|
125 |
+
"transformer.encoder.layers.23.self_attention.query_key_value.bias": "pytorch_model-00006-of-00007.bin",
|
126 |
+
"transformer.encoder.layers.23.self_attention.query_key_value.weight": "pytorch_model-00006-of-00007.bin",
|
127 |
+
"transformer.encoder.layers.24.input_layernorm.weight": "pytorch_model-00006-of-00007.bin",
|
128 |
+
"transformer.encoder.layers.24.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00007.bin",
|
129 |
+
"transformer.encoder.layers.24.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00007.bin",
|
130 |
+
"transformer.encoder.layers.24.post_attention_layernorm.weight": "pytorch_model-00006-of-00007.bin",
|
131 |
+
"transformer.encoder.layers.24.self_attention.dense.weight": "pytorch_model-00006-of-00007.bin",
|
132 |
+
"transformer.encoder.layers.24.self_attention.query_key_value.bias": "pytorch_model-00006-of-00007.bin",
|
133 |
+
"transformer.encoder.layers.24.self_attention.query_key_value.weight": "pytorch_model-00006-of-00007.bin",
|
134 |
+
"transformer.encoder.layers.25.input_layernorm.weight": "pytorch_model-00006-of-00007.bin",
|
135 |
+
"transformer.encoder.layers.25.mlp.dense_4h_to_h.weight": "pytorch_model-00006-of-00007.bin",
|
136 |
+
"transformer.encoder.layers.25.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00007.bin",
|
137 |
+
"transformer.encoder.layers.25.post_attention_layernorm.weight": "pytorch_model-00006-of-00007.bin",
|
138 |
+
"transformer.encoder.layers.25.self_attention.dense.weight": "pytorch_model-00006-of-00007.bin",
|
139 |
+
"transformer.encoder.layers.25.self_attention.query_key_value.bias": "pytorch_model-00006-of-00007.bin",
|
140 |
+
"transformer.encoder.layers.25.self_attention.query_key_value.weight": "pytorch_model-00006-of-00007.bin",
|
141 |
+
"transformer.encoder.layers.26.input_layernorm.weight": "pytorch_model-00006-of-00007.bin",
|
142 |
+
"transformer.encoder.layers.26.mlp.dense_4h_to_h.weight": "pytorch_model-00007-of-00007.bin",
|
143 |
+
"transformer.encoder.layers.26.mlp.dense_h_to_4h.weight": "pytorch_model-00006-of-00007.bin",
|
144 |
+
"transformer.encoder.layers.26.post_attention_layernorm.weight": "pytorch_model-00006-of-00007.bin",
|
145 |
+
"transformer.encoder.layers.26.self_attention.dense.weight": "pytorch_model-00006-of-00007.bin",
|
146 |
+
"transformer.encoder.layers.26.self_attention.query_key_value.bias": "pytorch_model-00006-of-00007.bin",
|
147 |
+
"transformer.encoder.layers.26.self_attention.query_key_value.weight": "pytorch_model-00006-of-00007.bin",
|
148 |
+
"transformer.encoder.layers.27.input_layernorm.weight": "pytorch_model-00007-of-00007.bin",
|
149 |
+
"transformer.encoder.layers.27.mlp.dense_4h_to_h.weight": "pytorch_model-00007-of-00007.bin",
|
150 |
+
"transformer.encoder.layers.27.mlp.dense_h_to_4h.weight": "pytorch_model-00007-of-00007.bin",
|
151 |
+
"transformer.encoder.layers.27.post_attention_layernorm.weight": "pytorch_model-00007-of-00007.bin",
|
152 |
+
"transformer.encoder.layers.27.self_attention.dense.weight": "pytorch_model-00007-of-00007.bin",
|
153 |
+
"transformer.encoder.layers.27.self_attention.query_key_value.bias": "pytorch_model-00007-of-00007.bin",
|
154 |
+
"transformer.encoder.layers.27.self_attention.query_key_value.weight": "pytorch_model-00007-of-00007.bin",
|
155 |
+
"transformer.encoder.layers.3.input_layernorm.weight": "pytorch_model-00001-of-00007.bin",
|
156 |
+
"transformer.encoder.layers.3.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00007.bin",
|
157 |
+
"transformer.encoder.layers.3.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00007.bin",
|
158 |
+
"transformer.encoder.layers.3.post_attention_layernorm.weight": "pytorch_model-00001-of-00007.bin",
|
159 |
+
"transformer.encoder.layers.3.self_attention.dense.weight": "pytorch_model-00001-of-00007.bin",
|
160 |
+
"transformer.encoder.layers.3.self_attention.query_key_value.bias": "pytorch_model-00001-of-00007.bin",
|
161 |
+
"transformer.encoder.layers.3.self_attention.query_key_value.weight": "pytorch_model-00001-of-00007.bin",
|
162 |
+
"transformer.encoder.layers.4.input_layernorm.weight": "pytorch_model-00002-of-00007.bin",
|
163 |
+
"transformer.encoder.layers.4.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00007.bin",
|
164 |
+
"transformer.encoder.layers.4.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00007.bin",
|
165 |
+
"transformer.encoder.layers.4.post_attention_layernorm.weight": "pytorch_model-00002-of-00007.bin",
|
166 |
+
"transformer.encoder.layers.4.self_attention.dense.weight": "pytorch_model-00002-of-00007.bin",
|
167 |
+
"transformer.encoder.layers.4.self_attention.query_key_value.bias": "pytorch_model-00002-of-00007.bin",
|
168 |
+
"transformer.encoder.layers.4.self_attention.query_key_value.weight": "pytorch_model-00002-of-00007.bin",
|
169 |
+
"transformer.encoder.layers.5.input_layernorm.weight": "pytorch_model-00002-of-00007.bin",
|
170 |
+
"transformer.encoder.layers.5.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00007.bin",
|
171 |
+
"transformer.encoder.layers.5.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00007.bin",
|
172 |
+
"transformer.encoder.layers.5.post_attention_layernorm.weight": "pytorch_model-00002-of-00007.bin",
|
173 |
+
"transformer.encoder.layers.5.self_attention.dense.weight": "pytorch_model-00002-of-00007.bin",
|
174 |
+
"transformer.encoder.layers.5.self_attention.query_key_value.bias": "pytorch_model-00002-of-00007.bin",
|
175 |
+
"transformer.encoder.layers.5.self_attention.query_key_value.weight": "pytorch_model-00002-of-00007.bin",
|
176 |
+
"transformer.encoder.layers.6.input_layernorm.weight": "pytorch_model-00002-of-00007.bin",
|
177 |
+
"transformer.encoder.layers.6.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00007.bin",
|
178 |
+
"transformer.encoder.layers.6.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00007.bin",
|
179 |
+
"transformer.encoder.layers.6.post_attention_layernorm.weight": "pytorch_model-00002-of-00007.bin",
|
180 |
+
"transformer.encoder.layers.6.self_attention.dense.weight": "pytorch_model-00002-of-00007.bin",
|
181 |
+
"transformer.encoder.layers.6.self_attention.query_key_value.bias": "pytorch_model-00002-of-00007.bin",
|
182 |
+
"transformer.encoder.layers.6.self_attention.query_key_value.weight": "pytorch_model-00002-of-00007.bin",
|
183 |
+
"transformer.encoder.layers.7.input_layernorm.weight": "pytorch_model-00002-of-00007.bin",
|
184 |
+
"transformer.encoder.layers.7.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00007.bin",
|
185 |
+
"transformer.encoder.layers.7.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00007.bin",
|
186 |
+
"transformer.encoder.layers.7.post_attention_layernorm.weight": "pytorch_model-00002-of-00007.bin",
|
187 |
+
"transformer.encoder.layers.7.self_attention.dense.weight": "pytorch_model-00002-of-00007.bin",
|
188 |
+
"transformer.encoder.layers.7.self_attention.query_key_value.bias": "pytorch_model-00002-of-00007.bin",
|
189 |
+
"transformer.encoder.layers.7.self_attention.query_key_value.weight": "pytorch_model-00002-of-00007.bin",
|
190 |
+
"transformer.encoder.layers.8.input_layernorm.weight": "pytorch_model-00002-of-00007.bin",
|
191 |
+
"transformer.encoder.layers.8.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00007.bin",
|
192 |
+
"transformer.encoder.layers.8.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00007.bin",
|
193 |
+
"transformer.encoder.layers.8.post_attention_layernorm.weight": "pytorch_model-00003-of-00007.bin",
|
194 |
+
"transformer.encoder.layers.8.self_attention.dense.weight": "pytorch_model-00003-of-00007.bin",
|
195 |
+
"transformer.encoder.layers.8.self_attention.query_key_value.bias": "pytorch_model-00003-of-00007.bin",
|
196 |
+
"transformer.encoder.layers.8.self_attention.query_key_value.weight": "pytorch_model-00003-of-00007.bin",
|
197 |
+
"transformer.encoder.layers.9.input_layernorm.weight": "pytorch_model-00003-of-00007.bin",
|
198 |
+
"transformer.encoder.layers.9.mlp.dense_4h_to_h.weight": "pytorch_model-00003-of-00007.bin",
|
199 |
+
"transformer.encoder.layers.9.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00007.bin",
|
200 |
+
"transformer.encoder.layers.9.post_attention_layernorm.weight": "pytorch_model-00003-of-00007.bin",
|
201 |
+
"transformer.encoder.layers.9.self_attention.dense.weight": "pytorch_model-00003-of-00007.bin",
|
202 |
+
"transformer.encoder.layers.9.self_attention.query_key_value.bias": "pytorch_model-00003-of-00007.bin",
|
203 |
+
"transformer.encoder.layers.9.self_attention.query_key_value.weight": "pytorch_model-00003-of-00007.bin",
|
204 |
+
"transformer.output_layer.weight": "pytorch_model-00007-of-00007.bin",
|
205 |
+
"transformer.rotary_pos_emb.inv_freq": "pytorch_model-00001-of-00007.bin"
|
206 |
+
}
|
207 |
+
}
|
text_encoder/quantization.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import Linear
|
2 |
+
from torch.nn.parameter import Parameter
|
3 |
+
|
4 |
+
import bz2
|
5 |
+
import torch
|
6 |
+
import base64
|
7 |
+
import ctypes
|
8 |
+
from transformers.utils import logging
|
9 |
+
|
10 |
+
from typing import List
|
11 |
+
from functools import partial
|
12 |
+
|
13 |
+
logger = logging.get_logger(__name__)
|
14 |
+
|
15 |
+
try:
|
16 |
+
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
|
17 |
+
|
18 |
+
class Kernel:
|
19 |
+
def __init__(self, code: bytes, function_names: List[str]):
|
20 |
+
self.code = code
|
21 |
+
self._function_names = function_names
|
22 |
+
self._cmodule = LazyKernelCModule(self.code)
|
23 |
+
|
24 |
+
for name in self._function_names:
|
25 |
+
setattr(self, name, KernelFunction(self._cmodule, name))
|
26 |
+
|
27 |
+
quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ"
|
28 |
+
|
29 |
+
kernels = Kernel(
|
30 |
+
bz2.decompress(base64.b64decode(quantization_code)),
|
31 |
+
[
|
32 |
+
"int4WeightCompression",
|
33 |
+
"int4WeightExtractionFloat",
|
34 |
+
"int4WeightExtractionHalf",
|
35 |
+
"int8WeightExtractionFloat",
|
36 |
+
"int8WeightExtractionHalf",
|
37 |
+
],
|
38 |
+
)
|
39 |
+
except Exception as exception:
|
40 |
+
kernels = None
|
41 |
+
logger.warning("Failed to load cpm_kernels:" + str(exception))
|
42 |
+
|
43 |
+
|
44 |
+
class W8A16Linear(torch.autograd.Function):
|
45 |
+
@staticmethod
|
46 |
+
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
|
47 |
+
ctx.inp_shape = inp.size()
|
48 |
+
ctx.weight_bit_width = weight_bit_width
|
49 |
+
out_features = quant_w.size(0)
|
50 |
+
inp = inp.contiguous().view(-1, inp.size(-1))
|
51 |
+
weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
|
52 |
+
ctx.weight_shape = weight.size()
|
53 |
+
output = inp.mm(weight.t())
|
54 |
+
ctx.save_for_backward(inp, quant_w, scale_w)
|
55 |
+
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def backward(ctx, grad_output: torch.Tensor):
|
59 |
+
inp, quant_w, scale_w = ctx.saved_tensors
|
60 |
+
weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
|
61 |
+
grad_output = grad_output.contiguous().view(-1, weight.size(0))
|
62 |
+
grad_input = grad_output.mm(weight)
|
63 |
+
grad_weight = grad_output.t().mm(inp)
|
64 |
+
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
|
65 |
+
|
66 |
+
|
67 |
+
def compress_int4_weight(weight: torch.Tensor): # (n, m)
|
68 |
+
with torch.cuda.device(weight.device):
|
69 |
+
n, m = weight.size(0), weight.size(1)
|
70 |
+
assert m % 2 == 0
|
71 |
+
m = m // 2
|
72 |
+
out = torch.empty(n, m, dtype=torch.int8, device="cuda")
|
73 |
+
stream = torch.cuda.current_stream()
|
74 |
+
|
75 |
+
gridDim = (n, 1, 1)
|
76 |
+
blockDim = (min(round_up(m, 32), 1024), 1, 1)
|
77 |
+
|
78 |
+
kernels.int4WeightCompression(
|
79 |
+
gridDim,
|
80 |
+
blockDim,
|
81 |
+
0,
|
82 |
+
stream,
|
83 |
+
[ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
|
84 |
+
)
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
|
89 |
+
assert scale_list.dtype in [torch.half, torch.bfloat16]
|
90 |
+
assert weight.dtype in [torch.int8]
|
91 |
+
if source_bit_width == 8:
|
92 |
+
return weight.to(scale_list.dtype) * scale_list[:, None]
|
93 |
+
elif source_bit_width == 4:
|
94 |
+
func = (
|
95 |
+
kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half else kernels.int4WeightExtractionBFloat16
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
assert False, "Unsupported bit-width"
|
99 |
+
|
100 |
+
with torch.cuda.device(weight.device):
|
101 |
+
n, m = weight.size(0), weight.size(1)
|
102 |
+
out = torch.empty(n, m * (8 // source_bit_width), dtype=scale_list.dtype, device="cuda")
|
103 |
+
stream = torch.cuda.current_stream()
|
104 |
+
|
105 |
+
gridDim = (n, 1, 1)
|
106 |
+
blockDim = (min(round_up(m, 32), 1024), 1, 1)
|
107 |
+
|
108 |
+
func(
|
109 |
+
gridDim,
|
110 |
+
blockDim,
|
111 |
+
0,
|
112 |
+
stream,
|
113 |
+
[
|
114 |
+
ctypes.c_void_p(weight.data_ptr()),
|
115 |
+
ctypes.c_void_p(scale_list.data_ptr()),
|
116 |
+
ctypes.c_void_p(out.data_ptr()),
|
117 |
+
ctypes.c_int32(n),
|
118 |
+
ctypes.c_int32(m),
|
119 |
+
],
|
120 |
+
)
|
121 |
+
return out
|
122 |
+
|
123 |
+
|
124 |
+
class QuantizedLinear(torch.nn.Module):
|
125 |
+
def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args,
|
126 |
+
**kwargs):
|
127 |
+
super().__init__()
|
128 |
+
self.weight_bit_width = weight_bit_width
|
129 |
+
|
130 |
+
shape = weight.shape
|
131 |
+
|
132 |
+
if weight is None or empty_init:
|
133 |
+
self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device)
|
134 |
+
self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device)
|
135 |
+
else:
|
136 |
+
self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)
|
137 |
+
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
|
138 |
+
if weight_bit_width == 4:
|
139 |
+
self.weight = compress_int4_weight(self.weight)
|
140 |
+
|
141 |
+
self.weight = Parameter(self.weight.to(device), requires_grad=False)
|
142 |
+
self.weight_scale = Parameter(self.weight_scale.to(device), requires_grad=False)
|
143 |
+
self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None
|
144 |
+
|
145 |
+
def forward(self, input):
|
146 |
+
output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)
|
147 |
+
if self.bias is not None:
|
148 |
+
output = output + self.bias
|
149 |
+
return output
|
150 |
+
|
151 |
+
|
152 |
+
def quantize(model, weight_bit_width, empty_init=False, device=None):
|
153 |
+
"""Replace fp16 linear with quantized linear"""
|
154 |
+
for layer in model.layers:
|
155 |
+
layer.self_attention.query_key_value = QuantizedLinear(
|
156 |
+
weight_bit_width=weight_bit_width,
|
157 |
+
weight=layer.self_attention.query_key_value.weight.to(torch.cuda.current_device()),
|
158 |
+
bias=layer.self_attention.query_key_value.bias,
|
159 |
+
dtype=layer.self_attention.query_key_value.weight.dtype,
|
160 |
+
device=layer.self_attention.query_key_value.weight.device if device is None else device,
|
161 |
+
empty_init=empty_init
|
162 |
+
)
|
163 |
+
layer.self_attention.dense = QuantizedLinear(
|
164 |
+
weight_bit_width=weight_bit_width,
|
165 |
+
weight=layer.self_attention.dense.weight.to(torch.cuda.current_device()),
|
166 |
+
bias=layer.self_attention.dense.bias,
|
167 |
+
dtype=layer.self_attention.dense.weight.dtype,
|
168 |
+
device=layer.self_attention.dense.weight.device if device is None else device,
|
169 |
+
empty_init=empty_init
|
170 |
+
)
|
171 |
+
layer.mlp.dense_h_to_4h = QuantizedLinear(
|
172 |
+
weight_bit_width=weight_bit_width,
|
173 |
+
weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
|
174 |
+
bias=layer.mlp.dense_h_to_4h.bias,
|
175 |
+
dtype=layer.mlp.dense_h_to_4h.weight.dtype,
|
176 |
+
device=layer.mlp.dense_h_to_4h.weight.device if device is None else device,
|
177 |
+
empty_init=empty_init
|
178 |
+
)
|
179 |
+
layer.mlp.dense_4h_to_h = QuantizedLinear(
|
180 |
+
weight_bit_width=weight_bit_width,
|
181 |
+
weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
|
182 |
+
bias=layer.mlp.dense_4h_to_h.bias,
|
183 |
+
dtype=layer.mlp.dense_4h_to_h.weight.dtype,
|
184 |
+
device=layer.mlp.dense_4h_to_h.weight.device if device is None else device,
|
185 |
+
empty_init=empty_init
|
186 |
+
)
|
187 |
+
|
188 |
+
return model
|
text_encoder/tokenization_chatglm.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from typing import List, Optional, Union, Dict
|
5 |
+
from sentencepiece import SentencePieceProcessor
|
6 |
+
from transformers import PreTrainedTokenizer
|
7 |
+
from transformers.utils import logging, PaddingStrategy
|
8 |
+
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
9 |
+
|
10 |
+
|
11 |
+
class SPTokenizer:
|
12 |
+
def __init__(self, model_path: str):
|
13 |
+
# reload tokenizer
|
14 |
+
assert os.path.isfile(model_path), model_path
|
15 |
+
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
16 |
+
|
17 |
+
# BOS / EOS token IDs
|
18 |
+
self.n_words: int = self.sp_model.vocab_size()
|
19 |
+
self.bos_id: int = self.sp_model.bos_id()
|
20 |
+
self.eos_id: int = self.sp_model.eos_id()
|
21 |
+
self.pad_id: int = self.sp_model.unk_id()
|
22 |
+
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
23 |
+
|
24 |
+
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
|
25 |
+
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
|
26 |
+
self.special_tokens = {}
|
27 |
+
self.index_special_tokens = {}
|
28 |
+
for token in special_tokens:
|
29 |
+
self.special_tokens[token] = self.n_words
|
30 |
+
self.index_special_tokens[self.n_words] = token
|
31 |
+
self.n_words += 1
|
32 |
+
self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
|
33 |
+
|
34 |
+
def tokenize(self, s: str, encode_special_tokens=False):
|
35 |
+
if encode_special_tokens:
|
36 |
+
last_index = 0
|
37 |
+
t = []
|
38 |
+
for match in re.finditer(self.role_special_token_expression, s):
|
39 |
+
if last_index < match.start():
|
40 |
+
t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
|
41 |
+
t.append(s[match.start():match.end()])
|
42 |
+
last_index = match.end()
|
43 |
+
if last_index < len(s):
|
44 |
+
t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
|
45 |
+
return t
|
46 |
+
else:
|
47 |
+
return self.sp_model.EncodeAsPieces(s)
|
48 |
+
|
49 |
+
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
|
50 |
+
assert type(s) is str
|
51 |
+
t = self.sp_model.encode(s)
|
52 |
+
if bos:
|
53 |
+
t = [self.bos_id] + t
|
54 |
+
if eos:
|
55 |
+
t = t + [self.eos_id]
|
56 |
+
return t
|
57 |
+
|
58 |
+
def decode(self, t: List[int]) -> str:
|
59 |
+
text, buffer = "", []
|
60 |
+
for token in t:
|
61 |
+
if token in self.index_special_tokens:
|
62 |
+
if buffer:
|
63 |
+
text += self.sp_model.decode(buffer)
|
64 |
+
buffer = []
|
65 |
+
text += self.index_special_tokens[token]
|
66 |
+
else:
|
67 |
+
buffer.append(token)
|
68 |
+
if buffer:
|
69 |
+
text += self.sp_model.decode(buffer)
|
70 |
+
return text
|
71 |
+
|
72 |
+
def decode_tokens(self, tokens: List[str]) -> str:
|
73 |
+
text = self.sp_model.DecodePieces(tokens)
|
74 |
+
return text
|
75 |
+
|
76 |
+
def convert_token_to_id(self, token):
|
77 |
+
""" Converts a token (str) in an id using the vocab. """
|
78 |
+
if token in self.special_tokens:
|
79 |
+
return self.special_tokens[token]
|
80 |
+
return self.sp_model.PieceToId(token)
|
81 |
+
|
82 |
+
def convert_id_to_token(self, index):
|
83 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
84 |
+
if index in self.index_special_tokens:
|
85 |
+
return self.index_special_tokens[index]
|
86 |
+
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
|
87 |
+
return ""
|
88 |
+
return self.sp_model.IdToPiece(index)
|
89 |
+
|
90 |
+
|
91 |
+
class ChatGLMTokenizer(PreTrainedTokenizer):
|
92 |
+
vocab_files_names = {"vocab_file": "tokenizer.model"}
|
93 |
+
|
94 |
+
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
95 |
+
|
96 |
+
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
|
97 |
+
**kwargs):
|
98 |
+
self.name = "GLMTokenizer"
|
99 |
+
|
100 |
+
self.vocab_file = vocab_file
|
101 |
+
self.tokenizer = SPTokenizer(vocab_file)
|
102 |
+
self.special_tokens = {
|
103 |
+
"<bos>": self.tokenizer.bos_id,
|
104 |
+
"<eos>": self.tokenizer.eos_id,
|
105 |
+
"<pad>": self.tokenizer.pad_id
|
106 |
+
}
|
107 |
+
self.encode_special_tokens = encode_special_tokens
|
108 |
+
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
109 |
+
encode_special_tokens=encode_special_tokens,
|
110 |
+
**kwargs)
|
111 |
+
|
112 |
+
def get_command(self, token):
|
113 |
+
if token in self.special_tokens:
|
114 |
+
return self.special_tokens[token]
|
115 |
+
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
116 |
+
return self.tokenizer.special_tokens[token]
|
117 |
+
|
118 |
+
@property
|
119 |
+
def unk_token(self) -> str:
|
120 |
+
return "<unk>"
|
121 |
+
|
122 |
+
@property
|
123 |
+
def pad_token(self) -> str:
|
124 |
+
return "<unk>"
|
125 |
+
|
126 |
+
@property
|
127 |
+
def pad_token_id(self):
|
128 |
+
return self.get_command("<pad>")
|
129 |
+
|
130 |
+
@property
|
131 |
+
def eos_token(self) -> str:
|
132 |
+
return "</s>"
|
133 |
+
|
134 |
+
@property
|
135 |
+
def eos_token_id(self):
|
136 |
+
return self.get_command("<eos>")
|
137 |
+
|
138 |
+
@property
|
139 |
+
def vocab_size(self):
|
140 |
+
return self.tokenizer.n_words
|
141 |
+
|
142 |
+
def get_vocab(self):
|
143 |
+
""" Returns vocab as a dict """
|
144 |
+
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
145 |
+
vocab.update(self.added_tokens_encoder)
|
146 |
+
return vocab
|
147 |
+
|
148 |
+
def _tokenize(self, text, **kwargs):
|
149 |
+
return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
|
150 |
+
|
151 |
+
def _convert_token_to_id(self, token):
|
152 |
+
""" Converts a token (str) in an id using the vocab. """
|
153 |
+
return self.tokenizer.convert_token_to_id(token)
|
154 |
+
|
155 |
+
def _convert_id_to_token(self, index):
|
156 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
157 |
+
return self.tokenizer.convert_id_to_token(index)
|
158 |
+
|
159 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
160 |
+
return self.tokenizer.decode_tokens(tokens)
|
161 |
+
|
162 |
+
def save_vocabulary(self, save_directory, filename_prefix=None):
|
163 |
+
"""
|
164 |
+
Save the vocabulary and special tokens file to a directory.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
save_directory (`str`):
|
168 |
+
The directory in which to save the vocabulary.
|
169 |
+
filename_prefix (`str`, *optional*):
|
170 |
+
An optional prefix to add to the named of the saved files.
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
`Tuple(str)`: Paths to the files saved.
|
174 |
+
"""
|
175 |
+
if os.path.isdir(save_directory):
|
176 |
+
vocab_file = os.path.join(
|
177 |
+
save_directory, self.vocab_files_names["vocab_file"]
|
178 |
+
)
|
179 |
+
else:
|
180 |
+
vocab_file = save_directory
|
181 |
+
|
182 |
+
with open(self.vocab_file, 'rb') as fin:
|
183 |
+
proto_str = fin.read()
|
184 |
+
|
185 |
+
with open(vocab_file, "wb") as writer:
|
186 |
+
writer.write(proto_str)
|
187 |
+
|
188 |
+
return (vocab_file,)
|
189 |
+
|
190 |
+
def get_prefix_tokens(self):
|
191 |
+
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
192 |
+
return prefix_tokens
|
193 |
+
|
194 |
+
def build_single_message(self, role, metadata, message):
|
195 |
+
assert role in ["system", "user", "assistant", "observation"], role
|
196 |
+
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
|
197 |
+
message_tokens = self.tokenizer.encode(message)
|
198 |
+
tokens = role_tokens + message_tokens
|
199 |
+
return tokens
|
200 |
+
|
201 |
+
def build_chat_input(self, query, history=None, role="user"):
|
202 |
+
if history is None:
|
203 |
+
history = []
|
204 |
+
input_ids = []
|
205 |
+
for item in history:
|
206 |
+
content = item["content"]
|
207 |
+
if item["role"] == "system" and "tools" in item:
|
208 |
+
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
|
209 |
+
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
|
210 |
+
input_ids.extend(self.build_single_message(role, "", query))
|
211 |
+
input_ids.extend([self.get_command("<|assistant|>")])
|
212 |
+
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
|
213 |
+
|
214 |
+
def build_inputs_with_special_tokens(
|
215 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
216 |
+
) -> List[int]:
|
217 |
+
"""
|
218 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
219 |
+
adding special tokens. A BERT sequence has the following format:
|
220 |
+
|
221 |
+
- single sequence: `[CLS] X [SEP]`
|
222 |
+
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
223 |
+
|
224 |
+
Args:
|
225 |
+
token_ids_0 (`List[int]`):
|
226 |
+
List of IDs to which the special tokens will be added.
|
227 |
+
token_ids_1 (`List[int]`, *optional*):
|
228 |
+
Optional second list of IDs for sequence pairs.
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
232 |
+
"""
|
233 |
+
prefix_tokens = self.get_prefix_tokens()
|
234 |
+
token_ids_0 = prefix_tokens + token_ids_0
|
235 |
+
if token_ids_1 is not None:
|
236 |
+
token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
|
237 |
+
return token_ids_0
|
238 |
+
|
239 |
+
def _pad(
|
240 |
+
self,
|
241 |
+
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
242 |
+
max_length: Optional[int] = None,
|
243 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
244 |
+
pad_to_multiple_of: Optional[int] = None,
|
245 |
+
return_attention_mask: Optional[bool] = None,
|
246 |
+
) -> dict:
|
247 |
+
"""
|
248 |
+
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
249 |
+
|
250 |
+
Args:
|
251 |
+
encoded_inputs:
|
252 |
+
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
253 |
+
max_length: maximum length of the returned list and optionally padding length (see below).
|
254 |
+
Will truncate by taking into account the special tokens.
|
255 |
+
padding_strategy: PaddingStrategy to use for padding.
|
256 |
+
|
257 |
+
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
258 |
+
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
259 |
+
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
260 |
+
The tokenizer padding sides are defined in self.padding_side:
|
261 |
+
|
262 |
+
- 'left': pads on the left of the sequences
|
263 |
+
- 'right': pads on the right of the sequences
|
264 |
+
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
265 |
+
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
266 |
+
`>= 7.5` (Volta).
|
267 |
+
return_attention_mask:
|
268 |
+
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
269 |
+
"""
|
270 |
+
# Load from model defaults
|
271 |
+
assert self.padding_side == "left"
|
272 |
+
|
273 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
274 |
+
seq_length = len(required_input)
|
275 |
+
|
276 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
277 |
+
max_length = len(required_input)
|
278 |
+
|
279 |
+
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
280 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
281 |
+
|
282 |
+
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
283 |
+
|
284 |
+
# Initialize attention mask if not present.
|
285 |
+
if "attention_mask" not in encoded_inputs:
|
286 |
+
encoded_inputs["attention_mask"] = [1] * seq_length
|
287 |
+
|
288 |
+
if "position_ids" not in encoded_inputs:
|
289 |
+
encoded_inputs["position_ids"] = list(range(seq_length))
|
290 |
+
|
291 |
+
if needs_to_be_padded:
|
292 |
+
difference = max_length - len(required_input)
|
293 |
+
|
294 |
+
if "attention_mask" in encoded_inputs:
|
295 |
+
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
|
296 |
+
if "position_ids" in encoded_inputs:
|
297 |
+
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
|
298 |
+
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
299 |
+
|
300 |
+
return encoded_inputs
|
text_encoder/tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e7dc4c393423b76e4373e5157ddc34803a0189ba96b21ddbb40269d31468a6f2
|
3 |
+
size 1018370
|
text_encoder/tokenizer_config.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name_or_path": "THUDM/chatglm3-6b-base",
|
3 |
+
"remove_space": false,
|
4 |
+
"do_lower_case": false,
|
5 |
+
"tokenizer_class": "ChatGLMTokenizer",
|
6 |
+
"auto_map": {
|
7 |
+
"AutoTokenizer": [
|
8 |
+
"tokenization_chatglm.ChatGLMTokenizer",
|
9 |
+
null
|
10 |
+
]
|
11 |
+
}
|
12 |
+
}
|
text_encoder/vocab.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e7dc4c393423b76e4373e5157ddc34803a0189ba96b21ddbb40269d31468a6f2
|
3 |
+
size 1018370
|
tokenizer/tokenization_chatglm.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from typing import List, Optional, Union, Dict
|
5 |
+
from sentencepiece import SentencePieceProcessor
|
6 |
+
from transformers import PreTrainedTokenizer
|
7 |
+
from transformers.utils import logging, PaddingStrategy
|
8 |
+
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
9 |
+
|
10 |
+
|
11 |
+
class SPTokenizer:
|
12 |
+
def __init__(self, model_path: str):
|
13 |
+
# reload tokenizer
|
14 |
+
assert os.path.isfile(model_path), model_path
|
15 |
+
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
16 |
+
|
17 |
+
# BOS / EOS token IDs
|
18 |
+
self.n_words: int = self.sp_model.vocab_size()
|
19 |
+
self.bos_id: int = self.sp_model.bos_id()
|
20 |
+
self.eos_id: int = self.sp_model.eos_id()
|
21 |
+
self.pad_id: int = self.sp_model.unk_id()
|
22 |
+
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
23 |
+
|
24 |
+
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
|
25 |
+
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
|
26 |
+
self.special_tokens = {}
|
27 |
+
self.index_special_tokens = {}
|
28 |
+
for token in special_tokens:
|
29 |
+
self.special_tokens[token] = self.n_words
|
30 |
+
self.index_special_tokens[self.n_words] = token
|
31 |
+
self.n_words += 1
|
32 |
+
self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
|
33 |
+
|
34 |
+
def tokenize(self, s: str, encode_special_tokens=False):
|
35 |
+
if encode_special_tokens:
|
36 |
+
last_index = 0
|
37 |
+
t = []
|
38 |
+
for match in re.finditer(self.role_special_token_expression, s):
|
39 |
+
if last_index < match.start():
|
40 |
+
t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
|
41 |
+
t.append(s[match.start():match.end()])
|
42 |
+
last_index = match.end()
|
43 |
+
if last_index < len(s):
|
44 |
+
t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
|
45 |
+
return t
|
46 |
+
else:
|
47 |
+
return self.sp_model.EncodeAsPieces(s)
|
48 |
+
|
49 |
+
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
|
50 |
+
assert type(s) is str
|
51 |
+
t = self.sp_model.encode(s)
|
52 |
+
if bos:
|
53 |
+
t = [self.bos_id] + t
|
54 |
+
if eos:
|
55 |
+
t = t + [self.eos_id]
|
56 |
+
return t
|
57 |
+
|
58 |
+
def decode(self, t: List[int]) -> str:
|
59 |
+
text, buffer = "", []
|
60 |
+
for token in t:
|
61 |
+
if token in self.index_special_tokens:
|
62 |
+
if buffer:
|
63 |
+
text += self.sp_model.decode(buffer)
|
64 |
+
buffer = []
|
65 |
+
text += self.index_special_tokens[token]
|
66 |
+
else:
|
67 |
+
buffer.append(token)
|
68 |
+
if buffer:
|
69 |
+
text += self.sp_model.decode(buffer)
|
70 |
+
return text
|
71 |
+
|
72 |
+
def decode_tokens(self, tokens: List[str]) -> str:
|
73 |
+
text = self.sp_model.DecodePieces(tokens)
|
74 |
+
return text
|
75 |
+
|
76 |
+
def convert_token_to_id(self, token):
|
77 |
+
""" Converts a token (str) in an id using the vocab. """
|
78 |
+
if token in self.special_tokens:
|
79 |
+
return self.special_tokens[token]
|
80 |
+
return self.sp_model.PieceToId(token)
|
81 |
+
|
82 |
+
def convert_id_to_token(self, index):
|
83 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
84 |
+
if index in self.index_special_tokens:
|
85 |
+
return self.index_special_tokens[index]
|
86 |
+
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
|
87 |
+
return ""
|
88 |
+
return self.sp_model.IdToPiece(index)
|
89 |
+
|
90 |
+
|
91 |
+
class ChatGLMTokenizer(PreTrainedTokenizer):
|
92 |
+
vocab_files_names = {"vocab_file": "tokenizer.model"}
|
93 |
+
|
94 |
+
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
95 |
+
|
96 |
+
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
|
97 |
+
**kwargs):
|
98 |
+
self.name = "GLMTokenizer"
|
99 |
+
|
100 |
+
self.vocab_file = vocab_file
|
101 |
+
self.tokenizer = SPTokenizer(vocab_file)
|
102 |
+
self.special_tokens = {
|
103 |
+
"<bos>": self.tokenizer.bos_id,
|
104 |
+
"<eos>": self.tokenizer.eos_id,
|
105 |
+
"<pad>": self.tokenizer.pad_id
|
106 |
+
}
|
107 |
+
self.encode_special_tokens = encode_special_tokens
|
108 |
+
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
109 |
+
encode_special_tokens=encode_special_tokens,
|
110 |
+
**kwargs)
|
111 |
+
|
112 |
+
def get_command(self, token):
|
113 |
+
if token in self.special_tokens:
|
114 |
+
return self.special_tokens[token]
|
115 |
+
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
116 |
+
return self.tokenizer.special_tokens[token]
|
117 |
+
|
118 |
+
@property
|
119 |
+
def unk_token(self) -> str:
|
120 |
+
return "<unk>"
|
121 |
+
|
122 |
+
@property
|
123 |
+
def pad_token(self) -> str:
|
124 |
+
return "<unk>"
|
125 |
+
|
126 |
+
@property
|
127 |
+
def pad_token_id(self):
|
128 |
+
return self.get_command("<pad>")
|
129 |
+
|
130 |
+
@property
|
131 |
+
def eos_token(self) -> str:
|
132 |
+
return "</s>"
|
133 |
+
|
134 |
+
@property
|
135 |
+
def eos_token_id(self):
|
136 |
+
return self.get_command("<eos>")
|
137 |
+
|
138 |
+
@property
|
139 |
+
def vocab_size(self):
|
140 |
+
return self.tokenizer.n_words
|
141 |
+
|
142 |
+
def get_vocab(self):
|
143 |
+
""" Returns vocab as a dict """
|
144 |
+
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
145 |
+
vocab.update(self.added_tokens_encoder)
|
146 |
+
return vocab
|
147 |
+
|
148 |
+
def _tokenize(self, text, **kwargs):
|
149 |
+
return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
|
150 |
+
|
151 |
+
def _convert_token_to_id(self, token):
|
152 |
+
""" Converts a token (str) in an id using the vocab. """
|
153 |
+
return self.tokenizer.convert_token_to_id(token)
|
154 |
+
|
155 |
+
def _convert_id_to_token(self, index):
|
156 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
157 |
+
return self.tokenizer.convert_id_to_token(index)
|
158 |
+
|
159 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
160 |
+
return self.tokenizer.decode_tokens(tokens)
|
161 |
+
|
162 |
+
def save_vocabulary(self, save_directory, filename_prefix=None):
|
163 |
+
"""
|
164 |
+
Save the vocabulary and special tokens file to a directory.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
save_directory (`str`):
|
168 |
+
The directory in which to save the vocabulary.
|
169 |
+
filename_prefix (`str`, *optional*):
|
170 |
+
An optional prefix to add to the named of the saved files.
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
`Tuple(str)`: Paths to the files saved.
|
174 |
+
"""
|
175 |
+
if os.path.isdir(save_directory):
|
176 |
+
vocab_file = os.path.join(
|
177 |
+
save_directory, self.vocab_files_names["vocab_file"]
|
178 |
+
)
|
179 |
+
else:
|
180 |
+
vocab_file = save_directory
|
181 |
+
|
182 |
+
with open(self.vocab_file, 'rb') as fin:
|
183 |
+
proto_str = fin.read()
|
184 |
+
|
185 |
+
with open(vocab_file, "wb") as writer:
|
186 |
+
writer.write(proto_str)
|
187 |
+
|
188 |
+
return (vocab_file,)
|
189 |
+
|
190 |
+
def get_prefix_tokens(self):
|
191 |
+
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
192 |
+
return prefix_tokens
|
193 |
+
|
194 |
+
def build_single_message(self, role, metadata, message):
|
195 |
+
assert role in ["system", "user", "assistant", "observation"], role
|
196 |
+
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
|
197 |
+
message_tokens = self.tokenizer.encode(message)
|
198 |
+
tokens = role_tokens + message_tokens
|
199 |
+
return tokens
|
200 |
+
|
201 |
+
def build_chat_input(self, query, history=None, role="user"):
|
202 |
+
if history is None:
|
203 |
+
history = []
|
204 |
+
input_ids = []
|
205 |
+
for item in history:
|
206 |
+
content = item["content"]
|
207 |
+
if item["role"] == "system" and "tools" in item:
|
208 |
+
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
|
209 |
+
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
|
210 |
+
input_ids.extend(self.build_single_message(role, "", query))
|
211 |
+
input_ids.extend([self.get_command("<|assistant|>")])
|
212 |
+
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
|
213 |
+
|
214 |
+
def build_inputs_with_special_tokens(
|
215 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
216 |
+
) -> List[int]:
|
217 |
+
"""
|
218 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
219 |
+
adding special tokens. A BERT sequence has the following format:
|
220 |
+
|
221 |
+
- single sequence: `[CLS] X [SEP]`
|
222 |
+
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
223 |
+
|
224 |
+
Args:
|
225 |
+
token_ids_0 (`List[int]`):
|
226 |
+
List of IDs to which the special tokens will be added.
|
227 |
+
token_ids_1 (`List[int]`, *optional*):
|
228 |
+
Optional second list of IDs for sequence pairs.
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
232 |
+
"""
|
233 |
+
prefix_tokens = self.get_prefix_tokens()
|
234 |
+
token_ids_0 = prefix_tokens + token_ids_0
|
235 |
+
if token_ids_1 is not None:
|
236 |
+
token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
|
237 |
+
return token_ids_0
|
238 |
+
|
239 |
+
def _pad(
|
240 |
+
self,
|
241 |
+
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
242 |
+
max_length: Optional[int] = None,
|
243 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
244 |
+
pad_to_multiple_of: Optional[int] = None,
|
245 |
+
return_attention_mask: Optional[bool] = None,
|
246 |
+
) -> dict:
|
247 |
+
"""
|
248 |
+
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
249 |
+
|
250 |
+
Args:
|
251 |
+
encoded_inputs:
|
252 |
+
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
253 |
+
max_length: maximum length of the returned list and optionally padding length (see below).
|
254 |
+
Will truncate by taking into account the special tokens.
|
255 |
+
padding_strategy: PaddingStrategy to use for padding.
|
256 |
+
|
257 |
+
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
258 |
+
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
259 |
+
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
260 |
+
The tokenizer padding sides are defined in self.padding_side:
|
261 |
+
|
262 |
+
- 'left': pads on the left of the sequences
|
263 |
+
- 'right': pads on the right of the sequences
|
264 |
+
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
265 |
+
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
266 |
+
`>= 7.5` (Volta).
|
267 |
+
return_attention_mask:
|
268 |
+
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
269 |
+
"""
|
270 |
+
# Load from model defaults
|
271 |
+
assert self.padding_side == "left"
|
272 |
+
|
273 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
274 |
+
seq_length = len(required_input)
|
275 |
+
|
276 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
277 |
+
max_length = len(required_input)
|
278 |
+
|
279 |
+
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
280 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
281 |
+
|
282 |
+
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
283 |
+
|
284 |
+
# Initialize attention mask if not present.
|
285 |
+
if "attention_mask" not in encoded_inputs:
|
286 |
+
encoded_inputs["attention_mask"] = [1] * seq_length
|
287 |
+
|
288 |
+
if "position_ids" not in encoded_inputs:
|
289 |
+
encoded_inputs["position_ids"] = list(range(seq_length))
|
290 |
+
|
291 |
+
if needs_to_be_padded:
|
292 |
+
difference = max_length - len(required_input)
|
293 |
+
|
294 |
+
if "attention_mask" in encoded_inputs:
|
295 |
+
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
|
296 |
+
if "position_ids" in encoded_inputs:
|
297 |
+
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
|
298 |
+
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
299 |
+
|
300 |
+
return encoded_inputs
|
tokenizer/tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e7dc4c393423b76e4373e5157ddc34803a0189ba96b21ddbb40269d31468a6f2
|
3 |
+
size 1018370
|
tokenizer/tokenizer_config.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name_or_path": "THUDM/chatglm3-6b-base",
|
3 |
+
"remove_space": false,
|
4 |
+
"do_lower_case": false,
|
5 |
+
"tokenizer_class": "ChatGLMTokenizer",
|
6 |
+
"auto_map": {
|
7 |
+
"AutoTokenizer": [
|
8 |
+
"tokenization_chatglm.ChatGLMTokenizer",
|
9 |
+
null
|
10 |
+
]
|
11 |
+
}
|
12 |
+
}
|
tokenizer/vocab.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e7dc4c393423b76e4373e5157ddc34803a0189ba96b21ddbb40269d31468a6f2
|
3 |
+
size 1018370
|
unet/config.json
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "UNet2DConditionModel",
|
3 |
+
"_diffusers_version": "0.21.0.dev0",
|
4 |
+
"_name_or_path": "/mmu-vcg/wujunqiang/models/2_inpaint_sdxl_v4_controlNet_v0_1kw_inout_c9_norm_align_1024/checkpoint-22000/unet",
|
5 |
+
"act_fn": "silu",
|
6 |
+
"addition_embed_type": "text_time",
|
7 |
+
"addition_embed_type_num_heads": 64,
|
8 |
+
"addition_time_embed_dim": 256,
|
9 |
+
"attention_head_dim": [
|
10 |
+
5,
|
11 |
+
10,
|
12 |
+
20
|
13 |
+
],
|
14 |
+
"attention_type": "default",
|
15 |
+
"block_out_channels": [
|
16 |
+
320,
|
17 |
+
640,
|
18 |
+
1280
|
19 |
+
],
|
20 |
+
"center_input_sample": false,
|
21 |
+
"class_embed_type": null,
|
22 |
+
"class_embeddings_concat": false,
|
23 |
+
"conv_in_kernel": 3,
|
24 |
+
"conv_out_kernel": 3,
|
25 |
+
"cross_attention_dim": 2048,
|
26 |
+
"cross_attention_norm": null,
|
27 |
+
"down_block_types": [
|
28 |
+
"DownBlock2D",
|
29 |
+
"CrossAttnDownBlock2D",
|
30 |
+
"CrossAttnDownBlock2D"
|
31 |
+
],
|
32 |
+
"downsample_padding": 1,
|
33 |
+
"dropout": 0.0,
|
34 |
+
"dual_cross_attention": false,
|
35 |
+
"encoder_hid_dim": 4096,
|
36 |
+
"encoder_hid_dim_type": "text_proj",
|
37 |
+
"flip_sin_to_cos": true,
|
38 |
+
"freq_shift": 0,
|
39 |
+
"in_channels": 9,
|
40 |
+
"layers_per_block": 2,
|
41 |
+
"mid_block_only_cross_attention": null,
|
42 |
+
"mid_block_scale_factor": 1,
|
43 |
+
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
44 |
+
"norm_eps": 1e-05,
|
45 |
+
"norm_num_groups": 32,
|
46 |
+
"num_attention_heads": null,
|
47 |
+
"num_class_embeds": null,
|
48 |
+
"only_cross_attention": false,
|
49 |
+
"out_channels": 4,
|
50 |
+
"projection_class_embeddings_input_dim": 5632,
|
51 |
+
"resnet_out_scale_factor": 1.0,
|
52 |
+
"resnet_skip_time_act": false,
|
53 |
+
"resnet_time_scale_shift": "default",
|
54 |
+
"reverse_transformer_layers_per_block": null,
|
55 |
+
"sample_size": 128,
|
56 |
+
"time_cond_proj_dim": null,
|
57 |
+
"time_embedding_act_fn": null,
|
58 |
+
"time_embedding_dim": null,
|
59 |
+
"time_embedding_type": "positional",
|
60 |
+
"timestep_post_act": null,
|
61 |
+
"transformer_layers_per_block": [
|
62 |
+
1,
|
63 |
+
2,
|
64 |
+
10
|
65 |
+
],
|
66 |
+
"up_block_types": [
|
67 |
+
"CrossAttnUpBlock2D",
|
68 |
+
"CrossAttnUpBlock2D",
|
69 |
+
"UpBlock2D"
|
70 |
+
],
|
71 |
+
"upcast_attention": false,
|
72 |
+
"use_linear_projection": true
|
73 |
+
}
|
unet/diffusion_pytorch_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1661388ebb2ed8d265d19725004f3b2c0e4ff384bf9125b9216028dce883a063
|
3 |
+
size 5159169040
|
vae/config.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "AutoencoderKL",
|
3 |
+
"_diffusers_version": "0.18.0.dev0",
|
4 |
+
"_name_or_path": "./vae",
|
5 |
+
"act_fn": "silu",
|
6 |
+
"block_out_channels": [
|
7 |
+
128,
|
8 |
+
256,
|
9 |
+
512,
|
10 |
+
512
|
11 |
+
],
|
12 |
+
"down_block_types": [
|
13 |
+
"DownEncoderBlock2D",
|
14 |
+
"DownEncoderBlock2D",
|
15 |
+
"DownEncoderBlock2D",
|
16 |
+
"DownEncoderBlock2D"
|
17 |
+
],
|
18 |
+
"in_channels": 3,
|
19 |
+
"latent_channels": 4,
|
20 |
+
"layers_per_block": 2,
|
21 |
+
"norm_num_groups": 32,
|
22 |
+
"out_channels": 3,
|
23 |
+
"sample_size": 1024,
|
24 |
+
"scaling_factor": 0.13025,
|
25 |
+
"up_block_types": [
|
26 |
+
"UpDecoderBlock2D",
|
27 |
+
"UpDecoderBlock2D",
|
28 |
+
"UpDecoderBlock2D",
|
29 |
+
"UpDecoderBlock2D"
|
30 |
+
]
|
31 |
+
}
|
vae/diffusion_pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:48b7606461b0646c5a578e686e2fddbb54ae789823ab82ca07f4b898261f6383
|
3 |
+
size 334712113
|
vae/diffusion_pytorch_model.fp16.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2ce744db8ec41697eaecabe3508566aa76e53d71f79e595b0d0f56c9f07405ce
|
3 |
+
size 167405651
|
vae/diffusion_pytorch_model.fp16.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bcb60880a46b63dea58e9bc591abe15f8350bde47b405f9c38f4be70c6161e68
|
3 |
+
size 167335342
|
vae/diffusion_pytorch_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1598f3d24932bcfe6634e8b618ea1e30ab1d57f5aad13a6d2de446d2199f2341
|
3 |
+
size 334643268
|