"
+labels: []
+body:
+ - type: checkboxes
+ attributes:
+ label: 是否已有关于该错误的issue或讨论? | Is there an existing issue / discussion for this?
+ description: |
+ 请先搜索您遇到的错误是否在已有的issues或讨论中提到过。
+ Please search to see if an issue / discussion already exists for the bug you encountered.
+ [Issues](https://github.com/QwenLM/Qwen-7B/issues)
+ [Discussions](https://github.com/QwenLM/Qwen-7B/discussions)
+ options:
+ - label: 我已经搜索过已有的issues和讨论 | I have searched the existing issues / discussions
+ required: true
+ - type: checkboxes
+ attributes:
+ label: 该问题是否在FAQ中有解答? | Is there an existing answer for this in FAQ?
+ description: |
+ 请先搜索您遇到的错误是否已在FAQ中有相关解答。
+ Please search to see if an answer already exists in FAQ for the bug you encountered.
+ [FAQ-en](https://github.com/QwenLM/Qwen-7B/blob/main/FAQ.md)
+ [FAQ-zh](https://github.com/QwenLM/Qwen-7B/blob/main/FAQ_zh.md)
+ options:
+ - label: 我已经搜索过FAQ | I have searched FAQ
+ required: true
+ - type: textarea
+ attributes:
+ label: 当前行为 | Current Behavior
+ description: |
+ 准确描述遇到的行为。
+ A concise description of what you're experiencing.
+ validations:
+ required: false
+ - type: textarea
+ attributes:
+ label: 期望行为 | Expected Behavior
+ description: |
+ 准确描述预期的行为。
+ A concise description of what you expected to happen.
+ validations:
+ required: false
+ - type: textarea
+ attributes:
+ label: 复现方法 | Steps To Reproduce
+ description: |
+ 复现当前行为的详细步骤。
+ Steps to reproduce the behavior.
+ placeholder: |
+ 1. In this environment...
+ 2. With this config...
+ 3. Run '...'
+ 4. See error...
+ validations:
+ required: false
+ - type: textarea
+ attributes:
+ label: 运行环境 | Environment
+ description: |
+ examples:
+ - **OS**: Ubuntu 20.04
+ - **Python**: 3.8
+ - **Transformers**: 4.31.0
+ - **PyTorch**: 2.0.1
+ - **CUDA**: 11.4
+ value: |
+ - OS:
+ - Python:
+ - Transformers:
+ - PyTorch:
+ - CUDA (`python -c 'import torch; print(torch.version.cuda)'`):
+ render: Markdown
+ validations:
+ required: false
+ - type: textarea
+ attributes:
+ label: 备注 | Anything else?
+ description: |
+ 您可以在这里补充其他关于该问题背景信息的描述、链接或引用等。
+
+ 您可以通过点击高亮此区域然后拖动文件的方式上传图片或日志文件。
+
+ Links? References? Anything that will give us more context about the issue you are encountering!
+
+ Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in.
+ validations:
+ required: false
diff --git a/.github/ISSUE_TEMPLATE/config.yaml b/.github/ISSUE_TEMPLATE/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0086358db1eb971c0cfa8739c27518bbc18a5ff4
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/config.yaml
@@ -0,0 +1 @@
+blank_issues_enabled: true
diff --git a/.github/ISSUE_TEMPLATE/feature_request.yaml b/.github/ISSUE_TEMPLATE/feature_request.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e677af83ac00162afab9318e10e5fc1b9c229fd6
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature_request.yaml
@@ -0,0 +1,78 @@
+name: "💡 Feature Request"
+description: 创建新功能请求 | Create a new ticket for a new feature request
+title: "💡 [REQUEST] - "
+labels: [
+ "question"
+]
+body:
+ - type: input
+ id: start_date
+ attributes:
+ label: "起始日期 | Start Date"
+ description: |
+ 起始开发日期
+ Start of development
+ placeholder: "month/day/year"
+ validations:
+ required: false
+ - type: textarea
+ id: implementation_pr
+ attributes:
+ label: "实现PR | Implementation PR"
+ description: |
+ 实现该功能的Pull request
+ Pull request used
+ placeholder: "#Pull Request ID"
+ validations:
+ required: false
+ - type: textarea
+ id: reference_issues
+ attributes:
+ label: "相关Issues | Reference Issues"
+ description: |
+ 与该功能相关的issues
+ Common issues
+ placeholder: "#Issues IDs"
+ validations:
+ required: false
+ - type: textarea
+ id: summary
+ attributes:
+ label: "摘要 | Summary"
+ description: |
+ 简要描述新功能的特点
+ Provide a brief explanation of the feature
+ placeholder: |
+ Describe in a few lines your feature request
+ validations:
+ required: true
+ - type: textarea
+ id: basic_example
+ attributes:
+ label: "基本示例 | Basic Example"
+ description: Indicate here some basic examples of your feature.
+ placeholder: A few specific words about your feature request.
+ validations:
+ required: true
+ - type: textarea
+ id: drawbacks
+ attributes:
+ label: "缺陷 | Drawbacks"
+ description: |
+ 该新功能有哪些缺陷/可能造成哪些影响?
+ What are the drawbacks/impacts of your feature request ?
+ placeholder: |
+ Identify the drawbacks and impacts while being neutral on your feature request
+ validations:
+ required: true
+ - type: textarea
+ id: unresolved_question
+ attributes:
+ label: "未解决问题 | Unresolved questions"
+ description: |
+ 有哪些尚未解决的问题?
+ What questions still remain unresolved ?
+ placeholder: |
+ Identify any unresolved issues.
+ validations:
+ required: false
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..39e90665d32fb065ef7a679d9e02769b1ef8fabe
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,11 @@
+__pycache__
+*.so
+build
+.coverage_*
+*.egg-info
+*~
+.vscode/
+.idea/
+.DS_Store
+
+/private/
diff --git a/FAQ.md b/FAQ.md
new file mode 100644
index 0000000000000000000000000000000000000000..c6452860350302bcd98d15b0fdc57c807c8a04db
--- /dev/null
+++ b/FAQ.md
@@ -0,0 +1,85 @@
+# FAQ
+
+## Installation & Environment
+
+#### Failure in installing flash attention
+
+Flash attention is an option for accelerating training and inference. Only NVIDIA GPUs of Turing, Ampere, Ada, and Hopper architecture, e.g., H100, A100, RTX 3090, T4, RTX 2080, can support flash attention. You can use our models without installing it.
+
+#### Which version of transformers should I use?
+
+4.31.0 is preferred.
+
+#### I downloaded the codes and checkpoints but I can't load the model locally. What should I do?
+
+Please check if you have updated the code to the latest, and correctly downloaded all the sharded checkpoint files.
+
+#### `qwen.tiktoken` is not found. What is it?
+
+This is the merge file of the tokenizer. You have to download it. Note that if you just git clone the repo without [git-lfs](https://git-lfs.com), you cannot download this file.
+
+#### transformers_stream_generator/tiktoken/accelerate not found
+
+Run the command `pip install -r requirements.txt`. You can find the file at [https://github.com/QwenLM/Qwen-7B/blob/main/requirements.txt](https://github.com/QwenLM/Qwen-7B/blob/main/requirements.txt).
+
+
+
+
+## Demo & Inference
+
+#### Is there any demo? CLI demo and Web UI demo?
+
+Yes, see `web_demo.py` for web demo and `cli_demo.py` for CLI demo. See README for more information.
+
+
+
+#### Can I use CPU only?
+
+Yes, run `python cli_demo.py --cpu_only` will load the model and inference on CPU only.
+
+#### Can Qwen support streaming?
+
+Yes. See the function `chat_stream` in `modeling_qwen.py`.
+
+#### Gibberish in result when using chat_stream().
+
+This is because tokens represent bytes and a single token may be a meaningless string. We have updated the default setting of our tokenizer to avoid such decoding results. Please update the code to the latest version.
+
+#### It seems that the generation is not related to the instruction...
+
+Please check if you are loading Qwen-7B-Chat instead of Qwen-7B. Qwen-7B is the base model without alignment, which behaves differently from the SFT/Chat model.
+
+#### Is quantization supported?
+
+Yes, the quantization is supported by `bitsandbytes`. We are working on an improved version and will release the quantized model checkpoints.
+
+#### Errors in running quantized models: `importlib.metadata.PackageNotFoundError: No package metadata was found for bitsandbytes`
+
+For Linux users,running `pip install bitsandbytes` directly can solve the problem. For Windows users, you can run `python -m pip install bitsandbytes --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui`·
+
+#### Slow when processing long sequences
+
+We solved this problem. Updating the code to the latest version can help.
+
+#### Unsatisfactory performance in processing long sequences
+
+Please ensure that NTK is applied. `use_dynamc_ntk` and `use_logn_attn` in `config.json` should be set to `true` (`true` by default).
+
+
+
+
+## Finetuning
+
+#### Can Qwen support SFT or even RLHF?
+
+We do not provide finetuning or RLHF codes for now. However, some projects have supported finetuning, see [FastChat](**[https://github.com/lm-sys/FastChat](https://github.com/lm-sys/FastChat)), [Firefly]([https://github.com/yangjianxin1/Firefly](https://github.com/yangjianxin1/Firefly)), [**LLaMA Efficient Tuning**]([https://github.com/hiyouga/LLaMA-Efficient-Tuning](https://github.com/hiyouga/LLaMA-Efficient-Tuning)), etc. We will soon update the relevant codes.
+
+
+
+
+## Tokenizer
+
+#### bos_id/eos_id/pad_id not found
+
+In our training, we only use `<|endoftext|>` as the separator and padding token. You can set bos_id, eos_id, and pad_id to tokenizer.eod_id. Learn more about our tokenizer from our documents about the tokenizer.
+
diff --git a/FAQ_zh.md b/FAQ_zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..174ae69d81c067721e365c11fee752e904fdff1b
--- /dev/null
+++ b/FAQ_zh.md
@@ -0,0 +1,80 @@
+# FAQ
+
+## 安装&环境
+
+#### flash attention 安装失败
+
+flash attention是一个用于加速模型训练推理的可选项,且仅适用于Turing、Ampere、Ada、Hopper架构的Nvidia GPU显卡(如H100、A100、RTX 3090、T4、RTX 2080),您可以在不安装flash attention的情况下正常使用模型进行推理。
+
+#### 我应该用哪个transformers版本?
+
+建议使用4.31.0。
+
+#### 我把模型和代码下到本地,按照教程无法使用,该怎么办?
+
+答:别着急,先检查你的代码是不是更新到最新版本,然后确认你是否完整地将模型checkpoint下到本地。
+
+#### `qwen.tiktoken`这个文件找不到,怎么办?
+
+这个是我们的tokenizer的merge文件,你必须下载它才能使用我们的tokenizer。注意,如果你使用git clone却没有使用git-lfs,这个文件不会被下载。如果你不了解git-lfs,可点击[官网](https://git-lfs.com/)了解。
+
+#### transformers_stream_generator/tiktoken/accelerate,这几个库提示找不到,怎么办?
+
+运行如下命令:`pip install -r requirements.txt`。相关依赖库在[https://github.com/QwenLM/Qwen-7B/blob/main/requirements.txt](https://github.com/QwenLM/Qwen-7B/blob/main/requirements.txt) 可以找到。
+
+
+
+## Tokenizer
+
+#### bos_id/eos_id/pad_id,这些token id不存在,为什么?
+
+在训练过程中,我们仅使用<|endoftext|>这一token作为sample/document之间的分隔符及padding位置占位符,你可以将bos_id, eos_id, pad_id均指向tokenizer.eod_id。请阅读我们关于tokenizer的文档,了解如何设置这些id。
+
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..d69279ec2cbac39574f08ab39efe628b63c883ef
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,53 @@
+Tongyi Qianwen LICENSE AGREEMENT
+
+Tongyi Qianwen Release Date: August 3, 2023
+
+By clicking to agree or by using or distributing any portion or element of the Tongyi Qianwen Materials, you will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
+
+1. Definitions
+ a. This Tongyi Qianwen LICENSE AGREEMENT (this "Agreement") shall mean the terms and conditions for use, reproduction, distribution and modification of the Materials as defined by this Agreement.
+ b. "We"(or "Us") shall mean Alibaba Cloud.
+ c. "You" (or "Your") shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Materials for any purpose and in any field of use.
+ d. "Third Parties" shall mean individuals or legal entities that are not under common control with Us or You.
+ e. "Tongyi Qianwen" shall mean the large language models (including Qwen-7B model and Qwen-7B-Chat model), and software and algorithms, consisting of trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing distributed by Us.
+ f. "Materials" shall mean, collectively, Alibaba Cloud's proprietary Tongyi Qianwen and Documentation (and any portion thereof) made available under this Agreement.
+ g. "Source" form shall mean the preferred form for making modifications, including but not limited to model source code, documentation source, and configuration files.
+ h. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+2. Grant of Rights
+You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Alibaba Cloud's intellectual property or other rights owned by Us embodied in the Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Materials.
+
+3. Redistribution
+You may reproduce and distribute copies of the Materials or derivative works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
+ a. You shall give any other recipients of the Materials or derivative works a copy of this Agreement;
+ b. You shall cause any modified files to carry prominent notices stating that You changed the files;
+ c. You shall retain in all copies of the Materials that You distribute the following attribution notices within a "Notice" text file distributed as a part of such copies: "Tongyi Qianwen is licensed under the Tongyi Qianwen LICENSE AGREEMENT, Copyright (c) Alibaba Cloud. All Rights Reserved."; and
+ d. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such derivative works as a whole, provided Your use, reproduction, and distribution of the work otherwise complies with the terms and conditions of this Agreement.
+
+4. Restrictions
+If you are commercially using the Materials, and your product or service has more than 100 million monthly active users, You shall request a license from Us. You cannot exercise your rights under this Agreement without our express authorization.
+
+5. Rules of use
+ a. The Materials may be subject to export controls or restrictions in China, the United States or other countries or regions. You shall comply with applicable laws and regulations in your use of the Materials.
+ b. You can not use the Materials or any output therefrom to improve any other large language model (excluding Tongyi Qianwen or derivative works thereof).
+
+6. Intellectual Property
+ a. We retain ownership of all intellectual property rights in and to the Materials and derivatives made by or for Us. Conditioned upon compliance with the terms and conditions of this Agreement, with respect to any derivative works and modifications of the Materials that are made by you, you are and will be the owner of such derivative works and modifications.
+ b. No trademark license is granted to use the trade names, trademarks, service marks, or product names of Us, except as required to fulfill notice requirements under this Agreement or as required for reasonable and customary use in describing and redistributing the Materials.
+ c. If you commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any entity alleging that the Materials or any output therefrom, or any part of the foregoing, infringe any intellectual property or other right owned or licensable by you, then all licences granted to you under this Agreement shall terminate as of the date such lawsuit or other proceeding is commenced or brought.
+
+7. Disclaimer of Warranty and Limitation of Liability
+
+ a. We are not obligated to support, update, provide training for, or develop any further version of the Tongyi Qianwen Materials or to grant any license thereto.
+ b. THE MATERIALS ARE PROVIDED "AS IS" WITHOUT ANY EXPRESS OR IMPLIED WARRANTY OF ANY KIND INCLUDING WARRANTIES OF MERCHANTABILITY, NONINFRINGEMENT, OR FITNESS FOR A PARTICULAR PURPOSE. WE MAKE NO WARRANTY AND ASSUME NO RESPONSIBILITY FOR THE SAFETY OR STABILITY OF THE MATERIALS AND ANY OUTPUT THEREFROM.
+ c. IN NO EVENT SHALL WE BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE MATERIALS OR ANY OUTPUT OF IT, NO MATTER HOW IT’S CAUSED.
+ d. You will defend, indemnify and hold harmless Us from and against any claim by any third party arising out of or related to your use or distribution of the Materials.
+
+8. Survival and Termination.
+ a. The term of this Agreement shall commence upon your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
+ b. We may terminate this Agreement if you breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, you must delete and cease use of the Materials. Sections 7 and 9 shall survive the termination of this Agreement.
+
+9. Governing Law and Jurisdiction.
+ a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
+ b. The People's Courts in Hangzhou City shall have exclusive jurisdiction over any dispute arising out of this Agreement.
\ No newline at end of file
diff --git a/NOTICE b/NOTICE
new file mode 100644
index 0000000000000000000000000000000000000000..22c063e728d032f6578e28eca18542ebd0c4610b
--- /dev/null
+++ b/NOTICE
@@ -0,0 +1,52 @@
+------------- LICENSE FOR NVIDIA Megatron-LM code --------------
+
+Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions
+are met:
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+ * Neither the name of NVIDIA CORPORATION nor the names of its
+ contributors may be used to endorse or promote products derived
+ from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+------------- LICENSE FOR OpenAI tiktoken code --------------
+
+MIT License
+
+Copyright (c) 2022 OpenAI, Shantanu Jain
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/README.md b/README.md
index eaf0ff7cbf8c3afb1360aae80ee77f87cea436e6..21b1d660244040ff160c605895d6a8987dc8704e 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,437 @@
---
-title: Qwen 7B Main
-emoji: 👀
-colorFrom: yellow
-colorTo: green
+title: Qwen-7B-main
+app_file: web_demo.py
sdk: gradio
sdk_version: 3.40.1
-app_file: app.py
-pinned: false
---
+
+
+
+
+We opensource **Qwen-7B** and **Qwen-7B-Chat** on both **🤖 ModelScope** and **🤗 Hugging Face** (Click the logos on top to the repos with codes and checkpoints). This repo includes the brief introduction to Qwen-7B, the usage guidance, and also a technical memo [link](tech_memo.md) that provides more information.
+
+Qwen-7B is the 7B-parameter version of the large language model series, Qwen (abbr. Tongyi Qianwen), proposed by Alibaba Cloud. Qwen-7B is a Transformer-based large language model, which is pretrained on a large volume of data, including web texts, books, codes, etc. Additionally, based on the pretrained Qwen-7B, we release Qwen-7B-Chat, a large-model-based AI assistant, which is trained with alignment techniques. The features of the Qwen-7B series include:
+
+1. **Trained with high-quality pretraining data**. We have pretrained Qwen-7B on a self-constructed large-scale high-quality dataset of over 2.2 trillion tokens. The dataset includes plain texts and codes, and it covers a wide range of domains, including general domain data and professional domain data.
+2. **Strong performance**. In comparison with the models of the similar model size, we outperform the competitors on a series of benchmark datasets, which evaluates natural language understanding, mathematics, coding, etc.
+3. **Better support of languages**. Our tokenizer, based on a large vocabulary of over 150K tokens, is a more efficient one compared with other tokenizers. It is friendly to many languages, and it is helpful for users to further finetune Qwen-7B for the extension of understanding a certain language.
+4. **Support of 8K Context Length**. Both Qwen-7B and Qwen-7B-Chat support the context length of 8K, which allows inputs with long contexts.
+5. **Support of Plugins**. Qwen-7B-Chat is trained with plugin-related alignment data, and thus it is capable of using tools, including APIs, models, databases, etc., and it is capable of playing as an agent.
+
+The following sections include information that you might find it helpful. Specifically, we advise you to read the FAQ section before you launch issues.
+
+## News
+
+* 2023.8.3 We release both Qwen-7B and Qwen-7B-Chat on ModelScope and Hugging Face. We also provide a technical memo for more details about the model, including training details and model performance.
+
+## Performance
+
+In general, Qwen-7B outperforms the baseline models of a similar model size, and even outperforms larger models of around 13B parameters, on a series of benchmark datasets, e.g., MMLU, C-Eval, GSM8K, HumanEval, and WMT22, CMMLU, etc., which evaluate the models' capabilities on natural language understanding, mathematic problem solving, coding, etc. See the results below.
+
+| Model | MMLU | C-Eval | GSM8K | HumanEval | WMT22 (en-zh) | CMMLU |
+| :---------------- | :------------: | :------------: | :------------: | :------------: | :------------: |:------------: |
+| LLaMA-7B | 35.1 | - | 11.0 | 10.5 | 8.7 | - |
+| LLaMA 2-7B | 45.3 | - | 14.6 | 12.8 | 17.9 | - |
+| Baichuan-7B | 42.3 | 42.8 | 9.7 | 9.2 | 26.6 | 44.4 |
+| ChatGLM2-6B | 47.9 | 51.7 | 32.4 | 9.2 | - | 48.8 |
+| InternLM-7B | 51.0 | 52.8 | 31.2 | 10.4 | 14.8 | - |
+| Baichuan-13B | 51.6 | 53.6 | 26.6 | 12.8 | 30.0 | 55.8 |
+| LLaMA-13B | 46.9 | 35.5 | 17.8 | 15.8 | 12.0 | - |
+| LLaMA 2-13B | 54.8 | - | 28.7 | 18.3 | 24.2 | - |
+| ChatGLM2-12B | 56.2 | **61.6** | 40.9 | - | - | - |
+| **Qwen-7B** | **56.7** | 59.6 | **51.6** | **24.4** | **30.6** | **58.8** |
+
+
+
+
+
+
+Additionally, according to the third-party evaluation of large language models, conducted by [OpenCompass](https://opencompass.org.cn/leaderboard-llm), Qwen-7B and Qwen-7B-Chat are the top 7B-parameter models. This evaluation consists of a large amount of public benchmarks for the evaluation of language understanding and generation, coding, mathematics, reasoning, etc.
+
+For more experimental results (detailed model performance on more benchmark datasets) and details, please refer to our technical memo by clicking [here](tech_memo.md).
+
+## Requirements
+
+* python 3.8 and above
+* pytorch 1.12 and above, 2.0 and above are recommended
+* CUDA 11.4 and above are recommended (this is for GPU users, flash-attention users, etc.)
+
+## Quickstart
+
+Below, we provide simple examples to show how to use Qwen-7B with 🤖 ModelScope and 🤗 Transformers.
+
+Before running the code, make sure you have setup the environment and installed the required packages. Make sure you meet the above requirements, and then install the dependent libraries.
+
+```bash
+pip install -r requirements.txt
+```
+
+If your device supports fp16 or bf16, we recommend installing [flash-attention](https://github.com/Dao-AILab/flash-attention) for higher efficiency and lower memory usage. (**flash-attention is optional and the project can run normally without installing it**)
+
+```bash
+git clone -b v1.0.8 https://github.com/Dao-AILab/flash-attention
+cd flash-attention && pip install .
+# Below are optional. Installing them might be slow.
+# pip install csrc/layer_norm
+# pip install csrc/rotary
+```
+
+Now you can start with ModelScope or Transformers.
+
+#### 🤗 Transformers
+
+To use Qwen-7B-Chat for the inference, all you need to do is to input a few lines of codes as demonstrated below. However, **please make sure that you are using the latest code.**
+
+```python
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from transformers.generation import GenerationConfig
+
+# Note: The default behavior now has injection attack prevention off.
+tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
+
+# use bf16
+# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval()
+# use fp16
+# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval()
+# use cpu only
+# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="cpu", trust_remote_code=True).eval()
+# use auto mode, automatically select precision based on the device.
+model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True).eval()
+
+# Specify hyperparameters for generation
+model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
+
+# 第一轮对话 1st dialogue turn
+response, history = model.chat(tokenizer, "你好", history=None)
+print(response)
+# 你好!很高兴为你提供帮助。
+
+# 第二轮对话 2nd dialogue turn
+response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=history)
+print(response)
+# 这是一个关于一个年轻人奋斗创业最终取得成功的故事。
+# 故事的主人公叫李明,他来自一个普通的家庭,父母都是普通的工人。从小,李明就立下了一个目标:要成为一名成功的企业家。
+# 为了实现这个目标,李明勤奋学习,考上了大学。在大学期间,他积极参加各种创业比赛,获得了不少奖项。他还利用课余时间去实习,积累了宝贵的经验。
+# 毕业后,李明决定开始自己的创业之路。他开始寻找投资机会,但多次都被拒绝了。然而,他并没有放弃。他继续努力,不断改进自己的创业计划,并寻找新的投资机会。
+# 最终,李明成功地获得了一笔投资,开始了自己的创业之路。他成立了一家科技公司,专注于开发新型软件。在他的领导下,公司迅速发展起来,成为了一家成功的科技企业。
+# 李明的成功并不是偶然的。他勤奋、坚韧、勇于冒险,不断学习和改进自己。他的成功也证明了,只要努力奋斗,任何人都有可能取得成功。
+
+# 第三轮对话 3rd dialogue turn
+response, history = model.chat(tokenizer, "给这个故事起一个标题", history=history)
+print(response)
+# 《奋斗创业:一个年轻人的成功之路》
+```
+
+Running Qwen-7B pretrained base model is also simple.
+
+
+ Running Qwen-7B
+
+```python
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from transformers.generation import GenerationConfig
+
+tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B", trust_remote_code=True)
+# use bf16
+# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True, bf16=True).eval()
+# use fp16
+# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True, fp16=True).eval()
+# use cpu only
+# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="cpu", trust_remote_code=True).eval()
+# use auto mode, automatically select precision based on the device.
+model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True).eval()
+
+# Specify hyperparameters for generation
+model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B", trust_remote_code=True)
+
+inputs = tokenizer('蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是', return_tensors='pt')
+inputs = inputs.to(model.device)
+pred = model.generate(**inputs)
+print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
+# 蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是亚的斯亚贝巴(Addis Ababa)...
+```
+
+
+
+#### 🤖 ModelScope
+
+ModelScope is an opensource platform for Model-as-a-Service (MaaS), which provides flexible and cost-effective model service to AI developers. Similarly, you can run the models with ModelScope as shown below:
+
+```python
+import os
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope import snapshot_download
+
+model_id = 'QWen/qwen-7b-chat'
+revision = 'v1.0.0'
+
+model_dir = snapshot_download(model_id, revision)
+
+pipe = pipeline(
+task=Tasks.chat, model=model_dir, device_map='auto')
+history = None
+
+text = '浙江的省会在哪里?'
+results = pipe(text, history=history)
+response, history = results['response'], results['history']
+print(f'Response: {response}')
+text = '它有什么好玩的地方呢?'
+results = pipe(text, history=history)
+response, history = results['response'], results['history']
+print(f'Response: {response}')
+```
+
+## Tokenizer
+
+Our tokenizer based on tiktoken is different from other tokenizers, e.g., sentencepiece tokenizer. You need to pay attention to special tokens, especially in finetuning. For more detailed information on the tokenizer and related use in fine-tuning, please refer to the [documentation](tokenization_note.md).
+
+## Quantization
+
+We provide examples to show how to load models in `NF4` and `Int8`. For starters, make sure you have implemented `bitsandbytes`. Note that the requirements for `bitsandbytes` are:
+
+```
+**Requirements** Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0.
+```
+
+Then run the following command to install `bitsandbytes`:
+
+```
+pip install bitsandbytes
+```
+
+Windows users should find another option, which might be [bitsandbytes-windows-webui](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels).
+
+Then you only need to add your quantization configuration to `AutoModelForCausalLM.from_pretrained`. See the example below:
+
+```python
+from transformers import AutoModelForCausalLM, BitsAndBytesConfig
+
+# quantization configuration for NF4 (4 bits)
+quantization_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type='nf4',
+ bnb_4bit_compute_dtype=torch.bfloat16
+)
+
+# quantization configuration for Int8 (8 bits)
+quantization_config = BitsAndBytesConfig(load_in_8bit=True)
+
+model = AutoModelForCausalLM.from_pretrained(
+ args.checkpoint_path,
+ device_map="cuda:0",
+ quantization_config=quantization_config,
+ max_memory=max_memory,
+ trust_remote_code=True,
+).eval()
+```
+
+With this method, it is available to load Qwen-7B in `NF4` and `Int8`, which saves you memory usage. We provide related statistics of model performance below. We find that the quantization downgrades the effectiveness slightly but significantly reduces memory costs.
+
+| Precision | MMLU | GPU Memory for Loading Model |
+| ----------- | :------: | :---------------------------: |
+| BF16 | 56.7 | 16.38G |
+| Int8 | 52.8 | 10.44G |
+| NF4 | 48.9 | 7.79G |
+
+Note: The GPU memory usage profiling in the above table is performed on single A100-SXM4-80G GPU, PyTorch 2.0.1 and CUDA 11.8, with flash attention used.
+
+## Inference Efficiency
+
+### Inference Speed
+
+We measured the average inference speed of generating 2K tokens under BF16 precision and Int8 or NF4 quantization levels, respectively.
+
+| Quantization Level | Inference Speed with flash_attn (tokens/s) | Inference Speed w/o flash_attn (tokens/s) |
+| ---------------------- | :----------------------------------------: | :---------------------------------------: |
+| BF16 (no quantization) | 30.06 | 27.55 |
+| Int8 (bnb) | 7.94 | 7.86 |
+| NF4 (bnb) | 21.43 | 20.37 |
+
+In detail, the setting of profiling is generating 2048 new tokens with 1 context token. The profiling runs on single A100-SXM4-80G GPU with PyTorch 2.0.1 and CUDA 11.8. The inference speed is averaged over the generated 2048 tokens.
+
+### GPU Memory Usage
+
+We also profile the peak GPU memory usage for encoding 2048 tokens as context (and generating single token) and generating 8192 tokens (with single token as context) under BF16 or Int8/NF4 quantization levels, respectively. The results are shown below.
+
+When using flash attention, the memory usage is:
+
+| Quantization Level | Peak Usage for Encoding 2048 Tokens | Peak Usage for Generating 8192 Tokens |
+| ------------------ | :---------------------------------: | :-----------------------------------: |
+| BF16 | 18.11GB | 23.52GB |
+| Int8 | 12.17GB | 17.60GB |
+| NF4 | 9.52GB | 14.93GB |
+
+When not using flash attention, the memory usage is:
+
+| Quantization Level | Peak Usage for Encoding 2048 Tokens | Peak Usage for Generating 8192 Tokens |
+| ------------------ | :---------------------------------: | :-----------------------------------: |
+| BF16 | 18.11GB | 24.40GB |
+| Int8 | 12.18GB | 18.47GB |
+| NF4 | 9.52GB | 15.81GB |
+
+The above speed and memory profiling are conducted using [this script](https://qianwen-res.oss-cn-beijing.aliyuncs.com/profile.py).
+
+## Demo
+
+
+### Web UI
+
+We provide code for users to build a web UI demo (thanks to @wysaid). Before you start, make sure you install the following packages:
+
+```
+pip install -r requirements_web_demo.txt
+```
+
+Then run the command below and click on the generated link:
+
+```
+python web_demo.py
+```
+
+
+
+
+
+
+
+### CLI Demo
+
+We provide a CLI demo example in `cli_demo.py`, which supports streaming output for the generation. Users can interact with Qwen-7B-Chat by inputting prompts, and the model returns model outputs in the streaming mode. Run the command below:
+
+```
+python cli_demo.py
+```
+
+
+
+
+
+
+
+## API
+
+We provide methods to deploy local API based on OpenAI API (thanks to @hanpenggit). Before you start, install the required packages:
+
+```bash
+pip install fastapi uvicorn openai pydantic sse_starlette
+```
+
+Then run the command to deploy your API:
+
+```bash
+python openai_api.py
+```
+
+You can change your arguments, e.g., `-c` for checkpoint name or path, `--cpu-only` for CPU deployment, etc. If you meet problems launching your API deployment, updating the packages to the latest version can probably solve them.
+
+Using the API is also simple. See the example below:
+
+```python
+import openai
+openai.api_base = "http://localhost:8000/v1"
+openai.api_key = "none"
+
+# create a request activating streaming response
+for chunk in openai.ChatCompletion.create(
+ model="Qwen-7B",
+ messages=[
+ {"role": "user", "content": "你好"}
+ ],
+ stream=True
+):
+ if hasattr(chunk.choices[0].delta, "content"):
+ print(chunk.choices[0].delta.content, end="", flush=True)
+
+# create a request not activating streaming response
+response = openai.ChatCompletion.create(
+ model="Qwen-7B",
+ messages=[
+ {"role": "user", "content": "你好"}
+ ],
+ stream=False
+)
+print(response.choices[0].message.content)
+```
+
+
+
+
+
+
+
+## Tool Usage
+
+Qwen-7B-Chat is specifically optimized for tool usage, including API, database, models, etc., so that users can build their own Qwen-7B-based LangChain, Agent, and Code Interpreter. In our evaluation [benchmark](eval/EVALUATION.md) for assessing tool usage capabilities, we find that Qwen-7B reaches stable performance.
+
+| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
+|:------------|:----------------------:|:----------------------:|:----------------------:|
+| GPT-4 | 95% | **0.90** | 15% |
+| GPT-3.5 | 85% | 0.88 | 75% |
+| **Qwen-7B** | **99%** | 0.89 | **9.7%** |
+
+For how to write and use prompts for ReAct Prompting, please refer to [the ReAct examples](examples/react_prompt.md). The use of tools can enable the model to better perform tasks.
+
+Additionally, we provide experimental results to show its capabilities of playing as an agent. See [Hugging Face Agent](https://huggingface.co/docs/transformers/transformers_agents) for more information. Its performance on the run-mode benchmark provided by Hugging Face is as follows:
+
+| Model | Tool Selection↑ | Tool Used↑ | Code↑ |
+|:---------------|:---------------:|:-----------:|:---------:|
+|GPT-4 | **100** | **100** | **97.41** |
+|GPT-3.5 | 95.37 | 96.30 | 87.04 |
+|StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
+| **Qwen-7B** | 90.74 | 92.59 | 74.07 |
+
+## Long-Context Understanding
+
+To extend the context length and break the bottleneck of training sequence length, we introduce several techniques, including NTK-aware interpolation, window attention, and LogN attention scaling, to extend the context length to over 8K tokens. We conduct language modeling experiments on the arXiv dataset with the PPL evaluation and find that Qwen-7B can reach outstanding performance in the scenario of long context. Results are demonstrated below:
+
+
+
+
Model
Sequence Length
+
+
+
1024
2048
4096
8192
16384
+
+
+
Qwen-7B
4.23
3.78
39.35
469.81
2645.09
+
+
+
+ dynamic_ntk
4.23
3.78
3.59
3.66
5.71
+
+
+
+ dynamic_ntk + logn
4.23
3.78
3.58
3.56
4.62
+
+
+
+ dynamic_ntk + logn + window_attn
4.23
3.78
3.58
3.49
4.32
+
+
+
+## Reproduction
+
+For your reproduction of the model performance on benchmark datasets, we provide scripts for you to reproduce the results. Check [eval/EVALUATION.md](eval/EVALUATION.md) for more information. Note that the reproduction may lead to slight differences from our reported results.
+
+## FAQ
+
+If you meet problems, please refer to [FAQ](FAQ.md) and the issues first to search a solution before you launch a new issue.
+
+## License Agreement
+
+Researchers and developers are free to use the codes and model weights of both Qwen-7B and Qwen-7B-Chat. We also allow their commercial use. Check our license at [LICENSE](LICENSE) for more details. If you have requirements for commercial use, please fill out the [form](https://dashscope.console.aliyun.com/openModelApply/qianwen) to apply.
+
+## Contact Us
+
+If you are interested to leave a message to either our research team or product team, feel free to send an email to qianwen_opensource@alibabacloud.com.
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/README_CN.md b/README_CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..af4d8f9750a836ad7217c136c7d5432bd7018cb3
--- /dev/null
+++ b/README_CN.md
@@ -0,0 +1,436 @@
+
+
+
+
+## 再現
+
+ベンチマークデータセットでのモデル性能の再現のために、結果を再現するスクリプトを提供しています。詳しくは [eval/EVALUATION.md](eval/EVALUATION.md) を確認してください。なお、再現の結果、我々の報告結果と若干異なる場合がある。
+
+## FAQ
+
+問題が発生した場合は、[FAQ](FAQ.md)やissueを参照し、新しいissueを立ち上げる前に解決策を探してください。
+
+## ライセンス契約
+
+Qwen-7B と Qwen-7B-Chat のコードとモデルウェイトは、研究者や開発者が自由に使用することができます。また、商用利用も可能です。詳しくは [LICENSE](LICENSE) をご覧ください。商用利用を希望される方は、[リクエストフォーム](https://dashscope.console.aliyun.com/openModelApply/qianwen)に必要事項をご記入の上、お申し込みください。
+
+## お問い合わせ
+
+研究チームまたは製品チームへのメッセージは、qianwen_opensource@alibabacloud.com までお気軽にお送りください。
+
diff --git a/assets/cli_demo.gif b/assets/cli_demo.gif
new file mode 100644
index 0000000000000000000000000000000000000000..4d7be88b219cd0f0416a3f3480978188109151f6
--- /dev/null
+++ b/assets/cli_demo.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2502b56784d9e2ba70094c040f80b3571d4969a24d27b126ce22ed489b3e31f1
+size 1981045
diff --git a/assets/hfagent_chat_1.png b/assets/hfagent_chat_1.png
new file mode 100644
index 0000000000000000000000000000000000000000..bcd64182f575c6068b0c5dd31aa624ee0fbe4757
--- /dev/null
+++ b/assets/hfagent_chat_1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:356ea19c2c4a656cae9d55e2d727d1651d1955ec67385615c6582b394478e889
+size 1708738
diff --git a/assets/hfagent_chat_2.png b/assets/hfagent_chat_2.png
new file mode 100644
index 0000000000000000000000000000000000000000..527c2421784005ac7d5978ca1d6418424ffaec04
--- /dev/null
+++ b/assets/hfagent_chat_2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7db53a1a77dfc19072ce418db6df56fd89f9e7cb2e30430ac8320f10fc8a8bc0
+size 1927640
diff --git a/assets/hfagent_run.png b/assets/hfagent_run.png
new file mode 100644
index 0000000000000000000000000000000000000000..02937cb95aae0a8fc8bc2717daf64ee972044493
--- /dev/null
+++ b/assets/hfagent_run.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fbf4c1232c86e334b5425aacdcc9e7a878100f80d6d70725060cb312bae7d701
+size 2770957
diff --git a/assets/logo.jpg b/assets/logo.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6d11f3ae1c69da9f82f632b55b641d0836c86437
Binary files /dev/null and b/assets/logo.jpg differ
diff --git a/assets/openai_api.gif b/assets/openai_api.gif
new file mode 100644
index 0000000000000000000000000000000000000000..3152fd50795c05586f7707bc9e7945c8d721aa6e
--- /dev/null
+++ b/assets/openai_api.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b457ed0497eba0dff8e2a11093662a548df73c09b506f590a94cab9535a6b83b
+size 1201656
diff --git a/assets/performance.png b/assets/performance.png
new file mode 100644
index 0000000000000000000000000000000000000000..0f10765d135c2273458e4eebc8803fb77a7dd5f1
Binary files /dev/null and b/assets/performance.png differ
diff --git a/assets/qwen_tokenizer.png b/assets/qwen_tokenizer.png
new file mode 100644
index 0000000000000000000000000000000000000000..a6b0366bfac8b3825ebec80de7a0b112bc275d5e
Binary files /dev/null and b/assets/qwen_tokenizer.png differ
diff --git a/assets/react_showcase_001.png b/assets/react_showcase_001.png
new file mode 100644
index 0000000000000000000000000000000000000000..474c59fb3fa7ecd51e5f51f9ef5d55cd00f7be84
Binary files /dev/null and b/assets/react_showcase_001.png differ
diff --git a/assets/react_showcase_002.png b/assets/react_showcase_002.png
new file mode 100644
index 0000000000000000000000000000000000000000..eef8ce6a71250b086d98ce5b3f4fb42d875f2b85
Binary files /dev/null and b/assets/react_showcase_002.png differ
diff --git a/assets/react_tutorial_001.png b/assets/react_tutorial_001.png
new file mode 100644
index 0000000000000000000000000000000000000000..b9629be9c9c4c7f294013cb3b7db6e915b460f40
Binary files /dev/null and b/assets/react_tutorial_001.png differ
diff --git a/assets/react_tutorial_002.png b/assets/react_tutorial_002.png
new file mode 100644
index 0000000000000000000000000000000000000000..1d9ede6217d5b9e415df414e1f3199b6f13f37f1
Binary files /dev/null and b/assets/react_tutorial_002.png differ
diff --git a/assets/tokenizer.pdf b/assets/tokenizer.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..f33e7e55ebf76c5a2d9457270608bbd0a257c9dd
Binary files /dev/null and b/assets/tokenizer.pdf differ
diff --git a/assets/tokenizer.png b/assets/tokenizer.png
new file mode 100644
index 0000000000000000000000000000000000000000..b16c0cdee53be705226d57b4c50d65138a2a0dbf
Binary files /dev/null and b/assets/tokenizer.png differ
diff --git a/assets/wanx_colorful_black.png b/assets/wanx_colorful_black.png
new file mode 100644
index 0000000000000000000000000000000000000000..b5db29b910322dc2253ef71be65fb05b43a20d9a
--- /dev/null
+++ b/assets/wanx_colorful_black.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:650a5431b1a3b4411fc4c2fd44dea3066a4ec67b03b684721086265698d738c4
+size 1326970
diff --git a/assets/web_demo.gif b/assets/web_demo.gif
new file mode 100644
index 0000000000000000000000000000000000000000..535a4b6863cda291e24000a4d1c31bc136e0bff1
--- /dev/null
+++ b/assets/web_demo.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4a721165a571d1b8a22861d0c489f5b6ce5bb1df44470fff957bc8704e2bf996
+size 18786391
diff --git a/cli_demo.py b/cli_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a095add560f228a0c15a10e28e80f4db1acb114
--- /dev/null
+++ b/cli_demo.py
@@ -0,0 +1,194 @@
+# Copyright (c) Alibaba Cloud.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""A simple command-line interactive chat demo."""
+
+import argparse
+import os
+import platform
+import shutil
+from copy import deepcopy
+
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from transformers.generation import GenerationConfig
+from transformers.trainer_utils import set_seed
+
+DEFAULT_CKPT_PATH = 'QWen/QWen-7B-Chat'
+
+_WELCOME_MSG = '''\
+Welcome to use Qwen-7B-Chat model, type text to start chat, type :h to show command help
+欢迎使用 Qwen-7B 模型,输入内容即可进行对话,:h 显示命令帮助
+'''
+_HELP_MSG = '''\
+Commands:
+ :help / :h Show this help message 显示帮助信息
+ :exit / :quit / :q Exit the demo 退出Demo
+ :clear / :cl Clear screen 清屏
+ :clear-his / :clh Clear history 清除对话历史
+ :history / :his Show history 显示对话历史
+ :seed Show current random seed 显示当前随机种子
+ :seed Set random seed to 设置随机种子
+ :conf Show current generation config 显示生成配置
+ :conf = Change generation config 修改生成配置
+ :reset-conf Reset generation config 重置生成配置
+'''
+
+
+def _load_model_tokenizer(args):
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
+ )
+
+ if args.cpu_only:
+ device_map = "cpu"
+ else:
+ device_map = "auto"
+
+ model = AutoModelForCausalLM.from_pretrained(
+ args.checkpoint_path,
+ device_map=device_map,
+ trust_remote_code=True,
+ resume_download=True,
+ ).eval()
+ model.generation_config = GenerationConfig.from_pretrained(
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
+ )
+ return model, tokenizer
+
+
+def _clear_screen():
+ if platform.system() == "Windows":
+ os.system("cls")
+ else:
+ os.system("clear")
+
+
+def _print_history(history):
+ terminal_width = shutil.get_terminal_size()[0]
+ print(f'History ({len(history)})'.center(terminal_width, '='))
+ for index, (query, response) in enumerate(history):
+ print(f'User[{index}]: {query}')
+ print(f'QWen[{index}]: {response}')
+ print('=' * terminal_width)
+
+
+def _get_input() -> str:
+ while True:
+ try:
+ message = input('User> ').strip()
+ except UnicodeDecodeError:
+ print('[ERROR] Encoding error in input')
+ continue
+ except KeyboardInterrupt:
+ exit(1)
+ if message:
+ return message
+ print('[ERROR] Query is empty')
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='QWen-7B-Chat command-line interactive chat demo.')
+ parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
+ help="Checkpoint name or path, default to %(default)r")
+ parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
+ parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
+ args = parser.parse_args()
+
+ history, response = [], ''
+
+ model, tokenizer = _load_model_tokenizer(args)
+ orig_gen_config = deepcopy(model.generation_config)
+
+ _clear_screen()
+ print(_WELCOME_MSG)
+
+ seed = args.seed
+
+ while True:
+ query = _get_input()
+
+ # Process commands.
+ if query.startswith(':'):
+ command_words = query[1:].strip().split()
+ if not command_words:
+ command = ''
+ else:
+ command = command_words[0]
+
+ if command in ['exit', 'quit', 'q']:
+ break
+ elif command in ['clear', 'cl']:
+ _clear_screen()
+ print(_WELCOME_MSG)
+ continue
+ elif command in ['clear-history', 'clh']:
+ print(f'[INFO] All {len(history)} history cleared')
+ history.clear()
+ continue
+ elif command in ['help', 'h']:
+ print(_HELP_MSG)
+ continue
+ elif command in ['history', 'his']:
+ _print_history(history)
+ continue
+ elif command in ['seed']:
+ if len(command_words) == 1:
+ print(f'[INFO] Current random seed: {seed}')
+ continue
+ else:
+ new_seed_s = command_words[1]
+ try:
+ new_seed = int(new_seed_s)
+ except ValueError:
+ print(f'[WARNING] Fail to change random seed: {new_seed_s!r} is not a valid number')
+ else:
+ print(f'[INFO] Random seed changed to {new_seed}')
+ seed = new_seed
+ continue
+ elif command in ['conf']:
+ if len(command_words) == 1:
+ print(model.generation_config)
+ else:
+ for key_value_pairs_str in command_words[1:]:
+ eq_idx = key_value_pairs_str.find('=')
+ if eq_idx == -1:
+ print('[WARNING] format: =')
+ continue
+ conf_key, conf_value_str = key_value_pairs_str[:eq_idx], key_value_pairs_str[eq_idx + 1:]
+ try:
+ conf_value = eval(conf_value_str)
+ except Exception as e:
+ print(e)
+ continue
+ else:
+ print(f'[INFO] Change config: model.generation_config.{conf_key} = {conf_value}')
+ setattr(model.generation_config, conf_key, conf_value)
+ continue
+ elif command in ['reset-conf']:
+ print('[INFO] Reset generation config')
+ model.generation_config = deepcopy(orig_gen_config)
+ print(model.generation_config)
+ continue
+ else:
+ # As normal query.
+ pass
+
+ # Run chat.
+ set_seed(seed)
+ try:
+ for response in model.chat_stream(tokenizer, query, history=history):
+ _clear_screen()
+ print(f"\nUser: {query}")
+ print(f"\nQwen-7B: {response}")
+ except KeyboardInterrupt:
+ print('[WARNING] Generation interrupted')
+ continue
+
+ history.append((query, response))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/eval/EVALUATION.md b/eval/EVALUATION.md
new file mode 100644
index 0000000000000000000000000000000000000000..44e0af62a0e01674702b564b43b41c17b3835eeb
--- /dev/null
+++ b/eval/EVALUATION.md
@@ -0,0 +1,83 @@
+## 评测复现
+
+- CEVAL
+
+```Shell
+wget https://huggingface.co/datasets/ceval/ceval-exam/resolve/main/ceval-exam.zip
+mkdir data/ceval
+mv ceval-exam.zip data/ceval
+cd data/ceval; unzip ceval-exam.zip
+cd ../../
+
+# Qwen-7B
+python evaluate_ceval.py -d data/ceval/
+
+# Qwen-7B-Chat
+pip install thefuzz
+python evaluate_chat_ceval.py -d data/ceval/
+```
+
+- MMLU
+
+```Shell
+wget https://people.eecs.berkeley.edu/~hendrycks/data.tar
+mkdir data/mmlu
+mv data.tar data/mmlu
+cd data/mmlu; tar xf data.tar
+cd ../../
+
+# Qwen-7B
+python evaluate_mmlu.py -d data/mmlu/data/
+
+# Qwen-7B-Chat
+pip install thefuzz
+python evaluate_chat_mmlu.py -d data/mmlu/data/
+```
+
+- HumanEval
+
+Get the HumanEval.jsonl file from [here](https://github.com/openai/human-eval/tree/master/data)
+
+```Shell
+git clone https://github.com/openai/human-eval
+pip install -e human-eval
+
+# Qwen-7B
+python evaluate_humaneval.py -f HumanEval.jsonl -o HumanEval_res.jsonl
+evaluate_functional_correctness HumanEval_res.jsonl
+# Qwen-7B-Chat
+python evaluate_chat_mmlu.py -f HumanEval.jsonl -o HumanEval_res_chat.jsonl
+evaluate_functional_correctness HumanEval_res_chat.jsonl
+```
+
+When installing package human-eval, please note its following disclaimer:
+
+This program exists to run untrusted model-generated code. Users are strongly encouraged not to do so outside of a robust security sandbox. The execution call in execution.py is deliberately commented out to ensure users read this disclaimer before running code in a potentially unsafe manner. See the comment in execution.py for more information and instructions.
+
+- GSM8K
+
+```Shell
+# Qwen-7B
+python evaluate_gsm8k.py
+
+# Qwen-7B-Chat
+python evaluate_chat_gsm8k.py # zeroshot
+python evaluate_chat_gsm8k.py --use-fewshot # fewshot
+```
+
+- PLUGIN
+
+This script is used to reproduce the results of the ReAct and Hugging Face Agent in the Tool Usage section of the README document.
+
+```Shell
+# Qwen-7B-Chat
+mkdir data;
+cd data;
+wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/opensource_data/exam_plugin_v1/exam_plugin_v1_react_positive.jsonl;
+wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/opensource_data/exam_plugin_v1/exam_plugin_v1_react_negative.jsonl;
+cd ..;
+pip install json5;
+pip install jsonlines;
+pip install rouge_score;
+python evaluate_plugin.py --eval-react-positive --eval-react-negative --eval-hfagent
+```
diff --git a/eval/evaluate_ceval.py b/eval/evaluate_ceval.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1616a530ca812eecd580737df498cd2a912d27c
--- /dev/null
+++ b/eval/evaluate_ceval.py
@@ -0,0 +1,263 @@
+import os
+import pandas as pd
+import numpy as np
+import argparse
+import datasets
+import torch
+
+from typing import List
+from tqdm import tqdm
+from transformers.trainer_utils import set_seed
+
+
+'''
+wget https://huggingface.co/datasets/ceval/ceval-exam/resolve/main/ceval-exam.zip
+mkdir data/ceval
+mv ceval-exam.zip data/ceval
+cd data/ceval; unzip ceval-exam.zip
+cd ../../
+python evaluate_ceval.py -d data/ceval/
+'''
+
+def load_models_tokenizer(args):
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+ from transformers.generation import GenerationConfig
+
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval()
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
+ return model, tokenizer
+
+
+def format_example(line, include_answer=True):
+ example = '问题:' + line['question']
+ for choice in choices:
+ example += f'\n{choice}. {line[f"{choice}"]}'
+
+ if include_answer:
+ example += '\n答案:' + line["answer"] + '\n\n'
+ else:
+ example += '\n答案:'
+ return example
+
+
+def generate_few_shot_prompt(k, subject, dev_df):
+ prompt = ''
+ if k == -1:
+ k = dev_df.shape[0]
+ for i in range(k):
+ prompt += format_example(
+ dev_df.iloc[i, :],
+ include_answer=True,
+ )
+ return prompt
+
+
+def get_logits(tokenizer, model, inputs: List[str]):
+ input_ids = tokenizer(inputs, padding=False)['input_ids']
+ input_ids = torch.tensor(input_ids, device=model.device)
+ tokens = {'input_ids': input_ids}
+
+ outputs = model(input_ids)['logits']
+ logits = outputs[:, -1, :]
+ log_probs = torch.nn.functional.softmax(logits, dim=-1)
+ return log_probs, {'tokens': tokens}
+
+
+@torch.no_grad()
+def eval_subject(
+ model,
+ tokenizer,
+ subject_name,
+ test_df,
+ k=5,
+ dev_df=None,
+ few_shot=False,
+ save_result_dir=None,
+ **kwargs
+):
+ result = []
+ score = []
+
+ few_shot_prompt = generate_few_shot_prompt(
+ k, subject_name, dev_df) if few_shot else ''
+ all_probs = {'prob_A': [], 'prob_B': [], 'prob_C': [], 'prob_D': []}
+ if args.debug: print(f"few_shot_prompt: {few_shot_prompt}")
+
+ for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
+ question = format_example(row, include_answer=False)
+ full_prompt = few_shot_prompt + question
+
+ output, input_info = get_logits(tokenizer, model, [full_prompt])
+ assert output.shape[0] == 1
+ logits = output.flatten()
+
+ softval = torch.nn.functional.softmax(
+ torch.tensor(
+ [
+ logits[tokenizer("A")['input_ids']],
+ logits[tokenizer("B")['input_ids']],
+ logits[tokenizer("C")['input_ids']],
+ logits[tokenizer("D")['input_ids']],
+ ]
+ ),
+ dim=0,
+ )
+ if softval.dtype in {torch.bfloat16, torch.float16}:
+ softval = softval.to(dtype=torch.float32)
+ probs = softval.detach().cpu().numpy()
+
+ for i, choice in enumerate(choices):
+ all_probs[f'prob_{choice}'].append(probs[i])
+ pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
+
+ if 'answer' in row:
+ correct = 1 if pred == row['answer'] else 0
+ score.append(correct)
+ if args.debug: print(f'{question} pred: {pred} ref: {row["answer"]}')
+ result.append(pred)
+
+ if score:
+ correct_ratio = 100 * sum(score) / len(score)
+ if args.debug: print(subject_name, correct_ratio)
+ else:
+ correct_ratio = 0
+ if save_result_dir:
+ test_df['model_output'] = result
+ for i, choice in enumerate(choices):
+ test_df[f'prob_{choice}'] = (all_probs[f'prob_{choice}'])
+ if score:
+ test_df["correctness"] = score
+ os.makedirs(save_result_dir, exist_ok=True)
+ test_df.to_csv(os.path.join(
+ save_result_dir, f'{subject_name}_result.csv'), encoding="utf-8", index=False)
+
+ return correct_ratio
+
+
+def cal_ceval(res):
+ acc_sum_dict = dict()
+ acc_norm_sum_dict = dict()
+ cnt_dict = dict()
+ acc_sum = 0.
+ cnt = 0
+ hard_cnt = 0
+ hard_acc_sum = 0.
+ for tt in res.keys():
+ name = tt.split('-')[-1]
+ acc_sum += float(res[tt])
+ cnt += 1
+ class_ = TASK_NAME_MAPPING[name][2]
+ if class_ not in acc_sum_dict:
+ acc_sum_dict[class_] = 0.
+ acc_norm_sum_dict[class_] = 0.
+ cnt_dict[class_] = 0.
+ if name in hard_list:
+ hard_cnt += 1
+ hard_acc_sum += float(res[tt])
+ acc_sum_dict[class_] += float(res[tt])
+ cnt_dict[class_] += 1
+ print('\n\n\n')
+ for k in ['STEM', 'Social Science', 'Humanities', 'Other']:
+ if k in cnt_dict:
+ print('%s acc: %.2f ' % (
+ k, acc_sum_dict[k] / cnt_dict[k]))
+ if hard_cnt > 0:
+ print('Hard acc:%.2f ' % (hard_acc_sum / hard_cnt))
+ print('AVERAGE acc:%.2f ' % (acc_sum / cnt))
+
+
+TASK_NAME_MAPPING = {
+ "computer_network": ["Computer Network", "\u8ba1\u7b97\u673a\u7f51\u7edc", "STEM"],
+ "operating_system": ["Operating System", "\u64cd\u4f5c\u7cfb\u7edf", "STEM"],
+ "computer_architecture": ["Computer Architecture", "\u8ba1\u7b97\u673a\u7ec4\u6210", "STEM"],
+ "college_programming": ["College Programming", "\u5927\u5b66\u7f16\u7a0b", "STEM"],
+ "college_physics": ["College Physics", "\u5927\u5b66\u7269\u7406", "STEM"],
+ "college_chemistry": ["College Chemistry", "\u5927\u5b66\u5316\u5b66", "STEM"],
+ "advanced_mathematics": ["Advanced Mathematics", "\u9ad8\u7b49\u6570\u5b66", "STEM"],
+ "probability_and_statistics": ["Probability and Statistics", "\u6982\u7387\u7edf\u8ba1", "STEM"],
+ "discrete_mathematics": ["Discrete Mathematics", "\u79bb\u6563\u6570\u5b66", "STEM"],
+ "electrical_engineer": ["Electrical Engineer", "\u6ce8\u518c\u7535\u6c14\u5de5\u7a0b\u5e08", "STEM"],
+ "metrology_engineer": ["Metrology Engineer", "\u6ce8\u518c\u8ba1\u91cf\u5e08", "STEM"],
+ "high_school_mathematics": ["High School Mathematics", "\u9ad8\u4e2d\u6570\u5b66", "STEM"],
+ "high_school_physics": ["High School Physics", "\u9ad8\u4e2d\u7269\u7406", "STEM"],
+ "high_school_chemistry": ["High School Chemistry", "\u9ad8\u4e2d\u5316\u5b66", "STEM"],
+ "high_school_biology": ["High School Biology", "\u9ad8\u4e2d\u751f\u7269", "STEM"],
+ "middle_school_mathematics": ["Middle School Mathematics", "\u521d\u4e2d\u6570\u5b66", "STEM"],
+ "middle_school_biology": ["Middle School Biology", "\u521d\u4e2d\u751f\u7269", "STEM"],
+ "middle_school_physics": ["Middle School Physics", "\u521d\u4e2d\u7269\u7406", "STEM"],
+ "middle_school_chemistry": ["Middle School Chemistry", "\u521d\u4e2d\u5316\u5b66", "STEM"],
+ "veterinary_medicine": ["Veterinary Medicine", "\u517d\u533b\u5b66", "STEM"],
+ "college_economics": ["College Economics", "\u5927\u5b66\u7ecf\u6d4e\u5b66", "Social Science"],
+ "business_administration": ["Business Administration", "\u5de5\u5546\u7ba1\u7406", "Social Science"],
+ "marxism": ["Marxism", "\u9a6c\u514b\u601d\u4e3b\u4e49\u57fa\u672c\u539f\u7406", "Social Science"],
+ "mao_zedong_thought": ["Mao Zedong Thought", "\u6bdb\u6cfd\u4e1c\u601d\u60f3\u548c\u4e2d\u56fd\u7279\u8272\u793e\u4f1a\u4e3b\u4e49\u7406\u8bba\u4f53\u7cfb\u6982\u8bba", "Social Science"],
+ "education_science": ["Education Science", "\u6559\u80b2\u5b66", "Social Science"],
+ "teacher_qualification": ["Teacher Qualification", "\u6559\u5e08\u8d44\u683c", "Social Science"],
+ "high_school_politics": ["High School Politics", "\u9ad8\u4e2d\u653f\u6cbb", "Social Science"],
+ "high_school_geography": ["High School Geography", "\u9ad8\u4e2d\u5730\u7406", "Social Science"],
+ "middle_school_politics": ["Middle School Politics", "\u521d\u4e2d\u653f\u6cbb", "Social Science"],
+ "middle_school_geography": ["Middle School Geography", "\u521d\u4e2d\u5730\u7406", "Social Science"],
+ "modern_chinese_history": ["Modern Chinese History", "\u8fd1\u4ee3\u53f2\u7eb2\u8981", "Humanities"],
+ "ideological_and_moral_cultivation": ["Ideological and Moral Cultivation", "\u601d\u60f3\u9053\u5fb7\u4fee\u517b\u4e0e\u6cd5\u5f8b\u57fa\u7840", "Humanities"],
+ "logic": ["Logic", "\u903b\u8f91\u5b66", "Humanities"],
+ "law": ["Law", "\u6cd5\u5b66", "Humanities"],
+ "chinese_language_and_literature": ["Chinese Language and Literature", "\u4e2d\u56fd\u8bed\u8a00\u6587\u5b66", "Humanities"],
+ "art_studies": ["Art Studies", "\u827a\u672f\u5b66", "Humanities"],
+ "professional_tour_guide": ["Professional Tour Guide", "\u5bfc\u6e38\u8d44\u683c", "Humanities"],
+ "legal_professional": ["Legal Professional", "\u6cd5\u5f8b\u804c\u4e1a\u8d44\u683c", "Humanities"],
+ "high_school_chinese": ["High School Chinese", "\u9ad8\u4e2d\u8bed\u6587", "Humanities"],
+ "high_school_history": ["High School History", "\u9ad8\u4e2d\u5386\u53f2", "Humanities"],
+ "middle_school_history": ["Middle School History", "\u521d\u4e2d\u5386\u53f2", "Humanities"],
+ "civil_servant": ["Civil Servant", "\u516c\u52a1\u5458", "Other"],
+ "sports_science": ["Sports Science", "\u4f53\u80b2\u5b66", "Other"],
+ "plant_protection": ["Plant Protection", "\u690d\u7269\u4fdd\u62a4", "Other"],
+ "basic_medicine": ["Basic Medicine", "\u57fa\u7840\u533b\u5b66", "Other"],
+ "clinical_medicine": ["Clinical Medicine", "\u4e34\u5e8a\u533b\u5b66", "Other"],
+ "urban_and_rural_planner": ["Urban and Rural Planner", "\u6ce8\u518c\u57ce\u4e61\u89c4\u5212\u5e08", "Other"],
+ "accountant": ["Accountant", "\u6ce8\u518c\u4f1a\u8ba1\u5e08", "Other"],
+ "fire_engineer": ["Fire Engineer", "\u6ce8\u518c\u6d88\u9632\u5de5\u7a0b\u5e08", "Other"],
+ "environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "\u73af\u5883\u5f71\u54cd\u8bc4\u4ef7\u5de5\u7a0b\u5e08", "Other"],
+ "tax_accountant": ["Tax Accountant", "\u7a0e\u52a1\u5e08", "Other"],
+ "physician": ["Physician", "\u533b\u5e08\u8d44\u683c", "Other"]
+}
+hard_list = ['advanced_mathematics', 'discrete_mathematics', 'probability_and_statistics', 'college_physics', 'college_chemistry', 'high_school_mathematics', 'high_school_physics', 'high_school_chemistry']
+choices = ["A", "B", "C", "D"]
+
+
+def main(args):
+ model, tokenizer = load_models_tokenizer(args)
+
+ dev_result = {}
+ for subject_name in tqdm(TASK_NAME_MAPPING.keys()):
+ val_file_path = os.path.join(args.eval_data_path, 'val', f'{subject_name}_val.csv')
+ dev_file_path = os.path.join(args.eval_data_path, 'dev', f'{subject_name}_dev.csv')
+ # test_file_path = os.path.join(args.eval_data_path, 'test', f'{subject_name}_test.csv')
+ val_df = pd.read_csv(val_file_path)
+ dev_df = pd.read_csv(dev_file_path)
+ # test_df = pd.read_csv(test_file_path)
+
+ score = eval_subject(model, tokenizer, subject_name, val_df, dev_df=dev_df, k=5, few_shot=True,
+ save_result_dir=f"outs/ceval_eval_result")
+ dev_result[subject_name] = score
+ cal_ceval(dev_result)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
+ parser.add_argument('-c', '--checkpoint-path', type=str, help='Checkpoint path', default="Qwen/Qwen-7B")
+ parser.add_argument('-s', '--seed', type=int, default=1234, help='Random seed')
+
+ """Provide extra arguments required for tasks."""
+ group = parser.add_argument_group(title='Evaluation options')
+ group.add_argument('-d', '--eval_data_path', type=str, required=True,
+ help='Path to eval data')
+ group.add_argument("--max-seq-len", type=int, default=2048,
+ help='Size of the output generated text.')
+ group.add_argument("--debug", action='store_true', default=False,
+ help='Print infos.')
+
+ args = parser.parse_args()
+ set_seed(args.seed)
+
+ main(args)
\ No newline at end of file
diff --git a/eval/evaluate_chat_ceval.py b/eval/evaluate_chat_ceval.py
new file mode 100644
index 0000000000000000000000000000000000000000..93434d16988fe4203004810a1a89fc5fd8f395b9
--- /dev/null
+++ b/eval/evaluate_chat_ceval.py
@@ -0,0 +1,290 @@
+import os
+import pandas as pd
+import numpy as np
+import argparse
+import datasets
+import torch
+import re
+from thefuzz import process
+from typing import List
+from tqdm import tqdm
+from transformers.trainer_utils import set_seed
+
+'''
+wget https://huggingface.co/datasets/ceval/ceval-exam/resolve/main/ceval-exam.zip
+mkdir data/ceval
+mv ceval-exam.zip data/ceval
+cd data/ceval; unzip ceval-exam.zip
+cd ../../
+
+pip install thefuzz
+python eval/evaluate_chat_ceval.py -d data/ceval
+'''
+
+def load_models_tokenizer(args):
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+ from transformers.generation import GenerationConfig
+
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True, bf16=True, use_flash_attn=True).eval()
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
+ model.generation_config.do_sample = False # use greedy decoding
+ return model, tokenizer
+
+def process_before_extraction(gen, question, choice_dict):
+ # Example Prompt:
+ # 关于传输层的面向连接服务的特性是____。
+ # A. 既不保证可靠,也不保证按序交付
+ # B. 不保证可靠,但保证按序交付
+ # C. 保证可靠,但不保证按序交付
+ # D. 既保证可靠,也保证按序交付
+ # Example Model Output:
+ # 关于传输层的面向连接服务的特性是既保证可靠,也保证按序交付
+ # Processed Output:
+ # 答案是D
+
+ question_split = question.rstrip("。").split("。")[-1].split("_")
+
+ # replacing the question
+ if len(question_split[0].strip()) > 4:
+ gen = gen.replace(question_split[0], "答案是")
+ if len(question_split[-1].strip()) > 4:
+ gen = gen.replace(question_split[-1], "")
+
+ # replace the choice by letter in the generated sentence
+ # from longest one to shortest one
+ for key, val in sorted(choice_dict.items(), key=lambda x: len(x[1]), reverse=True):
+ gen = gen.replace(val.rstrip("。"), key)
+ return gen
+
+def count_substr(gen, pattern):
+ return len(re.findall(pattern, gen))
+
+def extract_choice(gen, prompt, choice_list):
+ # 答案是A | 选项是A | 应该选A选项
+ res = re.search(r"(?:(?:选|选择|选定)|(?:(?:答案|选项)(?![^ABCD]{0,10}?(?:不|非)[^ABCD]{0,10}?(?:是|为|:|:|】))[^ABCD]{0,10}?(?:是|为|:|:|】))[^ABCD]{0,10}?)(A|B|C|D)(?:选项)?(?:\)|。|\.|,|,|.|、|A|B|C|D|$)", gen)
+
+ # A选项正确 | A选项符合题意
+ if res is None:
+ res = re.search(r"(A|B|C|D)(?:选?项)?(?![^ABCD]{0,4}?(?:不|非)[^ABCD]{0,4}?(?:正确|对|符合))[^ABCD]{0,4}?(?:正确|对|符合)", gen)
+
+ # 直接输出 A
+ if res is None:
+ res = re.search(r"^(A|B|C|D)(?:。|\.|,|,|.|$)", gen)
+
+ # 获取第一个出现的字母
+ if res is None:
+ res = re.search(r"(? 0:
+ print('Hard acc:%.2f ' % (hard_acc_sum / hard_cnt))
+ print('AVERAGE acc:%.2f ' % (acc_sum / cnt))
+
+
+TASK_NAME_MAPPING = {
+ "computer_network": ["Computer Network", "\u8ba1\u7b97\u673a\u7f51\u7edc", "STEM"],
+ "operating_system": ["Operating System", "\u64cd\u4f5c\u7cfb\u7edf", "STEM"],
+ "computer_architecture": ["Computer Architecture", "\u8ba1\u7b97\u673a\u7ec4\u6210", "STEM"],
+ "college_programming": ["College Programming", "\u5927\u5b66\u7f16\u7a0b", "STEM"],
+ "college_physics": ["College Physics", "\u5927\u5b66\u7269\u7406", "STEM"],
+ "college_chemistry": ["College Chemistry", "\u5927\u5b66\u5316\u5b66", "STEM"],
+ "advanced_mathematics": ["Advanced Mathematics", "\u9ad8\u7b49\u6570\u5b66", "STEM"],
+ "probability_and_statistics": ["Probability and Statistics", "\u6982\u7387\u7edf\u8ba1", "STEM"],
+ "discrete_mathematics": ["Discrete Mathematics", "\u79bb\u6563\u6570\u5b66", "STEM"],
+ "electrical_engineer": ["Electrical Engineer", "\u6ce8\u518c\u7535\u6c14\u5de5\u7a0b\u5e08", "STEM"],
+ "metrology_engineer": ["Metrology Engineer", "\u6ce8\u518c\u8ba1\u91cf\u5e08", "STEM"],
+ "high_school_mathematics": ["High School Mathematics", "\u9ad8\u4e2d\u6570\u5b66", "STEM"],
+ "high_school_physics": ["High School Physics", "\u9ad8\u4e2d\u7269\u7406", "STEM"],
+ "high_school_chemistry": ["High School Chemistry", "\u9ad8\u4e2d\u5316\u5b66", "STEM"],
+ "high_school_biology": ["High School Biology", "\u9ad8\u4e2d\u751f\u7269", "STEM"],
+ "middle_school_mathematics": ["Middle School Mathematics", "\u521d\u4e2d\u6570\u5b66", "STEM"],
+ "middle_school_biology": ["Middle School Biology", "\u521d\u4e2d\u751f\u7269", "STEM"],
+ "middle_school_physics": ["Middle School Physics", "\u521d\u4e2d\u7269\u7406", "STEM"],
+ "middle_school_chemistry": ["Middle School Chemistry", "\u521d\u4e2d\u5316\u5b66", "STEM"],
+ "veterinary_medicine": ["Veterinary Medicine", "\u517d\u533b\u5b66", "STEM"],
+ "college_economics": ["College Economics", "\u5927\u5b66\u7ecf\u6d4e\u5b66", "Social Science"],
+ "business_administration": ["Business Administration", "\u5de5\u5546\u7ba1\u7406", "Social Science"],
+ "marxism": ["Marxism", "\u9a6c\u514b\u601d\u4e3b\u4e49\u57fa\u672c\u539f\u7406", "Social Science"],
+ "mao_zedong_thought": ["Mao Zedong Thought", "\u6bdb\u6cfd\u4e1c\u601d\u60f3\u548c\u4e2d\u56fd\u7279\u8272\u793e\u4f1a\u4e3b\u4e49\u7406\u8bba\u4f53\u7cfb\u6982\u8bba", "Social Science"],
+ "education_science": ["Education Science", "\u6559\u80b2\u5b66", "Social Science"],
+ "teacher_qualification": ["Teacher Qualification", "\u6559\u5e08\u8d44\u683c", "Social Science"],
+ "high_school_politics": ["High School Politics", "\u9ad8\u4e2d\u653f\u6cbb", "Social Science"],
+ "high_school_geography": ["High School Geography", "\u9ad8\u4e2d\u5730\u7406", "Social Science"],
+ "middle_school_politics": ["Middle School Politics", "\u521d\u4e2d\u653f\u6cbb", "Social Science"],
+ "middle_school_geography": ["Middle School Geography", "\u521d\u4e2d\u5730\u7406", "Social Science"],
+ "modern_chinese_history": ["Modern Chinese History", "\u8fd1\u4ee3\u53f2\u7eb2\u8981", "Humanities"],
+ "ideological_and_moral_cultivation": ["Ideological and Moral Cultivation", "\u601d\u60f3\u9053\u5fb7\u4fee\u517b\u4e0e\u6cd5\u5f8b\u57fa\u7840", "Humanities"],
+ "logic": ["Logic", "\u903b\u8f91\u5b66", "Humanities"],
+ "law": ["Law", "\u6cd5\u5b66", "Humanities"],
+ "chinese_language_and_literature": ["Chinese Language and Literature", "\u4e2d\u56fd\u8bed\u8a00\u6587\u5b66", "Humanities"],
+ "art_studies": ["Art Studies", "\u827a\u672f\u5b66", "Humanities"],
+ "professional_tour_guide": ["Professional Tour Guide", "\u5bfc\u6e38\u8d44\u683c", "Humanities"],
+ "legal_professional": ["Legal Professional", "\u6cd5\u5f8b\u804c\u4e1a\u8d44\u683c", "Humanities"],
+ "high_school_chinese": ["High School Chinese", "\u9ad8\u4e2d\u8bed\u6587", "Humanities"],
+ "high_school_history": ["High School History", "\u9ad8\u4e2d\u5386\u53f2", "Humanities"],
+ "middle_school_history": ["Middle School History", "\u521d\u4e2d\u5386\u53f2", "Humanities"],
+ "civil_servant": ["Civil Servant", "\u516c\u52a1\u5458", "Other"],
+ "sports_science": ["Sports Science", "\u4f53\u80b2\u5b66", "Other"],
+ "plant_protection": ["Plant Protection", "\u690d\u7269\u4fdd\u62a4", "Other"],
+ "basic_medicine": ["Basic Medicine", "\u57fa\u7840\u533b\u5b66", "Other"],
+ "clinical_medicine": ["Clinical Medicine", "\u4e34\u5e8a\u533b\u5b66", "Other"],
+ "urban_and_rural_planner": ["Urban and Rural Planner", "\u6ce8\u518c\u57ce\u4e61\u89c4\u5212\u5e08", "Other"],
+ "accountant": ["Accountant", "\u6ce8\u518c\u4f1a\u8ba1\u5e08", "Other"],
+ "fire_engineer": ["Fire Engineer", "\u6ce8\u518c\u6d88\u9632\u5de5\u7a0b\u5e08", "Other"],
+ "environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "\u73af\u5883\u5f71\u54cd\u8bc4\u4ef7\u5de5\u7a0b\u5e08", "Other"],
+ "tax_accountant": ["Tax Accountant", "\u7a0e\u52a1\u5e08", "Other"],
+ "physician": ["Physician", "\u533b\u5e08\u8d44\u683c", "Other"]
+}
+hard_list = ['advanced_mathematics', 'discrete_mathematics', 'probability_and_statistics', 'college_physics', 'college_chemistry', 'high_school_mathematics', 'high_school_physics', 'high_school_chemistry']
+choices = ["A", "B", "C", "D"]
+
+
+def main(args):
+ print("loading model weights")
+ if args.checkpoint_path:
+ model, tokenizer = load_models_tokenizer(args)
+ else:
+ model, tokenizer = None, None
+ print("model loaded")
+ dev_result = {}
+ for subject_name in tqdm(TASK_NAME_MAPPING.keys()):
+ val_file_path = os.path.join(args.eval_data_path, 'val', f'{subject_name}_val.csv')
+ # dev_file_path = os.path.join(args.eval_data_path, 'dev', f'{subject_name}_dev.csv')
+ # test_file_path = os.path.join(args.eval_data_path, 'test', f'{subject_name}_test.csv')
+ val_df = pd.read_csv(val_file_path)
+ # dev_df = pd.read_csv(dev_file_path)
+ # test_df = pd.read_csv(test_file_path)
+
+ score = eval_subject(model, tokenizer, subject_name, val_df,
+ save_result_dir=f"outs_chat/ceval_eval_result", overwrite=args.overwrite)
+ dev_result[subject_name] = score
+ cal_ceval(dev_result)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
+ parser.add_argument('-c', '--checkpoint-path', type=str, help='Checkpoint path', default="Qwen/Qwen-7B-Chat")
+ parser.add_argument('-s', '--seed', type=int, default=1234, help='Random seed')
+
+ """Provide extra arguments required for tasks."""
+ group = parser.add_argument_group(title='Evaluation options')
+ group.add_argument('-d', '--eval_data_path', type=str, required=True,
+ help='Path to eval data')
+ group.add_argument("--debug", action='store_true', default=False,
+ help='Print infos.')
+ group.add_argument("--overwrite", action='store_true', default=False,
+ help='Overwrite existed results')
+
+ args = parser.parse_args()
+ set_seed(args.seed)
+
+ main(args)
\ No newline at end of file
diff --git a/eval/evaluate_chat_gsm8k.py b/eval/evaluate_chat_gsm8k.py
new file mode 100644
index 0000000000000000000000000000000000000000..41829388ca432c99ae610502c5944bd217ce074b
--- /dev/null
+++ b/eval/evaluate_chat_gsm8k.py
@@ -0,0 +1,137 @@
+import random
+import tqdm
+import os
+import re
+import sys
+import torch
+import numpy as np
+import jsonlines
+import argparse
+import json
+from pathlib import Path
+from datasets import load_from_disk,load_dataset
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from transformers.generation import GenerationConfig
+
+'''
+python eval/evaluate_chat_gsm8k.py [--use-fewshot]
+'''
+
+INVALID_ANS = "[invalid]"
+DEVICE = "cuda:0"
+
+def doc_to_text(doc, use_fewshot):
+ if use_fewshot:
+ context = "Question: Angelo and Melanie want to plan how many hours over the next week they should study together for their test next week. They have 2 chapters of their textbook to study and 4 worksheets to memorize. They figure out that they should dedicate 3 hours to each chapter of their textbook and 1.5 hours for each worksheet. If they plan to study no more than 4 hours each day, how many days should they plan to study total over the next week if they take a 10-minute break every hour, include 3 10-minute snack breaks each day, and 30 minutes for lunch each day?\nLet's think step by step\n" \
+ "Angelo and Melanie think they should dedicate 3 hours to each of the 2 chapters, 3 hours x 2 chapters = 6 hours total.\nFor the worksheets they plan to dedicate 1.5 hours for each worksheet, 1.5 hours x 4 worksheets = 6 hours total.\nAngelo and Melanie need to start with planning 12 hours to study, at 4 hours a day, 12 / 4 = 3 days.\nHowever, they need to include time for breaks and lunch. Every hour they want to include a 10-minute break, so 12 total hours x 10 minutes = 120 extra minutes for breaks.\nThey also want to include 3 10-minute snack breaks, 3 x 10 minutes = 30 minutes.\nAnd they want to include 30 minutes for lunch each day, so 120 minutes for breaks + 30 minutes for snack breaks + 30 minutes for lunch = 180 minutes, or 180 / 60 minutes per hour = 3 extra hours.\nSo Angelo and Melanie want to plan 12 hours to study + 3 hours of breaks = 15 hours total.\nThey want to study no more than 4 hours each day, 15 hours / 4 hours each day = 3.75\nThey will need to plan to study 4 days to allow for all the time they need.\nThe answer is 4\n\n" \
+ "Question: Mark's basketball team scores 25 2 pointers, 8 3 pointers and 10 free throws. Their opponents score double the 2 pointers but half the 3 pointers and free throws. What's the total number of points scored by both teams added together?\nLet's think step by step\n" \
+ "Mark's team scores 25 2 pointers, meaning they scored 25*2= 50 points in 2 pointers.\nHis team also scores 6 3 pointers, meaning they scored 8*3= 24 points in 3 pointers\nThey scored 10 free throws, and free throws count as one point so they scored 10*1=10 points in free throws.\nAll together his team scored 50+24+10= 84 points\nMark's opponents scored double his team's number of 2 pointers, meaning they scored 50*2=100 points in 2 pointers.\nHis opponents scored half his team's number of 3 pointers, meaning they scored 24/2= 12 points in 3 pointers.\nThey also scored half Mark's team's points in free throws, meaning they scored 10/2=5 points in free throws.\nAll together Mark's opponents scored 100+12+5=117 points\nThe total score for the game is both team's scores added together, so it is 84+117=201 points\nThe answer is 201\n\n" \
+ "Question: Bella has two times as many marbles as frisbees. She also has 20 more frisbees than deck cards. If she buys 2/5 times more of each item, what would be the total number of the items she will have if she currently has 60 marbles?\nLet's think step by step\n" \
+ "When Bella buys 2/5 times more marbles, she'll have increased the number of marbles by 2/5*60 = 24\nThe total number of marbles she'll have is 60+24 = 84\nIf Bella currently has 60 marbles, and she has two times as many marbles as frisbees, she has 60/2 = 30 frisbees.\nIf Bella buys 2/5 times more frisbees, she'll have 2/5*30 = 12 more frisbees.\nThe total number of frisbees she'll have will increase to 30+12 = 42\nBella also has 20 more frisbees than deck cards, meaning she has 30-20 = 10 deck cards\nIf she buys 2/5 times more deck cards, she'll have 2/5*10 = 4 more deck cards.\nThe total number of deck cards she'll have is 10+4 = 14\nTogether, Bella will have a total of 14+42+84 = 140 items\nThe answer is 140\n\n" \
+ "Question: A group of 4 fruit baskets contains 9 apples, 15 oranges, and 14 bananas in the first three baskets and 2 less of each fruit in the fourth basket. How many fruits are there?\nLet's think step by step\n" \
+ "For the first three baskets, the number of apples and oranges in one basket is 9+15=24\nIn total, together with bananas, the number of fruits in one basket is 24+14=38 for the first three baskets.\nSince there are three baskets each having 38 fruits, there are 3*38=114 fruits in the first three baskets.\nThe number of apples in the fourth basket is 9-2=7\nThere are also 15-2=13 oranges in the fourth basket\nThe combined number of oranges and apples in the fourth basket is 13+7=20\nThe fourth basket also contains 14-2=12 bananas.\nIn total, the fourth basket has 20+12=32 fruits.\nThe four baskets together have 32+114=146 fruits.\nThe answer is 146\n\n" \
+ f"Question: {doc['question']}\nLet's think step by step"
+ else:
+ context = doc['question']
+ return context
+
+def decode(tokens_list, tokenizer, raw_text_len):
+ sents = []
+ # print(len(tokens_list))
+ for tokens in tokens_list:
+ tokens = tokens.cpu().numpy().tolist()
+ sent = tokenizer.tokenizer.decode(
+ tokens[raw_text_len:])
+ sent = sent.split('<|endoftext|>')[0]
+ sent = sent.split('\n\n\n')[0]
+ sent = sent.split("\n\n")[0]
+ sent = sent.split("Question:")[0]
+ sents.append(sent)
+ return sents
+
+def generate_sample(model, tokenizer, question):
+ response, history = model.chat(
+ tokenizer,
+ question,
+ history=None,
+ )
+ print(question)
+ print("-------------")
+ print(response)
+ print("=============")
+ return response
+
+
+def extract_answer_hf(completion):
+ def _get_last_digit(s):
+ _PAT_LAST_DIGIT = re.compile(r"(?<=(\s|[\$%#{]))([+-])?(?=(\S))(0|([1-9](\d*|\d{0,2}(,\d{3})*)))?(\.\d*[1-9])?(?=(\s|[.,}]|$))")
+ match = list(_PAT_LAST_DIGIT.finditer(s))
+ if match:
+ last_digit = match[-1].group().replace(",", "").replace("+", "")
+ # print(f"The last digit in {s} is {last_digit}")
+ else:
+ last_digit = None
+ print(f"No digits found in {s!r}")
+ return last_digit
+
+ job_gen = completion.strip('.').replace('\n', '\\n')
+ last_digit = _get_last_digit(job_gen)
+ if last_digit is not None:
+ return eval(last_digit)
+ else:
+ return INVALID_ANS
+
+def extract_answer(completion):
+ try:
+ last_number = re.findall(r'\d+', completion)[-1]
+ return eval(last_number)
+ except:
+ return INVALID_ANS
+
+def is_correct( completion, answer):
+ gold = extract_answer(answer)
+ assert gold != INVALID_ANS, "No ground truth answer found in the document."
+ return extract_answer(completion) == gold
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
+ parser.add_argument("-c", "--checkpoint-path", type=Path, help="Checkpoint path", default="Qwen/Qwen-7B-Chat")
+ parser.add_argument("-f","--sample-input-file", type=str, default=None)
+ parser.add_argument("-o","--sample-output-file", type=str, default="gsm8k_res.jsonl")
+ parser.add_argument("--use-fewshot", action="store_true")
+
+ args = parser.parse_args()
+
+ if args.sample_input_file is not None:
+ dataset = load_from_disk(args.sample_input_file)# or:
+ else:
+ dataset = load_dataset("gsm8k", "main")
+
+ print('Loading tokenizer ...')
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True, bf16=True, use_flash_attn=True)
+
+ print('Loading model ...')
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval()
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
+ model.generation_config.do_sample = False # use greedy decoding
+
+ test = dataset["test"]
+
+ f_output = open(args.sample_output_file, 'w', encoding='utf-8')
+ tot_length = test.num_rows
+ acc_res = []
+ for doc in tqdm.tqdm(test):
+ context = doc_to_text(doc, args.use_fewshot)
+ print(context)
+ completion = generate_sample(model, tokenizer, context)
+ answer = doc["answer"]
+ acc = is_correct(completion, answer)
+ doc["completion"] = completion
+ doc["acc"] = acc
+ f_output.write(json.dumps(doc, ensure_ascii=False) + "\n")
+ f_output.flush()
+ acc_res.append(acc)
+
+ f_output.close()
+ print("4-shot Acc: " if args.use_fewshot else "Zero-shot Acc", np.mean(acc_res))
diff --git a/eval/evaluate_chat_humaneval.py b/eval/evaluate_chat_humaneval.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5bace4eb7e348714c995a13b15fbe9732d4a0c2
--- /dev/null
+++ b/eval/evaluate_chat_humaneval.py
@@ -0,0 +1,82 @@
+import random
+import tqdm
+import os
+import sys
+import torch
+import jsonlines
+import argparse
+import jsonlines
+from pathlib import Path
+import re
+import textwrap
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from transformers.generation import GenerationConfig
+
+"""
+Get the HumanEval.jsonl file from [here](https://github.com/openai/human-eval/tree/master/data)
+
+python eval/evaluate_chat_humaneval.py -f HumanEval.jsonl -o HumanEval_res.jsonl
+git clone https://github.com/openai/human-eval
+pip install -e human-eval
+evaluate_functional_correctness HumanEval_res.jsonl
+"""
+
+DEVICE = "cuda:0"
+
+def extract_code(text, entry_point):
+
+ # 正则表达式匹配代码块
+ code_block_pattern = re.compile(rf"```(?:[Pp]ython\n)?.*?def\s+{entry_point}.*?:\n(.*?)\n```", re.DOTALL)
+ code_block = code_block_pattern.search(text)
+ if code_block is None:
+ code_block_pattern = re.compile(rf"def\s+{entry_point}.*?:\n(.*?)(?:\n(?!\n*(?: |\t))|$)", re.DOTALL)
+ code_block = code_block_pattern.search(text)
+ if code_block is None:
+ code_block_pattern = re.compile(rf"def.*?:\n(.*?)(?:\n(?!\n*(?: |\t))|$)", re.DOTALL)
+ code_block = code_block_pattern.search(text)
+
+ if code_block is not None:
+ return code_block.group(1)
+ else:
+ # if no code block is found, assume the LM is simply filling the code
+ return textwrap.indent(text, ' ' * 4)
+
+def generate_sample(model, tokenizer, question, entry_point):
+ response, history = model.chat(
+ tokenizer,
+ question,
+ history=None,
+ )
+ print(question)
+ print(response)
+ answer = extract_code(response, entry_point)
+ return answer, response
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
+ parser.add_argument("-c", "--checkpoint-path", type=Path, help='Checkpoint path', default="Qwen/Qwen-7B-Chat")
+ parser.add_argument("-f","--sample-input-file", type=str, default=None, help="data path to HumanEval.jsonl")
+ parser.add_argument("-o","--sample-output-file", type=str, default="HumanEval_res.jsonl")
+
+
+ args = parser.parse_args()
+ print('Loading tokenizer ...')
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
+
+ print('Loading model ...')
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True, bf16=True, use_flash_attn=True).eval()
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
+ model.generation_config.do_sample = False # use greedy decoding
+
+ f_output = jsonlines.Writer(open(args.sample_output_file, 'w', encoding='utf-8'))
+
+ f = jsonlines.open(args.sample_input_file)
+ with f_output as output:
+ for jobj in tqdm.tqdm(f, desc='task_idx'):
+ prompt = "Help me fill the following code.\n" + jobj['prompt']
+ task_id = jobj['task_id']
+ answer, response = generate_sample(model, tokenizer, prompt, jobj['entry_point'])
+ gen_jobjs = {'task_id': task_id, "completion": answer, 'response': response}
+ output.write(gen_jobjs)
+ f_output.close()
diff --git a/eval/evaluate_chat_mmlu.py b/eval/evaluate_chat_mmlu.py
new file mode 100644
index 0000000000000000000000000000000000000000..a070228516c966ba64bd754dd45e641179a2bde4
--- /dev/null
+++ b/eval/evaluate_chat_mmlu.py
@@ -0,0 +1,207 @@
+import os
+import pandas as pd
+import numpy as np
+import argparse
+import datasets
+import torch
+import re
+from thefuzz import process
+from typing import List
+from tqdm import tqdm
+from transformers.trainer_utils import set_seed
+
+'''
+wget https://people.eecs.berkeley.edu/~hendrycks/data.tar
+mkdir data/mmlu
+mv data.tar data/mmlu
+cd data/mmlu; tar xf data.tar
+cd ../../
+
+pip install thefuzz
+python eval/evaluate_chat_mmlu.py -d data/mmlu/data/
+'''
+
+def load_models_tokenizer(args):
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+ from transformers.generation import GenerationConfig
+
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True, bf16=True, use_flash_attn=True).eval()
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
+ model.generation_config.do_sample = False # use greedy decoding
+ return model, tokenizer
+
+
+def format_example(line):
+ example = 'The following is a multiple-choice question. Please choose the most suitable one among A, B, C and D as the answer to this question.\n\n' + line['question'] + "\n"
+ for choice in choices:
+ example += f'{choice}. {line[f"{choice}"]}\n'
+ return example
+
+
+def process_before_extraction(gen, choice_dict):
+ # replace the choice by letter in the generated sentence
+ # from longest one to shortest one
+ for key, val in sorted(choice_dict.items(), key=lambda x: len(x[1]), reverse=True):
+ pattern = re.compile(re.escape(val.rstrip(".")), re.IGNORECASE)
+ gen = pattern.sub(key, gen)
+ return gen
+
+def extract_choice(gen, choice_list):
+ # answer is A | choice is A | choose A
+ res = re.search(r"(?:(?:[Cc]hoose)|(?:(?:[Aa]nswer|[Cc]hoice)(?![^ABCD]{0,20}?(?:n't|not))[^ABCD]{0,10}?\b(?:|is|:|be))\b)[^ABCD]{0,20}?\b(A|B|C|D)\b", gen)
+
+ # A is correct | A is right
+ if res is None:
+ res = re.search(r"\b(A|B|C|D)\b(?![^ABCD]{0,8}?(?:n't|not)[^ABCD]{0,5}?(?:correct|right))[^ABCD]{0,10}?\b(?:correct|right)\b", gen)
+
+ # straight answer: A
+ if res is None:
+ res = re.search(r"^(A|B|C|D)(?:\.|,|:|$)", gen)
+
+ # simply extract the first appearred letter
+ if res is None:
+ res = re.search(r"(?')[0]
+ sent = sent.split('\n\n\n')[0]
+ sent = sent.split("\n\n")[0]
+ sent = sent.split("Question:")[0]
+ sents.append(sent)
+ return sents
+
+def generate_sample(model, tokenizer, input_txt):
+ input_ids = tokenizer.tokenizer.encode(input_txt)
+ raw_text_len = len(input_ids)
+ context_enc = torch.tensor(
+ [input_ids]).to(model.device)
+ print(f"Input text: {input_txt}\n")
+ outputs = model.generate(context_enc)
+ output_text = decode(outputs,tokenizer,raw_text_len)[0]
+ print(f"\nOutput text: {output_text}\n")
+ return output_text
+
+
+def extract_answer_hf(completion):
+ match = ANS_RE.search(completion)
+ if match:
+ match_str = match.group(1).strip()
+ match_str = match_str.replace(",", "")
+ return eval(match_str)
+ else:
+ return INVALID_ANS
+
+def extract_answer(completion):
+ try:
+ last_number = re.findall(r'\d+', completion)[-1]
+ return eval(last_number)
+ except:
+ return INVALID_ANS
+
+def is_correct( completion, answer):
+ gold = extract_answer_hf(answer)
+ assert gold != INVALID_ANS, "No ground truth answer found in the document."
+ return extract_answer(completion) == gold
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
+ parser.add_argument("-c", "--checkpoint-path", type=str, help="Checkpoint path", default="Qwen/Qwen-7B")
+ parser.add_argument("-f","--sample-input-file", type=str, default=None)
+ parser.add_argument("-o","--sample-output-file", type=str, default="gsm8k_res.jsonl")
+
+ args = parser.parse_args()
+
+ fewshot_prompt = open("gsm8k_prompt.txt").read()
+ if args.sample_input_file is not None:
+ dataset = load_from_disk(args.sample_input_file)
+ else:
+ config = datasets.DownloadConfig(resume_download=True, max_retries=100)
+ dataset = load_dataset("gsm8k", 'main', download_config=config)
+
+ test = dataset["test"]
+
+ print('Loading tokenizer ...')
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
+
+ print('Loading model ...')
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval()
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
+ model.generation_config.do_sample = False
+
+ f_output = jsonlines.Writer(open(args.sample_output_file, 'w', encoding='utf-8'))
+ tot_length = test.num_rows
+ acc_res = []
+ for doc in test:
+ context = doc_to_text(doc)
+ completion = generate_sample(model, tokenizer, context)
+ answer= doc["answer"]
+ acc = is_correct(completion, answer)
+ doc["completion"]=completion
+ doc["acc"]=acc
+ f_output.write(doc)
+ acc_res.append(acc)
+
+ f_output.close()
+ print("Acc: ",np.mean(acc_res))
\ No newline at end of file
diff --git a/eval/evaluate_humaneval.py b/eval/evaluate_humaneval.py
new file mode 100644
index 0000000000000000000000000000000000000000..af78319f3fa7c4756ddb6b7b7deb43a00ec7e690
--- /dev/null
+++ b/eval/evaluate_humaneval.py
@@ -0,0 +1,70 @@
+import random
+import tqdm
+import os
+import sys
+import torch
+import jsonlines
+import argparse
+import jsonlines
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from transformers.generation import GenerationConfig
+
+"""
+git clone https://github.com/openai/human-eval
+$ pip install -e human-eval
+evaluate_functional_correctness sample-output-file
+"""
+
+def decode(tokens_list, tokenizer, raw_text_len):
+ sents = []
+ # print(len(tokens_list))
+ for tokens in tokens_list:
+ tokens = tokens.cpu().numpy().tolist()
+ sent = tokenizer.tokenizer.decode(
+ tokens[raw_text_len:])
+ sent = sent.split('<|endoftext|>')[0]
+ sent = sent.split('\n\n\n')[0]
+ sent = sent.split("\n\n")[0]
+ sent = sent.split("def ")[0]
+ sents.append(sent)
+ return sents
+
+def generate_sample(model, tokenizer, input_txt):
+ input_ids = tokenizer.tokenizer.encode(input_txt)
+ raw_text_len = len(input_ids)
+ context_enc = torch.tensor([input_ids] ).to(model.device)
+ print(f"Input text: {input_txt}\n")
+ outputs = model.generate(context_enc)
+ output_text = decode(outputs,tokenizer,raw_text_len)[0]
+ print(f"\nOutput text: \n{output_text}\n")
+ return output_text
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
+ parser.add_argument("-c", "--checkpoint-path", type=str, help='Checkpoint path', default="Qwen/Qwen-7B")
+ parser.add_argument("-f","--sample-input-file", type=str, default=None, help="data path to HumanEval.jsonl")
+ parser.add_argument("-o","--sample-output-file", type=str, default="HumanEval_res.jsonl")
+
+
+ args = parser.parse_args()
+ print('Loading tokenizer ...')
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
+
+ print('Loading model ...')
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval()
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
+ model.generation_config.do_sample = False
+
+ f_output = jsonlines.Writer(open(args.sample_output_file, 'w', encoding='utf-8'))
+
+ f = jsonlines.open(args.sample_input_file)
+ with f_output as output:
+ for jobj in tqdm.tqdm(f, desc='task_idx'):
+ prompt = jobj['prompt']
+ task_id = jobj['task_id']
+ gen_sents = generate_sample(model, tokenizer, prompt)
+ gen_jobjs = {'task_id': task_id, "completion": gen_sents}
+ output.write(gen_jobjs)
+ f_output.close()
\ No newline at end of file
diff --git a/eval/evaluate_mmlu.py b/eval/evaluate_mmlu.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b6970c6b74d6297b6d504862fb15d735298097f
--- /dev/null
+++ b/eval/evaluate_mmlu.py
@@ -0,0 +1,218 @@
+import os
+import pandas as pd
+import numpy as np
+import argparse
+import datasets
+import torch
+
+from typing import List
+from tqdm import tqdm
+from transformers.trainer_utils import set_seed
+
+
+'''
+wget https://people.eecs.berkeley.edu/~hendrycks/data.tar
+mkdir data/mmlu
+mv data.tar data/mmlu
+cd data/mmlu; tar xf data.tar
+cd ../../
+python eval/evaluate_mmlu.py -d data/mmlu/data/
+'''
+
+
+def load_models_tokenizer(args):
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+ from transformers.generation import GenerationConfig
+
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval()
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
+ return model, tokenizer
+
+
+def format_example(line, include_answer=True):
+ example = 'Question: ' + line['question']
+ for choice in choices:
+ example += f'\n{choice}. {line[f"{choice}"]}'
+
+ if include_answer:
+ example += '\nAnswer: ' + line["answer"] + '\n\n'
+ else:
+ example += '\nAnswer:'
+ return example
+
+
+def generate_few_shot_prompt(k, subject, dev_df):
+
+ def format_subject(subject):
+ l = subject.split("_")
+ s = ""
+ for entry in l:
+ s += " " + entry
+ return s.strip()
+
+ prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(format_subject(subject))
+
+ if k == -1:
+ k = dev_df.shape[0]
+ for i in range(k):
+ prompt += format_example(
+ dev_df.iloc[i, :],
+ include_answer=True,
+ )
+ return prompt
+
+
+def get_logits(tokenizer, model, inputs: List[str]):
+ input_ids = tokenizer(inputs, padding=False)['input_ids']
+ input_ids = torch.tensor(input_ids, device=model.device)
+
+ if input_ids.shape[1] > args.max_seq_len:
+ input_ids = input_ids[:, input_ids.shape[1]-args.max_seq_len+1:]
+ tokens = {'input_ids': input_ids}
+
+ outputs = model(input_ids)['logits']
+ logits = outputs[:, -1, :]
+ log_probs = torch.nn.functional.softmax(logits, dim=-1)
+ return log_probs, {'tokens': tokens}
+
+
+@torch.no_grad()
+def eval_subject(
+ model,
+ tokenizer,
+ subject_name,
+ test_df,
+ k=5,
+ dev_df=None,
+ few_shot=False,
+ save_result_dir=None,
+ **kwargs
+):
+ result = []
+ score = []
+
+ few_shot_prompt = generate_few_shot_prompt(
+ k, subject_name, dev_df) if few_shot else []
+ all_probs = {'prob_A': [], 'prob_B': [], 'prob_C': [], 'prob_D': []}
+ if args.debug: print(f"few_shot_prompt: {few_shot_prompt}")
+
+ for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
+ question = format_example(row, include_answer=False)
+ full_prompt = few_shot_prompt + question
+
+ output, input_info = get_logits(tokenizer, model, [full_prompt])
+ assert output.shape[0] == 1
+ logits = output.flatten()
+
+ softval = torch.nn.functional.softmax(
+ torch.tensor(
+ [
+ logits[tokenizer(" A")['input_ids']],
+ logits[tokenizer(" B")['input_ids']],
+ logits[tokenizer(" C")['input_ids']],
+ logits[tokenizer(" D")['input_ids']],
+ ]
+ ),
+ dim=0,
+ )
+ if softval.dtype in {torch.bfloat16, torch.float16}:
+ softval = softval.to(dtype=torch.float32)
+ probs = softval.detach().cpu().numpy()
+
+ for i, choice in enumerate(choices):
+ all_probs[f'prob_{choice}'].append(probs[i])
+ pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
+
+ if 'answer' in row:
+ correct = 1 if pred == row['answer'] else 0
+ score.append(correct)
+ if args.debug: print(f'{question} pred: {pred} ref: {row["answer"]}')
+ result.append(pred)
+
+ if save_result_dir:
+ test_df['model_output'] = result
+ for i, choice in enumerate(choices):
+ test_df[f'prob_{choice}'] = (all_probs[f'prob_{choice}'])
+ if score:
+ test_df["correctness"] = score
+ os.makedirs(save_result_dir, exist_ok=True)
+ test_df.to_csv(os.path.join(
+ save_result_dir, f'{subject_name}_result.csv'), encoding="utf-8", index=False)
+
+ return score
+
+
+def cal_mmlu(res):
+ acc_sum_dict = dict()
+ acc_norm_sum_dict = dict()
+ cnt_dict = dict()
+ acc_sum = 0.
+ cnt = 0
+ hard_cnt = 0
+ hard_acc_sum = 0.
+
+ for class_ in TASK_NAME_MAPPING.keys():
+ acc_sum_dict[class_] = 0.
+ acc_norm_sum_dict[class_] = 0.
+ cnt_dict[class_] = 0.
+
+ for tt in TASK_NAME_MAPPING[class_]:
+ acc_sum += sum(res[tt])
+ cnt += len(res[tt])
+
+ acc_sum_dict[class_] += sum(res[tt])
+ cnt_dict[class_] += len(res[tt])
+
+ print('\n\n\n', 'total cnt:', cnt, '\n')
+ for k in TASK_NAME_MAPPING.keys():
+ if k in cnt_dict:
+ print('%s ACC: %.2f ' % (
+ k, acc_sum_dict[k] / cnt_dict[k] * 100))
+ print('AVERAGE ACC:%.2f ' % (acc_sum / cnt * 100))
+
+
+def main(args):
+ model, tokenizer = load_models_tokenizer(args)
+
+ dev_result = {}
+ for subject_name in tqdm(SUBJECTS):
+ # val_file_path = os.path.join(args.eval_data_path, 'val', f'{subject_name}_val.csv')
+ dev_file_path = os.path.join(args.eval_data_path, 'dev', f'{subject_name}_dev.csv')
+ test_file_path = os.path.join(args.eval_data_path, 'test', f'{subject_name}_test.csv')
+ # val_df = pd.read_csv(val_file_path, names=['question','A','B','C','D','answer'])
+ dev_df = pd.read_csv(dev_file_path, names=['question','A','B','C','D','answer'])
+ test_df = pd.read_csv(test_file_path, names=['question','A','B','C','D','answer'])
+
+ score = eval_subject(model, tokenizer, subject_name, test_df, dev_df=dev_df, k=5, few_shot=True,
+ save_result_dir=f"outs/mmlu_eval_result")
+ dev_result[subject_name] = score
+ cal_mmlu(dev_result)
+
+
+TASK_NAME_MAPPING = {'stem': ['abstract_algebra', 'anatomy', 'astronomy', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_physics', 'computer_security', 'conceptual_physics', 'electrical_engineering', 'elementary_mathematics', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_mathematics', 'high_school_physics', 'high_school_statistics', 'machine_learning'],
+ 'Humanities': ['formal_logic', 'high_school_european_history', 'high_school_us_history', 'high_school_world_history', 'international_law', 'jurisprudence', 'logical_fallacies', 'moral_disputes', 'moral_scenarios', 'philosophy', 'prehistory', 'professional_law', 'world_religions'],
+ 'other': ['business_ethics', 'college_medicine', 'human_aging', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'nutrition', 'professional_accounting', 'professional_medicine', 'virology', 'global_facts', 'clinical_knowledge'],
+ 'social': ['econometrics', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_microeconomics', 'high_school_psychology', 'human_sexuality', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy']}
+SUBJECTS = [v for vl in TASK_NAME_MAPPING.values() for v in vl]
+choices = ["A", "B", "C", "D"]
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
+ parser.add_argument('-c', '--checkpoint-path', type=str, help='Checkpoint path', default="Qwen/Qwen-7B")
+ parser.add_argument('-s', '--seed', type=int, default=1234, help='Random seed')
+ parser.add_argument('--gpu', type=int, default=0, help='gpu id')
+
+ """Provide extra arguments required for tasks."""
+ group = parser.add_argument_group(title='Evaluation options')
+ group.add_argument('-d', '--eval_data_path', type=str,
+ help='Path to eval data')
+ group.add_argument("--max-seq-len", type=int, default=2048,
+ help='Size of the output generated text.')
+ group.add_argument("--debug", action='store_true', default=False,
+ help='Print infos.')
+
+ args = parser.parse_args()
+ set_seed(args.seed)
+
+ main(args)
\ No newline at end of file
diff --git a/eval/evaluate_plugin.py b/eval/evaluate_plugin.py
new file mode 100644
index 0000000000000000000000000000000000000000..89974ad7b644d24488f5cae22c09b7bd104f5c6d
--- /dev/null
+++ b/eval/evaluate_plugin.py
@@ -0,0 +1,308 @@
+import argparse
+import json
+import os
+import pprint
+
+import json5
+import jsonlines
+from rouge_score import rouge_scorer
+from tqdm import tqdm
+from transformers import Agent, AutoModelForCausalLM, AutoTokenizer
+from transformers.generation import GenerationConfig
+from transformers.tools.evaluate_agent import evaluate_agent
+from transformers.trainer_utils import set_seed
+
+data_root_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
+ 'data')
+
+
+def is_callable(response, golden):
+ return response['action'].strip().lower() == golden['action'].strip(
+ ).lower()
+
+
+def process_res(response):
+ # parse response
+ response += '\n' # fix not-find bug
+ thought = response[:response.find('Action:')].strip()
+ action = response[response.find('Action:') +
+ len('Action:'):response.find('Action Input:')].strip()
+ action_input = response[response.find('Action Input:') +
+ len('Action Input:'):response.find('Observation:'
+ )].strip()
+ #TODO: This parsing result is incorrect if the response contains multiple Actions. To be fixed in the future.
+ observation = response[response.find('Observation:') +
+ len('Observation:'):response.rfind('Thought:'
+ )].strip()
+ thought_last = response[response.rfind('Thought:') +
+ len('Thought:'):response.find('Final Answer:'
+ )].strip()
+ final_answer = response[response.find('Final Answer:') +
+ len('Final Answer:'):].strip()
+ try:
+ action_input = json.dumps(json5.loads(action_input),
+ ensure_ascii=False,
+ sort_keys=True)
+ except:
+ # print("JSON Load Error:", action_input)
+ pass
+ res_dict = {
+ 'thought': thought,
+ 'action': action,
+ 'action_input': action_input,
+ 'observation': observation,
+ 'thought_last': thought_last,
+ 'final_answer': final_answer
+ }
+ return res_dict
+
+
+class _DummyTokenizer:
+ def tokenize(self, text: str):
+ return text.split()
+
+
+def _get_tokenized_string(tokenizer, text_list):
+ token_ids_list, tokenized_string_list = [], []
+ for text in text_list:
+ assert tokenizer is not None
+ token_ids = tokenizer.encode(text)
+ tokens_bytes = tokenizer.convert_ids_to_tokens(token_ids)
+ tokens = [
+ token.decode('utf-8', errors='replace') for token in tokens_bytes
+ ]
+ tokenized_string = ' '.join(tokens)
+ token_ids_list.append(token_ids)
+ tokenized_string_list.append(tokenized_string)
+ return token_ids_list, tokenized_string_list
+
+
+def eval_action(job):
+ response = job['gen'][0]
+ golden = job['response']
+
+ if 'Action:' in response:
+ response, golden = process_res(response), process_res(golden)
+ if is_callable(response, golden):
+ return True
+ return False
+
+
+def eval_action_input(job, tokenizer):
+ response = job['gen'][0]
+ golden = job['response']
+ response, golden = process_res(response), process_res(golden)
+ query = job['prompt']
+
+ job = {}
+ job['prompt'] = query
+ job['gen'] = response['action_input']
+ job['response'] = golden['action_input']
+
+ job['_gen_tok'], job['_gen_tok_str'] = _get_tokenized_string(
+ tokenizer, [response['action_input']])
+ job['_reference_tok'], job['_reference_tok_str'] = _get_tokenized_string(
+ tokenizer, [golden['action_input']])
+
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'],
+ tokenizer=_DummyTokenizer())
+ score = scorer.score(job['_reference_tok_str'][0], job['_gen_tok_str'][0])
+
+ rouge = score['rougeL'].fmeasure
+
+ return rouge
+
+
+class QWenAgent(Agent):
+ """
+ Agent that uses QWen model and tokenizer to generate code.
+
+ Example:
+
+ ```py
+ agent = QWenAgent()
+ agent.run("Draw me a picture of rivers and lakes.")
+ ```
+ """
+ def __init__(self,
+ chat_prompt_template=None,
+ run_prompt_template=None,
+ additional_tools=None,
+ tokenizer=None,
+ model=None):
+ if tokenizer and model:
+ self.tokenizer = tokenizer
+ self.model = model
+ else:
+ checkpoint = 'Qwen/Qwen-7B-Chat'
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ checkpoint, trust_remote_code=True)
+ self.model = AutoModelForCausalLM.from_pretrained(
+ checkpoint, device_map='auto',
+ trust_remote_code=True).cuda().eval()
+ self.model.generation_config = GenerationConfig.from_pretrained(
+ checkpoint, trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参
+ self.model.generation_config.do_sample = False # greedy
+
+ super().__init__(
+ chat_prompt_template=chat_prompt_template,
+ run_prompt_template=run_prompt_template,
+ additional_tools=additional_tools,
+ )
+
+ def generate_one(self, prompt, stop):
+ # "Human:" 和 "Assistant:" 曾为通义千问的特殊保留字,需要替换为 "_HUMAN_:" 和 "_ASSISTANT_:"。这一问题将在未来版本修复。
+ prompt = prompt.replace('Human:',
+ '_HUMAN_:').replace('Assistant:',
+ '_ASSISTANT_:')
+ stop = [
+ item.replace('Human:', '_HUMAN_:').replace('Assistant:',
+ '_ASSISTANT_:')
+ for item in stop
+ ]
+
+ result, _ = self.model.chat(self.tokenizer, prompt, history=None)
+ for stop_seq in stop:
+ if result.endswith(stop_seq):
+ result = result[:-len(stop_seq)]
+
+ result = result.replace('_HUMAN_:',
+ 'Human:').replace('_ASSISTANT_:', 'Assistant:')
+ return result
+
+
+def load_models_tokenizer(args):
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path,
+ trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path,
+ device_map='auto',
+ trust_remote_code=True,
+ bf16=True,
+ use_flash_attn=True).eval()
+ model.generation_config = GenerationConfig.from_pretrained(
+ args.checkpoint_path, trust_remote_code=True)
+ model.generation_config.do_sample = False # use greedy decoding
+ return model, tokenizer
+
+
+def load_jobs(filename):
+ jobs = []
+ with jsonlines.open(os.path.join(data_root_path, filename),
+ mode='r') as reader:
+ for job in reader:
+ jobs.append(job)
+ return jobs
+
+
+def react_inference(filename, model, tokenizer):
+ filename_cache = filename + '.cache'
+ if os.path.exists(os.path.join(data_root_path, filename_cache)):
+ jobs = load_jobs(filename=filename_cache)
+ print('Loaded from', filename_cache)
+ else:
+ with open(os.path.join(data_root_path, filename_cache), 'w') as f:
+ jobs = load_jobs(filename=filename)
+ print('Inference:', filename)
+ for job in tqdm(jobs):
+ response, history = model.chat(tokenizer,
+ job['prompt'],
+ history=None)
+ job['gen'] = [response]
+ f.writelines(json.dumps(job, ensure_ascii=False) + '\n')
+ print(filename_cache, 'is saved.')
+ return jobs
+
+
+def main(args):
+ print('loading model weights')
+ if args.checkpoint_path is not None:
+ model, tokenizer = load_models_tokenizer(args)
+ else:
+ model, tokenizer = None, None
+ print('model loaded')
+
+ result = {}
+ # eval react positive
+ if args.eval_react_positive:
+ print('eval react positive ...')
+ acc_count = 0
+ rouge_mean = 0
+ jobs = react_inference(filename=args.eval_react_positive_filename,
+ model=model,
+ tokenizer=tokenizer)
+ for job in jobs:
+ if eval_action(job):
+ acc_count += 1
+ rouge = eval_action_input(job, tokenizer)
+ rouge_mean += (rouge / len(jobs))
+
+ scores = {
+ 'action_right_rate': acc_count / len(jobs),
+ 'action_input_rouge': rouge_mean,
+ }
+
+ result.update({'react_positive': scores})
+
+ # eval react negative
+ if args.eval_react_negative:
+ print('eval react negative ...')
+ bad_count = 0
+ jobs = react_inference(filename=args.eval_react_negative_filename,
+ model=model,
+ tokenizer=tokenizer)
+ for job in jobs:
+ if '\nAction:' in job['gen'][0]:
+ bad_count += 1
+ scores = {'bad_rate': bad_count / len(jobs)}
+ result.update({'react_negative': scores})
+
+ # eval hfagent
+ if args.eval_hfagent:
+ print('eval hfagent ...')
+ agent = QWenAgent(model=model, tokenizer=tokenizer)
+ scores = evaluate_agent(agent, verbose=False, return_errors=False)
+ result.update({'hfagent': scores})
+
+ pp = pprint.PrettyPrinter(indent=4)
+ pp.pprint(result)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
+ parser.add_argument('-c',
+ '--checkpoint-path',
+ type=str,
+ help='Checkpoint path',
+ default='Qwen/Qwen-7B-Chat')
+ parser.add_argument('-s',
+ '--seed',
+ type=int,
+ default=1234,
+ help='Random seed')
+ """Provide extra arguments required for tasks."""
+ group = parser.add_argument_group(title='Evaluation options')
+ group.add_argument('--eval-react-positive',
+ action='store_true',
+ default=False,
+ help='Eval react positive.')
+ group.add_argument('--eval-react-positive-filename',
+ type=str,
+ default='exam_plugin_v1_react_positive.jsonl',
+ help='Eval react positive filename.')
+ group.add_argument('--eval-react-negative',
+ action='store_true',
+ default=False,
+ help='Eval react negative.')
+ group.add_argument('--eval-react-negative-filename',
+ type=str,
+ default='exam_plugin_v1_react_negative.jsonl',
+ help='Eval react negative filename.')
+ group.add_argument('--eval-hfagent',
+ action='store_true',
+ default=False,
+ help='Eval hfagent.')
+
+ args = parser.parse_args()
+ set_seed(args.seed)
+
+ main(args)
diff --git a/eval/gsm8k_prompt.txt b/eval/gsm8k_prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..eea39e11963f68b7cef30aad1dacbb4c1df017af
--- /dev/null
+++ b/eval/gsm8k_prompt.txt
@@ -0,0 +1,59 @@
+Question: In 2004, there were 60 kids at a cookout. In 2005, half the number of kids came to the cookout as compared to 2004. In 2006, 2/3 as many kids came to the cookout as in 2005. How many kids came to the cookout in 2006?
+Let's think step by step
+In 2005, 60/2=30 kids came to the cookout.
+In 2006, 30/3*2=20 kids came to the cookout.
+The answer is 20
+
+Question: Zilla spent 7% of her monthly earnings on rent, half of it on her other monthly expenses, and put the rest in her savings. If she spent $133 on her rent, how much does she deposit into her savings account in a month?
+Let's think step by step
+Since $133 is equal to 7% of her earnings, then 1% is equal to $133/7 = $19.
+The total monthly earning of Zilla is represented by 100%, so $19 x 100 = $1900 is her monthly earnings.
+So, $1900/2 = $950 is spent on her other monthly expenses.
+The total amount spent on the rent and other monthly expenses is $133 + $950 = $1083.
+Hence, she saves $1900 - $1083 = $817 per month.
+The answer is 817
+
+Question: If Buzz bought a pizza with 78 slices at a restaurant and then decided to share it with the waiter in the ratio of 5:8, with Buzz's ratio being 5, what's twenty less the number of slices of pizza that the waiter ate?
+Let's think step by step
+The total ratio representing the slices of pizza that Buzz bought is 5+8=13
+If he shared the slices of pizza with the waiter, the waiter received a fraction of 8/13 of the total number of slices, which totals 8/13 * 78 = 48 slices
+Twenty less the number of slices of pizza that the waiter ate is 48-20 = 28
+The answer is 28
+
+Question: Jame gets a raise to $20 per hour and works 40 hours a week. His old job was $16 an hour for 25 hours per week. How much more money does he make per year in his new job than the old job if he works 52 weeks a year?
+Let's think step by step
+He makes 20*40=$800 per week
+He used to make 16*25=$400 per week
+So his raise was 800-400=$400 per week
+So he makes 400*52=$20,800 per year more
+The answer is 20800
+
+Question: Mr. Gardner bakes 20 cookies, 25 cupcakes, and 35 brownies for his second-grade class of 20 students. If he wants to give each student an equal amount of sweet treats, how many sweet treats will each student receive?
+Let's think step by step
+Mr. Gardner bakes a total of 20 + 25 + 35 = 80 sweet treats
+Each student will receive 80 / 20 = 4 sweet treats
+The answer is 4
+
+Question: A used car lot has 24 cars and motorcycles (in total) for sale. A third of the vehicles are motorcycles, and a quarter of the cars have a spare tire included. How many tires are on the used car lot’s vehicles in all?
+Let's think step by step
+The used car lot has 24 / 3 = 8 motorcycles with 2 tires each.
+The lot has 24 - 8 = 16 cars for sale
+There are 16 / 4 = 4 cars with a spare tire with 5 tires each.
+The lot has 16 - 4 = 12 cars with 4 tires each.
+Thus, the used car lot’s vehicles have 8 * 2 + 4 * 5 + 12 * 4 = 16 + 20 + 48 = 84 tires in all.
+The answer is 84
+
+Question: Norma takes her clothes to the laundry. She leaves 9 T-shirts and twice as many sweaters as T-shirts in the washer. When she returns she finds 3 sweaters and triple the number of T-shirts. How many items are missing?
+Let's think step by step
+Norma left 9 T-shirts And twice as many sweaters, she took 9 * 2= 18 sweaters
+Adding the T-shirts and sweaters, Norma left 9 + 18 = 27 clothes
+When she came back, she found 3 sweaters And triple the number of T-shirts, she found 3 * 3 = 9 T-shirts
+Adding the T-shirts and sweaters, Norma found 3 + 9 = 12 clothes
+Subtracting the clothes she left from the clothes she found, 27 - 12 = 15 clothes are missing
+The answer is 15
+
+Question: Adam has an orchard. Every day for 30 days he picks 4 apples from his orchard. After a month, Adam has collected all the remaining apples, which were 230. How many apples in total has Adam collected from his orchard?
+Let's think step by step
+During 30 days Adam picked 4 * 30 = 120 apples.
+So in total with all the remaining apples, he picked 120 + 230 = 350 apples from his orchard.
+The answer is 350
diff --git a/examples/langchain_tooluse.ipynb b/examples/langchain_tooluse.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..de91bceb09b5910812b3ffcbb1e6c30920bf50c7
--- /dev/null
+++ b/examples/langchain_tooluse.ipynb
@@ -0,0 +1,708 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "30e24ef3",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "# 如何让 Qwen-7b 使用 Langchain 中的 工具\n",
+ "\n",
+ "本文档主要介绍如何让千问调用 [LangChain](https://python.langchain.com/docs/get_started/introduction.html) 框架中实现好的谷歌搜索、 WolframAlpha 等工具。将主要基于 [ReAct Prompting](https://github.com/QwenLM/Qwen-7B/blob/main/examples/react_prompt.md) 技术,一种特殊的链式思考(Chain-of-Thought,简称 CoT)提示技巧,来实现这一目的。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "212979ec",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "## 安装依赖"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "e21c6728",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 安装千问的依赖\n",
+ "!cd Qwen-7b\n",
+ "!pip install -r requirements.txt\n",
+ "\n",
+ "# 安装 langchain 相关依赖\n",
+ "!pip install langchain google-search-results wolframalpha arxiv;"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3b5e6ef9",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "## 第零步 - 导入 LangChain 的工具"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "af7d0058",
+ "metadata": {},
+ "source": [
+ "以下引入几个常用 APIs 作为演示:\n",
+ " - [谷歌搜索API](https://serper.dev/?gclid=EAIaIQobChMIj9eqof7OgAMV44VbCh1F3QZoEAAYASABEgIh3fD_BwE#google-search-api)\n",
+ " - [WolframAlpha](https://products.wolframalpha.com/api/)\n",
+ " - arxiv论文搜索\n",
+ " - python shell (需升级python至3.9以上使用)\n",
+ "\n",
+ "注1:此处推荐模仿此案例,细致地构造给千问看的工具描述。\n",
+ "\n",
+ "注2:谷歌搜索(SERPAPI), WolframAlpha 需自行申请它们的 API_KEY 后才能使用。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "07e49b98-9d6c-41f2-9b18-f043f2d13e1a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from langchain import SerpAPIWrapper\n",
+ "from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper\n",
+ "from langchain.utilities import ArxivAPIWrapper\n",
+ "from langchain.tools.python.tool import PythonAstREPLTool\n",
+ "\n",
+ "from typing import Dict, Tuple\n",
+ "import os\n",
+ "import json\n",
+ "\n",
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
+ "from transformers.generation import GenerationConfig\n",
+ "\n",
+ "# 为了使用谷歌搜索(SERPAPI), WolframAlpha,您需要自行申请它们的 API KEY,然后填入此处\n",
+ "os.environ['SERPAPI_API_KEY'] = '重要!请在这里填入您的 SERPAPI_API_KEY!'\n",
+ "os.environ['WOLFRAM_ALPHA_APPID'] = '重要!请在这里填入您的 WOLFRAM_ALPHA_APPID!'\n",
+ "\n",
+ "search = SerpAPIWrapper()\n",
+ "WolframAlpha = WolframAlphaAPIWrapper()\n",
+ "arxiv = ArxivAPIWrapper()\n",
+ "python=PythonAstREPLTool()\n",
+ "\n",
+ "def tool_wrapper_for_qwen(tool):\n",
+ " def tool_(query):\n",
+ " query = json.loads(query)[\"query\"]\n",
+ " return tool.run(query)\n",
+ " return tool_\n",
+ "\n",
+ "# 以下是给千问看的工具描述:\n",
+ "TOOLS = [\n",
+ " {\n",
+ " 'name_for_human':\n",
+ " 'google search',\n",
+ " 'name_for_model':\n",
+ " 'Search',\n",
+ " 'description_for_model':\n",
+ " 'useful for when you need to answer questions about current events.',\n",
+ " 'parameters': [{\n",
+ " \"name\": \"query\",\n",
+ " \"type\": \"string\",\n",
+ " \"description\": \"search query of google\",\n",
+ " 'required': True\n",
+ " }], \n",
+ " 'tool_api': tool_wrapper_for_qwen(search)\n",
+ " },\n",
+ " {\n",
+ " 'name_for_human':\n",
+ " 'Wolfram Alpha',\n",
+ " 'name_for_model':\n",
+ " 'Math',\n",
+ " 'description_for_model':\n",
+ " 'Useful for when you need to answer questions about Math, Science, Technology, Culture, Society and Everyday Life.',\n",
+ " 'parameters': [{\n",
+ " \"name\": \"query\",\n",
+ " \"type\": \"string\",\n",
+ " \"description\": \"the problem to solved by Wolfram Alpha\",\n",
+ " 'required': True\n",
+ " }], \n",
+ " 'tool_api': tool_wrapper_for_qwen(WolframAlpha)\n",
+ " }, \n",
+ " {\n",
+ " 'name_for_human':\n",
+ " 'arxiv',\n",
+ " 'name_for_model':\n",
+ " 'Arxiv',\n",
+ " 'description_for_model':\n",
+ " 'A wrapper around Arxiv.org Useful for when you need to answer questions about Physics, Mathematics, Computer Science, Quantitative Biology, Quantitative Finance, Statistics, Electrical Engineering, and Economics from scientific articles on arxiv.org.',\n",
+ " 'parameters': [{\n",
+ " \"name\": \"query\",\n",
+ " \"type\": \"string\",\n",
+ " \"description\": \"the document id of arxiv to search\",\n",
+ " 'required': True\n",
+ " }], \n",
+ " 'tool_api': tool_wrapper_for_qwen(arxiv)\n",
+ " },\n",
+ " {\n",
+ " 'name_for_human':\n",
+ " 'python',\n",
+ " 'name_for_model':\n",
+ " 'python',\n",
+ " 'description_for_model':\n",
+ " \"A Python shell. Use this to execute python commands. When using this tool, sometimes output is abbreviated - Make sure it does not look abbreviated before using it in your answer. \"\n",
+ " \"Don't add comments to your python code.\",\n",
+ " 'parameters': [{\n",
+ " \"name\": \"query\",\n",
+ " \"type\": \"string\",\n",
+ " \"description\": \"a valid python command.\",\n",
+ " 'required': True\n",
+ " }],\n",
+ " 'tool_api': tool_wrapper_for_qwen(python)\n",
+ " }\n",
+ "\n",
+ "]\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b7ec2027",
+ "metadata": {},
+ "source": [
+ "## 第一步:让千问判断调用什么工具,生成工具入参"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7a50d676",
+ "metadata": {},
+ "source": [
+ "根据prompt模版、query、工具的信息构建prompt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "4a8feb0e-22f7-4184-9ea0-b864812c9b09",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Answer the following questions as best you can. You have access to the following tools:\n",
+ "\n",
+ "Search: Call this tool to interact with the google search API. What is the google search API useful for? useful for when you need to answer questions about current events. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"search query of google\", \"required\": true}] Format the arguments as a JSON object.\n",
+ "\n",
+ "Use the following format:\n",
+ "\n",
+ "Question: the input question you must answer\n",
+ "Thought: you should always think about what to do\n",
+ "Action: the action to take, should be one of [Search]\n",
+ "Action Input: the input to the action\n",
+ "Observation: the result of the action\n",
+ "... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n",
+ "Thought: I now know the final answer\n",
+ "Final Answer: the final answer to the original input question\n",
+ "\n",
+ "Begin!\n",
+ "\n",
+ "Question: 加拿大2023年人口统计数字是多少?\n"
+ ]
+ }
+ ],
+ "source": [
+ "TOOL_DESC = \"\"\"{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters} Format the arguments as a JSON object.\"\"\"\n",
+ "\n",
+ "REACT_PROMPT = \"\"\"Answer the following questions as best you can. You have access to the following tools:\n",
+ "\n",
+ "{tool_descs}\n",
+ "\n",
+ "Use the following format:\n",
+ "\n",
+ "Question: the input question you must answer\n",
+ "Thought: you should always think about what to do\n",
+ "Action: the action to take, should be one of [{tool_names}]\n",
+ "Action Input: the input to the action\n",
+ "Observation: the result of the action\n",
+ "... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n",
+ "Thought: I now know the final answer\n",
+ "Final Answer: the final answer to the original input question\n",
+ "\n",
+ "Begin!\n",
+ "\n",
+ "Question: {query}\"\"\"\n",
+ "\n",
+ "def build_planning_prompt(TOOLS, query):\n",
+ " tool_descs = []\n",
+ " tool_names = []\n",
+ " for info in TOOLS:\n",
+ " tool_descs.append(\n",
+ " TOOL_DESC.format(\n",
+ " name_for_model=info['name_for_model'],\n",
+ " name_for_human=info['name_for_human'],\n",
+ " description_for_model=info['description_for_model'],\n",
+ " parameters=json.dumps(\n",
+ " info['parameters'], ensure_ascii=False),\n",
+ " )\n",
+ " )\n",
+ " tool_names.append(info['name_for_model'])\n",
+ " tool_descs = '\\n\\n'.join(tool_descs)\n",
+ " tool_names = ','.join(tool_names)\n",
+ "\n",
+ " prompt = REACT_PROMPT.format(tool_descs=tool_descs, tool_names=tool_names, query=query)\n",
+ " return prompt\n",
+ "\n",
+ "prompt_1 = build_planning_prompt(TOOLS[0:1], query=\"加拿大2023年人口统计数字是多少?\")\n",
+ "print(prompt_1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6f22b002",
+ "metadata": {},
+ "source": [
+ "将prompt作为输入获得response"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "f71b2577-118c-4ce2-a0ed-a45ec59ea35b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "A new version of the following files was downloaded from https://huggingface.co/Qwen/Qwen-7B-Chat:\n",
+ "- tokenization_qwen.py\n",
+ ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
+ "A new version of the following files was downloaded from https://huggingface.co/Qwen/Qwen-7B-Chat:\n",
+ "- configuration_qwen.py\n",
+ ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
+ "A new version of the following files was downloaded from https://huggingface.co/Qwen/Qwen-7B-Chat:\n",
+ "- qwen_generation_utils.py\n",
+ ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
+ "A new version of the following files was downloaded from https://huggingface.co/Qwen/Qwen-7B-Chat:\n",
+ "- modeling_qwen.py\n",
+ "- qwen_generation_utils.py\n",
+ ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "23435445dded44d6951aa6a7b771a963",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading shards: 0%| | 0/8 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\".\n",
+ "Try importing flash-attention for faster inference...\n",
+ "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary\n",
+ "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm\n",
+ "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency https://github.com/Dao-AILab/flash-attention\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "728a1c13c2884291ade4cb4a1edfaaf2",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/8 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# 国内连 hugginface 网络不好,这段代码可能需要多重试\n",
+ "checkpoint = \"Qwen/Qwen-7B-Chat\"\n",
+ "TOKENIZER = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)\n",
+ "MODEL = AutoModelForCausalLM.from_pretrained(checkpoint, device_map=\"auto\", trust_remote_code=True).eval()\n",
+ "MODEL.generation_config = GenerationConfig.from_pretrained(checkpoint, trust_remote_code=True)\n",
+ "MODEL.generation_config.do_sample = False # greedy"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "dc0dbd6c-5a0f-44c9-a019-0ec0283ca92d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Thought: 我应该使用搜索工具帮助我完成任务。search api能完成搜索的任务。\n",
+ "Action: Search\n",
+ "Action Input: {\"query\": \"加拿大 2023年人口统计数字\"}\n",
+ "Observation:\n"
+ ]
+ }
+ ],
+ "source": [
+ "stop = [\"Observation:\", \"Observation:\\n\"]\n",
+ "react_stop_words_tokens = [TOKENIZER.encode(stop_) for stop_ in stop]\n",
+ "response_1, _ = MODEL.chat(TOKENIZER, prompt_1, history=None, stop_words_ids=react_stop_words_tokens)\n",
+ "print(response_1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1ebf47ac",
+ "metadata": {},
+ "source": [
+ "## 第二步:从千问的输出中解析需要使用的工具和入参,并调用对应工具"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "1a431670-a1f6-4afd-972f-1cfd6d06e8c9",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "根据加拿大统计局预测,加拿大人口今天(2023年6月16日)预计将超过4000万。 联邦统计局使用模型来实时估计加拿大的人口,该计数模型预计加拿大人口将在北美东部时间今天下午3点前达到4000万。 加拿大的人口增长率目前为2.7%。\n"
+ ]
+ }
+ ],
+ "source": [
+ "def parse_latest_plugin_call(text: str) -> Tuple[str, str]:\n",
+ " i = text.rfind('\\nAction:')\n",
+ " j = text.rfind('\\nAction Input:')\n",
+ " k = text.rfind('\\nObservation:')\n",
+ " if 0 <= i < j: # If the text has `Action` and `Action input`,\n",
+ " if k < j: # but does not contain `Observation`,\n",
+ " # then it is likely that `Observation` is ommited by the LLM,\n",
+ " # because the output text may have discarded the stop word.\n",
+ " text = text.rstrip() + '\\nObservation:' # Add it back.\n",
+ " k = text.rfind('\\nObservation:')\n",
+ " if 0 <= i < j < k:\n",
+ " plugin_name = text[i + len('\\nAction:'):j].strip()\n",
+ " plugin_args = text[j + len('\\nAction Input:'):k].strip()\n",
+ " return plugin_name, plugin_args\n",
+ " return '', ''\n",
+ "\n",
+ "def use_api(tools, response):\n",
+ " use_toolname, action_input = parse_latest_plugin_call(response)\n",
+ " if use_toolname == \"\":\n",
+ " return \"no tool founds\"\n",
+ "\n",
+ " used_tool_meta = list(filter(lambda x: x[\"name_for_model\"] == use_toolname, tools))\n",
+ " if len(used_tool_meta) == 0:\n",
+ " return \"no tool founds\"\n",
+ " \n",
+ " api_output = used_tool_meta[0][\"tool_api\"](action_input)\n",
+ " return api_output\n",
+ "\n",
+ "api_output = use_api(TOOLS, response_1)\n",
+ "print(api_output)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "106a4ba0",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "## 第三步:让千问根据工具返回结果继续作答\n",
+ "拼接上述返回答案,形成新的prompt,并获得生成最终结果"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "a9d4d42d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Answer the following questions as best you can. You have access to the following tools:\n",
+ "\n",
+ "Search: Call this tool to interact with the google search API. What is the google search API useful for? useful for when you need to answer questions about current events. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"search query of google\", \"required\": true}] Format the arguments as a JSON object.\n",
+ "\n",
+ "Use the following format:\n",
+ "\n",
+ "Question: the input question you must answer\n",
+ "Thought: you should always think about what to do\n",
+ "Action: the action to take, should be one of [Search]\n",
+ "Action Input: the input to the action\n",
+ "Observation: the result of the action\n",
+ "... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n",
+ "Thought: I now know the final answer\n",
+ "Final Answer: the final answer to the original input question\n",
+ "\n",
+ "Begin!\n",
+ "\n",
+ "Question: 加拿大2023年人口统计数字是多少?Thought: 我应该使用搜索工具帮助我完成任务。search api能完成搜索的任务。\n",
+ "Action: Search\n",
+ "Action Input: {\"query\": \"加拿大 2023年人口统计数字\"}\n",
+ "Observation: 根据加拿大统计局预测,加拿大人口今天(2023年6月16日)预计将超过4000万。 联邦统计局使用模型来实时估计加拿大的人口,该计数模型预计加拿大人口将在北美东部时间今天下午3点前达到4000万。 加拿大的人口增长率目前为2.7%。 Thought: I now know the final answer.\n",
+ "Final Answer: 加拿大2023年人口统计数字预计为4000万。\n"
+ ]
+ }
+ ],
+ "source": [
+ "prompt_2 = prompt_1 + response_1 + ' ' + api_output\n",
+ "stop = [\"Observation:\", \"Observation:\\n\"]\n",
+ "react_stop_words_tokens = [TOKENIZER.encode(stop_) for stop_ in stop]\n",
+ "response_2, _ = MODEL.chat(TOKENIZER, prompt_2, history=None, stop_words_ids=react_stop_words_tokens)\n",
+ "print(prompt_2, response_2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0b8da9fd",
+ "metadata": {},
+ "source": [
+ "## 总结 - 串联起整个流程"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "1e51a8ea",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def main(query, choose_tools):\n",
+ " prompt = build_planning_prompt(choose_tools, query) # 组织prompt\n",
+ " print(prompt)\n",
+ " stop = [\"Observation:\", \"Observation:\\n\"]\n",
+ " react_stop_words_tokens = [TOKENIZER.encode(stop_) for stop_ in stop]\n",
+ " response, _ = MODEL.chat(TOKENIZER, prompt, history=None, stop_words_ids=react_stop_words_tokens)\n",
+ "\n",
+ " while \"Final Answer:\" not in response: # 出现final Answer时结束\n",
+ " api_output = use_api(choose_tools, response) # 抽取入参并执行api\n",
+ " api_output = str(api_output) # 部分api工具返回结果非字符串格式需进行转化后输出\n",
+ " if \"no tool founds\" == api_output:\n",
+ " break\n",
+ " print(\"\\033[32m\" + response + \"\\033[0m\" + \"\\033[34m\" + ' ' + api_output + \"\\033[0m\")\n",
+ " prompt = prompt + response + ' ' + api_output # 合并api输出\n",
+ " response, _ = MODEL.chat(TOKENIZER, prompt, history=None, stop_words_ids=react_stop_words_tokens) # 继续生成\n",
+ "\n",
+ " print(\"\\033[32m\" + response + \"\\033[0m\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "6dc38a34",
+ "metadata": {
+ "collapsed": false,
+ "jupyter": {
+ "outputs_hidden": false
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "==========\n",
+ "Answer the following questions as best you can. You have access to the following tools:\n",
+ "\n",
+ "Search: Call this tool to interact with the google search API. What is the google search API useful for? useful for when you need to answer questions about current events. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"search query of google\", \"required\": true}] Format the arguments as a JSON object.\n",
+ "\n",
+ "Math: Call this tool to interact with the Wolfram Alpha API. What is the Wolfram Alpha API useful for? Useful for when you need to answer questions about Math, Science, Technology, Culture, Society and Everyday Life. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"the problem to solved by Wolfram Alpha\", \"required\": true}] Format the arguments as a JSON object.\n",
+ "\n",
+ "Arxiv: Call this tool to interact with the arxiv API. What is the arxiv API useful for? A wrapper around Arxiv.org Useful for when you need to answer questions about Physics, Mathematics, Computer Science, Quantitative Biology, Quantitative Finance, Statistics, Electrical Engineering, and Economics from scientific articles on arxiv.org. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"the document id of arxiv to search\", \"required\": true}] Format the arguments as a JSON object.\n",
+ "\n",
+ "python: Call this tool to interact with the python API. What is the python API useful for? A Python shell. Use this to execute python commands. When using this tool, sometimes output is abbreviated - Make sure it does not look abbreviated before using it in your answer. Don't add comments to your python code. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"a valid python command.\", \"required\": true}] Format the arguments as a JSON object.\n",
+ "\n",
+ "Use the following format:\n",
+ "\n",
+ "Question: the input question you must answer\n",
+ "Thought: you should always think about what to do\n",
+ "Action: the action to take, should be one of [Search,Math,Arxiv,python]\n",
+ "Action Input: the input to the action\n",
+ "Observation: the result of the action\n",
+ "... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n",
+ "Thought: I now know the final answer\n",
+ "Final Answer: the final answer to the original input question\n",
+ "\n",
+ "Begin!\n",
+ "\n",
+ "Question: 加拿大2022年的人口数量有多少?\n",
+ "\u001B[32mThought: 我应该使用搜索工具帮助我完成任务。search api能完成搜索的任务。\n",
+ "Action: Search\n",
+ "Action Input: {\"query\": \"加拿大 2022年人口数量\"}\n",
+ "Observation:\u001B[0m\u001B[34m 中新社多伦多3月22日电(记者余瑞冬)加拿大统计局3月22日公布的人口统计数据显示,截至今年1月1日,该国估算总人口约为3956.62万人,且2022年的人口增长数创纪录地突破100万人。 加统计局估算,该国人口在2022年增长105.011万人,年增长2.7%,创1957年以来最大增幅。\u001B[0m\n",
+ "\u001B[32mThought: I now know the final answer.\n",
+ "Final Answer: 加拿大2022年的人口数量约为3956.62万人。\u001B[0m\n",
+ "==========\n",
+ "Answer the following questions as best you can. You have access to the following tools:\n",
+ "\n",
+ "Search: Call this tool to interact with the google search API. What is the google search API useful for? useful for when you need to answer questions about current events. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"search query of google\", \"required\": true}] Format the arguments as a JSON object.\n",
+ "\n",
+ "Math: Call this tool to interact with the Wolfram Alpha API. What is the Wolfram Alpha API useful for? Useful for when you need to answer questions about Math, Science, Technology, Culture, Society and Everyday Life. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"the problem to solved by Wolfram Alpha\", \"required\": true}] Format the arguments as a JSON object.\n",
+ "\n",
+ "Arxiv: Call this tool to interact with the arxiv API. What is the arxiv API useful for? A wrapper around Arxiv.org Useful for when you need to answer questions about Physics, Mathematics, Computer Science, Quantitative Biology, Quantitative Finance, Statistics, Electrical Engineering, and Economics from scientific articles on arxiv.org. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"the document id of arxiv to search\", \"required\": true}] Format the arguments as a JSON object.\n",
+ "\n",
+ "python: Call this tool to interact with the python API. What is the python API useful for? A Python shell. Use this to execute python commands. When using this tool, sometimes output is abbreviated - Make sure it does not look abbreviated before using it in your answer. Don't add comments to your python code. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"a valid python command.\", \"required\": true}] Format the arguments as a JSON object.\n",
+ "\n",
+ "Use the following format:\n",
+ "\n",
+ "Question: the input question you must answer\n",
+ "Thought: you should always think about what to do\n",
+ "Action: the action to take, should be one of [Search,Math,Arxiv,python]\n",
+ "Action Input: the input to the action\n",
+ "Observation: the result of the action\n",
+ "... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n",
+ "Thought: I now know the final answer\n",
+ "Final Answer: the final answer to the original input question\n",
+ "\n",
+ "Begin!\n",
+ "\n",
+ "Question: 求解方程 2x+5 = -3x + 7\n",
+ "\u001B[32mThought: 我应该使用数学工具帮助我完成任务。Wolfram Alpha API应该能完成这项任务。\n",
+ "Action: Math\n",
+ "Action Input: {\"query\": \"2x+5 = -3x + 7\"}\n",
+ "Observation:\u001B[0m\u001B[34m Assumption: 2 x + 5 = -3 x + 7 \n",
+ "Answer: x = 2/5\u001B[0m\n",
+ "\u001B[32mThought: I now know the final answer.\n",
+ "Final Answer: x = 2/5\u001B[0m\n",
+ "==========\n",
+ "Answer the following questions as best you can. You have access to the following tools:\n",
+ "\n",
+ "Search: Call this tool to interact with the google search API. What is the google search API useful for? useful for when you need to answer questions about current events. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"search query of google\", \"required\": true}] Format the arguments as a JSON object.\n",
+ "\n",
+ "Math: Call this tool to interact with the Wolfram Alpha API. What is the Wolfram Alpha API useful for? Useful for when you need to answer questions about Math, Science, Technology, Culture, Society and Everyday Life. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"the problem to solved by Wolfram Alpha\", \"required\": true}] Format the arguments as a JSON object.\n",
+ "\n",
+ "Arxiv: Call this tool to interact with the arxiv API. What is the arxiv API useful for? A wrapper around Arxiv.org Useful for when you need to answer questions about Physics, Mathematics, Computer Science, Quantitative Biology, Quantitative Finance, Statistics, Electrical Engineering, and Economics from scientific articles on arxiv.org. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"the document id of arxiv to search\", \"required\": true}] Format the arguments as a JSON object.\n",
+ "\n",
+ "python: Call this tool to interact with the python API. What is the python API useful for? A Python shell. Use this to execute python commands. When using this tool, sometimes output is abbreviated - Make sure it does not look abbreviated before using it in your answer. Don't add comments to your python code. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"a valid python command.\", \"required\": true}] Format the arguments as a JSON object.\n",
+ "\n",
+ "Use the following format:\n",
+ "\n",
+ "Question: the input question you must answer\n",
+ "Thought: you should always think about what to do\n",
+ "Action: the action to take, should be one of [Search,Math,Arxiv,python]\n",
+ "Action Input: the input to the action\n",
+ "Observation: the result of the action\n",
+ "... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n",
+ "Thought: I now know the final answer\n",
+ "Final Answer: the final answer to the original input question\n",
+ "\n",
+ "Begin!\n",
+ "\n",
+ "Question: 编号是1605.08386的论文讲了些什么?\n",
+ "\u001B[32mThought: 我需要使用Arxiv API来搜索这篇论文。\n",
+ "Action: Arxiv\n",
+ "Action Input: {\"query\": \"1605.08386\"}\n",
+ "Observation:\u001B[0m\u001B[34m Published: 2016-05-26\n",
+ "Title: Heat-bath random walks with Markov bases\n",
+ "Authors: Caprice Stanley, Tobias Windisch\n",
+ "Summary: Graphs on lattice points are studied whose edges come from a finite set of\n",
+ "allowed moves of arbitrary length. We show that the diameter of these graphs on\n",
+ "fibers of a fixed integer matrix can be bounded from above by a constant. We\n",
+ "then study the mixing behaviour of heat-bath random walks on these graphs. We\n",
+ "also state explicit conditions on the set of moves so that the heat-bath random\n",
+ "walk, a generalization of the Glauber dynamics, is an expander in fixed\n",
+ "dimension.\u001B[0m\n",
+ "\u001B[32mThought: I now know the final answer.\n",
+ "Final Answer: 这篇论文的题目是《热浴随机游走的马尔可夫基》,作者是Caprice Stanley和Tobias Windisch。摘要中提到,该论文研究了在有限的允许移动集合中,由任意长度的边构成的图的边。我们证明了这些图在固定整数矩阵纤维上的直径可以被一个常数所限制。然后,我们研究了热浴随机游走在这类图上的混合行为。我们还给出了一个明确的条件,使得热浴随机游走(一个Glauber动力学的推广)在固定维度下是一个扩张。\u001B[0m\n",
+ "==========\n",
+ "Answer the following questions as best you can. You have access to the following tools:\n",
+ "\n",
+ "Search: Call this tool to interact with the google search API. What is the google search API useful for? useful for when you need to answer questions about current events. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"search query of google\", \"required\": true}] Format the arguments as a JSON object.\n",
+ "\n",
+ "Math: Call this tool to interact with the Wolfram Alpha API. What is the Wolfram Alpha API useful for? Useful for when you need to answer questions about Math, Science, Technology, Culture, Society and Everyday Life. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"the problem to solved by Wolfram Alpha\", \"required\": true}] Format the arguments as a JSON object.\n",
+ "\n",
+ "Arxiv: Call this tool to interact with the arxiv API. What is the arxiv API useful for? A wrapper around Arxiv.org Useful for when you need to answer questions about Physics, Mathematics, Computer Science, Quantitative Biology, Quantitative Finance, Statistics, Electrical Engineering, and Economics from scientific articles on arxiv.org. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"the document id of arxiv to search\", \"required\": true}] Format the arguments as a JSON object.\n",
+ "\n",
+ "python: Call this tool to interact with the python API. What is the python API useful for? A Python shell. Use this to execute python commands. When using this tool, sometimes output is abbreviated - Make sure it does not look abbreviated before using it in your answer. Don't add comments to your python code. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"a valid python command.\", \"required\": true}] Format the arguments as a JSON object.\n",
+ "\n",
+ "Use the following format:\n",
+ "\n",
+ "Question: the input question you must answer\n",
+ "Thought: you should always think about what to do\n",
+ "Action: the action to take, should be one of [Search,Math,Arxiv,python]\n",
+ "Action Input: the input to the action\n",
+ "Observation: the result of the action\n",
+ "... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n",
+ "Thought: I now know the final answer\n",
+ "Final Answer: the final answer to the original input question\n",
+ "\n",
+ "Begin!\n",
+ "\n",
+ "Question: 使用python对下面的列表进行排序: [2, 4135, 523, 2, 3]\n",
+ "\u001B[32mThought: 我应该使用python API来执行python命令。\n",
+ "Action: python\n",
+ "Action Input: {\"query\": \"sorted([2, 4135, 523, 2, 3])\"}\n",
+ "Observation:\u001B[0m\u001B[34m [2, 2, 3, 523, 4135]\u001B[0m\n",
+ "\u001B[32mThought: I now know the final answer.\n",
+ "Final Answer: 使用python对给定的列表进行排序,结果为 [2, 2, 3, 523, 4135]。\u001B[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 请尽可能控制备选工具数量\n",
+ "query = \"加拿大2022年的人口数量有多少?\" # 所提问题\n",
+ "choose_tools = TOOLS # 选择备选工具\n",
+ "print(\"=\" * 10)\n",
+ "main(query, choose_tools)\n",
+ "\n",
+ "query = \"求解方程 2x+5 = -3x + 7\" # 所提问题\n",
+ "choose_tools = TOOLS # 选择备选工具\n",
+ "print(\"=\" * 10)\n",
+ "main(query, choose_tools)\n",
+ "\n",
+ "query = \"编号是1605.08386的论文讲了些什么?\" # 所提问题\n",
+ "choose_tools = TOOLS # 选择备选工具\n",
+ "print(\"=\" * 10)\n",
+ "main(query, choose_tools)\n",
+ "\n",
+ "query =\"使用python对下面的列表进行排序: [2, 4135, 523, 2, 3]\"\n",
+ "choose_tools = TOOLS # 选择备选工具\n",
+ "print(\"=\" * 10)\n",
+ "main(query, choose_tools)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.16"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/react_demo.py b/examples/react_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f9f452034484d057e0b9c020dfdcaf09fb2d155
--- /dev/null
+++ b/examples/react_demo.py
@@ -0,0 +1,288 @@
+#
+# 相关材料:
+# ReAct Prompting 原理简要介绍,不包含代码实现:
+# https://github.com/QwenLM/Qwen-7B/blob/main/examples/react_prompt.md
+# 基于 model.chat 接口(对话模式)的 ReAct Prompting 实现(含接入 LangChain 的工具实现):
+# https://github.com/QwenLM/Qwen-7B/blob/main/examples/langchain_tooluse.ipynb
+# 基于 model.generate 接口(续写模式)的 ReAct Prompting 实现,比 chat 模式的实现更复杂些:
+# https://github.com/QwenLM/Qwen-7B/blob/main/examples/react_demo.py(本文件)
+#
+
+import json
+import os
+
+import json5
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from transformers.generation import GenerationConfig
+
+for _ in range(10): # 网络不稳定,多试几次
+ try:
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
+ generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained(
+ "Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True
+ ).eval()
+ model.generation_config = generation_config
+ model.generation_config.do_sample = False
+ break
+ except Exception:
+ pass
+
+# 将一个插件的关键信息拼接成一段文本的模版。
+TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""
+
+# ReAct prompting 的 instruction 模版,将包含插件的详细信息。
+PROMPT_REACT = """Answer the following questions as best you can. You have access to the following tools:
+
+{tools_text}
+
+Use the following format:
+
+Question: the input question you must answer
+Thought: you should always think about what to do
+Action: the action to take, should be one of [{tools_name_text}]
+Action Input: the input to the action
+Observation: the result of the action
+... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
+Thought: I now know the final answer
+Final Answer: the final answer to the original input question
+
+Begin!
+
+Question: {query}"""
+
+
+#
+# 本示例代码的入口函数。
+#
+# 输入:
+# prompt: 用户的最新一个问题。
+# history: 用户与模型的对话历史,是一个 list,
+# list 中的每个元素为 {"user": "用户输入", "bot": "模型输出"} 的一轮对话。
+# 最新的一轮对话放 list 末尾。不包含最新一个问题。
+# list_of_plugin_info: 候选插件列表,是一个 list,list 中的每个元素为一个插件的关键信息。
+# 比如 list_of_plugin_info = [plugin_info_0, plugin_info_1, plugin_info_2],
+# 其中 plugin_info_0, plugin_info_1, plugin_info_2 这几个样例见本文档前文。
+#
+# 输出:
+# 模型对用户最新一个问题的回答。
+#
+def llm_with_plugin(prompt: str, history, list_of_plugin_info=()):
+ chat_history = [(x['user'], x['bot']) for x in history] + [(prompt, '')]
+
+ # 需要让模型进行续写的初始文本
+ planning_prompt = build_input_text(chat_history, list_of_plugin_info)
+
+ text = ''
+ while True:
+ output = text_completion(planning_prompt + text, stop_words=['Observation:', 'Observation:\n'])
+ action, action_input, output = parse_latest_plugin_call(output)
+ if action: # 需要调用插件
+ # action、action_input 分别为需要调用的插件代号、输入参数
+ # observation是插件返回的结果,为字符串
+ observation = call_plugin(action, action_input)
+ output += f'\nObservation: {observation}\nThought:'
+ text += output
+ else: # 生成结束,并且不再需要调用插件
+ text += output
+ break
+
+ new_history = []
+ new_history.extend(history)
+ new_history.append({'user': prompt, 'bot': text})
+ return text, new_history
+
+
+# 将对话历史、插件信息聚合成一段初始文本
+def build_input_text(chat_history, list_of_plugin_info) -> str:
+ # 候选插件的详细信息
+ tools_text = []
+ for plugin_info in list_of_plugin_info:
+ tool = TOOL_DESC.format(
+ name_for_model=plugin_info["name_for_model"],
+ name_for_human=plugin_info["name_for_human"],
+ description_for_model=plugin_info["description_for_model"],
+ parameters=json.dumps(plugin_info["parameters"], ensure_ascii=False),
+ )
+ if plugin_info.get('args_format', 'json') == 'json':
+ tool += " Format the arguments as a JSON object."
+ elif plugin_info['args_format'] == 'code':
+ tool += ' Enclose the code within triple backticks (`) at the beginning and end of the code.'
+ else:
+ raise NotImplementedError
+ tools_text.append(tool)
+ tools_text = '\n\n'.join(tools_text)
+
+ # 候选插件的代号
+ tools_name_text = ', '.join([plugin_info["name_for_model"] for plugin_info in list_of_plugin_info])
+
+ im_start = '<|im_start|>'
+ im_end = '<|im_end|>'
+ prompt = f'{im_start}system\nYou are a helpful assistant.{im_end}'
+ for i, (query, response) in enumerate(chat_history):
+ if list_of_plugin_info: # 如果有候选插件
+ # 倒数第一轮或倒数第二轮对话填入详细的插件信息,但具体什么位置填可以自行判断
+ if (len(chat_history) == 1) or (i == len(chat_history) - 2):
+ query = PROMPT_REACT.format(
+ tools_text=tools_text,
+ tools_name_text=tools_name_text,
+ query=query,
+ )
+ query = query.lstrip('\n').rstrip() # 重要!若不 strip 会与训练时数据的构造方式产生差异。
+ response = response.lstrip('\n').rstrip() # 重要!若不 strip 会与训练时数据的构造方式产生差异。
+ # 使用续写模式(text completion)时,需要用如下格式区分用户和AI:
+ prompt += f"\n{im_start}user\n{query}{im_end}"
+ prompt += f"\n{im_start}assistant\n{response}{im_end}"
+
+ assert prompt.endswith(f"\n{im_start}assistant\n{im_end}")
+ prompt = prompt[: -len(f'{im_end}')]
+ return prompt
+
+
+def text_completion(input_text: str, stop_words) -> str: # 作为一个文本续写模型来使用
+ im_end = '<|im_end|>'
+ if im_end not in stop_words:
+ stop_words = stop_words + [im_end]
+ stop_words_ids = [tokenizer.encode(w) for w in stop_words]
+
+ # TODO: 增加流式输出的样例实现
+ input_ids = torch.tensor([tokenizer.encode(input_text)]).to(model.device)
+ output = model.generate(input_ids, stop_words_ids=stop_words_ids)
+ output = output.tolist()[0]
+ output = tokenizer.decode(output, errors="ignore")
+ assert output.startswith(input_text)
+ output = output[len(input_text) :].replace('<|endoftext|>', '').replace(im_end, '')
+
+ for stop_str in stop_words:
+ idx = output.find(stop_str)
+ if idx != -1:
+ output = output[: idx + len(stop_str)]
+ return output # 续写 input_text 的结果,不包含 input_text 的内容
+
+
+def parse_latest_plugin_call(text):
+ plugin_name, plugin_args = '', ''
+ i = text.rfind('\nAction:')
+ j = text.rfind('\nAction Input:')
+ k = text.rfind('\nObservation:')
+ if 0 <= i < j: # If the text has `Action` and `Action input`,
+ if k < j: # but does not contain `Observation`,
+ # then it is likely that `Observation` is ommited by the LLM,
+ # because the output text may have discarded the stop word.
+ text = text.rstrip() + '\nObservation:' # Add it back.
+ k = text.rfind('\nObservation:')
+ plugin_name = text[i + len('\nAction:') : j].strip()
+ plugin_args = text[j + len('\nAction Input:') : k].strip()
+ text = text[:k]
+ return plugin_name, plugin_args, text
+
+
+#
+# 输入:
+# plugin_name: 需要调用的插件代号,对应 name_for_model。
+# plugin_args:插件的输入参数,是一个 dict,dict 的 key、value 分别为参数名、参数值。
+# 输出:
+# 插件的返回结果,需要是字符串。
+# 即使原本是 JSON 输出,也请 json.dumps(..., ensure_ascii=False) 成字符串。
+#
+def call_plugin(plugin_name: str, plugin_args: str) -> str:
+ #
+ # 请开发者自行完善这部分内容。这里的参考实现仅是 demo 用途,非生产用途。
+ #
+ if plugin_name == 'google_search':
+ # 使用 SerpAPI 需要在这里填入您的 SERPAPI_API_KEY!
+ os.environ["SERPAPI_API_KEY"] = os.getenv("SERPAPI_API_KEY", default='')
+ from langchain import SerpAPIWrapper
+
+ return SerpAPIWrapper().run(json5.loads(plugin_args)['search_query'])
+ elif plugin_name == 'image_gen':
+ import urllib.parse
+
+ prompt = json5.loads(plugin_args)["prompt"]
+ prompt = urllib.parse.quote(prompt)
+ return json.dumps({'image_url': f'https://image.pollinations.ai/prompt/{prompt}'}, ensure_ascii=False)
+ else:
+ raise NotImplementedError
+
+
+def test():
+ tools = [
+ {
+ 'name_for_human': '谷歌搜索',
+ 'name_for_model': 'google_search',
+ 'description_for_model': '谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。',
+ 'parameters': [
+ {
+ 'name': 'search_query',
+ 'description': '搜索关键词或短语',
+ 'required': True,
+ 'schema': {'type': 'string'},
+ }
+ ],
+ },
+ {
+ 'name_for_human': '文生图',
+ 'name_for_model': 'image_gen',
+ 'description_for_model': '文生图是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL',
+ 'parameters': [
+ {
+ 'name': 'prompt',
+ 'description': '英文关键词,描述了希望图像具有什么内容',
+ 'required': True,
+ 'schema': {'type': 'string'},
+ }
+ ],
+ },
+ ]
+ history = []
+ for query in ['你好', '谁是周杰伦', '他老婆是谁', '给我画个可爱的小猫吧,最好是黑猫']:
+ print(f"User's Query:\n{query}\n")
+ response, history = llm_with_plugin(prompt=query, history=history, list_of_plugin_info=tools)
+ print(f"Qwen's Response:\n{response}\n")
+
+
+if __name__ == "__main__":
+ test()
+
+"""如果执行成功,在终端下应当能看到如下输出:
+User's Query:
+你好
+
+Qwen's Response:
+Thought: 提供的工具对回答该问题帮助较小,我将不使用工具直接作答。
+Final Answer: 你好!很高兴见到你。有什么我可以帮忙的吗?
+
+User's Query:
+谁是周杰伦
+
+Qwen's Response:
+Thought: 我应该使用Google搜索查找相关信息。
+Action: google_search
+Action Input: {"search_query": "周杰伦"}
+Observation: Jay Chou is a Taiwanese singer, songwriter, record producer, rapper, actor, television personality, and businessman.
+Thought: I now know the final answer.
+Final Answer: 周杰伦(Jay Chou)是一位来自台湾的歌手、词曲创作人、音乐制作人、说唱歌手、演员、电视节目主持人和企业家。他以其独特的音乐风格和才华在华语乐坛享有很高的声誉。
+
+User's Query:
+他老婆是谁
+
+Qwen's Response:
+Thought: 我应该使用Google搜索查找相关信息。
+Action: google_search
+Action Input: {"search_query": "周杰伦 老婆"}
+Observation: Hannah Quinlivan
+Thought: I now know the final answer.
+Final Answer: 周杰伦的老婆是Hannah Quinlivan,她是一位澳大利亚籍的模特和演员。两人于2015年结婚,并育有一子。
+
+User's Query:
+给我画个可爱的小猫吧,最好是黑猫
+
+Qwen's Response:
+Thought: 我应该使用文生图API来生成一张可爱的小猫图片。
+Action: image_gen
+Action Input: {"prompt": "cute black cat"}
+Observation: {"image_url": "https://image.pollinations.ai/prompt/cute%20black%20cat"}
+Thought: I now know the final answer.
+Final Answer: 生成的可爱小猫图片的URL为https://image.pollinations.ai/prompt/cute%20black%20cat。你可以点击这个链接查看图片。
+"""
diff --git a/examples/react_prompt.md b/examples/react_prompt.md
new file mode 100644
index 0000000000000000000000000000000000000000..9ffd7924b2e16addfed2cf0bb975eb8501cecdfa
--- /dev/null
+++ b/examples/react_prompt.md
@@ -0,0 +1,249 @@
+# ReAct Prompting 示例
+
+本文档将介绍如何用 ReAct Prompting 技术命令千问使用工具。
+
+本文档主要基本的原理概念介绍,并在文末附上了一些具体实现相关的 FAQ,但不含被调用插件的实际实现。如果您更喜欢一边调试实际可执行的代码、一边理解原理,可以转而阅读整合了 LangChain 常用工具的这个 [ipython notebook](https://github.com/QwenLM/Qwen-7B/blob/main/examples/langchain_tooluse.ipynb)。
+
+此外,本文档和前述的 ipython notebook 都仅介绍单轮对话的实现。如果想了解多轮对话下的实现,可参见 [react_demo.py](https://github.com/QwenLM/Qwen-7B/blob/main/examples/react_demo.py)。
+
+## 准备工作一:样例问题、样例工具
+
+假设我们有如下的一个适合用工具处理的 query,以及有夸克搜索、通义万相文生图这两个工具:
+
+```py
+query = '我是老板,我说啥你做啥。现在给我画个五彩斑斓的黑。'
+
+TOOLS = [
+ {
+ 'name_for_human':
+ '夸克搜索',
+ 'name_for_model':
+ 'quark_search',
+ 'description_for_model':
+ '夸克搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。',
+ 'parameters': [{
+ 'name': 'search_query',
+ 'description': '搜索关键词或短语',
+ 'required': True,
+ 'schema': {
+ 'type': 'string'
+ },
+ }],
+ },
+ {
+ 'name_for_human':
+ '通义万相',
+ 'name_for_model':
+ 'image_gen',
+ 'description_for_model':
+ '通义万相是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL',
+ 'parameters': [{
+ 'name': 'query',
+ 'description': '中文关键词,描述了希望图像具有什么内容',
+ 'required': True,
+ 'schema': {
+ 'type': 'string'
+ },
+ }],
+ },
+]
+```
+
+## 准备工作二:ReAct 模版
+
+我们将使用如下的 ReAct prompt 模版来激发千问使用工具的能力。
+
+```py
+TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters} Format the arguments as a JSON object."""
+
+REACT_PROMPT = """Answer the following questions as best you can. You have access to the following tools:
+
+{tool_descs}
+
+Use the following format:
+
+Question: the input question you must answer
+Thought: you should always think about what to do
+Action: the action to take, should be one of [{tool_names}]
+Action Input: the input to the action
+Observation: the result of the action
+... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
+Thought: I now know the final answer
+Final Answer: the final answer to the original input question
+
+Begin!
+
+Question: {query}"""
+```
+
+## 步骤一:让千问判断要调用什么工具、生成工具入参
+
+首先我们需要根据 ReAct prompt 模版、query、工具的信息构建 prompt:
+
+```py
+tool_descs = []
+tool_names = []
+for info in TOOLS:
+ tool_descs.append(
+ TOOL_DESC.format(
+ name_for_model=info['name_for_model'],
+ name_for_human=info['name_for_human'],
+ description_for_model=info['description_for_model'],
+ parameters=json.dumps(
+ info['parameters'], ensure_ascii=False),
+ )
+ )
+ tool_names.append(info['name_for_model'])
+tool_descs = '\n\n'.join(tool_descs)
+tool_names = ','.join(tool_names)
+
+prompt = REACT_PROMPT.format(tool_descs=tool_descs, tool_names=tool_names, query=query)
+print(prompt)
+```
+
+打印出来的、构建好的 prompt 如下:
+
+```
+Answer the following questions as best you can. You have access to the following tools:
+
+quark_search: Call this tool to interact with the 夸克搜索 API. What is the 夸克搜索 API useful for? 夸克搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。 Parameters: [{"name": "search_query", "description": "搜索关键词或短语", "required": true, "schema": {"type": "string"}}] Format the arguments as a JSON object.
+
+image_gen: Call this tool to interact with the 通义万相 API. What is the 通义万相 API useful for? 通义万相是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL Parameters: [{"name": "query", "description": "中文关键词,描述了希望图像具有什么内容", "required": true, "schema": {"type": "string"}}] Format the arguments as a JSON object.
+
+Use the following format:
+
+Question: the input question you must answer
+Thought: you should always think about what to do
+Action: the action to take, should be one of [quark_search,image_gen]
+Action Input: the input to the action
+Observation: the result of the action
+... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
+Thought: I now know the final answer
+Final Answer: the final answer to the original input question
+
+Begin!
+
+Question: 我是老板,我说啥你做啥。现在给我画个五彩斑斓的黑。
+```
+
+将这个 prompt 送入千问,并记得设置 "Observation" 为 stop word (见本文末尾的 FAQ)—— 即让千问在预测到要生成的下一个词是 "Observation" 时马上停止生成 —— 则千问在得到这个 prompt 后会生成如下的结果:
+
+![](../assets/react_tutorial_001.png)
+
+```
+Thought: 我应该使用通义万相API来生成一张五彩斑斓的黑的图片。
+Action: image_gen
+Action Input: {"query": "五彩斑斓的黑"}
+```
+
+在得到这个结果后,调用千问的开发者可以通过简单的解析提取出 `{"query": "五彩斑斓的黑"}` 并基于这个解析结果调用文生图服务 —— 这部分逻辑需要开发者自行实现,或者也可以使用千问商业版,商业版本将内部集成相关逻辑。
+
+## 步骤二:让千问根据插件返回结果继续作答
+
+让我们假设文生图插件返回了如下结果:
+
+```
+{"status_code": 200, "request_id": "3d894da2-0e26-9b7c-bd90-102e5250ae03", "code": null, "message": "", "output": {"task_id": "2befaa09-a8b3-4740-ada9-4d00c2758b05", "task_status": "SUCCEEDED", "results": [{"url": "https://dashscope-result-sh.oss-cn-shanghai.aliyuncs.com/1e5e2015/20230801/1509/6b26bb83-469e-4c70-bff4-a9edd1e584f3-1.png"}], "task_metrics": {"TOTAL": 1, "SUCCEEDED": 1, "FAILED": 0}}, "usage": {"image_count": 1}}
+```
+
+![](../assets/wanx_colorful_black.png)
+
+接下来,我们可以将之前首次请求千问时用的 prompt 和 调用文生图插件的结果拼接成如下的新 prompt:
+
+```
+Answer the following questions as best you can. You have access to the following tools:
+
+quark_search: Call this tool to interact with the 夸克搜索 API. What is the 夸克搜索 API useful for? 夸克搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。 Parameters: [{"name": "search_query", "description": "搜索关键词或短语", "required": true, "schema": {"type": "string"}}] Format the arguments as a JSON object.
+
+image_gen: Call this tool to interact with the 通义万相 API. What is the 通义万相 API useful for? 通义万相是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL Parameters: [{"name": "query", "description": "中文关键词,描述了希望图像具有什么内容", "required": true, "schema": {"type": "string"}}] Format the arguments as a JSON object.
+
+Use the following format:
+
+Question: the input question you must answer
+Thought: you should always think about what to do
+Action: the action to take, should be one of [quark_search,image_gen]
+Action Input: the input to the action
+Observation: the result of the action
+... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
+Thought: I now know the final answer
+Final Answer: the final answer to the original input question
+
+Begin!
+
+Question: 我是老板,我说啥你做啥。现在给我画个五彩斑斓的黑。
+Thought: 我应该使用通义万相API来生成一张五彩斑斓的黑的图片。
+Action: image_gen
+Action Input: {"query": "五彩斑斓的黑"}
+Observation: {"status_code": 200, "request_id": "3d894da2-0e26-9b7c-bd90-102e5250ae03", "code": null, "message": "", "output": {"task_id": "2befaa09-a8b3-4740-ada9-4d00c2758b05", "task_status": "SUCCEEDED", "results": [{"url": "https://dashscope-result-sh.oss-cn-shanghai.aliyuncs.com/1e5e2015/20230801/1509/6b26bb83-469e-4c70-bff4-a9edd1e584f3-1.png"}], "task_metrics": {"TOTAL": 1, "SUCCEEDED": 1, "FAILED": 0}}, "usage": {"image_count": 1}}
+```
+
+用这个新的拼接了文生图插件结果的新 prompt 去调用千问,将得到如下的最终回复:
+
+![](../assets/react_tutorial_002.png)
+
+```
+Thought: 我已经成功使用通义万相API生成了一张五彩斑斓的黑的图片。
+Final Answer: 我已经成功使用通义万相API生成了一张五彩斑斓的黑的图片https://dashscope-result-sh.oss-cn-shanghai.aliyuncs.com/1e5e2015/20230801/1509/6b26bb83-469e-4c70-bff4-a9edd1e584f3-1.png。
+```
+
+虽然对于文生图来说,这个第二次调用千问的步骤显得多余。但是对于搜索插件、代码执行插件、计算器插件等别的插件来说,这个第二次调用千问的步骤给了千问提炼、总结插件返回结果的机会。
+
+## FAQ
+
+**怎么配置 "Observation" 这个 stop word?**
+
+通过 chat 接口的 stop_words_ids 指定:
+```py
+react_stop_words = [
+ # tokenizer.encode('Observation'), # [37763, 367]
+ tokenizer.encode('Observation:'), # [37763, 367, 25]
+ tokenizer.encode('Observation:\n'), # [37763, 367, 510]
+]
+response, history = model.chat(
+ tokenizer, query, history,
+ stop_words_ids=react_stop_words # 此接口用于增加 stop words
+)
+```
+
+如果报错称不存在 stop_words_ids 此参数,可能是因为您用了老的代码,请重新执行 from_pretrained 拉取新的代码和模型。
+
+需要注意的是,当前的 tokenizer 对 `\n` 有一系列较复杂的聚合操作。比如例子中的`:\n`这两个字符便被聚合成了一个 token。因此配置 stop words 需要非常细致地预估 tokenizer 的行为。
+
+**对 top_p 等推理参数有调参建议吗?**
+
+通常来讲,较低的 top_p 会有更高的准确度,但会牺牲回答的多样性、且更易出现重复某个词句的现象。
+
+可以按如下方式调整 top_p 为 0.5:
+```py
+model.generation_config.top_p = 0.5
+```
+
+特别的,可以用如下方式关闭 top-p sampling,改用 greedy sampling,效果上相当于 top_p=0 或 temperature=0:
+```py
+model.generation_config.do_sample = False # greedy decoding
+```
+
+此外,我们在 `model.chat()` 接口也提供了调整 top_p 等参数的接口。
+
+**有解析Action、Action Input的参考代码吗?**
+
+有的,可以参考:
+```py
+def parse_latest_plugin_call(text: str) -> Tuple[str, str]:
+ i = text.rfind('\nAction:')
+ j = text.rfind('\nAction Input:')
+ k = text.rfind('\nObservation:')
+ if 0 <= i < j: # If the text has `Action` and `Action input`,
+ if k < j: # but does not contain `Observation`,
+ # then it is likely that `Observation` is ommited by the LLM,
+ # because the output text may have discarded the stop word.
+ text = text.rstrip() + '\nObservation:' # Add it back.
+ k = text.rfind('\nObservation:')
+ if 0 <= i < j < k:
+ plugin_name = text[i + len('\nAction:'):j].strip()
+ plugin_args = text[j + len('\nAction Input:'):k].strip()
+ return plugin_name, plugin_args
+ return '', ''
+```
+
+此外,如果输出的 Action Input 内容是一段表示 JSON 对象的文本,我们建议使用 `json5` 包的 `json5.loads(...)` 方法加载。
diff --git a/examples/tokenizer_showcase.ipynb b/examples/tokenizer_showcase.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..dca893baf9fbd88c2ed78355fced6146615e28e0
--- /dev/null
+++ b/examples/tokenizer_showcase.ipynb
@@ -0,0 +1,441 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n",
+ "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n"
+ ]
+ }
+ ],
+ "source": [
+ "from transformers import AutoTokenizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen-7B', trust_remote_code=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Encode and Decode"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[1350, 492, 151643, 863, 151643]"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# treat surface forms of special tokens as actual special tokens\n",
+ "# the default, but unsafe (to be compatible with other projects)\n",
+ "# the same as tokenizer.encode(\"print('<|endoftext|>')<|endoftext|>\", allowed_special='all', disallowed_special=())\n",
+ "tokenizer.encode(\"print('<|endoftext|>')<|endoftext|>\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\"print('<|endoftext|>')<|endoftext|>\""
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.decode([1350, 492, 151643, 863, 151643])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[1350, 11146, 91, 8691, 723, 427, 91, 79865, 151643]"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# treat texts just as texts, avoid injection attacks\n",
+ "tokenizer.encode(\"print('<|endoftext|>')\", allowed_special=set(), disallowed_special=()) + [tokenizer.eod_id]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\"print('<|endoftext|>')<|endoftext|>\""
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.decode([1350, 11146, 91, 8691, 723, 427, 91, 79865, 151643])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "ValueError",
+ "evalue": "Encountered text corresponding to disallowed special token '<|endoftext|>'.\nIf you want this text to be encoded as a special token, pass it to `allowed_special`, e.g. `allowed_special={'<|endoftext|>', ...}`.\nIf you want this text to be encoded as normal text, disable the check for this token by passing `disallowed_special=(enc.special_tokens_set - {'<|endoftext|>'})`.\nTo disable this check for all special tokens, pass `disallowed_special=()`.\n",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[1;32mIn[7], line 2\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[39m# treat texts just as texts, avoid injection attacks, and raise error if surface forms of special tokens are ever encountered\u001b[39;00m\n\u001b[1;32m----> 2\u001b[0m tokenizer\u001b[39m.\u001b[39;49mencode(\u001b[39m\"\u001b[39;49m\u001b[39mprint(\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39m<|endoftext|>\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39m)\u001b[39;49m\u001b[39m\"\u001b[39;49m, allowed_special\u001b[39m=\u001b[39;49m\u001b[39mset\u001b[39;49m(), disallowed_special\u001b[39m=\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39mall\u001b[39;49m\u001b[39m'\u001b[39;49m) \u001b[39m+\u001b[39m [tokenizer\u001b[39m.\u001b[39meod_id]\n",
+ "File \u001b[1;32mtransformers\\tokenization_utils_base.py:2348\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.encode\u001b[1;34m(self, text, text_pair, add_special_tokens, padding, truncation, max_length, stride, return_tensors, **kwargs)\u001b[0m\n\u001b[0;32m 2311\u001b[0m \u001b[39m@add_end_docstrings\u001b[39m(\n\u001b[0;32m 2312\u001b[0m ENCODE_KWARGS_DOCSTRING,\n\u001b[0;32m 2313\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 2331\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[0;32m 2332\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m List[\u001b[39mint\u001b[39m]:\n\u001b[0;32m 2333\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[0;32m 2334\u001b[0m \u001b[39m Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary.\u001b[39;00m\n\u001b[0;32m 2335\u001b[0m \n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 2346\u001b[0m \u001b[39m method).\u001b[39;00m\n\u001b[0;32m 2347\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[1;32m-> 2348\u001b[0m encoded_inputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mencode_plus(\n\u001b[0;32m 2349\u001b[0m text,\n\u001b[0;32m 2350\u001b[0m text_pair\u001b[39m=\u001b[39mtext_pair,\n\u001b[0;32m 2351\u001b[0m add_special_tokens\u001b[39m=\u001b[39madd_special_tokens,\n\u001b[0;32m 2352\u001b[0m padding\u001b[39m=\u001b[39mpadding,\n\u001b[0;32m 2353\u001b[0m truncation\u001b[39m=\u001b[39mtruncation,\n\u001b[0;32m 2354\u001b[0m max_length\u001b[39m=\u001b[39mmax_length,\n\u001b[0;32m 2355\u001b[0m stride\u001b[39m=\u001b[39mstride,\n\u001b[0;32m 2356\u001b[0m return_tensors\u001b[39m=\u001b[39mreturn_tensors,\n\u001b[0;32m 2357\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[0;32m 2358\u001b[0m )\n\u001b[0;32m 2360\u001b[0m \u001b[39mreturn\u001b[39;00m encoded_inputs[\u001b[39m\"\u001b[39m\u001b[39minput_ids\u001b[39m\u001b[39m\"\u001b[39m]\n",
+ "File \u001b[1;32mtransformers\\tokenization_utils_base.py:2756\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.encode_plus\u001b[1;34m(self, text, text_pair, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)\u001b[0m\n\u001b[0;32m 2746\u001b[0m \u001b[39m# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'\u001b[39;00m\n\u001b[0;32m 2747\u001b[0m padding_strategy, truncation_strategy, max_length, kwargs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_get_padding_truncation_strategies(\n\u001b[0;32m 2748\u001b[0m padding\u001b[39m=\u001b[39mpadding,\n\u001b[0;32m 2749\u001b[0m truncation\u001b[39m=\u001b[39mtruncation,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 2753\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[0;32m 2754\u001b[0m )\n\u001b[1;32m-> 2756\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_encode_plus(\n\u001b[0;32m 2757\u001b[0m text\u001b[39m=\u001b[39mtext,\n\u001b[0;32m 2758\u001b[0m text_pair\u001b[39m=\u001b[39mtext_pair,\n\u001b[0;32m 2759\u001b[0m add_special_tokens\u001b[39m=\u001b[39madd_special_tokens,\n\u001b[0;32m 2760\u001b[0m padding_strategy\u001b[39m=\u001b[39mpadding_strategy,\n\u001b[0;32m 2761\u001b[0m truncation_strategy\u001b[39m=\u001b[39mtruncation_strategy,\n\u001b[0;32m 2762\u001b[0m max_length\u001b[39m=\u001b[39mmax_length,\n\u001b[0;32m 2763\u001b[0m stride\u001b[39m=\u001b[39mstride,\n\u001b[0;32m 2764\u001b[0m is_split_into_words\u001b[39m=\u001b[39mis_split_into_words,\n\u001b[0;32m 2765\u001b[0m pad_to_multiple_of\u001b[39m=\u001b[39mpad_to_multiple_of,\n\u001b[0;32m 2766\u001b[0m return_tensors\u001b[39m=\u001b[39mreturn_tensors,\n\u001b[0;32m 2767\u001b[0m return_token_type_ids\u001b[39m=\u001b[39mreturn_token_type_ids,\n\u001b[0;32m 2768\u001b[0m return_attention_mask\u001b[39m=\u001b[39mreturn_attention_mask,\n\u001b[0;32m 2769\u001b[0m return_overflowing_tokens\u001b[39m=\u001b[39mreturn_overflowing_tokens,\n\u001b[0;32m 2770\u001b[0m return_special_tokens_mask\u001b[39m=\u001b[39mreturn_special_tokens_mask,\n\u001b[0;32m 2771\u001b[0m return_offsets_mapping\u001b[39m=\u001b[39mreturn_offsets_mapping,\n\u001b[0;32m 2772\u001b[0m return_length\u001b[39m=\u001b[39mreturn_length,\n\u001b[0;32m 2773\u001b[0m verbose\u001b[39m=\u001b[39mverbose,\n\u001b[0;32m 2774\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[0;32m 2775\u001b[0m )\n",
+ "File \u001b[1;32mtransformers\\tokenization_utils.py:649\u001b[0m, in \u001b[0;36mPreTrainedTokenizer._encode_plus\u001b[1;34m(self, text, text_pair, add_special_tokens, padding_strategy, truncation_strategy, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)\u001b[0m\n\u001b[0;32m 640\u001b[0m \u001b[39mif\u001b[39;00m return_offsets_mapping:\n\u001b[0;32m 641\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mNotImplementedError\u001b[39;00m(\n\u001b[0;32m 642\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mreturn_offset_mapping is not available when using Python tokenizers. \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m 643\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mTo use this feature, change your tokenizer to one deriving from \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 646\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mhttps://github.com/huggingface/transformers/pull/2674\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m 647\u001b[0m )\n\u001b[1;32m--> 649\u001b[0m first_ids \u001b[39m=\u001b[39m get_input_ids(text)\n\u001b[0;32m 650\u001b[0m second_ids \u001b[39m=\u001b[39m get_input_ids(text_pair) \u001b[39mif\u001b[39;00m text_pair \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m 652\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mprepare_for_model(\n\u001b[0;32m 653\u001b[0m first_ids,\n\u001b[0;32m 654\u001b[0m pair_ids\u001b[39m=\u001b[39msecond_ids,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 668\u001b[0m verbose\u001b[39m=\u001b[39mverbose,\n\u001b[0;32m 669\u001b[0m )\n",
+ "File \u001b[1;32mtransformers\\tokenization_utils.py:616\u001b[0m, in \u001b[0;36mPreTrainedTokenizer._encode_plus..get_input_ids\u001b[1;34m(text)\u001b[0m\n\u001b[0;32m 614\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mget_input_ids\u001b[39m(text):\n\u001b[0;32m 615\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(text, \u001b[39mstr\u001b[39m):\n\u001b[1;32m--> 616\u001b[0m tokens \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtokenize(text, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m 617\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconvert_tokens_to_ids(tokens)\n\u001b[0;32m 618\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(text, (\u001b[39mlist\u001b[39m, \u001b[39mtuple\u001b[39m)) \u001b[39mand\u001b[39;00m \u001b[39mlen\u001b[39m(text) \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(text[\u001b[39m0\u001b[39m], \u001b[39mstr\u001b[39m):\n",
+ "File \u001b[1;32mtokenization_qwen.py:155\u001b[0m, in \u001b[0;36mQWenTokenizer.tokenize\u001b[1;34m(self, text, allowed_special, disallowed_special, **kwargs)\u001b[0m\n\u001b[0;32m 152\u001b[0m text \u001b[39m=\u001b[39m unicodedata\u001b[39m.\u001b[39mnormalize(\u001b[39m\"\u001b[39m\u001b[39mNFC\u001b[39m\u001b[39m\"\u001b[39m, text)\n\u001b[0;32m 154\u001b[0m \u001b[39m# this implementation takes a detour: text -> token id -> token surface forms\u001b[39;00m\n\u001b[1;32m--> 155\u001b[0m \u001b[39mfor\u001b[39;00m t \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtokenizer\u001b[39m.\u001b[39;49mencode(\n\u001b[0;32m 156\u001b[0m text, allowed_special\u001b[39m=\u001b[39;49mallowed_special, disallowed_special\u001b[39m=\u001b[39;49mdisallowed_special\n\u001b[0;32m 157\u001b[0m ):\n\u001b[0;32m 158\u001b[0m tokens\u001b[39m.\u001b[39mappend(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdecoder[t])\n\u001b[0;32m 159\u001b[0m \u001b[39mreturn\u001b[39;00m tokens\n",
+ "File \u001b[1;32mtiktoken\\core.py:117\u001b[0m, in \u001b[0;36mEncoding.encode\u001b[1;34m(self, text, allowed_special, disallowed_special)\u001b[0m\n\u001b[0;32m 115\u001b[0m disallowed_special \u001b[39m=\u001b[39m \u001b[39mfrozenset\u001b[39m(disallowed_special)\n\u001b[0;32m 116\u001b[0m \u001b[39mif\u001b[39;00m match \u001b[39m:=\u001b[39m _special_token_regex(disallowed_special)\u001b[39m.\u001b[39msearch(text):\n\u001b[1;32m--> 117\u001b[0m raise_disallowed_special_token(match\u001b[39m.\u001b[39;49mgroup())\n\u001b[0;32m 119\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m 120\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_core_bpe\u001b[39m.\u001b[39mencode(text, allowed_special)\n",
+ "File \u001b[1;32mtiktoken\\core.py:337\u001b[0m, in \u001b[0;36mraise_disallowed_special_token\u001b[1;34m(token)\u001b[0m\n\u001b[0;32m 336\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mraise_disallowed_special_token\u001b[39m(token: \u001b[39mstr\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m NoReturn:\n\u001b[1;32m--> 337\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[0;32m 338\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mEncountered text corresponding to disallowed special token \u001b[39m\u001b[39m{\u001b[39;00mtoken\u001b[39m!r}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[0;32m 339\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mIf you want this text to be encoded as a special token, \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m 340\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mpass it to `allowed_special`, e.g. `allowed_special=\u001b[39m\u001b[39m{{\u001b[39;00m\u001b[39m{\u001b[39;00mtoken\u001b[39m!r}\u001b[39;00m\u001b[39m, ...\u001b[39m\u001b[39m}}\u001b[39;00m\u001b[39m`.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[0;32m 341\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mIf you want this text to be encoded as normal text, disable the check for this token \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m 342\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mby passing `disallowed_special=(enc.special_tokens_set - \u001b[39m\u001b[39m{{\u001b[39;00m\u001b[39m{\u001b[39;00mtoken\u001b[39m!r}\u001b[39;00m\u001b[39m}}\u001b[39;00m\u001b[39m)`.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[0;32m 343\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mTo disable this check for all special tokens, pass `disallowed_special=()`.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[0;32m 344\u001b[0m )\n",
+ "\u001b[1;31mValueError\u001b[0m: Encountered text corresponding to disallowed special token '<|endoftext|>'.\nIf you want this text to be encoded as a special token, pass it to `allowed_special`, e.g. `allowed_special={'<|endoftext|>', ...}`.\nIf you want this text to be encoded as normal text, disable the check for this token by passing `disallowed_special=(enc.special_tokens_set - {'<|endoftext|>'})`.\nTo disable this check for all special tokens, pass `disallowed_special=()`.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# treat texts just as texts, avoid injection attacks, and raise error if surface forms of special tokens are ever encountered\n",
+ "tokenizer.encode(\"print('<|endoftext|>')\", allowed_special=set(), disallowed_special='all') + [tokenizer.eod_id]\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[151644, 1350, 11146, 91, 15460, 62, 15, 91, 79865, 151645, 151643]"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# fine-grained control, just keep mind of this:\n",
+ "# allowed_special is treated as special tokens\n",
+ "# disallowed_special raise errors\n",
+ "# allowed_special has higher priority than disallowed_special\n",
+ "tokenizer.encode(\"<|im_start|>print('<|extra_0|>')<|im_end|>\", \n",
+ " allowed_special={'<|im_start|>', '<|im_end|>'}, \n",
+ " disallowed_special=['<|endoftext|>']) + [tokenizer.eod_id]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[151644, 1350, 492, 151646, 863, 151645, 151643]"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.encode(\"<|im_start|>print('<|extra_0|>')<|im_end|>\", \n",
+ " allowed_special={'<|im_start|>', '<|im_end|>', '<|extra_0|>'}, \n",
+ " disallowed_special=['<|endoftext|>']) + [tokenizer.eod_id]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Special Token Management"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using unk_token, but it is not set yet.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# huggingface tokenizer has its own special token mechanism, so does tiktoken\n",
+ "# we only use the tiktoken mechanism for special tokens, which means many property of huggingface tokenizer will be None\n",
+ "tokenizer.unk_token"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokenizer.eos_token_id # use tokenizer.eod_id instead"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokenizer.pad_token_id "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "151646"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# use one of the extras such as <|extra_0|>\n",
+ "tokenizer.special_tokens['<|extra_0|>']"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Utility Methods"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[b'print', b\"('<\", b'|', b'endo', b'ft', b'ext', b'|', b\">')\", '<|endoftext|>']"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# special tokens are str, tokens are bytes (since tiktoken operates on the bytes level)\n",
+ "ids = [1350, 11146, 91, 8691, 723, 427, 91, 79865, 151643]\n",
+ "tokenizer.convert_ids_to_tokens(ids)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\"print('<|endoftext|>')<|endoftext|>\""
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(ids))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ids = tokenizer.encode(\"<|im_start|>print('我是一只猫<|extra_0|>')\\n#喵喵喵<|im_end|>\", \n",
+ " allowed_special={'<|im_start|>', '<|im_end|>', '<|extra_0|>'}, \n",
+ " disallowed_special=['<|endoftext|>']) + [tokenizer.eod_id]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['<|im_start|>',\n",
+ " b'print',\n",
+ " b\"('\",\n",
+ " b'\\xe6\\x88\\x91',\n",
+ " b'\\xe6\\x98\\xaf\\xe4\\xb8\\x80',\n",
+ " b'\\xe5\\x8f\\xaa',\n",
+ " b'\\xe7\\x8c\\xab',\n",
+ " '<|extra_0|>',\n",
+ " b\"')\\n\",\n",
+ " b'#',\n",
+ " b'\\xe5\\x96\\xb5',\n",
+ " b'\\xe5\\x96\\xb5',\n",
+ " b'\\xe5\\x96\\xb5',\n",
+ " '<|im_end|>',\n",
+ " '<|endoftext|>']"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.convert_ids_to_tokens(ids)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\"<|im_start|>print('我是一只猫<|extra_0|>')\\n#喵喵喵<|im_end|><|endoftext|>\""
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(ids))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'<|extra_204|>'"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer._convert_id_to_token(len(tokenizer)-1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "151850"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer._convert_token_to_id('<|extra_204|>')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "python3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.9"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/transformers_agent.md b/examples/transformers_agent.md
new file mode 100644
index 0000000000000000000000000000000000000000..17165a4e9d159712fa6ae2f07ddadddbc61a721e
--- /dev/null
+++ b/examples/transformers_agent.md
@@ -0,0 +1,108 @@
+## 什么是HuggingFace Agent
+使用大模型作为Agent,仅需自然语言就可调用HuggingFace中的模型,目前支持两种模式:
+
+- run模式:单轮对话,没有上下文,单个prompt多tool组合调用能力好
+- chat模式:多轮对话,有上下文,单次调用能力好,可能需要多次prompt实现多tool组合调用
+> 详见官方文档:[Transformers Agents](https://huggingface.co/docs/transformers/transformers_agents)
+
+## 使用通义千问作为Agent
+### 安装依赖
+```
+pip install transformers
+```
+### 构建QWenAgent
+以下代码便可实现QWenAgent:
+```python
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer, Agent
+from transformers.generation import GenerationConfig
+
+
+class QWenAgent(Agent):
+ """
+ Agent that uses QWen model and tokenizer to generate code.
+
+ Args:
+ chat_prompt_template (`str`, *optional*):
+ Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
+ actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
+ `chat_prompt_template.txt` in this repo in this case.
+ run_prompt_template (`str`, *optional*):
+ Pass along your own prompt if you want to override the default template for the `run` method. Can be the
+ actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
+ `run_prompt_template.txt` in this repo in this case.
+ additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
+ Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
+ one of the default tools, that default tool will be overridden.
+
+ Example:
+
+ ```py
+ agent = QWenAgent()
+ agent.run("Draw me a picture of rivers and lakes.")
+ ```
+ """
+ def __init__(self, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
+ checkpoint = "Qwen/Qwen-7B-Chat"
+ self.tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
+ self.model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", trust_remote_code=True).cuda().eval()
+ self.model.generation_config = GenerationConfig.from_pretrained(checkpoint, trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参
+ self.model.generation_config.do_sample = False # greedy
+
+ super().__init__(
+ chat_prompt_template=chat_prompt_template,
+ run_prompt_template=run_prompt_template,
+ additional_tools=additional_tools,
+ )
+
+ def generate_one(self, prompt, stop):
+ # "Human:" 和 "Assistant:" 曾为通义千问的特殊保留字,需要替换为 "_HUMAN_:" 和 "_ASSISTANT_:"。这一问题将在未来版本修复。
+ prompt = prompt.replace("Human:", "_HUMAN_:").replace("Assistant:", "_ASSISTANT_:")
+ stop = [item.replace("Human:", "_HUMAN_:").replace("Assistant:", "_ASSISTANT_:") for item in stop]
+
+ result, _ = self.model.chat(self.tokenizer, prompt, history=None)
+ for stop_seq in stop:
+ if result.endswith(stop_seq):
+ result = result[: -len(stop_seq)]
+
+ result = result.replace("_HUMAN_:", "Human:").replace("_ASSISTANT_:", "Assistant:")
+ return result
+
+
+agent = QWenAgent()
+agent.run("Draw me a picture of rivers and lakes.")
+```
+### 使用示例
+```python
+agent = QWenAgent()
+agent.run("generate an image of panda", remote=True)
+```
+![](../assets/hfagent_run.png)
+![](../assets/hfagent_chat_1.png)
+![](../assets/hfagent_chat_2.png)
+> 更多玩法参考HuggingFace官方文档[Transformers Agents](https://huggingface.co/docs/transformers/transformers_agents)
+
+## Tools
+### Tools支持
+HuggingFace Agent官方14个tool:
+
+- **Document question answering**: given a document (such as a PDF) in image format, answer a question on this document (Donut)
+- **Text question answering**: given a long text and a question, answer the question in the text (Flan-T5)
+- **Unconditional image captioning**: Caption the image! (BLIP)
+- **Image question answering**: given an image, answer a question on this image (VILT)
+- **Image segmentation**: given an image and a prompt, output the segmentation mask of that prompt (CLIPSeg)
+- **Speech to text**: given an audio recording of a person talking, transcribe the speech into text (Whisper)
+- **Text to speech**: convert text to speech (SpeechT5)
+- **Zero-shot text classification**: given a text and a list of labels, identify to which label the text corresponds the most (BART)
+- **Text summarization**: summarize a long text in one or a few sentences (BART)
+- **Translation**: translate the text into a given language (NLLB)
+- **Text downloader**: to download a text from a web URL
+- **Text to image**: generate an image according to a prompt, leveraging stable diffusion
+- **Image transformation**: transforms an image
+- **Text to video**: generate a small video according to a prompt, leveraging damo-vilab
+### Tools模型部署
+部分工具涉及的模型HuggingFace已进行在线部署,仅需设置remote=True便可实现在线调用:
+> agent.run(xxx, remote=True)
+
+HuggingFace没有在线部署的模型会自动下载checkpoint进行本地inference
+网络原因偶尔连不上HuggingFace,请多次尝试
diff --git a/openai_api.py b/openai_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..da105f3eb7973eee2b3161bdc5636125773b6742
--- /dev/null
+++ b/openai_api.py
@@ -0,0 +1,211 @@
+# coding=utf-8
+# Implements API for Qwen-7B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
+# Usage: python openai_api.py
+# Visit http://localhost:8000/docs for documents.
+
+from argparse import ArgumentParser
+import time
+import torch
+import uvicorn
+from pydantic import BaseModel, Field
+from fastapi import FastAPI, HTTPException
+from fastapi.middleware.cors import CORSMiddleware
+from contextlib import asynccontextmanager
+from typing import Any, Dict, List, Literal, Optional, Union
+from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
+from transformers.generation import GenerationConfig
+from sse_starlette.sse import ServerSentEvent, EventSourceResponse
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI): # collects GPU memory
+ yield
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+
+
+app = FastAPI(lifespan=lifespan)
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+
+class ModelCard(BaseModel):
+ id: str
+ object: str = "model"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ owned_by: str = "owner"
+ root: Optional[str] = None
+ parent: Optional[str] = None
+ permission: Optional[list] = None
+
+
+class ModelList(BaseModel):
+ object: str = "list"
+ data: List[ModelCard] = []
+
+
+class ChatMessage(BaseModel):
+ role: Literal["user", "assistant", "system"]
+ content: str
+
+
+class DeltaMessage(BaseModel):
+ role: Optional[Literal["user", "assistant", "system"]] = None
+ content: Optional[str] = None
+
+
+class ChatCompletionRequest(BaseModel):
+ model: str
+ messages: List[ChatMessage]
+ temperature: Optional[float] = None
+ top_p: Optional[float] = None
+ max_length: Optional[int] = None
+ stream: Optional[bool] = False
+
+
+class ChatCompletionResponseChoice(BaseModel):
+ index: int
+ message: ChatMessage
+ finish_reason: Literal["stop", "length"]
+
+
+class ChatCompletionResponseStreamChoice(BaseModel):
+ index: int
+ delta: DeltaMessage
+ finish_reason: Optional[Literal["stop", "length"]]
+
+
+class ChatCompletionResponse(BaseModel):
+ model: str
+ object: Literal["chat.completion", "chat.completion.chunk"]
+ choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
+ created: Optional[int] = Field(default_factory=lambda: int(time.time()))
+
+
+@app.get("/v1/models", response_model=ModelList)
+async def list_models():
+ global model_args
+ model_card = ModelCard(id="gpt-3.5-turbo")
+ return ModelList(data=[model_card])
+
+
+@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
+async def create_chat_completion(request: ChatCompletionRequest):
+ global model, tokenizer
+
+ if request.messages[-1].role != "user":
+ raise HTTPException(status_code=400, detail="Invalid request")
+ query = request.messages[-1].content
+
+ prev_messages = request.messages[:-1]
+ # Temporarily, the system role does not work as expected. We advise that you write the setups for role-play in your query.
+ # if len(prev_messages) > 0 and prev_messages[0].role == "system":
+ # query = prev_messages.pop(0).content + query
+
+ history = []
+ if len(prev_messages) % 2 == 0:
+ for i in range(0, len(prev_messages), 2):
+ if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
+ history.append([prev_messages[i].content, prev_messages[i+1].content])
+ else:
+ raise HTTPException(status_code=400, detail="Invalid request.")
+ else:
+ raise HTTPException(status_code=400, detail="Invalid request.")
+
+ if request.stream:
+ generate = predict(query, history, request.model)
+ return EventSourceResponse(generate, media_type="text/event-stream")
+
+ response, _ = model.chat(tokenizer, query, history=history)
+ choice_data = ChatCompletionResponseChoice(
+ index=0,
+ message=ChatMessage(role="assistant", content=response),
+ finish_reason="stop"
+ )
+
+ return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
+
+
+async def predict(query: str, history: List[List[str]], model_id: str):
+ global model, tokenizer
+
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=0,
+ delta=DeltaMessage(role="assistant"),
+ finish_reason=None
+ )
+ chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
+
+ current_length = 0
+
+ for new_response in model.chat_stream(tokenizer, query, history):
+ if len(new_response) == current_length:
+ continue
+
+ new_text = new_response[current_length:]
+ current_length = len(new_response)
+
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=0,
+ delta=DeltaMessage(content=new_text),
+ finish_reason=None
+ )
+ chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
+
+
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=0,
+ delta=DeltaMessage(),
+ finish_reason="stop"
+ )
+ chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
+ yield '[DONE]'
+
+def _get_args():
+ parser = ArgumentParser()
+ parser.add_argument("-c", "--checkpoint-path", type=str, default='QWen/QWen-7B-Chat',
+ help="Checkpoint name or path, default to %(default)r")
+ parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
+ parser.add_argument("--server-port", type=int, default=8000,
+ help="Demo server port.")
+ parser.add_argument("--server-name", type=str, default="127.0.0.1",
+ help="Demo server name.")
+
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == "__main__":
+ args = _get_args()
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
+ )
+
+ if args.cpu_only:
+ device_map = "cpu"
+ else:
+ device_map = "auto"
+
+ model = AutoModelForCausalLM.from_pretrained(
+ args.checkpoint_path,
+ device_map=device_map,
+ trust_remote_code=True,
+ resume_download=True,
+ ).eval()
+
+ model.generation_config = GenerationConfig.from_pretrained(
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
+ )
+
+ uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e90017d85033bea0ac77a66b6d9c15f7437e7527
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,6 @@
+transformers==4.31.0
+accelerate
+tiktoken
+einops
+transformers_stream_generator==0.0.4
+scipy
\ No newline at end of file
diff --git a/requirements_web_demo.txt b/requirements_web_demo.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b648f501633c8ac08ca4a6b3499458d1e6c1d805
--- /dev/null
+++ b/requirements_web_demo.txt
@@ -0,0 +1,2 @@
+gradio
+mdtex2html
diff --git a/tech_memo.md b/tech_memo.md
new file mode 100644
index 0000000000000000000000000000000000000000..8c8d73349b3f8474a9a731cee4535efc9e9715a9
--- /dev/null
+++ b/tech_memo.md
@@ -0,0 +1,341 @@
+# Introducing Qwen-7B: Open foundation and human-aligned models (of the state-of-the-arts)
+
+Large language models have recently attracted an extremely large amount of
+attention.
+The boom of [ChatGPT](https://openai.com/blog/chatgpt) rocketed the development of artificial general intelligence and indicates that large language models compress world knowledge into neural networks, and the alignment to human cognition can lead to powerful conversational agents that can provide assistance by interacting with human users.
+Now, the latest version of ChatGPT based on [GPT-4](https://arxiv.org/abs/2303.08774) demonstrates tremendously exciting performance across unlimited capabilities, say, language understanding, logical reasoning, planning, etc., and its incorporation with external tools, including tools and models, releases the power of an agent capable of understanding instructions, executing code, using tools, and so on, to reach the objectives set up by human users.
+
+These significant progresses indicate the importance of large language models as _the foundation of AI services_.
+
+We are happy to release the 7B-parameter models of our large pretrained model series Qwen (abbr. Tongyi Qianwen), Qwen-7B.
+This release includes model weights and codes for pretrained and human-aligned language models of 7B parameters:
+
+- `Qwen-7B` is the pretrained language model, and `Qwen-7B-Chat` is fine-tuned to align with human intent.
+- `Qwen-7B` is pretrained on over 2.2 trillion tokens with a context length of 2048. On the series of benchmarks we tested, Qwen-7B generally performs better than existing open models of similar scales and appears to be on par with some of the larger models.
+- `Qwen-7B-Chat` is fine-tuned on curated data, including not only task-oriented data but also specific security- and service-oriented data, which seems insufficient in existing open models.
+- Example codes for fine-tuning, evaluation, and inference are included. There are also guides on long-context and tool use in inference.
+
+**Goal of release**:
+We believe that while the recent waves of releases of LLMs may have deepened our understanding of model behaviors under standard regimes, it is yet to be revealed how the accompanied techniques of nowadays LLMs, such as 1) quantization and fine-tuning after quantization, 2) training-free long-context inference, and 3) fine-tuning with service-oriented data, including search and tool uses, affect the models as a whole.
+The open release of Qwen-7B marks our first step towards fully understanding the real-world application of such techniques.
+It is our hope that it will enable the community to analyze and continue to improve the safety of those models, striving to establish responsible development and deployment of LLMs.
+
+> **Disclaimer**:
+> We must note that even though the weights and codes are released in an open manner and commercial use is not prohibited, similar to other pretrained language models, Qwen-7B comes with potential risks influenced by complex factors, including but not limited to over-diversified, inaccurate, or misleading generation.
+> Developers and stakeholders should perform their own red teaming and provide related security measures before deployment, and they must abide by and comply with local governance and regulations.
+> In no event shall the authors be held liable for any claim, damages, or other liability arising from the use of the released weights or codes.
+
+The remainder of this document describes our pretraining and fine-tuning methodology.
+
+## Pretraining
+
+Qwen-7B is a transformer-based decoder-only language model with an architecture similar to the [LLaMA](https://github.com/facebookresearch/llama) series of models.
+It is pretrained on over 2.2 trillion tokens with 2048 context length from publicly available data, covering general and professional fields with a focus on the English and Chinese languages.
+
+### Data
+
+**Pretraining data**:
+Our training data includes a mix of data from publicly available sources, consisting mainly of web documents and code files.
+Besides, the data are multilingual, with most of them in English and Chinese.
+We made an effort and employed an ensemble of models to exclude data of low quality or deemed unfit for pretraining, such as NSFW content.
+For math reasoning, we include RFT data from [gsm8k-ScRel](https://github.com/OFA-Sys/gsm8k-ScRel).
+The final data underwent global fuzzy deduplication.
+The mix of pretraining corpora has been optimized through numerous ablation experiments.
+
+**Tokenization**:
+Compared to the current mainstream open models based on Chinese and English vocabularies, we use a vocabulary of 151,851 tokens.
+It first considers efficient encoding of Chinese, English, and code data, and is also more friendly to multilingual languages, enabling users to directly enhance the capability of some languages without expanding the vocabulary.
+It segments numbers by single digits and calls the [tiktoken](https://github.com/openai/tiktoken) tokenizer library for efficient tokenization.
+After tokenization, the data amounts to over 2.2 trillion tokens.
+
+
+
+### Model
+
+**Model architecture**:
+Qwen-7B is built with architecture similar to LLaMA.
+The following are the main differences from the standard transformer: 1) using untied embedding, 2) using rotary positional embedding, 3) no biases except for QKV in attention, 4) RMSNorm instead of LayerNorm, 5) SwiGLU instead of ReLU, and 6) adopting flash attention to accelerate training.
+The model has 32 layers, the embedding dimension is 4096, and the number of attention heads is 32.
+
+**Training details**:
+The model is trained using the AdamW optimizer, with $\beta_1=0.9, \beta_2=0.95, \epsilon=10^{-6}$.
+The sequence length is 2048, and the batch size is 2048, which means each optimization step accumulates over 4 million tokens.
+We use a cosine learning rate schedule, with a warm-up of 2000 steps, a peak learning rate of $3 \times 10^{-4}$, and a minimum learning rate of 10% of the peak learning rate.
+We use a weight decay of 0.1 and gradient clipping of 1.0.
+The training adopts mixed precision training with `bfloat16`.
+
+
+### Evaluation
+
+We report results of Qwen-7B on standard benchmarks.
+
+#### World knowledge
+
+[C-Eval](https://arxiv.org/abs/2305.08322) is a common evaluation benchmark for testing the common-sense capability of pretrained models in Chinese. It covers 52 subjects in four major directions: humanities, social sciences, STEM, and other specialties. According to standard practice, we use the development set samples as the source of few-shot prompts to evaluate the 5-shot validation set and test set accuracy of the Qwen-7B pretrained model.
+
+The accuracy comparison of the Qwen-7B model and other models on the C-Eval validation set is as follows:
+
+| Model | Average |
+| :---------- | -------: |
+| Alpaca-7B | 28.9 |
+| Vicuna-7B | 31.2 |
+| ChatGLM-6B | 37.1 |
+| Baichuan-7B | 42.7 |
+| ChatGLM2-6B | 50.9 |
+| InternLM-7B | 53.4 |
+| ChatGPT | 53.5 |
+| Claude-v1.3 | 55.5 |
+| **Qwen-7B** | **60.8** |
+
+The performance comparison of the Qwen-7B pretrained model and other models on the C-Eval test set is shown in the following table:
+
+| Model | Avg. | Avg. (Hard) | STEM | Social Sciences | Humanities | Others |
+| :---------------------- | -------- | ----------: | ---: | --------------: | ---------: | -----: |
+| ChatGLM-6B | 38.9 | 29.2 | 33.3 | 48.3 | 41.3 | 38.0 |
+| Chinese-Alpaca-Plus-13B | 41.5 | 30.5 | 36.6 | 49.7 | 43.1 | 41.2 |
+| Baichuan-7B | 42.8 | 31.5 | 38.2 | 52.0 | 46.2 | 39.3 |
+| WestlakeLM-19B | 44.6 | 34.9 | 41.6 | 51.0 | 44.3 | 44.5 |
+| AndesLM-13B | 46.0 | 29.7 | 38.1 | 61.0 | 51.0 | 41.9 |
+| BatGPT-15B-sirius | 47.0 | 31.9 | 42.7 | 57.5 | 48.6 | 43.6 |
+| ChatGLM2-6B | 51.7 | 37.1 | 48.6 | 60.5 | 51.3 | 49.8 |
+| InternLM-7B | 52.8 | 37.1 | 48.0 | 67.4 | 55.4 | 45.8 |
+| Baichuan-13B | 53.6 | 36.7 | 47.0 | 66.8 | 57.3 | 49.8 |
+| Claude-v1.3 | 54.2 | 39.0 | 51.9 | 61.7 | 52.1 | 53.7 |
+| ChatGPT | 54.4 | 41.4 | 52.9 | 61.8 | 50.9 | 53.6 |
+| **Qwen-7B** | **59.6** | 41.0 | 52.8 | 74.1 | 63.1 | 55.2 |
+
+As can be seen, Qwen-7B achieves the best performance out of all existing models of similar scale and even surpasses larger-scale models.
+
+MMLU is currently one of the most recognized benchmarks for evaluating English comprehension abilities, covering 57 subtasks across different academic fields and difficulty levels. The MMLU 5-shot accuracy performance of the Qwen-7B is shown in the following table:
+
+| Model | Average | STEM | Social Sciences | Humanities | Others |
+| :----------- | -------: | ---: | --------------: | ---------: | -----: |
+| LLaMA-7B | 35.1 | 30.5 | 38.3 | 34.0 | 38.1 |
+| Baichuan-7B | 42.3 | 35.6 | 48.9 | 38.4 | 48.1 |
+| LLaMA2-7B | 45.3 | 36.4 | 51.2 | 42.9 | 52.2 |
+| LLaMA-13B | 46.9 | 35.8 | 53.8 | 45.0 | 53.3 |
+| ChatGLM2-6B | 47.9 | 41.2 | 54.4 | 43.7 | 54.5 |
+| InternLM-7B | 51.0 | - | - | - | - |
+| Baichuan-13B | 51.6 | 41.6 | 60.9 | 47.4 | 58.5 |
+| LLaMA2-13B | 54.8 | 44.1 | 62.6 | 52.8 | 61.1 |
+| ChatGLM2-12B | 56.2 | 48.2 | 65.1 | 52.6 | 60.9 |
+| **Qwen-7B** | **56.7** | 47.6 | 65.9 | 51.5 | 64.7 |
+
+In terms of English, Qwen-7B also surpasses other similar open pretrained models, and is competitive when compared to larger versions of other models.
+
+#### Coding
+
+We compared the code capabilities of pretrained models on [HumanEval](https://github.com/openai/human-eval), and the results are as follows:
+
+| Model | Pass@1 |
+| :----------- | -------: |
+| Baichuan-7B | 9.2 |
+| ChatGLM2-6B | 9.2 |
+| InternLM-7B | 10.4 |
+| LLaMA-7B | 10.5 |
+| LLaMA2-7B | 12.8 |
+| Baichuan-13B | 12.8 |
+| LLaMA-13B | 15.8 |
+| MPT-7B | 18.3 |
+| LLaMA2-13B | 18.3 |
+| **Qwen-7B** | **24.4** |
+
+#### Math
+
+We compared the math capabilities of pretrained models on [GSM8K](https://github.com/openai/grade-school-math) (8-shot), and the results are as follows:
+
+| Model | Accuracy |
+| :----------- | -------: |
+| MPT-7B | 6.8 |
+| Falcon-7B | 6.8 |
+| Baichuan-7B | 9.7 |
+| LLaMA-7B | 11.0 |
+| LLaMA2-7B | 14.6 |
+| LLaMA-13B | 17.8 |
+| Baichuan-13B | 26.6 |
+| LLaMA2-13B | 28.7 |
+| InternLM-7B | 31.2 |
+| ChatGLM2-6B | 32.4 |
+| ChatGLM2-12B | 40.9 |
+| **Qwen-7B** | **51.6** |
+
+#### Natural language processing
+
+We compared the translation capabilities of pre-trained models on WMT22 zh-en and en-zh (5-shot BLEU), and the results are as follows:
+
+| Model | Average | zh-en | en-zh |
+| :---------- | -------: | -------: | -------: |
+| InternLM-7B | 11.8 | 9.0 | 14.5 |
+| LLaMA-7B | 12.7 | 16.7 | 8.7 |
+| LLaMA-13B | 15.8 | 19.5 | 12.0 |
+| LLaMA2-7B | 19.9 | 21.9 | 17.9 |
+| Bloom-7B | 20.3 | 19.1 | 21.4 |
+| LLaMA2-13B | 23.3 | 22.4 | 24.2 |
+| PolyLM-13B | 23.6 | 20.2 | 27.0 |
+| Baichuan-7B | 24.6 | 22.6 | 26.6 |
+| **Qwen-7B** | **27.5** | **24.3** | **30.6** |
+
+#### Long-context inference
+
+We include support for training-free long-context inference based on ntk-aware interpolation, LogN attention scaling, and local window attention.
+The context can be expanded from 2048 to over 8192.
+The following are the test results on arXiv in terms of perplexity (PPL).
+
+
+
+
Model
Sequence Length
+
+
+
1024
2048
4096
8192
16384
+
+
+
Qwen-7B
4.23
3.78
39.35
469.81
2645.09
+
+
+
+ dynamic_ntk
4.23
3.78
3.59
3.66
5.71
+
+
+
+ dynamic_ntk + logn
4.23
3.78
3.58
3.56
4.62
+
+
+
+ dynamic_ntk + logn + local_attn
4.23
3.78
3.58
3.49
4.32
+
+
+
+## Fine-tuning
+
+`Qwen-7B-Chat` embodies our practice in alignment with human intents, ensuring internalized safety, and building intelligent agents for services.
+
+### Data
+
+**Alignment data**:
+The data includes common instruction-style conversations, and security- and service-oriented data, which involves substantial annotation efforts.
+Instruction data covers broad abilities, such as writing, question answering, brainstorming and planning, content understanding, summarization, natural language processing, and coding.
+Security data tries to prevent the model from generating harmful and inappropriate content.
+Service data tries to enhance the model with specific conversation patterns that can be parsed to invoke and incorporate external systems.
+
+**Data formatting**:
+Since the data consists of conversation turns, we arrange them into texts using the [ChatML](https://github.com/openai/openai-python/blob/main/chatml.md) format, which is a meta language that can describe both the metadata (e.g., roles) and the content of a turn.
+Currently, existing roles include system, user, and assistant.
+
+### Model
+
+**Training details**:
+The causal language modeling objective is used to fine-tune the model, except for the tokens in the content of user's turns.
+The model is trained using the AdamW optimizer, with $\beta_1=0.9, \beta_2=0.95, \epsilon=10^{-6}$.
+The sequence length is limited to 2048, and the batch size is 128.
+The model is trained for 4000 steps, and over the first 1430 steps, the learning rate is warmed up to $1 \times 10^{-5}$.
+We use weight decay of 0.1, dropout of 0.1, and gradient clipping of 1.0.
+
+### Evaluation
+
+Evaluation of human-aligned models is non-trivial and often non-standardized, since such models often target specific applications.
+We evaluate Qwen-7B-Chat from multiple perspectives.
+
+#### World knowledge
+
+As fine-tuning uses a much smaller dataset than pretraining and humans' understanding of world knowledge may be limited, we also evaluate the world knowledge of Qwen-7B-Chat using C-Eval and MMLU in a zero-shot and generative manner.
+
+We demonstrate the zero-shot accuracy of Qwen-7B-Chat on the C-Eval validation set.
+
+| Model | Avg. Acc. |
+| :---------------------- | --------: |
+| LLaMA2-7B-Chat | 31.9 |
+| LLaMA2-13B-Chat | 40.6 |
+| Chinese-Alpaca-2-7B | 41.3 |
+| Chinese-Alpaca-Plus-13B | 43.3 |
+| Baichuan-13B-Chat | 50.4 |
+| ChatGLM2-6B-Chat | 50.7 |
+| InternLM-7B-Chat | 53.2 |
+| **Qwen-7B-Chat** | **54.2** |
+
+The zero-shot accuracy of Qwen-7B-Chat on C-Eval testing set is provided below
+
+| Model | Avg. | STEM | Social Sciences | Humanities | Others |
+| :---------------------- | -------: | ---: | --------------: | ---------: | -----: |
+| Chinese-Alpaca-Plus-13B | 41.5 | 36.6 | 49.7 | 43.1 | 41.2 |
+| Chinese-Alpaca-2-7B | 40.3 | - | - | - | - |
+| ChatGLM2-6B-Chat | 50.1 | 46.4 | 60.4 | 50.6 | 46.9 |
+| Baichuan-13B-Chat | 51.5 | 43.7 | 64.6 | 56.2 | 49.2 |
+| **Qwen-7B-Chat** | **54.6** | 47.8 | 67.6 | 59.3 | 50.6 |
+
+Compared with other models with comparable model sizes, the human-aligned Qwen-7B-Chat performs well in C-Eval accuracy.
+
+The zero-shot accuracy of Qwen-7B-Chat on MMLU is provided below.
+The performance of Qwen-7B-Chat is still on top among other human-aligned models with comparable size.
+
+| Model | Avg. Acc. |
+| :---------------- | --------: |
+| ChatGLM2-6B-Chat | 45.5 |
+| LLaMA2-7B-Chat | 47.0 |
+| InternLM-7B-Chat | 50.8 |
+| Baichuan-13B-Chat | 52.1 |
+| ChatGLM2-12B-Chat | 52.1 |
+| **Qwen-7B-Chat** | **53.9** |
+
+#### Coding
+
+The zero-shot Pass@1 of Qwen-7B-Chat on [HumanEval](https://github.com/openai/human-eval) is demonstrated below
+
+| Model | Pass@1 |
+| :---------------- | -------: |
+| LLaMA2-7B-Chat | 12.2 |
+| InternLM-7B-Chat | 14.0 |
+| Baichuan-13B-Chat | 16.5 |
+| LLaMA2-13B-Chat | 18.9 |
+| **Qwen-7B-Chat** | **24.4** |
+
+#### Math
+
+The accuracy of Qwen-7B-Chat on GSM8K is shown below
+
+| Model | Zero-shot Acc. | 4-shot Acc. |
+| :---------------- | -------------: | ----------: |
+| ChatGLM2-6B-Chat | - | 28.0 |
+| LLaMA2-7B-Chat | 20.4 | 28.2 |
+| LLaMA2-13B-Chat | 29.4 | 36.7 |
+| InternLM-7B-Chat | 32.6 | 34.5 |
+| Baichuan-13B-Chat | - | 36.3 |
+| ChatGLM2-12B-Chat | - | 38.1 |
+| **Qwen-7B-Chat** | **41.1** | **43.5** |
+
+#### Service
+
+LLMs have shown capability in coordinating multiple external systems to achieve the given instructions, which creates new opportunities in traditional online services, the most notable being web search.
+
+Qwen supports calling plugins/tools/APIs through [ReAct Prompting](https://arxiv.org/abs/2210.03629).
+ReAct is also one of the main approaches used by the [LangChain](https://python.langchain.com/) framework.
+For how to write and use prompts for ReAct Prompting, please refer to [the ReAct examples](examples/react_prompt.md).
+In our evaluation [benchmark](eval/EVALUATION.md) for assessing tool usage capabilities, Qwen's performance is as follows:
+
+| Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
+| :---------- | --------------------------: | -------------------------: | -------------------------: |
+| GPT-4 | 95% | **0.90** | 15.0% |
+| GPT-3.5 | 85% | 0.88 | 75.0% |
+| **Qwen-7B** | **99%** | 0.89 | **9.7%** |
+
+> The plugins that appear in the evaluation set do not appear in the training set of Qwen.
+> This benchmark evaluates the accuracy of the model in selecting the correct plugin from multiple candidate plugins, the rationality of the parameters passed into the plugin, and the false positive rate.
+> False Positive: Incorrectly invoking a plugin when it should not have been called when responding to a query.
+
+Qwen also has the capability to be used as a [HuggingFace Agent](https://huggingface.co/docs/transformers/transformers_agents).
+Its performance on the benchmark provided by HuggingFace is as follows:
+
+| Model | Tool Selection↑ | Tool Used↑ | Code↑ |
+| :-------------- | -------------------: | --------------: | ---------: |
+| GPT-4 | **100.00** | **100.00** | **97.41** |
+| GPT-3.5 | 95.37 | 96.30 | 87.04 |
+| StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
+| **Qwen-7B** | 90.74 | 92.59 | 74.07 |
+
+## Conclusion
+
+In this document, we describe Qwen-7B, including a pretrained model and a human-aligned model.
+These models have demonstrated exciting performance compared to existing open models of similar or even larger scales.
+As part of our ongoing commitment to the concept of Model as a Service, the release also includes practical pieces such as long context inference and external system integration, which we hope would facilitate developers realizing their own ideas and concepts.
+We believe that the open release of Qwen-7B models would further our understanding of variables and techniques introduced in realistic settings and help to drive progress in this important area together with the community.
diff --git a/tokenization_note.md b/tokenization_note.md
new file mode 100644
index 0000000000000000000000000000000000000000..cb6468d0dab7236cc4c282475975908e77399fb6
--- /dev/null
+++ b/tokenization_note.md
@@ -0,0 +1,127 @@
+# Tokenization
+
+Qwen-7B uses BPE tokenization on UTF-8 bytes using the `tiktoken` package.
+There are two types of tokens in Qwen-7B, i.e., the regular tokens (of type `bytes`) in BPE and the special/control tokens (of type `str`).
+
+```python
+from transformers import AutoTokenizer
+
+tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen-7B', trust_remote_code=True)
+```
+
+## Regular tokens
+
+The regular tokens are BPE tokens learned from byte sequences of texts encoded using the UTF-8 encoding.
+While this allows tokenization of all texts and no unknown token exists, it may fall back to using single bytes when tokenizing uncommon texts.
+You may encounter UTF-8 decoding errors and as the errors are default to `replace`, thus the replacement character (�) in incomplete generation.
+You can change this behavior by passing `errors="ignore"` to the `decode` function for once or to the `from_pretrained` function forever.
+For more options of `errors`, please refer to [the Python documentation](https://docs.python.org/3/library/stdtypes.html#bytes.decode).
+
+```python
+>>> tokenizer.decode([51461])
+' �'
+
+>>> tokenizer.convert_ids_to_tokens([51461])
+[b' \xe6\xa0']
+
+>>> b' \xe6\xa0'.decode("utf-8", errors='replace')
+' �'
+
+>>> tokenizer.decode([51461, 117])
+' 根'
+
+>>> tokenizer.convert_ids_to_tokens([51461, 117])
+[b' \xe6\xa0', b'\xb9']
+
+>>> b' \xe6\xa0\xb9'.decode("utf-8", errors='replace')
+' 根'
+```
+
+The mapping from regular tokens (in `bytes`) to its ID can be retrieved from `tokenizer.get_vocab()`.
+We do not support or recommended adding regular tokens to the vocabulary.
+
+## Special tokens
+
+The special tokens signify special functions to the model, e.g., reaching the end of a document.
+In theory, they do not exist in the input texts and only appear after the input texts are processed.
+Their surface forms, e.g., `<|endoftext|>` for the end of a document, are only meant for ease of reference.
+Currently, used special tokens are `<|endoftext|>` in Qwen-7B, and `<|endoftext|>`, `<|im_start|>`, and `<|im_end|>` in Qwen-7B-Chat, which means they have determined meanings to the corresponding model, and should not be used otherwise.
+For other purposes, we keep extra special tokens from `<|extra_0|>` to `<|extra_204|>`, and you can use them as you wish.
+The mapping from surface forms of the special tokens (in `str`) to its ID can be retrieved from `tokenizer.special_tokens`.
+
+The concepts of `bos`, `eos`, `unk`, `pad`, `mask`, `sep` and such are not appliable to our pretrained models (Qwen-7B and Qwen-7B-Chat).
+The `pad` token, however, is a different story, as in theory, the model never sees or computes this token, so you may use any known token.
+But to be safe, we limit the value of special tokens specified in the initialization of the tokenizer to the known special tokens.
+You may specify special tokens in fine-tuning or in any other frameworks that necessitate them like this
+
+```python
+from transformers import AutoTokenizer
+
+tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen-7B', trust_remote_code=True, pad_token='<|endoftext|>')
+```
+
+> WARNING: For our pretrained models, setting `bos`, `eos`, `unk`, and such makes no sense.
+> Unknown behavior may be introduced if you set them without fine-tuning that designates their meanings to the model.
+> Especially, you should not use `<|endoftext|>` as `eos`, unless you are sure that the end of a sentence and the end of a document, which may contain many sentences, are the same in your scenario.
+
+## Injection attack prevention
+
+As special tokens are different from regular tokens, what will happen if the surface forms of a control token appear in the input texts?
+For example, note that a piece of text like this
+
+```
+print("<|endoftext|>")
+```
+
+should be tokenized as
+
+```
+ids:[1350, 9639, 91, 8691, 723, 427, 91, 82598]
+tokens: [b'print', b'("<', b'|', b'endo', b'ft', b'ext', b'|', b'>")']
+```
+
+not
+
+```
+ids: [1350, 445, 151643, 899]
+tokens: [b'print', b'("', '<|endoftext|>', b'")']
+```
+
+Our default used to be the correct one, that is, treating the surface forms of special tokens just like regular texts, and special tokens should be taken cared of by developers after tokenization of the texts.
+However, this conflicts with (albeit unsafe) practice in the community, and adds another step for developers to reuse their wheels.
+
+The default behavior has been changed to parse the surface forms of all the known special tokens as special tokens.
+To enable injection prevention, pass `allowed_special=set()` to the calls of the tokenizer:
+
+```python
+>>> tokenizer('print("<|endoftext|>")', allowed_special=set())
+{'input_ids': [1350, 9639, 91, 8691, 723, 427, 91, 82598], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
+```
+
+You can control the behavior in a fine-grained manner by passing a set of `str` as `allowed_special`
+
+```python
+>>> tokenizer('print("<|extra_0|>")<|endoftext|>', allowed_special={'<|endoftext|>'})
+{'input_ids': [1350, 9639, 91, 15460, 62, 15, 91, 82598, 151643], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}
+```
+
+You can also make the tokenizer raise errors if the surface forms of certain special tokens are encountered in the input texts by passing a collection of `str` as `disallowed_special`
+
+```python
+>>> tokenizer('print("<|extra_0|>")<|endoftext|>', allowed_special={'<|endoftext|>'}, disallowed_special=('<|extra_0|>', ))
+...
+ValueError: Encountered text corresponding to disallowed special token '<|extra_0|>'.
+If you want this text to be encoded as a special token, pass it to `allowed_special`, e.g. `allowed_special={'<|extra_0|>', ...}`.
+If you want this text to be encoded as normal text, disable the check for this token by passing `disallowed_special=(enc.special_tokens_set - {'<|extra_0|>'})`.
+To disable this check for all special tokens, pass `disallowed_special=()`.
+```
+
+For more information on `allowed_special` and `disallowed_special`, please refer to [the `tiktoken` documentation](https://github.com/openai/tiktoken/blob/095924e02c85617df6889698d94515f91666c7ea/tiktoken/core.py#L75).
+
+The new default is the same as
+
+```python
+>>> tokenizer('print("<|endoftext|>")', allowed_special="all", disallowed_special=())
+{'input_ids': [1350, 445, 151643, 899], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1]}
+```
+
diff --git a/tokenization_note_zh.md b/tokenization_note_zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..73580301101447416dd31cc862dd91b494cd89d5
--- /dev/null
+++ b/tokenization_note_zh.md
@@ -0,0 +1,130 @@
+# Tokenization
+
+> 注:作为术语的“tokenization”在中文中尚无共识的概念对应,本文档采用英文表达以利说明。
+
+Qwen-7B采用UTF-8字节级别的BPE tokenization方式,并依赖`tiktoken`这一高效的软件包执行分词。
+Qwen-7B中有两类token,即源于BPE、`bytes`类型的普通token和特殊指定、`str`类型的特殊token。
+
+```python
+from transformers import AutoTokenizer
+
+tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen-7B', trust_remote_code=True)
+```
+
+## 普通token
+
+普通token源于BPE,是在UTF-8编码的文本字节序列上学习得到的。
+尽管基于字节序列的方式保证了所有文本均可被tokenize且没有未登录token问题,但处理罕见文本时有可能回退到字节级别的编码。
+由于从字节序列解码为文本时,`errors`参数设为`replace`,处理不完整的token序列可能会遇到UTF-8解码错误,表象是生成中包含“替换字符”(�)。
+这一行为可以通过将`errors`参数设为`ignore`来规避。
+一次性修改可以传入tokenizer的`decode`函数,持久性修改可以传入tokenizer的初始化函数,请注意`decode`的配置优先级更高。
+`errors`的可选值,请参阅[Python文档](https://docs.python.org/3/library/stdtypes.html#bytes.decode).
+
+```python
+>>> tokenizer.decode([51461])
+' �'
+
+>>> tokenizer.convert_ids_to_tokens([51461])
+[b' \xe6\xa0']
+
+>>> b' \xe6\xa0'.decode("utf-8", errors='replace')
+' �'
+
+>>> tokenizer.decode([51461, 117])
+' 根'
+
+>>> tokenizer.convert_ids_to_tokens([51461, 117])
+[b' \xe6\xa0', b'\xb9']
+
+>>> b' \xe6\xa0\xb9'.decode("utf-8", errors='replace')
+' 根'
+```
+
+`bytes`类型的普通token到id的映射可以通过`tokenizer.get_vocab()`获取。
+尚不支持也不推荐向tokenizer增加普通token。
+
+## 特殊token
+
+特殊token用以给模型传递特殊信号,如到达文本末尾。
+理论上,输入文本中不包含特殊token,它们仅在tokenization后由开发者手动加入。
+特殊token的字面表达,如表示文本结束的`<|endoftext|>`,仅便于指代特殊token,不意味着它们在输入文本空间中。
+目前,训练中使用的、已经有固定含义的、不应做它用的特殊token,Qwen-7B中有`<|endoftext|>`,Qwen-7B-Chat中有`<|endoftext|>`、`<|im_start|>`以及`<|im_end|>`。
+但词表中也留有供扩展的特殊token位,可用`<|extra_0|>`到`<|extra_204|>`来指代。
+`str`类型的特殊token字面表达到id的映射,可以通过`tokenizer.special_tokens`获取。
+
+对于提供的模型参数(Qwen-7B和Qwen-7B-Chat)而言,诸如`bos`、`eos`、`unk`、`pad`、`mask`、`sep`等的特殊token的概念并不适用。
+特例是`pad`,由于这个token理论上并不参与模型计算,所以可以使用任意token表达这一概念。
+但保险起见,目前可在tokenizer初始化时设定的特殊token,仅可使用已知的特殊token字面表达,即`<|endoftext|>`、`<|im_start|>`、`<|im_end|>`和`<|extra_0|>`到`<|extra_204|>`。
+对于微调或者其它需要这些token才能运行的框架,可以如下配置
+
+```python
+from transformers import AutoTokenizer
+
+tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen-7B', trust_remote_code=True, pad_token='<|endoftext|>')
+```
+
+> 注意: 对于提供的训练好的模型,设置诸如`bos`、`eos`、`unk`之类的没有意义,即模型不需要这些概念。
+> 如果设置了这些token,但没有相应的微调这些token以让模型理解其含义,未知行为可能被触发。
+> 特别时,不应混淆`<|endoftext|>`和`eos`的概念,除非应用场景中它们的实际含义是一致的,即句子末尾等价于文本末尾。
+
+**注入攻击防御**
+
+由于特殊token和普通token概念上的差异,如果输入文本中含有特殊token的字面表达该如何处理?
+以下面文本为例
+
+```
+print("<|endoftext|>")
+```
+
+其正确的tokenization为
+
+```
+ids:[1350, 9639, 91, 8691, 723, 427, 91, 82598]
+tokens: [b'print', b'("<', b'|', b'endo', b'ft', b'ext', b'|', b'>")']
+```
+
+不是
+
+```
+ids: [1350, 445, 151643, 899]
+tokens: [b'print', b'("', '<|endoftext|>', b'")']
+```
+
+默认行为曾是正确的,即输入文本中任何字符一律按普通token处理,特殊token应由开发者在tokenization人工处理。
+然后,这与社区中的实践似有差异,为开发者复用代码增加了额外适配步骤。
+
+默认行为已被调整为从输入文本中解析特殊token的字面表达。
+如需启用注入攻击防御,请传入参数`allowed_special=set()`:
+
+```python
+>>> tokenizer('print("<|endoftext|>")', allowed_special=set())
+{'input_ids': [1350, 9639, 91, 8691, 723, 427, 91, 82598], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
+```
+
+这一行为可以更精细的调控,将`allowed_special`设计为`str`的集合即可:
+
+```python
+>>> tokenizer('print("<|extra_0|>")<|endoftext|>', allowed_special={'<|endoftext|>'})
+{'input_ids': [1350, 9639, 91, 15460, 62, 15, 91, 82598, 151643], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}
+```
+
+如果希望输入中遇到特殊token的字面表达时,获得更直接的提醒,通过配置`disallowed_special`可以让tokenizer直接触发异常:
+
+```python
+>>> tokenizer('print("<|extra_0|>")<|endoftext|>', allowed_special={'<|endoftext|>'}, disallowed_special=('<|extra_0|>', ))
+...
+ValueError: Encountered text corresponding to disallowed special token '<|extra_0|>'.
+If you want this text to be encoded as a special token, pass it to `allowed_special`, e.g. `allowed_special={'<|extra_0|>', ...}`.
+If you want this text to be encoded as normal text, disable the check for this token by passing `disallowed_special=(enc.special_tokens_set - {'<|extra_0|>'})`.
+To disable this check for all special tokens, pass `disallowed_special=()`.
+```
+
+更多关于`allowed_special`和`disallowed_special`的信息, 请参阅[`tiktoken`代码](https://github.com/openai/tiktoken/blob/095924e02c85617df6889698d94515f91666c7ea/tiktoken/core.py#L75).
+
+新的默认行为与以下设定等价
+
+```python
+>>> tokenizer('print("<|endoftext|>")', allowed_special="all", disallowed_special=())
+{'input_ids': [1350, 445, 151643, 899], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1]}
+```
+
diff --git a/web_demo.py b/web_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd25f7f9a9c5a1390c4ea5fbfe8fb3cef58d76eb
--- /dev/null
+++ b/web_demo.py
@@ -0,0 +1,192 @@
+# Copyright (c) Alibaba Cloud.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""A simple web interactive chat demo based on gradio."""
+
+from argparse import ArgumentParser
+
+import gradio as gr
+import mdtex2html
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from transformers.generation import GenerationConfig
+
+DEFAULT_CKPT_PATH = 'QWen/QWen-7B-Chat'
+
+
+def _get_args():
+ parser = ArgumentParser()
+ parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
+ help="Checkpoint name or path, default to %(default)r")
+ parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
+
+ parser.add_argument("--share", action="store_true", default=False,
+ help="Create a publicly shareable link for the interface.")
+ parser.add_argument("--inbrowser", action="store_true", default=False,
+ help="Automatically launch the interface in a new tab on the default browser.")
+ parser.add_argument("--server-port", type=int, default=8000,
+ help="Demo server port.")
+ parser.add_argument("--server-name", type=str, default="127.0.0.1",
+ help="Demo server name.")
+
+ args = parser.parse_args()
+ return args
+
+
+def _load_model_tokenizer(args):
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
+ )
+
+ if args.cpu_only:
+ device_map = "cpu"
+ else:
+ device_map = "auto"
+
+ model = AutoModelForCausalLM.from_pretrained(
+ args.checkpoint_path,
+ device_map=device_map,
+ trust_remote_code=True,
+ resume_download=True,
+ ).eval()
+ model.generation_config = GenerationConfig.from_pretrained(
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
+ )
+
+ return model, tokenizer
+
+
+def postprocess(self, y):
+ if y is None:
+ return []
+ for i, (message, response) in enumerate(y):
+ y[i] = (
+ None if message is None else mdtex2html.convert(message),
+ None if response is None else mdtex2html.convert(response),
+ )
+ return y
+
+
+gr.Chatbot.postprocess = postprocess
+
+
+def _parse_text(text):
+ lines = text.split("\n")
+ lines = [line for line in lines if line != ""]
+ count = 0
+ for i, line in enumerate(lines):
+ if "```" in line:
+ count += 1
+ items = line.split("`")
+ if count % 2 == 1:
+ lines[i] = f'
'
+ else:
+ lines[i] = f"
"
+ else:
+ if i > 0:
+ if count % 2 == 1:
+ line = line.replace("`", r"\`")
+ line = line.replace("<", "<")
+ line = line.replace(">", ">")
+ line = line.replace(" ", " ")
+ line = line.replace("*", "*")
+ line = line.replace("_", "_")
+ line = line.replace("-", "-")
+ line = line.replace(".", ".")
+ line = line.replace("!", "!")
+ line = line.replace("(", "(")
+ line = line.replace(")", ")")
+ line = line.replace("$", "$")
+ lines[i] = " " + line
+ text = "".join(lines)
+ return text
+
+
+def _launch_demo(args, model, tokenizer):
+
+ def predict(_query, _chatbot, _task_history):
+ print(f"User: {_parse_text(_query)}")
+ _chatbot.append((_parse_text(_query), ""))
+ full_response = ""
+
+ for response in model.chat_stream(tokenizer, _query, history=_task_history):
+ _chatbot[-1] = (_parse_text(_query), _parse_text(response))
+
+ yield _chatbot
+ full_response = _parse_text(response)
+
+ print(f"History: {_task_history}")
+ _task_history.append((_query, full_response))
+ print(f"Qwen-7B-Chat: {_parse_text(full_response)}")
+
+ def regenerate(_chatbot, _task_history):
+ if not _task_history:
+ yield _chatbot
+ return
+ item = _task_history.pop(-1)
+ _chatbot.pop(-1)
+ yield from predict(item[0], _chatbot, _task_history)
+
+ def reset_user_input():
+ return gr.update(value="")
+
+ def reset_state(_task_history):
+ _task_history.clear()
+ return []
+
+ with gr.Blocks() as demo:
+ gr.Markdown("""\
+
""")
+ gr.Markdown("""
Qwen-7B-Chat Bot
""")
+ gr.Markdown(
+ """\
+
This WebUI is based on Qwen-7B-Chat, developed by Alibaba Cloud. \
+(本WebUI基于Qwen-7B-Chat打造,实现聊天机器人功能。)