znskiss commited on
Commit
ade0520
1 Parent(s): 81bb514

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. .github/ISSUE_TEMPLATE/bug_report.yaml +88 -0
  3. .github/ISSUE_TEMPLATE/config.yaml +1 -0
  4. .github/ISSUE_TEMPLATE/feature_request.yaml +78 -0
  5. .gitignore +11 -0
  6. FAQ.md +85 -0
  7. FAQ_zh.md +80 -0
  8. LICENSE +53 -0
  9. NOTICE +52 -0
  10. README.md +432 -7
  11. README_CN.md +436 -0
  12. README_JA.md +435 -0
  13. assets/cli_demo.gif +3 -0
  14. assets/hfagent_chat_1.png +3 -0
  15. assets/hfagent_chat_2.png +3 -0
  16. assets/hfagent_run.png +3 -0
  17. assets/logo.jpg +0 -0
  18. assets/openai_api.gif +3 -0
  19. assets/performance.png +0 -0
  20. assets/qwen_tokenizer.png +0 -0
  21. assets/react_showcase_001.png +0 -0
  22. assets/react_showcase_002.png +0 -0
  23. assets/react_tutorial_001.png +0 -0
  24. assets/react_tutorial_002.png +0 -0
  25. assets/tokenizer.pdf +0 -0
  26. assets/tokenizer.png +0 -0
  27. assets/wanx_colorful_black.png +3 -0
  28. assets/web_demo.gif +3 -0
  29. cli_demo.py +194 -0
  30. eval/EVALUATION.md +83 -0
  31. eval/evaluate_ceval.py +263 -0
  32. eval/evaluate_chat_ceval.py +290 -0
  33. eval/evaluate_chat_gsm8k.py +137 -0
  34. eval/evaluate_chat_humaneval.py +82 -0
  35. eval/evaluate_chat_mmlu.py +207 -0
  36. eval/evaluate_cmmlu.py +271 -0
  37. eval/evaluate_gsm8k.py +110 -0
  38. eval/evaluate_humaneval.py +70 -0
  39. eval/evaluate_mmlu.py +218 -0
  40. eval/evaluate_plugin.py +308 -0
  41. eval/gsm8k_prompt.txt +59 -0
  42. examples/langchain_tooluse.ipynb +708 -0
  43. examples/react_demo.py +288 -0
  44. examples/react_prompt.md +249 -0
  45. examples/tokenizer_showcase.ipynb +441 -0
  46. examples/transformers_agent.md +108 -0
  47. openai_api.py +211 -0
  48. requirements.txt +6 -0
  49. requirements_web_demo.txt +2 -0
  50. tech_memo.md +341 -0
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/cli_demo.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/hfagent_chat_1.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/hfagent_chat_2.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/hfagent_run.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/openai_api.gif filter=lfs diff=lfs merge=lfs -text
41
+ assets/wanx_colorful_black.png filter=lfs diff=lfs merge=lfs -text
42
+ assets/web_demo.gif filter=lfs diff=lfs merge=lfs -text
.github/ISSUE_TEMPLATE/bug_report.yaml ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 🐞 Bug
2
+ description: 提交错误报告 | File a bug/issue
3
+ title: "[BUG] <title>"
4
+ labels: []
5
+ body:
6
+ - type: checkboxes
7
+ attributes:
8
+ label: 是否已有关于该错误的issue或讨论? | Is there an existing issue / discussion for this?
9
+ description: |
10
+ 请先搜索您遇到的错误是否在已有的issues或讨论中提到过。
11
+ Please search to see if an issue / discussion already exists for the bug you encountered.
12
+ [Issues](https://github.com/QwenLM/Qwen-7B/issues)
13
+ [Discussions](https://github.com/QwenLM/Qwen-7B/discussions)
14
+ options:
15
+ - label: 我已经搜索过已有的issues和讨论 | I have searched the existing issues / discussions
16
+ required: true
17
+ - type: checkboxes
18
+ attributes:
19
+ label: 该问题是否在FAQ中有解答? | Is there an existing answer for this in FAQ?
20
+ description: |
21
+ 请先搜索您遇到的错误是否已在FAQ中有相关解答。
22
+ Please search to see if an answer already exists in FAQ for the bug you encountered.
23
+ [FAQ-en](https://github.com/QwenLM/Qwen-7B/blob/main/FAQ.md)
24
+ [FAQ-zh](https://github.com/QwenLM/Qwen-7B/blob/main/FAQ_zh.md)
25
+ options:
26
+ - label: 我已经搜索过FAQ | I have searched FAQ
27
+ required: true
28
+ - type: textarea
29
+ attributes:
30
+ label: 当前行为 | Current Behavior
31
+ description: |
32
+ 准确描述遇到的行为。
33
+ A concise description of what you're experiencing.
34
+ validations:
35
+ required: false
36
+ - type: textarea
37
+ attributes:
38
+ label: 期望行为 | Expected Behavior
39
+ description: |
40
+ 准确描述预期的行为。
41
+ A concise description of what you expected to happen.
42
+ validations:
43
+ required: false
44
+ - type: textarea
45
+ attributes:
46
+ label: 复现方法 | Steps To Reproduce
47
+ description: |
48
+ 复现当前行为的详细步骤。
49
+ Steps to reproduce the behavior.
50
+ placeholder: |
51
+ 1. In this environment...
52
+ 2. With this config...
53
+ 3. Run '...'
54
+ 4. See error...
55
+ validations:
56
+ required: false
57
+ - type: textarea
58
+ attributes:
59
+ label: 运行环境 | Environment
60
+ description: |
61
+ examples:
62
+ - **OS**: Ubuntu 20.04
63
+ - **Python**: 3.8
64
+ - **Transformers**: 4.31.0
65
+ - **PyTorch**: 2.0.1
66
+ - **CUDA**: 11.4
67
+ value: |
68
+ - OS:
69
+ - Python:
70
+ - Transformers:
71
+ - PyTorch:
72
+ - CUDA (`python -c 'import torch; print(torch.version.cuda)'`):
73
+ render: Markdown
74
+ validations:
75
+ required: false
76
+ - type: textarea
77
+ attributes:
78
+ label: 备注 | Anything else?
79
+ description: |
80
+ 您可以在这里补充其他关于该问题背景信息的描述、链接或引用等。
81
+
82
+ 您可以通过点击高亮此区域然后拖动文件的方式上传图片或日志文件。
83
+
84
+ Links? References? Anything that will give us more context about the issue you are encountering!
85
+
86
+ Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in.
87
+ validations:
88
+ required: false
.github/ISSUE_TEMPLATE/config.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ blank_issues_enabled: true
.github/ISSUE_TEMPLATE/feature_request.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "💡 Feature Request"
2
+ description: 创建新功能请求 | Create a new ticket for a new feature request
3
+ title: "💡 [REQUEST] - <title>"
4
+ labels: [
5
+ "question"
6
+ ]
7
+ body:
8
+ - type: input
9
+ id: start_date
10
+ attributes:
11
+ label: "起始日期 | Start Date"
12
+ description: |
13
+ 起始开发日期
14
+ Start of development
15
+ placeholder: "month/day/year"
16
+ validations:
17
+ required: false
18
+ - type: textarea
19
+ id: implementation_pr
20
+ attributes:
21
+ label: "实现PR | Implementation PR"
22
+ description: |
23
+ 实现该功能的Pull request
24
+ Pull request used
25
+ placeholder: "#Pull Request ID"
26
+ validations:
27
+ required: false
28
+ - type: textarea
29
+ id: reference_issues
30
+ attributes:
31
+ label: "相关Issues | Reference Issues"
32
+ description: |
33
+ 与该功能相关的issues
34
+ Common issues
35
+ placeholder: "#Issues IDs"
36
+ validations:
37
+ required: false
38
+ - type: textarea
39
+ id: summary
40
+ attributes:
41
+ label: "摘要 | Summary"
42
+ description: |
43
+ 简要描述新功能的特点
44
+ Provide a brief explanation of the feature
45
+ placeholder: |
46
+ Describe in a few lines your feature request
47
+ validations:
48
+ required: true
49
+ - type: textarea
50
+ id: basic_example
51
+ attributes:
52
+ label: "基本示例 | Basic Example"
53
+ description: Indicate here some basic examples of your feature.
54
+ placeholder: A few specific words about your feature request.
55
+ validations:
56
+ required: true
57
+ - type: textarea
58
+ id: drawbacks
59
+ attributes:
60
+ label: "缺陷 | Drawbacks"
61
+ description: |
62
+ 该新功能有哪些缺陷/可能造成哪些影响?
63
+ What are the drawbacks/impacts of your feature request ?
64
+ placeholder: |
65
+ Identify the drawbacks and impacts while being neutral on your feature request
66
+ validations:
67
+ required: true
68
+ - type: textarea
69
+ id: unresolved_question
70
+ attributes:
71
+ label: "未解决问题 | Unresolved questions"
72
+ description: |
73
+ 有哪些尚未解决的问题?
74
+ What questions still remain unresolved ?
75
+ placeholder: |
76
+ Identify any unresolved issues.
77
+ validations:
78
+ required: false
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.so
3
+ build
4
+ .coverage_*
5
+ *.egg-info
6
+ *~
7
+ .vscode/
8
+ .idea/
9
+ .DS_Store
10
+
11
+ /private/
FAQ.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FAQ
2
+
3
+ ## Installation & Environment
4
+
5
+ #### Failure in installing flash attention
6
+
7
+ Flash attention is an option for accelerating training and inference. Only NVIDIA GPUs of Turing, Ampere, Ada, and Hopper architecture, e.g., H100, A100, RTX 3090, T4, RTX 2080, can support flash attention. You can use our models without installing it.
8
+
9
+ #### Which version of transformers should I use?
10
+
11
+ 4.31.0 is preferred.
12
+
13
+ #### I downloaded the codes and checkpoints but I can't load the model locally. What should I do?
14
+
15
+ Please check if you have updated the code to the latest, and correctly downloaded all the sharded checkpoint files.
16
+
17
+ #### `qwen.tiktoken` is not found. What is it?
18
+
19
+ This is the merge file of the tokenizer. You have to download it. Note that if you just git clone the repo without [git-lfs](https://git-lfs.com), you cannot download this file.
20
+
21
+ #### transformers_stream_generator/tiktoken/accelerate not found
22
+
23
+ Run the command `pip install -r requirements.txt`. You can find the file at [https://github.com/QwenLM/Qwen-7B/blob/main/requirements.txt](https://github.com/QwenLM/Qwen-7B/blob/main/requirements.txt).
24
+ <br><br>
25
+
26
+
27
+
28
+ ## Demo & Inference
29
+
30
+ #### Is there any demo? CLI demo and Web UI demo?
31
+
32
+ Yes, see `web_demo.py` for web demo and `cli_demo.py` for CLI demo. See README for more information.
33
+
34
+
35
+
36
+ #### Can I use CPU only?
37
+
38
+ Yes, run `python cli_demo.py --cpu_only` will load the model and inference on CPU only.
39
+
40
+ #### Can Qwen support streaming?
41
+
42
+ Yes. See the function `chat_stream` in `modeling_qwen.py`.
43
+
44
+ #### Gibberish in result when using chat_stream().
45
+
46
+ This is because tokens represent bytes and a single token may be a meaningless string. We have updated the default setting of our tokenizer to avoid such decoding results. Please update the code to the latest version.
47
+
48
+ #### It seems that the generation is not related to the instruction...
49
+
50
+ Please check if you are loading Qwen-7B-Chat instead of Qwen-7B. Qwen-7B is the base model without alignment, which behaves differently from the SFT/Chat model.
51
+
52
+ #### Is quantization supported?
53
+
54
+ Yes, the quantization is supported by `bitsandbytes`. We are working on an improved version and will release the quantized model checkpoints.
55
+
56
+ #### Errors in running quantized models: `importlib.metadata.PackageNotFoundError: No package metadata was found for bitsandbytes`
57
+
58
+ For Linux users,running `pip install bitsandbytes` directly can solve the problem. For Windows users, you can run `python -m pip install bitsandbytes --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui`·
59
+
60
+ #### Slow when processing long sequences
61
+
62
+ We solved this problem. Updating the code to the latest version can help.
63
+
64
+ #### Unsatisfactory performance in processing long sequences
65
+
66
+ Please ensure that NTK is applied. `use_dynamc_ntk` and `use_logn_attn` in `config.json` should be set to `true` (`true` by default).
67
+ <br><br>
68
+
69
+
70
+
71
+ ## Finetuning
72
+
73
+ #### Can Qwen support SFT or even RLHF?
74
+
75
+ We do not provide finetuning or RLHF codes for now. However, some projects have supported finetuning, see [FastChat](**[https://github.com/lm-sys/FastChat](https://github.com/lm-sys/FastChat)), [Firefly]([https://github.com/yangjianxin1/Firefly](https://github.com/yangjianxin1/Firefly)), [**LLaMA Efficient Tuning**]([https://github.com/hiyouga/LLaMA-Efficient-Tuning](https://github.com/hiyouga/LLaMA-Efficient-Tuning)), etc. We will soon update the relevant codes.
76
+ <br><br>
77
+
78
+
79
+
80
+ ## Tokenizer
81
+
82
+ #### bos_id/eos_id/pad_id not found
83
+
84
+ In our training, we only use `<|endoftext|>` as the separator and padding token. You can set bos_id, eos_id, and pad_id to tokenizer.eod_id. Learn more about our tokenizer from our documents about the tokenizer.
85
+
FAQ_zh.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FAQ
2
+
3
+ ## 安装&环境
4
+
5
+ #### flash attention 安装失败
6
+
7
+ flash attention是一个用于加速模型训练推理的可选项,且仅适用于Turing、Ampere、Ada、Hopper架构的Nvidia GPU显卡(如H100、A100、RTX 3090、T4、RTX 2080),您可以在不安装flash attention的情况下正常使用模型进行推理。
8
+
9
+ #### 我应该用哪个transformers版本?
10
+
11
+ 建议使用4.31.0。
12
+
13
+ #### 我把模型和代码下到本地,按照教程无法使用,该怎么办?
14
+
15
+ 答:别着急,先检查你的代码是不是更新到最新版本,然后确认你是否完整地将模型checkpoint下到本地。
16
+
17
+ #### `qwen.tiktoken`这个文件找不到,怎么办?
18
+
19
+ 这个是我们的tokenizer的merge文件,你必须下载它才能使用我们的tokenizer。注意,如果你使用git clone却没有使用git-lfs,这个文件不会被下载。如果你不了解git-lfs,可点击[官网](https://git-lfs.com/)了解。
20
+
21
+ #### transformers_stream_generator/tiktoken/accelerate,这几个库提示找不到,怎么办?
22
+
23
+ 运行如下命令:`pip install -r requirements.txt`。相关依赖库在[https://github.com/QwenLM/Qwen-7B/blob/main/requirements.txt](https://github.com/QwenLM/Qwen-7B/blob/main/requirements.txt) 可以找到。
24
+ <br><br>
25
+
26
+
27
+ ## Demo & 推理
28
+
29
+ #### 是否提供Demo?CLI Demo及Web UI Demo?
30
+
31
+ `web_demo.py`和`cli_demo.py`分别提供了Web UI以及CLI的Demo。请查看README相关内容了解更多。
32
+
33
+ #### 我没有GPU,只用CPU运行CLI demo可以吗?
34
+
35
+ 可以的,运行`python cli_demo.py --cpu_only`命令即可将模型读取到CPU并使用CPU进行推理。
36
+
37
+ #### Qwen支持流式推理吗?
38
+
39
+ Qwen当前支持流式推理。见位于`modeling_qwen.py`的`chat_stream`函数。
40
+
41
+ #### 使用`chat_stream()`生成混乱的内容及乱码,为什么?
42
+
43
+ 这是由于模型生成过程中输出的部分token需要与后续token一起解码才能输出正常文本,单个token解码结果是无意义字符串,我们已经更新了tokenizer解码时的默认设置,避免这些字符串在生成结果中出现,如果仍有类似问题请更新模型至最新版本。
44
+
45
+ #### 模型的输出看起来与输入无关/没有遵循指令/看起来呆呆的
46
+
47
+ 请检查是否加载的是Qwen-7B-Chat模型进行推理,Qwen-7B模型是未经align的预训练基模型,不期望具备响应用户指令的能力。我们在模型最新版本已经对`chat`及`chat_stream`接口内进行了检查,避免您误将预训练模型作为SFT/Chat模型使用。
48
+
49
+ #### 是否有量化版本模型
50
+
51
+ 目前Qwen支持基于`bitsandbytes`的8-bit和4-bit的量化推理。后续我们将进一步更新提供更加高效的量化推理实现,并提供对应的量化模型。
52
+
53
+ #### 运行量化推理报错:`importlib.metadata.PackageNotFoundError: No package metadata was found for bitsandbytes`
54
+
55
+ 对于linux 用户,直接`pip install bitsandbytes`即可。对于windows用户,可以 运行`python -m pip install bitsandbytes --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui`。
56
+
57
+ #### 生成序列较长后速度显著变慢
58
+
59
+ 这一问题已经在最新版本中修复。请更新到最新代码。
60
+
61
+ #### 处理长序列时效果有问题
62
+
63
+ 请确认是否开启ntk。若要启用这些技巧,请将`config.json`里的`use_dynamc_ntk`和`use_logn_attn`设置为`true`。最新代码默认为`true`。
64
+ <br><br>
65
+
66
+
67
+ ## 微调
68
+
69
+ #### 当前是否支持SFT和RLHF?
70
+
71
+ 我们目前未提供SFT和RLHF代码。当前有多个外部项目已实现支持,如[FastChat](**[https://github.com/lm-sys/FastChat](https://github.com/lm-sys/FastChat))、[Firefly]([https://github.com/yangjianxin1/Firefly](https://github.com/yangjianxin1/Firefly))、[**LLaMA Efficient Tuning**]([https://github.com/hiyouga/LLaMA-Efficient-Tuning](https://github.com/hiyouga/LLaMA-Efficient-Tuning))等。我们会尽快更新这部分代码和说明。
72
+ <br><br>
73
+
74
+
75
+ ## Tokenizer
76
+
77
+ #### bos_id/eos_id/pad_id,这些token id不存在,为什么?
78
+
79
+ 在训练过程中,我们仅使用<|endoftext|>这一token作为sample/document之间的分隔符及padding位置占位符,你可以将bos_id, eos_id, pad_id均指向tokenizer.eod_id。请阅读我们关于tokenizer的文档,了解如何设置这些id。
80
+
LICENSE ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tongyi Qianwen LICENSE AGREEMENT
2
+
3
+ Tongyi Qianwen Release Date: August 3, 2023
4
+
5
+ By clicking to agree or by using or distributing any portion or element of the Tongyi Qianwen Materials, you will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
6
+
7
+ 1. Definitions
8
+ a. This Tongyi Qianwen LICENSE AGREEMENT (this "Agreement") shall mean the terms and conditions for use, reproduction, distribution and modification of the Materials as defined by this Agreement.
9
+ b. "We"(or "Us") shall mean Alibaba Cloud.
10
+ c. "You" (or "Your") shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Materials for any purpose and in any field of use.
11
+ d. "Third Parties" shall mean individuals or legal entities that are not under common control with Us or You.
12
+ e. "Tongyi Qianwen" shall mean the large language models (including Qwen-7B model and Qwen-7B-Chat model), and software and algorithms, consisting of trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing distributed by Us.
13
+ f. "Materials" shall mean, collectively, Alibaba Cloud's proprietary Tongyi Qianwen and Documentation (and any portion thereof) made available under this Agreement.
14
+ g. "Source" form shall mean the preferred form for making modifications, including but not limited to model source code, documentation source, and configuration files.
15
+ h. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation,
16
+ and conversions to other media types.
17
+
18
+ 2. Grant of Rights
19
+ You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Alibaba Cloud's intellectual property or other rights owned by Us embodied in the Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Materials.
20
+
21
+ 3. Redistribution
22
+ You may reproduce and distribute copies of the Materials or derivative works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
23
+ a. You shall give any other recipients of the Materials or derivative works a copy of this Agreement;
24
+ b. You shall cause any modified files to carry prominent notices stating that You changed the files;
25
+ c. You shall retain in all copies of the Materials that You distribute the following attribution notices within a "Notice" text file distributed as a part of such copies: "Tongyi Qianwen is licensed under the Tongyi Qianwen LICENSE AGREEMENT, Copyright (c) Alibaba Cloud. All Rights Reserved."; and
26
+ d. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such derivative works as a whole, provided Your use, reproduction, and distribution of the work otherwise complies with the terms and conditions of this Agreement.
27
+
28
+ 4. Restrictions
29
+ If you are commercially using the Materials, and your product or service has more than 100 million monthly active users, You shall request a license from Us. You cannot exercise your rights under this Agreement without our express authorization.
30
+
31
+ 5. Rules of use
32
+ a. The Materials may be subject to export controls or restrictions in China, the United States or other countries or regions. You shall comply with applicable laws and regulations in your use of the Materials.
33
+ b. You can not use the Materials or any output therefrom to improve any other large language model (excluding Tongyi Qianwen or derivative works thereof).
34
+
35
+ 6. Intellectual Property
36
+ a. We retain ownership of all intellectual property rights in and to the Materials and derivatives made by or for Us. Conditioned upon compliance with the terms and conditions of this Agreement, with respect to any derivative works and modifications of the Materials that are made by you, you are and will be the owner of such derivative works and modifications.
37
+ b. No trademark license is granted to use the trade names, trademarks, service marks, or product names of Us, except as required to fulfill notice requirements under this Agreement or as required for reasonable and customary use in describing and redistributing the Materials.
38
+ c. If you commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any entity alleging that the Materials or any output therefrom, or any part of the foregoing, infringe any intellectual property or other right owned or licensable by you, then all licences granted to you under this Agreement shall terminate as of the date such lawsuit or other proceeding is commenced or brought.
39
+
40
+ 7. Disclaimer of Warranty and Limitation of Liability
41
+
42
+ a. We are not obligated to support, update, provide training for, or develop any further version of the Tongyi Qianwen Materials or to grant any license thereto.
43
+ b. THE MATERIALS ARE PROVIDED "AS IS" WITHOUT ANY EXPRESS OR IMPLIED WARRANTY OF ANY KIND INCLUDING WARRANTIES OF MERCHANTABILITY, NONINFRINGEMENT, OR FITNESS FOR A PARTICULAR PURPOSE. WE MAKE NO WARRANTY AND ASSUME NO RESPONSIBILITY FOR THE SAFETY OR STABILITY OF THE MATERIALS AND ANY OUTPUT THEREFROM.
44
+ c. IN NO EVENT SHALL WE BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE MATERIALS OR ANY OUTPUT OF IT, NO MATTER HOW IT’S CAUSED.
45
+ d. You will defend, indemnify and hold harmless Us from and against any claim by any third party arising out of or related to your use or distribution of the Materials.
46
+
47
+ 8. Survival and Termination.
48
+ a. The term of this Agreement shall commence upon your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
49
+ b. We may terminate this Agreement if you breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, you must delete and cease use of the Materials. Sections 7 and 9 shall survive the termination of this Agreement.
50
+
51
+ 9. Governing Law and Jurisdiction.
52
+ a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
53
+ b. The People's Courts in Hangzhou City shall have exclusive jurisdiction over any dispute arising out of this Agreement.
NOTICE ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ------------- LICENSE FOR NVIDIA Megatron-LM code --------------
2
+
3
+ Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions
7
+ are met:
8
+ * Redistributions of source code must retain the above copyright
9
+ notice, this list of conditions and the following disclaimer.
10
+ * Redistributions in binary form must reproduce the above copyright
11
+ notice, this list of conditions and the following disclaimer in the
12
+ documentation and/or other materials provided with the distribution.
13
+ * Neither the name of NVIDIA CORPORATION nor the names of its
14
+ contributors may be used to endorse or promote products derived
15
+ from this software without specific prior written permission.
16
+
17
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
18
+ EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
20
+ PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
21
+ CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
25
+ OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
+
29
+
30
+ ------------- LICENSE FOR OpenAI tiktoken code --------------
31
+
32
+ MIT License
33
+
34
+ Copyright (c) 2022 OpenAI, Shantanu Jain
35
+
36
+ Permission is hereby granted, free of charge, to any person obtaining a copy
37
+ of this software and associated documentation files (the "Software"), to deal
38
+ in the Software without restriction, including without limitation the rights
39
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
40
+ copies of the Software, and to permit persons to whom the Software is
41
+ furnished to do so, subject to the following conditions:
42
+
43
+ The above copyright notice and this permission notice shall be included in all
44
+ copies or substantial portions of the Software.
45
+
46
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
47
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
48
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
49
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
50
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
51
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
52
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,437 @@
1
  ---
2
- title: Qwen 7B Main
3
- emoji: 👀
4
- colorFrom: yellow
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 3.40.1
8
- app_file: app.py
9
- pinned: false
10
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Qwen-7B-main
3
+ app_file: web_demo.py
 
 
4
  sdk: gradio
5
  sdk_version: 3.40.1
 
 
6
  ---
7
+ <br>
8
+
9
+ <p align="center">
10
+ <img src="assets/logo.jpg" width="400"/>
11
+ <p>
12
+ <br>
13
+
14
+ <p align="center">
15
+ Qwen-7B <a href="https://modelscope.cn/models/qwen/Qwen-7B/summary">🤖 <a> | <a href="https://huggingface.co/Qwen/Qwen-7B">🤗</a>&nbsp | Qwen-7B-Chat <a href="https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary">🤖 <a>| <a href="https://huggingface.co/Qwen/Qwen-7B-Chat">🤗</a>&nbsp | &nbsp<a href="https://modelscope.cn/studios/qwen/Qwen-7B-Chat-Demo/summary">Demo</a>&nbsp | &nbsp<a href="https://github.com/QwenLM/Qwen-7B/blob/main/tech_memo.md">Report</a>&nbsp&nbsp | &nbsp&nbsp<a href="https://discord.gg/9bjvspyu">Discord</a>
16
+ </p>
17
+ <br>
18
+
19
+ <p align="center">
20
+ <a href="README_CN.md">中文</a>&nbsp | &nbspEnglish&nbsp | &nbsp<a href="README_JA.md">日本語</a>
21
+ </p>
22
+ <br><br>
23
+
24
+ We opensource **Qwen-7B** and **Qwen-7B-Chat** on both **🤖 ModelScope** and **🤗 Hugging Face** (Click the logos on top to the repos with codes and checkpoints). This repo includes the brief introduction to Qwen-7B, the usage guidance, and also a technical memo [link](tech_memo.md) that provides more information.
25
+
26
+ Qwen-7B is the 7B-parameter version of the large language model series, Qwen (abbr. Tongyi Qianwen), proposed by Alibaba Cloud. Qwen-7B is a Transformer-based large language model, which is pretrained on a large volume of data, including web texts, books, codes, etc. Additionally, based on the pretrained Qwen-7B, we release Qwen-7B-Chat, a large-model-based AI assistant, which is trained with alignment techniques. The features of the Qwen-7B series include:
27
+
28
+ 1. **Trained with high-quality pretraining data**. We have pretrained Qwen-7B on a self-constructed large-scale high-quality dataset of over 2.2 trillion tokens. The dataset includes plain texts and codes, and it covers a wide range of domains, including general domain data and professional domain data.
29
+ 2. **Strong performance**. In comparison with the models of the similar model size, we outperform the competitors on a series of benchmark datasets, which evaluates natural language understanding, mathematics, coding, etc.
30
+ 3. **Better support of languages**. Our tokenizer, based on a large vocabulary of over 150K tokens, is a more efficient one compared with other tokenizers. It is friendly to many languages, and it is helpful for users to further finetune Qwen-7B for the extension of understanding a certain language.
31
+ 4. **Support of 8K Context Length**. Both Qwen-7B and Qwen-7B-Chat support the context length of 8K, which allows inputs with long contexts.
32
+ 5. **Support of Plugins**. Qwen-7B-Chat is trained with plugin-related alignment data, and thus it is capable of using tools, including APIs, models, databases, etc., and it is capable of playing as an agent.
33
+
34
+ The following sections include information that you might find it helpful. Specifically, we advise you to read the FAQ section before you launch issues.
35
+
36
+ ## News
37
+
38
+ * 2023.8.3 We release both Qwen-7B and Qwen-7B-Chat on ModelScope and Hugging Face. We also provide a technical memo for more details about the model, including training details and model performance.
39
+
40
+ ## Performance
41
+
42
+ In general, Qwen-7B outperforms the baseline models of a similar model size, and even outperforms larger models of around 13B parameters, on a series of benchmark datasets, e.g., MMLU, C-Eval, GSM8K, HumanEval, and WMT22, CMMLU, etc., which evaluate the models' capabilities on natural language understanding, mathematic problem solving, coding, etc. See the results below.
43
+
44
+ | Model | MMLU | C-Eval | GSM8K | HumanEval | WMT22 (en-zh) | CMMLU |
45
+ | :---------------- | :------------: | :------------: | :------------: | :------------: | :------------: |:------------: |
46
+ | LLaMA-7B | 35.1 | - | 11.0 | 10.5 | 8.7 | - |
47
+ | LLaMA 2-7B | 45.3 | - | 14.6 | 12.8 | 17.9 | - |
48
+ | Baichuan-7B | 42.3 | 42.8 | 9.7 | 9.2 | 26.6 | 44.4 |
49
+ | ChatGLM2-6B | 47.9 | 51.7 | 32.4 | 9.2 | - | 48.8 |
50
+ | InternLM-7B | 51.0 | 52.8 | 31.2 | 10.4 | 14.8 | - |
51
+ | Baichuan-13B | 51.6 | 53.6 | 26.6 | 12.8 | 30.0 | 55.8 |
52
+ | LLaMA-13B | 46.9 | 35.5 | 17.8 | 15.8 | 12.0 | - |
53
+ | LLaMA 2-13B | 54.8 | - | 28.7 | 18.3 | 24.2 | - |
54
+ | ChatGLM2-12B | 56.2 | **61.6** | 40.9 | - | - | - |
55
+ | **Qwen-7B** | **56.7** | 59.6 | **51.6** | **24.4** | **30.6** | **58.8** |
56
+
57
+ <p align="center">
58
+ <img src="assets/performance.png" width="1000"/>
59
+ <p>
60
+ <br>
61
+
62
+ Additionally, according to the third-party evaluation of large language models, conducted by [OpenCompass](https://opencompass.org.cn/leaderboard-llm), Qwen-7B and Qwen-7B-Chat are the top 7B-parameter models. This evaluation consists of a large amount of public benchmarks for the evaluation of language understanding and generation, coding, mathematics, reasoning, etc.
63
+
64
+ For more experimental results (detailed model performance on more benchmark datasets) and details, please refer to our technical memo by clicking [here](tech_memo.md).
65
+
66
+ ## Requirements
67
+
68
+ * python 3.8 and above
69
+ * pytorch 1.12 and above, 2.0 and above are recommended
70
+ * CUDA 11.4 and above are recommended (this is for GPU users, flash-attention users, etc.)
71
+
72
+ ## Quickstart
73
+
74
+ Below, we provide simple examples to show how to use Qwen-7B with 🤖 ModelScope and 🤗 Transformers.
75
+
76
+ Before running the code, make sure you have setup the environment and installed the required packages. Make sure you meet the above requirements, and then install the dependent libraries.
77
+
78
+ ```bash
79
+ pip install -r requirements.txt
80
+ ```
81
+
82
+ If your device supports fp16 or bf16, we recommend installing [flash-attention](https://github.com/Dao-AILab/flash-attention) for higher efficiency and lower memory usage. (**flash-attention is optional and the project can run normally without installing it**)
83
+
84
+ ```bash
85
+ git clone -b v1.0.8 https://github.com/Dao-AILab/flash-attention
86
+ cd flash-attention && pip install .
87
+ # Below are optional. Installing them might be slow.
88
+ # pip install csrc/layer_norm
89
+ # pip install csrc/rotary
90
+ ```
91
+
92
+ Now you can start with ModelScope or Transformers.
93
+
94
+ #### 🤗 Transformers
95
+
96
+ To use Qwen-7B-Chat for the inference, all you need to do is to input a few lines of codes as demonstrated below. However, **please make sure that you are using the latest code.**
97
+
98
+ ```python
99
+ from transformers import AutoModelForCausalLM, AutoTokenizer
100
+ from transformers.generation import GenerationConfig
101
+
102
+ # Note: The default behavior now has injection attack prevention off.
103
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
104
+
105
+ # use bf16
106
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval()
107
+ # use fp16
108
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval()
109
+ # use cpu only
110
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="cpu", trust_remote_code=True).eval()
111
+ # use auto mode, automatically select precision based on the device.
112
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True).eval()
113
+
114
+ # Specify hyperparameters for generation
115
+ model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
116
+
117
+ # 第一轮对话 1st dialogue turn
118
+ response, history = model.chat(tokenizer, "你好", history=None)
119
+ print(response)
120
+ # 你好!很高兴为你提供帮助。
121
+
122
+ # 第二轮对话 2nd dialogue turn
123
+ response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=history)
124
+ print(response)
125
+ # 这是一个关于一个年轻人奋斗创业最终取得成功的故事。
126
+ # 故事的主人公叫李明,他来自一个普通的家庭,父母都是普通的工人。从小,李明就立下了一个目标:要成为一名成功的企业家。
127
+ # 为了实现这个目标,李明勤奋学习,考上了大学。在大学期间,他积极参加各种创业比赛,获得了不少奖项。他还利用课余时间去实习,积累了宝贵的经验。
128
+ # 毕业后,李明决定开始自己的创业之路。他开始寻找投资机会,但多次都被拒绝了。然而,他并没有放弃。他继续努力,不断改进自己的创业计划,并寻找新的投资机会。
129
+ # 最终,李明成功地获得了一笔投资,开始了自己的创业之路。他成立了一家科技公司,专注于开发新型软件。在他的领导下,公司迅速发展起来,成为了一家成功的科技企业。
130
+ # 李明的成功并不是偶然的。他勤奋、坚韧、勇于冒险,不断学习和改进自己。他的成功也证明了,只要努力奋斗,任何人都有可能取得成功。
131
+
132
+ # 第三轮对话 3rd dialogue turn
133
+ response, history = model.chat(tokenizer, "给这个故事起一个标题", history=history)
134
+ print(response)
135
+ # 《奋斗创业:一个年轻人的成功之路》
136
+ ```
137
+
138
+ Running Qwen-7B pretrained base model is also simple.
139
+
140
+ <details>
141
+ <summary>Running Qwen-7B</summary>
142
+
143
+ ```python
144
+ from transformers import AutoModelForCausalLM, AutoTokenizer
145
+ from transformers.generation import GenerationConfig
146
+
147
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B", trust_remote_code=True)
148
+ # use bf16
149
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True, bf16=True).eval()
150
+ # use fp16
151
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True, fp16=True).eval()
152
+ # use cpu only
153
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="cpu", trust_remote_code=True).eval()
154
+ # use auto mode, automatically select precision based on the device.
155
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True).eval()
156
+
157
+ # Specify hyperparameters for generation
158
+ model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B", trust_remote_code=True)
159
+
160
+ inputs = tokenizer('蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是', return_tensors='pt')
161
+ inputs = inputs.to(model.device)
162
+ pred = model.generate(**inputs)
163
+ print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
164
+ # 蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是亚的斯亚贝巴(Addis Ababa)...
165
+ ```
166
+
167
+ </details>
168
+
169
+ #### 🤖 ModelScope
170
+
171
+ ModelScope is an opensource platform for Model-as-a-Service (MaaS), which provides flexible and cost-effective model service to AI developers. Similarly, you can run the models with ModelScope as shown below:
172
+
173
+ ```python
174
+ import os
175
+ from modelscope.pipelines import pipeline
176
+ from modelscope.utils.constant import Tasks
177
+ from modelscope import snapshot_download
178
+
179
+ model_id = 'QWen/qwen-7b-chat'
180
+ revision = 'v1.0.0'
181
+
182
+ model_dir = snapshot_download(model_id, revision)
183
+
184
+ pipe = pipeline(
185
+ task=Tasks.chat, model=model_dir, device_map='auto')
186
+ history = None
187
+
188
+ text = '浙江的省会在哪里?'
189
+ results = pipe(text, history=history)
190
+ response, history = results['response'], results['history']
191
+ print(f'Response: {response}')
192
+ text = '它有什么好玩的地方呢?'
193
+ results = pipe(text, history=history)
194
+ response, history = results['response'], results['history']
195
+ print(f'Response: {response}')
196
+ ```
197
+
198
+ ## Tokenizer
199
+
200
+ Our tokenizer based on tiktoken is different from other tokenizers, e.g., sentencepiece tokenizer. You need to pay attention to special tokens, especially in finetuning. For more detailed information on the tokenizer and related use in fine-tuning, please refer to the [documentation](tokenization_note.md).
201
+
202
+ ## Quantization
203
+
204
+ We provide examples to show how to load models in `NF4` and `Int8`. For starters, make sure you have implemented `bitsandbytes`. Note that the requirements for `bitsandbytes` are:
205
+
206
+ ```
207
+ **Requirements** Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0.
208
+ ```
209
+
210
+ Then run the following command to install `bitsandbytes`:
211
+
212
+ ```
213
+ pip install bitsandbytes
214
+ ```
215
+
216
+ Windows users should find another option, which might be [bitsandbytes-windows-webui](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels).
217
+
218
+ Then you only need to add your quantization configuration to `AutoModelForCausalLM.from_pretrained`. See the example below:
219
+
220
+ ```python
221
+ from transformers import AutoModelForCausalLM, BitsAndBytesConfig
222
+
223
+ # quantization configuration for NF4 (4 bits)
224
+ quantization_config = BitsAndBytesConfig(
225
+ load_in_4bit=True,
226
+ bnb_4bit_quant_type='nf4',
227
+ bnb_4bit_compute_dtype=torch.bfloat16
228
+ )
229
+
230
+ # quantization configuration for Int8 (8 bits)
231
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
232
+
233
+ model = AutoModelForCausalLM.from_pretrained(
234
+ args.checkpoint_path,
235
+ device_map="cuda:0",
236
+ quantization_config=quantization_config,
237
+ max_memory=max_memory,
238
+ trust_remote_code=True,
239
+ ).eval()
240
+ ```
241
+
242
+ With this method, it is available to load Qwen-7B in `NF4` and `Int8`, which saves you memory usage. We provide related statistics of model performance below. We find that the quantization downgrades the effectiveness slightly but significantly reduces memory costs.
243
+
244
+ | Precision | MMLU | GPU Memory for Loading Model |
245
+ | ----------- | :------: | :---------------------------: |
246
+ | BF16 | 56.7 | 16.38G |
247
+ | Int8 | 52.8 | 10.44G |
248
+ | NF4 | 48.9 | 7.79G |
249
+
250
+ Note: The GPU memory usage profiling in the above table is performed on single A100-SXM4-80G GPU, PyTorch 2.0.1 and CUDA 11.8, with flash attention used.
251
+
252
+ ## Inference Efficiency
253
+
254
+ ### Inference Speed
255
+
256
+ We measured the average inference speed of generating 2K tokens under BF16 precision and Int8 or NF4 quantization levels, respectively.
257
+
258
+ | Quantization Level | Inference Speed with flash_attn (tokens/s) | Inference Speed w/o flash_attn (tokens/s) |
259
+ | ---------------------- | :----------------------------------------: | :---------------------------------------: |
260
+ | BF16 (no quantization) | 30.06 | 27.55 |
261
+ | Int8 (bnb) | 7.94 | 7.86 |
262
+ | NF4 (bnb) | 21.43 | 20.37 |
263
+
264
+ In detail, the setting of profiling is generating 2048 new tokens with 1 context token. The profiling runs on single A100-SXM4-80G GPU with PyTorch 2.0.1 and CUDA 11.8. The inference speed is averaged over the generated 2048 tokens.
265
+
266
+ ### GPU Memory Usage
267
+
268
+ We also profile the peak GPU memory usage for encoding 2048 tokens as context (and generating single token) and generating 8192 tokens (with single token as context) under BF16 or Int8/NF4 quantization levels, respectively. The results are shown below.
269
+
270
+ When using flash attention, the memory usage is:
271
+
272
+ | Quantization Level | Peak Usage for Encoding 2048 Tokens | Peak Usage for Generating 8192 Tokens |
273
+ | ------------------ | :---------------------------------: | :-----------------------------------: |
274
+ | BF16 | 18.11GB | 23.52GB |
275
+ | Int8 | 12.17GB | 17.60GB |
276
+ | NF4 | 9.52GB | 14.93GB |
277
+
278
+ When not using flash attention, the memory usage is:
279
+
280
+ | Quantization Level | Peak Usage for Encoding 2048 Tokens | Peak Usage for Generating 8192 Tokens |
281
+ | ------------------ | :---------------------------------: | :-----------------------------------: |
282
+ | BF16 | 18.11GB | 24.40GB |
283
+ | Int8 | 12.18GB | 18.47GB |
284
+ | NF4 | 9.52GB | 15.81GB |
285
+
286
+ The above speed and memory profiling are conducted using [this script](https://qianwen-res.oss-cn-beijing.aliyuncs.com/profile.py).
287
+
288
+ ## Demo
289
+
290
+
291
+ ### Web UI
292
+
293
+ We provide code for users to build a web UI demo (thanks to @wysaid). Before you start, make sure you install the following packages:
294
+
295
+ ```
296
+ pip install -r requirements_web_demo.txt
297
+ ```
298
+
299
+ Then run the command below and click on the generated link:
300
+
301
+ ```
302
+ python web_demo.py
303
+ ```
304
+
305
+ <p align="center">
306
+ <br>
307
+ <img src="assets/web_demo.gif" width="600" />
308
+ <br>
309
+ <p>
310
+
311
+ ### CLI Demo
312
+
313
+ We provide a CLI demo example in `cli_demo.py`, which supports streaming output for the generation. Users can interact with Qwen-7B-Chat by inputting prompts, and the model returns model outputs in the streaming mode. Run the command below:
314
+
315
+ ```
316
+ python cli_demo.py
317
+ ```
318
+
319
+ <p align="center">
320
+ <br>
321
+ <img src="assets/cli_demo.gif" width="600" />
322
+ <br>
323
+ <p>
324
+
325
+ ## API
326
+
327
+ We provide methods to deploy local API based on OpenAI API (thanks to @hanpenggit). Before you start, install the required packages:
328
+
329
+ ```bash
330
+ pip install fastapi uvicorn openai pydantic sse_starlette
331
+ ```
332
+
333
+ Then run the command to deploy your API:
334
+
335
+ ```bash
336
+ python openai_api.py
337
+ ```
338
+
339
+ You can change your arguments, e.g., `-c` for checkpoint name or path, `--cpu-only` for CPU deployment, etc. If you meet problems launching your API deployment, updating the packages to the latest version can probably solve them.
340
+
341
+ Using the API is also simple. See the example below:
342
+
343
+ ```python
344
+ import openai
345
+ openai.api_base = "http://localhost:8000/v1"
346
+ openai.api_key = "none"
347
+
348
+ # create a request activating streaming response
349
+ for chunk in openai.ChatCompletion.create(
350
+ model="Qwen-7B",
351
+ messages=[
352
+ {"role": "user", "content": "你好"}
353
+ ],
354
+ stream=True
355
+ ):
356
+ if hasattr(chunk.choices[0].delta, "content"):
357
+ print(chunk.choices[0].delta.content, end="", flush=True)
358
+
359
+ # create a request not activating streaming response
360
+ response = openai.ChatCompletion.create(
361
+ model="Qwen-7B",
362
+ messages=[
363
+ {"role": "user", "content": "你好"}
364
+ ],
365
+ stream=False
366
+ )
367
+ print(response.choices[0].message.content)
368
+ ```
369
+
370
+ <p align="center">
371
+ <br>
372
+ <img src="assets/openai_api.gif" width="600" />
373
+ <br>
374
+ <p>
375
+
376
+ ## Tool Usage
377
+
378
+ Qwen-7B-Chat is specifically optimized for tool usage, including API, database, models, etc., so that users can build their own Qwen-7B-based LangChain, Agent, and Code Interpreter. In our evaluation [benchmark](eval/EVALUATION.md) for assessing tool usage capabilities, we find that Qwen-7B reaches stable performance.
379
+
380
+ | Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
381
+ |:------------|:----------------------:|:----------------------:|:----------------------:|
382
+ | GPT-4 | 95% | **0.90** | 15% |
383
+ | GPT-3.5 | 85% | 0.88 | 75% |
384
+ | **Qwen-7B** | **99%** | 0.89 | **9.7%** |
385
+
386
+ For how to write and use prompts for ReAct Prompting, please refer to [the ReAct examples](examples/react_prompt.md). The use of tools can enable the model to better perform tasks.
387
+
388
+ Additionally, we provide experimental results to show its capabilities of playing as an agent. See [Hugging Face Agent](https://huggingface.co/docs/transformers/transformers_agents) for more information. Its performance on the run-mode benchmark provided by Hugging Face is as follows:
389
+
390
+ | Model | Tool Selection↑ | Tool Used↑ | Code↑ |
391
+ |:---------------|:---------------:|:-----------:|:---------:|
392
+ |GPT-4 | **100** | **100** | **97.41** |
393
+ |GPT-3.5 | 95.37 | 96.30 | 87.04 |
394
+ |StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
395
+ | **Qwen-7B** | 90.74 | 92.59 | 74.07 |
396
+
397
+ ## Long-Context Understanding
398
+
399
+ To extend the context length and break the bottleneck of training sequence length, we introduce several techniques, including NTK-aware interpolation, window attention, and LogN attention scaling, to extend the context length to over 8K tokens. We conduct language modeling experiments on the arXiv dataset with the PPL evaluation and find that Qwen-7B can reach outstanding performance in the scenario of long context. Results are demonstrated below:
400
+
401
+ <table>
402
+ <tr>
403
+ <th rowspan="2">Model</th><th colspan="5" align="center">Sequence Length</th>
404
+ </tr>
405
+ <tr>
406
+ <th align="center">1024</th><th align="center">2048</th><th align="center">4096</th><th align="center">8192</th><th align="center">16384</th>
407
+ </tr>
408
+ <tr>
409
+ <td>Qwen-7B</td><td align="center"><b>4.23</b></td><td align="center"><b>3.78</b></td><td align="center">39.35</td><td align="center">469.81</td><td align="center">2645.09</td>
410
+ </tr>
411
+ <tr>
412
+ <td>+ dynamic_ntk</td><td align="center"><b>4.23</b></td><td align="center"><b>3.78</b></td><td align="center">3.59</td><td align="center">3.66</td><td align="center">5.71</td>
413
+ </tr>
414
+ <tr>
415
+ <td>+ dynamic_ntk + logn</td><td align="center"><b>4.23</b></td><td align="center"><b>3.78</b></td><td align="center"><b>3.58</b></td><td align="center">3.56</td><td align="center">4.62</td>
416
+ </tr>
417
+ <tr>
418
+ <td>+ dynamic_ntk + logn + window_attn</td><td align="center"><b>4.23</b></td><td align="center"><b>3.78</b></td><td align="center"><b>3.58</b></td><td align="center"><b>3.49</b></td><td align="center"><b>4.32</b></td>
419
+ </tr>
420
+ </table>
421
+
422
+ ## Reproduction
423
+
424
+ For your reproduction of the model performance on benchmark datasets, we provide scripts for you to reproduce the results. Check [eval/EVALUATION.md](eval/EVALUATION.md) for more information. Note that the reproduction may lead to slight differences from our reported results.
425
+
426
+ ## FAQ
427
+
428
+ If you meet problems, please refer to [FAQ](FAQ.md) and the issues first to search a solution before you launch a new issue.
429
+
430
+ ## License Agreement
431
+
432
+ Researchers and developers are free to use the codes and model weights of both Qwen-7B and Qwen-7B-Chat. We also allow their commercial use. Check our license at [LICENSE](LICENSE) for more details. If you have requirements for commercial use, please fill out the [form](https://dashscope.console.aliyun.com/openModelApply/qianwen) to apply.
433
+
434
+ ## Contact Us
435
+
436
+ If you are interested to leave a message to either our research team or product team, feel free to send an email to qianwen_opensource@alibabacloud.com.
437
 
 
README_CN.md ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <br>
2
+
3
+ <p align="center">
4
+ <img src="assets/logo.jpg" width="400"/>
5
+ <p>
6
+ <br>
7
+
8
+ <p align="center">
9
+ Qwen-7B <a href="https://modelscope.cn/models/qwen/Qwen-7B/summary">🤖 <a> | <a href="https://huggingface.co/Qwen/Qwen-7B">🤗</a>&nbsp | Qwen-7B-Chat <a href="https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary">🤖 <a>| <a href="https://huggingface.co/Qwen/Qwen-7B-Chat">🤗</a>&nbsp | &nbsp<a href="https://modelscope.cn/studios/qwen/Qwen-7B-Chat-Demo/summary">Demo</a>&nbsp | &nbsp<a href="https://github.com/QwenLM/Qwen-7B/blob/main/tech_memo.md">Report</a>&nbsp&nbsp | &nbsp&nbsp<a href="https://discord.gg/9bjvspyu">Discord</a>
10
+ </p>
11
+ <br>
12
+
13
+ <p align="center">
14
+ 中文</a>&nbsp | &nbsp<a href="README.md">English</a>&nbsp | &nbsp<a href="README_JA.md">日本語</a>
15
+ </p>
16
+ <br><br>
17
+
18
+ 我们在🤖 **ModelScope**以及🤗 **Hugging Face**均开源了**Qwen-7B**系列模型。请在本文档顶部点击相关链接查看仓库信息。本仓库主要包括Qwen-7B的简介、使用指南、技术备忘等内容。想了解更多关于模型的信息,请点击[链接](tech_memo.md)查看我们的技术备忘录。
19
+
20
+ 通义千问-7B(Qwen-7B) 是阿里云研发的通义千问大模型系列的70亿参数规模的模型。Qwen-7B是基于Transformer的大语言模型, 在超大规模的预训练数据上进行训练得到。预训练数据类型多样,覆盖广泛,包括大量网络文本、专业书籍、代码等。同时,在Qwen-7B的基础上,我们使用对齐机制打造了基于大语言模型的AI助手Qwen-7B-Chat。Qwen-7B系列模型的特点包括:
21
+
22
+ 1. **大规模高质量预训练数据**:我们使用了超过2.2万亿token的自建大规模预训练数据集进行语言模型的预训练。数据集包括文本和代码等多种数据类型,覆盖通用领域和专业领域。
23
+ 2. **优秀的模型性能**:相比同规模的开源模型,Qwen-7B在多个评测数据集上具有显著优势,甚至超出12-13B等更大规模的模型。评测评估的能力范围包括自然语言理解与生成、数学运算解题、代码生成等。
24
+ 3. **更好地支持多语言**:基于更大词表的分词器在分词上更高效,同时它对其他语言表现更加友好。用户可以在Qwen-7B的基础上更方便地训练特定语言的7B语言模型。
25
+ 4. **8K的上下文长度**:Qwen-7B及Qwen-7B-Chat均能支持8K的上下文长度, 允许用户输入更长的prompt。
26
+ 5. **支持插件调用**:Qwen-7B-Chat针对插件调用相关的对齐数据做了特定优化,当前模型能有效调用插件以及升级为Agent。
27
+
28
+ 以下章节的信息可能对你有帮助,建议阅读。如果你在使用过程遇到问题,建议先查询FAQ,如仍无法解决再提交issue。
29
+
30
+ ## 新闻
31
+
32
+ * 2023年8月3日 在魔搭社区(ModelScope)和Hugging Face同步推出Qwen-7B和Qwen-7B-Chat模型。同时,我们发布了技术备忘录,介绍了相关的训练细节和模型表现。
33
+
34
+ ## 评测表现
35
+
36
+ Qwen-7B在多个全面评估自然语言理解与生成、数学运算解题、代码生成等能力的评测数据集上,包括MMLU、C-Eval、GSM8K、HumanEval、WMT22、CMMLU等,均超出了同规模大语言模型的表现,甚至超出了如12-13B参数等更大规模的语言模型。
37
+
38
+ | Model | MMLU | C-Eval | GSM8K | HumanEval | WMT22 (en-zh) | CMMLU |
39
+ | :---------------- | :------------: | :------------: | :------------: | :------------: | :------------: |:------------: |
40
+ | LLaMA-7B | 35.1 | - | 11.0 | 10.5 | 8.7 | - |
41
+ | LLaMA 2-7B | 45.3 | - | 14.6 | 12.8 | 17.9 | - |
42
+ | Baichuan-7B | 42.3 | 42.8 | 9.7 | 9.2 | 26.6 | 44.4 |
43
+ | ChatGLM2-6B | 47.9 | 51.7 | 32.4 | 9.2 | - | 48.8 |
44
+ | InternLM-7B | 51.0 | 52.8 | 31.2 | 10.4 | 14.8 | - |
45
+ | Baichuan-13B | 51.6 | 53.6 | 26.6 | 12.8 | 30.0 | 55.8 |
46
+ | LLaMA-13B | 46.9 | 35.5 | 17.8 | 15.8 | 12.0 | - |
47
+ | LLaMA 2-13B | 54.8 | - | 28.7 | 18.3 | 24.2 | - |
48
+ | ChatGLM2-12B | 56.2 | **61.6** | 40.9 | - | - | - |
49
+ | **Qwen-7B** | **56.7** | 59.6 | **51.6** | **24.4** | **30.6** | **58.8** |
50
+
51
+ <p align="center">
52
+ <img src="assets/performance.png" width="1000"/>
53
+ <p>
54
+ <br>
55
+
56
+ 此外,根据[OpenCompass](https://opencompass.org.cn/leaderboard-llm)进行的大型语言模型第三方评估,Qwen-7B 和 Qwen-7B-Chat 是其中表现最优的7B参数模型��该评估由大量公开基准组成,用于评估语言理解和生成、代码生成、数学、推理等。
57
+
58
+ 更多的实验结果和细节请查看我们的技术备忘录。点击[这里](tech_memo.md)。
59
+
60
+ ## 要求
61
+
62
+ * python 3.8及以上版本
63
+ * pytorch 1.12及以上版本,推荐2.0及以上版本
64
+ * 建议使用CUDA 11.4及以上(GPU用户、flash-attention用户等需考虑此选项)
65
+
66
+ ## 快速使用
67
+
68
+ 我们提供简单的示例来说明如何利用🤖 ModelScope和🤗 Transformers快速使用Qwen-7B和Qwen-7B-Chat。
69
+
70
+ 在开始前,请确保你已经配置好环境并安装好相关的代码包。最重要的是,确保你满足上述要求,然后安装相关的依赖库。
71
+
72
+ ```bash
73
+ pip install -r requirements.txt
74
+ ```
75
+
76
+ 如果你的显卡支持fp16或bf16精度,我们还推荐安装[flash-attention](https://github.com/Dao-AILab/flash-attention)来提高你的运行效率以及降低显存占用。(**flash-attention只是可选项,不安装也可正常运行该项目**)
77
+
78
+ ```bash
79
+ git clone -b v1.0.8 https://github.com/Dao-AILab/flash-attention
80
+ cd flash-attention && pip install .
81
+ # 下方安装可选,安装可能比较缓慢。
82
+ # pip install csrc/layer_norm
83
+ # pip install csrc/rotary
84
+ ```
85
+
86
+ 接下来你可以开始使用Transformers或者ModelScope来使用我们的模型。
87
+
88
+ #### 🤗 Transformers
89
+
90
+ 如希望使用Qwen-7B-chat进行推理,所需要写的只是如下所示的数行代码。**请确保你使用的是最新代码。**
91
+
92
+ ```python
93
+ from transformers import AutoModelForCausalLM, AutoTokenizer
94
+ from transformers.generation import GenerationConfig
95
+
96
+ # 请注意:分词器默认行为已更改为默认关闭特殊token攻击防护。
97
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
98
+
99
+ # 打开bf16精度,A100、H100、RTX3060、RTX3070等显卡建议启用以节省显存
100
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval()
101
+ # 打开fp16精度,V100、P100、T4等显卡建议启用以节省显存
102
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval()
103
+ # 使用CPU进行推理,需要约32GB内存
104
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="cpu", trust_remote_code=True).eval()
105
+ # 默认使用自动模式,根据设备自动选择精度
106
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True).eval()
107
+
108
+ # 可指定不同的生成长度、top_p等相关超参
109
+ model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
110
+
111
+ # 第一轮对话 1st dialogue turn
112
+ response, history = model.chat(tokenizer, "你好", history=None)
113
+ print(response)
114
+ # 你好!很高兴为你提供帮助。
115
+
116
+ # 第二轮对话 2nd dialogue turn
117
+ response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=history)
118
+ print(response)
119
+ # 这是一个关于一个年轻人奋斗创业最终取得成功的故事。
120
+ # 故事的主人公叫李明,他来自一个普通的家庭,父母都是普通的工人。从小,李明就立下了一个目标:要成为一名成功的企业家。
121
+ # 为了实现这个目标,李明勤奋学习,考上了大学。在大学期间,他积极参加各种创业比赛,获得了不少奖项。他还利用课余时间去实习,积累了宝贵的经验。
122
+ # 毕业后,李明决定开始自己的创业之路。他开始寻找投资机会,但多次都被拒绝了。然而,他并没有放弃。他继续努力,不断改进自己的创业计划,并寻找新的投资机会。
123
+ # 最终,李明成功地获得了一笔投资,开始了自己的创业之路。他成立了一家科技公司,专注于开发新型软件。在他的领导下,公司迅速发展起来,成为了一家成功的科技企业。
124
+ # 李明的成功并不是偶然的。他勤奋、坚韧、勇于冒险,不断学习和改进自己。他的成功也证明了,只要努力奋斗,任何人都有可能取得成功。
125
+
126
+ # 第三轮对话 3rd dialogue turn
127
+ response, history = model.chat(tokenizer, "给这个故事起一个标题", history=history)
128
+ print(response)
129
+ # 《奋斗创业:一个年轻人的成功之路》
130
+ ```
131
+
132
+ 运行Qwen-7B同样非常简单。
133
+
134
+ <details>
135
+ <summary>运行Qwen-7B</summary>
136
+
137
+ ```python
138
+ from transformers import AutoModelForCausalLM, AutoTokenizer
139
+ from transformers.generation import GenerationConfig
140
+
141
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B", trust_remote_code=True)
142
+
143
+ # 打开bf16精度,A100、H100、RTX3060、RTX3070等显卡建议启用以节省显存
144
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True, bf16=True).eval()
145
+ # 打开fp16精度,V100、P100、T4等显卡建议启用以节省显存
146
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True, fp16=True).eval()
147
+ # 使用CPU进行推理,需要约32GB内存
148
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="cpu", trust_remote_code=True).eval()
149
+ # 默认使用自动模式,根据设备自动选择精度
150
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True).eval()
151
+
152
+ # 可指定不同的生成长度、top_p等相关超参
153
+ model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B", trust_remote_code=True)
154
+
155
+ inputs = tokenizer('蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是', return_tensors='pt')
156
+ inputs = inputs.to(model.device)
157
+ pred = model.generate(**inputs)
158
+ print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
159
+ # 蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是亚的斯亚贝巴(Addis Ababa)...
160
+ ```
161
+
162
+ </details>
163
+
164
+ #### 🤖 ModelScope
165
+
166
+ 魔搭(ModelScope)是开源的模型即服务共享平台,为泛AI开发者提供灵活、易用、低成本的一站式模型服务产品。使用ModelScope同样非常简单,代码如下所示:
167
+
168
+ ```python
169
+ import os
170
+ from modelscope.pipelines import pipeline
171
+ from modelscope.utils.constant import Tasks
172
+ from modelscope import snapshot_download
173
+
174
+ model_id = 'QWen/qwen-7b-chat'
175
+ revision = 'v1.0.0'
176
+
177
+ model_dir = snapshot_download(model_id, revision)
178
+
179
+ pipe = pipeline(
180
+ task=Tasks.chat, model=model_dir, device_map='auto')
181
+ history = None
182
+
183
+ text = '浙江的省会在哪里?'
184
+ results = pipe(text, history=history)
185
+ response, history = results['response'], results['history']
186
+ print(f'Response: {response}')
187
+ text = '它有什么好玩的地方呢?'
188
+ results = pipe(text, history=history)
189
+ response, history = results['response'], results['history']
190
+ print(f'Response: {response}')
191
+ ```
192
+
193
+ ## Tokenization
194
+
195
+ > 注:作为术语的“tokenization”在中文中尚无共识的概念对应,本文档采用英文表达以利说明。
196
+
197
+ 基于tiktoken的tokenizer有别于其他分词器,比如sentencepiece tokenizer。尤其在微调阶段,需要特别注意特殊token的使用。关于tokenizer的更多信息,以及微调时涉及的相关使用,请参阅[文档](tokenization_note_zh.md)。
198
+
199
+ ## 量化
200
+
201
+ 如希望使用更低精度的量化模型,如4比特和8比特的模型,我们提供了简单的示例来说明如何快速使用量化模型。在开始前,确保你已经安装了`bitsandbytes`。请注意,`bitsandbytes`的安装要求是:
202
+
203
+ ```
204
+ **Requirements** Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0.
205
+ ```
206
+
207
+ 随后运行如下命令安装`bitsandbytes`:
208
+
209
+ ```
210
+ pip install bitsandbytes
211
+ ```
212
+
213
+ Windows用户需安装特定版本的`bitsandbytes`,可选项包括[bitsandbytes-windows-webui](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels)。
214
+
215
+ 你只需要在`AutoModelForCausalLM.from_pretrained`中添加你的量化配置,即可使用量化模型。如下所示:
216
+
217
+ ```python
218
+ from transformers import AutoModelForCausalLM, BitsAndBytesConfig
219
+
220
+ # quantization configuration for NF4 (4 bits)
221
+ quantization_config = BitsAndBytesConfig(
222
+ load_in_4bit=True,
223
+ bnb_4bit_quant_type='nf4',
224
+ bnb_4bit_compute_dtype=torch.bfloat16
225
+ )
226
+
227
+ # quantization configuration for Int8 (8 bits)
228
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
229
+
230
+ model = AutoModelForCausalLM.from_pretrained(
231
+ args.checkpoint_path,
232
+ device_map="cuda:0",
233
+ quantization_config=quantization_config,
234
+ max_memory=max_memory,
235
+ trust_remote_code=True,
236
+ ).eval()
237
+ ```
238
+
239
+ 上述方法可以让我们将模型量化成`NF4`和`Int8`精度的模型进行读取,帮助我们节省显存开销。我们也提供了相关性能数据。我们发现尽管模型在效果上存在损失,但模型的显存开销大幅降低。
240
+
241
+ | Precision | MMLU | GPU Memory for Loading Model |
242
+ | ----------- | :------: | :---------------------------: |
243
+ | BF16 | 56.7 | 16.38G |
244
+ | Int8 | 52.8 | 10.44G |
245
+ | NF4 | 48.9 | 7.79G |
246
+
247
+ 注:表中显存占用的测试环境为A100-SXM4-80G单卡,PyTorch 2.0.1,CUDA 11.8,开启flash attention
248
+
249
+ ## 推理性能
250
+
251
+ ### 推理速度
252
+
253
+ 我们分别测试了BF16和量化条件下,模型生成2K tokens的平均推理速度,结果如下
254
+
255
+ | 量化等级 | 开flash_attn的推理速度 (字符/秒) | 关flash_attn的推理速度 (字符/秒) |
256
+ | ------ | :---------------------------: | :---------------------------: |
257
+ | BF16 (无量化) | 30.06 | 27.55 |
258
+ | Int8 (bnb) | 7.94 | 7.86 |
259
+ | NF4 (bnb) | 21.43 | 20.37 |
260
+
261
+ 具体的评测方式为:指定输入context长度为1,生成长度为2048;测试硬件为A100-SXM4-80G单卡,软件环境为PyTorch 2.0.1,CUDA版本11.8,计算生成该2048序列的平均速度
262
+
263
+ ### 显存占用
264
+
265
+ 在BF16和不同量化条件下,我们分别测算了模型编码2048长度序列(并生成1个token),和生成8192长度序列(编码1个token作为context)的峰值显存占用。结果如下
266
+
267
+ 打开flash attention时
268
+
269
+ | 量化等级 | 编码 2048 长度的峰值显存 | 生成 8192 长度的峰值显存 |
270
+ | --- | :---: | :---: |
271
+ | BF16 | 18.11GB | 23.52GB |
272
+ | Int8 | 12.17GB | 17.60GB |
273
+ | NF4 | 9.52GB | 14.93GB |
274
+
275
+ 关闭flash attention时
276
+
277
+ | 量化等级 | 编码 2048 长度的峰值显存 | 生成 8192 长度的峰值显存 |
278
+ | --- | :---: | :---: |
279
+ | BF16 | 18.11GB | 24.40GB |
280
+ | Int8 | 12.18GB | 18.47GB |
281
+ | NF4 | 9.52GB | 15.81GB |
282
+
283
+ 以上测速和显存占用情况,均可通过该[评测脚本](https://qianwen-res.oss-cn-beijing.aliyuncs.com/profile.py)测算得到。
284
+
285
+ ## Demo
286
+
287
+ ### Web UI
288
+
289
+ 我们提供了Web UI的demo供用户使用 (感谢 @wysaid 支持)。在开始前,确保已经安装如下代码库:
290
+
291
+ ```
292
+ pip install -r requirements_web_demo.txt
293
+ ```
294
+
295
+ 随后运行如下命令,并点击生成链接:
296
+
297
+ ```
298
+ python web_demo.py
299
+ ```
300
+
301
+ <p align="center">
302
+ <br>
303
+ <img src="assets/web_demo.gif" width="600" />
304
+ <br>
305
+ <p>
306
+
307
+
308
+ ### 交互式Demo
309
+
310
+ 我们提供了一个简单的交互式Demo示例,请查看`cli_demo.py`。当前模型已经支持流式输出,用户可通过输入文字的方式和Qwen-7B-Chat交互,模型将流式输出返回结果。运行如下命令:
311
+
312
+ ```
313
+ python cli_demo.py
314
+ ```
315
+
316
+ <p align="center">
317
+ <br>
318
+ <img src="assets/cli_demo.gif" width="600" />
319
+ <br>
320
+ <p>
321
+
322
+ ## API
323
+
324
+ 我们提供了OpenAI API格式的本地API部署方法(感谢@hanpenggit)。在开始之前先安装必要的代码库:
325
+
326
+ ```bash
327
+ pip install fastapi uvicorn openai pydantic sse_starlette
328
+ ```
329
+
330
+ 随后即可运行以下命令部署你的本地API:
331
+
332
+ ```bash
333
+ python openai_api.py
334
+ ```
335
+
336
+ 你也可以修改参数,比如`-c`来修改模型名称或路径, `--cpu-only`改为CPU部署等等。如果部署出现问题,更新上述代码库往往可以解决大多数问题。
337
+
338
+ 使用API同样非常简单,示例如下:
339
+
340
+ ```python
341
+ import openai
342
+ openai.api_base = "http://localhost:8000/v1"
343
+ openai.api_key = "none"
344
+
345
+ # 使用流式回复的请求
346
+ for chunk in openai.ChatCompletion.create(
347
+ model="Qwen-7B",
348
+ messages=[
349
+ {"role": "user", "content": "你好"}
350
+ ],
351
+ stream=True
352
+ ):
353
+ if hasattr(chunk.choices[0].delta, "content"):
354
+ print(chunk.choices[0].delta.content, end="", flush=True)
355
+
356
+ # 不使用流式回复的请求
357
+ response = openai.ChatCompletion.create(
358
+ model="Qwen-7B",
359
+ messages=[
360
+ {"role": "user", "content": "你好"}
361
+ ],
362
+ stream=False
363
+ )
364
+ print(response.choices[0].message.content)
365
+ ```
366
+
367
+ <p align="center">
368
+ <br>
369
+ <img src="assets/openai_api.gif" width="600" />
370
+ <br>
371
+ <p>
372
+
373
+ ## 工具调用
374
+
375
+ Qwen-7B-Chat针对包括API、数据库、模型等工具在内的调用进行了优化。用户可以开发基于Qwen-7B的LangChain、Agent甚至Code Interpreter。在我们开源的[评测数据集](eval/EVALUATION.md)上测试模型的工具调用能力,并发现Qwen-7B-Chat能够取得稳定的表现。
376
+
377
+ | Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
378
+ |:------------|:----------------------:|:----------------------:|:----------------------:|
379
+ | GPT-4 | 95% | **0.90** | 15% |
380
+ | GPT-3.5 | 85% | 0.88 | 75% |
381
+ | **Qwen-7B** | **99%** | 0.89 | **9.7%** |
382
+
383
+ 我们提供了文档说明如何根据ReAct Prompting的原则写作你的prompt。
384
+
385
+ For how to write and use prompts for ReAct Prompting, please refer to [the ReAct examples](examples/react_prompt.md)。
386
+
387
+ 此外,我们还提供了实验结果表明我们的模型扮演Agent的能力。请阅读相关文档[链接](https://huggingface.co/docs/transformers/transformers_agents)了解更多信息。模型在Hugging Face提供的评测数据集上表现如下:
388
+
389
+ | Model | Tool Selection↑ | Tool Used↑ | Code↑ |
390
+ |:---------------|:---------------:|:-----------:|:---------:|
391
+ |GPT-4 | **100** | **100** | **97.41** |
392
+ |GPT-3.5 | 95.37 | 96.30 | 87.04 |
393
+ |StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
394
+ | **Qwen-7B** | 90.74 | 92.59 | 74.07 |
395
+
396
+ ## 长文本理解
397
+
398
+ 我们引入了NTK插值、窗口注意力、LogN注意力缩放等技术来提升模型的上下文长度并突破训练序列长度的限制。我们的模型已经突破8K的序列长度。通过arXiv数据集上的语言模型实验,我们发现Qwen-7B能够在长序列的设置下取得不错的表现。
399
+
400
+ <table>
401
+ <tr>
402
+ <th rowspan="2">Model</th><th colspan="5" align="center">Sequence Length</th>
403
+ </tr>
404
+ <tr>
405
+ <th align="center">1024</th><th align="center">2048</th><th align="center">4096</th><th align="center">8192</th><th align="center">16384</th>
406
+ </tr>
407
+ <tr>
408
+ <td>Qwen-7B</td><td align="center"><b>4.23</b></td><td align="center"><b>3.78</b></td><td align="center">39.35</td><td align="center">469.81</td><td align="center">2645.09</td>
409
+ </tr>
410
+ <tr>
411
+ <td>+ dynamic_ntk</td><td align="center"><b>4.23</b></td><td align="center"><b>3.78</b></td><td align="center">3.59</td><td align="center">3.66</td><td align="center">5.71</td>
412
+ </tr>
413
+ <tr>
414
+ <td>+ dynamic_ntk + logn</td><td align="center"><b>4.23</b></td><td align="center"><b>3.78</b></td><td align="center"><b>3.58</b></td><td align="center">3.56</td><td align="center">4.62</td>
415
+ </tr>
416
+ <tr>
417
+ <td>+ dynamic_ntk + logn + local_attn</td><td align="center"><b>4.23</b></td><td align="center"><b>3.78</b></td><td align="center"><b>3.58</b></td><td align="center"><b>3.49</b></td><td align="center"><b>4.32</b></td>
418
+ </tr>
419
+ </table>
420
+
421
+ ## 复现
422
+
423
+ 我们提供了评测脚本以供复现我们的实验结果。注意,由于内部代码和开源代码存在少许差异,评测结果可能与汇报结果存在细微的结果不一致。请阅读[eval/EVALUATION.md](eval/EVALUATION.md)了解更多信息。
424
+
425
+ ## FAQ
426
+
427
+ 如遇到问题,敬请查阅[FAQ](FAQ_zh.md)以及issue区,如仍无法解决再提交issue。
428
+
429
+ ## 使用协议
430
+
431
+ 研究人员与开发者可使用Qwen-7B和Qwen-7B-Chat或进行二次开发。我们同样允许商业使用,具体细节请查看[LICENSE](LICENSE)。如需商用,请填写[问卷](https://dashscope.console.aliyun.com/openModelApply/qianwen)申请。
432
+
433
+ ## 联系我们
434
+
435
+ 如果你想给我们的研发团队和产品团队留言,请通过邮件(qianwen_opensource@alibabacloud.com)联系我们。
436
+
README_JA.md ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <br>
2
+
3
+ <p align="center">
4
+ <img src="assets/logo.jpg" width="400"/>
5
+ <p>
6
+ <br>
7
+
8
+ <p align="center">
9
+ Qwen-7B <a href="https://modelscope.cn/models/qwen/Qwen-7B/summary">🤖 <a> | <a href="https://huggingface.co/Qwen/Qwen-7B">🤗</a>&nbsp | Qwen-7B-Chat <a href="https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary">🤖 <a>| <a href="https://huggingface.co/Qwen/Qwen-7B-Chat">🤗</a>&nbsp | &nbsp<a href="https://modelscope.cn/studios/qwen/Qwen-7B-Chat-Demo/summary">Demo</a>&nbsp | &nbsp<a href="https://github.com/QwenLM/Qwen-7B/blob/main/tech_memo.md">Report</a>&nbsp&nbsp | &nbsp&nbsp<a href="https://discord.gg/9bjvspyu">Discord</a>
10
+ </p>
11
+ <br>
12
+
13
+ <p align="center">
14
+ <a href="README_CN.md">中文</a>&nbsp | &nbsp<a href="README.md">English</a>&nbsp | &nbsp日本語
15
+ </p>
16
+ <br><br>
17
+ <p align="right">
18
+ Japanese document maintainer: Ikko Eltociear Ashimine
19
+ </p>
20
+ <br><br>
21
+
22
+ 私たちは、**Qwen-7B** と **Qwen-7B-Chat** を **🤖 ModelScope** と **🤗 Hugging Face** の両方でオープンソース化しています(上部のロゴをクリックすると、コードとチェックポイントのあるリポジトリに移動します)。このレポには、Qwen-7B の簡単な紹介と、使い方の手引き、さらに詳しい情報を提供する技術メモ [link](tech_memo.md) が含まれています。
23
+
24
+ Qwen-7Bは、アリババクラウドが提唱する大規模言語モデルシリーズQwen(略称:Tongyi Qianwen)の7Bパラメータ版です。Qwen-7BはTransformerベースの大規模言語モデルであり、ウェブテキスト、書籍、コードなどを含む大量のデータで事前学習される。さらに、事前学習されたQwen-7Bをベースに、アライメント技術で学習された大規模モデルベースのAIアシスタントであるQwen-7B-Chatをリリースする。Qwen-7Bシリーズの特徴は以下の通りです:
25
+
26
+ 1. **高品質な事前トレーニングデータでトレーニング**。Qwen-7B は 2.2 兆以上のトークンを含む大規模で高品質なデータセットに対して事前学習を行った。このデータセットには平文とコードが含まれ、一般的なドメインデータと専門的なドメインデータを含む幅広いドメインをカバーしている。
27
+ 2. **強いパフォーマンス**。自然言語理解、数学、コーディングなどを評価する一連のベンチマークデータセットにおいて、同程度のモデルサイズのモデルと比較して、競合他社を凌駕しています。
28
+ 3. **言語サポートの向上**。Qwen-7B のトークナイザは、15 万以上のトークンの語彙をベースにしており、他のトークナイザに比べて効率的です。多くの言語に対応しており、ユーザが特定の言語を理解するために Qwen-7B をさらに微調整するのに役立ちます。
29
+ 4. **8K コンテキスト長をサポート**。Qwen-7B と Qwen-7B-Chat はともに 8K のコンテキスト長をサポートしており、長いコンテキストでの入力を可能にしている。
30
+ 5. **プラグインのサポート**。Qwen-7B-Chat は、プラグイン関連のアライメントデータでトレーニングされているため、API、モデル、データベースなどのツールを使用することができ、エージェントとしてプレイすることができる。
31
+
32
+ 以下のセクションには、参考になる情報が記載されています。特に、issueを立ち上げる前にFAQセクションをお読みになることをお勧めします。
33
+
34
+ ## ニュース
35
+
36
+ * 2023.8.3 Qwen-7B と Qwen-7B-Chat を ModelScope と Hugging Face で公開。また、トレーニングの詳細やモデルの性能など、モデルの詳細についてはテクニカルメモを提供しています。
37
+
38
+ ## パフォーマンス
39
+
40
+ 一般的に、Qwen-7B は、MMLU、C-Eval、GSM8K、HumanEval、WMT22、CMMLU などの自然言語理解、数学的問題解決、コーディングなどに関するモデルの能力を評価する一連のベンチマークデータセットにおいて、同程度のモデルサイズのベースラインモデルを凌駕し、さらには 13B 程度のパラメータを持つより大規模なモデルをも凌駕している。以下の結果をご覧ください。
41
+
42
+ | Model | MMLU | C-Eval | GSM8K | HumanEval | WMT22 (en-zh) | CMMLU |
43
+ | :---------------- | :------------: | :------------: | :------------: | :------------: | :------------: |:------------: |
44
+ | LLaMA-7B | 35.1 | - | 11.0 | 10.5 | 8.7 | - |
45
+ | LLaMA 2-7B | 45.3 | - | 14.6 | 12.8 | 17.9 | - |
46
+ | Baichuan-7B | 42.3 | 42.8 | 9.7 | 9.2 | 26.6 | 44.4 |
47
+ | ChatGLM2-6B | 47.9 | 51.7 | 32.4 | 9.2 | - | 48.8 |
48
+ | InternLM-7B | 51.0 | 52.8 | 31.2 | 10.4 | 14.8 | - |
49
+ | Baichuan-13B | 51.6 | 53.6 | 26.6 | 12.8 | 30.0 | 55.8 |
50
+ | LLaMA-13B | 46.9 | 35.5 | 17.8 | 15.8 | 12.0 | - |
51
+ | LLaMA 2-13B | 54.8 | - | 28.7 | 18.3 | 24.2 | - |
52
+ | ChatGLM2-12B | 56.2 | **61.6** | 40.9 | - | - | - |
53
+ | **Qwen-7B** | **56.7** | 59.6 | **51.6** | **24.4** | **30.6** | **58.8** |
54
+
55
+ <p align="center">
56
+ <img src="assets/performance.png" width="1000"/>
57
+ <p>
58
+ <br>
59
+
60
+ さらに、[OpenCompass](https://opencompass.org.cn/leaderboard-llm)が実施した大規模言語モデルの第三者評価によると、Qwen-7BとQwen-7B-Chatは7Bパラメータモデルのトップである。この評価は、言語理解・生成、コーディング、数学、推論などの評価のための大量の公開ベンチマークで構成されている。
61
+
62
+ より詳細な実験結果(より多くのベンチマークデータセットでの詳細なモデル性能)や詳細については、[こちら](tech_memo.md)をクリックして技術メモを参照してください。
63
+
64
+ ## 必要条件
65
+
66
+ * python 3.8 以上
67
+ * pytorch 1.12 以上、2.0 以上を推奨
68
+ * CUDA 11.4 以上を推奨(GPU ユーザー、フラッシュアテンションユーザー向けなど)
69
+
70
+ ## クイックスタート
71
+
72
+ 以下では、Qwen-7B と 🤖 ModelScope と 🤗 Transformers の簡単な使用例を示します。
73
+
74
+ コードを実行する前に、環境のセットアップと必要なパッケージのインストールが済んでいることを確認してください。上記の要件を満たしていることを確認してから、依存するライブラリをインストールしてください。
75
+
76
+ ```bash
77
+ pip install -r requirements.txt
78
+ ```
79
+
80
+ お使いのデバイスが fp16 または bf16 をサポートしている場合、[flash-attention](https://github.com/Dao-AILab/flash-attention) をインストールすることで、より高い効率とメモリ使用量を抑えることができます。(**flash-attention はオプションであり、インストールしなくてもプロジェクトは正常に実行できます**)
81
+
82
+ ```bash
83
+ git clone -b v1.0.8 https://github.com/Dao-AILab/flash-attention
84
+ cd flash-attention && pip install .
85
+ # 以下はオプションです。インストールに時間がかかる場合があります。
86
+ # pip install csrc/layer_norm
87
+ # pip install csrc/rotary
88
+ ```
89
+
90
+ これで ModelScope か Transformers で始めることができます。
91
+
92
+ #### 🤗 Transformers
93
+
94
+ Qwen-7B-Chat を推論に使用するには、以下のように数行のコードを入力するだけです。**最新のコードを使用していることを確認してください。**
95
+
96
+ ```python
97
+ from transformers import AutoModelForCausalLM, AutoTokenizer
98
+ from transformers.generation import GenerationConfig
99
+
100
+ # 注: デフォルトの動作では、インジェクション攻撃防止機能がオフになっています。
101
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
102
+
103
+ # bf16 を使用
104
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval()
105
+ # fp16 を使用
106
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval()
107
+ # CPU のみ使用
108
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="cpu", trust_remote_code=True).eval()
109
+ # オートモードを使用すると、デバイスに応じて自動的に精度が選択されます。
110
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True).eval()
111
+
112
+ # 生成のためのハイパーパラメータを指定
113
+ model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
114
+
115
+ # 第一轮对话 第一回対話ターン
116
+ response, history = model.chat(tokenizer, "你好", history=None)
117
+ print(response)
118
+ # こんにちは! お役に立ててうれしいです。
119
+
120
+ # 第二轮对话 第二回対話ターン
121
+ response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=history)
122
+ print(response)
123
+ # これは、自分のビジネスを始めようと奮闘し、やがて成功する若者の物語である。
124
+ # この物語の主人公は、平凡な家庭に生まれ、平凡な労働者である両親を持つ李明である。 李明は子供の頃から起業家として成功することを目標としていた。
125
+ # この目標を達成するため、李明��猛勉強して大学に入った。 大学時代には、さまざまな起業家コンテストに積極的に参加し、多くの賞を獲得した。 また、余暇を利用してインターンシップにも参加し、貴重な経験を積んだ。
126
+ # 卒業後、李明は起業を決意した。 投資先を探し始めたが、何度も断られた。 しかし、彼はあきらめなかった。 彼は懸命に働き続け、ビジネスプランを改善し、新たな投資機会を探した。
127
+ # やがて李明は投資を受けることに成功し、自分のビジネスを始めた。 彼は新しいタイプのソフトウェアの開発に焦点を当てたテクノロジー会社を設立した。 彼のリーダーシップの下、会社は急速に成長し、テクノロジー企業として成功を収めた。
128
+ # 李明の成功は偶然ではない。 彼は勤勉で、たくましく、冒険好きで、常に学び、自分を高めている。 彼の成功はまた、努力すれば誰でも成功できることを証明している。
129
+
130
+ # 第三轮对话 第三回対話ターン
131
+ response, history = model.chat(tokenizer, "给这个故事起一个标题", history=history)
132
+ print(response)
133
+ # 《起業への奮闘:ある若者の成功への道》
134
+ ```
135
+
136
+ Qwen-7B の学習済みベースモデルの実行も簡単です。
137
+
138
+ <details>
139
+ <summary>Qwen-7B の実行</summary>
140
+
141
+ ```python
142
+ from transformers import AutoModelForCausalLM, AutoTokenizer
143
+ from transformers.generation import GenerationConfig
144
+
145
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B", trust_remote_code=True)
146
+ # bf16 を使用
147
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True, bf16=True).eval()
148
+ # fp16 を使用
149
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True, fp16=True).eval()
150
+ # CPU のみ使用
151
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="cpu", trust_remote_code=True).eval()
152
+ # オートモードを使用すると、デバイスに応じて自動的に精度が選択されます。
153
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B", device_map="auto", trust_remote_code=True).eval()
154
+
155
+ # 生成のためのハイパーパラメータを指定
156
+ model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B", trust_remote_code=True)
157
+
158
+ inputs = tokenizer('モンゴルの首都はウランバートル(Ulaanbaatar)\nアイスランドの首都はレイキャビク(Reykjavik)\nエチオピアの首都は', return_tensors='pt')
159
+ inputs = inputs.to(model.device)
160
+ pred = model.generate(**inputs)
161
+ print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
162
+ # モンゴルの首都はウランバートル(Ulaanbaatar)\nアイスランドの首都はレイキャビク(Reykjavik)\nエチオピアの首都はアディスアベバ(Addis Ababa)...
163
+ ```
164
+
165
+ </details>
166
+
167
+ #### 🤖 ModelScope
168
+
169
+ ModelScope は、MaaS(Model-as-a-Service) のためのオープンソースプラットフォームであり、AI 開発者に柔軟で費用対効果の高いモデルサービスを提供します。同様に、以下のように ModelScope でモデルを実行することができます:
170
+
171
+ ```python
172
+ import os
173
+ from modelscope.pipelines import pipeline
174
+ from modelscope.utils.constant import Tasks
175
+ from modelscope import snapshot_download
176
+
177
+ model_id = 'QWen/qwen-7b-chat'
178
+ revision = 'v1.0.0'
179
+
180
+ model_dir = snapshot_download(model_id, revision)
181
+
182
+ pipe = pipeline(
183
+ task=Tasks.chat, model=model_dir, device_map='auto')
184
+ history = None
185
+
186
+ text = '浙江省の省都はどこですか?'
187
+ results = pipe(text, history=history)
188
+ response, history = results['response'], results['history']
189
+ print(f'Response: {response}')
190
+ text = '何がそんなに面白いのか?'
191
+ results = pipe(text, history=history)
192
+ response, history = results['response'], results['history']
193
+ print(f'Response: {response}')
194
+ ```
195
+
196
+ ## トークナイザー
197
+
198
+ tiktoken に基づくトークナイザーは、他のトークナイザー、例えばセンテンスピーストークナイザーとは異なります。特にファインチューニングの際には、特殊なトークンに注意を払う必要があります。トークナイザに関する詳細な情報や、ファインチューニングにおける使用方法については、[ドキュメント](tokenization_note.md)を参照してください。
199
+
200
+ ## 量子化
201
+
202
+ `NF4` と `Int8` のモデルをロードする方法を示す例を提供します。手始めに、`bitsandbytes` が実装されていることを確認して下さい。`bitsandbytes` の要件は以下の通りになります:
203
+
204
+ ```
205
+ **必要条件** Python >= 3.8。Linux ディストリビューション(Ubuntu、MacOS など)+ CUDA > 10.0。
206
+ ```
207
+
208
+ そして、以下のコマンドを実行して `bitsandbytes` をインストールする:
209
+
210
+ ```
211
+ pip install bitsandbytes
212
+ ```
213
+
214
+ Windows ユーザは、[bitsandbytes-windows-webui](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) という別のオプションを見つける必要があります。
215
+
216
+ そして、量子化の設定を `AutoModelForCausalLM.from_pretrained` に追加するだけとなります。以下の例を参照してください:
217
+
218
+ ```python
219
+ from transformers import AutoModelForCausalLM, BitsAndBytesConfig
220
+
221
+ # NF4(4ビット)の量子化設定
222
+ quantization_config = BitsAndBytesConfig(
223
+ load_in_4bit=True,
224
+ bnb_4bit_quant_type='nf4',
225
+ bnb_4bit_compute_dtype=torch.bfloat16
226
+ )
227
+
228
+ # Int8(8ビット)の量子化設定
229
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
230
+
231
+ model = AutoModelForCausalLM.from_pretrained(
232
+ args.checkpoint_path,
233
+ device_map="cuda:0",
234
+ quantization_config=quantization_config,
235
+ max_memory=max_memory,
236
+ trust_remote_code=True,
237
+ ).eval()
238
+ ```
239
+
240
+ この方法では、Qwen-7B を `NF4` と `Int8` でロードすることができ、メモリ使用量を節約できる。以下にモデル性能の関連統計量を示します。量子化により、有効性は若干低下するが、推論効率は大幅に向上し、メモリコストが削減されることがわかります。
241
+
242
+ | Precision | MMLU | GPU Memory for Loading Model |
243
+ | ----------- | :------: | :---------------------------: |
244
+ | BF16 | 56.7 | 16.38G |
245
+ | Int8 | 52.8 | 10.44G |
246
+ | NF4 | 48.9 | 7.79G |
247
+
248
+ 注:上表のGPUメモリ使用量プロファイリングは、シングルA100-SXM4-80G GPU、PyTorch 2.0.1、CUDA 11.8、フラッシュアテンション使用で実行されています。
249
+
250
+ ## 推論効率
251
+
252
+ ### 推論スピード
253
+
254
+ BF16精度、量子化レベルInt8またはNF4で、それぞれ2Kトークンを生成する平均推論速度を測定した。
255
+
256
+ | Quantization Level | Inference Speed with flash_attn (tokens/s) | Inference Speed w/o flash_attn (tokens/s) |
257
+ | ------ | :---------------------------: | :---------------------------: |
258
+ | BF16 (no quantization) | 30.06 | 27.55 |
259
+ | Int8 (bnb) | 7.94 | 7.86 |
260
+ | NF4 (bnb) | 21.43 | 20.37 |
261
+
262
+ 詳細には、プロファイリングの設定は、1コンテクスト・トークンで2048の新しいトークンを生成している。プロファイリングは、PyTorch 2.0.1とCUDA 11.8を搭載したシングルA100-SXM4-80G GPUで実行される。推論速度は生成された2048個のトークンの平均です。
263
+
264
+ ### GPUメモリ使用量
265
+
266
+ また、BF16またはInt8/NF4量子化レベルの下で、2048個のトークンをコンテキストとしてエンコードした場合(および単一のトークンを生成した場合)と、8192個のトークンを生成した場合(単一のトークンをコンテキストとして生成した場合)のGPUメモリ使用量のピーク値をそれぞれプロファイリングしました。結果を以下に示す。
267
+
268
+ Flash attentionを使用した場合のメモリ使用量は以下の通りである:
269
+
270
+ | Quantization Level | Peak Usage for Encoding 2048 Tokens | Peak Usage for Generating 8192 Tokens |
271
+ | --- | :---: | :---: |
272
+ | BF16 | 18.11GB | 23.52GB |
273
+ | Int8 | 12.17GB | 17.60GB |
274
+ | NF4 | 9.52GB | 14.93GB |
275
+
276
+ Flash attentionを使用しない場合、メモリ使用量は次のようになる:
277
+
278
+ | Quantization Level | Peak Usage for Encoding 2048 Tokens | Peak Usage for Generating 8192 Tokens |
279
+ | --- | :---: | :---: |
280
+ | BF16 | 18.11GB | 24.40GB |
281
+ | Int8 | 12.18GB | 18.47GB |
282
+ | NF4 | 9.52GB | 15.81GB |
283
+
284
+ 上記のスピードとメモリーのプロファイリングは、[このスクリプト](https://qianwen-res.oss-cn-beijing.aliyuncs.com/profile.py)を使って行われた。
285
+
286
+ ## デモ
287
+
288
+ ### ウェブ UI
289
+
290
+ ウェブUIデモを構築するためのコードを提供します(@wysaidに感謝)。始める前に、以下のパッケージがインストールされていることを確認してください:
291
+
292
+ ```
293
+ pip install -r requirements_web_demo.txt
294
+ ```
295
+
296
+ そして、以下のコマンドを実行し、生成されたリンクをクリックする:
297
+
298
+ ```
299
+ python web_demo.py
300
+ ```
301
+
302
+ <p align="center">
303
+ <br>
304
+ <img src="assets/web_demo.gif" width="600" />
305
+ <br>
306
+ <p>
307
+
308
+ ### CLI デモ
309
+
310
+ `cli_demo.py` に CLI のデモ例を用意しています。ユーザはプロンプトを入力することで Qwen-7B-Chat と対話することができ、モデルはストリーミングモードでモデルの出力を返します。以下のコマンドを実行する:
311
+
312
+ ```
313
+ python cli_demo.py
314
+ ```
315
+
316
+ <p align="center">
317
+ <br>
318
+ <img src="assets/cli_demo.gif" width="600" />
319
+ <br>
320
+ <p>
321
+
322
+ ## API
323
+
324
+ OpenAI APIをベースにローカルAPIをデプロイする方法を提供する(@hanpenggitに感謝)。始める前に、必要なパッケージをインストールしてください:
325
+
326
+ ```bash
327
+ pip install fastapi uvicorn openai pydantic sse_starlette
328
+ ```
329
+
330
+ それから、APIをデプロイするコマンドを���行する:
331
+
332
+ ```bash
333
+ python openai_api.py
334
+ ```
335
+
336
+ チェックポイント名やパスには `-c` 、CPU デプロイメントには `--cpu-only` など、引数を変更できます。APIデプロイメントを起動する際に問題が発生した場合は、パッケージを最新バージョンに更新することで解決できる可能性があります。
337
+
338
+ APIの使い方も簡単だ。以下の例をご覧ください:
339
+
340
+ ```python
341
+ import openai
342
+ openai.api_base = "http://localhost:8000/v1"
343
+ openai.api_key = "none"
344
+
345
+ # create a request activating streaming response
346
+ for chunk in openai.ChatCompletion.create(
347
+ model="Qwen-7B",
348
+ messages=[
349
+ {"role": "user", "content": "你好"}
350
+ ],
351
+ stream=True
352
+ ):
353
+ if hasattr(chunk.choices[0].delta, "content"):
354
+ print(chunk.choices[0].delta.content, end="", flush=True)
355
+
356
+ # create a request not activating streaming response
357
+ response = openai.ChatCompletion.create(
358
+ model="Qwen-7B",
359
+ messages=[
360
+ {"role": "user", "content": "你好"}
361
+ ],
362
+ stream=False
363
+ )
364
+ print(response.choices[0].message.content)
365
+ ```
366
+
367
+ <p align="center">
368
+ <br>
369
+ <img src="assets/openai_api.gif" width="600" />
370
+ <br>
371
+ <p>
372
+
373
+ ## ツールの使用
374
+
375
+ Qwen-7B-Chat は、API、データベース、モデルなど、ツールの利用に特化して最適化されており、ユーザは独自の Qwen-7B ベースの LangChain、エージェント、コードインタプリタを構築することができます。ツール利用能力を評価するための評価[ベンチマーク](eval/EVALUATION.md)では、Qwen-7B は安定した性能に達しています。
376
+ [](https://)
377
+
378
+ | Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
379
+ |:------------|:----------------------:|:----------------------:|:----------------------:|
380
+ | GPT-4 | 95% | **0.90** | 15% |
381
+ | GPT-3.5 | 85% | 0.88 | 75% |
382
+ | **Qwen-7B** | **99%** | 0.89 | **9.7%** |
383
+
384
+ ReAct プロンプトの書き方や使い方については、[ReAct の例](examples/react_prompt.md)を参照してください。ツールを使用することで、モデルがよりよいタスクを実行できるようになります。
385
+
386
+ さらに、エージェントとしての能力を示す実験結果を提供する。詳細は [Hugging Face Agent](https://huggingface.co/docs/transformers/transformers_agents) を参照。Hugging Face が提供するランモードベンチマークでの性能は以下の通りです:
387
+
388
+ | Model | Tool Selection↑ | Tool Used↑ | Code↑ |
389
+ |:---------------|:---------------:|:-----------:|:---------:|
390
+ |GPT-4 | **100** | **100** | **97.41** |
391
+ |GPT-3.5 | 95.37 | 96.30 | 87.04 |
392
+ |StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
393
+ | **Qwen-7B** | 90.74 | 92.59 | 74.07 |
394
+
395
+ ## 長い文脈の理解
396
+
397
+ コンテキストの長さを拡張し、訓練シーケンスの長さのボトルネックを解消するために、NTK を考慮した補間、ウィンドウアテンション、LogN アテンションスケーリングなどの技術を導入し、コンテキストの長さを 8K トークン以上に拡張する。arXiv データセットを用いて PPL 評価による言語モデリング実験を行い、Qwen-7B が長いコンテキストのシナリオにおいて卓越した性能を達成できることを見出した。以下に結果を示します:
398
+
399
+ <table>
400
+ <tr>
401
+ <th rowspan="2">Model</th><th colspan="5" align="center">Sequence Length</th>
402
+ </tr>
403
+ <tr>
404
+ <th align="center">1024</th><th align="center">2048</th><th align="center">4096</th><th align="center">8192</th><th align="center">16384</th>
405
+ </tr>
406
+ <tr>
407
+ <td>Qwen-7B</td><td align="center"><b>4.23</b></td><td align="center"><b>3.78</b></td><td align="center">39.35</td><td align="center">469.81</td><td align="center">2645.09</td>
408
+ </tr>
409
+ <tr>
410
+ <td>+ dynamic_ntk</td><td align="center"><b>4.23</b></td><td align="center"><b>3.78</b></td><td align="center">3.59</td><td align="center">3.66</td><td align="center">5.71</td>
411
+ </tr>
412
+ <tr>
413
+ <td>+ dynamic_ntk + logn</td><td align="center"><b>4.23</b></td><td align="center"><b>3.78</b></td><td align="center"><b>3.58</b></td><td align="center">3.56</td><td align="center">4.62</td>
414
+ </tr>
415
+ <tr>
416
+ <td>+ dynamic_ntk + logn + window_attn</td><td align="center"><b>4.23</b></td><td align="center"><b>3.78</b></td><td align="center"><b>3.58</b></td><td align="center"><b>3.49</b></td><td align="center"><b>4.32</b></td>
417
+ </tr>
418
+ </table>
419
+
420
+ ## 再現
421
+
422
+ ベンチマークデータセットでのモデル性能の再現のために、結果を再現するスクリプトを提供しています。詳しくは [eval/EVALUATION.md](eval/EVALUATION.md) を確認してください。なお、再現の結果、我々の報告結果と若干異なる場合がある。
423
+
424
+ ## FAQ
425
+
426
+ 問題が発生した場合は、[FAQ](FAQ.md)やissueを参照し、新しいissueを立ち上げる前に解決策を探してください。
427
+
428
+ ## ライセンス契約
429
+
430
+ Qwen-7B と Qwen-7B-Chat のコードとモデルウェイトは、研究者や開発者が自由に使用することができます。また、商用利用も可能です。詳しくは [LICENSE](LICENSE) をご覧ください。商用利用を希望される方は、[リクエストフォーム](https://dashscope.console.aliyun.com/openModelApply/qianwen)に必要事項をご記入の上、お申し込みください。
431
+
432
+ ## お問い合わせ
433
+
434
+ 研究チームまたは製品チームへのメッセージは、qianwen_opensource@alibabacloud.com までお気軽にお送りください。
435
+
assets/cli_demo.gif ADDED

Git LFS Details

  • SHA256: 2502b56784d9e2ba70094c040f80b3571d4969a24d27b126ce22ed489b3e31f1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.98 MB
assets/hfagent_chat_1.png ADDED

Git LFS Details

  • SHA256: 356ea19c2c4a656cae9d55e2d727d1651d1955ec67385615c6582b394478e889
  • Pointer size: 132 Bytes
  • Size of remote file: 1.71 MB
assets/hfagent_chat_2.png ADDED

Git LFS Details

  • SHA256: 7db53a1a77dfc19072ce418db6df56fd89f9e7cb2e30430ac8320f10fc8a8bc0
  • Pointer size: 132 Bytes
  • Size of remote file: 1.93 MB
assets/hfagent_run.png ADDED

Git LFS Details

  • SHA256: fbf4c1232c86e334b5425aacdcc9e7a878100f80d6d70725060cb312bae7d701
  • Pointer size: 132 Bytes
  • Size of remote file: 2.77 MB
assets/logo.jpg ADDED
assets/openai_api.gif ADDED

Git LFS Details

  • SHA256: b457ed0497eba0dff8e2a11093662a548df73c09b506f590a94cab9535a6b83b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.2 MB
assets/performance.png ADDED
assets/qwen_tokenizer.png ADDED
assets/react_showcase_001.png ADDED
assets/react_showcase_002.png ADDED
assets/react_tutorial_001.png ADDED
assets/react_tutorial_002.png ADDED
assets/tokenizer.pdf ADDED
Binary file (24.7 kB). View file
 
assets/tokenizer.png ADDED
assets/wanx_colorful_black.png ADDED

Git LFS Details

  • SHA256: 650a5431b1a3b4411fc4c2fd44dea3066a4ec67b03b684721086265698d738c4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.33 MB
assets/web_demo.gif ADDED

Git LFS Details

  • SHA256: 4a721165a571d1b8a22861d0c489f5b6ce5bb1df44470fff957bc8704e2bf996
  • Pointer size: 133 Bytes
  • Size of remote file: 18.8 MB
cli_demo.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """A simple command-line interactive chat demo."""
7
+
8
+ import argparse
9
+ import os
10
+ import platform
11
+ import shutil
12
+ from copy import deepcopy
13
+
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer
15
+ from transformers.generation import GenerationConfig
16
+ from transformers.trainer_utils import set_seed
17
+
18
+ DEFAULT_CKPT_PATH = 'QWen/QWen-7B-Chat'
19
+
20
+ _WELCOME_MSG = '''\
21
+ Welcome to use Qwen-7B-Chat model, type text to start chat, type :h to show command help
22
+ 欢迎使用 Qwen-7B 模型,输入内容即可进行对话,:h 显示命令帮助
23
+ '''
24
+ _HELP_MSG = '''\
25
+ Commands:
26
+ :help / :h Show this help message 显示帮助信息
27
+ :exit / :quit / :q Exit the demo 退出Demo
28
+ :clear / :cl Clear screen 清屏
29
+ :clear-his / :clh Clear history 清除对话历史
30
+ :history / :his Show history 显示对话历史
31
+ :seed Show current random seed 显示当前随机种子
32
+ :seed <N> Set random seed to <N> 设置随机种子
33
+ :conf Show current generation config 显示生成配置
34
+ :conf <key>=<value> Change generation config 修改生成配置
35
+ :reset-conf Reset generation config 重置生成配置
36
+ '''
37
+
38
+
39
+ def _load_model_tokenizer(args):
40
+ tokenizer = AutoTokenizer.from_pretrained(
41
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
42
+ )
43
+
44
+ if args.cpu_only:
45
+ device_map = "cpu"
46
+ else:
47
+ device_map = "auto"
48
+
49
+ model = AutoModelForCausalLM.from_pretrained(
50
+ args.checkpoint_path,
51
+ device_map=device_map,
52
+ trust_remote_code=True,
53
+ resume_download=True,
54
+ ).eval()
55
+ model.generation_config = GenerationConfig.from_pretrained(
56
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
57
+ )
58
+ return model, tokenizer
59
+
60
+
61
+ def _clear_screen():
62
+ if platform.system() == "Windows":
63
+ os.system("cls")
64
+ else:
65
+ os.system("clear")
66
+
67
+
68
+ def _print_history(history):
69
+ terminal_width = shutil.get_terminal_size()[0]
70
+ print(f'History ({len(history)})'.center(terminal_width, '='))
71
+ for index, (query, response) in enumerate(history):
72
+ print(f'User[{index}]: {query}')
73
+ print(f'QWen[{index}]: {response}')
74
+ print('=' * terminal_width)
75
+
76
+
77
+ def _get_input() -> str:
78
+ while True:
79
+ try:
80
+ message = input('User> ').strip()
81
+ except UnicodeDecodeError:
82
+ print('[ERROR] Encoding error in input')
83
+ continue
84
+ except KeyboardInterrupt:
85
+ exit(1)
86
+ if message:
87
+ return message
88
+ print('[ERROR] Query is empty')
89
+
90
+
91
+ def main():
92
+ parser = argparse.ArgumentParser(
93
+ description='QWen-7B-Chat command-line interactive chat demo.')
94
+ parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
95
+ help="Checkpoint name or path, default to %(default)r")
96
+ parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
97
+ parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
98
+ args = parser.parse_args()
99
+
100
+ history, response = [], ''
101
+
102
+ model, tokenizer = _load_model_tokenizer(args)
103
+ orig_gen_config = deepcopy(model.generation_config)
104
+
105
+ _clear_screen()
106
+ print(_WELCOME_MSG)
107
+
108
+ seed = args.seed
109
+
110
+ while True:
111
+ query = _get_input()
112
+
113
+ # Process commands.
114
+ if query.startswith(':'):
115
+ command_words = query[1:].strip().split()
116
+ if not command_words:
117
+ command = ''
118
+ else:
119
+ command = command_words[0]
120
+
121
+ if command in ['exit', 'quit', 'q']:
122
+ break
123
+ elif command in ['clear', 'cl']:
124
+ _clear_screen()
125
+ print(_WELCOME_MSG)
126
+ continue
127
+ elif command in ['clear-history', 'clh']:
128
+ print(f'[INFO] All {len(history)} history cleared')
129
+ history.clear()
130
+ continue
131
+ elif command in ['help', 'h']:
132
+ print(_HELP_MSG)
133
+ continue
134
+ elif command in ['history', 'his']:
135
+ _print_history(history)
136
+ continue
137
+ elif command in ['seed']:
138
+ if len(command_words) == 1:
139
+ print(f'[INFO] Current random seed: {seed}')
140
+ continue
141
+ else:
142
+ new_seed_s = command_words[1]
143
+ try:
144
+ new_seed = int(new_seed_s)
145
+ except ValueError:
146
+ print(f'[WARNING] Fail to change random seed: {new_seed_s!r} is not a valid number')
147
+ else:
148
+ print(f'[INFO] Random seed changed to {new_seed}')
149
+ seed = new_seed
150
+ continue
151
+ elif command in ['conf']:
152
+ if len(command_words) == 1:
153
+ print(model.generation_config)
154
+ else:
155
+ for key_value_pairs_str in command_words[1:]:
156
+ eq_idx = key_value_pairs_str.find('=')
157
+ if eq_idx == -1:
158
+ print('[WARNING] format: <key>=<value>')
159
+ continue
160
+ conf_key, conf_value_str = key_value_pairs_str[:eq_idx], key_value_pairs_str[eq_idx + 1:]
161
+ try:
162
+ conf_value = eval(conf_value_str)
163
+ except Exception as e:
164
+ print(e)
165
+ continue
166
+ else:
167
+ print(f'[INFO] Change config: model.generation_config.{conf_key} = {conf_value}')
168
+ setattr(model.generation_config, conf_key, conf_value)
169
+ continue
170
+ elif command in ['reset-conf']:
171
+ print('[INFO] Reset generation config')
172
+ model.generation_config = deepcopy(orig_gen_config)
173
+ print(model.generation_config)
174
+ continue
175
+ else:
176
+ # As normal query.
177
+ pass
178
+
179
+ # Run chat.
180
+ set_seed(seed)
181
+ try:
182
+ for response in model.chat_stream(tokenizer, query, history=history):
183
+ _clear_screen()
184
+ print(f"\nUser: {query}")
185
+ print(f"\nQwen-7B: {response}")
186
+ except KeyboardInterrupt:
187
+ print('[WARNING] Generation interrupted')
188
+ continue
189
+
190
+ history.append((query, response))
191
+
192
+
193
+ if __name__ == "__main__":
194
+ main()
eval/EVALUATION.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## 评测复现
2
+
3
+ - CEVAL
4
+
5
+ ```Shell
6
+ wget https://huggingface.co/datasets/ceval/ceval-exam/resolve/main/ceval-exam.zip
7
+ mkdir data/ceval
8
+ mv ceval-exam.zip data/ceval
9
+ cd data/ceval; unzip ceval-exam.zip
10
+ cd ../../
11
+
12
+ # Qwen-7B
13
+ python evaluate_ceval.py -d data/ceval/
14
+
15
+ # Qwen-7B-Chat
16
+ pip install thefuzz
17
+ python evaluate_chat_ceval.py -d data/ceval/
18
+ ```
19
+
20
+ - MMLU
21
+
22
+ ```Shell
23
+ wget https://people.eecs.berkeley.edu/~hendrycks/data.tar
24
+ mkdir data/mmlu
25
+ mv data.tar data/mmlu
26
+ cd data/mmlu; tar xf data.tar
27
+ cd ../../
28
+
29
+ # Qwen-7B
30
+ python evaluate_mmlu.py -d data/mmlu/data/
31
+
32
+ # Qwen-7B-Chat
33
+ pip install thefuzz
34
+ python evaluate_chat_mmlu.py -d data/mmlu/data/
35
+ ```
36
+
37
+ - HumanEval
38
+
39
+ Get the HumanEval.jsonl file from [here](https://github.com/openai/human-eval/tree/master/data)
40
+
41
+ ```Shell
42
+ git clone https://github.com/openai/human-eval
43
+ pip install -e human-eval
44
+
45
+ # Qwen-7B
46
+ python evaluate_humaneval.py -f HumanEval.jsonl -o HumanEval_res.jsonl
47
+ evaluate_functional_correctness HumanEval_res.jsonl
48
+ # Qwen-7B-Chat
49
+ python evaluate_chat_mmlu.py -f HumanEval.jsonl -o HumanEval_res_chat.jsonl
50
+ evaluate_functional_correctness HumanEval_res_chat.jsonl
51
+ ```
52
+
53
+ When installing package human-eval, please note its following disclaimer:
54
+
55
+ This program exists to run untrusted model-generated code. Users are strongly encouraged not to do so outside of a robust security sandbox. The execution call in execution.py is deliberately commented out to ensure users read this disclaimer before running code in a potentially unsafe manner. See the comment in execution.py for more information and instructions.
56
+
57
+ - GSM8K
58
+
59
+ ```Shell
60
+ # Qwen-7B
61
+ python evaluate_gsm8k.py
62
+
63
+ # Qwen-7B-Chat
64
+ python evaluate_chat_gsm8k.py # zeroshot
65
+ python evaluate_chat_gsm8k.py --use-fewshot # fewshot
66
+ ```
67
+
68
+ - PLUGIN
69
+
70
+ This script is used to reproduce the results of the ReAct and Hugging Face Agent in the Tool Usage section of the README document.
71
+
72
+ ```Shell
73
+ # Qwen-7B-Chat
74
+ mkdir data;
75
+ cd data;
76
+ wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/opensource_data/exam_plugin_v1/exam_plugin_v1_react_positive.jsonl;
77
+ wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/opensource_data/exam_plugin_v1/exam_plugin_v1_react_negative.jsonl;
78
+ cd ..;
79
+ pip install json5;
80
+ pip install jsonlines;
81
+ pip install rouge_score;
82
+ python evaluate_plugin.py --eval-react-positive --eval-react-negative --eval-hfagent
83
+ ```
eval/evaluate_ceval.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import argparse
5
+ import datasets
6
+ import torch
7
+
8
+ from typing import List
9
+ from tqdm import tqdm
10
+ from transformers.trainer_utils import set_seed
11
+
12
+
13
+ '''
14
+ wget https://huggingface.co/datasets/ceval/ceval-exam/resolve/main/ceval-exam.zip
15
+ mkdir data/ceval
16
+ mv ceval-exam.zip data/ceval
17
+ cd data/ceval; unzip ceval-exam.zip
18
+ cd ../../
19
+ python evaluate_ceval.py -d data/ceval/
20
+ '''
21
+
22
+ def load_models_tokenizer(args):
23
+ from transformers import AutoModelForCausalLM, AutoTokenizer
24
+ from transformers.generation import GenerationConfig
25
+
26
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
27
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval()
28
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
29
+ return model, tokenizer
30
+
31
+
32
+ def format_example(line, include_answer=True):
33
+ example = '问题:' + line['question']
34
+ for choice in choices:
35
+ example += f'\n{choice}. {line[f"{choice}"]}'
36
+
37
+ if include_answer:
38
+ example += '\n答案:' + line["answer"] + '\n\n'
39
+ else:
40
+ example += '\n答案:'
41
+ return example
42
+
43
+
44
+ def generate_few_shot_prompt(k, subject, dev_df):
45
+ prompt = ''
46
+ if k == -1:
47
+ k = dev_df.shape[0]
48
+ for i in range(k):
49
+ prompt += format_example(
50
+ dev_df.iloc[i, :],
51
+ include_answer=True,
52
+ )
53
+ return prompt
54
+
55
+
56
+ def get_logits(tokenizer, model, inputs: List[str]):
57
+ input_ids = tokenizer(inputs, padding=False)['input_ids']
58
+ input_ids = torch.tensor(input_ids, device=model.device)
59
+ tokens = {'input_ids': input_ids}
60
+
61
+ outputs = model(input_ids)['logits']
62
+ logits = outputs[:, -1, :]
63
+ log_probs = torch.nn.functional.softmax(logits, dim=-1)
64
+ return log_probs, {'tokens': tokens}
65
+
66
+
67
+ @torch.no_grad()
68
+ def eval_subject(
69
+ model,
70
+ tokenizer,
71
+ subject_name,
72
+ test_df,
73
+ k=5,
74
+ dev_df=None,
75
+ few_shot=False,
76
+ save_result_dir=None,
77
+ **kwargs
78
+ ):
79
+ result = []
80
+ score = []
81
+
82
+ few_shot_prompt = generate_few_shot_prompt(
83
+ k, subject_name, dev_df) if few_shot else ''
84
+ all_probs = {'prob_A': [], 'prob_B': [], 'prob_C': [], 'prob_D': []}
85
+ if args.debug: print(f"few_shot_prompt: {few_shot_prompt}")
86
+
87
+ for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
88
+ question = format_example(row, include_answer=False)
89
+ full_prompt = few_shot_prompt + question
90
+
91
+ output, input_info = get_logits(tokenizer, model, [full_prompt])
92
+ assert output.shape[0] == 1
93
+ logits = output.flatten()
94
+
95
+ softval = torch.nn.functional.softmax(
96
+ torch.tensor(
97
+ [
98
+ logits[tokenizer("A")['input_ids']],
99
+ logits[tokenizer("B")['input_ids']],
100
+ logits[tokenizer("C")['input_ids']],
101
+ logits[tokenizer("D")['input_ids']],
102
+ ]
103
+ ),
104
+ dim=0,
105
+ )
106
+ if softval.dtype in {torch.bfloat16, torch.float16}:
107
+ softval = softval.to(dtype=torch.float32)
108
+ probs = softval.detach().cpu().numpy()
109
+
110
+ for i, choice in enumerate(choices):
111
+ all_probs[f'prob_{choice}'].append(probs[i])
112
+ pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
113
+
114
+ if 'answer' in row:
115
+ correct = 1 if pred == row['answer'] else 0
116
+ score.append(correct)
117
+ if args.debug: print(f'{question} pred: {pred} ref: {row["answer"]}')
118
+ result.append(pred)
119
+
120
+ if score:
121
+ correct_ratio = 100 * sum(score) / len(score)
122
+ if args.debug: print(subject_name, correct_ratio)
123
+ else:
124
+ correct_ratio = 0
125
+ if save_result_dir:
126
+ test_df['model_output'] = result
127
+ for i, choice in enumerate(choices):
128
+ test_df[f'prob_{choice}'] = (all_probs[f'prob_{choice}'])
129
+ if score:
130
+ test_df["correctness"] = score
131
+ os.makedirs(save_result_dir, exist_ok=True)
132
+ test_df.to_csv(os.path.join(
133
+ save_result_dir, f'{subject_name}_result.csv'), encoding="utf-8", index=False)
134
+
135
+ return correct_ratio
136
+
137
+
138
+ def cal_ceval(res):
139
+ acc_sum_dict = dict()
140
+ acc_norm_sum_dict = dict()
141
+ cnt_dict = dict()
142
+ acc_sum = 0.
143
+ cnt = 0
144
+ hard_cnt = 0
145
+ hard_acc_sum = 0.
146
+ for tt in res.keys():
147
+ name = tt.split('-')[-1]
148
+ acc_sum += float(res[tt])
149
+ cnt += 1
150
+ class_ = TASK_NAME_MAPPING[name][2]
151
+ if class_ not in acc_sum_dict:
152
+ acc_sum_dict[class_] = 0.
153
+ acc_norm_sum_dict[class_] = 0.
154
+ cnt_dict[class_] = 0.
155
+ if name in hard_list:
156
+ hard_cnt += 1
157
+ hard_acc_sum += float(res[tt])
158
+ acc_sum_dict[class_] += float(res[tt])
159
+ cnt_dict[class_] += 1
160
+ print('\n\n\n')
161
+ for k in ['STEM', 'Social Science', 'Humanities', 'Other']:
162
+ if k in cnt_dict:
163
+ print('%s acc: %.2f ' % (
164
+ k, acc_sum_dict[k] / cnt_dict[k]))
165
+ if hard_cnt > 0:
166
+ print('Hard acc:%.2f ' % (hard_acc_sum / hard_cnt))
167
+ print('AVERAGE acc:%.2f ' % (acc_sum / cnt))
168
+
169
+
170
+ TASK_NAME_MAPPING = {
171
+ "computer_network": ["Computer Network", "\u8ba1\u7b97\u673a\u7f51\u7edc", "STEM"],
172
+ "operating_system": ["Operating System", "\u64cd\u4f5c\u7cfb\u7edf", "STEM"],
173
+ "computer_architecture": ["Computer Architecture", "\u8ba1\u7b97\u673a\u7ec4\u6210", "STEM"],
174
+ "college_programming": ["College Programming", "\u5927\u5b66\u7f16\u7a0b", "STEM"],
175
+ "college_physics": ["College Physics", "\u5927\u5b66\u7269\u7406", "STEM"],
176
+ "college_chemistry": ["College Chemistry", "\u5927\u5b66\u5316\u5b66", "STEM"],
177
+ "advanced_mathematics": ["Advanced Mathematics", "\u9ad8\u7b49\u6570\u5b66", "STEM"],
178
+ "probability_and_statistics": ["Probability and Statistics", "\u6982\u7387\u7edf\u8ba1", "STEM"],
179
+ "discrete_mathematics": ["Discrete Mathematics", "\u79bb\u6563\u6570\u5b66", "STEM"],
180
+ "electrical_engineer": ["Electrical Engineer", "\u6ce8\u518c\u7535\u6c14\u5de5\u7a0b\u5e08", "STEM"],
181
+ "metrology_engineer": ["Metrology Engineer", "\u6ce8\u518c\u8ba1\u91cf\u5e08", "STEM"],
182
+ "high_school_mathematics": ["High School Mathematics", "\u9ad8\u4e2d\u6570\u5b66", "STEM"],
183
+ "high_school_physics": ["High School Physics", "\u9ad8\u4e2d\u7269\u7406", "STEM"],
184
+ "high_school_chemistry": ["High School Chemistry", "\u9ad8\u4e2d\u5316\u5b66", "STEM"],
185
+ "high_school_biology": ["High School Biology", "\u9ad8\u4e2d\u751f\u7269", "STEM"],
186
+ "middle_school_mathematics": ["Middle School Mathematics", "\u521d\u4e2d\u6570\u5b66", "STEM"],
187
+ "middle_school_biology": ["Middle School Biology", "\u521d\u4e2d\u751f\u7269", "STEM"],
188
+ "middle_school_physics": ["Middle School Physics", "\u521d\u4e2d\u7269\u7406", "STEM"],
189
+ "middle_school_chemistry": ["Middle School Chemistry", "\u521d\u4e2d\u5316\u5b66", "STEM"],
190
+ "veterinary_medicine": ["Veterinary Medicine", "\u517d\u533b\u5b66", "STEM"],
191
+ "college_economics": ["College Economics", "\u5927\u5b66\u7ecf\u6d4e\u5b66", "Social Science"],
192
+ "business_administration": ["Business Administration", "\u5de5\u5546\u7ba1\u7406", "Social Science"],
193
+ "marxism": ["Marxism", "\u9a6c\u514b\u601d\u4e3b\u4e49\u57fa\u672c\u539f\u7406", "Social Science"],
194
+ "mao_zedong_thought": ["Mao Zedong Thought", "\u6bdb\u6cfd\u4e1c\u601d\u60f3\u548c\u4e2d\u56fd\u7279\u8272\u793e\u4f1a\u4e3b\u4e49\u7406\u8bba\u4f53\u7cfb\u6982\u8bba", "Social Science"],
195
+ "education_science": ["Education Science", "\u6559\u80b2\u5b66", "Social Science"],
196
+ "teacher_qualification": ["Teacher Qualification", "\u6559\u5e08\u8d44\u683c", "Social Science"],
197
+ "high_school_politics": ["High School Politics", "\u9ad8\u4e2d\u653f\u6cbb", "Social Science"],
198
+ "high_school_geography": ["High School Geography", "\u9ad8\u4e2d\u5730\u7406", "Social Science"],
199
+ "middle_school_politics": ["Middle School Politics", "\u521d\u4e2d\u653f\u6cbb", "Social Science"],
200
+ "middle_school_geography": ["Middle School Geography", "\u521d\u4e2d\u5730\u7406", "Social Science"],
201
+ "modern_chinese_history": ["Modern Chinese History", "\u8fd1\u4ee3\u53f2\u7eb2\u8981", "Humanities"],
202
+ "ideological_and_moral_cultivation": ["Ideological and Moral Cultivation", "\u601d\u60f3\u9053\u5fb7\u4fee\u517b\u4e0e\u6cd5\u5f8b\u57fa\u7840", "Humanities"],
203
+ "logic": ["Logic", "\u903b\u8f91\u5b66", "Humanities"],
204
+ "law": ["Law", "\u6cd5\u5b66", "Humanities"],
205
+ "chinese_language_and_literature": ["Chinese Language and Literature", "\u4e2d\u56fd\u8bed\u8a00\u6587\u5b66", "Humanities"],
206
+ "art_studies": ["Art Studies", "\u827a\u672f\u5b66", "Humanities"],
207
+ "professional_tour_guide": ["Professional Tour Guide", "\u5bfc\u6e38\u8d44\u683c", "Humanities"],
208
+ "legal_professional": ["Legal Professional", "\u6cd5\u5f8b\u804c\u4e1a\u8d44\u683c", "Humanities"],
209
+ "high_school_chinese": ["High School Chinese", "\u9ad8\u4e2d\u8bed\u6587", "Humanities"],
210
+ "high_school_history": ["High School History", "\u9ad8\u4e2d\u5386\u53f2", "Humanities"],
211
+ "middle_school_history": ["Middle School History", "\u521d\u4e2d\u5386\u53f2", "Humanities"],
212
+ "civil_servant": ["Civil Servant", "\u516c\u52a1\u5458", "Other"],
213
+ "sports_science": ["Sports Science", "\u4f53\u80b2\u5b66", "Other"],
214
+ "plant_protection": ["Plant Protection", "\u690d\u7269\u4fdd\u62a4", "Other"],
215
+ "basic_medicine": ["Basic Medicine", "\u57fa\u7840\u533b\u5b66", "Other"],
216
+ "clinical_medicine": ["Clinical Medicine", "\u4e34\u5e8a\u533b\u5b66", "Other"],
217
+ "urban_and_rural_planner": ["Urban and Rural Planner", "\u6ce8\u518c\u57ce\u4e61\u89c4\u5212\u5e08", "Other"],
218
+ "accountant": ["Accountant", "\u6ce8\u518c\u4f1a\u8ba1\u5e08", "Other"],
219
+ "fire_engineer": ["Fire Engineer", "\u6ce8\u518c\u6d88\u9632\u5de5\u7a0b\u5e08", "Other"],
220
+ "environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "\u73af\u5883\u5f71\u54cd\u8bc4\u4ef7\u5de5\u7a0b\u5e08", "Other"],
221
+ "tax_accountant": ["Tax Accountant", "\u7a0e\u52a1\u5e08", "Other"],
222
+ "physician": ["Physician", "\u533b\u5e08\u8d44\u683c", "Other"]
223
+ }
224
+ hard_list = ['advanced_mathematics', 'discrete_mathematics', 'probability_and_statistics', 'college_physics', 'college_chemistry', 'high_school_mathematics', 'high_school_physics', 'high_school_chemistry']
225
+ choices = ["A", "B", "C", "D"]
226
+
227
+
228
+ def main(args):
229
+ model, tokenizer = load_models_tokenizer(args)
230
+
231
+ dev_result = {}
232
+ for subject_name in tqdm(TASK_NAME_MAPPING.keys()):
233
+ val_file_path = os.path.join(args.eval_data_path, 'val', f'{subject_name}_val.csv')
234
+ dev_file_path = os.path.join(args.eval_data_path, 'dev', f'{subject_name}_dev.csv')
235
+ # test_file_path = os.path.join(args.eval_data_path, 'test', f'{subject_name}_test.csv')
236
+ val_df = pd.read_csv(val_file_path)
237
+ dev_df = pd.read_csv(dev_file_path)
238
+ # test_df = pd.read_csv(test_file_path)
239
+
240
+ score = eval_subject(model, tokenizer, subject_name, val_df, dev_df=dev_df, k=5, few_shot=True,
241
+ save_result_dir=f"outs/ceval_eval_result")
242
+ dev_result[subject_name] = score
243
+ cal_ceval(dev_result)
244
+
245
+
246
+ if __name__ == '__main__':
247
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
248
+ parser.add_argument('-c', '--checkpoint-path', type=str, help='Checkpoint path', default="Qwen/Qwen-7B")
249
+ parser.add_argument('-s', '--seed', type=int, default=1234, help='Random seed')
250
+
251
+ """Provide extra arguments required for tasks."""
252
+ group = parser.add_argument_group(title='Evaluation options')
253
+ group.add_argument('-d', '--eval_data_path', type=str, required=True,
254
+ help='Path to eval data')
255
+ group.add_argument("--max-seq-len", type=int, default=2048,
256
+ help='Size of the output generated text.')
257
+ group.add_argument("--debug", action='store_true', default=False,
258
+ help='Print infos.')
259
+
260
+ args = parser.parse_args()
261
+ set_seed(args.seed)
262
+
263
+ main(args)
eval/evaluate_chat_ceval.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import argparse
5
+ import datasets
6
+ import torch
7
+ import re
8
+ from thefuzz import process
9
+ from typing import List
10
+ from tqdm import tqdm
11
+ from transformers.trainer_utils import set_seed
12
+
13
+ '''
14
+ wget https://huggingface.co/datasets/ceval/ceval-exam/resolve/main/ceval-exam.zip
15
+ mkdir data/ceval
16
+ mv ceval-exam.zip data/ceval
17
+ cd data/ceval; unzip ceval-exam.zip
18
+ cd ../../
19
+
20
+ pip install thefuzz
21
+ python eval/evaluate_chat_ceval.py -d data/ceval
22
+ '''
23
+
24
+ def load_models_tokenizer(args):
25
+ from transformers import AutoModelForCausalLM, AutoTokenizer
26
+ from transformers.generation import GenerationConfig
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
29
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True, bf16=True, use_flash_attn=True).eval()
30
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
31
+ model.generation_config.do_sample = False # use greedy decoding
32
+ return model, tokenizer
33
+
34
+ def process_before_extraction(gen, question, choice_dict):
35
+ # Example Prompt:
36
+ # 关于传输层的面向连接服务的特性是____。
37
+ # A. 既不保证可靠,也不保证按序交付
38
+ # B. 不保证可靠,但保证按序交付
39
+ # C. 保证可靠,但不保证按序交付
40
+ # D. 既保证可靠,也保证按序交付
41
+ # Example Model Output:
42
+ # 关于传输层的面向连接服务的特性是既保证可靠,也保证按序交付
43
+ # Processed Output:
44
+ # 答案是D
45
+
46
+ question_split = question.rstrip("。").split("。")[-1].split("_")
47
+
48
+ # replacing the question
49
+ if len(question_split[0].strip()) > 4:
50
+ gen = gen.replace(question_split[0], "答案是")
51
+ if len(question_split[-1].strip()) > 4:
52
+ gen = gen.replace(question_split[-1], "")
53
+
54
+ # replace the choice by letter in the generated sentence
55
+ # from longest one to shortest one
56
+ for key, val in sorted(choice_dict.items(), key=lambda x: len(x[1]), reverse=True):
57
+ gen = gen.replace(val.rstrip("。"), key)
58
+ return gen
59
+
60
+ def count_substr(gen, pattern):
61
+ return len(re.findall(pattern, gen))
62
+
63
+ def extract_choice(gen, prompt, choice_list):
64
+ # 答案是A | 选项是A | 应该选A选项
65
+ res = re.search(r"(?:(?:选|选择|选定)|(?:(?:答案|选项)(?![^ABCD]{0,10}?(?:不|非)[^ABCD]{0,10}?(?:是|为|:|:|】))[^ABCD]{0,10}?(?:是|为|:|:|】))[^ABCD]{0,10}?)(A|B|C|D)(?:选项)?(?:\)|。|\.|,|,|.|、|A|B|C|D|$)", gen)
66
+
67
+ # A选项正确 | A选项符合题意
68
+ if res is None:
69
+ res = re.search(r"(A|B|C|D)(?:选?项)?(?![^ABCD]{0,4}?(?:不|非)[^ABCD]{0,4}?(?:正确|对|符合))[^ABCD]{0,4}?(?:正确|对|符合)", gen)
70
+
71
+ # 直接输出 A
72
+ if res is None:
73
+ res = re.search(r"^(A|B|C|D)(?:。|\.|,|,|.|$)", gen)
74
+
75
+ # 获取第一个出现的字母
76
+ if res is None:
77
+ res = re.search(r"(?<![a-zA-Z])(A|B|C|D)(?![a-zA-Z=])", gen)
78
+
79
+ if res is None:
80
+ return choices[choice_list.index(process.extractOne(gen, choice_list)[0])]
81
+ else:
82
+ return res.group(1)
83
+
84
+ def format_example(line):
85
+ example = line['question'] + "\n\n"
86
+ for choice in choices:
87
+ example += f'{choice}. {line[f"{choice}"]}\n'
88
+ return example
89
+
90
+ def extract_answer(response, row):
91
+ prompt = row['question']
92
+ gen = process_before_extraction(response, prompt, {choice: row[choice] for choice in choices})
93
+ if not isinstance(prompt, str):
94
+ prompt = prompt[0]
95
+ pred = extract_choice(gen, prompt, [row[choice] for choice in choices])
96
+ return pred
97
+
98
+ @torch.no_grad()
99
+ def eval_subject(
100
+ model,
101
+ tokenizer,
102
+ subject_name,
103
+ test_df,
104
+ save_result_dir=None,
105
+ overwrite=False,
106
+ **kwargs
107
+ ):
108
+
109
+ result_path = os.path.join(save_result_dir, f'{subject_name}_result.csv')
110
+ if not overwrite and os.path.exists(result_path):
111
+ print(f"{result_path} existed, skip!")
112
+ score = []
113
+ for (_, datarow), (_, resultrow) in zip(test_df.iterrows(), pd.read_csv(result_path).iterrows()):
114
+ pred = extract_answer(resultrow['model_response'], datarow)
115
+ correct = 1 if pred == datarow['answer'] else 0
116
+ score.append(correct)
117
+ correct_ratio = 100 * sum(score) / len(score)
118
+ return correct_ratio
119
+
120
+ responses = []
121
+ result = []
122
+ score = []
123
+
124
+ for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
125
+ question = format_example(row)
126
+
127
+ response, history = model.chat(
128
+ tokenizer,
129
+ question,
130
+ history=None,
131
+ )
132
+ print(question)
133
+ print(response)
134
+ pred = extract_answer(response, row)
135
+ print(pred)
136
+ print("======================")
137
+
138
+ if 'answer' in row:
139
+ correct = 1 if pred == row['answer'] else 0
140
+ score.append(correct)
141
+ if args.debug: print(f'{question} pred: {pred} ref: {row["answer"]}')
142
+ responses.append(response)
143
+ result.append(pred)
144
+
145
+ if score:
146
+ correct_ratio = 100 * sum(score) / len(score)
147
+ if args.debug: print(subject_name, correct_ratio)
148
+ else:
149
+ correct_ratio = 0
150
+ if save_result_dir:
151
+ test_df['model_response'] = responses
152
+ test_df['model_output'] = result
153
+ if score:
154
+ test_df["correctness"] = score
155
+ os.makedirs(save_result_dir, exist_ok=True)
156
+ test_df.to_csv(result_path, encoding="utf-8", index=False)
157
+
158
+ return correct_ratio
159
+
160
+
161
+ def cal_ceval(res):
162
+ acc_sum_dict = dict()
163
+ acc_norm_sum_dict = dict()
164
+ cnt_dict = dict()
165
+ acc_sum = 0.
166
+ cnt = 0
167
+ hard_cnt = 0
168
+ hard_acc_sum = 0.
169
+ for tt in res.keys():
170
+ name = tt.split('-')[-1]
171
+ acc_sum += float(res[tt])
172
+ cnt += 1
173
+ class_ = TASK_NAME_MAPPING[name][2]
174
+ if class_ not in acc_sum_dict:
175
+ acc_sum_dict[class_] = 0.
176
+ acc_norm_sum_dict[class_] = 0.
177
+ cnt_dict[class_] = 0.
178
+ if name in hard_list:
179
+ hard_cnt += 1
180
+ hard_acc_sum += float(res[tt])
181
+ acc_sum_dict[class_] += float(res[tt])
182
+ cnt_dict[class_] += 1
183
+ print('\n\n\n')
184
+ for k in ['STEM', 'Social Science', 'Humanities', 'Other']:
185
+ if k in cnt_dict:
186
+ print('%s acc: %.2f ' % (
187
+ k, acc_sum_dict[k] / cnt_dict[k]))
188
+ if hard_cnt > 0:
189
+ print('Hard acc:%.2f ' % (hard_acc_sum / hard_cnt))
190
+ print('AVERAGE acc:%.2f ' % (acc_sum / cnt))
191
+
192
+
193
+ TASK_NAME_MAPPING = {
194
+ "computer_network": ["Computer Network", "\u8ba1\u7b97\u673a\u7f51\u7edc", "STEM"],
195
+ "operating_system": ["Operating System", "\u64cd\u4f5c\u7cfb\u7edf", "STEM"],
196
+ "computer_architecture": ["Computer Architecture", "\u8ba1\u7b97\u673a\u7ec4\u6210", "STEM"],
197
+ "college_programming": ["College Programming", "\u5927\u5b66\u7f16\u7a0b", "STEM"],
198
+ "college_physics": ["College Physics", "\u5927\u5b66\u7269\u7406", "STEM"],
199
+ "college_chemistry": ["College Chemistry", "\u5927\u5b66\u5316\u5b66", "STEM"],
200
+ "advanced_mathematics": ["Advanced Mathematics", "\u9ad8\u7b49\u6570\u5b66", "STEM"],
201
+ "probability_and_statistics": ["Probability and Statistics", "\u6982\u7387\u7edf\u8ba1", "STEM"],
202
+ "discrete_mathematics": ["Discrete Mathematics", "\u79bb\u6563\u6570\u5b66", "STEM"],
203
+ "electrical_engineer": ["Electrical Engineer", "\u6ce8\u518c\u7535\u6c14\u5de5\u7a0b\u5e08", "STEM"],
204
+ "metrology_engineer": ["Metrology Engineer", "\u6ce8\u518c\u8ba1\u91cf\u5e08", "STEM"],
205
+ "high_school_mathematics": ["High School Mathematics", "\u9ad8\u4e2d\u6570\u5b66", "STEM"],
206
+ "high_school_physics": ["High School Physics", "\u9ad8\u4e2d\u7269\u7406", "STEM"],
207
+ "high_school_chemistry": ["High School Chemistry", "\u9ad8\u4e2d\u5316\u5b66", "STEM"],
208
+ "high_school_biology": ["High School Biology", "\u9ad8\u4e2d\u751f\u7269", "STEM"],
209
+ "middle_school_mathematics": ["Middle School Mathematics", "\u521d\u4e2d\u6570\u5b66", "STEM"],
210
+ "middle_school_biology": ["Middle School Biology", "\u521d\u4e2d\u751f\u7269", "STEM"],
211
+ "middle_school_physics": ["Middle School Physics", "\u521d\u4e2d\u7269\u7406", "STEM"],
212
+ "middle_school_chemistry": ["Middle School Chemistry", "\u521d\u4e2d\u5316\u5b66", "STEM"],
213
+ "veterinary_medicine": ["Veterinary Medicine", "\u517d\u533b\u5b66", "STEM"],
214
+ "college_economics": ["College Economics", "\u5927\u5b66\u7ecf\u6d4e\u5b66", "Social Science"],
215
+ "business_administration": ["Business Administration", "\u5de5\u5546\u7ba1\u7406", "Social Science"],
216
+ "marxism": ["Marxism", "\u9a6c\u514b\u601d\u4e3b\u4e49\u57fa\u672c\u539f\u7406", "Social Science"],
217
+ "mao_zedong_thought": ["Mao Zedong Thought", "\u6bdb\u6cfd\u4e1c\u601d\u60f3\u548c\u4e2d\u56fd\u7279\u8272\u793e\u4f1a\u4e3b\u4e49\u7406\u8bba\u4f53\u7cfb\u6982\u8bba", "Social Science"],
218
+ "education_science": ["Education Science", "\u6559\u80b2\u5b66", "Social Science"],
219
+ "teacher_qualification": ["Teacher Qualification", "\u6559\u5e08\u8d44\u683c", "Social Science"],
220
+ "high_school_politics": ["High School Politics", "\u9ad8\u4e2d\u653f\u6cbb", "Social Science"],
221
+ "high_school_geography": ["High School Geography", "\u9ad8\u4e2d\u5730\u7406", "Social Science"],
222
+ "middle_school_politics": ["Middle School Politics", "\u521d\u4e2d\u653f\u6cbb", "Social Science"],
223
+ "middle_school_geography": ["Middle School Geography", "\u521d\u4e2d\u5730\u7406", "Social Science"],
224
+ "modern_chinese_history": ["Modern Chinese History", "\u8fd1\u4ee3\u53f2\u7eb2\u8981", "Humanities"],
225
+ "ideological_and_moral_cultivation": ["Ideological and Moral Cultivation", "\u601d\u60f3\u9053\u5fb7\u4fee\u517b\u4e0e\u6cd5\u5f8b\u57fa\u7840", "Humanities"],
226
+ "logic": ["Logic", "\u903b\u8f91\u5b66", "Humanities"],
227
+ "law": ["Law", "\u6cd5\u5b66", "Humanities"],
228
+ "chinese_language_and_literature": ["Chinese Language and Literature", "\u4e2d\u56fd\u8bed\u8a00\u6587\u5b66", "Humanities"],
229
+ "art_studies": ["Art Studies", "\u827a\u672f\u5b66", "Humanities"],
230
+ "professional_tour_guide": ["Professional Tour Guide", "\u5bfc\u6e38\u8d44\u683c", "Humanities"],
231
+ "legal_professional": ["Legal Professional", "\u6cd5\u5f8b\u804c\u4e1a\u8d44\u683c", "Humanities"],
232
+ "high_school_chinese": ["High School Chinese", "\u9ad8\u4e2d\u8bed\u6587", "Humanities"],
233
+ "high_school_history": ["High School History", "\u9ad8\u4e2d\u5386\u53f2", "Humanities"],
234
+ "middle_school_history": ["Middle School History", "\u521d\u4e2d\u5386\u53f2", "Humanities"],
235
+ "civil_servant": ["Civil Servant", "\u516c\u52a1\u5458", "Other"],
236
+ "sports_science": ["Sports Science", "\u4f53\u80b2\u5b66", "Other"],
237
+ "plant_protection": ["Plant Protection", "\u690d\u7269\u4fdd\u62a4", "Other"],
238
+ "basic_medicine": ["Basic Medicine", "\u57fa\u7840\u533b\u5b66", "Other"],
239
+ "clinical_medicine": ["Clinical Medicine", "\u4e34\u5e8a\u533b\u5b66", "Other"],
240
+ "urban_and_rural_planner": ["Urban and Rural Planner", "\u6ce8\u518c\u57ce\u4e61\u89c4\u5212\u5e08", "Other"],
241
+ "accountant": ["Accountant", "\u6ce8\u518c\u4f1a\u8ba1\u5e08", "Other"],
242
+ "fire_engineer": ["Fire Engineer", "\u6ce8\u518c\u6d88\u9632\u5de5\u7a0b\u5e08", "Other"],
243
+ "environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "\u73af\u5883\u5f71\u54cd\u8bc4\u4ef7\u5de5\u7a0b\u5e08", "Other"],
244
+ "tax_accountant": ["Tax Accountant", "\u7a0e\u52a1\u5e08", "Other"],
245
+ "physician": ["Physician", "\u533b\u5e08\u8d44\u683c", "Other"]
246
+ }
247
+ hard_list = ['advanced_mathematics', 'discrete_mathematics', 'probability_and_statistics', 'college_physics', 'college_chemistry', 'high_school_mathematics', 'high_school_physics', 'high_school_chemistry']
248
+ choices = ["A", "B", "C", "D"]
249
+
250
+
251
+ def main(args):
252
+ print("loading model weights")
253
+ if args.checkpoint_path:
254
+ model, tokenizer = load_models_tokenizer(args)
255
+ else:
256
+ model, tokenizer = None, None
257
+ print("model loaded")
258
+ dev_result = {}
259
+ for subject_name in tqdm(TASK_NAME_MAPPING.keys()):
260
+ val_file_path = os.path.join(args.eval_data_path, 'val', f'{subject_name}_val.csv')
261
+ # dev_file_path = os.path.join(args.eval_data_path, 'dev', f'{subject_name}_dev.csv')
262
+ # test_file_path = os.path.join(args.eval_data_path, 'test', f'{subject_name}_test.csv')
263
+ val_df = pd.read_csv(val_file_path)
264
+ # dev_df = pd.read_csv(dev_file_path)
265
+ # test_df = pd.read_csv(test_file_path)
266
+
267
+ score = eval_subject(model, tokenizer, subject_name, val_df,
268
+ save_result_dir=f"outs_chat/ceval_eval_result", overwrite=args.overwrite)
269
+ dev_result[subject_name] = score
270
+ cal_ceval(dev_result)
271
+
272
+
273
+ if __name__ == '__main__':
274
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
275
+ parser.add_argument('-c', '--checkpoint-path', type=str, help='Checkpoint path', default="Qwen/Qwen-7B-Chat")
276
+ parser.add_argument('-s', '--seed', type=int, default=1234, help='Random seed')
277
+
278
+ """Provide extra arguments required for tasks."""
279
+ group = parser.add_argument_group(title='Evaluation options')
280
+ group.add_argument('-d', '--eval_data_path', type=str, required=True,
281
+ help='Path to eval data')
282
+ group.add_argument("--debug", action='store_true', default=False,
283
+ help='Print infos.')
284
+ group.add_argument("--overwrite", action='store_true', default=False,
285
+ help='Overwrite existed results')
286
+
287
+ args = parser.parse_args()
288
+ set_seed(args.seed)
289
+
290
+ main(args)
eval/evaluate_chat_gsm8k.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import tqdm
3
+ import os
4
+ import re
5
+ import sys
6
+ import torch
7
+ import numpy as np
8
+ import jsonlines
9
+ import argparse
10
+ import json
11
+ from pathlib import Path
12
+ from datasets import load_from_disk,load_dataset
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+ from transformers.generation import GenerationConfig
15
+
16
+ '''
17
+ python eval/evaluate_chat_gsm8k.py [--use-fewshot]
18
+ '''
19
+
20
+ INVALID_ANS = "[invalid]"
21
+ DEVICE = "cuda:0"
22
+
23
+ def doc_to_text(doc, use_fewshot):
24
+ if use_fewshot:
25
+ context = "Question: Angelo and Melanie want to plan how many hours over the next week they should study together for their test next week. They have 2 chapters of their textbook to study and 4 worksheets to memorize. They figure out that they should dedicate 3 hours to each chapter of their textbook and 1.5 hours for each worksheet. If they plan to study no more than 4 hours each day, how many days should they plan to study total over the next week if they take a 10-minute break every hour, include 3 10-minute snack breaks each day, and 30 minutes for lunch each day?\nLet's think step by step\n" \
26
+ "Angelo and Melanie think they should dedicate 3 hours to each of the 2 chapters, 3 hours x 2 chapters = 6 hours total.\nFor the worksheets they plan to dedicate 1.5 hours for each worksheet, 1.5 hours x 4 worksheets = 6 hours total.\nAngelo and Melanie need to start with planning 12 hours to study, at 4 hours a day, 12 / 4 = 3 days.\nHowever, they need to include time for breaks and lunch. Every hour they want to include a 10-minute break, so 12 total hours x 10 minutes = 120 extra minutes for breaks.\nThey also want to include 3 10-minute snack breaks, 3 x 10 minutes = 30 minutes.\nAnd they want to include 30 minutes for lunch each day, so 120 minutes for breaks + 30 minutes for snack breaks + 30 minutes for lunch = 180 minutes, or 180 / 60 minutes per hour = 3 extra hours.\nSo Angelo and Melanie want to plan 12 hours to study + 3 hours of breaks = 15 hours total.\nThey want to study no more than 4 hours each day, 15 hours / 4 hours each day = 3.75\nThey will need to plan to study 4 days to allow for all the time they need.\nThe answer is 4\n\n" \
27
+ "Question: Mark's basketball team scores 25 2 pointers, 8 3 pointers and 10 free throws. Their opponents score double the 2 pointers but half the 3 pointers and free throws. What's the total number of points scored by both teams added together?\nLet's think step by step\n" \
28
+ "Mark's team scores 25 2 pointers, meaning they scored 25*2= 50 points in 2 pointers.\nHis team also scores 6 3 pointers, meaning they scored 8*3= 24 points in 3 pointers\nThey scored 10 free throws, and free throws count as one point so they scored 10*1=10 points in free throws.\nAll together his team scored 50+24+10= 84 points\nMark's opponents scored double his team's number of 2 pointers, meaning they scored 50*2=100 points in 2 pointers.\nHis opponents scored half his team's number of 3 pointers, meaning they scored 24/2= 12 points in 3 pointers.\nThey also scored half Mark's team's points in free throws, meaning they scored 10/2=5 points in free throws.\nAll together Mark's opponents scored 100+12+5=117 points\nThe total score for the game is both team's scores added together, so it is 84+117=201 points\nThe answer is 201\n\n" \
29
+ "Question: Bella has two times as many marbles as frisbees. She also has 20 more frisbees than deck cards. If she buys 2/5 times more of each item, what would be the total number of the items she will have if she currently has 60 marbles?\nLet's think step by step\n" \
30
+ "When Bella buys 2/5 times more marbles, she'll have increased the number of marbles by 2/5*60 = 24\nThe total number of marbles she'll have is 60+24 = 84\nIf Bella currently has 60 marbles, and she has two times as many marbles as frisbees, she has 60/2 = 30 frisbees.\nIf Bella buys 2/5 times more frisbees, she'll have 2/5*30 = 12 more frisbees.\nThe total number of frisbees she'll have will increase to 30+12 = 42\nBella also has 20 more frisbees than deck cards, meaning she has 30-20 = 10 deck cards\nIf she buys 2/5 times more deck cards, she'll have 2/5*10 = 4 more deck cards.\nThe total number of deck cards she'll have is 10+4 = 14\nTogether, Bella will have a total of 14+42+84 = 140 items\nThe answer is 140\n\n" \
31
+ "Question: A group of 4 fruit baskets contains 9 apples, 15 oranges, and 14 bananas in the first three baskets and 2 less of each fruit in the fourth basket. How many fruits are there?\nLet's think step by step\n" \
32
+ "For the first three baskets, the number of apples and oranges in one basket is 9+15=24\nIn total, together with bananas, the number of fruits in one basket is 24+14=38 for the first three baskets.\nSince there are three baskets each having 38 fruits, there are 3*38=114 fruits in the first three baskets.\nThe number of apples in the fourth basket is 9-2=7\nThere are also 15-2=13 oranges in the fourth basket\nThe combined number of oranges and apples in the fourth basket is 13+7=20\nThe fourth basket also contains 14-2=12 bananas.\nIn total, the fourth basket has 20+12=32 fruits.\nThe four baskets together have 32+114=146 fruits.\nThe answer is 146\n\n" \
33
+ f"Question: {doc['question']}\nLet's think step by step"
34
+ else:
35
+ context = doc['question']
36
+ return context
37
+
38
+ def decode(tokens_list, tokenizer, raw_text_len):
39
+ sents = []
40
+ # print(len(tokens_list))
41
+ for tokens in tokens_list:
42
+ tokens = tokens.cpu().numpy().tolist()
43
+ sent = tokenizer.tokenizer.decode(
44
+ tokens[raw_text_len:])
45
+ sent = sent.split('<|endoftext|>')[0]
46
+ sent = sent.split('\n\n\n')[0]
47
+ sent = sent.split("\n\n")[0]
48
+ sent = sent.split("Question:")[0]
49
+ sents.append(sent)
50
+ return sents
51
+
52
+ def generate_sample(model, tokenizer, question):
53
+ response, history = model.chat(
54
+ tokenizer,
55
+ question,
56
+ history=None,
57
+ )
58
+ print(question)
59
+ print("-------------")
60
+ print(response)
61
+ print("=============")
62
+ return response
63
+
64
+
65
+ def extract_answer_hf(completion):
66
+ def _get_last_digit(s):
67
+ _PAT_LAST_DIGIT = re.compile(r"(?<=(\s|[\$%#{]))([+-])?(?=(\S))(0|([1-9](\d*|\d{0,2}(,\d{3})*)))?(\.\d*[1-9])?(?=(\s|[.,}]|$))")
68
+ match = list(_PAT_LAST_DIGIT.finditer(s))
69
+ if match:
70
+ last_digit = match[-1].group().replace(",", "").replace("+", "")
71
+ # print(f"The last digit in {s} is {last_digit}")
72
+ else:
73
+ last_digit = None
74
+ print(f"No digits found in {s!r}")
75
+ return last_digit
76
+
77
+ job_gen = completion.strip('.').replace('\n', '\\n')
78
+ last_digit = _get_last_digit(job_gen)
79
+ if last_digit is not None:
80
+ return eval(last_digit)
81
+ else:
82
+ return INVALID_ANS
83
+
84
+ def extract_answer(completion):
85
+ try:
86
+ last_number = re.findall(r'\d+', completion)[-1]
87
+ return eval(last_number)
88
+ except:
89
+ return INVALID_ANS
90
+
91
+ def is_correct( completion, answer):
92
+ gold = extract_answer(answer)
93
+ assert gold != INVALID_ANS, "No ground truth answer found in the document."
94
+ return extract_answer(completion) == gold
95
+
96
+ if __name__ == '__main__':
97
+
98
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
99
+ parser.add_argument("-c", "--checkpoint-path", type=Path, help="Checkpoint path", default="Qwen/Qwen-7B-Chat")
100
+ parser.add_argument("-f","--sample-input-file", type=str, default=None)
101
+ parser.add_argument("-o","--sample-output-file", type=str, default="gsm8k_res.jsonl")
102
+ parser.add_argument("--use-fewshot", action="store_true")
103
+
104
+ args = parser.parse_args()
105
+
106
+ if args.sample_input_file is not None:
107
+ dataset = load_from_disk(args.sample_input_file)# or:
108
+ else:
109
+ dataset = load_dataset("gsm8k", "main")
110
+
111
+ print('Loading tokenizer ...')
112
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True, bf16=True, use_flash_attn=True)
113
+
114
+ print('Loading model ...')
115
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval()
116
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
117
+ model.generation_config.do_sample = False # use greedy decoding
118
+
119
+ test = dataset["test"]
120
+
121
+ f_output = open(args.sample_output_file, 'w', encoding='utf-8')
122
+ tot_length = test.num_rows
123
+ acc_res = []
124
+ for doc in tqdm.tqdm(test):
125
+ context = doc_to_text(doc, args.use_fewshot)
126
+ print(context)
127
+ completion = generate_sample(model, tokenizer, context)
128
+ answer = doc["answer"]
129
+ acc = is_correct(completion, answer)
130
+ doc["completion"] = completion
131
+ doc["acc"] = acc
132
+ f_output.write(json.dumps(doc, ensure_ascii=False) + "\n")
133
+ f_output.flush()
134
+ acc_res.append(acc)
135
+
136
+ f_output.close()
137
+ print("4-shot Acc: " if args.use_fewshot else "Zero-shot Acc", np.mean(acc_res))
eval/evaluate_chat_humaneval.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import tqdm
3
+ import os
4
+ import sys
5
+ import torch
6
+ import jsonlines
7
+ import argparse
8
+ import jsonlines
9
+ from pathlib import Path
10
+ import re
11
+ import textwrap
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
+ from transformers.generation import GenerationConfig
14
+
15
+ """
16
+ Get the HumanEval.jsonl file from [here](https://github.com/openai/human-eval/tree/master/data)
17
+
18
+ python eval/evaluate_chat_humaneval.py -f HumanEval.jsonl -o HumanEval_res.jsonl
19
+ git clone https://github.com/openai/human-eval
20
+ pip install -e human-eval
21
+ evaluate_functional_correctness HumanEval_res.jsonl
22
+ """
23
+
24
+ DEVICE = "cuda:0"
25
+
26
+ def extract_code(text, entry_point):
27
+
28
+ # 正则表达式匹配代码块
29
+ code_block_pattern = re.compile(rf"```(?:[Pp]ython\n)?.*?def\s+{entry_point}.*?:\n(.*?)\n```", re.DOTALL)
30
+ code_block = code_block_pattern.search(text)
31
+ if code_block is None:
32
+ code_block_pattern = re.compile(rf"def\s+{entry_point}.*?:\n(.*?)(?:\n(?!\n*(?: |\t))|$)", re.DOTALL)
33
+ code_block = code_block_pattern.search(text)
34
+ if code_block is None:
35
+ code_block_pattern = re.compile(rf"def.*?:\n(.*?)(?:\n(?!\n*(?: |\t))|$)", re.DOTALL)
36
+ code_block = code_block_pattern.search(text)
37
+
38
+ if code_block is not None:
39
+ return code_block.group(1)
40
+ else:
41
+ # if no code block is found, assume the LM is simply filling the code
42
+ return textwrap.indent(text, ' ' * 4)
43
+
44
+ def generate_sample(model, tokenizer, question, entry_point):
45
+ response, history = model.chat(
46
+ tokenizer,
47
+ question,
48
+ history=None,
49
+ )
50
+ print(question)
51
+ print(response)
52
+ answer = extract_code(response, entry_point)
53
+ return answer, response
54
+
55
+ if __name__ == '__main__':
56
+
57
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
58
+ parser.add_argument("-c", "--checkpoint-path", type=Path, help='Checkpoint path', default="Qwen/Qwen-7B-Chat")
59
+ parser.add_argument("-f","--sample-input-file", type=str, default=None, help="data path to HumanEval.jsonl")
60
+ parser.add_argument("-o","--sample-output-file", type=str, default="HumanEval_res.jsonl")
61
+
62
+
63
+ args = parser.parse_args()
64
+ print('Loading tokenizer ...')
65
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
66
+
67
+ print('Loading model ...')
68
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True, bf16=True, use_flash_attn=True).eval()
69
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
70
+ model.generation_config.do_sample = False # use greedy decoding
71
+
72
+ f_output = jsonlines.Writer(open(args.sample_output_file, 'w', encoding='utf-8'))
73
+
74
+ f = jsonlines.open(args.sample_input_file)
75
+ with f_output as output:
76
+ for jobj in tqdm.tqdm(f, desc='task_idx'):
77
+ prompt = "Help me fill the following code.\n" + jobj['prompt']
78
+ task_id = jobj['task_id']
79
+ answer, response = generate_sample(model, tokenizer, prompt, jobj['entry_point'])
80
+ gen_jobjs = {'task_id': task_id, "completion": answer, 'response': response}
81
+ output.write(gen_jobjs)
82
+ f_output.close()
eval/evaluate_chat_mmlu.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import argparse
5
+ import datasets
6
+ import torch
7
+ import re
8
+ from thefuzz import process
9
+ from typing import List
10
+ from tqdm import tqdm
11
+ from transformers.trainer_utils import set_seed
12
+
13
+ '''
14
+ wget https://people.eecs.berkeley.edu/~hendrycks/data.tar
15
+ mkdir data/mmlu
16
+ mv data.tar data/mmlu
17
+ cd data/mmlu; tar xf data.tar
18
+ cd ../../
19
+
20
+ pip install thefuzz
21
+ python eval/evaluate_chat_mmlu.py -d data/mmlu/data/
22
+ '''
23
+
24
+ def load_models_tokenizer(args):
25
+ from transformers import AutoModelForCausalLM, AutoTokenizer
26
+ from transformers.generation import GenerationConfig
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
29
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True, bf16=True, use_flash_attn=True).eval()
30
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
31
+ model.generation_config.do_sample = False # use greedy decoding
32
+ return model, tokenizer
33
+
34
+
35
+ def format_example(line):
36
+ example = 'The following is a multiple-choice question. Please choose the most suitable one among A, B, C and D as the answer to this question.\n\n' + line['question'] + "\n"
37
+ for choice in choices:
38
+ example += f'{choice}. {line[f"{choice}"]}\n'
39
+ return example
40
+
41
+
42
+ def process_before_extraction(gen, choice_dict):
43
+ # replace the choice by letter in the generated sentence
44
+ # from longest one to shortest one
45
+ for key, val in sorted(choice_dict.items(), key=lambda x: len(x[1]), reverse=True):
46
+ pattern = re.compile(re.escape(val.rstrip(".")), re.IGNORECASE)
47
+ gen = pattern.sub(key, gen)
48
+ return gen
49
+
50
+ def extract_choice(gen, choice_list):
51
+ # answer is A | choice is A | choose A
52
+ res = re.search(r"(?:(?:[Cc]hoose)|(?:(?:[Aa]nswer|[Cc]hoice)(?![^ABCD]{0,20}?(?:n't|not))[^ABCD]{0,10}?\b(?:|is|:|be))\b)[^ABCD]{0,20}?\b(A|B|C|D)\b", gen)
53
+
54
+ # A is correct | A is right
55
+ if res is None:
56
+ res = re.search(r"\b(A|B|C|D)\b(?![^ABCD]{0,8}?(?:n't|not)[^ABCD]{0,5}?(?:correct|right))[^ABCD]{0,10}?\b(?:correct|right)\b", gen)
57
+
58
+ # straight answer: A
59
+ if res is None:
60
+ res = re.search(r"^(A|B|C|D)(?:\.|,|:|$)", gen)
61
+
62
+ # simply extract the first appearred letter
63
+ if res is None:
64
+ res = re.search(r"(?<![a-zA-Z])(A|B|C|D)(?![a-zA-Z=])", gen)
65
+
66
+ if res is None:
67
+ return choices[choice_list.index(process.extractOne(gen, choice_list)[0])]
68
+ else:
69
+ return res.group(1)
70
+
71
+ def extract_answer(response, row):
72
+ gen = process_before_extraction(response, {choice: row[choice] for choice in choices})
73
+ pred = extract_choice(gen, [row[choice] for choice in choices])
74
+ return pred
75
+
76
+ @torch.no_grad()
77
+ def eval_subject(
78
+ model,
79
+ tokenizer,
80
+ subject_name,
81
+ test_df,
82
+ save_result_dir=None,
83
+ overwrite=False,
84
+ **kwargs
85
+ ):
86
+ result_path = os.path.join(save_result_dir, f'{subject_name}_result.csv')
87
+ if not overwrite and os.path.exists(result_path):
88
+ print(f"{result_path} existed, skip!")
89
+ score = []
90
+ for (_, datarow), (_, resultrow) in zip(test_df.iterrows(), pd.read_csv(result_path).iterrows()):
91
+ # pred = extract_answer(resultrow['model_response'], datarow)
92
+ pred = resultrow['model_output']
93
+ correct = 1 if pred == datarow['answer'] else 0
94
+ score.append(correct)
95
+ return score
96
+
97
+ result = []
98
+ score = []
99
+
100
+ for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
101
+ question = format_example(row)
102
+
103
+ response, history = model.chat(
104
+ tokenizer,
105
+ question,
106
+ history=None,
107
+ )
108
+ print(question)
109
+ print(response)
110
+ pred = extract_answer(response, row)
111
+ print(pred)
112
+ print("======================")
113
+
114
+ if 'answer' in row:
115
+ correct = 1 if pred == row['answer'] else 0
116
+ score.append(correct)
117
+ if args.debug: print(f'{question} pred: {pred} ref: {row["answer"]}')
118
+ result.append(pred)
119
+
120
+ if save_result_dir:
121
+ test_df['model_output'] = result
122
+ test_df['model_response'] = response
123
+ if score:
124
+ test_df["correctness"] = score
125
+ os.makedirs(save_result_dir, exist_ok=True)
126
+ test_df.to_csv(os.path.join(
127
+ save_result_dir, f'{subject_name}_result.csv'), encoding="utf-8", index=False)
128
+
129
+ return score
130
+
131
+
132
+ def cal_mmlu(res):
133
+ acc_sum_dict = dict()
134
+ acc_norm_sum_dict = dict()
135
+ cnt_dict = dict()
136
+ acc_sum = 0.
137
+ cnt = 0
138
+ hard_cnt = 0
139
+ hard_acc_sum = 0.
140
+
141
+ for class_ in TASK_NAME_MAPPING.keys():
142
+ acc_sum_dict[class_] = 0.
143
+ acc_norm_sum_dict[class_] = 0.
144
+ cnt_dict[class_] = 0.
145
+
146
+ for tt in TASK_NAME_MAPPING[class_]:
147
+ acc_sum += sum(res[tt])
148
+ cnt += len(res[tt])
149
+
150
+ acc_sum_dict[class_] += sum(res[tt])
151
+ cnt_dict[class_] += len(res[tt])
152
+
153
+ print('\n\n\n')
154
+ for k in TASK_NAME_MAPPING.keys():
155
+ if k in cnt_dict:
156
+ print('%s ACC: %.2f ' % (
157
+ k, acc_sum_dict[k] * 100 / cnt_dict[k]))
158
+ print('AVERAGE ACC:%.2f ' % (acc_sum *100 / cnt))
159
+
160
+
161
+ def main(args):
162
+ print("loading model weights")
163
+ if args.checkpoint_path is not None:
164
+ model, tokenizer = load_models_tokenizer(args)
165
+ else:
166
+ model, tokenizer = None, None
167
+ print("model loaded")
168
+
169
+ dev_result = {}
170
+ for subject_name in tqdm(SUBJECTS):
171
+ # val_file_path = os.path.join(args.eval_data_path, 'val', f'{subject_name}_val.csv')
172
+ # dev_file_path = os.path.join(args.eval_data_path, 'dev', f'{subject_name}_dev.csv')
173
+ test_file_path = os.path.join(args.eval_data_path, 'test', f'{subject_name}_test.csv')
174
+ # val_df = pd.read_csv(val_file_path, names=['question','A','B','C','D','answer'])
175
+ # dev_df = pd.read_csv(dev_file_path, names=['question','A','B','C','D','answer'])
176
+ test_df = pd.read_csv(test_file_path, names=['question','A','B','C','D','answer'])
177
+
178
+ score = eval_subject(model, tokenizer, subject_name, test_df, save_result_dir=f"outs_chat/mmlu_eval_result", overwrite=args.overwrite)
179
+ dev_result[subject_name] = score
180
+ cal_mmlu(dev_result)
181
+
182
+
183
+ TASK_NAME_MAPPING = {'stem': ['abstract_algebra', 'anatomy', 'astronomy', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_physics', 'computer_security', 'conceptual_physics', 'electrical_engineering', 'elementary_mathematics', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_mathematics', 'high_school_physics', 'high_school_statistics', 'machine_learning'],
184
+ 'Humanities': ['formal_logic', 'high_school_european_history', 'high_school_us_history', 'high_school_world_history', 'international_law', 'jurisprudence', 'logical_fallacies', 'moral_disputes', 'moral_scenarios', 'philosophy', 'prehistory', 'professional_law', 'world_religions'],
185
+ 'other': ['business_ethics', 'college_medicine', 'human_aging', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'nutrition', 'professional_accounting', 'professional_medicine', 'virology', 'global_facts', 'clinical_knowledge'],
186
+ 'social': ['econometrics', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_microeconomics', 'high_school_psychology', 'human_sexuality', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy']}
187
+ SUBJECTS = [v for vl in TASK_NAME_MAPPING.values() for v in vl]
188
+ choices = ["A", "B", "C", "D"]
189
+
190
+ if __name__ == '__main__':
191
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
192
+ parser.add_argument('-c', '--checkpoint-path', type=str, help='Checkpoint path', default="Qwen/Qwen-7B-Chat")
193
+ parser.add_argument('-s', '--seed', type=int, default=1234, help='Random seed')
194
+
195
+ """Provide extra arguments required for tasks."""
196
+ group = parser.add_argument_group(title='Evaluation options')
197
+ group.add_argument('-d', '--eval_data_path', type=str,
198
+ help='Path to eval data')
199
+ group.add_argument("--debug", action='store_true', default=False,
200
+ help='Print infos.')
201
+ group.add_argument("--overwrite", action='store_true', default=False,
202
+ help='Overwrite existed results')
203
+
204
+ args = parser.parse_args()
205
+ set_seed(args.seed)
206
+
207
+ main(args)
eval/evaluate_cmmlu.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import argparse
5
+ import datasets
6
+ import torch
7
+ from collections import defaultdict
8
+
9
+ from typing import List
10
+ from tqdm import tqdm
11
+ from transformers.trainer_utils import set_seed
12
+
13
+
14
+ '''
15
+ wget https://huggingface.co/datasets/haonan-li/cmmlu/resolve/main/cmmlu_v1_0_1.zip
16
+ mkdir data/cmmlu
17
+ mv cmmlu_v1_0_1.zip data/cmmlu
18
+ cd data/cmmlu; unzip cmmlu_v1_0_1.zip
19
+ cd ../../
20
+ python evaluate_cmmlu.py -d data/cmmlu/
21
+ '''
22
+
23
+ def load_models_tokenizer(args):
24
+ from transformers import AutoModelForCausalLM, AutoTokenizer
25
+ from transformers.generation import GenerationConfig
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
28
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval()
29
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
30
+ return model, tokenizer
31
+
32
+
33
+ def format_example(line, include_answer=True):
34
+ example = '问题:' + line['Question']
35
+ for choice in choices:
36
+ example += f'\n{choice}. {line[f"{choice}"]}'
37
+
38
+ if include_answer:
39
+ example += '\n答案:' + line["Answer"] + '\n\n'
40
+ else:
41
+ example += '\n答案:'
42
+ return example
43
+
44
+
45
+ def generate_few_shot_prompt(k, subject, dev_df):
46
+ prompt = ''
47
+ if k == -1:
48
+ k = dev_df.shape[0]
49
+ for i in range(k):
50
+ prompt += format_example(
51
+ dev_df.iloc[i, :],
52
+ include_answer=True,
53
+ )
54
+ return prompt
55
+
56
+
57
+ def get_logits(tokenizer, model, inputs: List[str]):
58
+ input_ids = tokenizer(inputs, padding=False)['input_ids']
59
+ input_ids = torch.tensor(input_ids, device=model.device)
60
+ tokens = {'input_ids': input_ids}
61
+
62
+ outputs = model(input_ids)['logits']
63
+ logits = outputs[:, -1, :]
64
+ log_probs = torch.nn.functional.softmax(logits, dim=-1)
65
+ return log_probs, {'tokens': tokens}
66
+
67
+
68
+ @torch.no_grad()
69
+ def eval_subject(
70
+ model,
71
+ tokenizer,
72
+ subject_name,
73
+ test_df,
74
+ k=5,
75
+ dev_df=None,
76
+ few_shot=False,
77
+ save_result_dir=None,
78
+ **kwargs
79
+ ):
80
+ result = []
81
+ score = []
82
+
83
+ few_shot_prompt = generate_few_shot_prompt(
84
+ k, subject_name, dev_df) if few_shot else []
85
+ all_probs = {'prob_A': [], 'prob_B': [], 'prob_C': [], 'prob_D': []}
86
+ if args.debug: print(f"few_shot_prompt: {few_shot_prompt}")
87
+
88
+ for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
89
+ question = format_example(row, include_answer=False)
90
+ full_prompt = few_shot_prompt + question
91
+
92
+ output, input_info = get_logits(tokenizer, model, [full_prompt])
93
+ assert output.shape[0] == 1
94
+ logits = output.flatten()
95
+
96
+ softval = torch.nn.functional.softmax(
97
+ torch.tensor(
98
+ [
99
+ logits[tokenizer("A")['input_ids']],
100
+ logits[tokenizer("B")['input_ids']],
101
+ logits[tokenizer("C")['input_ids']],
102
+ logits[tokenizer("D")['input_ids']],
103
+ ]
104
+ ),
105
+ dim=0,
106
+ )
107
+ if softval.dtype in {torch.bfloat16, torch.float16}:
108
+ softval = softval.to(dtype=torch.float32)
109
+ probs = softval.detach().cpu().numpy()
110
+
111
+ for i, choice in enumerate(choices):
112
+ all_probs[f'prob_{choice}'].append(probs[i])
113
+ pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
114
+
115
+ if 'Answer' in row:
116
+ correct = 1 if pred == row['Answer'] else 0
117
+ score.append(correct)
118
+ if args.debug: print(f'{question} pred: {pred} ref: {row["Answer"]}')
119
+ result.append(pred)
120
+
121
+ if score:
122
+ correct_ratio = 100 * sum(score) / len(score)
123
+ if args.debug: print(subject_name, correct_ratio)
124
+ else:
125
+ correct_ratio = 0
126
+ if save_result_dir:
127
+ test_df['model_output'] = result
128
+ for i, choice in enumerate(choices):
129
+ test_df[f'prob_{choice}'] = (all_probs[f'prob_{choice}'])
130
+ if score:
131
+ test_df["correctness"] = score
132
+ os.makedirs(save_result_dir, exist_ok=True)
133
+ test_df.to_csv(os.path.join(
134
+ save_result_dir, f'{subject_name}_result.csv'), encoding="utf-8", index=False)
135
+
136
+ return correct_ratio
137
+
138
+
139
+ def cal_cmmlu(res):
140
+ print('\n\n\n')
141
+ res = {k.split('-')[-1]:float(v) for k,v in res.items()}
142
+ for k, v in TASK_NAME_MAPPING.items():
143
+ avg_acc = np.mean(list(map(lambda x: res[x], v)))
144
+ print(f"{k} acc: {avg_acc:.2f}")
145
+ avg_all_acc = np.mean(list(res.values()))
146
+ print(f"AVERAGE acc: {avg_all_acc:.2f}")
147
+
148
+
149
+ subcategories = {
150
+ "agronomy": ['other'],
151
+ "anatomy": ['biology'],
152
+ "ancient_chinese": ['linguistics','china specific'],
153
+ "arts": ['arts'],
154
+ "astronomy": ['physics'],
155
+ "business_ethics": ['business'],
156
+ "chinese_civil_service_exam": ['politics','china specific'],
157
+ "chinese_driving_rule": ['other','china specific'],
158
+ "chinese_food_culture": ['culture','china specific'],
159
+ "chinese_foreign_policy": ['politics','china specific'],
160
+ "chinese_history":['history','china specific'],
161
+ "chinese_literature": ['literature','china specific'],
162
+ "chinese_teacher_qualification": ['education','china specific'],
163
+ "college_actuarial_science":['math'],
164
+ "college_education":['education'],
165
+ "college_engineering_hydrology": ['engineering'],
166
+ "college_law": ['law'],
167
+ "college_mathematics": ['math'],
168
+ "college_medical_statistics":['statistics'],
169
+ "clinical_knowledge": ['other'],
170
+ "college_medicine": ['other'],
171
+ "computer_science": ['computer science'],
172
+ "computer_security": ['other'],
173
+ "conceptual_physics": ['physics'],
174
+ "construction_project_management": ['other','china specific'],
175
+ "economics": ['economics'],
176
+ "education": ['education'],
177
+ "elementary_chinese":['linguistics','china specific'],
178
+ "elementary_commonsense":['other','china specific'],
179
+ "elementary_information_and_technology": ['other'],
180
+ "electrical_engineering": ['engineering'],
181
+ "elementary_mathematics": ['math'],
182
+ "ethnology": ['culture','china specific'],
183
+ "food_science": ['other'],
184
+ "genetics": ['biology'],
185
+ "global_facts": ['global'],
186
+ "high_school_biology": ['biology'],
187
+ "high_school_chemistry": ['chemistry'],
188
+ "high_school_geography": ['geography'],
189
+ "high_school_mathematics": ['math'],
190
+ "high_school_physics": ['physics'],
191
+ "high_school_politics": ['politics','china specific'],
192
+ "human_sexuality": ['other'],
193
+ "international_law": ['law'],
194
+ "journalism": ['sociology'],
195
+ "jurisprudence": ['law'],
196
+ "legal_and_moral_basis": ['other'],
197
+ "logical": ['philosophy'],
198
+ "machine_learning": ['computer science'],
199
+ "management": ['business'],
200
+ "marketing": ['business'],
201
+ "marxist_theory": ['philosophy'],
202
+ "modern_chinese": ['linguistics','china specific'],
203
+ "nutrition": ['other'],
204
+ "philosophy": ['philosophy'],
205
+ "professional_accounting": ['business'],
206
+ "professional_law": ['law'],
207
+ "professional_medicine": ['other'],
208
+ "professional_psychology": ['psychology'],
209
+ "public_relations": ['politics'],
210
+ "security_study": ['politics'],
211
+ "sociology": ['culture'],
212
+ "sports_science": ['other'],
213
+ "traditional_chinese_medicine": ['other','china specific'],
214
+ "virology": ['biology'],
215
+ "world_history":['history'],
216
+ "world_religions": ['global'],
217
+ }
218
+
219
+ categories = {
220
+ "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering", "statistics"],
221
+ "Humanities": ["history", "philosophy", "law", "arts", "literature", "global"],
222
+ "Social Science": ['linguistics',"business", "politics", "culture", "economics", "geography", "psychology", "education", "sociology"],
223
+ "Other":["other"],
224
+ "China specific": ["china specific"],
225
+ }
226
+
227
+ TASK_NAME_MAPPING = defaultdict(list)
228
+ for k,v in categories.items():
229
+ for subject, subcat in subcategories.items():
230
+ for c in subcat:
231
+ if c in v:
232
+ TASK_NAME_MAPPING[k].append(subject)
233
+
234
+
235
+ choices = ["A", "B", "C", "D"]
236
+
237
+
238
+ def main(args):
239
+ model, tokenizer = load_models_tokenizer(args)
240
+
241
+ test_result = {}
242
+ for subject_name in tqdm(subcategories.keys()):
243
+ dev_file_path = os.path.join(args.eval_data_path, 'dev', f'{subject_name}.csv')
244
+ test_file_path = os.path.join(args.eval_data_path, 'test', f'{subject_name}.csv')
245
+ dev_df = pd.read_csv(dev_file_path)
246
+ test_df = pd.read_csv(test_file_path)
247
+
248
+ score = eval_subject(model, tokenizer, subject_name, dev_df=dev_df, test_df=test_df, k=5, few_shot=True,
249
+ save_result_dir=f"outs/cmmlu_eval_result")
250
+ test_result[subject_name] = score
251
+ cal_cmmlu(test_result)
252
+
253
+
254
+ if __name__ == '__main__':
255
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
256
+ parser.add_argument('-c', '--checkpoint-path', type=str, help='Checkpoint path', default="Qwen/Qwen-7B")
257
+ parser.add_argument('-s', '--seed', type=int, default=1234, help='Random seed')
258
+
259
+ """Provide extra arguments required for tasks."""
260
+ group = parser.add_argument_group(title='Evaluation options')
261
+ group.add_argument('-d', '--eval_data_path', type=str, required=True,
262
+ help='Path to eval data')
263
+ group.add_argument("--max-seq-len", type=int, default=2048,
264
+ help='Size of the output generated text.')
265
+ group.add_argument("--debug", action='store_true', default=False,
266
+ help='Print infos.')
267
+
268
+ args = parser.parse_args()
269
+ set_seed(args.seed)
270
+
271
+ main(args)
eval/evaluate_gsm8k.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import tqdm
3
+ import os
4
+ import re
5
+ import sys
6
+ import torch
7
+ import numpy as np
8
+ import jsonlines
9
+ import argparse
10
+ import jsonlines
11
+ import datasets
12
+ from datasets import load_from_disk,load_dataset
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+ from transformers.generation import GenerationConfig
15
+
16
+
17
+ ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
18
+ INVALID_ANS = "[invalid]"
19
+
20
+ def doc_to_text(doc):
21
+ return fewshot_prompt + "\nQuestion: " + doc["question"] + "\nLet's think step by step\n"
22
+
23
+ def decode(tokens_list, tokenizer, raw_text_len):
24
+ sents = []
25
+ # print(len(tokens_list))
26
+ for tokens in tokens_list:
27
+ tokens = tokens.cpu().numpy().tolist()
28
+ sent = tokenizer.tokenizer.decode(
29
+ tokens[raw_text_len:])
30
+ sent = sent.split('<|endoftext|>')[0]
31
+ sent = sent.split('\n\n\n')[0]
32
+ sent = sent.split("\n\n")[0]
33
+ sent = sent.split("Question:")[0]
34
+ sents.append(sent)
35
+ return sents
36
+
37
+ def generate_sample(model, tokenizer, input_txt):
38
+ input_ids = tokenizer.tokenizer.encode(input_txt)
39
+ raw_text_len = len(input_ids)
40
+ context_enc = torch.tensor(
41
+ [input_ids]).to(model.device)
42
+ print(f"Input text: {input_txt}\n")
43
+ outputs = model.generate(context_enc)
44
+ output_text = decode(outputs,tokenizer,raw_text_len)[0]
45
+ print(f"\nOutput text: {output_text}\n")
46
+ return output_text
47
+
48
+
49
+ def extract_answer_hf(completion):
50
+ match = ANS_RE.search(completion)
51
+ if match:
52
+ match_str = match.group(1).strip()
53
+ match_str = match_str.replace(",", "")
54
+ return eval(match_str)
55
+ else:
56
+ return INVALID_ANS
57
+
58
+ def extract_answer(completion):
59
+ try:
60
+ last_number = re.findall(r'\d+', completion)[-1]
61
+ return eval(last_number)
62
+ except:
63
+ return INVALID_ANS
64
+
65
+ def is_correct( completion, answer):
66
+ gold = extract_answer_hf(answer)
67
+ assert gold != INVALID_ANS, "No ground truth answer found in the document."
68
+ return extract_answer(completion) == gold
69
+
70
+ if __name__ == '__main__':
71
+
72
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
73
+ parser.add_argument("-c", "--checkpoint-path", type=str, help="Checkpoint path", default="Qwen/Qwen-7B")
74
+ parser.add_argument("-f","--sample-input-file", type=str, default=None)
75
+ parser.add_argument("-o","--sample-output-file", type=str, default="gsm8k_res.jsonl")
76
+
77
+ args = parser.parse_args()
78
+
79
+ fewshot_prompt = open("gsm8k_prompt.txt").read()
80
+ if args.sample_input_file is not None:
81
+ dataset = load_from_disk(args.sample_input_file)
82
+ else:
83
+ config = datasets.DownloadConfig(resume_download=True, max_retries=100)
84
+ dataset = load_dataset("gsm8k", 'main', download_config=config)
85
+
86
+ test = dataset["test"]
87
+
88
+ print('Loading tokenizer ...')
89
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
90
+
91
+ print('Loading model ...')
92
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval()
93
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
94
+ model.generation_config.do_sample = False
95
+
96
+ f_output = jsonlines.Writer(open(args.sample_output_file, 'w', encoding='utf-8'))
97
+ tot_length = test.num_rows
98
+ acc_res = []
99
+ for doc in test:
100
+ context = doc_to_text(doc)
101
+ completion = generate_sample(model, tokenizer, context)
102
+ answer= doc["answer"]
103
+ acc = is_correct(completion, answer)
104
+ doc["completion"]=completion
105
+ doc["acc"]=acc
106
+ f_output.write(doc)
107
+ acc_res.append(acc)
108
+
109
+ f_output.close()
110
+ print("Acc: ",np.mean(acc_res))
eval/evaluate_humaneval.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import tqdm
3
+ import os
4
+ import sys
5
+ import torch
6
+ import jsonlines
7
+ import argparse
8
+ import jsonlines
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from transformers.generation import GenerationConfig
11
+
12
+ """
13
+ git clone https://github.com/openai/human-eval
14
+ $ pip install -e human-eval
15
+ evaluate_functional_correctness sample-output-file
16
+ """
17
+
18
+ def decode(tokens_list, tokenizer, raw_text_len):
19
+ sents = []
20
+ # print(len(tokens_list))
21
+ for tokens in tokens_list:
22
+ tokens = tokens.cpu().numpy().tolist()
23
+ sent = tokenizer.tokenizer.decode(
24
+ tokens[raw_text_len:])
25
+ sent = sent.split('<|endoftext|>')[0]
26
+ sent = sent.split('\n\n\n')[0]
27
+ sent = sent.split("\n\n")[0]
28
+ sent = sent.split("def ")[0]
29
+ sents.append(sent)
30
+ return sents
31
+
32
+ def generate_sample(model, tokenizer, input_txt):
33
+ input_ids = tokenizer.tokenizer.encode(input_txt)
34
+ raw_text_len = len(input_ids)
35
+ context_enc = torch.tensor([input_ids] ).to(model.device)
36
+ print(f"Input text: {input_txt}\n")
37
+ outputs = model.generate(context_enc)
38
+ output_text = decode(outputs,tokenizer,raw_text_len)[0]
39
+ print(f"\nOutput text: \n{output_text}\n")
40
+ return output_text
41
+
42
+
43
+ if __name__ == '__main__':
44
+
45
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
46
+ parser.add_argument("-c", "--checkpoint-path", type=str, help='Checkpoint path', default="Qwen/Qwen-7B")
47
+ parser.add_argument("-f","--sample-input-file", type=str, default=None, help="data path to HumanEval.jsonl")
48
+ parser.add_argument("-o","--sample-output-file", type=str, default="HumanEval_res.jsonl")
49
+
50
+
51
+ args = parser.parse_args()
52
+ print('Loading tokenizer ...')
53
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
54
+
55
+ print('Loading model ...')
56
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval()
57
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
58
+ model.generation_config.do_sample = False
59
+
60
+ f_output = jsonlines.Writer(open(args.sample_output_file, 'w', encoding='utf-8'))
61
+
62
+ f = jsonlines.open(args.sample_input_file)
63
+ with f_output as output:
64
+ for jobj in tqdm.tqdm(f, desc='task_idx'):
65
+ prompt = jobj['prompt']
66
+ task_id = jobj['task_id']
67
+ gen_sents = generate_sample(model, tokenizer, prompt)
68
+ gen_jobjs = {'task_id': task_id, "completion": gen_sents}
69
+ output.write(gen_jobjs)
70
+ f_output.close()
eval/evaluate_mmlu.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import argparse
5
+ import datasets
6
+ import torch
7
+
8
+ from typing import List
9
+ from tqdm import tqdm
10
+ from transformers.trainer_utils import set_seed
11
+
12
+
13
+ '''
14
+ wget https://people.eecs.berkeley.edu/~hendrycks/data.tar
15
+ mkdir data/mmlu
16
+ mv data.tar data/mmlu
17
+ cd data/mmlu; tar xf data.tar
18
+ cd ../../
19
+ python eval/evaluate_mmlu.py -d data/mmlu/data/
20
+ '''
21
+
22
+
23
+ def load_models_tokenizer(args):
24
+ from transformers import AutoModelForCausalLM, AutoTokenizer
25
+ from transformers.generation import GenerationConfig
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
28
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval()
29
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
30
+ return model, tokenizer
31
+
32
+
33
+ def format_example(line, include_answer=True):
34
+ example = 'Question: ' + line['question']
35
+ for choice in choices:
36
+ example += f'\n{choice}. {line[f"{choice}"]}'
37
+
38
+ if include_answer:
39
+ example += '\nAnswer: ' + line["answer"] + '\n\n'
40
+ else:
41
+ example += '\nAnswer:'
42
+ return example
43
+
44
+
45
+ def generate_few_shot_prompt(k, subject, dev_df):
46
+
47
+ def format_subject(subject):
48
+ l = subject.split("_")
49
+ s = ""
50
+ for entry in l:
51
+ s += " " + entry
52
+ return s.strip()
53
+
54
+ prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(format_subject(subject))
55
+
56
+ if k == -1:
57
+ k = dev_df.shape[0]
58
+ for i in range(k):
59
+ prompt += format_example(
60
+ dev_df.iloc[i, :],
61
+ include_answer=True,
62
+ )
63
+ return prompt
64
+
65
+
66
+ def get_logits(tokenizer, model, inputs: List[str]):
67
+ input_ids = tokenizer(inputs, padding=False)['input_ids']
68
+ input_ids = torch.tensor(input_ids, device=model.device)
69
+
70
+ if input_ids.shape[1] > args.max_seq_len:
71
+ input_ids = input_ids[:, input_ids.shape[1]-args.max_seq_len+1:]
72
+ tokens = {'input_ids': input_ids}
73
+
74
+ outputs = model(input_ids)['logits']
75
+ logits = outputs[:, -1, :]
76
+ log_probs = torch.nn.functional.softmax(logits, dim=-1)
77
+ return log_probs, {'tokens': tokens}
78
+
79
+
80
+ @torch.no_grad()
81
+ def eval_subject(
82
+ model,
83
+ tokenizer,
84
+ subject_name,
85
+ test_df,
86
+ k=5,
87
+ dev_df=None,
88
+ few_shot=False,
89
+ save_result_dir=None,
90
+ **kwargs
91
+ ):
92
+ result = []
93
+ score = []
94
+
95
+ few_shot_prompt = generate_few_shot_prompt(
96
+ k, subject_name, dev_df) if few_shot else []
97
+ all_probs = {'prob_A': [], 'prob_B': [], 'prob_C': [], 'prob_D': []}
98
+ if args.debug: print(f"few_shot_prompt: {few_shot_prompt}")
99
+
100
+ for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
101
+ question = format_example(row, include_answer=False)
102
+ full_prompt = few_shot_prompt + question
103
+
104
+ output, input_info = get_logits(tokenizer, model, [full_prompt])
105
+ assert output.shape[0] == 1
106
+ logits = output.flatten()
107
+
108
+ softval = torch.nn.functional.softmax(
109
+ torch.tensor(
110
+ [
111
+ logits[tokenizer(" A")['input_ids']],
112
+ logits[tokenizer(" B")['input_ids']],
113
+ logits[tokenizer(" C")['input_ids']],
114
+ logits[tokenizer(" D")['input_ids']],
115
+ ]
116
+ ),
117
+ dim=0,
118
+ )
119
+ if softval.dtype in {torch.bfloat16, torch.float16}:
120
+ softval = softval.to(dtype=torch.float32)
121
+ probs = softval.detach().cpu().numpy()
122
+
123
+ for i, choice in enumerate(choices):
124
+ all_probs[f'prob_{choice}'].append(probs[i])
125
+ pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
126
+
127
+ if 'answer' in row:
128
+ correct = 1 if pred == row['answer'] else 0
129
+ score.append(correct)
130
+ if args.debug: print(f'{question} pred: {pred} ref: {row["answer"]}')
131
+ result.append(pred)
132
+
133
+ if save_result_dir:
134
+ test_df['model_output'] = result
135
+ for i, choice in enumerate(choices):
136
+ test_df[f'prob_{choice}'] = (all_probs[f'prob_{choice}'])
137
+ if score:
138
+ test_df["correctness"] = score
139
+ os.makedirs(save_result_dir, exist_ok=True)
140
+ test_df.to_csv(os.path.join(
141
+ save_result_dir, f'{subject_name}_result.csv'), encoding="utf-8", index=False)
142
+
143
+ return score
144
+
145
+
146
+ def cal_mmlu(res):
147
+ acc_sum_dict = dict()
148
+ acc_norm_sum_dict = dict()
149
+ cnt_dict = dict()
150
+ acc_sum = 0.
151
+ cnt = 0
152
+ hard_cnt = 0
153
+ hard_acc_sum = 0.
154
+
155
+ for class_ in TASK_NAME_MAPPING.keys():
156
+ acc_sum_dict[class_] = 0.
157
+ acc_norm_sum_dict[class_] = 0.
158
+ cnt_dict[class_] = 0.
159
+
160
+ for tt in TASK_NAME_MAPPING[class_]:
161
+ acc_sum += sum(res[tt])
162
+ cnt += len(res[tt])
163
+
164
+ acc_sum_dict[class_] += sum(res[tt])
165
+ cnt_dict[class_] += len(res[tt])
166
+
167
+ print('\n\n\n', 'total cnt:', cnt, '\n')
168
+ for k in TASK_NAME_MAPPING.keys():
169
+ if k in cnt_dict:
170
+ print('%s ACC: %.2f ' % (
171
+ k, acc_sum_dict[k] / cnt_dict[k] * 100))
172
+ print('AVERAGE ACC:%.2f ' % (acc_sum / cnt * 100))
173
+
174
+
175
+ def main(args):
176
+ model, tokenizer = load_models_tokenizer(args)
177
+
178
+ dev_result = {}
179
+ for subject_name in tqdm(SUBJECTS):
180
+ # val_file_path = os.path.join(args.eval_data_path, 'val', f'{subject_name}_val.csv')
181
+ dev_file_path = os.path.join(args.eval_data_path, 'dev', f'{subject_name}_dev.csv')
182
+ test_file_path = os.path.join(args.eval_data_path, 'test', f'{subject_name}_test.csv')
183
+ # val_df = pd.read_csv(val_file_path, names=['question','A','B','C','D','answer'])
184
+ dev_df = pd.read_csv(dev_file_path, names=['question','A','B','C','D','answer'])
185
+ test_df = pd.read_csv(test_file_path, names=['question','A','B','C','D','answer'])
186
+
187
+ score = eval_subject(model, tokenizer, subject_name, test_df, dev_df=dev_df, k=5, few_shot=True,
188
+ save_result_dir=f"outs/mmlu_eval_result")
189
+ dev_result[subject_name] = score
190
+ cal_mmlu(dev_result)
191
+
192
+
193
+ TASK_NAME_MAPPING = {'stem': ['abstract_algebra', 'anatomy', 'astronomy', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_physics', 'computer_security', 'conceptual_physics', 'electrical_engineering', 'elementary_mathematics', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_mathematics', 'high_school_physics', 'high_school_statistics', 'machine_learning'],
194
+ 'Humanities': ['formal_logic', 'high_school_european_history', 'high_school_us_history', 'high_school_world_history', 'international_law', 'jurisprudence', 'logical_fallacies', 'moral_disputes', 'moral_scenarios', 'philosophy', 'prehistory', 'professional_law', 'world_religions'],
195
+ 'other': ['business_ethics', 'college_medicine', 'human_aging', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'nutrition', 'professional_accounting', 'professional_medicine', 'virology', 'global_facts', 'clinical_knowledge'],
196
+ 'social': ['econometrics', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_microeconomics', 'high_school_psychology', 'human_sexuality', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy']}
197
+ SUBJECTS = [v for vl in TASK_NAME_MAPPING.values() for v in vl]
198
+ choices = ["A", "B", "C", "D"]
199
+
200
+ if __name__ == '__main__':
201
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
202
+ parser.add_argument('-c', '--checkpoint-path', type=str, help='Checkpoint path', default="Qwen/Qwen-7B")
203
+ parser.add_argument('-s', '--seed', type=int, default=1234, help='Random seed')
204
+ parser.add_argument('--gpu', type=int, default=0, help='gpu id')
205
+
206
+ """Provide extra arguments required for tasks."""
207
+ group = parser.add_argument_group(title='Evaluation options')
208
+ group.add_argument('-d', '--eval_data_path', type=str,
209
+ help='Path to eval data')
210
+ group.add_argument("--max-seq-len", type=int, default=2048,
211
+ help='Size of the output generated text.')
212
+ group.add_argument("--debug", action='store_true', default=False,
213
+ help='Print infos.')
214
+
215
+ args = parser.parse_args()
216
+ set_seed(args.seed)
217
+
218
+ main(args)
eval/evaluate_plugin.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import pprint
5
+
6
+ import json5
7
+ import jsonlines
8
+ from rouge_score import rouge_scorer
9
+ from tqdm import tqdm
10
+ from transformers import Agent, AutoModelForCausalLM, AutoTokenizer
11
+ from transformers.generation import GenerationConfig
12
+ from transformers.tools.evaluate_agent import evaluate_agent
13
+ from transformers.trainer_utils import set_seed
14
+
15
+ data_root_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
16
+ 'data')
17
+
18
+
19
+ def is_callable(response, golden):
20
+ return response['action'].strip().lower() == golden['action'].strip(
21
+ ).lower()
22
+
23
+
24
+ def process_res(response):
25
+ # parse response
26
+ response += '\n' # fix not-find bug
27
+ thought = response[:response.find('Action:')].strip()
28
+ action = response[response.find('Action:') +
29
+ len('Action:'):response.find('Action Input:')].strip()
30
+ action_input = response[response.find('Action Input:') +
31
+ len('Action Input:'):response.find('Observation:'
32
+ )].strip()
33
+ #TODO: This parsing result is incorrect if the response contains multiple Actions. To be fixed in the future.
34
+ observation = response[response.find('Observation:') +
35
+ len('Observation:'):response.rfind('Thought:'
36
+ )].strip()
37
+ thought_last = response[response.rfind('Thought:') +
38
+ len('Thought:'):response.find('Final Answer:'
39
+ )].strip()
40
+ final_answer = response[response.find('Final Answer:') +
41
+ len('Final Answer:'):].strip()
42
+ try:
43
+ action_input = json.dumps(json5.loads(action_input),
44
+ ensure_ascii=False,
45
+ sort_keys=True)
46
+ except:
47
+ # print("JSON Load Error:", action_input)
48
+ pass
49
+ res_dict = {
50
+ 'thought': thought,
51
+ 'action': action,
52
+ 'action_input': action_input,
53
+ 'observation': observation,
54
+ 'thought_last': thought_last,
55
+ 'final_answer': final_answer
56
+ }
57
+ return res_dict
58
+
59
+
60
+ class _DummyTokenizer:
61
+ def tokenize(self, text: str):
62
+ return text.split()
63
+
64
+
65
+ def _get_tokenized_string(tokenizer, text_list):
66
+ token_ids_list, tokenized_string_list = [], []
67
+ for text in text_list:
68
+ assert tokenizer is not None
69
+ token_ids = tokenizer.encode(text)
70
+ tokens_bytes = tokenizer.convert_ids_to_tokens(token_ids)
71
+ tokens = [
72
+ token.decode('utf-8', errors='replace') for token in tokens_bytes
73
+ ]
74
+ tokenized_string = ' '.join(tokens)
75
+ token_ids_list.append(token_ids)
76
+ tokenized_string_list.append(tokenized_string)
77
+ return token_ids_list, tokenized_string_list
78
+
79
+
80
+ def eval_action(job):
81
+ response = job['gen'][0]
82
+ golden = job['response']
83
+
84
+ if 'Action:' in response:
85
+ response, golden = process_res(response), process_res(golden)
86
+ if is_callable(response, golden):
87
+ return True
88
+ return False
89
+
90
+
91
+ def eval_action_input(job, tokenizer):
92
+ response = job['gen'][0]
93
+ golden = job['response']
94
+ response, golden = process_res(response), process_res(golden)
95
+ query = job['prompt']
96
+
97
+ job = {}
98
+ job['prompt'] = query
99
+ job['gen'] = response['action_input']
100
+ job['response'] = golden['action_input']
101
+
102
+ job['_gen_tok'], job['_gen_tok_str'] = _get_tokenized_string(
103
+ tokenizer, [response['action_input']])
104
+ job['_reference_tok'], job['_reference_tok_str'] = _get_tokenized_string(
105
+ tokenizer, [golden['action_input']])
106
+
107
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'],
108
+ tokenizer=_DummyTokenizer())
109
+ score = scorer.score(job['_reference_tok_str'][0], job['_gen_tok_str'][0])
110
+
111
+ rouge = score['rougeL'].fmeasure
112
+
113
+ return rouge
114
+
115
+
116
+ class QWenAgent(Agent):
117
+ """
118
+ Agent that uses QWen model and tokenizer to generate code.
119
+
120
+ Example:
121
+
122
+ ```py
123
+ agent = QWenAgent()
124
+ agent.run("Draw me a picture of rivers and lakes.")
125
+ ```
126
+ """
127
+ def __init__(self,
128
+ chat_prompt_template=None,
129
+ run_prompt_template=None,
130
+ additional_tools=None,
131
+ tokenizer=None,
132
+ model=None):
133
+ if tokenizer and model:
134
+ self.tokenizer = tokenizer
135
+ self.model = model
136
+ else:
137
+ checkpoint = 'Qwen/Qwen-7B-Chat'
138
+ self.tokenizer = AutoTokenizer.from_pretrained(
139
+ checkpoint, trust_remote_code=True)
140
+ self.model = AutoModelForCausalLM.from_pretrained(
141
+ checkpoint, device_map='auto',
142
+ trust_remote_code=True).cuda().eval()
143
+ self.model.generation_config = GenerationConfig.from_pretrained(
144
+ checkpoint, trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参
145
+ self.model.generation_config.do_sample = False # greedy
146
+
147
+ super().__init__(
148
+ chat_prompt_template=chat_prompt_template,
149
+ run_prompt_template=run_prompt_template,
150
+ additional_tools=additional_tools,
151
+ )
152
+
153
+ def generate_one(self, prompt, stop):
154
+ # "Human:" 和 "Assistant:" 曾为通义千问的特殊保留字,需要替换为 "_HUMAN_:" 和 "_ASSISTANT_:"。这一问题将在未来版本修复。
155
+ prompt = prompt.replace('Human:',
156
+ '_HUMAN_:').replace('Assistant:',
157
+ '_ASSISTANT_:')
158
+ stop = [
159
+ item.replace('Human:', '_HUMAN_:').replace('Assistant:',
160
+ '_ASSISTANT_:')
161
+ for item in stop
162
+ ]
163
+
164
+ result, _ = self.model.chat(self.tokenizer, prompt, history=None)
165
+ for stop_seq in stop:
166
+ if result.endswith(stop_seq):
167
+ result = result[:-len(stop_seq)]
168
+
169
+ result = result.replace('_HUMAN_:',
170
+ 'Human:').replace('_ASSISTANT_:', 'Assistant:')
171
+ return result
172
+
173
+
174
+ def load_models_tokenizer(args):
175
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path,
176
+ trust_remote_code=True)
177
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path,
178
+ device_map='auto',
179
+ trust_remote_code=True,
180
+ bf16=True,
181
+ use_flash_attn=True).eval()
182
+ model.generation_config = GenerationConfig.from_pretrained(
183
+ args.checkpoint_path, trust_remote_code=True)
184
+ model.generation_config.do_sample = False # use greedy decoding
185
+ return model, tokenizer
186
+
187
+
188
+ def load_jobs(filename):
189
+ jobs = []
190
+ with jsonlines.open(os.path.join(data_root_path, filename),
191
+ mode='r') as reader:
192
+ for job in reader:
193
+ jobs.append(job)
194
+ return jobs
195
+
196
+
197
+ def react_inference(filename, model, tokenizer):
198
+ filename_cache = filename + '.cache'
199
+ if os.path.exists(os.path.join(data_root_path, filename_cache)):
200
+ jobs = load_jobs(filename=filename_cache)
201
+ print('Loaded from', filename_cache)
202
+ else:
203
+ with open(os.path.join(data_root_path, filename_cache), 'w') as f:
204
+ jobs = load_jobs(filename=filename)
205
+ print('Inference:', filename)
206
+ for job in tqdm(jobs):
207
+ response, history = model.chat(tokenizer,
208
+ job['prompt'],
209
+ history=None)
210
+ job['gen'] = [response]
211
+ f.writelines(json.dumps(job, ensure_ascii=False) + '\n')
212
+ print(filename_cache, 'is saved.')
213
+ return jobs
214
+
215
+
216
+ def main(args):
217
+ print('loading model weights')
218
+ if args.checkpoint_path is not None:
219
+ model, tokenizer = load_models_tokenizer(args)
220
+ else:
221
+ model, tokenizer = None, None
222
+ print('model loaded')
223
+
224
+ result = {}
225
+ # eval react positive
226
+ if args.eval_react_positive:
227
+ print('eval react positive ...')
228
+ acc_count = 0
229
+ rouge_mean = 0
230
+ jobs = react_inference(filename=args.eval_react_positive_filename,
231
+ model=model,
232
+ tokenizer=tokenizer)
233
+ for job in jobs:
234
+ if eval_action(job):
235
+ acc_count += 1
236
+ rouge = eval_action_input(job, tokenizer)
237
+ rouge_mean += (rouge / len(jobs))
238
+
239
+ scores = {
240
+ 'action_right_rate': acc_count / len(jobs),
241
+ 'action_input_rouge': rouge_mean,
242
+ }
243
+
244
+ result.update({'react_positive': scores})
245
+
246
+ # eval react negative
247
+ if args.eval_react_negative:
248
+ print('eval react negative ...')
249
+ bad_count = 0
250
+ jobs = react_inference(filename=args.eval_react_negative_filename,
251
+ model=model,
252
+ tokenizer=tokenizer)
253
+ for job in jobs:
254
+ if '\nAction:' in job['gen'][0]:
255
+ bad_count += 1
256
+ scores = {'bad_rate': bad_count / len(jobs)}
257
+ result.update({'react_negative': scores})
258
+
259
+ # eval hfagent
260
+ if args.eval_hfagent:
261
+ print('eval hfagent ...')
262
+ agent = QWenAgent(model=model, tokenizer=tokenizer)
263
+ scores = evaluate_agent(agent, verbose=False, return_errors=False)
264
+ result.update({'hfagent': scores})
265
+
266
+ pp = pprint.PrettyPrinter(indent=4)
267
+ pp.pprint(result)
268
+
269
+
270
+ if __name__ == '__main__':
271
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
272
+ parser.add_argument('-c',
273
+ '--checkpoint-path',
274
+ type=str,
275
+ help='Checkpoint path',
276
+ default='Qwen/Qwen-7B-Chat')
277
+ parser.add_argument('-s',
278
+ '--seed',
279
+ type=int,
280
+ default=1234,
281
+ help='Random seed')
282
+ """Provide extra arguments required for tasks."""
283
+ group = parser.add_argument_group(title='Evaluation options')
284
+ group.add_argument('--eval-react-positive',
285
+ action='store_true',
286
+ default=False,
287
+ help='Eval react positive.')
288
+ group.add_argument('--eval-react-positive-filename',
289
+ type=str,
290
+ default='exam_plugin_v1_react_positive.jsonl',
291
+ help='Eval react positive filename.')
292
+ group.add_argument('--eval-react-negative',
293
+ action='store_true',
294
+ default=False,
295
+ help='Eval react negative.')
296
+ group.add_argument('--eval-react-negative-filename',
297
+ type=str,
298
+ default='exam_plugin_v1_react_negative.jsonl',
299
+ help='Eval react negative filename.')
300
+ group.add_argument('--eval-hfagent',
301
+ action='store_true',
302
+ default=False,
303
+ help='Eval hfagent.')
304
+
305
+ args = parser.parse_args()
306
+ set_seed(args.seed)
307
+
308
+ main(args)
eval/gsm8k_prompt.txt ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Question: In 2004, there were 60 kids at a cookout. In 2005, half the number of kids came to the cookout as compared to 2004. In 2006, 2/3 as many kids came to the cookout as in 2005. How many kids came to the cookout in 2006?
2
+ Let's think step by step
3
+ In 2005, 60/2=30 kids came to the cookout.
4
+ In 2006, 30/3*2=20 kids came to the cookout.
5
+ The answer is 20
6
+
7
+ Question: Zilla spent 7% of her monthly earnings on rent, half of it on her other monthly expenses, and put the rest in her savings. If she spent $133 on her rent, how much does she deposit into her savings account in a month?
8
+ Let's think step by step
9
+ Since $133 is equal to 7% of her earnings, then 1% is equal to $133/7 = $19.
10
+ The total monthly earning of Zilla is represented by 100%, so $19 x 100 = $1900 is her monthly earnings.
11
+ So, $1900/2 = $950 is spent on her other monthly expenses.
12
+ The total amount spent on the rent and other monthly expenses is $133 + $950 = $1083.
13
+ Hence, she saves $1900 - $1083 = $817 per month.
14
+ The answer is 817
15
+
16
+ Question: If Buzz bought a pizza with 78 slices at a restaurant and then decided to share it with the waiter in the ratio of 5:8, with Buzz's ratio being 5, what's twenty less the number of slices of pizza that the waiter ate?
17
+ Let's think step by step
18
+ The total ratio representing the slices of pizza that Buzz bought is 5+8=13
19
+ If he shared the slices of pizza with the waiter, the waiter received a fraction of 8/13 of the total number of slices, which totals 8/13 * 78 = 48 slices
20
+ Twenty less the number of slices of pizza that the waiter ate is 48-20 = 28
21
+ The answer is 28
22
+
23
+ Question: Jame gets a raise to $20 per hour and works 40 hours a week. His old job was $16 an hour for 25 hours per week. How much more money does he make per year in his new job than the old job if he works 52 weeks a year?
24
+ Let's think step by step
25
+ He makes 20*40=$800 per week
26
+ He used to make 16*25=$400 per week
27
+ So his raise was 800-400=$400 per week
28
+ So he makes 400*52=$20,800 per year more
29
+ The answer is 20800
30
+
31
+ Question: Mr. Gardner bakes 20 cookies, 25 cupcakes, and 35 brownies for his second-grade class of 20 students. If he wants to give each student an equal amount of sweet treats, how many sweet treats will each student receive?
32
+ Let's think step by step
33
+ Mr. Gardner bakes a total of 20 + 25 + 35 = 80 sweet treats
34
+ Each student will receive 80 / 20 = 4 sweet treats
35
+ The answer is 4
36
+
37
+ Question: A used car lot has 24 cars and motorcycles (in total) for sale. A third of the vehicles are motorcycles, and a quarter of the cars have a spare tire included. How many tires are on the used car lot’s vehicles in all?
38
+ Let's think step by step
39
+ The used car lot has 24 / 3 = 8 motorcycles with 2 tires each.
40
+ The lot has 24 - 8 = 16 cars for sale
41
+ There are 16 / 4 = 4 cars with a spare tire with 5 tires each.
42
+ The lot has 16 - 4 = 12 cars with 4 tires each.
43
+ Thus, the used car lot’s vehicles have 8 * 2 + 4 * 5 + 12 * 4 = 16 + 20 + 48 = 84 tires in all.
44
+ The answer is 84
45
+
46
+ Question: Norma takes her clothes to the laundry. She leaves 9 T-shirts and twice as many sweaters as T-shirts in the washer. When she returns she finds 3 sweaters and triple the number of T-shirts. How many items are missing?
47
+ Let's think step by step
48
+ Norma left 9 T-shirts And twice as many sweaters, she took 9 * 2= 18 sweaters
49
+ Adding the T-shirts and sweaters, Norma left 9 + 18 = 27 clothes
50
+ When she came back, she found 3 sweaters And triple the number of T-shirts, she found 3 * 3 = 9 T-shirts
51
+ Adding the T-shirts and sweaters, Norma found 3 + 9 = 12 clothes
52
+ Subtracting the clothes she left from the clothes she found, 27 - 12 = 15 clothes are missing
53
+ The answer is 15
54
+
55
+ Question: Adam has an orchard. Every day for 30 days he picks 4 apples from his orchard. After a month, Adam has collected all the remaining apples, which were 230. How many apples in total has Adam collected from his orchard?
56
+ Let's think step by step
57
+ During 30 days Adam picked 4 * 30 = 120 apples.
58
+ So in total with all the remaining apples, he picked 120 + 230 = 350 apples from his orchard.
59
+ The answer is 350
examples/langchain_tooluse.ipynb ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "30e24ef3",
6
+ "metadata": {
7
+ "tags": []
8
+ },
9
+ "source": [
10
+ "# 如何让 Qwen-7b 使用 Langchain 中的 工具\n",
11
+ "\n",
12
+ "本文档主要介绍如何让千问调用 [LangChain](https://python.langchain.com/docs/get_started/introduction.html) 框架中实现好的谷歌搜索、 WolframAlpha 等工具。将主要基于 [ReAct Prompting](https://github.com/QwenLM/Qwen-7B/blob/main/examples/react_prompt.md) 技术,一种特殊的链式思考(Chain-of-Thought,简称 CoT)提示技巧,来实现这一目的。"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "markdown",
17
+ "id": "212979ec",
18
+ "metadata": {
19
+ "tags": []
20
+ },
21
+ "source": [
22
+ "## 安装依赖"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 1,
28
+ "id": "e21c6728",
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "# 安装千问的依赖\n",
33
+ "!cd Qwen-7b\n",
34
+ "!pip install -r requirements.txt\n",
35
+ "\n",
36
+ "# 安装 langchain 相关依赖\n",
37
+ "!pip install langchain google-search-results wolframalpha arxiv;"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "markdown",
42
+ "id": "3b5e6ef9",
43
+ "metadata": {
44
+ "tags": []
45
+ },
46
+ "source": [
47
+ "## 第零步 - 导入 LangChain 的工具"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "markdown",
52
+ "id": "af7d0058",
53
+ "metadata": {},
54
+ "source": [
55
+ "以下引入几个常用 APIs 作为演示:\n",
56
+ " - [谷歌搜索API](https://serper.dev/?gclid=EAIaIQobChMIj9eqof7OgAMV44VbCh1F3QZoEAAYASABEgIh3fD_BwE#google-search-api)\n",
57
+ " - [WolframAlpha](https://products.wolframalpha.com/api/)\n",
58
+ " - arxiv论文搜索\n",
59
+ " - python shell (需升级python至3.9以上使用)\n",
60
+ "\n",
61
+ "注1:此处推荐模仿此案例,细致地构造给千问看的工具描述。\n",
62
+ "\n",
63
+ "注2:谷歌搜索(SERPAPI), WolframAlpha 需自行申请它们的 API_KEY 后才能使用。"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": 2,
69
+ "id": "07e49b98-9d6c-41f2-9b18-f043f2d13e1a",
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "from langchain import SerpAPIWrapper\n",
74
+ "from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper\n",
75
+ "from langchain.utilities import ArxivAPIWrapper\n",
76
+ "from langchain.tools.python.tool import PythonAstREPLTool\n",
77
+ "\n",
78
+ "from typing import Dict, Tuple\n",
79
+ "import os\n",
80
+ "import json\n",
81
+ "\n",
82
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
83
+ "from transformers.generation import GenerationConfig\n",
84
+ "\n",
85
+ "# 为了使用谷歌搜索(SERPAPI), WolframAlpha,您需要自行申请它们的 API KEY,然后填入此处\n",
86
+ "os.environ['SERPAPI_API_KEY'] = '重要!请在这里填入您的 SERPAPI_API_KEY!'\n",
87
+ "os.environ['WOLFRAM_ALPHA_APPID'] = '重要!请在这里填入您的 WOLFRAM_ALPHA_APPID!'\n",
88
+ "\n",
89
+ "search = SerpAPIWrapper()\n",
90
+ "WolframAlpha = WolframAlphaAPIWrapper()\n",
91
+ "arxiv = ArxivAPIWrapper()\n",
92
+ "python=PythonAstREPLTool()\n",
93
+ "\n",
94
+ "def tool_wrapper_for_qwen(tool):\n",
95
+ " def tool_(query):\n",
96
+ " query = json.loads(query)[\"query\"]\n",
97
+ " return tool.run(query)\n",
98
+ " return tool_\n",
99
+ "\n",
100
+ "# 以下是给千问看的工具描述:\n",
101
+ "TOOLS = [\n",
102
+ " {\n",
103
+ " 'name_for_human':\n",
104
+ " 'google search',\n",
105
+ " 'name_for_model':\n",
106
+ " 'Search',\n",
107
+ " 'description_for_model':\n",
108
+ " 'useful for when you need to answer questions about current events.',\n",
109
+ " 'parameters': [{\n",
110
+ " \"name\": \"query\",\n",
111
+ " \"type\": \"string\",\n",
112
+ " \"description\": \"search query of google\",\n",
113
+ " 'required': True\n",
114
+ " }], \n",
115
+ " 'tool_api': tool_wrapper_for_qwen(search)\n",
116
+ " },\n",
117
+ " {\n",
118
+ " 'name_for_human':\n",
119
+ " 'Wolfram Alpha',\n",
120
+ " 'name_for_model':\n",
121
+ " 'Math',\n",
122
+ " 'description_for_model':\n",
123
+ " 'Useful for when you need to answer questions about Math, Science, Technology, Culture, Society and Everyday Life.',\n",
124
+ " 'parameters': [{\n",
125
+ " \"name\": \"query\",\n",
126
+ " \"type\": \"string\",\n",
127
+ " \"description\": \"the problem to solved by Wolfram Alpha\",\n",
128
+ " 'required': True\n",
129
+ " }], \n",
130
+ " 'tool_api': tool_wrapper_for_qwen(WolframAlpha)\n",
131
+ " }, \n",
132
+ " {\n",
133
+ " 'name_for_human':\n",
134
+ " 'arxiv',\n",
135
+ " 'name_for_model':\n",
136
+ " 'Arxiv',\n",
137
+ " 'description_for_model':\n",
138
+ " 'A wrapper around Arxiv.org Useful for when you need to answer questions about Physics, Mathematics, Computer Science, Quantitative Biology, Quantitative Finance, Statistics, Electrical Engineering, and Economics from scientific articles on arxiv.org.',\n",
139
+ " 'parameters': [{\n",
140
+ " \"name\": \"query\",\n",
141
+ " \"type\": \"string\",\n",
142
+ " \"description\": \"the document id of arxiv to search\",\n",
143
+ " 'required': True\n",
144
+ " }], \n",
145
+ " 'tool_api': tool_wrapper_for_qwen(arxiv)\n",
146
+ " },\n",
147
+ " {\n",
148
+ " 'name_for_human':\n",
149
+ " 'python',\n",
150
+ " 'name_for_model':\n",
151
+ " 'python',\n",
152
+ " 'description_for_model':\n",
153
+ " \"A Python shell. Use this to execute python commands. When using this tool, sometimes output is abbreviated - Make sure it does not look abbreviated before using it in your answer. \"\n",
154
+ " \"Don't add comments to your python code.\",\n",
155
+ " 'parameters': [{\n",
156
+ " \"name\": \"query\",\n",
157
+ " \"type\": \"string\",\n",
158
+ " \"description\": \"a valid python command.\",\n",
159
+ " 'required': True\n",
160
+ " }],\n",
161
+ " 'tool_api': tool_wrapper_for_qwen(python)\n",
162
+ " }\n",
163
+ "\n",
164
+ "]\n"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "markdown",
169
+ "id": "b7ec2027",
170
+ "metadata": {},
171
+ "source": [
172
+ "## 第一步:让千问判断调用什么工具,生成工具入参"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "markdown",
177
+ "id": "7a50d676",
178
+ "metadata": {},
179
+ "source": [
180
+ "根据prompt模版、query、工具的信息构建prompt"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": 3,
186
+ "id": "4a8feb0e-22f7-4184-9ea0-b864812c9b09",
187
+ "metadata": {
188
+ "scrolled": true
189
+ },
190
+ "outputs": [
191
+ {
192
+ "name": "stdout",
193
+ "output_type": "stream",
194
+ "text": [
195
+ "Answer the following questions as best you can. You have access to the following tools:\n",
196
+ "\n",
197
+ "Search: Call this tool to interact with the google search API. What is the google search API useful for? useful for when you need to answer questions about current events. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"search query of google\", \"required\": true}] Format the arguments as a JSON object.\n",
198
+ "\n",
199
+ "Use the following format:\n",
200
+ "\n",
201
+ "Question: the input question you must answer\n",
202
+ "Thought: you should always think about what to do\n",
203
+ "Action: the action to take, should be one of [Search]\n",
204
+ "Action Input: the input to the action\n",
205
+ "Observation: the result of the action\n",
206
+ "... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n",
207
+ "Thought: I now know the final answer\n",
208
+ "Final Answer: the final answer to the original input question\n",
209
+ "\n",
210
+ "Begin!\n",
211
+ "\n",
212
+ "Question: 加拿大2023年人口统计数字是多少?\n"
213
+ ]
214
+ }
215
+ ],
216
+ "source": [
217
+ "TOOL_DESC = \"\"\"{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters} Format the arguments as a JSON object.\"\"\"\n",
218
+ "\n",
219
+ "REACT_PROMPT = \"\"\"Answer the following questions as best you can. You have access to the following tools:\n",
220
+ "\n",
221
+ "{tool_descs}\n",
222
+ "\n",
223
+ "Use the following format:\n",
224
+ "\n",
225
+ "Question: the input question you must answer\n",
226
+ "Thought: you should always think about what to do\n",
227
+ "Action: the action to take, should be one of [{tool_names}]\n",
228
+ "Action Input: the input to the action\n",
229
+ "Observation: the result of the action\n",
230
+ "... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n",
231
+ "Thought: I now know the final answer\n",
232
+ "Final Answer: the final answer to the original input question\n",
233
+ "\n",
234
+ "Begin!\n",
235
+ "\n",
236
+ "Question: {query}\"\"\"\n",
237
+ "\n",
238
+ "def build_planning_prompt(TOOLS, query):\n",
239
+ " tool_descs = []\n",
240
+ " tool_names = []\n",
241
+ " for info in TOOLS:\n",
242
+ " tool_descs.append(\n",
243
+ " TOOL_DESC.format(\n",
244
+ " name_for_model=info['name_for_model'],\n",
245
+ " name_for_human=info['name_for_human'],\n",
246
+ " description_for_model=info['description_for_model'],\n",
247
+ " parameters=json.dumps(\n",
248
+ " info['parameters'], ensure_ascii=False),\n",
249
+ " )\n",
250
+ " )\n",
251
+ " tool_names.append(info['name_for_model'])\n",
252
+ " tool_descs = '\\n\\n'.join(tool_descs)\n",
253
+ " tool_names = ','.join(tool_names)\n",
254
+ "\n",
255
+ " prompt = REACT_PROMPT.format(tool_descs=tool_descs, tool_names=tool_names, query=query)\n",
256
+ " return prompt\n",
257
+ "\n",
258
+ "prompt_1 = build_planning_prompt(TOOLS[0:1], query=\"加拿大2023年人口统计数字是多少���\")\n",
259
+ "print(prompt_1)"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "markdown",
264
+ "id": "6f22b002",
265
+ "metadata": {},
266
+ "source": [
267
+ "将prompt作为输入获得response"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": 4,
273
+ "id": "f71b2577-118c-4ce2-a0ed-a45ec59ea35b",
274
+ "metadata": {},
275
+ "outputs": [
276
+ {
277
+ "name": "stderr",
278
+ "output_type": "stream",
279
+ "text": [
280
+ "A new version of the following files was downloaded from https://huggingface.co/Qwen/Qwen-7B-Chat:\n",
281
+ "- tokenization_qwen.py\n",
282
+ ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
283
+ "A new version of the following files was downloaded from https://huggingface.co/Qwen/Qwen-7B-Chat:\n",
284
+ "- configuration_qwen.py\n",
285
+ ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
286
+ "A new version of the following files was downloaded from https://huggingface.co/Qwen/Qwen-7B-Chat:\n",
287
+ "- qwen_generation_utils.py\n",
288
+ ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
289
+ "A new version of the following files was downloaded from https://huggingface.co/Qwen/Qwen-7B-Chat:\n",
290
+ "- modeling_qwen.py\n",
291
+ "- qwen_generation_utils.py\n",
292
+ ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
293
+ ]
294
+ },
295
+ {
296
+ "data": {
297
+ "application/vnd.jupyter.widget-view+json": {
298
+ "model_id": "23435445dded44d6951aa6a7b771a963",
299
+ "version_major": 2,
300
+ "version_minor": 0
301
+ },
302
+ "text/plain": [
303
+ "Downloading shards: 0%| | 0/8 [00:00<?, ?it/s]"
304
+ ]
305
+ },
306
+ "metadata": {},
307
+ "output_type": "display_data"
308
+ },
309
+ {
310
+ "name": "stderr",
311
+ "output_type": "stream",
312
+ "text": [
313
+ "The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\".\n",
314
+ "Try importing flash-attention for faster inference...\n",
315
+ "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary\n",
316
+ "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm\n",
317
+ "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency https://github.com/Dao-AILab/flash-attention\n"
318
+ ]
319
+ },
320
+ {
321
+ "data": {
322
+ "application/vnd.jupyter.widget-view+json": {
323
+ "model_id": "728a1c13c2884291ade4cb4a1edfaaf2",
324
+ "version_major": 2,
325
+ "version_minor": 0
326
+ },
327
+ "text/plain": [
328
+ "Loading checkpoint shards: 0%| | 0/8 [00:00<?, ?it/s]"
329
+ ]
330
+ },
331
+ "metadata": {},
332
+ "output_type": "display_data"
333
+ }
334
+ ],
335
+ "source": [
336
+ "# 国内连 hugginface 网络不好,这段代码可能需要多重试\n",
337
+ "checkpoint = \"Qwen/Qwen-7B-Chat\"\n",
338
+ "TOKENIZER = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)\n",
339
+ "MODEL = AutoModelForCausalLM.from_pretrained(checkpoint, device_map=\"auto\", trust_remote_code=True).eval()\n",
340
+ "MODEL.generation_config = GenerationConfig.from_pretrained(checkpoint, trust_remote_code=True)\n",
341
+ "MODEL.generation_config.do_sample = False # greedy"
342
+ ]
343
+ },
344
+ {
345
+ "cell_type": "code",
346
+ "execution_count": 5,
347
+ "id": "dc0dbd6c-5a0f-44c9-a019-0ec0283ca92d",
348
+ "metadata": {},
349
+ "outputs": [
350
+ {
351
+ "name": "stdout",
352
+ "output_type": "stream",
353
+ "text": [
354
+ "Thought: 我应该使用搜索工具帮助我完成任务。search api能完成搜索的任务。\n",
355
+ "Action: Search\n",
356
+ "Action Input: {\"query\": \"加拿大 2023年人口统计数字\"}\n",
357
+ "Observation:\n"
358
+ ]
359
+ }
360
+ ],
361
+ "source": [
362
+ "stop = [\"Observation:\", \"Observation:\\n\"]\n",
363
+ "react_stop_words_tokens = [TOKENIZER.encode(stop_) for stop_ in stop]\n",
364
+ "response_1, _ = MODEL.chat(TOKENIZER, prompt_1, history=None, stop_words_ids=react_stop_words_tokens)\n",
365
+ "print(response_1)"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "markdown",
370
+ "id": "1ebf47ac",
371
+ "metadata": {},
372
+ "source": [
373
+ "## 第二步:从千问的输出中解析需要使用的工具和入参,并调用对应工具"
374
+ ]
375
+ },
376
+ {
377
+ "cell_type": "code",
378
+ "execution_count": 6,
379
+ "id": "1a431670-a1f6-4afd-972f-1cfd6d06e8c9",
380
+ "metadata": {},
381
+ "outputs": [
382
+ {
383
+ "name": "stdout",
384
+ "output_type": "stream",
385
+ "text": [
386
+ "根据加拿大统计局��测,加拿大人口今天(2023年6月16日)预计将超过4000万。 联邦统计局使用模型来实时估计加拿大的人口,该计数模型预计加拿大人口将在北美东部时间今天下午3点前达到4000万。 加拿大的人口增长率目前为2.7%。\n"
387
+ ]
388
+ }
389
+ ],
390
+ "source": [
391
+ "def parse_latest_plugin_call(text: str) -> Tuple[str, str]:\n",
392
+ " i = text.rfind('\\nAction:')\n",
393
+ " j = text.rfind('\\nAction Input:')\n",
394
+ " k = text.rfind('\\nObservation:')\n",
395
+ " if 0 <= i < j: # If the text has `Action` and `Action input`,\n",
396
+ " if k < j: # but does not contain `Observation`,\n",
397
+ " # then it is likely that `Observation` is ommited by the LLM,\n",
398
+ " # because the output text may have discarded the stop word.\n",
399
+ " text = text.rstrip() + '\\nObservation:' # Add it back.\n",
400
+ " k = text.rfind('\\nObservation:')\n",
401
+ " if 0 <= i < j < k:\n",
402
+ " plugin_name = text[i + len('\\nAction:'):j].strip()\n",
403
+ " plugin_args = text[j + len('\\nAction Input:'):k].strip()\n",
404
+ " return plugin_name, plugin_args\n",
405
+ " return '', ''\n",
406
+ "\n",
407
+ "def use_api(tools, response):\n",
408
+ " use_toolname, action_input = parse_latest_plugin_call(response)\n",
409
+ " if use_toolname == \"\":\n",
410
+ " return \"no tool founds\"\n",
411
+ "\n",
412
+ " used_tool_meta = list(filter(lambda x: x[\"name_for_model\"] == use_toolname, tools))\n",
413
+ " if len(used_tool_meta) == 0:\n",
414
+ " return \"no tool founds\"\n",
415
+ " \n",
416
+ " api_output = used_tool_meta[0][\"tool_api\"](action_input)\n",
417
+ " return api_output\n",
418
+ "\n",
419
+ "api_output = use_api(TOOLS, response_1)\n",
420
+ "print(api_output)"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "markdown",
425
+ "id": "106a4ba0",
426
+ "metadata": {
427
+ "tags": []
428
+ },
429
+ "source": [
430
+ "## 第三步:让千问根据工具返回结果继续作答\n",
431
+ "拼接上述返回答案,形成新的prompt,并获得生成最终结果"
432
+ ]
433
+ },
434
+ {
435
+ "cell_type": "code",
436
+ "execution_count": 7,
437
+ "id": "a9d4d42d",
438
+ "metadata": {},
439
+ "outputs": [
440
+ {
441
+ "name": "stdout",
442
+ "output_type": "stream",
443
+ "text": [
444
+ "Answer the following questions as best you can. You have access to the following tools:\n",
445
+ "\n",
446
+ "Search: Call this tool to interact with the google search API. What is the google search API useful for? useful for when you need to answer questions about current events. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"search query of google\", \"required\": true}] Format the arguments as a JSON object.\n",
447
+ "\n",
448
+ "Use the following format:\n",
449
+ "\n",
450
+ "Question: the input question you must answer\n",
451
+ "Thought: you should always think about what to do\n",
452
+ "Action: the action to take, should be one of [Search]\n",
453
+ "Action Input: the input to the action\n",
454
+ "Observation: the result of the action\n",
455
+ "... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n",
456
+ "Thought: I now know the final answer\n",
457
+ "Final Answer: the final answer to the original input question\n",
458
+ "\n",
459
+ "Begin!\n",
460
+ "\n",
461
+ "Question: 加拿大2023年人口统计数字是多少?Thought: 我应该使用搜索工具帮助我完成任务。search api能完成搜索的任务。\n",
462
+ "Action: Search\n",
463
+ "Action Input: {\"query\": \"加拿大 2023年人口统计数字\"}\n",
464
+ "Observation: 根据加拿大统计局预测,加拿大人口今天(2023年6月16日)预计将超过4000万。 联邦统计局使用模型来实时估计加拿大的人口,该计数模型预计加拿大人口将在北美东部时间今天下午3点前达到4000万。 加拿大的人口增长率目前为2.7%。 Thought: I now know the final answer.\n",
465
+ "Final Answer: 加拿大2023年人口统计数字预计为4000万。\n"
466
+ ]
467
+ }
468
+ ],
469
+ "source": [
470
+ "prompt_2 = prompt_1 + response_1 + ' ' + api_output\n",
471
+ "stop = [\"Observation:\", \"Observation:\\n\"]\n",
472
+ "react_stop_words_tokens = [TOKENIZER.encode(stop_) for stop_ in stop]\n",
473
+ "response_2, _ = MODEL.chat(TOKENIZER, prompt_2, history=None, stop_words_ids=react_stop_words_tokens)\n",
474
+ "print(prompt_2, response_2)"
475
+ ]
476
+ },
477
+ {
478
+ "cell_type": "markdown",
479
+ "id": "0b8da9fd",
480
+ "metadata": {},
481
+ "source": [
482
+ "## 总结 - 串联起整个流程"
483
+ ]
484
+ },
485
+ {
486
+ "cell_type": "code",
487
+ "execution_count": 8,
488
+ "id": "1e51a8ea",
489
+ "metadata": {},
490
+ "outputs": [],
491
+ "source": [
492
+ "def main(query, choose_tools):\n",
493
+ " prompt = build_planning_prompt(choose_tools, query) # 组织prompt\n",
494
+ " print(prompt)\n",
495
+ " stop = [\"Observation:\", \"Observation:\\n\"]\n",
496
+ " react_stop_words_tokens = [TOKENIZER.encode(stop_) for stop_ in stop]\n",
497
+ " response, _ = MODEL.chat(TOKENIZER, prompt, history=None, stop_words_ids=react_stop_words_tokens)\n",
498
+ "\n",
499
+ " while \"Final Answer:\" not in response: # 出现final Answer时结束\n",
500
+ " api_output = use_api(choose_tools, response) # 抽取入参并执行api\n",
501
+ " api_output = str(api_output) # 部分api工具返回结果非字符串格式需进行转化后输出\n",
502
+ " if \"no tool founds\" == api_output:\n",
503
+ " break\n",
504
+ " print(\"\\033[32m\" + response + \"\\033[0m\" + \"\\033[34m\" + ' ' + api_output + \"\\033[0m\")\n",
505
+ " prompt = prompt + response + ' ' + api_output # 合并api输出\n",
506
+ " response, _ = MODEL.chat(TOKENIZER, prompt, history=None, stop_words_ids=react_stop_words_tokens) # 继续生成\n",
507
+ "\n",
508
+ " print(\"\\033[32m\" + response + \"\\033[0m\")"
509
+ ]
510
+ },
511
+ {
512
+ "cell_type": "code",
513
+ "execution_count": 9,
514
+ "id": "6dc38a34",
515
+ "metadata": {
516
+ "collapsed": false,
517
+ "jupyter": {
518
+ "outputs_hidden": false
519
+ }
520
+ },
521
+ "outputs": [
522
+ {
523
+ "name": "stdout",
524
+ "output_type": "stream",
525
+ "text": [
526
+ "==========\n",
527
+ "Answer the following questions as best you can. You have access to the following tools:\n",
528
+ "\n",
529
+ "Search: Call this tool to interact with the google search API. What is the google search API useful for? useful for when you need to answer questions about current events. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"search query of google\", \"required\": true}] Format the arguments as a JSON object.\n",
530
+ "\n",
531
+ "Math: Call this tool to interact with the Wolfram Alpha API. What is the Wolfram Alpha API useful for? Useful for when you need to answer questions about Math, Science, Technology, Culture, Society and Everyday Life. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"the problem to solved by Wolfram Alpha\", \"required\": true}] Format the arguments as a JSON object.\n",
532
+ "\n",
533
+ "Arxiv: Call this tool to interact with the arxiv API. What is the arxiv API useful for? A wrapper around Arxiv.org Useful for when you need to answer questions about Physics, Mathematics, Computer Science, Quantitative Biology, Quantitative Finance, Statistics, Electrical Engineering, and Economics from scientific articles on arxiv.org. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"the document id of arxiv to search\", \"required\": true}] Format the arguments as a JSON object.\n",
534
+ "\n",
535
+ "python: Call this tool to interact with the python API. What is the python API useful for? A Python shell. Use this to execute python commands. When using this tool, sometimes output is abbreviated - Make sure it does not look abbreviated before using it in your answer. Don't add comments to your python code. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"a valid python command.\", \"required\": true}] Format the arguments as a JSON object.\n",
536
+ "\n",
537
+ "Use the following format:\n",
538
+ "\n",
539
+ "Question: the input question you must answer\n",
540
+ "Thought: you should always think about what to do\n",
541
+ "Action: the action to take, should be one of [Search,Math,Arxiv,python]\n",
542
+ "Action Input: the input to the action\n",
543
+ "Observation: the result of the action\n",
544
+ "... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n",
545
+ "Thought: I now know the final answer\n",
546
+ "Final Answer: the final answer to the original input question\n",
547
+ "\n",
548
+ "Begin!\n",
549
+ "\n",
550
+ "Question: 加拿大2022年的人口数量有多少?\n",
551
+ "\u001B[32mThought: 我应该使用搜索工具帮助我完成任务。search api能完成搜索的任务。\n",
552
+ "Action: Search\n",
553
+ "Action Input: {\"query\": \"加拿大 2022年人口数量\"}\n",
554
+ "Observation:\u001B[0m\u001B[34m 中新社多伦多3月22日电(记者余瑞冬)加拿大统计局3月22日公布的人口统计数据显示,截至今年1月1日,该国估算总人口约为3956.62万人,且2022年的人口增长数创纪录地突破100万人。 加统计局估算,该国人口在2022年增长105.011万人,年增长2.7%,创1957年以来最大增幅。\u001B[0m\n",
555
+ "\u001B[32mThought: I now know the final answer.\n",
556
+ "Final Answer: 加拿大2022年的人口数量约为3956.62万人。\u001B[0m\n",
557
+ "==========\n",
558
+ "Answer the following questions as best you can. You have access to the following tools:\n",
559
+ "\n",
560
+ "Search: Call this tool to interact with the google search API. What is the google search API useful for? useful for when you need to answer questions about current events. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"search query of google\", \"required\": true}] Format the arguments as a JSON object.\n",
561
+ "\n",
562
+ "Math: Call this tool to interact with the Wolfram Alpha API. What is the Wolfram Alpha API useful for? Useful for when you need to answer questions about Math, Science, Technology, Culture, Society and Everyday Life. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"the problem to solved by Wolfram Alpha\", \"required\": true}] Format the arguments as a JSON object.\n",
563
+ "\n",
564
+ "Arxiv: Call this tool to interact with the arxiv API. What is the arxiv API useful for? A wrapper around Arxiv.org Useful for when you need to answer questions about Physics, Mathematics, Computer Science, Quantitative Biology, Quantitative Finance, Statistics, Electrical Engineering, and Economics from scientific articles on arxiv.org. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"the document id of arxiv to search\", \"required\": true}] Format the arguments as a JSON object.\n",
565
+ "\n",
566
+ "python: Call this tool to interact with the python API. What is the python API useful for? A Python shell. Use this to execute python commands. When using this tool, sometimes output is abbreviated - Make sure it does not look abbreviated before using it in your answer. Don't add comments to your python code. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"a valid python command.\", \"required\": true}] Format the arguments as a JSON object.\n",
567
+ "\n",
568
+ "Use the following format:\n",
569
+ "\n",
570
+ "Question: the input question you must answer\n",
571
+ "Thought: you should always think about what to do\n",
572
+ "Action: the action to take, should be one of [Search,Math,Arxiv,python]\n",
573
+ "Action Input: the input to the action\n",
574
+ "Observation: the result of the action\n",
575
+ "... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n",
576
+ "Thought: I now know the final answer\n",
577
+ "Final Answer: the final answer to the original input question\n",
578
+ "\n",
579
+ "Begin!\n",
580
+ "\n",
581
+ "Question: 求解方程 2x+5 = -3x + 7\n",
582
+ "\u001B[32mThought: 我应该使用数学工具帮助我完成任务。Wolfram Alpha API应该能完成这项任务。\n",
583
+ "Action: Math\n",
584
+ "Action Input: {\"query\": \"2x+5 = -3x + 7\"}\n",
585
+ "Observation:\u001B[0m\u001B[34m Assumption: 2 x + 5 = -3 x + 7 \n",
586
+ "Answer: x = 2/5\u001B[0m\n",
587
+ "\u001B[32mThought: I now know the final answer.\n",
588
+ "Final Answer: x = 2/5\u001B[0m\n",
589
+ "==========\n",
590
+ "Answer the following questions as best you can. You have access to the following tools:\n",
591
+ "\n",
592
+ "Search: Call this tool to interact with the google search API. What is the google search API useful for? useful for when you need to answer questions about current events. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"search query of google\", \"required\": true}] Format the arguments as a JSON object.\n",
593
+ "\n",
594
+ "Math: Call this tool to interact with the Wolfram Alpha API. What is the Wolfram Alpha API useful for? Useful for when you need to answer questions about Math, Science, Technology, Culture, Society and Everyday Life. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"the problem to solved by Wolfram Alpha\", \"required\": true}] Format the arguments as a JSON object.\n",
595
+ "\n",
596
+ "Arxiv: Call this tool to interact with the arxiv API. What is the arxiv API useful for? A wrapper around Arxiv.org Useful for when you need to answer questions about Physics, Mathematics, Computer Science, Quantitative Biology, Quantitative Finance, Statistics, Electrical Engineering, and Economics from scientific articles on arxiv.org. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"the document id of arxiv to search\", \"required\": true}] Format the arguments as a JSON object.\n",
597
+ "\n",
598
+ "python: Call this tool to interact with the python API. What is the python API useful for? A Python shell. Use this to execute python commands. When using this tool, sometimes output is abbreviated - Make sure it does not look abbreviated before using it in your answer. Don't add comments to your python code. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"a valid python command.\", \"required\": true}] Format the arguments as a JSON object.\n",
599
+ "\n",
600
+ "Use the following format:\n",
601
+ "\n",
602
+ "Question: the input question you must answer\n",
603
+ "Thought: you should always think about what to do\n",
604
+ "Action: the action to take, should be one of [Search,Math,Arxiv,python]\n",
605
+ "Action Input: the input to the action\n",
606
+ "Observation: the result of the action\n",
607
+ "... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n",
608
+ "Thought: I now know the final answer\n",
609
+ "Final Answer: the final answer to the original input question\n",
610
+ "\n",
611
+ "Begin!\n",
612
+ "\n",
613
+ "Question: 编号是1605.08386的论���讲了些什么?\n",
614
+ "\u001B[32mThought: 我需要使用Arxiv API来搜索这篇论文。\n",
615
+ "Action: Arxiv\n",
616
+ "Action Input: {\"query\": \"1605.08386\"}\n",
617
+ "Observation:\u001B[0m\u001B[34m Published: 2016-05-26\n",
618
+ "Title: Heat-bath random walks with Markov bases\n",
619
+ "Authors: Caprice Stanley, Tobias Windisch\n",
620
+ "Summary: Graphs on lattice points are studied whose edges come from a finite set of\n",
621
+ "allowed moves of arbitrary length. We show that the diameter of these graphs on\n",
622
+ "fibers of a fixed integer matrix can be bounded from above by a constant. We\n",
623
+ "then study the mixing behaviour of heat-bath random walks on these graphs. We\n",
624
+ "also state explicit conditions on the set of moves so that the heat-bath random\n",
625
+ "walk, a generalization of the Glauber dynamics, is an expander in fixed\n",
626
+ "dimension.\u001B[0m\n",
627
+ "\u001B[32mThought: I now know the final answer.\n",
628
+ "Final Answer: 这篇论文的题目是《热浴随机游走的马尔可夫基》,作者是Caprice Stanley和Tobias Windisch。摘要中提到,该论文研究了在有限的允许移动集合中,由任意长度的边构成的图的边。我们证明了这些图在固定整数矩阵纤维上的直径可以被一个常数所限制。然后,我们研究了热浴随机游走在这类图上的混合行为。我们还给出了一个明确的条件,使得热浴随机游走(一个Glauber动力学的推广)在固定维度下是一个扩张。\u001B[0m\n",
629
+ "==========\n",
630
+ "Answer the following questions as best you can. You have access to the following tools:\n",
631
+ "\n",
632
+ "Search: Call this tool to interact with the google search API. What is the google search API useful for? useful for when you need to answer questions about current events. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"search query of google\", \"required\": true}] Format the arguments as a JSON object.\n",
633
+ "\n",
634
+ "Math: Call this tool to interact with the Wolfram Alpha API. What is the Wolfram Alpha API useful for? Useful for when you need to answer questions about Math, Science, Technology, Culture, Society and Everyday Life. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"the problem to solved by Wolfram Alpha\", \"required\": true}] Format the arguments as a JSON object.\n",
635
+ "\n",
636
+ "Arxiv: Call this tool to interact with the arxiv API. What is the arxiv API useful for? A wrapper around Arxiv.org Useful for when you need to answer questions about Physics, Mathematics, Computer Science, Quantitative Biology, Quantitative Finance, Statistics, Electrical Engineering, and Economics from scientific articles on arxiv.org. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"the document id of arxiv to search\", \"required\": true}] Format the arguments as a JSON object.\n",
637
+ "\n",
638
+ "python: Call this tool to interact with the python API. What is the python API useful for? A Python shell. Use this to execute python commands. When using this tool, sometimes output is abbreviated - Make sure it does not look abbreviated before using it in your answer. Don't add comments to your python code. Parameters: [{\"name\": \"query\", \"type\": \"string\", \"description\": \"a valid python command.\", \"required\": true}] Format the arguments as a JSON object.\n",
639
+ "\n",
640
+ "Use the following format:\n",
641
+ "\n",
642
+ "Question: the input question you must answer\n",
643
+ "Thought: you should always think about what to do\n",
644
+ "Action: the action to take, should be one of [Search,Math,Arxiv,python]\n",
645
+ "Action Input: the input to the action\n",
646
+ "Observation: the result of the action\n",
647
+ "... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n",
648
+ "Thought: I now know the final answer\n",
649
+ "Final Answer: the final answer to the original input question\n",
650
+ "\n",
651
+ "Begin!\n",
652
+ "\n",
653
+ "Question: 使用python对下面的列表进行排序: [2, 4135, 523, 2, 3]\n",
654
+ "\u001B[32mThought: 我应该使用python API来执行python命令。\n",
655
+ "Action: python\n",
656
+ "Action Input: {\"query\": \"sorted([2, 4135, 523, 2, 3])\"}\n",
657
+ "Observation:\u001B[0m\u001B[34m [2, 2, 3, 523, 4135]\u001B[0m\n",
658
+ "\u001B[32mThought: I now know the final answer.\n",
659
+ "Final Answer: 使用python对给定的列表进行排序,结果为 [2, 2, 3, 523, 4135]。\u001B[0m\n"
660
+ ]
661
+ }
662
+ ],
663
+ "source": [
664
+ "# 请尽可能控制备选工具数量\n",
665
+ "query = \"加拿大2022年的人口数量有多少?\" # 所提问题\n",
666
+ "choose_tools = TOOLS # 选择备选工具\n",
667
+ "print(\"=\" * 10)\n",
668
+ "main(query, choose_tools)\n",
669
+ "\n",
670
+ "query = \"求解方程 2x+5 = -3x + 7\" # 所提问题\n",
671
+ "choose_tools = TOOLS # 选择备选工具\n",
672
+ "print(\"=\" * 10)\n",
673
+ "main(query, choose_tools)\n",
674
+ "\n",
675
+ "query = \"编号是1605.08386的论文讲了些什么?\" # 所提问题\n",
676
+ "choose_tools = TOOLS # 选择备选工具\n",
677
+ "print(\"=\" * 10)\n",
678
+ "main(query, choose_tools)\n",
679
+ "\n",
680
+ "query =\"使用python对下面的列表进行排序: [2, 4135, 523, 2, 3]\"\n",
681
+ "choose_tools = TOOLS # 选择备选工具\n",
682
+ "print(\"=\" * 10)\n",
683
+ "main(query, choose_tools)"
684
+ ]
685
+ }
686
+ ],
687
+ "metadata": {
688
+ "kernelspec": {
689
+ "display_name": "Python 3 (ipykernel)",
690
+ "language": "python",
691
+ "name": "python3"
692
+ },
693
+ "language_info": {
694
+ "codemirror_mode": {
695
+ "name": "ipython",
696
+ "version": 3
697
+ },
698
+ "file_extension": ".py",
699
+ "mimetype": "text/x-python",
700
+ "name": "python",
701
+ "nbconvert_exporter": "python",
702
+ "pygments_lexer": "ipython3",
703
+ "version": "3.9.16"
704
+ }
705
+ },
706
+ "nbformat": 4,
707
+ "nbformat_minor": 5
708
+ }
examples/react_demo.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # 相关材料:
3
+ # ReAct Prompting 原理简要介绍,不包含代码实现:
4
+ # https://github.com/QwenLM/Qwen-7B/blob/main/examples/react_prompt.md
5
+ # 基于 model.chat 接口(对话模式)的 ReAct Prompting 实现(含接入 LangChain 的工具实现):
6
+ # https://github.com/QwenLM/Qwen-7B/blob/main/examples/langchain_tooluse.ipynb
7
+ # 基于 model.generate 接口(续写模式)的 ReAct Prompting 实现,比 chat 模式的实现更复杂些:
8
+ # https://github.com/QwenLM/Qwen-7B/blob/main/examples/react_demo.py(本文件)
9
+ #
10
+
11
+ import json
12
+ import os
13
+
14
+ import json5
15
+ import torch
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer
17
+ from transformers.generation import GenerationConfig
18
+
19
+ for _ in range(10): # 网络不稳定,多试几次
20
+ try:
21
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
22
+ generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
23
+ model = AutoModelForCausalLM.from_pretrained(
24
+ "Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True
25
+ ).eval()
26
+ model.generation_config = generation_config
27
+ model.generation_config.do_sample = False
28
+ break
29
+ except Exception:
30
+ pass
31
+
32
+ # 将一个插件的关键信息拼接成一段文本的模版。
33
+ TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""
34
+
35
+ # ReAct prompting 的 instruction 模版,将包含插件的详细信息。
36
+ PROMPT_REACT = """Answer the following questions as best you can. You have access to the following tools:
37
+
38
+ {tools_text}
39
+
40
+ Use the following format:
41
+
42
+ Question: the input question you must answer
43
+ Thought: you should always think about what to do
44
+ Action: the action to take, should be one of [{tools_name_text}]
45
+ Action Input: the input to the action
46
+ Observation: the result of the action
47
+ ... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
48
+ Thought: I now know the final answer
49
+ Final Answer: the final answer to the original input question
50
+
51
+ Begin!
52
+
53
+ Question: {query}"""
54
+
55
+
56
+ #
57
+ # 本示例代码的入口函数。
58
+ #
59
+ # 输入:
60
+ # prompt: 用户的最新一个问题。
61
+ # history: 用户与模型的对话历史,是一个 list,
62
+ # list 中的每个元素为 {"user": "用户输入", "bot": "模型输出"} 的一轮对话。
63
+ # 最新的一轮对话放 list 末尾。不包含最新一个问题。
64
+ # list_of_plugin_info: 候选插件列表,是一个 list,list 中的每个元素为一个插件的关键信息。
65
+ # 比如 list_of_plugin_info = [plugin_info_0, plugin_info_1, plugin_info_2],
66
+ # 其中 plugin_info_0, plugin_info_1, plugin_info_2 这几个样例见本文档前文。
67
+ #
68
+ # 输出:
69
+ # 模型对用户最新一个问题的回答。
70
+ #
71
+ def llm_with_plugin(prompt: str, history, list_of_plugin_info=()):
72
+ chat_history = [(x['user'], x['bot']) for x in history] + [(prompt, '')]
73
+
74
+ # 需要让模型进行续写的初始文本
75
+ planning_prompt = build_input_text(chat_history, list_of_plugin_info)
76
+
77
+ text = ''
78
+ while True:
79
+ output = text_completion(planning_prompt + text, stop_words=['Observation:', 'Observation:\n'])
80
+ action, action_input, output = parse_latest_plugin_call(output)
81
+ if action: # 需要调用插件
82
+ # action、action_input 分别为需要调用的插件代号、输入参数
83
+ # observation是插件返回的结果,为字符串
84
+ observation = call_plugin(action, action_input)
85
+ output += f'\nObservation: {observation}\nThought:'
86
+ text += output
87
+ else: # 生成结束,并且不再需要调用插件
88
+ text += output
89
+ break
90
+
91
+ new_history = []
92
+ new_history.extend(history)
93
+ new_history.append({'user': prompt, 'bot': text})
94
+ return text, new_history
95
+
96
+
97
+ # 将对话历史、插件信息聚合成一段初始文本
98
+ def build_input_text(chat_history, list_of_plugin_info) -> str:
99
+ # 候选插件的详细信息
100
+ tools_text = []
101
+ for plugin_info in list_of_plugin_info:
102
+ tool = TOOL_DESC.format(
103
+ name_for_model=plugin_info["name_for_model"],
104
+ name_for_human=plugin_info["name_for_human"],
105
+ description_for_model=plugin_info["description_for_model"],
106
+ parameters=json.dumps(plugin_info["parameters"], ensure_ascii=False),
107
+ )
108
+ if plugin_info.get('args_format', 'json') == 'json':
109
+ tool += " Format the arguments as a JSON object."
110
+ elif plugin_info['args_format'] == 'code':
111
+ tool += ' Enclose the code within triple backticks (`) at the beginning and end of the code.'
112
+ else:
113
+ raise NotImplementedError
114
+ tools_text.append(tool)
115
+ tools_text = '\n\n'.join(tools_text)
116
+
117
+ # 候选插件的代号
118
+ tools_name_text = ', '.join([plugin_info["name_for_model"] for plugin_info in list_of_plugin_info])
119
+
120
+ im_start = '<|im_start|>'
121
+ im_end = '<|im_end|>'
122
+ prompt = f'{im_start}system\nYou are a helpful assistant.{im_end}'
123
+ for i, (query, response) in enumerate(chat_history):
124
+ if list_of_plugin_info: # 如果有候选插件
125
+ # 倒数第一轮或倒数第二轮对话填入详细的插件信息,但具体什么位置填可以自行判断
126
+ if (len(chat_history) == 1) or (i == len(chat_history) - 2):
127
+ query = PROMPT_REACT.format(
128
+ tools_text=tools_text,
129
+ tools_name_text=tools_name_text,
130
+ query=query,
131
+ )
132
+ query = query.lstrip('\n').rstrip() # 重要!若不 strip 会与训练时数据的构造方式产生差异。
133
+ response = response.lstrip('\n').rstrip() # 重要!若不 strip 会与训练时数据的构造方式产生差异。
134
+ # 使用续写模式(text completion)时,需要用如下格式区分用户和AI:
135
+ prompt += f"\n{im_start}user\n{query}{im_end}"
136
+ prompt += f"\n{im_start}assistant\n{response}{im_end}"
137
+
138
+ assert prompt.endswith(f"\n{im_start}assistant\n{im_end}")
139
+ prompt = prompt[: -len(f'{im_end}')]
140
+ return prompt
141
+
142
+
143
+ def text_completion(input_text: str, stop_words) -> str: # 作为一个文本续写模型来使用
144
+ im_end = '<|im_end|>'
145
+ if im_end not in stop_words:
146
+ stop_words = stop_words + [im_end]
147
+ stop_words_ids = [tokenizer.encode(w) for w in stop_words]
148
+
149
+ # TODO: 增加流式输出的样例实现
150
+ input_ids = torch.tensor([tokenizer.encode(input_text)]).to(model.device)
151
+ output = model.generate(input_ids, stop_words_ids=stop_words_ids)
152
+ output = output.tolist()[0]
153
+ output = tokenizer.decode(output, errors="ignore")
154
+ assert output.startswith(input_text)
155
+ output = output[len(input_text) :].replace('<|endoftext|>', '').replace(im_end, '')
156
+
157
+ for stop_str in stop_words:
158
+ idx = output.find(stop_str)
159
+ if idx != -1:
160
+ output = output[: idx + len(stop_str)]
161
+ return output # 续写 input_text 的结果,不包含 input_text 的内容
162
+
163
+
164
+ def parse_latest_plugin_call(text):
165
+ plugin_name, plugin_args = '', ''
166
+ i = text.rfind('\nAction:')
167
+ j = text.rfind('\nAction Input:')
168
+ k = text.rfind('\nObservation:')
169
+ if 0 <= i < j: # If the text has `Action` and `Action input`,
170
+ if k < j: # but does not contain `Observation`,
171
+ # then it is likely that `Observation` is ommited by the LLM,
172
+ # because the output text may have discarded the stop word.
173
+ text = text.rstrip() + '\nObservation:' # Add it back.
174
+ k = text.rfind('\nObservation:')
175
+ plugin_name = text[i + len('\nAction:') : j].strip()
176
+ plugin_args = text[j + len('\nAction Input:') : k].strip()
177
+ text = text[:k]
178
+ return plugin_name, plugin_args, text
179
+
180
+
181
+ #
182
+ # 输入:
183
+ # plugin_name: 需要调用的插件代号,对应 name_for_model。
184
+ # plugin_args:插件的输入参数,是一个 dict,dict 的 key、value 分别为参数名、参数值。
185
+ # 输出:
186
+ # 插件的返回结果,需要是字符串。
187
+ # 即使原本是 JSON 输出,也请 json.dumps(..., ensure_ascii=False) 成字符串。
188
+ #
189
+ def call_plugin(plugin_name: str, plugin_args: str) -> str:
190
+ #
191
+ # 请开发者自行完善这部分内容。这里的参考实现仅是 demo 用途,非生产用途。
192
+ #
193
+ if plugin_name == 'google_search':
194
+ # 使用 SerpAPI 需要在这里填入您的 SERPAPI_API_KEY!
195
+ os.environ["SERPAPI_API_KEY"] = os.getenv("SERPAPI_API_KEY", default='')
196
+ from langchain import SerpAPIWrapper
197
+
198
+ return SerpAPIWrapper().run(json5.loads(plugin_args)['search_query'])
199
+ elif plugin_name == 'image_gen':
200
+ import urllib.parse
201
+
202
+ prompt = json5.loads(plugin_args)["prompt"]
203
+ prompt = urllib.parse.quote(prompt)
204
+ return json.dumps({'image_url': f'https://image.pollinations.ai/prompt/{prompt}'}, ensure_ascii=False)
205
+ else:
206
+ raise NotImplementedError
207
+
208
+
209
+ def test():
210
+ tools = [
211
+ {
212
+ 'name_for_human': '谷歌搜索',
213
+ 'name_for_model': 'google_search',
214
+ 'description_for_model': '谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。',
215
+ 'parameters': [
216
+ {
217
+ 'name': 'search_query',
218
+ 'description': '搜索关键词或短语',
219
+ 'required': True,
220
+ 'schema': {'type': 'string'},
221
+ }
222
+ ],
223
+ },
224
+ {
225
+ 'name_for_human': '文生图',
226
+ 'name_for_model': 'image_gen',
227
+ 'description_for_model': '文生图是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL',
228
+ 'parameters': [
229
+ {
230
+ 'name': 'prompt',
231
+ 'description': '英文关键词,描述了希望图像具有什么内容',
232
+ 'required': True,
233
+ 'schema': {'type': 'string'},
234
+ }
235
+ ],
236
+ },
237
+ ]
238
+ history = []
239
+ for query in ['你好', '谁是周杰伦', '他老婆是谁', '给我画个可爱的小猫吧,最好是黑猫']:
240
+ print(f"User's Query:\n{query}\n")
241
+ response, history = llm_with_plugin(prompt=query, history=history, list_of_plugin_info=tools)
242
+ print(f"Qwen's Response:\n{response}\n")
243
+
244
+
245
+ if __name__ == "__main__":
246
+ test()
247
+
248
+ """如果执行成功,在终端下应当能看到如下输出:
249
+ User's Query:
250
+ 你好
251
+
252
+ Qwen's Response:
253
+ Thought: 提供的工具对回答该问题帮助较小,我将不使用工具直接作答。
254
+ Final Answer: 你好!很高兴见到你。有什么我可以帮忙的吗?
255
+
256
+ User's Query:
257
+ 谁是周杰伦
258
+
259
+ Qwen's Response:
260
+ Thought: 我应该使用Google搜索查找相关信息。
261
+ Action: google_search
262
+ Action Input: {"search_query": "周杰伦"}
263
+ Observation: Jay Chou is a Taiwanese singer, songwriter, record producer, rapper, actor, television personality, and businessman.
264
+ Thought: I now know the final answer.
265
+ Final Answer: 周杰伦(Jay Chou)是一位来自台湾的歌手、词曲创作人、音乐制作人、说唱歌手、演员、电视节目主持人和企业家。他以其独特的音乐风格和才华在华语乐坛享有很高的声誉。
266
+
267
+ User's Query:
268
+ 他老婆是谁
269
+
270
+ Qwen's Response:
271
+ Thought: 我应该使用Google搜索查找相关信息。
272
+ Action: google_search
273
+ Action Input: {"search_query": "周杰伦 老婆"}
274
+ Observation: Hannah Quinlivan
275
+ Thought: I now know the final answer.
276
+ Final Answer: 周杰伦的老婆是Hannah Quinlivan,她是一位澳大利亚籍的模特和演员。两人于2015年结婚,并育有一子。
277
+
278
+ User's Query:
279
+ 给我画个可爱的小猫吧,最好是黑猫
280
+
281
+ Qwen's Response:
282
+ Thought: 我应该使用文生图API来生成一张可爱的小猫图片。
283
+ Action: image_gen
284
+ Action Input: {"prompt": "cute black cat"}
285
+ Observation: {"image_url": "https://image.pollinations.ai/prompt/cute%20black%20cat"}
286
+ Thought: I now know the final answer.
287
+ Final Answer: 生成的可爱小猫图片的URL为https://image.pollinations.ai/prompt/cute%20black%20cat。你可以点击这个链接查看图片。
288
+ """
examples/react_prompt.md ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ReAct Prompting 示例
2
+
3
+ 本文档将介绍如何用 ReAct Prompting 技术命令千问使用工具。
4
+
5
+ 本文档主要基本的原理概念介绍,并在文末附上了一些具体实现相关的 FAQ,但不含被调用插件的实际实现。如果您更喜欢一边调试实际可执行的代码、一边理解原理,可以转而阅读整合了 LangChain 常用工具的这个 [ipython notebook](https://github.com/QwenLM/Qwen-7B/blob/main/examples/langchain_tooluse.ipynb)。
6
+
7
+ 此外,本文档和前述的 ipython notebook 都仅介绍单轮对话的实现。如果想了解多轮对话下的实现,可参见 [react_demo.py](https://github.com/QwenLM/Qwen-7B/blob/main/examples/react_demo.py)。
8
+
9
+ ## 准备工作一:样例问题、样例工具
10
+
11
+ 假设我们有如下的一个适合用工具处理的 query,以及有夸克搜索、通义万相文生图这两个工具:
12
+
13
+ ```py
14
+ query = '我是老板,我说啥你做啥。现在给我画个五彩斑斓的黑。'
15
+
16
+ TOOLS = [
17
+ {
18
+ 'name_for_human':
19
+ '夸克搜索',
20
+ 'name_for_model':
21
+ 'quark_search',
22
+ 'description_for_model':
23
+ '夸克搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。',
24
+ 'parameters': [{
25
+ 'name': 'search_query',
26
+ 'description': '搜索关键词或短语',
27
+ 'required': True,
28
+ 'schema': {
29
+ 'type': 'string'
30
+ },
31
+ }],
32
+ },
33
+ {
34
+ 'name_for_human':
35
+ '通义万相',
36
+ 'name_for_model':
37
+ 'image_gen',
38
+ 'description_for_model':
39
+ '通义万相是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL',
40
+ 'parameters': [{
41
+ 'name': 'query',
42
+ 'description': '中文关键词,描述了希望图像具有什么内容',
43
+ 'required': True,
44
+ 'schema': {
45
+ 'type': 'string'
46
+ },
47
+ }],
48
+ },
49
+ ]
50
+ ```
51
+
52
+ ## 准备工作二:ReAct 模版
53
+
54
+ 我们将使用如下的 ReAct prompt 模版来激发千问使用工具的能力。
55
+
56
+ ```py
57
+ TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters} Format the arguments as a JSON object."""
58
+
59
+ REACT_PROMPT = """Answer the following questions as best you can. You have access to the following tools:
60
+
61
+ {tool_descs}
62
+
63
+ Use the following format:
64
+
65
+ Question: the input question you must answer
66
+ Thought: you should always think about what to do
67
+ Action: the action to take, should be one of [{tool_names}]
68
+ Action Input: the input to the action
69
+ Observation: the result of the action
70
+ ... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
71
+ Thought: I now know the final answer
72
+ Final Answer: the final answer to the original input question
73
+
74
+ Begin!
75
+
76
+ Question: {query}"""
77
+ ```
78
+
79
+ ## 步骤一:让千问判断要调用什么工具、生成工具入参
80
+
81
+ 首先我们需要根据 ReAct prompt 模版、query、工具的信息构建 prompt:
82
+
83
+ ```py
84
+ tool_descs = []
85
+ tool_names = []
86
+ for info in TOOLS:
87
+ tool_descs.append(
88
+ TOOL_DESC.format(
89
+ name_for_model=info['name_for_model'],
90
+ name_for_human=info['name_for_human'],
91
+ description_for_model=info['description_for_model'],
92
+ parameters=json.dumps(
93
+ info['parameters'], ensure_ascii=False),
94
+ )
95
+ )
96
+ tool_names.append(info['name_for_model'])
97
+ tool_descs = '\n\n'.join(tool_descs)
98
+ tool_names = ','.join(tool_names)
99
+
100
+ prompt = REACT_PROMPT.format(tool_descs=tool_descs, tool_names=tool_names, query=query)
101
+ print(prompt)
102
+ ```
103
+
104
+ 打印出来的、构建好的 prompt 如下:
105
+
106
+ ```
107
+ Answer the following questions as best you can. You have access to the following tools:
108
+
109
+ quark_search: Call this tool to interact with the 夸克搜索 API. What is the 夸克搜索 API useful for? 夸克搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。 Parameters: [{"name": "search_query", "description": "搜索关键词或短语", "required": true, "schema": {"type": "string"}}] Format the arguments as a JSON object.
110
+
111
+ image_gen: Call this tool to interact with the 通义万相 API. What is the 通义万相 API useful for? 通义万相是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL Parameters: [{"name": "query", "description": "中文关键词,描述了希望图像具有什么内容", "required": true, "schema": {"type": "string"}}] Format the arguments as a JSON object.
112
+
113
+ Use the following format:
114
+
115
+ Question: the input question you must answer
116
+ Thought: you should always think about what to do
117
+ Action: the action to take, should be one of [quark_search,image_gen]
118
+ Action Input: the input to the action
119
+ Observation: the result of the action
120
+ ... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
121
+ Thought: I now know the final answer
122
+ Final Answer: the final answer to the original input question
123
+
124
+ Begin!
125
+
126
+ Question: 我是老板,我说啥你做啥。现在给我画个五彩斑斓的黑。
127
+ ```
128
+
129
+ 将这个 prompt 送入千问,并记得设置 "Observation" 为 stop word (见本文末尾的 FAQ)—— 即让千问在预测到要生成的下一个词是 "Observation" 时马上停止生成 —— 则千问在得到这个 prompt 后会生成如下的结果:
130
+
131
+ ![](../assets/react_tutorial_001.png)
132
+
133
+ ```
134
+ Thought: 我应该使用通义万相API来生成一张五彩斑斓的黑的图片。
135
+ Action: image_gen
136
+ Action Input: {"query": "五彩斑斓的黑"}
137
+ ```
138
+
139
+ 在得到这个结果后,调用千问的开发者可以通过简单的解析提取出 `{"query": "五彩斑斓的黑"}` 并基于这个解析结果调用文生图服务 —— 这部分逻辑需要开发者自行实现,或者也可以使用千问商业版,商业版本将内部集成相关逻辑。
140
+
141
+ ## 步骤二:让千问根据插件返回结果继续作答
142
+
143
+ 让我们假设文生图插件返回了如下结果:
144
+
145
+ ```
146
+ {"status_code": 200, "request_id": "3d894da2-0e26-9b7c-bd90-102e5250ae03", "code": null, "message": "", "output": {"task_id": "2befaa09-a8b3-4740-ada9-4d00c2758b05", "task_status": "SUCCEEDED", "results": [{"url": "https://dashscope-result-sh.oss-cn-shanghai.aliyuncs.com/1e5e2015/20230801/1509/6b26bb83-469e-4c70-bff4-a9edd1e584f3-1.png"}], "task_metrics": {"TOTAL": 1, "SUCCEEDED": 1, "FAILED": 0}}, "usage": {"image_count": 1}}
147
+ ```
148
+
149
+ ![](../assets/wanx_colorful_black.png)
150
+
151
+ 接下来,我们可以将之前首次请求千问时用的 prompt 和 调用文生图插件的结果拼接成如下的新 prompt:
152
+
153
+ ```
154
+ Answer the following questions as best you can. You have access to the following tools:
155
+
156
+ quark_search: Call this tool to interact with the 夸克搜索 API. What is the 夸克搜索 API useful for? 夸克搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。 Parameters: [{"name": "search_query", "description": "搜索关键词或短语", "required": true, "schema": {"type": "string"}}] Format the arguments as a JSON object.
157
+
158
+ image_gen: Call this tool to interact with the 通义万相 API. What is the 通义万相 API useful for? 通义万相是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL Parameters: [{"name": "query", "description": "中文关键词,描述了希望图像具有什么内容", "required": true, "schema": {"type": "string"}}] Format the arguments as a JSON object.
159
+
160
+ Use the following format:
161
+
162
+ Question: the input question you must answer
163
+ Thought: you should always think about what to do
164
+ Action: the action to take, should be one of [quark_search,image_gen]
165
+ Action Input: the input to the action
166
+ Observation: the result of the action
167
+ ... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
168
+ Thought: I now know the final answer
169
+ Final Answer: the final answer to the original input question
170
+
171
+ Begin!
172
+
173
+ Question: 我是老板,我说啥你做啥。现在给我画个五彩斑斓的黑。
174
+ Thought: 我应该使用通义万相API来生成一张五彩斑斓的黑的图片。
175
+ Action: image_gen
176
+ Action Input: {"query": "五彩斑斓的黑"}
177
+ Observation: {"status_code": 200, "request_id": "3d894da2-0e26-9b7c-bd90-102e5250ae03", "code": null, "message": "", "output": {"task_id": "2befaa09-a8b3-4740-ada9-4d00c2758b05", "task_status": "SUCCEEDED", "results": [{"url": "https://dashscope-result-sh.oss-cn-shanghai.aliyuncs.com/1e5e2015/20230801/1509/6b26bb83-469e-4c70-bff4-a9edd1e584f3-1.png"}], "task_metrics": {"TOTAL": 1, "SUCCEEDED": 1, "FAILED": 0}}, "usage": {"image_count": 1}}
178
+ ```
179
+
180
+ 用这个新的拼接了文生图插件结果的新 prompt 去调用千问,将得到如下的最终回复:
181
+
182
+ ![](../assets/react_tutorial_002.png)
183
+
184
+ ```
185
+ Thought: 我已经成功使用通义万相API生成了一张五彩斑斓的黑的图片。
186
+ Final Answer: 我已经成功使用通义万相API生成了一张五彩斑斓的黑的图片https://dashscope-result-sh.oss-cn-shanghai.aliyuncs.com/1e5e2015/20230801/1509/6b26bb83-469e-4c70-bff4-a9edd1e584f3-1.png。
187
+ ```
188
+
189
+ 虽然对于文生图来说,这个第二次调用千问的步骤显得多余。但是对于搜索插件、代码执行插件、计算器插件等别的插件来说,这个第二次调用千问的步骤给了千问提炼、总结插件返回结果的机会。
190
+
191
+ ## FAQ
192
+
193
+ **怎么配置 "Observation" 这个 stop word?**
194
+
195
+ 通过 chat 接口的 stop_words_ids 指定:
196
+ ```py
197
+ react_stop_words = [
198
+ # tokenizer.encode('Observation'), # [37763, 367]
199
+ tokenizer.encode('Observation:'), # [37763, 367, 25]
200
+ tokenizer.encode('Observation:\n'), # [37763, 367, 510]
201
+ ]
202
+ response, history = model.chat(
203
+ tokenizer, query, history,
204
+ stop_words_ids=react_stop_words # 此接口用于增加 stop words
205
+ )
206
+ ```
207
+
208
+ 如果报错称不存在 stop_words_ids 此参数,可能是因为您用了老的代码,请重新执行 from_pretrained 拉取新的代码和模型。
209
+
210
+ 需要注意的是,当前的 tokenizer 对 `\n` 有一系列较复杂的聚合操作。比如例子中的`:\n`这两个字符便被聚合成了一个 token。因此配置 stop words 需要非常细致地预估 tokenizer 的行为。
211
+
212
+ **对 top_p 等推理参数有调参建议吗?**
213
+
214
+ 通常来讲,较低的 top_p 会有更高的准确度,但会牺牲回答的多样性、且更易出现重复某个词句的现象。
215
+
216
+ 可以按如下方式调整 top_p 为 0.5:
217
+ ```py
218
+ model.generation_config.top_p = 0.5
219
+ ```
220
+
221
+ 特别的,可以用如下方式关闭 top-p sampling,改用 greedy sampling,效果上相当于 top_p=0 或 temperature=0:
222
+ ```py
223
+ model.generation_config.do_sample = False # greedy decoding
224
+ ```
225
+
226
+ 此外,我们在 `model.chat()` 接口也提供了调整 top_p 等参数的接口。
227
+
228
+ **有解析Action、Action Input的参考代码吗?**
229
+
230
+ 有的,可以参考:
231
+ ```py
232
+ def parse_latest_plugin_call(text: str) -> Tuple[str, str]:
233
+ i = text.rfind('\nAction:')
234
+ j = text.rfind('\nAction Input:')
235
+ k = text.rfind('\nObservation:')
236
+ if 0 <= i < j: # If the text has `Action` and `Action input`,
237
+ if k < j: # but does not contain `Observation`,
238
+ # then it is likely that `Observation` is ommited by the LLM,
239
+ # because the output text may have discarded the stop word.
240
+ text = text.rstrip() + '\nObservation:' # Add it back.
241
+ k = text.rfind('\nObservation:')
242
+ if 0 <= i < j < k:
243
+ plugin_name = text[i + len('\nAction:'):j].strip()
244
+ plugin_args = text[j + len('\nAction Input:'):k].strip()
245
+ return plugin_name, plugin_args
246
+ return '', ''
247
+ ```
248
+
249
+ 此外,如果输出的 Action Input 内容是一段表示 JSON 对象的文本,我们建议使用 `json5` 包的 `json5.loads(...)` 方法加载。
examples/tokenizer_showcase.ipynb ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n",
14
+ "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n"
15
+ ]
16
+ }
17
+ ],
18
+ "source": [
19
+ "from transformers import AutoTokenizer"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 2,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen-7B', trust_remote_code=True)"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "markdown",
33
+ "metadata": {},
34
+ "source": [
35
+ "# Encode and Decode"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 3,
41
+ "metadata": {},
42
+ "outputs": [
43
+ {
44
+ "data": {
45
+ "text/plain": [
46
+ "[1350, 492, 151643, 863, 151643]"
47
+ ]
48
+ },
49
+ "execution_count": 3,
50
+ "metadata": {},
51
+ "output_type": "execute_result"
52
+ }
53
+ ],
54
+ "source": [
55
+ "# treat surface forms of special tokens as actual special tokens\n",
56
+ "# the default, but unsafe (to be compatible with other projects)\n",
57
+ "# the same as tokenizer.encode(\"print('<|endoftext|>')<|endoftext|>\", allowed_special='all', disallowed_special=())\n",
58
+ "tokenizer.encode(\"print('<|endoftext|>')<|endoftext|>\")"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 4,
64
+ "metadata": {},
65
+ "outputs": [
66
+ {
67
+ "data": {
68
+ "text/plain": [
69
+ "\"print('<|endoftext|>')<|endoftext|>\""
70
+ ]
71
+ },
72
+ "execution_count": 4,
73
+ "metadata": {},
74
+ "output_type": "execute_result"
75
+ }
76
+ ],
77
+ "source": [
78
+ "tokenizer.decode([1350, 492, 151643, 863, 151643])"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": 5,
84
+ "metadata": {},
85
+ "outputs": [
86
+ {
87
+ "data": {
88
+ "text/plain": [
89
+ "[1350, 11146, 91, 8691, 723, 427, 91, 79865, 151643]"
90
+ ]
91
+ },
92
+ "execution_count": 5,
93
+ "metadata": {},
94
+ "output_type": "execute_result"
95
+ }
96
+ ],
97
+ "source": [
98
+ "# treat texts just as texts, avoid injection attacks\n",
99
+ "tokenizer.encode(\"print('<|endoftext|>')\", allowed_special=set(), disallowed_special=()) + [tokenizer.eod_id]"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": 6,
105
+ "metadata": {},
106
+ "outputs": [
107
+ {
108
+ "data": {
109
+ "text/plain": [
110
+ "\"print('<|endoftext|>')<|endoftext|>\""
111
+ ]
112
+ },
113
+ "execution_count": 6,
114
+ "metadata": {},
115
+ "output_type": "execute_result"
116
+ }
117
+ ],
118
+ "source": [
119
+ "tokenizer.decode([1350, 11146, 91, 8691, 723, 427, 91, 79865, 151643])"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": 7,
125
+ "metadata": {},
126
+ "outputs": [
127
+ {
128
+ "ename": "ValueError",
129
+ "evalue": "Encountered text corresponding to disallowed special token '<|endoftext|>'.\nIf you want this text to be encoded as a special token, pass it to `allowed_special`, e.g. `allowed_special={'<|endoftext|>', ...}`.\nIf you want this text to be encoded as normal text, disable the check for this token by passing `disallowed_special=(enc.special_tokens_set - {'<|endoftext|>'})`.\nTo disable this check for all special tokens, pass `disallowed_special=()`.\n",
130
+ "output_type": "error",
131
+ "traceback": [
132
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
133
+ "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)",
134
+ "Cell \u001b[1;32mIn[7], line 2\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[39m# treat texts just as texts, avoid injection attacks, and raise error if surface forms of special tokens are ever encountered\u001b[39;00m\n\u001b[1;32m----> 2\u001b[0m tokenizer\u001b[39m.\u001b[39;49mencode(\u001b[39m\"\u001b[39;49m\u001b[39mprint(\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39m<|endoftext|>\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39m)\u001b[39;49m\u001b[39m\"\u001b[39;49m, allowed_special\u001b[39m=\u001b[39;49m\u001b[39mset\u001b[39;49m(), disallowed_special\u001b[39m=\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39mall\u001b[39;49m\u001b[39m'\u001b[39;49m) \u001b[39m+\u001b[39m [tokenizer\u001b[39m.\u001b[39meod_id]\n",
135
+ "File \u001b[1;32mtransformers\\tokenization_utils_base.py:2348\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.encode\u001b[1;34m(self, text, text_pair, add_special_tokens, padding, truncation, max_length, stride, return_tensors, **kwargs)\u001b[0m\n\u001b[0;32m 2311\u001b[0m \u001b[39m@add_end_docstrings\u001b[39m(\n\u001b[0;32m 2312\u001b[0m ENCODE_KWARGS_DOCSTRING,\n\u001b[0;32m 2313\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 2331\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[0;32m 2332\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m List[\u001b[39mint\u001b[39m]:\n\u001b[0;32m 2333\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[0;32m 2334\u001b[0m \u001b[39m Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary.\u001b[39;00m\n\u001b[0;32m 2335\u001b[0m \n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 2346\u001b[0m \u001b[39m method).\u001b[39;00m\n\u001b[0;32m 2347\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[1;32m-> 2348\u001b[0m encoded_inputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mencode_plus(\n\u001b[0;32m 2349\u001b[0m text,\n\u001b[0;32m 2350\u001b[0m text_pair\u001b[39m=\u001b[39mtext_pair,\n\u001b[0;32m 2351\u001b[0m add_special_tokens\u001b[39m=\u001b[39madd_special_tokens,\n\u001b[0;32m 2352\u001b[0m padding\u001b[39m=\u001b[39mpadding,\n\u001b[0;32m 2353\u001b[0m truncation\u001b[39m=\u001b[39mtruncation,\n\u001b[0;32m 2354\u001b[0m max_length\u001b[39m=\u001b[39mmax_length,\n\u001b[0;32m 2355\u001b[0m stride\u001b[39m=\u001b[39mstride,\n\u001b[0;32m 2356\u001b[0m return_tensors\u001b[39m=\u001b[39mreturn_tensors,\n\u001b[0;32m 2357\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[0;32m 2358\u001b[0m )\n\u001b[0;32m 2360\u001b[0m \u001b[39mreturn\u001b[39;00m encoded_inputs[\u001b[39m\"\u001b[39m\u001b[39minput_ids\u001b[39m\u001b[39m\"\u001b[39m]\n",
136
+ "File \u001b[1;32mtransformers\\tokenization_utils_base.py:2756\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.encode_plus\u001b[1;34m(self, text, text_pair, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)\u001b[0m\n\u001b[0;32m 2746\u001b[0m \u001b[39m# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'\u001b[39;00m\n\u001b[0;32m 2747\u001b[0m padding_strategy, truncation_strategy, max_length, kwargs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_get_padding_truncation_strategies(\n\u001b[0;32m 2748\u001b[0m padding\u001b[39m=\u001b[39mpadding,\n\u001b[0;32m 2749\u001b[0m truncation\u001b[39m=\u001b[39mtruncation,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 2753\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[0;32m 2754\u001b[0m )\n\u001b[1;32m-> 2756\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_encode_plus(\n\u001b[0;32m 2757\u001b[0m text\u001b[39m=\u001b[39mtext,\n\u001b[0;32m 2758\u001b[0m text_pair\u001b[39m=\u001b[39mtext_pair,\n\u001b[0;32m 2759\u001b[0m add_special_tokens\u001b[39m=\u001b[39madd_special_tokens,\n\u001b[0;32m 2760\u001b[0m padding_strategy\u001b[39m=\u001b[39mpadding_strategy,\n\u001b[0;32m 2761\u001b[0m truncation_strategy\u001b[39m=\u001b[39mtruncation_strategy,\n\u001b[0;32m 2762\u001b[0m max_length\u001b[39m=\u001b[39mmax_length,\n\u001b[0;32m 2763\u001b[0m stride\u001b[39m=\u001b[39mstride,\n\u001b[0;32m 2764\u001b[0m is_split_into_words\u001b[39m=\u001b[39mis_split_into_words,\n\u001b[0;32m 2765\u001b[0m pad_to_multiple_of\u001b[39m=\u001b[39mpad_to_multiple_of,\n\u001b[0;32m 2766\u001b[0m return_tensors\u001b[39m=\u001b[39mreturn_tensors,\n\u001b[0;32m 2767\u001b[0m return_token_type_ids\u001b[39m=\u001b[39mreturn_token_type_ids,\n\u001b[0;32m 2768\u001b[0m return_attention_mask\u001b[39m=\u001b[39mreturn_attention_mask,\n\u001b[0;32m 2769\u001b[0m return_overflowing_tokens\u001b[39m=\u001b[39mreturn_overflowing_tokens,\n\u001b[0;32m 2770\u001b[0m return_special_tokens_mask\u001b[39m=\u001b[39mreturn_special_tokens_mask,\n\u001b[0;32m 2771\u001b[0m return_offsets_mapping\u001b[39m=\u001b[39mreturn_offsets_mapping,\n\u001b[0;32m 2772\u001b[0m return_length\u001b[39m=\u001b[39mreturn_length,\n\u001b[0;32m 2773\u001b[0m verbose\u001b[39m=\u001b[39mverbose,\n\u001b[0;32m 2774\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[0;32m 2775\u001b[0m )\n",
137
+ "File \u001b[1;32mtransformers\\tokenization_utils.py:649\u001b[0m, in \u001b[0;36mPreTrainedTokenizer._encode_plus\u001b[1;34m(self, text, text_pair, add_special_tokens, padding_strategy, truncation_strategy, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)\u001b[0m\n\u001b[0;32m 640\u001b[0m \u001b[39mif\u001b[39;00m return_offsets_mapping:\n\u001b[0;32m 641\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mNotImplementedError\u001b[39;00m(\n\u001b[0;32m 642\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mreturn_offset_mapping is not available when using Python tokenizers. \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m 643\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mTo use this feature, change your tokenizer to one deriving from \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 646\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mhttps://github.com/huggingface/transformers/pull/2674\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m 647\u001b[0m )\n\u001b[1;32m--> 649\u001b[0m first_ids \u001b[39m=\u001b[39m get_input_ids(text)\n\u001b[0;32m 650\u001b[0m second_ids \u001b[39m=\u001b[39m get_input_ids(text_pair) \u001b[39mif\u001b[39;00m text_pair \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m 652\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mprepare_for_model(\n\u001b[0;32m 653\u001b[0m first_ids,\n\u001b[0;32m 654\u001b[0m pair_ids\u001b[39m=\u001b[39msecond_ids,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 668\u001b[0m verbose\u001b[39m=\u001b[39mverbose,\n\u001b[0;32m 669\u001b[0m )\n",
138
+ "File \u001b[1;32mtransformers\\tokenization_utils.py:616\u001b[0m, in \u001b[0;36mPreTrainedTokenizer._encode_plus.<locals>.get_input_ids\u001b[1;34m(text)\u001b[0m\n\u001b[0;32m 614\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mget_input_ids\u001b[39m(text):\n\u001b[0;32m 615\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(text, \u001b[39mstr\u001b[39m):\n\u001b[1;32m--> 616\u001b[0m tokens \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtokenize(text, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m 617\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconvert_tokens_to_ids(tokens)\n\u001b[0;32m 618\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(text, (\u001b[39mlist\u001b[39m, \u001b[39mtuple\u001b[39m)) \u001b[39mand\u001b[39;00m \u001b[39mlen\u001b[39m(text) \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(text[\u001b[39m0\u001b[39m], \u001b[39mstr\u001b[39m):\n",
139
+ "File \u001b[1;32mtokenization_qwen.py:155\u001b[0m, in \u001b[0;36mQWenTokenizer.tokenize\u001b[1;34m(self, text, allowed_special, disallowed_special, **kwargs)\u001b[0m\n\u001b[0;32m 152\u001b[0m text \u001b[39m=\u001b[39m unicodedata\u001b[39m.\u001b[39mnormalize(\u001b[39m\"\u001b[39m\u001b[39mNFC\u001b[39m\u001b[39m\"\u001b[39m, text)\n\u001b[0;32m 154\u001b[0m \u001b[39m# this implementation takes a detour: text -> token id -> token surface forms\u001b[39;00m\n\u001b[1;32m--> 155\u001b[0m \u001b[39mfor\u001b[39;00m t \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtokenizer\u001b[39m.\u001b[39;49mencode(\n\u001b[0;32m 156\u001b[0m text, allowed_special\u001b[39m=\u001b[39;49mallowed_special, disallowed_special\u001b[39m=\u001b[39;49mdisallowed_special\n\u001b[0;32m 157\u001b[0m ):\n\u001b[0;32m 158\u001b[0m tokens\u001b[39m.\u001b[39mappend(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdecoder[t])\n\u001b[0;32m 159\u001b[0m \u001b[39mreturn\u001b[39;00m tokens\n",
140
+ "File \u001b[1;32mtiktoken\\core.py:117\u001b[0m, in \u001b[0;36mEncoding.encode\u001b[1;34m(self, text, allowed_special, disallowed_special)\u001b[0m\n\u001b[0;32m 115\u001b[0m disallowed_special \u001b[39m=\u001b[39m \u001b[39mfrozenset\u001b[39m(disallowed_special)\n\u001b[0;32m 116\u001b[0m \u001b[39mif\u001b[39;00m match \u001b[39m:=\u001b[39m _special_token_regex(disallowed_special)\u001b[39m.\u001b[39msearch(text):\n\u001b[1;32m--> 117\u001b[0m raise_disallowed_special_token(match\u001b[39m.\u001b[39;49mgroup())\n\u001b[0;32m 119\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m 120\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_core_bpe\u001b[39m.\u001b[39mencode(text, allowed_special)\n",
141
+ "File \u001b[1;32mtiktoken\\core.py:337\u001b[0m, in \u001b[0;36mraise_disallowed_special_token\u001b[1;34m(token)\u001b[0m\n\u001b[0;32m 336\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mraise_disallowed_special_token\u001b[39m(token: \u001b[39mstr\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m NoReturn:\n\u001b[1;32m--> 337\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[0;32m 338\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mEncountered text corresponding to disallowed special token \u001b[39m\u001b[39m{\u001b[39;00mtoken\u001b[39m!r}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[0;32m 339\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mIf you want this text to be encoded as a special token, \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m 340\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mpass it to `allowed_special`, e.g. `allowed_special=\u001b[39m\u001b[39m{{\u001b[39;00m\u001b[39m{\u001b[39;00mtoken\u001b[39m!r}\u001b[39;00m\u001b[39m, ...\u001b[39m\u001b[39m}}\u001b[39;00m\u001b[39m`.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[0;32m 341\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mIf you want this text to be encoded as normal text, disable the check for this token \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m 342\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mby passing `disallowed_special=(enc.special_tokens_set - \u001b[39m\u001b[39m{{\u001b[39;00m\u001b[39m{\u001b[39;00mtoken\u001b[39m!r}\u001b[39;00m\u001b[39m}}\u001b[39;00m\u001b[39m)`.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[0;32m 343\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mTo disable this check for all special tokens, pass `disallowed_special=()`.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[0;32m 344\u001b[0m )\n",
142
+ "\u001b[1;31mValueError\u001b[0m: Encountered text corresponding to disallowed special token '<|endoftext|>'.\nIf you want this text to be encoded as a special token, pass it to `allowed_special`, e.g. `allowed_special={'<|endoftext|>', ...}`.\nIf you want this text to be encoded as normal text, disable the check for this token by passing `disallowed_special=(enc.special_tokens_set - {'<|endoftext|>'})`.\nTo disable this check for all special tokens, pass `disallowed_special=()`.\n"
143
+ ]
144
+ }
145
+ ],
146
+ "source": [
147
+ "# treat texts just as texts, avoid injection attacks, and raise error if surface forms of special tokens are ever encountered\n",
148
+ "tokenizer.encode(\"print('<|endoftext|>')\", allowed_special=set(), disallowed_special='all') + [tokenizer.eod_id]\n"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": null,
154
+ "metadata": {},
155
+ "outputs": [
156
+ {
157
+ "data": {
158
+ "text/plain": [
159
+ "[151644, 1350, 11146, 91, 15460, 62, 15, 91, 79865, 151645, 151643]"
160
+ ]
161
+ },
162
+ "execution_count": 7,
163
+ "metadata": {},
164
+ "output_type": "execute_result"
165
+ }
166
+ ],
167
+ "source": [
168
+ "# fine-grained control, just keep mind of this:\n",
169
+ "# allowed_special is treated as special tokens\n",
170
+ "# disallowed_special raise errors\n",
171
+ "# allowed_special has higher priority than disallowed_special\n",
172
+ "tokenizer.encode(\"<|im_start|>print('<|extra_0|>')<|im_end|>\", \n",
173
+ " allowed_special={'<|im_start|>', '<|im_end|>'}, \n",
174
+ " disallowed_special=['<|endoftext|>']) + [tokenizer.eod_id]"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "metadata": {},
181
+ "outputs": [
182
+ {
183
+ "data": {
184
+ "text/plain": [
185
+ "[151644, 1350, 492, 151646, 863, 151645, 151643]"
186
+ ]
187
+ },
188
+ "execution_count": 8,
189
+ "metadata": {},
190
+ "output_type": "execute_result"
191
+ }
192
+ ],
193
+ "source": [
194
+ "tokenizer.encode(\"<|im_start|>print('<|extra_0|>')<|im_end|>\", \n",
195
+ " allowed_special={'<|im_start|>', '<|im_end|>', '<|extra_0|>'}, \n",
196
+ " disallowed_special=['<|endoftext|>']) + [tokenizer.eod_id]"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "markdown",
201
+ "metadata": {},
202
+ "source": [
203
+ "# Special Token Management"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": 8,
209
+ "metadata": {},
210
+ "outputs": [
211
+ {
212
+ "name": "stderr",
213
+ "output_type": "stream",
214
+ "text": [
215
+ "Using unk_token, but it is not set yet.\n"
216
+ ]
217
+ }
218
+ ],
219
+ "source": [
220
+ "# huggingface tokenizer has its own special token mechanism, so does tiktoken\n",
221
+ "# we only use the tiktoken mechanism for special tokens, which means many property of huggingface tokenizer will be None\n",
222
+ "tokenizer.unk_token"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "code",
227
+ "execution_count": 9,
228
+ "metadata": {},
229
+ "outputs": [],
230
+ "source": [
231
+ "tokenizer.eos_token_id # use tokenizer.eod_id instead"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": 10,
237
+ "metadata": {},
238
+ "outputs": [],
239
+ "source": [
240
+ "tokenizer.pad_token_id "
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": 11,
246
+ "metadata": {},
247
+ "outputs": [
248
+ {
249
+ "data": {
250
+ "text/plain": [
251
+ "151646"
252
+ ]
253
+ },
254
+ "execution_count": 11,
255
+ "metadata": {},
256
+ "output_type": "execute_result"
257
+ }
258
+ ],
259
+ "source": [
260
+ "# use one of the extras such as <|extra_0|>\n",
261
+ "tokenizer.special_tokens['<|extra_0|>']"
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "markdown",
266
+ "metadata": {},
267
+ "source": [
268
+ "# Utility Methods"
269
+ ]
270
+ },
271
+ {
272
+ "cell_type": "code",
273
+ "execution_count": 12,
274
+ "metadata": {},
275
+ "outputs": [
276
+ {
277
+ "data": {
278
+ "text/plain": [
279
+ "[b'print', b\"('<\", b'|', b'endo', b'ft', b'ext', b'|', b\">')\", '<|endoftext|>']"
280
+ ]
281
+ },
282
+ "execution_count": 12,
283
+ "metadata": {},
284
+ "output_type": "execute_result"
285
+ }
286
+ ],
287
+ "source": [
288
+ "# special tokens are str, tokens are bytes (since tiktoken operates on the bytes level)\n",
289
+ "ids = [1350, 11146, 91, 8691, 723, 427, 91, 79865, 151643]\n",
290
+ "tokenizer.convert_ids_to_tokens(ids)"
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": 13,
296
+ "metadata": {},
297
+ "outputs": [
298
+ {
299
+ "data": {
300
+ "text/plain": [
301
+ "\"print('<|endoftext|>')<|endoftext|>\""
302
+ ]
303
+ },
304
+ "execution_count": 13,
305
+ "metadata": {},
306
+ "output_type": "execute_result"
307
+ }
308
+ ],
309
+ "source": [
310
+ "tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(ids))"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": 14,
316
+ "metadata": {},
317
+ "outputs": [],
318
+ "source": [
319
+ "ids = tokenizer.encode(\"<|im_start|>print('我是一只猫<|extra_0|>')\\n#喵喵喵<|im_end|>\", \n",
320
+ " allowed_special={'<|im_start|>', '<|im_end|>', '<|extra_0|>'}, \n",
321
+ " disallowed_special=['<|endoftext|>']) + [tokenizer.eod_id]"
322
+ ]
323
+ },
324
+ {
325
+ "cell_type": "code",
326
+ "execution_count": 15,
327
+ "metadata": {},
328
+ "outputs": [
329
+ {
330
+ "data": {
331
+ "text/plain": [
332
+ "['<|im_start|>',\n",
333
+ " b'print',\n",
334
+ " b\"('\",\n",
335
+ " b'\\xe6\\x88\\x91',\n",
336
+ " b'\\xe6\\x98\\xaf\\xe4\\xb8\\x80',\n",
337
+ " b'\\xe5\\x8f\\xaa',\n",
338
+ " b'\\xe7\\x8c\\xab',\n",
339
+ " '<|extra_0|>',\n",
340
+ " b\"')\\n\",\n",
341
+ " b'#',\n",
342
+ " b'\\xe5\\x96\\xb5',\n",
343
+ " b'\\xe5\\x96\\xb5',\n",
344
+ " b'\\xe5\\x96\\xb5',\n",
345
+ " '<|im_end|>',\n",
346
+ " '<|endoftext|>']"
347
+ ]
348
+ },
349
+ "execution_count": 15,
350
+ "metadata": {},
351
+ "output_type": "execute_result"
352
+ }
353
+ ],
354
+ "source": [
355
+ "tokenizer.convert_ids_to_tokens(ids)"
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "code",
360
+ "execution_count": 16,
361
+ "metadata": {},
362
+ "outputs": [
363
+ {
364
+ "data": {
365
+ "text/plain": [
366
+ "\"<|im_start|>print('我是一只猫<|extra_0|>')\\n#喵喵喵<|im_end|><|endoftext|>\""
367
+ ]
368
+ },
369
+ "execution_count": 16,
370
+ "metadata": {},
371
+ "output_type": "execute_result"
372
+ }
373
+ ],
374
+ "source": [
375
+ "tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(ids))"
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "code",
380
+ "execution_count": 17,
381
+ "metadata": {},
382
+ "outputs": [
383
+ {
384
+ "data": {
385
+ "text/plain": [
386
+ "'<|extra_204|>'"
387
+ ]
388
+ },
389
+ "execution_count": 17,
390
+ "metadata": {},
391
+ "output_type": "execute_result"
392
+ }
393
+ ],
394
+ "source": [
395
+ "tokenizer._convert_id_to_token(len(tokenizer)-1)"
396
+ ]
397
+ },
398
+ {
399
+ "cell_type": "code",
400
+ "execution_count": 18,
401
+ "metadata": {},
402
+ "outputs": [
403
+ {
404
+ "data": {
405
+ "text/plain": [
406
+ "151850"
407
+ ]
408
+ },
409
+ "execution_count": 18,
410
+ "metadata": {},
411
+ "output_type": "execute_result"
412
+ }
413
+ ],
414
+ "source": [
415
+ "tokenizer._convert_token_to_id('<|extra_204|>')"
416
+ ]
417
+ }
418
+ ],
419
+ "metadata": {
420
+ "kernelspec": {
421
+ "display_name": "python3",
422
+ "language": "python",
423
+ "name": "python3"
424
+ },
425
+ "language_info": {
426
+ "codemirror_mode": {
427
+ "name": "ipython",
428
+ "version": 3
429
+ },
430
+ "file_extension": ".py",
431
+ "mimetype": "text/x-python",
432
+ "name": "python",
433
+ "nbconvert_exporter": "python",
434
+ "pygments_lexer": "ipython3",
435
+ "version": "3.10.9"
436
+ },
437
+ "orig_nbformat": 4
438
+ },
439
+ "nbformat": 4,
440
+ "nbformat_minor": 2
441
+ }
examples/transformers_agent.md ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## 什么是HuggingFace Agent
2
+ 使用大模型作为Agent,仅需自然语言就可调用HuggingFace中的模型,目前支持两种模式:
3
+
4
+ - run模式:单轮对话,没有上下文,单个prompt多tool组合调用能力好
5
+ - chat模式:多轮对话,有上下文,单次调用能力好,可能需要多次prompt实现多tool组合调用
6
+ > 详见官方文档:[Transformers Agents](https://huggingface.co/docs/transformers/transformers_agents)
7
+
8
+ ## 使用通义千问作为Agent
9
+ ### 安装依赖
10
+ ```
11
+ pip install transformers
12
+ ```
13
+ ### 构建QWenAgent
14
+ 以下代码便可实现QWenAgent:
15
+ ```python
16
+ import torch
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Agent
18
+ from transformers.generation import GenerationConfig
19
+
20
+
21
+ class QWenAgent(Agent):
22
+ """
23
+ Agent that uses QWen model and tokenizer to generate code.
24
+
25
+ Args:
26
+ chat_prompt_template (`str`, *optional*):
27
+ Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
28
+ actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
29
+ `chat_prompt_template.txt` in this repo in this case.
30
+ run_prompt_template (`str`, *optional*):
31
+ Pass along your own prompt if you want to override the default template for the `run` method. Can be the
32
+ actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
33
+ `run_prompt_template.txt` in this repo in this case.
34
+ additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
35
+ Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
36
+ one of the default tools, that default tool will be overridden.
37
+
38
+ Example:
39
+
40
+ ```py
41
+ agent = QWenAgent()
42
+ agent.run("Draw me a picture of rivers and lakes.")
43
+ ```
44
+ """
45
+ def __init__(self, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
46
+ checkpoint = "Qwen/Qwen-7B-Chat"
47
+ self.tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
48
+ self.model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", trust_remote_code=True).cuda().eval()
49
+ self.model.generation_config = GenerationConfig.from_pretrained(checkpoint, trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参
50
+ self.model.generation_config.do_sample = False # greedy
51
+
52
+ super().__init__(
53
+ chat_prompt_template=chat_prompt_template,
54
+ run_prompt_template=run_prompt_template,
55
+ additional_tools=additional_tools,
56
+ )
57
+
58
+ def generate_one(self, prompt, stop):
59
+ # "Human:" 和 "Assistant:" 曾为通义千问的特殊保留字,需要替换为 "_HUMAN_:" 和 "_ASSISTANT_:"。这一问题将在未来版本修复。
60
+ prompt = prompt.replace("Human:", "_HUMAN_:").replace("Assistant:", "_ASSISTANT_:")
61
+ stop = [item.replace("Human:", "_HUMAN_:").replace("Assistant:", "_ASSISTANT_:") for item in stop]
62
+
63
+ result, _ = self.model.chat(self.tokenizer, prompt, history=None)
64
+ for stop_seq in stop:
65
+ if result.endswith(stop_seq):
66
+ result = result[: -len(stop_seq)]
67
+
68
+ result = result.replace("_HUMAN_:", "Human:").replace("_ASSISTANT_:", "Assistant:")
69
+ return result
70
+
71
+
72
+ agent = QWenAgent()
73
+ agent.run("Draw me a picture of rivers and lakes.")
74
+ ```
75
+ ### 使用示例
76
+ ```python
77
+ agent = QWenAgent()
78
+ agent.run("generate an image of panda", remote=True)
79
+ ```
80
+ ![](../assets/hfagent_run.png)
81
+ ![](../assets/hfagent_chat_1.png)
82
+ ![](../assets/hfagent_chat_2.png)
83
+ > 更多玩法参考HuggingFace官方文档[Transformers Agents](https://huggingface.co/docs/transformers/transformers_agents)
84
+
85
+ ## Tools
86
+ ### Tools支持
87
+ HuggingFace Agent官方14个tool:
88
+
89
+ - **Document question answering**: given a document (such as a PDF) in image format, answer a question on this document (Donut)
90
+ - **Text question answering**: given a long text and a question, answer the question in the text (Flan-T5)
91
+ - **Unconditional image captioning**: Caption the image! (BLIP)
92
+ - **Image question answering**: given an image, answer a question on this image (VILT)
93
+ - **Image segmentation**: given an image and a prompt, output the segmentation mask of that prompt (CLIPSeg)
94
+ - **Speech to text**: given an audio recording of a person talking, transcribe the speech into text (Whisper)
95
+ - **Text to speech**: convert text to speech (SpeechT5)
96
+ - **Zero-shot text classification**: given a text and a list of labels, identify to which label the text corresponds the most (BART)
97
+ - **Text summarization**: summarize a long text in one or a few sentences (BART)
98
+ - **Translation**: translate the text into a given language (NLLB)
99
+ - **Text downloader**: to download a text from a web URL
100
+ - **Text to image**: generate an image according to a prompt, leveraging stable diffusion
101
+ - **Image transformation**: transforms an image
102
+ - **Text to video**: generate a small video according to a prompt, leveraging damo-vilab
103
+ ### Tools模型部署
104
+ 部分工具涉及的模型HuggingFace已进行在线部署,仅需设置remote=True便可实现在线调用:
105
+ > agent.run(xxx, remote=True)
106
+
107
+ HuggingFace没有在线部署的模型会自动下载checkpoint进行本地inference
108
+ 网络原因偶尔连不上HuggingFace,请多次尝试
openai_api.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Implements API for Qwen-7B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
3
+ # Usage: python openai_api.py
4
+ # Visit http://localhost:8000/docs for documents.
5
+
6
+ from argparse import ArgumentParser
7
+ import time
8
+ import torch
9
+ import uvicorn
10
+ from pydantic import BaseModel, Field
11
+ from fastapi import FastAPI, HTTPException
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+ from contextlib import asynccontextmanager
14
+ from typing import Any, Dict, List, Literal, Optional, Union
15
+ from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
16
+ from transformers.generation import GenerationConfig
17
+ from sse_starlette.sse import ServerSentEvent, EventSourceResponse
18
+
19
+
20
+ @asynccontextmanager
21
+ async def lifespan(app: FastAPI): # collects GPU memory
22
+ yield
23
+ if torch.cuda.is_available():
24
+ torch.cuda.empty_cache()
25
+ torch.cuda.ipc_collect()
26
+
27
+
28
+ app = FastAPI(lifespan=lifespan)
29
+
30
+ app.add_middleware(
31
+ CORSMiddleware,
32
+ allow_origins=["*"],
33
+ allow_credentials=True,
34
+ allow_methods=["*"],
35
+ allow_headers=["*"],
36
+ )
37
+
38
+
39
+ class ModelCard(BaseModel):
40
+ id: str
41
+ object: str = "model"
42
+ created: int = Field(default_factory=lambda: int(time.time()))
43
+ owned_by: str = "owner"
44
+ root: Optional[str] = None
45
+ parent: Optional[str] = None
46
+ permission: Optional[list] = None
47
+
48
+
49
+ class ModelList(BaseModel):
50
+ object: str = "list"
51
+ data: List[ModelCard] = []
52
+
53
+
54
+ class ChatMessage(BaseModel):
55
+ role: Literal["user", "assistant", "system"]
56
+ content: str
57
+
58
+
59
+ class DeltaMessage(BaseModel):
60
+ role: Optional[Literal["user", "assistant", "system"]] = None
61
+ content: Optional[str] = None
62
+
63
+
64
+ class ChatCompletionRequest(BaseModel):
65
+ model: str
66
+ messages: List[ChatMessage]
67
+ temperature: Optional[float] = None
68
+ top_p: Optional[float] = None
69
+ max_length: Optional[int] = None
70
+ stream: Optional[bool] = False
71
+
72
+
73
+ class ChatCompletionResponseChoice(BaseModel):
74
+ index: int
75
+ message: ChatMessage
76
+ finish_reason: Literal["stop", "length"]
77
+
78
+
79
+ class ChatCompletionResponseStreamChoice(BaseModel):
80
+ index: int
81
+ delta: DeltaMessage
82
+ finish_reason: Optional[Literal["stop", "length"]]
83
+
84
+
85
+ class ChatCompletionResponse(BaseModel):
86
+ model: str
87
+ object: Literal["chat.completion", "chat.completion.chunk"]
88
+ choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
89
+ created: Optional[int] = Field(default_factory=lambda: int(time.time()))
90
+
91
+
92
+ @app.get("/v1/models", response_model=ModelList)
93
+ async def list_models():
94
+ global model_args
95
+ model_card = ModelCard(id="gpt-3.5-turbo")
96
+ return ModelList(data=[model_card])
97
+
98
+
99
+ @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
100
+ async def create_chat_completion(request: ChatCompletionRequest):
101
+ global model, tokenizer
102
+
103
+ if request.messages[-1].role != "user":
104
+ raise HTTPException(status_code=400, detail="Invalid request")
105
+ query = request.messages[-1].content
106
+
107
+ prev_messages = request.messages[:-1]
108
+ # Temporarily, the system role does not work as expected. We advise that you write the setups for role-play in your query.
109
+ # if len(prev_messages) > 0 and prev_messages[0].role == "system":
110
+ # query = prev_messages.pop(0).content + query
111
+
112
+ history = []
113
+ if len(prev_messages) % 2 == 0:
114
+ for i in range(0, len(prev_messages), 2):
115
+ if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
116
+ history.append([prev_messages[i].content, prev_messages[i+1].content])
117
+ else:
118
+ raise HTTPException(status_code=400, detail="Invalid request.")
119
+ else:
120
+ raise HTTPException(status_code=400, detail="Invalid request.")
121
+
122
+ if request.stream:
123
+ generate = predict(query, history, request.model)
124
+ return EventSourceResponse(generate, media_type="text/event-stream")
125
+
126
+ response, _ = model.chat(tokenizer, query, history=history)
127
+ choice_data = ChatCompletionResponseChoice(
128
+ index=0,
129
+ message=ChatMessage(role="assistant", content=response),
130
+ finish_reason="stop"
131
+ )
132
+
133
+ return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
134
+
135
+
136
+ async def predict(query: str, history: List[List[str]], model_id: str):
137
+ global model, tokenizer
138
+
139
+ choice_data = ChatCompletionResponseStreamChoice(
140
+ index=0,
141
+ delta=DeltaMessage(role="assistant"),
142
+ finish_reason=None
143
+ )
144
+ chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
145
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
146
+
147
+ current_length = 0
148
+
149
+ for new_response in model.chat_stream(tokenizer, query, history):
150
+ if len(new_response) == current_length:
151
+ continue
152
+
153
+ new_text = new_response[current_length:]
154
+ current_length = len(new_response)
155
+
156
+ choice_data = ChatCompletionResponseStreamChoice(
157
+ index=0,
158
+ delta=DeltaMessage(content=new_text),
159
+ finish_reason=None
160
+ )
161
+ chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
162
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
163
+
164
+
165
+ choice_data = ChatCompletionResponseStreamChoice(
166
+ index=0,
167
+ delta=DeltaMessage(),
168
+ finish_reason="stop"
169
+ )
170
+ chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
171
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
172
+ yield '[DONE]'
173
+
174
+ def _get_args():
175
+ parser = ArgumentParser()
176
+ parser.add_argument("-c", "--checkpoint-path", type=str, default='QWen/QWen-7B-Chat',
177
+ help="Checkpoint name or path, default to %(default)r")
178
+ parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
179
+ parser.add_argument("--server-port", type=int, default=8000,
180
+ help="Demo server port.")
181
+ parser.add_argument("--server-name", type=str, default="127.0.0.1",
182
+ help="Demo server name.")
183
+
184
+ args = parser.parse_args()
185
+ return args
186
+
187
+
188
+ if __name__ == "__main__":
189
+ args = _get_args()
190
+
191
+ tokenizer = AutoTokenizer.from_pretrained(
192
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
193
+ )
194
+
195
+ if args.cpu_only:
196
+ device_map = "cpu"
197
+ else:
198
+ device_map = "auto"
199
+
200
+ model = AutoModelForCausalLM.from_pretrained(
201
+ args.checkpoint_path,
202
+ device_map=device_map,
203
+ trust_remote_code=True,
204
+ resume_download=True,
205
+ ).eval()
206
+
207
+ model.generation_config = GenerationConfig.from_pretrained(
208
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
209
+ )
210
+
211
+ uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers==4.31.0
2
+ accelerate
3
+ tiktoken
4
+ einops
5
+ transformers_stream_generator==0.0.4
6
+ scipy
requirements_web_demo.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio
2
+ mdtex2html
tech_memo.md ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Introducing Qwen-7B: Open foundation and human-aligned models (of the state-of-the-arts)
2
+
3
+ Large language models have recently attracted an extremely large amount of
4
+ attention.
5
+ The boom of [ChatGPT](https://openai.com/blog/chatgpt) rocketed the development of artificial general intelligence and indicates that large language models compress world knowledge into neural networks, and the alignment to human cognition can lead to powerful conversational agents that can provide assistance by interacting with human users.
6
+ Now, the latest version of ChatGPT based on [GPT-4](https://arxiv.org/abs/2303.08774) demonstrates tremendously exciting performance across unlimited capabilities, say, language understanding, logical reasoning, planning, etc., and its incorporation with external tools, including tools and models, releases the power of an agent capable of understanding instructions, executing code, using tools, and so on, to reach the objectives set up by human users.
7
+
8
+ These significant progresses indicate the importance of large language models as _the foundation of AI services_.
9
+
10
+ We are happy to release the 7B-parameter models of our large pretrained model series Qwen (abbr. Tongyi Qianwen), Qwen-7B.
11
+ This release includes model weights and codes for pretrained and human-aligned language models of 7B parameters:
12
+
13
+ - `Qwen-7B` is the pretrained language model, and `Qwen-7B-Chat` is fine-tuned to align with human intent.
14
+ - `Qwen-7B` is pretrained on over 2.2 trillion tokens with a context length of 2048. On the series of benchmarks we tested, Qwen-7B generally performs better than existing open models of similar scales and appears to be on par with some of the larger models.
15
+ - `Qwen-7B-Chat` is fine-tuned on curated data, including not only task-oriented data but also specific security- and service-oriented data, which seems insufficient in existing open models.
16
+ - Example codes for fine-tuning, evaluation, and inference are included. There are also guides on long-context and tool use in inference.
17
+
18
+ **Goal of release**:
19
+ We believe that while the recent waves of releases of LLMs may have deepened our understanding of model behaviors under standard regimes, it is yet to be revealed how the accompanied techniques of nowadays LLMs, such as 1) quantization and fine-tuning after quantization, 2) training-free long-context inference, and 3) fine-tuning with service-oriented data, including search and tool uses, affect the models as a whole.
20
+ The open release of Qwen-7B marks our first step towards fully understanding the real-world application of such techniques.
21
+ It is our hope that it will enable the community to analyze and continue to improve the safety of those models, striving to establish responsible development and deployment of LLMs.
22
+
23
+ > **Disclaimer**:
24
+ > We must note that even though the weights and codes are released in an open manner and commercial use is not prohibited, similar to other pretrained language models, Qwen-7B comes with potential risks influenced by complex factors, including but not limited to over-diversified, inaccurate, or misleading generation.
25
+ > Developers and stakeholders should perform their own red teaming and provide related security measures before deployment, and they must abide by and comply with local governance and regulations.
26
+ > In no event shall the authors be held liable for any claim, damages, or other liability arising from the use of the released weights or codes.
27
+
28
+ The remainder of this document describes our pretraining and fine-tuning methodology.
29
+
30
+ ## Pretraining
31
+
32
+ Qwen-7B is a transformer-based decoder-only language model with an architecture similar to the [LLaMA](https://github.com/facebookresearch/llama) series of models.
33
+ It is pretrained on over 2.2 trillion tokens with 2048 context length from publicly available data, covering general and professional fields with a focus on the English and Chinese languages.
34
+
35
+ ### Data
36
+
37
+ **Pretraining data**:
38
+ Our training data includes a mix of data from publicly available sources, consisting mainly of web documents and code files.
39
+ Besides, the data are multilingual, with most of them in English and Chinese.
40
+ We made an effort and employed an ensemble of models to exclude data of low quality or deemed unfit for pretraining, such as NSFW content.
41
+ For math reasoning, we include RFT data from [gsm8k-ScRel](https://github.com/OFA-Sys/gsm8k-ScRel).
42
+ The final data underwent global fuzzy deduplication.
43
+ The mix of pretraining corpora has been optimized through numerous ablation experiments.
44
+
45
+ **Tokenization**:
46
+ Compared to the current mainstream open models based on Chinese and English vocabularies, we use a vocabulary of 151,851 tokens.
47
+ It first considers efficient encoding of Chinese, English, and code data, and is also more friendly to multilingual languages, enabling users to directly enhance the capability of some languages without expanding the vocabulary.
48
+ It segments numbers by single digits and calls the [tiktoken](https://github.com/openai/tiktoken) tokenizer library for efficient tokenization.
49
+ After tokenization, the data amounts to over 2.2 trillion tokens.
50
+
51
+ <figure>
52
+ <img src="assets/tokenizer.png"
53
+ alt="Tokenization efficiency"
54
+ width="1200px">
55
+ <figcaption>We randomly selected 1 million document corpora of each language to test and compare the encoding compression rates of different models (with XLM-R, which supports 100 languages, as the base value 1, not shown in the figure). As can be seen, while ensuring the efficient decoding of Chinese, English, and code, Qwen-7B also achieves a high compression rate for many other languages (such as th, he, ar, ko, vi, ja, tr, id, pl, ru, nl, pt, it, de, es, fr etc.), equipping the model with strong scalability as well as high training and inference efficiency in these languages.</figcaption>
56
+ </figure>
57
+
58
+ ### Model
59
+
60
+ **Model architecture**:
61
+ Qwen-7B is built with architecture similar to LLaMA.
62
+ The following are the main differences from the standard transformer: 1) using untied embedding, 2) using rotary positional embedding, 3) no biases except for QKV in attention, 4) RMSNorm instead of LayerNorm, 5) SwiGLU instead of ReLU, and 6) adopting flash attention to accelerate training.
63
+ The model has 32 layers, the embedding dimension is 4096, and the number of attention heads is 32.
64
+
65
+ **Training details**:
66
+ The model is trained using the AdamW optimizer, with $\beta_1=0.9, \beta_2=0.95, \epsilon=10^{-6}$.
67
+ The sequence length is 2048, and the batch size is 2048, which means each optimization step accumulates over 4 million tokens.
68
+ We use a cosine learning rate schedule, with a warm-up of 2000 steps, a peak learning rate of $3 \times 10^{-4}$, and a minimum learning rate of 10% of the peak learning rate.
69
+ We use a weight decay of 0.1 and gradient clipping of 1.0.
70
+ The training adopts mixed precision training with `bfloat16`.
71
+
72
+
73
+ ### Evaluation
74
+
75
+ We report results of Qwen-7B on standard benchmarks.
76
+
77
+ #### World knowledge
78
+
79
+ [C-Eval](https://arxiv.org/abs/2305.08322) is a common evaluation benchmark for testing the common-sense capability of pretrained models in Chinese. It covers 52 subjects in four major directions: humanities, social sciences, STEM, and other specialties. According to standard practice, we use the development set samples as the source of few-shot prompts to evaluate the 5-shot validation set and test set accuracy of the Qwen-7B pretrained model.
80
+
81
+ The accuracy comparison of the Qwen-7B model and other models on the C-Eval validation set is as follows:
82
+
83
+ | Model | Average |
84
+ | :---------- | -------: |
85
+ | Alpaca-7B | 28.9 |
86
+ | Vicuna-7B | 31.2 |
87
+ | ChatGLM-6B | 37.1 |
88
+ | Baichuan-7B | 42.7 |
89
+ | ChatGLM2-6B | 50.9 |
90
+ | InternLM-7B | 53.4 |
91
+ | ChatGPT | 53.5 |
92
+ | Claude-v1.3 | 55.5 |
93
+ | **Qwen-7B** | **60.8** |
94
+
95
+ The performance comparison of the Qwen-7B pretrained model and other models on the C-Eval test set is shown in the following table:
96
+
97
+ | Model | Avg. | Avg. (Hard) | STEM | Social Sciences | Humanities | Others |
98
+ | :---------------------- | -------- | ----------: | ---: | --------------: | ---------: | -----: |
99
+ | ChatGLM-6B | 38.9 | 29.2 | 33.3 | 48.3 | 41.3 | 38.0 |
100
+ | Chinese-Alpaca-Plus-13B | 41.5 | 30.5 | 36.6 | 49.7 | 43.1 | 41.2 |
101
+ | Baichuan-7B | 42.8 | 31.5 | 38.2 | 52.0 | 46.2 | 39.3 |
102
+ | WestlakeLM-19B | 44.6 | 34.9 | 41.6 | 51.0 | 44.3 | 44.5 |
103
+ | AndesLM-13B | 46.0 | 29.7 | 38.1 | 61.0 | 51.0 | 41.9 |
104
+ | BatGPT-15B-sirius | 47.0 | 31.9 | 42.7 | 57.5 | 48.6 | 43.6 |
105
+ | ChatGLM2-6B | 51.7 | 37.1 | 48.6 | 60.5 | 51.3 | 49.8 |
106
+ | InternLM-7B | 52.8 | 37.1 | 48.0 | 67.4 | 55.4 | 45.8 |
107
+ | Baichuan-13B | 53.6 | 36.7 | 47.0 | 66.8 | 57.3 | 49.8 |
108
+ | Claude-v1.3 | 54.2 | 39.0 | 51.9 | 61.7 | 52.1 | 53.7 |
109
+ | ChatGPT | 54.4 | 41.4 | 52.9 | 61.8 | 50.9 | 53.6 |
110
+ | **Qwen-7B** | **59.6** | 41.0 | 52.8 | 74.1 | 63.1 | 55.2 |
111
+
112
+ As can be seen, Qwen-7B achieves the best performance out of all existing models of similar scale and even surpasses larger-scale models.
113
+
114
+ MMLU is currently one of the most recognized benchmarks for evaluating English comprehension abilities, covering 57 subtasks across different academic fields and difficulty levels. The MMLU 5-shot accuracy performance of the Qwen-7B is shown in the following table:
115
+
116
+ | Model | Average | STEM | Social Sciences | Humanities | Others |
117
+ | :----------- | -------: | ---: | --------------: | ---------: | -----: |
118
+ | LLaMA-7B | 35.1 | 30.5 | 38.3 | 34.0 | 38.1 |
119
+ | Baichuan-7B | 42.3 | 35.6 | 48.9 | 38.4 | 48.1 |
120
+ | LLaMA2-7B | 45.3 | 36.4 | 51.2 | 42.9 | 52.2 |
121
+ | LLaMA-13B | 46.9 | 35.8 | 53.8 | 45.0 | 53.3 |
122
+ | ChatGLM2-6B | 47.9 | 41.2 | 54.4 | 43.7 | 54.5 |
123
+ | InternLM-7B | 51.0 | - | - | - | - |
124
+ | Baichuan-13B | 51.6 | 41.6 | 60.9 | 47.4 | 58.5 |
125
+ | LLaMA2-13B | 54.8 | 44.1 | 62.6 | 52.8 | 61.1 |
126
+ | ChatGLM2-12B | 56.2 | 48.2 | 65.1 | 52.6 | 60.9 |
127
+ | **Qwen-7B** | **56.7** | 47.6 | 65.9 | 51.5 | 64.7 |
128
+
129
+ In terms of English, Qwen-7B also surpasses other similar open pretrained models, and is competitive when compared to larger versions of other models.
130
+
131
+ #### Coding
132
+
133
+ We compared the code capabilities of pretrained models on [HumanEval](https://github.com/openai/human-eval), and the results are as follows:
134
+
135
+ | Model | Pass@1 |
136
+ | :----------- | -------: |
137
+ | Baichuan-7B | 9.2 |
138
+ | ChatGLM2-6B | 9.2 |
139
+ | InternLM-7B | 10.4 |
140
+ | LLaMA-7B | 10.5 |
141
+ | LLaMA2-7B | 12.8 |
142
+ | Baichuan-13B | 12.8 |
143
+ | LLaMA-13B | 15.8 |
144
+ | MPT-7B | 18.3 |
145
+ | LLaMA2-13B | 18.3 |
146
+ | **Qwen-7B** | **24.4** |
147
+
148
+ #### Math
149
+
150
+ We compared the math capabilities of pretrained models on [GSM8K](https://github.com/openai/grade-school-math) (8-shot), and the results are as follows:
151
+
152
+ | Model | Accuracy |
153
+ | :----------- | -------: |
154
+ | MPT-7B | 6.8 |
155
+ | Falcon-7B | 6.8 |
156
+ | Baichuan-7B | 9.7 |
157
+ | LLaMA-7B | 11.0 |
158
+ | LLaMA2-7B | 14.6 |
159
+ | LLaMA-13B | 17.8 |
160
+ | Baichuan-13B | 26.6 |
161
+ | LLaMA2-13B | 28.7 |
162
+ | InternLM-7B | 31.2 |
163
+ | ChatGLM2-6B | 32.4 |
164
+ | ChatGLM2-12B | 40.9 |
165
+ | **Qwen-7B** | **51.6** |
166
+
167
+ #### Natural language processing
168
+
169
+ We compared the translation capabilities of pre-trained models on WMT22 zh-en and en-zh (5-shot BLEU), and the results are as follows:
170
+
171
+ | Model | Average | zh-en | en-zh |
172
+ | :---------- | -------: | -------: | -------: |
173
+ | InternLM-7B | 11.8 | 9.0 | 14.5 |
174
+ | LLaMA-7B | 12.7 | 16.7 | 8.7 |
175
+ | LLaMA-13B | 15.8 | 19.5 | 12.0 |
176
+ | LLaMA2-7B | 19.9 | 21.9 | 17.9 |
177
+ | Bloom-7B | 20.3 | 19.1 | 21.4 |
178
+ | LLaMA2-13B | 23.3 | 22.4 | 24.2 |
179
+ | PolyLM-13B | 23.6 | 20.2 | 27.0 |
180
+ | Baichuan-7B | 24.6 | 22.6 | 26.6 |
181
+ | **Qwen-7B** | **27.5** | **24.3** | **30.6** |
182
+
183
+ #### Long-context inference
184
+
185
+ We include support for training-free long-context inference based on ntk-aware interpolation, LogN attention scaling, and local window attention.
186
+ The context can be expanded from 2048 to over 8192.
187
+ The following are the test results on arXiv in terms of perplexity (PPL).
188
+
189
+ <table>
190
+ <tr>
191
+ <th rowspan="2">Model</th><th colspan="5" align="center">Sequence Length</th>
192
+ </tr>
193
+ <tr>
194
+ <th align="center">1024</th><th align="center">2048</th><th align="center">4096</th><th align="center">8192</th><th align="center">16384</th>
195
+ </tr>
196
+ <tr>
197
+ <td>Qwen-7B</td><td align="right"><b>4.23</b></td><td align="right"><b>3.78</b></td><td align="right">39.35</td><td align="right">469.81</td><td align="right">2645.09</td>
198
+ </tr>
199
+ <tr>
200
+ <td>+ dynamic_ntk</td><td align="right"><b>4.23</b></td><td align="right"><b>3.78</b></td><td align="right">3.59</td><td align="right">3.66</td><td align="right">5.71</td>
201
+ </tr>
202
+ <tr>
203
+ <td>+ dynamic_ntk + logn</td><td align="right"><b>4.23</b></td><td align="right"><b>3.78</b></td><td align="right"><b>3.58</b></td><td align="right">3.56</td><td align="right">4.62</td>
204
+ </tr>
205
+ <tr>
206
+ <td>+ dynamic_ntk + logn + local_attn</td><td align="right"><b>4.23</b></td><td align="right"><b>3.78</b></td><td align="right"><b>3.58</b></td><td align="right"><b>3.49</b></td><td align="right"><b>4.32</b></td>
207
+ </tr>
208
+ </table>
209
+
210
+ ## Fine-tuning
211
+
212
+ `Qwen-7B-Chat` embodies our practice in alignment with human intents, ensuring internalized safety, and building intelligent agents for services.
213
+
214
+ ### Data
215
+
216
+ **Alignment data**:
217
+ The data includes common instruction-style conversations, and security- and service-oriented data, which involves substantial annotation efforts.
218
+ Instruction data covers broad abilities, such as writing, question answering, brainstorming and planning, content understanding, summarization, natural language processing, and coding.
219
+ Security data tries to prevent the model from generating harmful and inappropriate content.
220
+ Service data tries to enhance the model with specific conversation patterns that can be parsed to invoke and incorporate external systems.
221
+
222
+ **Data formatting**:
223
+ Since the data consists of conversation turns, we arrange them into texts using the [ChatML](https://github.com/openai/openai-python/blob/main/chatml.md) format, which is a meta language that can describe both the metadata (e.g., roles) and the content of a turn.
224
+ Currently, existing roles include system, user, and assistant.
225
+
226
+ ### Model
227
+
228
+ **Training details**:
229
+ The causal language modeling objective is used to fine-tune the model, except for the tokens in the content of user's turns.
230
+ The model is trained using the AdamW optimizer, with $\beta_1=0.9, \beta_2=0.95, \epsilon=10^{-6}$.
231
+ The sequence length is limited to 2048, and the batch size is 128.
232
+ The model is trained for 4000 steps, and over the first 1430 steps, the learning rate is warmed up to $1 \times 10^{-5}$.
233
+ We use weight decay of 0.1, dropout of 0.1, and gradient clipping of 1.0.
234
+
235
+ ### Evaluation
236
+
237
+ Evaluation of human-aligned models is non-trivial and often non-standardized, since such models often target specific applications.
238
+ We evaluate Qwen-7B-Chat from multiple perspectives.
239
+
240
+ #### World knowledge
241
+
242
+ As fine-tuning uses a much smaller dataset than pretraining and humans' understanding of world knowledge may be limited, we also evaluate the world knowledge of Qwen-7B-Chat using C-Eval and MMLU in a zero-shot and generative manner.
243
+
244
+ We demonstrate the zero-shot accuracy of Qwen-7B-Chat on the C-Eval validation set.
245
+
246
+ | Model | Avg. Acc. |
247
+ | :---------------------- | --------: |
248
+ | LLaMA2-7B-Chat | 31.9 |
249
+ | LLaMA2-13B-Chat | 40.6 |
250
+ | Chinese-Alpaca-2-7B | 41.3 |
251
+ | Chinese-Alpaca-Plus-13B | 43.3 |
252
+ | Baichuan-13B-Chat | 50.4 |
253
+ | ChatGLM2-6B-Chat | 50.7 |
254
+ | InternLM-7B-Chat | 53.2 |
255
+ | **Qwen-7B-Chat** | **54.2** |
256
+
257
+ The zero-shot accuracy of Qwen-7B-Chat on C-Eval testing set is provided below
258
+
259
+ | Model | Avg. | STEM | Social Sciences | Humanities | Others |
260
+ | :---------------------- | -------: | ---: | --------------: | ---------: | -----: |
261
+ | Chinese-Alpaca-Plus-13B | 41.5 | 36.6 | 49.7 | 43.1 | 41.2 |
262
+ | Chinese-Alpaca-2-7B | 40.3 | - | - | - | - |
263
+ | ChatGLM2-6B-Chat | 50.1 | 46.4 | 60.4 | 50.6 | 46.9 |
264
+ | Baichuan-13B-Chat | 51.5 | 43.7 | 64.6 | 56.2 | 49.2 |
265
+ | **Qwen-7B-Chat** | **54.6** | 47.8 | 67.6 | 59.3 | 50.6 |
266
+
267
+ Compared with other models with comparable model sizes, the human-aligned Qwen-7B-Chat performs well in C-Eval accuracy.
268
+
269
+ The zero-shot accuracy of Qwen-7B-Chat on MMLU is provided below.
270
+ The performance of Qwen-7B-Chat is still on top among other human-aligned models with comparable size.
271
+
272
+ | Model | Avg. Acc. |
273
+ | :---------------- | --------: |
274
+ | ChatGLM2-6B-Chat | 45.5 |
275
+ | LLaMA2-7B-Chat | 47.0 |
276
+ | InternLM-7B-Chat | 50.8 |
277
+ | Baichuan-13B-Chat | 52.1 |
278
+ | ChatGLM2-12B-Chat | 52.1 |
279
+ | **Qwen-7B-Chat** | **53.9** |
280
+
281
+ #### Coding
282
+
283
+ The zero-shot Pass@1 of Qwen-7B-Chat on [HumanEval](https://github.com/openai/human-eval) is demonstrated below
284
+
285
+ | Model | Pass@1 |
286
+ | :---------------- | -------: |
287
+ | LLaMA2-7B-Chat | 12.2 |
288
+ | InternLM-7B-Chat | 14.0 |
289
+ | Baichuan-13B-Chat | 16.5 |
290
+ | LLaMA2-13B-Chat | 18.9 |
291
+ | **Qwen-7B-Chat** | **24.4** |
292
+
293
+ #### Math
294
+
295
+ The accuracy of Qwen-7B-Chat on GSM8K is shown below
296
+
297
+ | Model | Zero-shot Acc. | 4-shot Acc. |
298
+ | :---------------- | -------------: | ----------: |
299
+ | ChatGLM2-6B-Chat | - | 28.0 |
300
+ | LLaMA2-7B-Chat | 20.4 | 28.2 |
301
+ | LLaMA2-13B-Chat | 29.4 | 36.7 |
302
+ | InternLM-7B-Chat | 32.6 | 34.5 |
303
+ | Baichuan-13B-Chat | - | 36.3 |
304
+ | ChatGLM2-12B-Chat | - | 38.1 |
305
+ | **Qwen-7B-Chat** | **41.1** | **43.5** |
306
+
307
+ #### Service
308
+
309
+ LLMs have shown capability in coordinating multiple external systems to achieve the given instructions, which creates new opportunities in traditional online services, the most notable being web search.
310
+
311
+ Qwen supports calling plugins/tools/APIs through [ReAct Prompting](https://arxiv.org/abs/2210.03629).
312
+ ReAct is also one of the main approaches used by the [LangChain](https://python.langchain.com/) framework.
313
+ For how to write and use prompts for ReAct Prompting, please refer to [the ReAct examples](examples/react_prompt.md).
314
+ In our evaluation [benchmark](eval/EVALUATION.md) for assessing tool usage capabilities, Qwen's performance is as follows:
315
+
316
+ | Model | Tool Selection (Acc.↑) | Tool Input (Rouge-L↑) | False Positive Error↓ |
317
+ | :---------- | --------------------------: | -------------------------: | -------------------------: |
318
+ | GPT-4 | 95% | **0.90** | 15.0% |
319
+ | GPT-3.5 | 85% | 0.88 | 75.0% |
320
+ | **Qwen-7B** | **99%** | 0.89 | **9.7%** |
321
+
322
+ > The plugins that appear in the evaluation set do not appear in the training set of Qwen.
323
+ > This benchmark evaluates the accuracy of the model in selecting the correct plugin from multiple candidate plugins, the rationality of the parameters passed into the plugin, and the false positive rate.
324
+ > False Positive: Incorrectly invoking a plugin when it should not have been called when responding to a query.
325
+
326
+ Qwen also has the capability to be used as a [HuggingFace Agent](https://huggingface.co/docs/transformers/transformers_agents).
327
+ Its performance on the benchmark provided by HuggingFace is as follows:
328
+
329
+ | Model | Tool Selection↑ | Tool Used↑ | Code↑ |
330
+ | :-------------- | -------------------: | --------------: | ---------: |
331
+ | GPT-4 | **100.00** | **100.00** | **97.41** |
332
+ | GPT-3.5 | 95.37 | 96.30 | 87.04 |
333
+ | StarCoder-15.5B | 87.04 | 87.96 | 68.89 |
334
+ | **Qwen-7B** | 90.74 | 92.59 | 74.07 |
335
+
336
+ ## Conclusion
337
+
338
+ In this document, we describe Qwen-7B, including a pretrained model and a human-aligned model.
339
+ These models have demonstrated exciting performance compared to existing open models of similar or even larger scales.
340
+ As part of our ongoing commitment to the concept of Model as a Service, the release also includes practical pieces such as long context inference and external system integration, which we hope would facilitate developers realizing their own ideas and concepts.
341
+ We believe that the open release of Qwen-7B models would further our understanding of variables and techniques introduced in realistic settings and help to drive progress in this important area together with the community.